Registry

Registry 类可以提供一种完全相似的对外装饰函数来管理构建不同的组件,例如 backbones、head 和 necks 等等,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。

Registry 类最大好处是:解耦性强、可扩展性强,代码更易理解

回到 Registry 类本身,有如下几种用法:

# 0. 先构建一个全局的 CATS 注册器类
CATS = mmcv.Registry('cat')

# 通过装饰器方式作用在想要加入注册器的具体类中
#===============================================================
# 1. 不需要传入任何参数,此时默认实例化的配置字符串是 str (类名)
@CATS.register_module()
class BritishShorthair:
    pass
# 类实例化
CATS.get('BritishShorthair')(**args)

#==============================================================
# 2.传入指定 str,实例化时候只需要传入对应相同 str 即可
@CATS.register_module(name='Siamese')
class SiameseCat:
    pass
# 类实例化
CATS.get('Siamese')(**args)

#===============================================================
# 3.如果出现同名 Registry Key,可以选择报错或者强制覆盖
# 如果指定了 force=True,那么不会报错
# 此时 Registry 的 Key 中,Siamese2Cat 类会覆盖 SiameseCat 类
# 否则会报错
@CATS.register_module(name='Siamese',force=True)
class Siamese2Cat:
    pass
# 类实例化
CATS.get('Siamese')(**args)

#==============================================================
# 4. 可以直接注册类
class Munchkin:
    pass
CATS.register_module(Munchkin)

# 类实例化
CATS.get('Munchkin')(**args)

(1) 最简实现

# 方便起见,此处并未使用类方式构建,而是直接采用全局变量

_module_dict = dict()

# 定义装饰器函数
def register_module(name):
    def _register(cls):
        _module_dict[name] = cls
        return cls

    return _register

# 装饰器用法
@register_module('one_class')
class OneTest(object):
    pass

@register_module('two_class')
class TwoTest(object):
    pass

进行简单测试:

if __name__ == '__main__':
    # 通过注册类名实现自动实例化功能
    one_test = _module_dict['one_class']()
    print(one_test)

# 输出
<__main__.OneTest object at 0x7f1d7c5acee0>

可以发现只要将所定义的简单装饰器函数作用到类名上,然后内部采用 _module_dict 保存信息即可

(2) 实现无需传入参数,自动根据类名初始化类

_module_dict = dict()

def register_module(module_name=None):
    def _register(cls):
        name = module_name
        # 如果 module_name 没有给,则自动获取
        if module_name is None:
            name = cls.__name__
        _module_dict[name] = cls
        return cls

    return _register

@register_module('one_class')
class OneTest(object):
    pass

@register_module()
class TwoTest(object):
    pass

进行简单测试:

if __name__ == '__main__':
    one_test = _module_dict['one_class']
    # 方便起见,此处仅仅打印了类对象,而没有实例化。如果要实例化,只需要 one_test() 即可
    print(one_test)
    two_test = _module_dict['TwoTest']
    print(two_test)

# 输出
<class '__main__.OneTest '>
<class '__main__.TwoTest'>

Registry 类实现

基于上面的理解,此时再来看 MMCV 实现就会非常简单了,核心逻辑如下:

class Registry:
    def __init__(self, name):
        # 可实现注册类细分功能
        self._name = name 
        # 内部核心内容,维护所有的已经注册好的 class
        self._module_dict = dict()

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        # 最核心代码
        self._module_dict[module_name] = module_class

    # 装饰器函数
    def register_module(self, name=None, force=False, module=None):
        if module is not None:
            # 如果已经是 module,那就知道 增加到字典中即可
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # 最标准用法
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

在 MMCV 中所有的类实例化都是通过 build_from_cfg 函数实现,做的事情非常简单,就是给定 module_name,然后从 self._module_dict 提取即可。

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type') # 注册 str 类名
    if is_str(obj_type):
        # 相当于 self._module_dict[obj_type]
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')

    # 如果已经实例化了,那就直接返回
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')

    # 最终初始化对于类,并且返回,就完成了一个类的实例化过程
    return obj_cls(**args)

一个完整的使用例子如下:

CONVERTERS = Registry('converter')

@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)


mmdetection模型构建及Registry注册器机制

mmdetection封装的很好,很方便使用,比如我想训练的话只需如下的一条指令。在train.py中,通过build_detector来构建模型,

python tools/train.py  configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py

build_detector的定义如下,最后通过build_from_cfg来构建模型,这里看到了让人困惑的Registry.

