美文网首页
tensorflow模型保存

tensorflow模型保存

作者: 井底蛙蛙呱呱呱 | 来源:发表于2019-04-17 19:41 被阅读0次

    在使用TensorFlow训练模型时,为了避免每次预测都要重新训练模型,模型保存必不可少。而在模型保存时,使用不同的参数可采用不同的保存模式。

    TensorFlow使用tf.train.Saver()来保存或重载模型。其常用参数有两个个:

    • var_list 指定要保存的变量名;
    • max_to_keep 设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0。

    一般情况下我们不会指定var_list参数,表示我们保存模型的所有参数,也即var_list=tf.global_variables(),但是有时候我们可能只想保存某些指定的变量,譬如仅保存trainable_variables 和 batch normlization参数:

    # 提取batch nrom参数, 参考https://stackoverflow.com/questions/45800871/tensorflow-save-restore-batch-norm
    bn_vars = [var for var in tf.global_variables() if ('moving_variance' in var.name) or ('moving_mean' in var.name)]
    # 获取可训练参数
    train_var = tf.trainable_variables()
    # 保存时指定变量和ckpt个数,此时一般可以将保存文件大小压缩至原大小1/3
    saver = tf.train.Saver(train_var+bn_vars, max_to_keep=1000)  
    
    # 最后,指定是什么时候保存模型
    with tf.Session() as sess:
        ...
        if epochs%10==0:
            saver.save(sess, save_dir, epochs)
    

    最后,一般会在指定文件夹下生存4个文件(参考知乎回答):

    • checkpoint Checkpoint保存断点文件列表,可以用来迅速查找最近一次的断点文件;
    • vggfcn_mom1to2_22w.ckpt-10000.data-00000-of-00001 数据(data)文件保存所有变量的值,即网络权值;
    • vggfcn_mom1to2_22w.ckpt-10000.index index文件为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。Index文件由data block/index block/Footer等组成,构建时主要涉及BundleWriter、TableBuilder、BlockBuilder几个类,除了BundleEntry的序列化,还涉及了tensor name的编码及优化(比如丢弃重复的前缀)和data block的snappy压缩;
    • vggfcn_mom1to2_22w.ckpt-10000.meta meta文件是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等。

    其中后三个(data,index,meta)是重载模型必须的。

    重载较为简单,也可使用slim进行重载,这里不进行详细介绍:

    saver = tf.train.Saver()  
    with tf.Session() as sess:  
            saver.restore(sess, "save/model.ckpt")  
    

    最后,来看看怎么查看ckpt文件中的参数:

    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    checkpoint_path = "vggfcn_mom1to2_model.ckpt-335000"
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='',all_tensors=True)
    
    # 输出:
    tensor_name:  global_step
    335001.0
    tensor_name:  tower_0/softmax_cross_entropy_loss/value/avg
    0.20133862
    tensor_name:  tower_0/total_loss/avg
    0.23763658
    tensor_name:  vgg_16/conv1/conv1_1/biases
    [ 0.02797224  0.07390981 -0.00385985  0.00785956 -0.01232299 -0.01137739
      0.0253244  -0.07548576  0.00300218 -0.01330253  0.00169791 -0.0064972
     -0.04955265  0.23835294  0.00348639 -0.00088801 -0.00086359  0.04471491
      0.02329434 -0.08720596 -0.00614685  0.04078082  0.00396819  0.03650536
     -0.00419048  0.02792025 -0.01622533  0.00038355 -0.02241365  0.04451815
      0.14487104  0.09795988  0.02575901  0.01830461 -0.06216403  0.14097025
     -0.00921478 -0.0047136   0.03897408  0.05741399  0.00417633  0.03451509
     -0.01117235  0.0383427  -0.00806538 -0.00271318 -0.03132046 -0.06175949
      0.01164324 -0.00922603  0.00940489 -0.08562332  0.0114077   0.00084499
      0.0093553  -0.01873754 -0.01944772  0.02312844 -0.08838678 -0.01441999
      0.10067106  0.02811741 -0.00628062  0.1013862 ]
    tensor_name:  vgg_16/conv1/conv1_1/biases/ExponentialMovingAverage
    [ 0.02784752  0.07357726 -0.00385987  0.00785797 -0.01232291 -0.01145722
      0.02526607 -0.07584356  0.00298003 -0.01330372  0.0016973  -0.00649742
     -0.0496439   0.23820773  0.00345324 -0.00088819 -0.00087755  0.04459094
      0.02319813 -0.08743148 -0.00614733  0.04070283  0.00393652  0.03630615
     -0.00418991  0.02786526 -0.016225    0.00033586 -0.02241365  0.04432626
      0.14477569  0.09788124  0.02568062  0.01826548 -0.06237179  0.14037122
     -0.00921435 -0.0047203   0.03881586  0.0567292   0.00412849  0.03450596
     -0.01117271  0.0383288  -0.00806614 -0.00271349 -0.03151926 -0.06203282
      0.01159994 -0.00922479  0.00936983 -0.08602356  0.01136672  0.00076801
      0.0092632  -0.01873754 -0.01944782  0.02309002 -0.08878716 -0.01442014
      0.10042226  0.02782144 -0.00628062  0.10099234]
    tensor_name:  vgg_16/conv1/conv1_1/biases/Momentum
    [ 1.1582047e-03 -3.2689439e-03  0.0000000e+00 -3.7641104e-05
      0.0000000e+00 -2.8069332e-04 -3.4572899e-05 -8.5439772e-04
     ...
    

    .pb格式的模型保存、重载以及.ckpt格式文件转换为.pb格式

    # coding: utf-8
    
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    import os
    import numpy as np
    import cv2
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    
    
    def restore_pb(img_data, pb_file_path=None):
        '''restore model from .pb to infer
        '''
        if not pb_file_path:
            pb_file_path = "test/inceptionV3_0.pb"
        with tf.Graph().as_default():
            graph_def = tf.GraphDef()
            with tf.gfile.FastGFile(pb_file_path, "rb") as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                return_elems = ['tower_0/InceptionV3/predictions/Softmax:0']
                predictions = tf.import_graph_def(graph_def, return_elements=return_elems)
    
            with tf.Session() as sess:
                # print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
                # for node in sess.graph.node:
                #     print(node.name)
                sess.run(tf.global_variables_initializer())
                image = sess.graph.get_tensor_by_name('import/x:0')
                #label = sess.graph.get_tensor_by_name('y:0')
    
                predictions = sess.run(predictions, feed_dict={image: img_data})
                print('pb predictions:', predictions[0])
    
    
    def restore_ckpt(img_data):
        '''restore model from .ckpt to infer
        '''
        cpkt_meta = 'test/inception_mom1to2_rand5.ckpt-0.meta'
        ckpt = 'test/inception_mom1to2_rand5.ckpt-0'
        with tf.Graph().as_default():
            graph = tf.Graph()
            config = tf.ConfigProto(allow_soft_placement=True)
            with tf.Session(graph=graph, config=config) as sess:
                saver = tf.train.import_meta_graph(cpkt_meta)
                saver.restore(sess, ckpt)
                # print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
                image = graph.get_tensor_by_name('x:0')
                preds = graph.get_tensor_by_name('tower_0/InceptionV3/predictions/Softmax:0')
                predictions = sess.run(preds, feed_dict={image: img_data})
                print('ckpt predictions:', predictions)
    
    
    def ckpt2pb():
        '''convert .ckpt to .pb file
        '''
        # some path
        output_nodes = ['tower_0/InceptionV3/predictions/Softmax']
        cpkt_meta = 'test/inception_mom1to2_rand5.ckpt-0.meta'
        ckpt = 'test/inception_mom1to2_rand5.ckpt-0'
        pb_file_path = 'test/inceptionV3_ckpt2pb.pb'
    
        with tf.Graph().as_default():
            graph = tf.Graph()
            config = tf.ConfigProto(allow_soft_placement=True)
    
            with tf.Session(graph=graph, config=config) as sess:
                # restored model from ckpt
                saver = tf.train.import_meta_graph(cpkt_meta)
                saver.restore(sess, ckpt)
    
                # save freeze graph into .pb file
                graph_def = tf.get_default_graph().as_graph_def()
                constant_graph = graph_util.convert_variables_to_constants(sess, graph_def, output_nodes)
                with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
                    f.write(constant_graph.SerializeToString())
    
    
    img_data = cv2.imread('tumor_009_19327_175114_0.jpeg')
    img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
    img_data = (np.array(img_data).astype(np.float32)) / 256.0
    img_data = np.reshape(img_data, [-1, 256, 256, 3])
    
    restore_pb(img_data)
    restore_ckpt(img_data)
    ckpt2pb()
    restore_pb(img_data, pb_file_path='test/inceptionV3_ckpt2pb.pb')
    

    参考:
    Tensorflow框架实现中的“三”种图
    TensorFlow 保存模型为 PB 文件

    相关文章

      网友评论

          本文标题:tensorflow模型保存

          本文链接:https://www.haomeiwen.com/subject/ydhowqtx.html