美文网首页
mmdetection源码阅读笔记(0)--创建模型

mmdetection源码阅读笔记(0)--创建模型

作者: EwanRenton | 来源:发表于2019-05-04 17:39 被阅读0次

    之前做天池比赛用mmdetection取得了还不错的成绩,就想仔细读读mmdetection的源码,了解下具体实现。

    这个系列,准备按照目标检测和实例分割的pipeline来写。


    训练脚本

    官方提供了分布式训练,并且推荐使用分布式训练,即使在单机器上dist_train.sh

    #!/usr/bin/env bash
    
    PYTHON=${PYTHON:-"python3"}
    
    $PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}
    

    该脚本主要使用了torch.distributed.launch辅助启动工具,这个工具可以辅助在每个节点上启动多个进程process,支持Python2 和 Python3.
    更多关于分布式训练的细节可以参考pytorch 分布式训练 distributed parallel 笔记


    创建模型

    train.pymain()函数,先做了一些config文件,work_dir以及log的操作,之后调用了build_detector()来创建模型。

    build_detector()

    build_detector()定义在mmdet/models/builder.py中。
    下面是主要用到的几个函数。
    mmdet/models/builder.py

    from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
    
    def build_detector(cfg, train_cfg=None, test_cfg=None):
        return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
        
    
    

    build_detector()中有一个DETECTORS这是一个注册器,里面保存了所有支持的detector。具体的实现方式和Python装饰器有点像。
    下面以cascade_rcnn为例,看下是怎么进行注册过来的。

    1. 首先在mmdet/models/__init__.py里面from .detectors import *
    2. mmdet/models/detectors/__init__.py里面from .cascade_rcnn import CascadeRCNN
    3. mmdet/models/detectors/cascade_rcnn.py
    from ..registry import DETECTORS
    @DETECTORS.register_module
    class CascadeRCNN(BaseDetector, RPNTestMixin):
        other codes
    

    @DETECTORS.register_module这一行代码,将CascadeRCNN注册到了DETECTORS中。
    这里简单的说下@的用法,Python当解释器读到@的这样的修饰符之后,会先解析@后的内容,直接就把@下一行的函数或者类作为@后边的函数的参数,然后将返回值赋值给下一行修饰的函数对象。
    例如:

    def a():
        print("func a")
    def b():
        print("func b")
    @a
    @b
    def c():
        print("func c")
    

    python会按照自下而上的顺序把各自的函数结果作为下一个函数(上面的函数)的输入,也就是a(b(c()))
    回到我们的DETECTORS,也就是上面的操作将CascadeRCNN传给了DETECTORS.register_module
    mmdet/models/registry.py

    class Registry(object):
    
        def __init__(self, name):
            self._name = name
            self._module_dict = dict()
    
        def _register_module(self, module_class):
            """Register a module.
    
            Args:
                module (:obj:`nn.Module`): Module to be registered.
            """
            if not issubclass(module_class, nn.Module):
                raise TypeError(
                    'module must be a child of nn.Module, but got {}'.format(
                        module_class))
            module_name = module_class.__name__
            if module_name in self._module_dict:
                raise KeyError('{} is already registered in {}'.format(
                    module_name, self.name))
            self._module_dict[module_name] = module_class
    
        def register_module(self, cls):
            self._register_module(cls)
            return cls
    BACKBONES = Registry('backbone')
    NECKS = Registry('neck')
    ROI_EXTRACTORS = Registry('roi_extractor')
    HEADS = Registry('head')
    DETECTORS = Registry('detector')
    

    注册的模型被保存到了,self._module_dict中。
    再回到builder.py
    mmdet/models/builder.py

    def build(cfg, registry, default_args=None):
        if isinstance(cfg, list):
            modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
            return nn.Sequential(*modules)
        else:
            return _build_module(cfg, registry, default_args)
            
    def _build_module(cfg, registry, default_args):
        assert isinstance(cfg, dict) and 'type' in cfg
        assert isinstance(default_args, dict) or default_args is None
        args = cfg.copy()
        obj_type = args.pop('type')
        if mmcv.is_str(obj_type):
            if obj_type not in registry.module_dict:
                raise KeyError('{} is not in the {} registry'.format(
                    obj_type, registry.name))
            obj_type = registry.module_dict[obj_type]
        elif not isinstance(obj_type, type):
            raise TypeError('type must be a str or valid type, but got {}'.format(
                type(obj_type)))
        if default_args is not None:
            for name, value in default_args.items():
                args.setdefault(name, value)
        return obj_type(**args)
    

    build()中主要通过_build_module()registry.module_dict中实例化注册过的模型。


    最后

    这篇主要讲了mmdetection中的创建模型,下一篇准备以Cascade Rcnn为例看下网络的具体搭建。

    相关文章

      网友评论

          本文标题:mmdetection源码阅读笔记(0)--创建模型

          本文链接:https://www.haomeiwen.com/subject/wivonqtx.html