美文网首页
tensorflow模型保存

tensorflow模型保存

作者: 井底蛙蛙呱呱呱 | 来源:发表于2019-04-17 19:41 被阅读0次

在使用TensorFlow训练模型时,为了避免每次预测都要重新训练模型,模型保存必不可少。而在模型保存时,使用不同的参数可采用不同的保存模式。

TensorFlow使用tf.train.Saver()来保存或重载模型。其常用参数有两个个:

  • var_list 指定要保存的变量名;
  • max_to_keep 设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0。

一般情况下我们不会指定var_list参数,表示我们保存模型的所有参数,也即var_list=tf.global_variables(),但是有时候我们可能只想保存某些指定的变量,譬如仅保存trainable_variables 和 batch normlization参数:

# 提取batch nrom参数, 参考https://stackoverflow.com/questions/45800871/tensorflow-save-restore-batch-norm
bn_vars = [var for var in tf.global_variables() if ('moving_variance' in var.name) or ('moving_mean' in var.name)]
# 获取可训练参数
train_var = tf.trainable_variables()
# 保存时指定变量和ckpt个数,此时一般可以将保存文件大小压缩至原大小1/3
saver = tf.train.Saver(train_var+bn_vars, max_to_keep=1000)  

# 最后,指定是什么时候保存模型
with tf.Session() as sess:
    ...
    if epochs%10==0:
        saver.save(sess, save_dir, epochs)

最后,一般会在指定文件夹下生存4个文件(参考知乎回答):

  • checkpoint Checkpoint保存断点文件列表,可以用来迅速查找最近一次的断点文件;
  • vggfcn_mom1to2_22w.ckpt-10000.data-00000-of-00001 数据(data)文件保存所有变量的值,即网络权值;
  • vggfcn_mom1to2_22w.ckpt-10000.index index文件为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。Index文件由data block/index block/Footer等组成,构建时主要涉及BundleWriter、TableBuilder、BlockBuilder几个类,除了BundleEntry的序列化,还涉及了tensor name的编码及优化(比如丢弃重复的前缀)和data block的snappy压缩;
  • vggfcn_mom1to2_22w.ckpt-10000.meta meta文件是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等。

其中后三个(data,index,meta)是重载模型必须的。

重载较为简单,也可使用slim进行重载,这里不进行详细介绍:

saver = tf.train.Saver()  
with tf.Session() as sess:  
        saver.restore(sess, "save/model.ckpt")  

最后,来看看怎么查看ckpt文件中的参数:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
checkpoint_path = "vggfcn_mom1to2_model.ckpt-335000"
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='',all_tensors=True)

# 输出:
tensor_name:  global_step
335001.0
tensor_name:  tower_0/softmax_cross_entropy_loss/value/avg
0.20133862
tensor_name:  tower_0/total_loss/avg
0.23763658
tensor_name:  vgg_16/conv1/conv1_1/biases
[ 0.02797224  0.07390981 -0.00385985  0.00785956 -0.01232299 -0.01137739
  0.0253244  -0.07548576  0.00300218 -0.01330253  0.00169791 -0.0064972
 -0.04955265  0.23835294  0.00348639 -0.00088801 -0.00086359  0.04471491
  0.02329434 -0.08720596 -0.00614685  0.04078082  0.00396819  0.03650536
 -0.00419048  0.02792025 -0.01622533  0.00038355 -0.02241365  0.04451815
  0.14487104  0.09795988  0.02575901  0.01830461 -0.06216403  0.14097025
 -0.00921478 -0.0047136   0.03897408  0.05741399  0.00417633  0.03451509
 -0.01117235  0.0383427  -0.00806538 -0.00271318 -0.03132046 -0.06175949
  0.01164324 -0.00922603  0.00940489 -0.08562332  0.0114077   0.00084499
  0.0093553  -0.01873754 -0.01944772  0.02312844 -0.08838678 -0.01441999
  0.10067106  0.02811741 -0.00628062  0.1013862 ]
