美文网首页
GazeML深度模型导出

GazeML深度模型导出

作者: SpikeKing | 来源:发表于2020-04-17 21:25 被阅读0次

    工程:GazeML,https://github.com/swook/GazeML

    GazeML/src/core/model.py,放在文件最后。

    单文件工程:https://github.com/parai/dms

    导出模型:

                # 存储模型
                print('fetches:', fetches)  # to get output tensors' name
                sess = self._tensorflow_session
                from tensorflow.python.framework import graph_util
                constant_graph = graph_util.convert_variables_to_constants(
                    sess, sess.graph_def,
                    ['hourglass/hg_2/after/hmap/conv/BiasAdd',  # heatmaps
                     'upscale/mul',  # landmarks
                     'radius/out/fc/BiasAdd',  # radius
                     'Video/fifo_queue_DequeueMany',  # frame_index, eye, eye_index
                     ])
                with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
                    f.write(constant_graph.SerializeToString())
    
                from tensorflow.python.platform import gfile
    
                f = gfile.FastGFile('./gaze.pb', "rb")
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
    
                for node in graph_def.node:
                    if node.op == 'RefSwitch':
                        node.op = 'Switch'
                        for index in range(len(node.input)):
                            if 'moving_' in node.input[index]:
                                node.input[index] = node.input[index] + '/read'
                    elif node.op == 'AssignSub':
                        node.op = 'Sub'
                        if 'use_locking' in node.attr: del node.attr['use_locking']
    
                # import graph into session
                tf.import_graph_def(graph_def, name='')
                tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
                tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)
    

    全部代码:

        def inference_generator(self):
            """Perform inference on test data and yield a batch of output."""
            self.initialize_if_not(training=False)
            self.checkpoint.load_all()  # Load available weights
    
            # TODO: Make more generic by not picking first source
            data_source = next(iter(self._train_data.values()))
            while True:
                fetches = dict(self.output_tensors['train'], **data_source.output_tensors)
                start_time = time.time()
                outputs = self._tensorflow_session.run(
                    fetches=fetches,
                    feed_dict={
                        self.is_training: False,
                        self.use_batch_statistics: True,
                    },
                )
                outputs['inference_time'] = 1e3 * (time.time() - start_time)
    
                # 存储模型
                print('fetches:', fetches)  # to get output tensors' name
                sess = self._tensorflow_session
                from tensorflow.python.framework import graph_util
                constant_graph = graph_util.convert_variables_to_constants(
                    sess, sess.graph_def,
                    ['hourglass/hg_2/after/hmap/conv/BiasAdd',  # heatmaps
                     'upscale/mul',  # landmarks
                     'radius/out/fc/BiasAdd',  # radius
                     'Video/fifo_queue_DequeueMany',  # frame_index, eye, eye_index
                     ])
                with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
                    f.write(constant_graph.SerializeToString())
    
                from tensorflow.python.platform import gfile
    
                f = gfile.FastGFile('./gaze.pb', "rb")
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
    
                for node in graph_def.node:
                    if node.op == 'RefSwitch':
                        node.op = 'Switch'
                        for index in range(len(node.input)):
                            if 'moving_' in node.input[index]:
                                node.input[index] = node.input[index] + '/read'
                    elif node.op == 'AssignSub':
                        node.op = 'Sub'
                        if 'use_locking' in node.attr: del node.attr['use_locking']
    
                # import graph into session
                tf.import_graph_def(graph_def, name='')
                tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
                tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)
    
                yield outputs
    

    相关文章

      网友评论

          本文标题:GazeML深度模型导出

          本文链接:https://www.haomeiwen.com/subject/xhvuvhtx.html