Tensorflow Object Detection API是Tensorflow官方发布的一个建立在TensorFlow之上的开源框架,可以轻松构建,训练和部署对象检测模型。TensorFlow官方使用TensorFlow Slim项目框实现了近年来提出的多种优秀的深度卷积神经网络框架。
Tensorflow Object Detection API可以选择的模型:
- Single Shot Multibox Detector (SSD) with MobileNet,
- SSD with Inception V2,
- Region-Based Fully Convolutional Networks (R-FCN) with Resnet 101,
- Faster RCNN with Resnet 101,
- Faster RCNN with Inception Resnet v2
Github:https://github.com/tensorflow/models/tree/master/object_detection
在本文中,我们实现了在Windows环境下运行该框架的流程。在此之前我们要使用相关的卷积模型,需要自行编译作者指定的Caffe,不同的框架使用的Caffe版本也不尽相同。而基于其他深度学习框架的代码受制于作者水平的不同,可用性与效率也不尽相同,因此TOD API在Tensorflow上提供了了一套标准化的编写模式,既有利于使用,也有为编写其他模型提供了例子。
环境
- Windows 10
- Python 3.6
- Tensorflow-gpu 1.2
- CUDA Toolkit 8与 cuDNN v5
首先我们安装Tensorflow,最新的版本为1.2。在python 3.5+使用Tensorflow非常的简单,不需要过多的流程,只需要使用pip进行安装,所有相关的依赖就会自动安装完成。
# For CPU
pip install tensorflow
# For GPU
pip install tensorflow-gpu
其次官方要求下列包,我们一同使用pip进行安装。
pip install pillow
pip install lxml
pip install jupyter
pip install matplotlib
Tensorflow Object Detection API使用Protobufs来配置模型和训练参数。 在使用框架之前,必须编译Protobuf库。对于protobuf,在Linux下我们可以使用apt-get安装,在Windows下我们可以直接下载已经编译好的版本,这里我们选择下载列表中的protoc-3.3.0-win32.zip。
Github:https://github.com/google/protobuf/releases
我们将bin文件夹加入到环境变量中,然后在CMD执行protco命令,可以看到protobuf要求输入文件。
protoc.jpg接下来我们切换到models目录下,使用protoc命令编译.proto文件
# From tensorflow/models/
protoc object_detection/protos/*.proto --python_out=.
我们可以看见.proto文件已经被编译为了.py文件。
proto.jpg官方提供了一个object_detection_tutorial.ipynb文件,这个Demo会自动下载并执行最小最快的模型Single Shot Multibox Detector (SSD) with MobileNet。检测结果如下:
1.png 2.png为了方便在项目中使用,我们重写了一个Python文件,其中网络模型可以从下面的地址下载,每一个模型都有一个frozen_inference_graph.pb文件。代码与运行结果如下:
Tensorflow detection model:
https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md
# coding:utf8
import os
import sys
import cv2
import numpy as np
import tensorflow as tf
sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util
class TOD(object):
def __init__(self):
# Path to frozen detection graph. This is the actual model that is used for the object detection.
self.PATH_TO_CKPT = 'frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
self.PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
self.NUM_CLASSES = 90
self.detection_graph = self._load_model()
self.category_index = self._load_label_map()
def _load_model(self):
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
def _load_label_map(self):
label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
return category_index
def detect(self, image):
with self.detection_graph.as_default():
with tf.Session(graph=self.detection_graph) as sess:
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image, axis=0)
image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
self.category_index,
use_normalized_coordinates=True,
line_thickness=8)
while True:
cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
cv2.imshow("detection", image)
if cv2.waitKey(110) & 0xff == 27:
break
if __name__ == '__main__':
image = cv2.imread('dog.jpg')
detecotr = TOD()
detecotr.detect(image)
test.jpg
网友评论
你们的这个python文件不是一个好的detect script,非常容易识别成错误的object。detect同一个图片时,识别出来的结果跟eval任务中有很大的区别。归根结底是由于你们读入图片和预处理图片时与TF框架要求的不一致导致的。