mmd_hook
MMD3D:模型之registry.py和builder.py解读
1.引言:
本篇文章主要就是讲一下,搭建模型的思路,以及registry.py和builder.py中各个函数块的作用。
注:builder.py是在mmdet/models文件夹下,是用来创建BACKBONES、NECKS、ROI_EXTRACTORS、SHARED_HEADS、HEADS、LOSSES、DETECTORS的模型的。而关于build_dataset()(在mmdet/datasets/builder.py中),在后面讲到数据集的时候再来讲它。
在mmdet/utils文件夹下的registry.py为主要的实现过程,后面详细讲解。
先来看在mmdet/models文件夹下的registry.py,较简单,代码如下:
# -*- coding: utf-8 -*-
from mmdet.utils import Registry
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
#类的实例化,Registry是一个类,传入的是一个字符串。该字符串为Registry类的name属性值
举个例子:DETECTORS为注册表Registry的实例化对象,DETECTORS.name = 'detector',Registry类的定义在mmdet/utils/文件中。
所以,根据上面代码,我们就应该知道了,不止一个名为DETECTORS的注册表Registry,后面还会有名为NECKS、ROI_EXTRACTORS 、SHARED_HEADS 、HEADS 、LOSSES 的注册表,这些注册表下的_module_dict属性,则是用来存对应的相同类对象的,举个例子:比如DETECTORS的_module_dict下就有可能有:Faster R-CNN、Cascade R-CNN、FPN、HTC等常见的检测器,到这或许你就明白了注册表的作用咯。
而在mmdet/utils/Registry.py中,有一个类Registry的定义和一个方法:build_from_cfg()的实现。
build_from_cfg()方法的作用是从 congfig/py配置文件中获取字典数据,创建module(其实也就是一个class类),然后将这个module添加到之前创建的注册表Registry的属性_module_dict中(这是一个字典,key为类名,value为具体的类),返回值是一个实例化后的类对象。
所以,可以这样理解,从config/py配置文件中,将字典提取出来,然后为其映射成一个类,放进Registry对象的_module_dict属性中。(具体看下面的代码)
2.Registry.py文件
以下代码分三部分
2.1Part one:
inspect模块是针对模块,类,方法,功能等对象提供些有用的方法。例如可以帮助我们检查类的内容,检查方法的代码,提取和格式化方法的参数等。
# -*- coding: utf-8 -*-
import inspect
import mmcv
2.2Part two:
通过前面第一段的代码段,我们知道DETECTORS = Registry('detector')
detector是干什么的 ???
其实,DETECTORS = Registry('detector') 只是注册了一个对象名为DETECTORS ,属性name为detector的对象。然后用属性_module_dict 来保存config配置文件中的对应的字典数据所对应的class类(看第三部分代码)。请看如下类Registry的定义代码:
class Registry(object):
def __init__(self, name): #此处的self,是个对象(Object),是当前类的实例,name即为传进来的'detector'值
self._name = name
self._module_dict = dict() #定义的属性,是一个字典
def __repr__(self):
#返回一个可以用来表示对象的可打印字符串,可以理解为java中的toString()。
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property #把方法变成属性,通过self.name 就能获得name的值。
def name(self):
return self._name
#因为没有定义它的setter方法,所以是个只读属性,不能通过 self.name = newname进行修改。
@property
def module_dict(self):
#同上,通过self.module_dict可以获取属性_module_dict,也是只读的
return self._module_dict
def get(self, key):
#普通方法,获取字典中指定key的value,_module_dict是一个字典,然后就可以通过self.get(key),获取value值
return self._module_dict.get(key, None)
def _register_module(self, module_class):
#关键的一个方法,作用就是Register a module.
#在model文件夹下的py文件中,里面的class定义上面都会出现 @DETECTORS.register_module,意思就是将类当做形参,
#将类送入了方法register_module()中执行。@的具体用法看后面解释。
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class): #判断是否为类,是类的话,就为True,跳过判断
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__ #获取类名
if module_name in self._module_dict: #看该类是否已经登记在属性_module_dict中
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class #在module中dict新增key和value。key为类名,value为类对象
def register_module(self, cls): #对上面的方法,修改了名字,添加了返回值,即返回类本身
self._register_module(cls)
return cls
note:
@的含义: Python当解释器读到@的这样的修饰符之后,会先解析@后的内容,直接就把@下一行的函数或者类作为@后边的函数的参数,然后将返回值赋值给下一行修饰的函数对象。 在网上看到一个这样的例子:
def a(x):
if x==2:
return 4
return 6
def b(x):
if x==1:
return 2
return 3
@a
@b
def c():
return 1
python会按照自下而上的顺序把各自的函数结果作为下一个函数(上面的函数)的形参输入,也就是a(b(c()))。
2.3 Part three:
以下我们通过配置文件cascade_rcnn_r50_fpn_1x.py
进行讲解
build 模型的过程。
在train中,最先执行Registry的是DETECTORS,传入的参数是配置文件中的model字典。
#在 train.py中
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
#在builder.py中
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
所以,后面出现的参数cfg,指的就是配置文件中的model字典。下面是model字典的部分代码:
# model settings
model = dict(
type='CascadeRCNN',
num_stages=3,
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
我们继续往下看
先看build_from_cfg()方法的参数:
Args:
- cfg (dict): Config dict. It should at least contain the key “type”.这个cfg就是py配置文件中的字典。在py配置文件中,基本上dict都会有一个key为"type",当然也有不是的,不是的,这一步就不会执行,也就不会为他创建module。也就是这边创建成module的dict,都必须有key为"type"才可以创建(这里,我们主要讲的是注册表DETECTORS,所以此时cfg对应的是配置文件中的model字典,看上面截图)。举个例子:比如type='CascadeRCNN',后面我们会知道,这个value为"CascadeRCNN"的,其实就是models文件夹中某py文件中的类名,他们通过@DETECTORS.register_module,将类名当做形参,传入register_module。并保存下来。
- registry (:obj:Registry): The registry to search the type from.
- default_args (dict, optional): Default initialization arguments.
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None #两个是断言,相当于判断,否的话抛出异常。
args = cfg.copy() #args相当于temp中间变量,是个字典。
obj_type = args.pop('type') #字典的pop作用:移除序列中key为‘type’的元素,并且返回该元素的值
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type) #获取obj_type的value。
#如果obj_type已经注册到注册表registry中,即在属性_module_dict中,则obj_type 不为None
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():#items()返回字典的键值对用于遍历
args.setdefault(name, value)
#将default_args的键值对加入到args中,将模型和训练配置进行整合,然后送入类中返回
return obj_type(**args)
obj_type(**args),* *args是将字典unpack得到各个元素,分别与形参匹配送入函数中;看上面model的截图,所以这边,其实就是将除了’type’的所有字段,当做形参,送入了名为CascadeRCNN()的类中(type =' CascadeRCNN')。所以字典里的key就是类中的属性?继续看下面。
根据Cascade R-CNN的例子,我们在models/detectors找cascade_rcnn的py文件。
参考里面的参数时,直接打开对应的cascade_rcnn配置文件,在init中,里面的参数则
对应了配置文件中的字典名。下面两个截图分别是配置文件cascade_rcnn.py和model/detectors/cascade_rcnn.py中的类定义。
configs/cascade_rcnn.py:
# model settings
model = dict(
type='CascadeRCNN',
num_stages=3,
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
model/detectors/cascade_rcnn.py:
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self,
num_stages,
backbone,
neck=None,
shared_head=None,
rpn_head=None,
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
assert bbox_roi_extractor is not None
assert bbox_head is not None
super(CascadeRCNN, self).__init__()
注意的是,在py配置文件中,好多py文件中都有type = 'CascadeRCNN',所以有些参数和属性对不上很正常(毕竟已经设置为None了),因为这个参数可能是其他的cascade R-CNN里面的字典。
所以,我们在训练时,测试时,就要给出配置文件,配置文件可以不同,但相同type
detector等文件是相同的,毕竟已经将数据和实现完全的分离了。
注意:无论训练/检测,都会build DETECTORS;
2.4 builder.py文件
builder文件较为简单,因为train.py中,只出现了build_detector(),所以我们先记住里面的两个方法:build_detector和build()。
- build_detector:是创建一个detector,方法里调用了build()方法(所有的build_xx都是直接调用build方法,所以看懂这一个也就看懂所有了)。
- build():则是调用的Registry.py文件中的build_from_cfg()方法,这个方法我们已经在上面讲过了。
import:
# -*- coding: utf-8 -*-
from torch import nn
from mmdet.utils import build_from_cfg
#此处不会在执行registry而是直接进行sys.modules查询得到
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS)
#上面的registry是在models文件夹下,registry类的具体实现是在mmdet/utils文件夹下
只需要看一下build()的两个参数:cfg, registry
build_detector()在train.py中的调用,我们就可以知道,cfg是py配置文件中的字典, 以registry是DETECTORS为例,cfg就是model字典 (后面注册表为BACKBONES、NECKS等时,就是配置文件中的其他的字典了,不是model) 。
build()方法中,主干是一个判断结构,其实就是判断传进来的cfg是字典列表还是单独的字典,来分情况处理。(以注册表DETECTORS为例,是一个单独的字典)
- 字典列表的话:挨个调用build_from_cfg(),将其加到注册表的_module_dict中,然后再返回return nn.Sequential(*modules),这个地方的作用,有待博主继续研究一下下???
- 字典的话:直接调用build_from_cfg(),将其添加到注册表DETECTORS中(以DETECTORS为例)。
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
#build_from_cfg()返回值是一个带形参的类,返回时也就完成了实例化的过程。
]
#所以modules就是一个class类的列表
return nn.Sequential(*modules)
#nn.Sequential 一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数
else:
return build_from_cfg(cfg, registry, default_args) #Config dict
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
#DETECTORS = Registry('detector'),创建一个名为DETECTORS的注册表Registry。
def build_backbone(cfg):
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
return build(cfg, SHARED_HEADS)
def build_head(cfg):
return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)
后面的几个build_XXXXX()的方法也就跟build_detector()相同咯。
还是以注册表DETECTORS为例,配置文件为cascade_rcnn_r50_fpn_1x.py来讲解:在model文件夹下的cascade_rcnn.py文件中,有类Cascade_RCNN()的定义,在配置文件中,对应的key被传入类中当做属性,这些属性被初始化的时候,调用对应的build_XXXXX(),由此创建它们对应的注册表。
再以NECK为例,调用build_neck(cfg);然后执行build(cfg, NECKS),这一步,形参用到NECKS,所以在Registry中,又多了一个名为NECKS的注册表了。然后将配置文件中,字典名为neck的,然后生成一个类(类名是neck字典中的type的值,该类在models/necks文件夹下),同时将该类添加到了注册表NECKS的_module_dict中。
#在model/detectors/cascade_rcnn.py中
if neck is not None:
self.neck = builder.build_neck(neck)
#再builder.py中
def build_neck(cfg):
return build(cfg, NECKS)
#在configs/cascade_rcnn_r50_fpn_1x.py中
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
到这,NECK的注册和数据读入,相信大家已经很清楚了,其他的注册表也是类似的。
3.总结:
搭建模型思路:
- 首先,创建一个名为DETECTORS的注册表Registry。这个注册表有属性name='detector',和属性_module_dict。_module_dict 是一个字典,专门用来存各个对象名和对应的对象。
- 其次,读取py配置文件,py配置文件是个字典,(字典里还有字典,这里面的字典,也是后面来创建模型的,道理是一样的)。根据key为'type'的字典,创建module,对于的value为其module名,然后再models文件夹下中,已经存在了这些module的类。将字典中的其他数据,作为形参,实例化这些类。并保存这些module到属性_module_dict中。
- 到这,配置文件的数据,里面的字典(含有type的字典)对应着一个类,type为类名,其他字段则为其属性(其他字段也可能是个字典,后面也有可能要再为它们搭建模型哦)。由此完成模型的搭建。
这是搭建模型的一个思路,虽然讲得篇幅很大,有点乱乱的感觉,但是看懂后,就会发现很简单。
mmdetection搭建模型用途:
mmdetection将配置文件中,字典名为:backbone、neck、roi_extractor、shared_head、head、loss、detector的字典,全部实例化成注册表(Registry),然后这些字典里的type,都被实例化成对应的类(module),并添加到注册表的属性_module_dict中,其他的字段,则为这个类的属性,由此完成模型的建立,实际上,就是将配置文件的字典数据保存到类(module)中,以便后面读取数据,加载数据。
Problem:
Importance:
总体讲的通俗易懂,但是有一点疑问。文中多次出现类似于"通过build实例化模块类,然后把实例存入model_dict"的表述。但是,结合我目前掌握的相关知识,我的理解是,模块类的注册是通过register_module方法,把模块类加入到model_dict,然后bulid方法根据cfg提供的type类名,从Register类中通过get方法从model_dict获取模块类,然后利用cfg提供的参数对模块类进行实例化