美文网首页
TensorFlow模型读取问题

TensorFlow模型读取问题

作者: 一技破万法 | 来源:发表于2020-06-26 15:45 被阅读0次

tf.Graph()

表示实例化一个类,一个用于tensorflow计算和表示用的数据流图。

tf.Graph().as_defaults()

表示将这个类实例,也就是新生成的图作为整个TensorFlow运行环境的默认图。这儿设置的原因是,如果有多线程多个tf.Graph(),就要有一个默认图的概念。

声明情况大体有三种

1.tensor:通过张量本身直接出graph

#-*- coding: utf-8 -*-  
import tensorflow as tf

c = tf.constant(4.0)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
c_out = sess.run(c)
print(c_out)
print(c.graph == tf.get_default_graph())
print(c.graph)
print(tf.get_default_graph())

输出

4.0
True
<tensorflow.python.framework.ops.Graph object at 0x7f382f9ef110>
<tensorflow.python.framework.ops.Graph object at 0x7f382f9ef110>

2.通过声明一个默认的,然后定义张量内容,在后面可以调用或保存

# -*- coding: utf-8 -*-  
import tensorflow as tf

g = tf.Graph()
with g.as_default():
    c = tf.constant(4.0)

sess = tf.Session(graph=g)
c_out = sess.run(c)
print(c_out)
print(g)
print(tf.get_default_graph())

输出

4.0
<tensorflow.python.framework.ops.Graph object at 0x7f65f1cb2fd0>
<tensorflow.python.framework.ops.Graph object at 0x7f65de447c90>

3.通过多个声明,在后面通过变量名来分别调用

# -*- coding: utf-8 -*-  
import tensorflow as tf

g1 = tf.Graph()
with g1.as_default():
    c1 = tf.constant(4.0)

g2 = tf.Graph()
with g2.as_default():
    c2 = tf.constant(20.0)

with tf.Session(graph=g1) as sess1:
    print(sess1.run(c1))
with tf.Session(graph=g2) as sess2:
    print(sess2.run(c2))

输出

4.0
20.0

对graph的操作大体有三种
1.保存

# -*- coding: utf-8 -*-  
import tensorflow as tf

g1 = tf.Graph()
with g1.as_default():
    # 需要加上名称,在读取pb文件的时候,是通过name和下标来取得对应的tensor的
    c1 = tf.constant(4.0, name='c1')

g2 = tf.Graph()
with g2.as_default():
    c2 = tf.constant(20.0)

with tf.Session(graph=g1) as sess1:
    print(sess1.run(c1))
with tf.Session(graph=g2) as sess2:
    print(sess2.run(c2))

# g1的图定义,包含pb的path, pb文件名,是否是文本默认False
tf.train.write_graph(g1.as_graph_def(),'.','graph.pb',False)

输出

4.0
20.0
1
2
并且在当前文件夹下面生成graph.pb文件

2.从pb文件中调用

# -*- coding: utf-8 -*-  
import tensorflow as tf
from tensorflow.python.platform import gfile

#load graph
with gfile.FastGFile("./graph.pb",'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

sess = tf.Session()
c1_tensor = sess.graph.get_tensor_by_name("c1:0")
c1 = sess.run(c1_tensor)
print(c1)

输出

4.0
1

3.穿插调用

# -*- coding: utf-8 -*-  
import tensorflow as tf

g1 = tf.Graph()
with g1.as_default():
    # 声明的变量有名称是一个好的习惯,方便以后使用
    c1 = tf.constant(4.0, name="c1")

g2 = tf.Graph()
with g2.as_default():
    c2 = tf.constant(20.0, name="c2")

with tf.Session(graph=g2) as sess1:
    # 通过名称和下标来得到相应的值
    c1_list = tf.import_graph_def(g1.as_graph_def(), return_elements = ["c1:0"], name = '')
    print(sess1.run(c1_list[0]+c2))

输出

24.0



tf.GraphDef()

创建一个空的GraphDef对象,用于记录TensorFlow的计算图上的节点信息。

tf.gfile.GFile()

打开pb模型。tf.gfile.GFile('path','rb')

ParasFromString()

得到模型中的计算图和数据

tf.import_graph_def()

将图从GraphDef导入到当前的默认图中

tf.import_graph_def(
    graph_def,
    input_map=None,
    return_elements=None,
    name=None,
    op_dict=None,
    producer_op_list=None
)

参数:

  • graph_def: 包含要导入到默认图中的操作的GraphDef proto。
  • input_map: 将graph_def中的输入名称(作为字符串)映射到张量对象的字典。输入图中指定的输入张量的值将被重新映射到相应的张量值。
  • return_elements: 在graph_def中包含操作名的字符串列表,将作为operationobject返回;和/或graph_def中的张量名称,它们将作为张量对象返回。
  • name: (可选.) 将前缀放在graph_def中名称前面的前缀。注意,这并不适用于导入的函数名。默认为"import".
  • op_dict: (可选.) 已弃用,请勿使用
  • producer_op_list: (可选.) 一个OpList原型,带有(可能是剥离的)图表生产者使用的OpDefs列表。如果提供了,那么根据producer_op_list的默认值,在graph_def中无法识别的ops attrs将被删除。这将允许稍后的二进制文件生成更多的graphdef被早期的二进制文件所接受。

返回:
从导入的图中得到的与return_element中的名称相对应的操作和/或张量对象的列表。




读取tf模型整体流程

    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(tf_model_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
            with detection_graph.as_default():
                sess = tf.Session(graph=detection_graph)
                return sess, detection_graph

查看pb模型tensor名

            tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
            for tensor_name in tensor_name_list:
                print(tensor_name)

调用tf的graph进行预测

def tf_inference(sess, detection_graph, img_arr):
    '''
    Receive an image array and run inference
    :param sess: tensorflow session.
    :param detection_graph: tensorflow graph.
    :param img_arr: 3D numpy array, RGB order.
    :return:
    '''
    image_tensor = detection_graph.get_tensor_by_name('data_1:0')
    
    
    detection_bboxes = detection_graph.get_tensor_by_name('loc_branch_concat_1/concat:0')
    detection_scores = detection_graph.get_tensor_by_name('cls_branch_concat_1/concat:0')
    # image_np_expanded = np.expand_dims(img_arr, axis=0)
    bboxes, scores = sess.run([detection_bboxes, detection_scores],
                            feed_dict={image_tensor: img_arr})

    return bboxes, scores

相关文章

网友评论

      本文标题:TensorFlow模型读取问题

      本文链接:https://www.haomeiwen.com/subject/gopyfktx.html