美文网首页我爱编程
Tensorflow中的object detection API

Tensorflow中的object detection API

作者: Mouse_HH | 来源:发表于2017-10-19 19:26 被阅读0次

    Preparing Inputs

    代码高能预警

    Tensorflow Object Detection API 在读取数据中使用了TFRecord文件格式。API提供了两个示例脚本,(create_pascal_tf_record.pycreate_pet_tf_record.py)。这里我们精读一下代码create_pascal_tf_record.py

    掌握TFRocord读取方法的可以跳级了。
    先看一下License

    # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # 
    

    这个脚本的主要用处是把PASCAL数据集转换成TFRecord。
    用法是

      ./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \  --year=VOC2012 \      --output_path=/home/user/pascal.record
    

    引入各种库

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import hashlib
    import io
    import logging
    import os
    
    from lxml import etree
    import PIL.Image
    import tensorflow as tf
    
    from object_detection.utils import dataset_util
    from object_detection.utils import label_map_util
    

    到这一步,程序都在引入各种各样的库,没有的装就是了。

    flags

    flags = tf.app.flags
    flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
    flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                        'merged set.')
    flags.DEFINE_string('annotations_dir', 'Annotations',
                        '(Relative) path to annotations directory.')
    flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
    flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
    flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
                        'Path to label map proto')
    flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                         'difficult instances')
    FLAGS = flags.FLAGS
    

    Tensorflow 中的flags类似于argv,基本用法是flags.DEFINE_类型('参数名称','默认值','参数描述')。进一步了解flags用法请移步tensorflow 学习(三)使用flags定义命令行参数

    dict_to_tf_example

    SETS = ['train', 'val', 'trainval', 'test']
    YEARS = ['VOC2007', 'VOC2012', 'merged']
    
    def dict_to_tf_example(data,
                           dataset_directory,
                           label_map_dict,
                           ignore_difficult_instances=False,
                           image_subdirectory='JPEGImages'):
     
      img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
      full_path = os.path.join(dataset_directory, img_path)
      with tf.gfile.GFile(full_path, 'rb') as fid:
        encoded_jpg = fid.read()
      encoded_jpg_io = io.BytesIO(encoded_jpg)
      image = PIL.Image.open(encoded_jpg_io)
      if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
      key = hashlib.sha256(encoded_jpg).hexdigest()
    
      width = int(data['size']['width'])
      height = int(data['size']['height'])
    
      xmin = []
      ymin = []
      xmax = []
      ymax = []
      classes = []
      classes_text = []
      truncated = []
      poses = []
      difficult_obj = []
      for obj in data['object']:
        difficult = bool(int(obj['difficult']))
        if ignore_difficult_instances and difficult:
          continue
    
        difficult_obj.append(int(difficult))
    
        xmin.append(float(obj['bndbox']['xmin']) / width)
        ymin.append(float(obj['bndbox']['ymin']) / height)
        xmax.append(float(obj['bndbox']['xmax']) / width)
        ymax.append(float(obj['bndbox']['ymax']) / height)
        classes_text.append(obj['name'].encode('utf8'))
        classes.append(label_map_dict[obj['name']])
        truncated.append(int(obj['truncated']))
        poses.append(obj['pose'].encode('utf8'))
    
      example = tf.train.Example(features=tf.train.Features(feature={
          'image/height': dataset_util.int64_feature(height),
          'image/width': dataset_util.int64_feature(width),
          'image/filename': dataset_util.bytes_feature(
              data['filename'].encode('utf8')),
          'image/source_id': dataset_util.bytes_feature(
              data['filename'].encode('utf8')),
          'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
          'image/encoded': dataset_util.bytes_feature(encoded_jpg),
          'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
          'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
          'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
          'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
          'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
          'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
          'image/object/class/label': dataset_util.int64_list_feature(classes),
          'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
          'image/object/truncated': dataset_util.int64_list_feature(truncated),
          'image/object/view': dataset_util.bytes_list_feature(poses),
      }))
      return example
    

    这段主要定义了一个函数dict_to_tf_example的函数,用以将PASCAL数据集中的XML标注文件转换为tf.Example.
    输入参数为:

    • data: 包含标注信息的XML文件。PASCAL数据集中,每张图片的标注信息存放于对应的XML文件中。在main函数中,data是通过dataset_util.recursive_parse_xml_to_dict的方法将XML中信息导入为字典获取的;
    • dataset_directory: 你懂得;
    • label_map_dict: 为每一个类别赋予一个id;由默认路径下已有文本给出;
    • ignore_difficult_instances: 是否忽略数据集中的difficult_instances。 保持默认即可;
    • image_subdirectory: 包含Images的PASCAL数据集的子文件夹,同样保持默认即可。

    在得到图片的绝对路径后(full_path),通过GFile实现对图片的读取,并用PIL打开成为我们喜闻乐见的[c,h,w]格式。

    而后,将data传过来的信息转化为规范化的格式(x/width,y/height)添加到列表中。说到这里就不得不夸一下dataset_util.recursive_parse_xml_to_dict这个配件了,from XML to dict,很方便的。

    再然后定义了一个tf.train.Example 实例example,将获得的信息全加进去,最后返回example。

    def main(_):
      if FLAGS.set not in SETS:
        raise ValueError('set must be in : {}'.format(SETS))
      if FLAGS.year not in YEARS:
        raise ValueError('year must be in : {}'.format(YEARS))
    
      data_dir = FLAGS.data_dir
      years = ['VOC2007', 'VOC2012']
      if FLAGS.year != 'merged':
        years = [FLAGS.year]
    
      writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
    
      label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
    
      for year in years:
        logging.info('Reading from PASCAL %s dataset.', year)
        examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                                     'aeroplane_' + FLAGS.set + '.txt')
        annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
        examples_list = dataset_util.read_examples_list(examples_path)
        for idx, example in enumerate(examples_list):
          if idx % 100 == 0:
            logging.info('On image %d of %d', idx, len(examples_list))
          path = os.path.join(annotations_dir, example + '.xml')
          with tf.gfile.GFile(path, 'r') as fid:
            xml_str = fid.read()
          xml = etree.fromstring(xml_str)
          data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
    
          tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                          FLAGS.ignore_difficult_instances)
          writer.write(tf_example.SerializeToString())
    
      writer.close()
    
    
    if __name__ == '__main__':
      tf.app.run()
    

    把example保存为TFRecord格式。

    相关文章

      网友评论

        本文标题:Tensorflow中的object detection API

        本文链接:https://www.haomeiwen.com/subject/rixyuxtx.html