Registry
Registry
Registry 类可以提供一种完全相似的对外装饰函数来管理构建不同的组件,例如 backbones、head 和 necks 等等,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。
Registry 类最大好处是:解耦性强、可扩展性强,代码更易理解。
回到 Registry 类本身,有如下几种用法:
# 0. 先构建一个全局的 CATS 注册器类
CATS = mmcv.Registry('cat')
# 通过装饰器方式作用在想要加入注册器的具体类中
#===============================================================
# 1. 不需要传入任何参数,此时默认实例化的配置字符串是 str (类名)
@CATS.register_module()
class BritishShorthair:
pass
# 类实例化
CATS.get('BritishShorthair')(**args)
#==============================================================
# 2.传入指定 str,实例化时候只需要传入对应相同 str 即可
@CATS.register_module(name='Siamese')
class SiameseCat:
pass
# 类实例化
CATS.get('Siamese')(**args)
#===============================================================
# 3.如果出现同名 Registry Key,可以选择报错或者强制覆盖
# 如果指定了 force=True,那么不会报错
# 此时 Registry 的 Key 中,Siamese2Cat 类会覆盖 SiameseCat 类
# 否则会报错
@CATS.register_module(name='Siamese',force=True)
class Siamese2Cat:
pass
# 类实例化
CATS.get('Siamese')(**args)
#==============================================================
# 4. 可以直接注册类
class Munchkin:
pass
CATS.register_module(Munchkin)
# 类实例化
CATS.get('Munchkin')(**args)
(1) 最简实现
# 方便起见,此处并未使用类方式构建,而是直接采用全局变量
= dict()
_module_dict
# 定义装饰器函数
def register_module(name):
def _register(cls):
= cls
_module_dict[name] return cls
return _register
# 装饰器用法
@register_module('one_class')
class OneTest(object):
pass
@register_module('two_class')
class TwoTest(object):
pass
进行简单测试:
if __name__ == '__main__':
# 通过注册类名实现自动实例化功能
= _module_dict['one_class']()
one_test print(one_test)
# 输出
<__main__.OneTest object at 0x7f1d7c5acee0>
可以发现只要将所定义的简单装饰器函数作用到类名上,然后内部采用
_module_dict
保存信息即可
(2) 实现无需传入参数,自动根据类名初始化类
= dict()
_module_dict
def register_module(module_name=None):
def _register(cls):
= module_name
name # 如果 module_name 没有给,则自动获取
if module_name is None:
= cls.__name__
name = cls
_module_dict[name] return cls
return _register
@register_module('one_class')
class OneTest(object):
pass
@register_module()
class TwoTest(object):
pass
进行简单测试:
if __name__ == '__main__':
= _module_dict['one_class']
one_test # 方便起见,此处仅仅打印了类对象,而没有实例化。如果要实例化,只需要 one_test() 即可
print(one_test)
= _module_dict['TwoTest']
two_test print(two_test)
# 输出
<class '__main__.OneTest '>
<class '__main__.TwoTest'>
Registry 类实现
基于上面的理解,此时再来看 MMCV 实现就会非常简单了,核心逻辑如下:
class Registry:
def __init__(self, name):
# 可实现注册类细分功能
self._name = name
# 内部核心内容,维护所有的已经注册好的 class
self._module_dict = dict()
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
= module_class.__name__
module_name if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
# 最核心代码
self._module_dict[module_name] = module_class
# 装饰器函数
def register_module(self, name=None, force=False, module=None):
if module is not None:
# 如果已经是 module,那就知道 增加到字典中即可
self._register_module(
=module, module_name=name, force=force)
module_classreturn module
# 最标准用法
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
=cls, module_name=name, force=force)
module_classreturn cls
return _register
在 MMCV 中所有的类实例化都是通过 build_from_cfg
函数实现,做的事情非常简单,就是给定 module_name
,然后从
self._module_dict
提取即可。
def build_from_cfg(cfg, registry, default_args=None):
= cfg.copy()
args
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
= args.pop('type') # 注册 str 类名
obj_type if is_str(obj_type):
# 相当于 self._module_dict[obj_type]
= registry.get(obj_type)
obj_cls if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
# 如果已经实例化了,那就直接返回
elif inspect.isclass(obj_type):
= obj_type
obj_cls else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# 最终初始化对于类,并且返回,就完成了一个类的实例化过程
return obj_cls(**args)
一个完整的使用例子如下:
= Registry('converter')
CONVERTERS
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b
= dict(type='Converter1', a=a_value, b=b_value)
converter_cfg = build_from_cfg(converter_cfg,CONVERTERS) converter
mmdetection模型构建及Registry注册器机制
mmdetection封装的很好,很方便使用,比如我想训练的话只需如下的一条指令。在train.py中,通过build_detector来构建模型,
python tools/train.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
build_detector的定义如下,最后通过build_from_cfg来构建模型,这里看到了让人困惑的Registry.
from mmdet.cv_core.utils import Registry, build_from_cfg
from torch import nn
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)
一、Registry是干什么的
Registry完成了从字符串到类的映射,这样模型信息、训练时的参数信息,只需要写入到一个配置文件里,然后使用注册器来实例化即可。
二、如何实现
通过装饰器来实现。在mmcv/mmcv/registry.py中,我们看到了Registry类。其中完成字符串到类的映射,实际上就是下面的成员函数来实现的,核心代码就一句,将要注册的类添加到字典里,key为类的名字(字符串)。下面通过一个小例子,
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
来看看它的构建过程。在导入下面这个文件时,首先创建FRUIT实例,接着通过装饰器(这里是用成员函数装饰类)来注册Apple类,调用register_module,然后调用_register(注意:参数cls即为类Apple),最后调用_register_module完成Apple的添加。完成后,FRUIT就有了个字典成员:['Apple']=APPle。在build_from_cfg中,传入模型参数,即可通过FRUIT构建Apple的实例化对象。
class Registry():
def __init__(self, name):
self._name = name
self._module_dict = dict()
def _register_module(self, module_class, module_name, force):
self._module_dict[module_name] = module_class
print('self._module_dict',self._module_dict)
def register_module(self, name=None, force=False, module=None):
print('register module ...')
def _register(cls):
print('cls ', cls)
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
FRUIT = Registry('fruit')
@FRUIT.register_module()
class Apple():
def __init__(self, name):
self.name = name
运行结果:
register module ...
cls <class '__main__.Apple'>
self._module_dict {None: <class '__main__.Apple'>}
三、Registry在mmdetection中是如何构建模型的
我们来看一下构建模型的流程:
1、在train.py中通过build_detector构建模型,其中cfg.model, cfg.train_cfg如下,包括模型信息和训练信息。
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
2、最关键的部分来了。首先通过build_detector构建模型, 其中传入的DETECTORS是Registry的实例,在该实例中,包含了所有已经实现的检测器,如图。那么它是在哪里实现添加这些检测的类的呢?
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
看了前面那个小例子我们就能猜到,一定是在这些检测类上,用Registry对其进行了注册,看看faster rcnn的实现,证明了我们的猜想。这样只要
在定义这些类时,对其进行注册,那么就会自动加入到DETECTORS这个实例的成员字典里,非常的巧妙。当我们想实例化某个检测网络时,传入其字符名称即可。
既然都看到这里了,就进一步看看网络时如何继续构建的吧。mmdetection将网络分成了几个部分,backbone,head,neck等。在TwoStageDetector(
faster rcnn的基类)中,可以看到分别构建了这几个部分。head, neck, loss等,同样是通过Registry来注册实现的。最后就是将这几个部分组合起来即可。
@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
"""Base class for two-stage detectors.
Two-stage detectors typically consisting of a region proposal network and a
task-specific regression head.
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
self.roi_head = build_head(roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
简单理解mmdetection中的registry类
注册器类(Registry)
在mmdetection中,将会使用该类构建9个注册类实例,其实就是对类做一个划分管理。Python 装饰器的特性就是 被装饰对象(比如 ResNet 类)被定义的时候就立刻运行,从而将 ResNet 注册进 BACKBONES。
比如,backbone 作为一族(vgg,resnet等)
文件:mmdet.py
= Registry('backbone')
BACKBONES = Registry('neck')
NECKS = Registry('roi_extractor')
ROI_EXTRACTORS = Registry('shared_head')
SHARED_HEADS = Registry('head')
HEADS = Registry('loss')
LOSSES = Registry('detector') DETECTORS
文件:mmdet.py
= Registry('dataset')
DATASETS = Registry('pipeline') PIPELINES
每一个实例,都是存放属于这一簇的类,将来通过get key方式获取,key 来自于config文件.
mmdetection在构建模型的过程中,一直是通过key 去查找对应的类(在注册器中),找到对应的类,然后实例化,最终将配置描述的模型,构建出来.
举个栗子:
= 'vgg'
key = BACKBONES.get(key)
VGG
= 'bce'
key = LOSSES .get(key) BCE
Registry 类
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import inspect
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
= self.__class__.__name__ + '(name={}, items={})'.format(
format_str self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
= module_class.__name__
module_name if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
self.name))
module_name, self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
举个栗子:
在mmdetection的代码中,将一个类注册(插入)到(某一个)注册器里面,是直接写在类的声明上方.
= Registry('animal')
ANIMAL
@ANIMAL.register_module
class Dog(object):
def __init__(self):
pass
def run(self):
print('running dog')
# ANIMAL.register_module(Dog)
= ANIMAL.get('Dog')
dog
= dog()
d d.run()
等价写法:
= Registry('animal')
ANIMAL
class Dog(object):
def __init__(self):
pass
def run(self):
print('running dog')
ANIMAL.register_module(Dog)
= ANIMAL.get('Dog')
dog
= dog()
d d.run()
两者输出结果皆为:
running dog
mmcv之Registry类解读(增删改查)
前言
本文主要介绍mmcv的Registry类。建议读者先配置下mmcv环境:mmcv源码安装。我相信读者大多数对于Registry类有点儿迷,主要涉及python中装饰器的知识。因此,本文尽量做到面面俱到,会简要介绍一部分装饰器的用法。
1、Registry作用
Registry类可以简单理解为一个字典,举个例子,在mmdetection中,比如说创建了名为dataset的注册器对象,则注册器dataset中包含(CocoDataset类,VOCDataset类,Lvis类);同理,detector注册器对象中包含(FasterRcnn类,SSD类,YOLO类等)。因此,Registry对象完全可以理解为一个字典,里面存储着同系列的类。
2、源码分析
Registry虽说是一个字典,但是得实现增删改查的功能。增即往字典中添加新的类;查即查询字典中是否有这个类。那么在Registry类中如何实现这些功能呢?
2.1.初始化部分
class Registry:
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
这部分比较简单,就是传入了一个name并内部定义了一个self._module_dict字典。
2.2.查
查找self._module_dict存在一个某个类 实现也比较简单:
def get(self, key):
return self._module_dict.get(key, None)
主要借助get方法,若有key则返回对应的value;若无key则返回None。
2.3.增
增的方法mmdetection中提供了两种方式,区别是方法_register_module()是否指定了module参数:
该函数主要往self._module_dict中添加类。注意,往字典里面添加的是类。以下代码包含了上图中两种方式。这里我截取了核心代码:
def _register_module(self, module_class, module_name=None, force=False):
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
self._module_dict[name] = module_class
def register_module(self, name=None, force=False, module=None):
# 若指定module,则执行if语句,执行完后完成module类添加
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# 若没有指定module,则执行_register函数。
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
我将分两小节来介绍这两种方式。
2.3.1 指定module参数
现在我们想往字典self._module_dict字典中添加新类。最容易想到方法就是下面这样:
if __name__ == '__main__':
backbones = Registry('backbone')
class MobileNet:
pass
backbones.register_module(module=MobileNet)
print(backbones)
即直接指定参数module=MobileNet。内部通过self._module_dict[name]=module_class完成注册。
2.3.2 不指定module参数
上节提供方法完全可以,但是在利用mmdetection拓展新模型的时候,如果每次创建完一个类之后,然后通过上述方法注册,着实不方便。势必会影响mmdetection拓展性。而装饰器可以很方便给类拓展新功能,装饰器有机会我会单独出一篇文章, 这里简单记住装饰器用法:funB = funA(funB),即被装饰函数funB,经过装饰器funA的装饰,中间可能发生了一些其他事情,最终funA的return funB。 首先看用法:比如我想注册ResNet。
if __name__ == '__main__':
backbones = Registry('backbone')
@backbones.register_module()
class ResNet:
pass
print(backbones)
这里内部实质上经过了下面函数:
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
在这个过程中,funB相当于cls。而_register函数相当于funA。中间往self._module_dict字典中注册了类cls。然后return cls。即funB。