pytorch_nn_module__call__
PyTorch中nn.Module类中__call__方法介绍
在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下:
__call__ : Callable[…, Any] = _call_impl
forward: Callable[…, Any] = _forward_unimplemented
在PyTorch中nn.Module类是所有神经网络模块的基类,你的网络也应该继承这个类,需要重载__init__和forward函数。以下是仿照PyTorch中Module和AlexNet类实现写的假的实现的测试代码:
from typing import Callable, Any, List
def _forward_unimplemented(self, *input: Any) -> None:
"Should be overridden by all subclasses"
print("_forward_unimplemented")
raise NotImplementedError
class Module:
def __init__(self):
print("Module.__init__")
forward: Callable[..., Any] = _forward_unimplemented
def _call_impl(self, *input, **kwargs):
print("Module._call_impl")
result = self.forward(*input, **kwargs)
return result
__call__: Callable[..., Any] = _call_impl
def cpu(self):
print("Module.cpu")
class AlexNet(Module):
def __init__(self):
print("AlexNet.__init__")
super(AlexNet, self).__init__()
def forward(self, x):
print("AlexNet.forward")
return x
model = AlexNet()
x: List[int] = [1, 2, 3, 4]
print("result:", model(x))
model.cpu()
print("test finish")
执行model(x)语句时,会调用AlexNet的forward函数,是因为AlexNet的父类Module中的__call__函数:首先Module中有__call__方法,因此model(x)这条语句可以正常执行。Module中并没有直接给出__call__的实现体,而是__call__后紧跟冒号,此冒号表示类型注解;后面的Callable和Any是typing模块中的,Callable表示可调用类型,即等号右边应该是一个可调用类型,此处指的是_call_impl;Any是一种特殊的类型,它与所有类型兼容;Callable[…, Any]表示_call_impl可接受任意数量的参数并返回Any。这里__call__实际指向了_call_impl函数,因此调用__call__实际是调用_call_impl。
typing模块的介绍参考:https://blog.csdn.net/fengbingchun/article/details/122288737
_call_impl函数体内会调用forward,Module中的forward的实现方式与__call__相同,但是_forward_unimplemented函数并没有实现体,调用它会触发Error即NotImplementedError。因此在子类AlexNet中一定要给出forward的具体实现,否则调用的将是_forward_unimplemented。
测试代码执行结果如下:
forward,则执行结果如下:
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.