from mmdet.cv_core.utils import Registry, build_from_cfg
from torch import nn

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_detector(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)

一、Registry是干什么的

Registry完成了从字符串到类的映射,这样模型信息、训练时的参数信息,只需要写入到一个配置文件里,然后使用注册器来实例化即可。

二、如何实现

通过装饰器来实现。在mmcv/mmcv/registry.py中,我们看到了Registry类。其中完成字符串到类的映射,实际上就是下面的成员函数来实现的,核心代码就一句,将要注册的类添加到字典里,key为类的名字(字符串)。下面通过一个小例子,

def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        self._module_dict[module_name] = module_class

来看看它的构建过程。在导入下面这个文件时,首先创建FRUIT实例,接着通过装饰器(这里是用成员函数装饰类)来注册Apple类,调用register_module,然后调用_register(注意:参数cls即为类Apple),最后调用_register_module完成Apple的添加。完成后,FRUIT就有了个字典成员:['Apple']=APPle。在build_from_cfg中,传入模型参数,即可通过FRUIT构建Apple的实例化对象。

class Registry():

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def _register_module(self, module_class, module_name, force):
        self._module_dict[module_name] = module_class
        print('self._module_dict',self._module_dict)


    def register_module(self, name=None, force=False, module=None):
        print('register module ...')
        def _register(cls):
            print('cls ', cls)
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

FRUIT = Registry('fruit')

@FRUIT.register_module()
class Apple():
    def __init__(self, name):
        self.name = name

运行结果:

register module ...
cls  <class '__main__.Apple'>
self._module_dict {None: <class '__main__.Apple'>}

三、Registry在mmdetection中是如何构建模型的

我们来看一下构建模型的流程:

​ 1、在train.py中通过build_detector构建模型,其中cfg.model, cfg.train_cfg如下,包括模型信息和训练信息。

model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

2、最关键的部分来了。首先通过build_detector构建模型, 其中传入的DETECTORS是Registry的实例,在该实例中,包含了所有已经实现的检测器,如图。那么它是在哪里实现添加这些检测的类的呢?

