美文网首页
mmdetection组件构成与注册表分析

mmdetection组件构成与注册表分析

作者: 寒夏凉秋 | 来源:发表于2020-02-18 17:55 被阅读0次

    mmdetection 使用模块化设计,将一般的目标检测算法分成了几个不同的模块,在使用时只需要在配置文件中声明各个模块使用的组件名称和参数,就可以像搭建积木一样搭建一个完整的目标检测模型;

    基本组件

    mmdetection的组件大多数以类的形式定义:

    • BACKBONES 对应目标检测模型的主干网络,用以对图片进行特征抽取.如常用的Resnet,ResNeXt,HRNet等.
    • NECKS 对主干网络产生的特征图做一些特定的处理,最常见的就是fpn多尺度抽取信息.现有(FPN,BFP,HRFPN等)
    • Heads 目标检测的头部,包含了目标检测的主要算法逻辑,包括bbox的产生,回归target的计算,loss的计算等
    • LOSS 损失函数的定义
    • DETECTOR 前面所介绍的组件搭建而成的一个整体,通过加载detector来运行整体算法
    • PIPELINES 数据增强管道类.定义了数据预处理和后处理部分

    mmdetection中提供了类似注册表的实现方式,对各个组件进行注册和使用:
    首先我们来看Registry类的定义:
    mmdet/utils/registry.py

    class Registry(object):
        #初始化name是什么组件,组件里面是一个dict,保存name跟它的具体类
        def __init__(self, name):
            self._name = name
            self._module_dict = dict()
    
        def __repr__(self):
            format_str = self.__class__.__name__ + '(name={}, items={})'.format(
                self._name, list(self._module_dict.keys()))
            return format_str
    
        @property
        def name(self):
            return self._name
    
        @property
        def module_dict(self):
            return self._module_dict
    
        def get(self, key):
            return self._module_dict.get(key, None)
    
        #把组件类与类名注册到注册表中,方便从config文件构建类
        def _register_module(self, module_class):
            """Register a module.
    
            Args:
                module (:obj:`nn.Module`): Module to be registered.
            """
            if not inspect.isclass(module_class):
                raise TypeError('module must be a class, but got {}'.format(
                    type(module_class)))
            module_name = module_class.__name__
            if module_name in self._module_dict:
                raise KeyError('{} is already registered in {}'.format(
                    module_name, self.name))
            self._module_dict[module_name] = module_class
    
        def register_module(self, cls):
            self._register_module(cls)
            return cls
    

    我们看到Registry类其实底层保存一个dict,用于保存组件名字跟具体的类.方便从注册表中找到相应的类进行初始化.接着,定义了全局注册表:
    mmdet/models/registry.py

    BACKBONES = Registry('backbone')
    NECKS = Registry('neck')
    ROI_EXTRACTORS = Registry('roi_extractor')
    SHARED_HEADS = Registry('shared_head')
    HEADS = Registry('head')
    LOSSES = Registry('loss')
    DETECTORS = Registry('detector')
    

    我们来看,注册表如何使用:
    如果我们自定义了一个resnet的backbone类,我们将这样使用Registry类的register_module装饰函数,将resnet注册到BACKBONES注册表中;

    @BACKBONES.register_module
    class ResNet(nn.Module):
    

    那么我们该如何从config中构建起一个类呢:

    #mmdet/models/builder.py
    def build(cfg, registry, default_args=None):
        if isinstance(cfg, list):
            modules = [
                build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
            ]
            return nn.Sequential(*modules)
        else:
            return build_from_cfg(cfg, registry, default_args)
    
    def build_backbone(cfg):
        return build(cfg, BACKBONES)
    
    ##mmdet/utils/registry.py
    def build_from_cfg(cfg, registry, default_args=None):
        """Build a module from config dict.
    
        Args:
            cfg (dict): Config dict. It should at least contain the key "type".
            registry (:obj:`Registry`): The registry to search the type from.
            default_args (dict, optional): Default initialization arguments.
    
        Returns:
            obj: The constructed object.
        """
        #type即注册表中类名字,代表了要从注册表中根据type的name来获得类
        assert isinstance(cfg, dict) and 'type' in cfg
        assert isinstance(default_args, dict) or default_args is None
        args = cfg.copy()
        obj_type = args.pop('type')
        if mmcv.is_str(obj_type):
            obj_cls = registry.get(obj_type)
            if obj_cls is None:
                raise KeyError('{} is not in the {} registry'.format(
                    obj_type, registry.name))
        elif inspect.isclass(obj_type):
            obj_cls = obj_type
        else:
            raise TypeError('type must be a str or valid type, but got {}'.format(
                type(obj_type)))
        if default_args is not None:
            for name, value in default_args.items():
                args.setdefault(name, value)
        ##进行类的实例化,并传入config中的参数
        return obj_cls(**args)
    

    build_from_cfg函数的作用是,根据config文件中的type与传入的注册表来获取需要实例化的具体类,然后再将config中的参数传入类初始化函数中,得到一个实例化的组件类.

    resnet为例,整体流程如下所示:

    • (1)resnet类编写完成后,用@BACKBONES.register_module装饰器将自身注册到BACKBONES注册表中.

    • (2)在config中定义backbone,并指明了具体参数

    #config/faster_rcnn_r50_fpn_1x.py
    backbone=dict(
            type='ResNet',
            depth=50,
            num_stages=4,
            out_indices=(0, 1, 2, 3),
            frozen_stages=1,
            style='pytorch')
    
    • (3)通过build_from_cfg()函数,传入的分别是backbone这个dict和BACKBONES注册表类
    • (4)通过'type'为ResNet找到resnet的类,并初始化参数depth,num_stages,out_indices.

    mmdetection这样通过注册表的方式实现了数据与实现的分离;能更好地对组件进行抽象.

    相关文章

      网友评论

          本文标题:mmdetection组件构成与注册表分析

          本文链接:https://www.haomeiwen.com/subject/ztqnfhtx.html