在Windows下使用Tensorflow Object Det

作者: Daisy丶 | 来源:发表于2017-07-09 19:20 被阅读7458次

    Tensorflow Object Detection API是Tensorflow官方发布的一个建立在TensorFlow之上的开源框架,可以轻松构建,训练和部署对象检测模型。TensorFlow官方使用TensorFlow Slim项目框实现了近年来提出的多种优秀的深度卷积神经网络框架。

    Tensorflow Object Detection API可以选择的模型:

    • Single Shot Multibox Detector (SSD) with MobileNet,
    • SSD with Inception V2,
    • Region-Based Fully Convolutional Networks (R-FCN) with Resnet 101,
    • Faster RCNN with Resnet 101,
    • Faster RCNN with Inception Resnet v2

    Githubhttps://github.com/tensorflow/models/tree/master/object_detection

    在本文中,我们实现了在Windows环境下运行该框架的流程。在此之前我们要使用相关的卷积模型,需要自行编译作者指定的Caffe,不同的框架使用的Caffe版本也不尽相同。而基于其他深度学习框架的代码受制于作者水平的不同,可用性与效率也不尽相同,因此TOD API在Tensorflow上提供了了一套标准化的编写模式,既有利于使用,也有为编写其他模型提供了例子。

    环境

    • Windows 10
    • Python 3.6
    • Tensorflow-gpu 1.2
    • CUDA Toolkit 8与 cuDNN v5

    首先我们安装Tensorflow,最新的版本为1.2。在python 3.5+使用Tensorflow非常的简单,不需要过多的流程,只需要使用pip进行安装,所有相关的依赖就会自动安装完成。

    # For CPU
    pip install tensorflow
    # For GPU
    pip install tensorflow-gpu
    

    其次官方要求下列包,我们一同使用pip进行安装。

    pip install pillow
    pip install lxml
    pip install jupyter
    pip install matplotlib
    

    Tensorflow Object Detection API使用Protobufs来配置模型和训练参数。 在使用框架之前,必须编译Protobuf库。对于protobuf,在Linux下我们可以使用apt-get安装,在Windows下我们可以直接下载已经编译好的版本,这里我们选择下载列表中的protoc-3.3.0-win32.zip。

    Githubhttps://github.com/google/protobuf/releases

    我们将bin文件夹加入到环境变量中,然后在CMD执行protco命令,可以看到protobuf要求输入文件。

    protoc.jpg

    接下来我们切换到models目录下,使用protoc命令编译.proto文件

    # From tensorflow/models/
    protoc object_detection/protos/*.proto --python_out=.
    

    我们可以看见.proto文件已经被编译为了.py文件。

    proto.jpg

    官方提供了一个object_detection_tutorial.ipynb文件,这个Demo会自动下载并执行最小最快的模型Single Shot Multibox Detector (SSD) with MobileNet。检测结果如下:

    1.png 2.png

    为了方便在项目中使用,我们重写了一个Python文件,其中网络模型可以从下面的地址下载,每一个模型都有一个frozen_inference_graph.pb文件。代码与运行结果如下:

    Tensorflow detection model:
    https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md

    # coding:utf8
    import os
    import sys
    import cv2
    import numpy as np
    import tensorflow as tf
    sys.path.append("..")
    
    from utils import label_map_util
    from utils import visualization_utils as vis_util
    
    
    class TOD(object):
        def __init__(self):
            # Path to frozen detection graph. This is the actual model that is used for the object detection.
            self.PATH_TO_CKPT = 'frozen_inference_graph.pb'
    
            # List of the strings that is used to add correct label for each box.
            self.PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
    
            self.NUM_CLASSES = 90
    
            self.detection_graph = self._load_model()
            self.category_index = self._load_label_map()
    
        def _load_model(self):
            detection_graph = tf.Graph()
            with detection_graph.as_default():
                od_graph_def = tf.GraphDef()
                with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                    serialized_graph = fid.read()
                    od_graph_def.ParseFromString(serialized_graph)
                    tf.import_graph_def(od_graph_def, name='')
            return detection_graph
    
        def _load_label_map(self):
            label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
            categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True)
            category_index = label_map_util.create_category_index(categories)
            return category_index
    
        def detect(self, image):
            with self.detection_graph.as_default():
                with tf.Session(graph=self.detection_graph) as sess:
                    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                    image_np_expanded = np.expand_dims(image, axis=0)
                    image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                    # Each box represents a part of the image where a particular object was detected.
                    boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                    # Each score represent how level of confidence for each of the objects.
                    # Score is shown on the result image, together with the class label.
                    scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                    classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                    num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                    # Actual detection.
                    (boxes, scores, classes, num_detections) = sess.run(
                        [boxes, scores, classes, num_detections],
                        feed_dict={image_tensor: image_np_expanded})
                    # Visualization of the results of a detection.
                    vis_util.visualize_boxes_and_labels_on_image_array(
                        image,
                        np.squeeze(boxes),
                        np.squeeze(classes).astype(np.int32),
                        np.squeeze(scores),
                        self.category_index,
                        use_normalized_coordinates=True,
                        line_thickness=8)
    
            while True:
                cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
                cv2.imshow("detection", image)
                if cv2.waitKey(110) & 0xff == 27:
                    break
    
    
    if __name__ == '__main__':
        image = cv2.imread('dog.jpg')
        detecotr = TOD()
        detecotr.detect(image)
    
    
    test.jpg

    相关文章

      网友评论

      • 0d1968c972e6:tensorflow.python.framework.errors_impl.NotFoundError: NewRandomAccessFile failed to Create/Open: data\mscoco_label_map.pbtxt 什么意思请问
      • 1d7e4a095a43:您好我使用protoc命令编译.proto文件时候出现了object_detection/protos/*.proto: No such file or directory 请问这是怎么回事
        Daisy丶:@裴_13ec 可以试试用绝对路径
        1d7e4a095a43:@Alnxl 在window下没有实现呢!我在ubuntu下按照步骤就能解决
        Alnxl:我也遇到这个问题,请问解决了没:joy:
      • 五师傅:“为了方便在项目中使用,我们重写了一个Python文件”
        你们的这个python文件不是一个好的detect script,非常容易识别成错误的object。detect同一个图片时,识别出来的结果跟eval任务中有很大的区别。归根结底是由于你们读入图片和预处理图片时与TF框架要求的不一致导致的。

      本文标题:在Windows下使用Tensorflow Object Det

      本文链接:https://www.haomeiwen.com/subject/tynshxtx.html