美文网首页从零开始机器学习首页投稿(暂停使用,暂停投稿)程序员
从零开始机器学习-3 教会你的AI识别特定的物体(上)

从零开始机器学习-3 教会你的AI识别特定的物体(上)

作者: 养薛定谔的猫 | 来源:发表于2018-03-01 15:11 被阅读656次

    本文由 沈庆阳 所有,转载请与作者取得联系!
    在上一节中,我们成功地运行了TensorFlow的Object Detection API,并且可以使用该程序识别日常生活中一些常见的物体。那么对于没有预置在程序中的物体模型,则需要通过特定的训练来使你的AI程序识别该物体。

    关于COCO:
    COCO又称MS COCO,全称为:Common Objects in Context,即对日常生活中的常见物体进行语境标注的数据集。
    官方网址:http://cocodataset.org/#home
    

    TensorFlow的Object Detection API提供了一些预先训练好的模型。模型的识别精准度和识别速度往往不可兼得,如果想要达到更高的识别精准度则需要花更长的时间进行训练,我们往往需要在精准度和速度之间做取舍。
    训练特定的目标模型需要我们收集关于该物体(也称目标)的足够多的图片(通常在500~1000上下)。在我们收集完足够多的图片之后,再一张张框出图片中识别的目标,标上标签,并分为训练(Train)和测试(Test)两个分组,然后将这些数据生成TF Record文件。

    训练前的准备

    要想训练我们自己的模型,第一件事就是找到待识别物体的足够多数量的照片。对于收集这些照片的途径主要有:
    1、Google或是Baidu的图片搜索引擎
    2、Image Net等计算机视觉系统相关图像库
    对于图像的收集,笔者建议先搜寻Image Net等图像库,如果图像库中没有需要的物体图像,则通过Google或是Baidu等图片搜索引擎进行搜索。由于Image Net拥有完整的数据集,因此此处采用Google等图片搜索引擎搜索图片,而后手动标签的方法。

    收集互联网上的图片

    该教程选择以钢笔为检测的目标。因此,我们需要收集500张左右的照片作为训练用的数据集。
    经过2到3个小时的收集过程,最终收集到了足够多的图片。


    搜集到的含有Pen的图片

    标记收集到的图片

    对于已经收集到的图片,我们需要标记出所有图片中对应的待识别物体的范围。
    标记(Label)图片中的对象,在这里使用一个名为LabelImg的Github上的开源项目。

    LabelImg项目地址:https://github.com/tzutalin/labelImg
    

    将当前工作目录切换到我们想要存放labelImg工具的目录,然后使用git clone命令将仓库复制下来。

    root@ubuntu:~/Dev/tool# git clone https://github.com/tzutalin/labelImg
    正克隆到 'labelImg'...
    remote: Counting objects: 1138, done.
    remote: Compressing objects: 100% (16/16), done.
    remote: Total 1138 (delta 6), reused 14 (delta 6), pack-reused 1116
    接收对象中: 100% (1138/1138), 232.25 MiB | 5.54 MiB/s, 完成.
    处理 delta 中: 100% (643/643), 完成.
    检查连接... 完成。
    

    对于使用源代码编译安装的方式,Github的项目页面中有相应介绍。Ubuntu Linux中对应有Python2 + Qt4和Python3 + Qt5两种安装方式。这里选用Python3+Qt5的方式。

    Python 2 + Qt4

    sudo apt-get install pyqt4-dev-tools
    sudo pip install lxml
    make qt4py2
    python labelImg.py
    python labelImg.py [IMAGE_PATH] [PRE-DEFINED CLASS FILE]
    

    Python 3 + Qt5

    sudo apt-get install pyqt5-dev-tools
    sudo pip3 install lxml
    make qt5py3
    python3 labelImg.py
    python3 labelImg.py [IMAGE_PATH] [PRE-DEFINED CLASS FILE]
    

    使用python3 labelImg.py命令打开labelImg的UI界面。


    LabelImg的UI界面

    刚打开LabelImg的界面,空空如也。因为上批量标记图片,因此我们点击Open Dir来打开存放所有图片的文件夹。打开后,界面如下:


    打开文件夹之后的界面
    上图中,最中间本身灰色的区域显示出了文件夹中的图片。右下角的File List列出了该文件夹下面的所有的图片信息。
    通过点击Create RectBox按钮(快捷键为W),出现了如下图的横竖直线:
    定位线

    通过拖动定位线,标记出图中待标记物体(笔)的方框,并输入待标记物体的名称,点击OK。


    输入目标名称
    对于一张图片中有多个待检测目标的情况,此时则需要再次使用方框讲物体标注。
    常用快捷键
    下一张:快捷键D
    上一张:快捷键A
    创建矩形框:快捷键W
    保存:快捷键Ctrl+S
    

    在标注完毕之后,我们在突然的相同文件夹下会产生同名的xml文件。该xml文件内容如下:

    <annotation>
        <folder>Train</folder>
        <filename>jy5t4wr.jpg</filename>
        <path>/home/jack/桌面/Pen/Train/jy5t4wr.jpg</path>
        <source>
            <database>Unknown</database>
        </source>
        <size>
            <width>1024</width>
            <height>707</height>
            <depth>3</depth>
        </size>
        <segmented>0</segmented>
        <object>
            <name>pen</name>
            <pose>Unspecified</pose>
            <truncated>0</truncated>
            <difficult>0</difficult>
            <bndbox>
                <xmin>204</xmin>
                <ymin>160</ymin>
                <xmax>673</xmax>
                <ymax>616</ymax>
            </bndbox>
        </object>
    </annotation>
    

    在上述xml文件中,我们可以看到有图像文件的内容信息和Object标签的相关信息。
    至此,经过近6个小时的时间,我们的准备工作进行完毕。

    创建TF Record

    TF Record是一种二进制文件。传统的图像与标签往往是分为不同文件存放的(如jpg是图片格式,xml中包含标签等),而在TF Record中每一张输入图像和与其相关的标签则是存放在一个文件中的。TF Record并不对数据进行压缩,所以可以被快速加载到内存中,从而进行大量数据流的读取操作。
    在创建TF Record之前,我们将上一节中标注好的图片与其标注信息分别放入两个文件夹中,一个是test一个是train文件夹。其中,Test文件夹中的图片约为所有图片数量的10%。准备完成之后,将test和train文件夹放入images文件夹中。同时在形同的目录层级下新建名为data和training的文件夹。


    文件夹结构

    新建一个名为xml_to_csv.py的脚本:

    import os
    import glob
    import pandas as pd
    import xml.etree.ElementTree as ET
    
    
    def xml_to_csv(path):
        xml_list = []
        for xml_file in glob.glob(path + '/*.xml'):
            tree = ET.parse(xml_file)
            root = tree.getroot()
            for member in root.findall('object'):
                value = (root.find('filename').text,
                         int(root.find('size')[0].text),
                         int(root.find('size')[1].text),
                         member[0].text,
                         int(member[4][0].text),
                         int(member[4][1].text),
                         int(member[4][2].text),
                         int(member[4][3].text)
                         )
                xml_list.append(value)
        column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
        xml_df = pd.DataFrame(xml_list, columns=column_name)
        return xml_df
    
    
    def main():
        for directory in ['train','test']:
            image_path = os.path.join(os.getcwd(), 'images/{}'.format(directory))
            xml_df = xml_to_csv(image_path)
            xml_df.to_csv('data/{}_labels.csv'.format(directory), index=None)
            print('Successfully converted xml to csv.')
    
    
    main()
    

    该脚本来自https://github.com/datitran/raccoon_dataset,并经过修改
    通过控制台,使用python xml_to_csv.py或python3 xml_to_csv.py运行该脚本。
    当控制台出现如下提示时,脚本执行成功,并且在data目录下会出现test_labels.csv和train_labels.csv两个文件。

    root@jack:~/object_detection# python3 xml_to_csv.py 
    Successfully converted xml to csv.
    Successfully converted xml to csv.
    

    我们成功地讲xml标签转换为csv标签,那么下一步就是生成TF Record了。新建一个名为generate_tfrecord.py的脚本,输入以下内容:

    """
    Usage:
      # From tensorflow/models/
      # Create train data:
      python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train.record
    
      # Create test data:
      python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record
    """
    from __future__ import division
    from __future__ import print_function
    from __future__ import absolute_import
    
    import os
    import io
    import pandas as pd
    import tensorflow as tf
    
    from PIL import Image
    from object_detection.utils import dataset_util
    from collections import namedtuple, OrderedDict
    
    flags = tf.app.flags
    flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
    flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
    FLAGS = flags.FLAGS
    
    
    # TO-DO replace this with label map
    def class_text_to_int(row_label):
        if row_label == 'pen':
            return 1
        else:
            None
    
    
    def split(df, group):
        data = namedtuple('data', ['filename', 'object'])
        gb = df.groupby(group)
        return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
    
    
    def create_tf_example(group, path):
        with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        width, height = image.size
    
        filename = group.filename.encode('utf8')
        image_format = b'jpg'
        xmins = []
        xmaxs = []
        ymins = []
        ymaxs = []
        classes_text = []
        classes = []
    
        for index, row in group.object.iterrows():
            xmins.append(row['xmin'] / width)
            xmaxs.append(row['xmax'] / width)
            ymins.append(row['ymin'] / height)
            ymaxs.append(row['ymax'] / height)
            classes_text.append(row['class'].encode('utf8'))
            classes.append(class_text_to_int(row['class']))
    
        tf_example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(filename),
            'image/source_id': dataset_util.bytes_feature(filename),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature(image_format),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
        }))
        return tf_example
    
    
    def main(_):
        writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
        path = os.path.join(os.getcwd(), 'images')
        examples = pd.read_csv(FLAGS.csv_input)
        grouped = split(examples, 'filename')
        for group in grouped:
            tf_example = create_tf_example(group, path)
            writer.write(tf_example.SerializeToString())
    
        writer.close()
        output_path = os.path.join(os.getcwd(), FLAGS.output_path)
        print('Successfully created the TFRecords: {}'.format(output_path))
    
    
    if __name__ == '__main__':
        tf.app.run()
    

    注意第31行的代码中,我们需要讲row_label变量替换为我们自己定义的标签名称,此处我使用的是pen标签。

    ...
    29:# TO-DO replace this with label map
    30:def class_text_to_int(row_label):
    31:    if row_label == 'pen':
    ...
    

    另外,在第20行,导入了object_detection.utils中的包

    from object_detection.utils import dataset_util
    

    因此我们需要讲tensorflow的models安装到计算机中。通过cd命令进入到~/Dev/tensorflow/models/research# 目录,通过

    sudo python3 setup.py install
    

    命令来安装tensorflow的相关python包。
    安装完成之后,通过如下命令生成TF Record文件,对于train和test需要分别运行两次。

      python3 generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=data/train.record
    
      python3 generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=data/test.record
    
    成功生成了TF Record文件

    觉得写的不错的朋友可以点一个 喜欢♥ ~
    谢谢你的支持!

    相关文章

      网友评论

      • 一水寒Gd:安装object_detection成功 运行model_builder_test.py成功
        但还是 from object_detection.utils import dataset_util 报错 找不到 object_detection 楼主遇到过吗
      • f396978aa4ea:最后的结果是啥,就生成了record文件啊
        养薛定谔的猫:@EVILcreative 还没写完😳😳😳所以写了一个上 ,明天把下给写了

      本文标题:从零开始机器学习-3 教会你的AI识别特定的物体(上)

      本文链接:https://www.haomeiwen.com/subject/oputtftx.html