美文网首页
tensorflow将ckpt模型转为pb模型

tensorflow将ckpt模型转为pb模型

作者: 水击长空 | 来源:发表于2019-02-12 15:02 被阅读0次

    获取原网络中的所有节点

    在训练代码中定义好图之后加入以下代码:

    for node in tf.get_default_graph().as_graph_def().node:

        print(node.name)

    主要是要查看最后一个节点的名字

    模型转化

    不再重新建图时, 使用tf.train.import_meta_graph

    def freeze_graph(input_checkpoint,output_graph):

        '''

        :param input_checkpoint:ckpt模型路径

        :param output_graph: pb模型保存路径

        '''

        output_node_names = " " # 填入第一步得到的最后一个节点名

        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

        with tf.Session() as sess:

            saver.restore(sess, input_checkpoint) #恢复图并得到数据

            output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定

                sess=sess,

                input_graph_def=sess.graph_def,# 等于:sess.graph_def

                output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开

            with tf.gfile.GFile(output_graph, "wb") as f: #保存模型

                f.write(output_graph_def.SerializeToString()) #序列化输出

            print("%d ops in the final graph." % len(output_graph_def.node)) # 统计图中总的操作节点数

    或者修改前传代码,使用tf.train.Saver()

    在前传代码里,restore模型

    restorer = tf.train.Saver(tf.global_variables())

    ckpt = tf.train.get_checkpoint_state(' ') # 填入ckpt模型所在文件夹路径

    model_path = ckpt.model_checkpoint_path # 读取checkpoint文件里的第一行

    with tf.Session() as sess:

        # Create a saver.

        sess.run(tf.local_variables_initializer())

        sess.run(tf.global_variables_initializer())

        try:

            restorer.restore(sess, model_path)

            print(model_path.split('/')[-1] + " restored!")

        except IOError:

            print("checkpoints not found.")

        output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定

            sess=sess,

            input_graph_def=sess.graph_def,  # 等于:sess.graph_def

            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(out_pb_path, "wb") as f:  # 保存模型

            f.write(output_graph_def.SerializeToString())  # 序列化输出

        print("%d ops in the final graph." % len(output_graph_def.node))

      # 统计图中总的操作节点数

    从pb模型中读取节点

    #coding:utf-8

    import tensorflow as tf

    from tensorflow.python.framework import graph_util

    tf.reset_default_graph()  # 重置计算图

    output_graph_path = 'model/model_tfnew.pb'

    with tf.Session() as sess:

        tf.global_variables_initializer().run()

        output_graph_def = tf.GraphDef()

        # 获得默认的图

        graph = tf.get_default_graph()

        with open(output_graph_path, "rb") as f:

            output_graph_def.ParseFromString(f.read())

            _ = tf.import_graph_def(output_graph_def, name="")

            # 得到当前图有几个操作节点

            print("%d ops in the final graph." % len(output_graph_def.node))

            tensor_name = [tensor.name for tensor in output_graph_def.node]

            print(tensor_name)

            print('---------------------------')

            # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型

            summaryWriter = tf.summary.FileWriter('log_graph/', graph)

            for op in graph.get_operations():

                # print出tensor的name和值

                print(op.name, op.values())

    参考:https://blog.csdn.net/u010397980/article/details/84889174

               https://blog.csdn.net/guyuealian/article/details/82218092

    相关文章

      网友评论

          本文标题:tensorflow将ckpt模型转为pb模型

          本文链接:https://www.haomeiwen.com/subject/qehkeqtx.html