美文网首页
tensorflow的saved_model存取模型

tensorflow的saved_model存取模型

作者: 62ba53cbc93c | 来源:发表于2018-07-14 14:46 被阅读0次

    一种工程级方便的存取模型的方法,saved_model
    通过存取一个简单的模型来作为示范
    首先是模型定义

    import tensorflow as tf
    import numpy as np
    
    
    W = tf.get_variable(name="demo", initializer=tf.ones([10, 32],dtype=tf.float32))
    x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    
    y = tf.matmul(x, W)
    y_ = np.ones(shape=[10, 32], dtype=np.float32) # 使用np来创造两个label
    
    cost = tf.nn.sigmoid_cross_entropy_with_logits(logits=y, labels=y_, name=None)
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
    

    这里定义了一个简单的矩阵乘, 然后我们来简单的训练几步

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        feed_dict = {x: np.ones([10, 10])}
        for i in range(100):
            sess.run(train_op, feed_dict=feed_dict)
        print(sess.run(y, feed_dict=feed_dict))
    

    现在我们想把这个模型存储起来,传统的做法是用ckpt来做,现在tensorflow提供一种更强大简便的方法

    首先构建两个字典,inputs 和 outputs, 把要存入的变量放入字典中
    其中 tf.saved_model.utils.build_tensor_info是把变量变成可缓存对象的函数

        saved_model_dir = "save_model"
        signature_key = 'test_signature'
        input_key = 'input_x'
        output_key = 'output'
    
        # x 为输入tensor
        inputs = {input_key: tf.saved_model.utils.build_tensor_info(x)}
        # y 为最终需要的输出结果tensor
        outputs = {output_key: tf.saved_model.utils.build_tensor_info(y)}
    

    然后把两个字典打包放入 signature 中

    signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs=inputs,
            outputs=outputs,
            method_name=signature_key)
    

    然后建立SavedModelBuilder,并以signature的形式添加要存储的变量

    builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
    builder.add_meta_graph_and_variables(
            sess=sess,
            tags=['test_saved_model'],
            signature_def_map={signature_key: signature},
            clear_devices=True)
        builder.save()
    

    saved_model_dir 是要存模型的文件夹,可以是一个不存在的目录名,save之后,包括图结构,变量的内容,都会被存入到新创建的 saved_model_dir 目录内,下图就是存好的模型


    下面我们来取出一个训练好的模型
    用 tf.saved_model.loader.load 从 模型文件夹中取出模型
    其中tags字段是['test_saved_model'], 与存模型时候指定的字段相同
    把模型导入到session之后, 取出signature 就从signature中取出存入的变量了

    saved_model_dir = "save_model"
    signature_key = 'test_signature'
    input_key = 'input_x'
    output_key = 'output'
    
    with tf.Session() as sess1:
    
        meta_graph_def = tf.saved_model.loader.load(sess1, ['test_saved_model'], saved_model_dir)
        signature = meta_graph_def.signature_def
        x_tensor_name = signature[signature_key].inputs[input_key].name
        y_tensor_name = signature[signature_key].outputs[output_key].name
        print(x_tensor_name)
        print(y_tensor_name)
        x = sess1.graph.get_tensor_by_name(x_tensor_name)
        y = sess1.graph.get_tensor_by_name(y_tensor_name)
        feed_dict = {x: np.ones([1, 10])}
        print(sess1.run(y, feed_dict=feed_dict))
    

    我们看到,首先我们从signature 和 inputs/outputs都是一种字典的封装,把tensor_name存入到了字典中
    传统的导入 需要用get_tensor_by_name , 这样就需要记录tensor的name熟悉,很麻烦。
    通过signature,我们可以指定变量的别名,方便存取。

    另外,存模型和变量的时候,会把全部的模型图存入,并不是只存我们指定几个变量,而signature只是方便我们存取想要使用的变量。

    一个坑,使用tf.Session的时候,切记默认图和指定图的区别。tf.Session()会导入默认图的结构, 而导入模型是需要依附于sess的图, 在默认图中导入模型,如果默认图定义了其他计算图,会导致图冲突,模型导不进去。!!!

    相关文章

      网友评论

          本文标题:tensorflow的saved_model存取模型

          本文链接:https://www.haomeiwen.com/subject/lxrspftx.html