Object Detection API(4)—— Freeze

作者: AI小白龙 | 来源:发表于2018-04-30 20:02 被阅读54次

    Object Detection API4)—— Freeze Model模型导出

    博客:https://blog.csdn.net/qq_34106574

    简书:https://www.jianshu.com/u/fb86cd4f8bf8

    上一节使用自定义record数据进行模型训练和测试,本节将训练模型导出为pb格式,方便程序调用,后面还会介绍如何使用opencv的c++程序来调用训练好的模型。

    1,导出模型:

    object_detection目录下还提供了export_inference_graph.py。直接调用执行命令如下:

    python ../research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path train/faster_rcnn_resnet101_coco.config --trained_checkpoint_prefix train/model.ckpt-1738  --output_directory out

    导出完成后,在out目录下,会生成frozen_inference_graph.pb、model.ckpt.data-00000-of-00001、model.ckpt.meta、model.ckpt.data文件。如下图:

     

    2,调用接口测试:

    (1)添加mymodel在object_detection目录,将pb文件放置在data目录,将测试图像放在test_images目录.

     

    (2)修改object_detection_tutorial.ipynb

    (3)在object_detection目录执行命令如下:

    jupyter notebookobject_detection_tutorial.ipynb

     

    (4)测试结果如下:

    3,调用方式2:

    打开pycharm,建立工程,运行以下程序:

     

    import cv2

    import numpy as np

    import tensorflow as tffrom object_detection.utils import label_map_utilfrom object_detection.utils import visualization_utils as vis_util

    class TOD(object):

        def __init__(self):

            self.PATH_TO_CKPT = r'XXXX\models-master\object_detection\train\frozen_inference_graph.pb'

            self.PATH_TO_LABELS = r'XXXX\models-master\object_detection\train\my_label_map.pbtxt'

            self.NUM_CLASSES = 1

            self.detection_graph = self._load_model()

            self.category_index = self._load_label_map()

        def _load_model(self):

            detection_graph = tf.Graph()

            with detection_graph.as_default():

                od_graph_def = tf.GraphDef()

                with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:

                    serialized_graph = fid.read()

                    od_graph_def.ParseFromString(serialized_graph)

                    tf.import_graph_def(od_graph_def, name='')

            return detection_graph

        def _load_label_map(self):

            label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)

            categories = label_map_util.convert_label_map_to_categories(label_map,

                                                                        max_num_classes=self.NUM_CLASSES,

                                                                        use_display_name=True)

            category_index = label_map_util.create_category_index(categories)

            return category_index

        def detect(self, image):

            with self.detection_graph.as_default():

                with tf.Session(graph=self.detection_graph) as sess:

                    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]

                    image_np_expanded = np.expand_dims(image, axis=0)

                    image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')

                    boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')

                    scores = self.detection_graph.get_tensor_by_name('detection_scores:0')

                    classes = self.detection_graph.get_tensor_by_name('detection_classes:0')

                    num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')

                    # Actual detection.

                    (boxes, scores, classes, num_detections) = sess.run(

                        [boxes, scores, classes, num_detections],

                        feed_dict={image_tensor: image_np_expanded})

                    # Visualization of the results of a detection.

                    vis_util.visualize_boxes_and_labels_on_image_array(

                        image,

                        np.squeeze(boxes),

                        np.squeeze(classes).astype(np.int32),

                        np.squeeze(scores),

                        self.category_index,

                        use_normalized_coordinates=True,

                        line_thickness=8)

            cv2.imshow("Resault", image)

            cv2.waitKey(0)

    if __name__ == '__main__':

        image = cv2.imread('test.jpg')

        detecotr = TOD()

        detecotr.detect(image)

     

    参考资料

    [if !supportLists][1] [endif]https://www.cnblogs.com/qcloud1001/p/7677661.html

     

     

    注:更多内容分享及源码获取欢迎关注微信公众号:ML_Study

    版权声明:本文为博主原创文章,转载请联系作者取得授权。https://blog.csdn.net/qq_34106574/article/category/7628923 

    相关文章

      网友评论

        本文标题:Object Detection API(4)—— Freeze

        本文链接:https://www.haomeiwen.com/subject/atrtrftx.html