美文网首页
object_detectionAPI源码阅读笔记(16-通过c

object_detectionAPI源码阅读笔记(16-通过c

作者: yanghedada | 来源:发表于2018-10-31 22:22 被阅读128次

    这里的源码是从train.py开始看的。之后还有eval.py

    train.py

    在trian.py中config文件,被分成三分

      model_config = configs['model']
      train_config = configs['train_config']
      input_config = configs['train_input_config']
    

    moedel_config是构建模型的文件。

      model_fn = functools.partial(
          model_builder.build,
          model_config=model_config,
          is_training=True)
    

    在model_bulider.py中build会选择模型种类

    def build(model_config, is_training):
      """Builds a DetectionModel based on the model config.
    
      Args:
        model_config: A model.proto object containing the config for the desired
          DetectionModel.
        is_training: True if this model is being built for training purposes.
    
      Returns:
        DetectionModel based on the config.
    
      Raises:
        ValueError: On invalid meta architecture or model.
      """
      if not isinstance(model_config, model_pb2.DetectionModel):
        raise ValueError('model_config not of type model_pb2.DetectionModel.')
      meta_architecture = model_config.WhichOneof('model')
      if meta_architecture == 'ssd':
        return _build_ssd_model(model_config.ssd, is_training)
      if meta_architecture == 'faster_rcnn':
        return _build_faster_rcnn_model(model_config.faster_rcnn, is_training)
      raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
    
    

    如果你选择faster-rcnn,在model_builder.py中这些都是构建faster-rcnn模型的参数

    如果你有兴趣,在protos/model_pb2.py有很多model_config的默认值

    这时候模型已经构建完了

    回到train.py中

      train_config = configs['train_config']
    

    发现这是对trainer.py进行的配置文件,在trainer.py的train函数中,如下:

    在protos/train_pb2.py中的默认配置如下:

    _descriptor.FieldDescriptor(
          name='batch_size', full_name='object_detection.protos.TrainConfig.batch_size', index=0,
          number=1, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=32,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='data_augmentation_options', full_name='object_detection.protos.TrainConfig.data_augmentation_options', index=1,
          number=2, type=11, cpp_type=10, label=3,
          has_default_value=False, default_value=[],
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='sync_replicas', full_name='object_detection.protos.TrainConfig.sync_replicas', index=2,
          number=3, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='keep_checkpoint_every_n_hours', full_name='object_detection.protos.TrainConfig.keep_checkpoint_every_n_hours', index=3,
          number=4, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=1000,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='optimizer', full_name='object_detection.protos.TrainConfig.optimizer', index=4,
          number=5, type=11, cpp_type=10, label=1,
          has_default_value=False, default_value=None,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='gradient_clipping_by_norm', full_name='object_detection.protos.TrainConfig.gradient_clipping_by_norm', index=5,
          number=6, type=2, cpp_type=6, label=1,
          has_default_value=True, default_value=0,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='fine_tune_checkpoint', full_name='object_detection.protos.TrainConfig.fine_tune_checkpoint', index=6,
          number=7, type=9, cpp_type=9, label=1,
          has_default_value=True, default_value=_b("").decode('utf-8'),
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='from_detection_checkpoint', full_name='object_detection.protos.TrainConfig.from_detection_checkpoint', index=7,
          number=8, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='num_steps', full_name='object_detection.protos.TrainConfig.num_steps', index=8,
          number=9, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=0,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='startup_delay_steps', full_name='object_detection.protos.TrainConfig.startup_delay_steps', index=9,
          number=10, type=2, cpp_type=6, label=1,
          has_default_value=True, default_value=15,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='bias_grad_multiplier', full_name='object_detection.protos.TrainConfig.bias_grad_multiplier', index=10,
          number=11, type=2, cpp_type=6, label=1,
          has_default_value=True, default_value=0,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='freeze_variables', full_name='object_detection.protos.TrainConfig.freeze_variables', index=11,
          number=12, type=9, cpp_type=9, label=3,
          has_default_value=False, default_value=[],
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='replicas_to_aggregate', full_name='object_detection.protos.TrainConfig.replicas_to_aggregate', index=12,
          number=13, type=5, cpp_type=1, label=1,
          has_default_value=True, default_value=1,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='batch_queue_capacity', full_name='object_detection.protos.TrainConfig.batch_queue_capacity', index=13,
          number=14, type=5, cpp_type=1, label=1,
          has_default_value=True, default_value=150,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='num_batch_queue_threads', full_name='object_detection.protos.TrainConfig.num_batch_queue_threads', index=14,
          number=15, type=5, cpp_type=1, label=1,
          has_default_value=True, default_value=8,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='prefetch_queue_capacity', full_name='object_detection.protos.TrainConfig.prefetch_queue_capacity', index=15,
          number=16, type=5, cpp_type=1, label=1,
          has_default_value=True, default_value=5,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='merge_multiple_label_boxes', full_name='object_detection.protos.TrainConfig.merge_multiple_label_boxes', index=16,
          number=17, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
    

    再看input_config

     input_config = configs['train_input_config']
    

    在builder/input_reader_builder中

    input_reader_pb2中默认值:

        _descriptor.FieldDescriptor(
          name='label_map_path', full_name='object_detection.protos.InputReader.label_map_path', index=0,
          number=1, type=9, cpp_type=9, label=1,
          has_default_value=True, default_value=_b("").decode('utf-8'),
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='shuffle', full_name='object_detection.protos.InputReader.shuffle', index=1,
          number=2, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=True,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='queue_capacity', full_name='object_detection.protos.InputReader.queue_capacity', index=2,
          number=3, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=2000,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='min_after_dequeue', full_name='object_detection.protos.InputReader.min_after_dequeue', index=3,
          number=4, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=1000,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='num_epochs', full_name='object_detection.protos.InputReader.num_epochs', index=4,
          number=5, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=0,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='num_readers', full_name='object_detection.protos.InputReader.num_readers', index=5,
          number=6, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=8,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='load_instance_masks', full_name='object_detection.protos.InputReader.load_instance_masks', index=6,
          number=7, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='tf_record_input_reader', full_name='object_detection.protos.InputReader.tf_record_input_reader', index=7,
          number=8, type=11, cpp_type=10, label=1,
          has_default_value=False, default_value=None,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='external_input_reader', full_name='object_detection.protos.InputReader.external_input_reader', index=8,
          number=9, type=11, cpp_type=10, label=1,
          has_default_value=False, default_value=None,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
    

    value.py

    在value.py中也是被分为三个部分

      model_config = configs['model']
      eval_config = configs['eval_config']
      if FLAGS.eval_training_data:
        input_config = configs['train_input_config']
      else:
        input_config = configs['eval_input_config']
    

    这里

    eval_config = configs['eval_config']
    

    为新增的一个配置文件,进行计算评估用的一个文件。

    evaluator.py文件使用了这里的config文件参数

    文件开头的几种分数评估方式。

    EVAL_METRICS_CLASS_DICT = {
        'pascal_voc_metrics':
            object_detection_evaluation.PascalDetectionEvaluator,
        'weighted_pascal_voc_metrics':
            object_detection_evaluation.WeightedPascalDetectionEvaluator,
        'open_images_metrics':
            object_detection_evaluation.OpenImagesDetectionEvaluator
    }
    

    eval_pb2.py文件中的eval_config的默认值。

    _descriptor.FieldDescriptor(
          name='num_visualizations', full_name='object_detection.protos.EvalConfig.num_visualizations', index=0,
          number=1, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=10,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='num_examples', full_name='object_detection.protos.EvalConfig.num_examples', index=1,
          number=2, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=5000,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='eval_interval_secs', full_name='object_detection.protos.EvalConfig.eval_interval_secs', index=2,
          number=3, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=300,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='max_evals', full_name='object_detection.protos.EvalConfig.max_evals', index=3,
          number=4, type=13, cpp_type=3, label=1,
          has_default_value=True, default_value=0,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='save_graph', full_name='object_detection.protos.EvalConfig.save_graph', index=4,
          number=5, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='visualization_export_dir', full_name='object_detection.protos.EvalConfig.visualization_export_dir', index=5,
          number=6, type=9, cpp_type=9, label=1,
          has_default_value=True, default_value=_b("").decode('utf-8'),
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='eval_master', full_name='object_detection.protos.EvalConfig.eval_master', index=6,
          number=7, type=9, cpp_type=9, label=1,
          has_default_value=True, default_value=_b("").decode('utf-8'),
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='metrics_set', full_name='object_detection.protos.EvalConfig.metrics_set', index=7,
          number=8, type=9, cpp_type=9, label=1,
          has_default_value=True, default_value=_b("pascal_voc_metrics").decode('utf-8'),
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='export_path', full_name='object_detection.protos.EvalConfig.export_path', index=8,
          number=9, type=9, cpp_type=9, label=1,
          has_default_value=True, default_value=_b("").decode('utf-8'),
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='ignore_groundtruth', full_name='object_detection.protos.EvalConfig.ignore_groundtruth', index=9,
          number=10, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='use_moving_averages', full_name='object_detection.protos.EvalConfig.use_moving_averages', index=10,
          number=11, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None),
        _descriptor.FieldDescriptor(
          name='eval_instance_masks', full_name='object_detection.protos.EvalConfig.eval_instance_masks', index=11,
          number=12, type=8, cpp_type=7, label=1,
          has_default_value=True, default_value=False,
          message_type=None, enum_type=None, containing_type=None,
          is_extension=False, extension_scope=None,
          options=None)
    

    所有的超参数的默认值都可以在config文件中进行修改。

    相关文章

      网友评论

          本文标题:object_detectionAPI源码阅读笔记(16-通过c

          本文链接:https://www.haomeiwen.com/subject/dwpztqtx.html