美文网首页
GazeML深度模型导出

GazeML深度模型导出

作者: SpikeKing | 来源:发表于2020-04-17 21:25 被阅读0次

工程:GazeML,https://github.com/swook/GazeML

GazeML/src/core/model.py,放在文件最后。

单文件工程:https://github.com/parai/dms

导出模型:

            # 存储模型
            print('fetches:', fetches)  # to get output tensors' name
            sess = self._tensorflow_session
            from tensorflow.python.framework import graph_util
            constant_graph = graph_util.convert_variables_to_constants(
                sess, sess.graph_def,
                ['hourglass/hg_2/after/hmap/conv/BiasAdd',  # heatmaps
                 'upscale/mul',  # landmarks
                 'radius/out/fc/BiasAdd',  # radius
                 'Video/fifo_queue_DequeueMany',  # frame_index, eye, eye_index
                 ])
            with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
                f.write(constant_graph.SerializeToString())

            from tensorflow.python.platform import gfile

            f = gfile.FastGFile('./gaze.pb', "rb")
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

            for node in graph_def.node:
                if node.op == 'RefSwitch':
                    node.op = 'Switch'
                    for index in range(len(node.input)):
                        if 'moving_' in node.input[index]:
                            node.input[index] = node.input[index] + '/read'
                elif node.op == 'AssignSub':
                    node.op = 'Sub'
                    if 'use_locking' in node.attr: del node.attr['use_locking']

            # import graph into session
            tf.import_graph_def(graph_def, name='')
            tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
            tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)

全部代码:

    def inference_generator(self):
        """Perform inference on test data and yield a batch of output."""
        self.initialize_if_not(training=False)
        self.checkpoint.load_all()  # Load available weights

        # TODO: Make more generic by not picking first source
        data_source = next(iter(self._train_data.values()))
        while True:
            fetches = dict(self.output_tensors['train'], **data_source.output_tensors)
            start_time = time.time()
            outputs = self._tensorflow_session.run(
                fetches=fetches,
                feed_dict={
                    self.is_training: False,
                    self.use_batch_statistics: True,
                },
            )
            outputs['inference_time'] = 1e3 * (time.time() - start_time)

            # 存储模型
            print('fetches:', fetches)  # to get output tensors' name
            sess = self._tensorflow_session
            from tensorflow.python.framework import graph_util
            constant_graph = graph_util.convert_variables_to_constants(
                sess, sess.graph_def,
                ['hourglass/hg_2/after/hmap/conv/BiasAdd',  # heatmaps
                 'upscale/mul',  # landmarks
                 'radius/out/fc/BiasAdd',  # radius
                 'Video/fifo_queue_DequeueMany',  # frame_index, eye, eye_index
                 ])
            with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
                f.write(constant_graph.SerializeToString())

            from tensorflow.python.platform import gfile

            f = gfile.FastGFile('./gaze.pb', "rb")
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

            for node in graph_def.node:
                if node.op == 'RefSwitch':
                    node.op = 'Switch'
                    for index in range(len(node.input)):
                        if 'moving_' in node.input[index]:
                            node.input[index] = node.input[index] + '/read'
                elif node.op == 'AssignSub':
                    node.op = 'Sub'
                    if 'use_locking' in node.attr: del node.attr['use_locking']

            # import graph into session
            tf.import_graph_def(graph_def, name='')
            tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
            tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)

            yield outputs

相关文章

网友评论

      本文标题:GazeML深度模型导出

      本文链接:https://www.haomeiwen.com/subject/xhvuvhtx.html