def build_detector(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

看了前面那个小例子我们就能猜到,一定是在这些检测类上,用Registry对其进行了注册,看看faster rcnn的实现,证明了我们的猜想。这样只要

在定义这些类时,对其进行注册,那么就会自动加入到DETECTORS这个实例的成员字典里,非常的巧妙。当我们想实例化某个检测网络时,传入其字符名称即可。

既然都看到这里了,就进一步看看网络时如何继续构建的吧。mmdetection将网络分成了几个部分,backbone,head,neck等。在TwoStageDetector(

faster rcnn的基类)中,可以看到分别构建了这几个部分。head, neck, loss等,同样是通过Registry来注册实现的。最后就是将这几个部分组合起来即可。

@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
    """Base class for two-stage detectors.

    Two-stage detectors typically consisting of a region proposal network and a
    task-specific regression head.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(TwoStageDetector, self).__init__()
        self.backbone = build_backbone(backbone)

        if neck is not None:
            self.neck = build_neck(neck)

        if rpn_head is not None:
            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
            rpn_head_ = rpn_head.copy()
            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
            self.rpn_head = build_head(rpn_head_)

        if roi_head is not None:
            # update train and test cfg here for now
            # TODO: refactor assigner & sampler
            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
            roi_head.update(train_cfg=rcnn_train_cfg)
            roi_head.update(test_cfg=test_cfg.rcnn)
            self.roi_head = build_head(roi_head)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.init_weights(pretrained=pretrained)

简单理解mmdetection中的registry类

注册器类(Registry)

在mmdetection中,将会使用该类构建9个注册类实例,其实就是对类做一个划分管理。Python 装饰器的特性就是 被装饰对象(比如 ResNet 类)被定义的时候就立刻运行,从而将 ResNet 注册进 BACKBONES。

比如,backbone 作为一族(vgg,resnet等)

文件:mmdet.py

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

文件:mmdet.py

DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')

每一个实例,都是存放属于这一簇的类,将来通过get key方式获取,key 来自于config文件.

mmdetection在构建模型的过程中,一直是通过key 去查找对应的类(在注册器中),找到对应的类,然后实例化,最终将配置描述的模型,构建出来.

举个栗子:

key = 'vgg'
VGG = BACKBONES.get(key)

key = 'bce'
BCE = LOSSES .get(key)

Registry 类

#!/usr/bin/python3
# -*- coding: utf-8 -*-
import inspect

class Registry(object):

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls

举个栗子:

在mmdetection的代码中,将一个类注册(插入)到(某一个)注册器里面,是直接写在类的声明上方.

ANIMAL = Registry('animal')

@ANIMAL.register_module
class Dog(object):
    def __init__(self):
        pass

    def run(self):
        print('running dog')

# ANIMAL.register_module(Dog)

dog = ANIMAL.get('Dog')

d = dog()
d.run()

等价写法:

ANIMAL = Registry('animal')

class Dog(object):
    def __init__(self):
        pass

    def run(self):
        print('running dog')

ANIMAL.register_module(Dog)

dog = ANIMAL.get('Dog')

d = dog()
d.run()

两者输出结果皆为:

running dog


mmcv之Registry类解读(增删改查)

前言

 本文主要介绍mmcv的Registry类。建议读者先配置下mmcv环境:mmcv源码安装。我相信读者大多数对于Registry类有点儿迷,主要涉及python中装饰器的知识。因此,本文尽量做到面面俱到,会简要介绍一部分装饰器的用法。

1、Registry作用

 Registry类可以简单理解为一个字典,举个例子,在mmdetection中,比如说创建了名为dataset的注册器对象,则注册器dataset中包含(CocoDataset类,VOCDataset类,Lvis类);同理,detector注册器对象中包含(FasterRcnn类,SSD类,YOLO类等)。因此,Registry对象完全可以理解为一个字典,里面存储着同系列的类。

2、源码分析

 Registry虽说是一个字典,但是得实现增删改查的功能。增即往字典中添加新的类;查即查询字典中是否有这个类。那么在Registry类中如何实现这些功能呢?

2.1.初始化部分

class Registry:
    """A registry to map strings to classes.
Args:
    name (str): Registry name.
"""

def __init__(self, name):
    self._name = name
    self._module_dict = dict()

def __len__(self):
    return len(self._module_dict)

def __contains__(self, key):
    return self.get(key) is not None

def __repr__(self):
    format_str = self.__class__.__name__ + \
                 f'(name={self._name}, ' \
                 f'items={self._module_dict})'
    return format_str

这部分比较简单,就是传入了一个name并内部定义了一个self._module_dict字典。

2.2.查

 查找self._module_dict存在一个某个类 实现也比较简单:

def get(self, key):
    return self._module_dict.get(key, None)

 主要借助get方法,若有key则返回对应的value;若无key则返回None。

2.3.增

 增的方法mmdetection中提供了两种方式,区别是方法_register_module()是否指定了module参数:

该函数主要往self._module_dict中添加类。注意,往字典里面添加的是类。以下代码包含了上图中两种方式。这里我截取了核心代码:

def _register_module(self, module_class, module_name=None, force=False):

    if module_name is None:
        module_name = module_class.__name__
    if isinstance(module_name, str):
        module_name = [module_name]
        
    self._module_dict[name] = module_class

def register_module(self, name=None, force=False, module=None):
    # 若指定module,则执行if语句,执行完后完成module类添加
    if module is not None:
        self._register_module(
            module_class=module, module_name=name, force=force)
        return module

    # 若没有指定module,则执行_register函数。
    def _register(cls):
        self._register_module(
            module_class=cls, module_name=name, force=force)
        return cls

    return _register

 我将分两小节来介绍这两种方式。

2.3.1 指定module参数

  现在我们想往字典self._module_dict字典中添加新类。最容易想到方法就是下面这样:

if __name__ == '__main__':
    backbones = Registry('backbone')
    class MobileNet:
        pass
    backbones.register_module(module=MobileNet)
    print(backbones)

 即直接指定参数module=MobileNet。内部通过self._module_dict[name]=module_class完成注册。

2.3.2 不指定module参数

 上节提供方法完全可以,但是在利用mmdetection拓展新模型的时候,如果每次创建完一个类之后,然后通过上述方法注册,着实不方便。势必会影响mmdetection拓展性。而装饰器可以很方便给类拓展新功能,装饰器有机会我会单独出一篇文章,  这里简单记住装饰器用法:funB = funA(funB),即被装饰函数funB,经过装饰器funA的装饰,中间可能发生了一些其他事情,最终funA的return funB。  首先看用法:比如我想注册ResNet。

if __name__ == '__main__':
    backbones = Registry('backbone')
    @backbones.register_module()
    class ResNet:
        pass
    print(backbones)

 这里内部实质上经过了下面函数:

        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

在这个过程中,funB相当于cls。而_register函数相当于funA。中间往self._module_dict字典中注册了类cls。然后return cls。即funB。