PyTorch中nn.Module类中__call__方法介绍

在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下:

__call__ : Callable[…, Any] = _call_impl
forward: Callable[…, Any] = _forward_unimplemented

在PyTorch中nn.Module类是所有神经网络模块的基类,你的网络也应该继承这个类,需要重载__init__和forward函数。以下是仿照PyTorch中Module和AlexNet类实现写的假的实现的测试代码:

from typing import Callable, Any, List
 
def _forward_unimplemented(self, *input: Any) -> None:
    "Should be overridden by all subclasses"
    print("_forward_unimplemented")
    raise NotImplementedError
 
class Module:
    def __init__(self):
        print("Module.__init__")
 
    forward: Callable[..., Any] = _forward_unimplemented
 
    def _call_impl(self, *input, **kwargs):
        print("Module._call_impl")
        result = self.forward(*input, **kwargs)
        return result
 
    __call__: Callable[..., Any] = _call_impl
 
    def cpu(self):
        print("Module.cpu")
 
class AlexNet(Module):
    def __init__(self):
        print("AlexNet.__init__")
        super(AlexNet, self).__init__()
 
    def forward(self, x):
        print("AlexNet.forward")
        return x
 
model = AlexNet()
x: List[int] = [1, 2, 3, 4]
print("result:", model(x))
 
model.cpu()
 
print("test finish")

执行model(x)语句时,会调用AlexNet的forward函数,是因为AlexNet的父类Module中的__call__函数:首先Module中有__call__方法,因此model(x)这条语句可以正常执行。Module中并没有直接给出__call__的实现体,而是__call__后紧跟冒号,此冒号表示类型注解;后面的Callable和Any是typing模块中的,Callable表示可调用类型,即等号右边应该是一个可调用类型,此处指的是_call_impl;Any是一种特殊的类型,它与所有类型兼容;Callable[…, Any]表示_call_impl可接受任意数量的参数并返回Any。这里__call__实际指向了_call_impl函数,因此调用__call__实际是调用_call_impl。

  typing模块的介绍参考:https://blog.csdn.net/fengbingchun/article/details/122288737

  _call_impl函数体内会调用forward,Module中的forward的实现方式与__call__相同,但是_forward_unimplemented函数并没有实现体,调用它会触发Error即NotImplementedError。因此在子类AlexNet中一定要给出forward的具体实现,否则调用的将是_forward_unimplemented。

测试代码执行结果如下:

forward,则执行结果如下: