美文网首页TensorFlow阿群的参考资料
保存tensorflow模型为pb文件

保存tensorflow模型为pb文件

作者: 夕一啊 | 来源:发表于2018-09-17 11:35 被阅读322次

    通常训练模型的时候是保存ckpt方便接着训练,但是上线可以保存为pb模型,加载的时候不需要重新定义模型,只用输入输出来调用模型。

    import tensorflow as tf
    from tensorflow.python.saved_model import builder as saved_model_builder
    from tensorflow.python.saved_model import (signature_constants, signature_def_utils, tag_constants, utils)
    
    class model():
        def __init__(self):
            self.a = tf.placeholder(tf.float32, [None])
            self.w = tf.Variable(tf.constant(2.0, shape=[1]), name="w")
            b = tf.Variable(tf.constant(0.5, shape=[1]), name="b")
            self.y = self.a * self.w + b
    
    #模型保存为ckpt
    def save_model(): 
        m = model()
        session = tf.Session()
        session.run(tf.global_variables_initializer())
        update = tf.assign(m.w, [10])
        session.run(update)
        predict_y = session.run(m.y,feed_dict={m.a:[3.0]})
        print(predict_y)
    
        saver = tf.train.Saver()
        saver.save(session,"model/model.ckpt")
        session.close()
    
    #加载ckpt模型
    def load_model():
        m = model()
        session = tf.Session()
        saver = tf.train.Saver()
        saver.restore(session, "model/model.ckpt")
        predict_y = session.run(m.y, feed_dict={m.a: [3.0]})
        print(predict_y)
        return session,m
    
    #保存为pb模型
    def export_model(session,m):
    
       #只需要修改这一段,定义输入输出,其他保持默认即可
        model_signature = signature_def_utils.build_signature_def(
            inputs={"input": utils.build_tensor_info(m.a)},
            outputs={
                "output": utils.build_tensor_info(m.y)},
    
            method_name=signature_constants.PREDICT_METHOD_NAME)
    
        export_path = "pb_model/1"
        print("Export the model to {}".format(export_path))
    
        try:
            legacy_init_op = tf.group(
                tf.tables_initializer(), name='legacy_init_op')
            builder = saved_model_builder.SavedModelBuilder(export_path)
            builder.add_meta_graph_and_variables(
                session, [tag_constants.SERVING],
                clear_devices=True,
                signature_def_map={
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        model_signature,
                },
                legacy_init_op=legacy_init_op)
    
            builder.save()
        except Exception as e:
            print("Fail to export saved model, exception: {}".format(e))
    
    #加载pb模型
    def load_pb():
        session = tf.Session(graph=tf.Graph())
        model_file_path = "pb_model/1"
        meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path)
    
        model_graph_signature = list(meta_graph.signature_def.items())[0][1]
        output_tensor_names = []
        output_op_names = []
        for output_item in model_graph_signature.outputs.items():
            output_op_name = output_item[0]
            output_op_names.append(output_op_name)
            output_tensor_name = output_item[1].name
            output_tensor_names.append(output_tensor_name)
        print("load model finish!")
        sentences = {}
        for test_x in [[1],[2],[3],[4],[5]]:
    
            sentences["input"] = test_x
            feed_dict_map = {}
            for input_item in model_graph_signature.inputs.items():
                input_op_name = input_item[0]
                input_tensor_name = input_item[1].name
                feed_dict_map[input_tensor_name] = sentences[input_op_name]
            predict_y = session.run(output_tensor_names, feed_dict=feed_dict_map)
            print("predict pb y:",predict_y)
    
    if __name__ == "__main__":
        save_model()     
        session, m = load_model()
        export_model(session, m)
        load_pb()
    
    

    save_model 和load_model两个函数要分开执行,第一次注释掉load,只save,第二次load的时候注释掉save。因为声明模型的时候都是用默认图,变量命名会依次是0,1,load的时候名字对应不上。

    保存好的pb模型路径文件格式为


    image.png

    相关文章

      网友评论

      • 阿群1986:graph2 =tf.Graph()
        sess2=tf.Session(graph2)
        夕一啊:@阿群1986 是的,这样也可以。https://zhuanlan.zhihu.com/p/32887066 这里还有另一种保存pb模型的方式,而且图和参数保存在一个pb文件里,看起来更简洁。好像有很多api都可以保存为pb不知道它们的差异在哪
        阿群1986:load model用with graph2.as_default(): 新图就不会冲突
      • 阿群1986:#加载ckpt模型
        def load_model():
        … # m = model() 这一行被注释掉了,有问题
        … session = tf.Session()
        … # saver = tf.train.Saver()
        … saver = tf.train.import_meta_graph("model/model.ckpt.meta")
        … saver.restore(session, "model/model.ckpt")
        … predict_y = session.run(m.y, feed_dict={m.a: [3.0]})
        … print(predict_y)
        … return session,m
        阿群1986:好的
        夕一啊:谢谢提醒,已经纠正过来了,saver那行也有问题,是读图的模式了,也改过来了

      本文标题:保存tensorflow模型为pb文件

      本文链接:https://www.haomeiwen.com/subject/jtiknftx.html