mmd_hook
MMDetection3D:注册模型
train.py的开头中,已经开始注册必要的模块了
from mmcv import Config
from mmdet import __version__
from mmdet.datasets import build_dataset
from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed)
from mmdet.models import build_detector
看mmdet文件夹下的__init__.py,以及datasets , apis , models 下的__init__.py文件,发现: mmdet.__init__py:
from .backbones import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 from .roi_extractors import * # noqa: F401,F403 from .anchor_heads import * # noqa: F401,F403 from .shared_heads import * # noqa: F401,F403 from .bbox_heads import * # noqa: F401,F403 from .mask_heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403 from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS) from .builder import (build_backbone, build_neck, build_roi_extractor,
build_shared_head, build_head, build_loss,
build_detector)
__all__ = [ 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES', 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor', 'build_shared_head', 'build_head', 'build_loss', 'build_detector' ]
这个文件,第一行,导入了backbone._init_.py,看一下里面内容: mmdet.models.backbones._init_.py
from .resnet import ResNet, make_res_layer
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .hrnet import HRNet
__all__ = ['ResNet', 'make_res_layer', 'ResNeXt', 'SSDVGG', 'HRNet']
这里又导入了resnet,resnext等几个卷积神经网络,那么以resnet为例,看一下里面都有啥 mmdet.models.backbones.resnet.py 第13.14行:
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
其中又从registry中导入BACKBONES,那么再来看看registry和他的BACKBONES mmdet.registry.py:
from mmdet.utils import Registry
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
那么这个Registry又是何方神圣?意欲何为?看看去 mmdet.utils._init_.py:
from .registry import Registry, build_from_cfg
__all__ = ['Registry', 'build_from_cfg']
顺藤摸瓜,找到registry.py mmdet.utils.registry.py 代码稍长,分两段看吧。只看主要代码,能帮助理解其机制的代码,删除部分不影响理解的代码,全文都是。 Registry:
import inspect
import mmcv
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
这段代码呢,生成了一个字典,里面包含了模块名字,以后模块都要挂在这个名字下。此时我们反过头来再看registry.py中的代码,其实是生成了各个主要部分,并向外提供了接口。 这段代码到现在,暂时没有了下文,我们再来看build_from_cfg函数:
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type)
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
这段代码比较难弄,尤其是最后哪行。我们来分析一波吧,既然难懂,就先来看他在那里被调用的吧。回到tools.train.py:
...
def parse_args():
...
parser.add_argument('config', help='train config file path')
...
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
...
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
...
因为其有个cfg参数,而这build_detector是用到了cfg,可以算个线索(如果你使用IDE的话,可以看看build_from_cfg是被谁引用的,顺藤摸瓜,推荐)。再来看build_detector,文章第一个代码段train.py最后一行引入进来,在mmet.models里,而mmet.models.__init__py中,有
from .builder import (build_backbone, build_neck, build_roi_extractor,
build_shared_head, build_head, build_loss,
build_detector)
我们来看这build_detector具体内容吧mmet.models.builder.py:
from torch import nn
from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg):
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
return build(cfg, SHARED_HEADS)
def build_head(cfg):
return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
看到参数又传到build里,而cfg是个dict类型,所以又到了build_from_cfg,此刻我们来分析build_from_cfg:
def build_from_cfg(cfg, registry, default_args=None):
...
args = cfg.copy()
obj_type = args.pop('type')
...
return obj_type(**args)
再在你的配置文件里看到这个obj_type: configs.faster_rcnn_r50_fpn_1x.py:
model = dict(
type='FasterRCNN',
其实也就是执行了FasterRCNN(),那么,FasterRCNN又是从何而来呢?
答:在mmdet.models._*init里,可以看到from .detectors import \*
这行代码,再来瞧瞧mmdet.models.detectors._*init**.py:
from .base import BaseDetector
from .single_stage import SingleStageDetector
from .two_stage import TwoStageDetector
from .rpn import RPN
from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN
from .mask_rcnn import MaskRCNN
from .cascade_rcnn import CascadeRCNN
from .htc import HybridTaskCascade
from .retinanet import RetinaNet
from .fcos import FCOS
from .grid_rcnn import GridRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
__all__ = [
'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN'
]
可以看到这里注册了一大堆的模型,取出faster_rcnn来看,在mmdet.models.detectors.faster_rcnn.py里:
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):
def __init__(self,
backbone,
rpn_head,
bbox_roi_extractor,
bbox_head,
train_cfg,
test_cfg,
neck=None,
shared_head=None,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
shared_head=shared_head,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
看到这里以TwoStageDetector作为其父类,看TwoStageDetector:
import torch
import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
def __init__(self,
backbone,
neck=None,
shared_head=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
if shared_head is not None:
self.shared_head = builder.build_shared_head(shared_head)
if rpn_head is not None:
self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None:
if mask_roi_extractor is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.share_roi_extractor = False
else:
self.share_roi_extractor = True
self.mask_roi_extractor = self.bbox_roi_extractor
self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
@property
def with_rpn(self):
return hasattr(self, 'rpn_head') and self.rpn_head is not None
可以看到,在这里形成了整个模型。