工程:GazeML,https://github.com/swook/GazeML
GazeML/src/core/model.py
,放在文件最后。
单文件工程:https://github.com/parai/dms
导出模型:
# 存储模型
print('fetches:', fetches) # to get output tensors' name
sess = self._tensorflow_session
from tensorflow.python.framework import graph_util
constant_graph = graph_util.convert_variables_to_constants(
sess, sess.graph_def,
['hourglass/hg_2/after/hmap/conv/BiasAdd', # heatmaps
'upscale/mul', # landmarks
'radius/out/fc/BiasAdd', # radius
'Video/fifo_queue_DequeueMany', # frame_index, eye, eye_index
])
with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
from tensorflow.python.platform import gfile
f = gfile.FastGFile('./gaze.pb', "rb")
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
for node in graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
# import graph into session
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)
全部代码:
def inference_generator(self):
"""Perform inference on test data and yield a batch of output."""
self.initialize_if_not(training=False)
self.checkpoint.load_all() # Load available weights
# TODO: Make more generic by not picking first source
data_source = next(iter(self._train_data.values()))
while True:
fetches = dict(self.output_tensors['train'], **data_source.output_tensors)
start_time = time.time()
outputs = self._tensorflow_session.run(
fetches=fetches,
feed_dict={
self.is_training: False,
self.use_batch_statistics: True,
},
)
outputs['inference_time'] = 1e3 * (time.time() - start_time)
# 存储模型
print('fetches:', fetches) # to get output tensors' name
sess = self._tensorflow_session
from tensorflow.python.framework import graph_util
constant_graph = graph_util.convert_variables_to_constants(
sess, sess.graph_def,
['hourglass/hg_2/after/hmap/conv/BiasAdd', # heatmaps
'upscale/mul', # landmarks
'radius/out/fc/BiasAdd', # radius
'Video/fifo_queue_DequeueMany', # frame_index, eye, eye_index
])
with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
from tensorflow.python.platform import gfile
f = gfile.FastGFile('./gaze.pb', "rb")
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
for node in graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
# import graph into session
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)
yield outputs
网友评论