tensor_name:  vgg_16/conv1/conv1_1/biases/ExponentialMovingAverage
[ 0.02784752  0.07357726 -0.00385987  0.00785797 -0.01232291 -0.01145722
  0.02526607 -0.07584356  0.00298003 -0.01330372  0.0016973  -0.00649742
 -0.0496439   0.23820773  0.00345324 -0.00088819 -0.00087755  0.04459094
  0.02319813 -0.08743148 -0.00614733  0.04070283  0.00393652  0.03630615
 -0.00418991  0.02786526 -0.016225    0.00033586 -0.02241365  0.04432626
  0.14477569  0.09788124  0.02568062  0.01826548 -0.06237179  0.14037122
 -0.00921435 -0.0047203   0.03881586  0.0567292   0.00412849  0.03450596
 -0.01117271  0.0383288  -0.00806614 -0.00271349 -0.03151926 -0.06203282
  0.01159994 -0.00922479  0.00936983 -0.08602356  0.01136672  0.00076801
  0.0092632  -0.01873754 -0.01944782  0.02309002 -0.08878716 -0.01442014
  0.10042226  0.02782144 -0.00628062  0.10099234]
tensor_name:  vgg_16/conv1/conv1_1/biases/Momentum
[ 1.1582047e-03 -3.2689439e-03  0.0000000e+00 -3.7641104e-05
  0.0000000e+00 -2.8069332e-04 -3.4572899e-05 -8.5439772e-04
 ...

.pb格式的模型保存、重载以及.ckpt格式文件转换为.pb格式

# coding: utf-8

import tensorflow as tf
from tensorflow.python.framework import graph_util
import os
import numpy as np
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def restore_pb(img_data, pb_file_path=None):
    '''restore model from .pb to infer
    '''
    if not pb_file_path:
        pb_file_path = "test/inceptionV3_0.pb"
    with tf.Graph().as_default():
        graph_def = tf.GraphDef()
        with tf.gfile.FastGFile(pb_file_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            return_elems = ['tower_0/InceptionV3/predictions/Softmax:0']
            predictions = tf.import_graph_def(graph_def, return_elements=return_elems)

        with tf.Session() as sess:
            # print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
            # for node in sess.graph.node:
            #     print(node.name)
            sess.run(tf.global_variables_initializer())
            image = sess.graph.get_tensor_by_name('import/x:0')
            #label = sess.graph.get_tensor_by_name('y:0')

            predictions = sess.run(predictions, feed_dict={image: img_data})
            print('pb predictions:', predictions[0])


def restore_ckpt(img_data):
    '''restore model from .ckpt to infer
    '''
    cpkt_meta = 'test/inception_mom1to2_rand5.ckpt-0.meta'
    ckpt = 'test/inception_mom1to2_rand5.ckpt-0'
    with tf.Graph().as_default():
        graph = tf.Graph()
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(graph=graph, config=config) as sess:
            saver = tf.train.import_meta_graph(cpkt_meta)
            saver.restore(sess, ckpt)
            # print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
            image = graph.get_tensor_by_name('x:0')
            preds = graph.get_tensor_by_name('tower_0/InceptionV3/predictions/Softmax:0')
            predictions = sess.run(preds, feed_dict={image: img_data})
            print('ckpt predictions:', predictions)


def ckpt2pb():
    '''convert .ckpt to .pb file
    '''
    # some path
    output_nodes = ['tower_0/InceptionV3/predictions/Softmax']
    cpkt_meta = 'test/inception_mom1to2_rand5.ckpt-0.meta'
    ckpt = 'test/inception_mom1to2_rand5.ckpt-0'
    pb_file_path = 'test/inceptionV3_ckpt2pb.pb'

    with tf.Graph().as_default():
        graph = tf.Graph()
        config = tf.ConfigProto(allow_soft_placement=True)

        with tf.Session(graph=graph, config=config) as sess:
            # restored model from ckpt
            saver = tf.train.import_meta_graph(cpkt_meta)
            saver.restore(sess, ckpt)

            # save freeze graph into .pb file
            graph_def = tf.get_default_graph().as_graph_def()
            constant_graph = graph_util.convert_variables_to_constants(sess, graph_def, output_nodes)
            with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
                f.write(constant_graph.SerializeToString())


img_data = cv2.imread('tumor_009_19327_175114_0.jpeg')
img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
img_data = (np.array(img_data).astype(np.float32)) / 256.0
img_data = np.reshape(img_data, [-1, 256, 256, 3])

restore_pb(img_data)
restore_ckpt(img_data)
ckpt2pb()
restore_pb(img_data, pb_file_path='test/inceptionV3_ckpt2pb.pb')

参考:
Tensorflow框架实现中的“三”种图
TensorFlow 保存模型为 PB 文件

相关文章

网友评论

      本文标题:tensorflow模型保存

      本文链接:https://www.haomeiwen.com/subject/ydhowqtx.html