tf.Graph()
表示实例化一个类,一个用于tensorflow计算和表示用的数据流图。
tf.Graph().as_defaults()
表示将这个类实例,也就是新生成的图作为整个TensorFlow运行环境的默认图。这儿设置的原因是,如果有多线程多个tf.Graph(),就要有一个默认图的概念。
声明情况大体有三种
1.tensor:通过张量本身直接出graph
#-*- coding: utf-8 -*-
import tensorflow as tf
c = tf.constant(4.0)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
c_out = sess.run(c)
print(c_out)
print(c.graph == tf.get_default_graph())
print(c.graph)
print(tf.get_default_graph())
输出
4.0
True
<tensorflow.python.framework.ops.Graph object at 0x7f382f9ef110>
<tensorflow.python.framework.ops.Graph object at 0x7f382f9ef110>
2.通过声明一个默认的,然后定义张量内容,在后面可以调用或保存
# -*- coding: utf-8 -*-
import tensorflow as tf
g = tf.Graph()
with g.as_default():
c = tf.constant(4.0)
sess = tf.Session(graph=g)
c_out = sess.run(c)
print(c_out)
print(g)
print(tf.get_default_graph())
输出
4.0
<tensorflow.python.framework.ops.Graph object at 0x7f65f1cb2fd0>
<tensorflow.python.framework.ops.Graph object at 0x7f65de447c90>
3.通过多个声明,在后面通过变量名来分别调用
# -*- coding: utf-8 -*-
import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
c1 = tf.constant(4.0)
g2 = tf.Graph()
with g2.as_default():
c2 = tf.constant(20.0)
with tf.Session(graph=g1) as sess1:
print(sess1.run(c1))
with tf.Session(graph=g2) as sess2:
print(sess2.run(c2))
输出
4.0
20.0
对graph的操作大体有三种
1.保存
# -*- coding: utf-8 -*-
import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
# 需要加上名称,在读取pb文件的时候,是通过name和下标来取得对应的tensor的
c1 = tf.constant(4.0, name='c1')
g2 = tf.Graph()
with g2.as_default():
c2 = tf.constant(20.0)
with tf.Session(graph=g1) as sess1:
print(sess1.run(c1))
with tf.Session(graph=g2) as sess2:
print(sess2.run(c2))
# g1的图定义,包含pb的path, pb文件名,是否是文本默认False
tf.train.write_graph(g1.as_graph_def(),'.','graph.pb',False)
输出
4.0
20.0
1
2
并且在当前文件夹下面生成graph.pb文件
2.从pb文件中调用
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.python.platform import gfile
#load graph
with gfile.FastGFile("./graph.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
sess = tf.Session()
c1_tensor = sess.graph.get_tensor_by_name("c1:0")
c1 = sess.run(c1_tensor)
print(c1)
输出
4.0
1
3.穿插调用
# -*- coding: utf-8 -*-
import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
# 声明的变量有名称是一个好的习惯,方便以后使用
c1 = tf.constant(4.0, name="c1")
g2 = tf.Graph()
with g2.as_default():
c2 = tf.constant(20.0, name="c2")
with tf.Session(graph=g2) as sess1:
# 通过名称和下标来得到相应的值
c1_list = tf.import_graph_def(g1.as_graph_def(), return_elements = ["c1:0"], name = '')
print(sess1.run(c1_list[0]+c2))
输出
24.0
tf.GraphDef()
创建一个空的GraphDef对象,用于记录TensorFlow的计算图上的节点信息。
tf.gfile.GFile()
打开pb模型。tf.gfile.GFile('path','rb')
ParasFromString()
得到模型中的计算图和数据
tf.import_graph_def()
将图从GraphDef导入到当前的默认图中
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name=None,
op_dict=None,
producer_op_list=None
)
参数:
-
graph_def
: 包含要导入到默认图中的操作的GraphDef proto。 -
input_map
: 将graph_def中的输入名称(作为字符串)映射到张量对象的字典。输入图中指定的输入张量的值将被重新映射到相应的张量值。 -
return_elements
: 在graph_def中包含操作名的字符串列表,将作为operationobject返回;和/或graph_def中的张量名称,它们将作为张量对象返回。 -
name
: (可选.) 将前缀放在graph_def中名称前面的前缀。注意,这并不适用于导入的函数名。默认为"import"
. -
op_dict
: (可选.) 已弃用,请勿使用 -
producer_op_list
: (可选.) 一个OpList原型,带有(可能是剥离的)图表生产者使用的OpDefs列表。如果提供了,那么根据producer_op_list的默认值,在graph_def中无法识别的ops attrs将被删除。这将允许稍后的二进制文件生成更多的graphdef被早期的二进制文件所接受。
返回:
从导入的图中得到的与return_element中的名称相对应的操作和/或张量对象的列表。
读取tf模型整体流程
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(tf_model_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with detection_graph.as_default():
sess = tf.Session(graph=detection_graph)
return sess, detection_graph
查看pb模型tensor名
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name)
调用tf的graph进行预测
def tf_inference(sess, detection_graph, img_arr):
'''
Receive an image array and run inference
:param sess: tensorflow session.
:param detection_graph: tensorflow graph.
:param img_arr: 3D numpy array, RGB order.
:return:
'''
image_tensor = detection_graph.get_tensor_by_name('data_1:0')
detection_bboxes = detection_graph.get_tensor_by_name('loc_branch_concat_1/concat:0')
detection_scores = detection_graph.get_tensor_by_name('cls_branch_concat_1/concat:0')
# image_np_expanded = np.expand_dims(img_arr, axis=0)
bboxes, scores = sess.run([detection_bboxes, detection_scores],
feed_dict={image_tensor: img_arr})
return bboxes, scores
网友评论