前言
TensorFlow 的API确实有点乱,1版本和2版本函数变化比较大,最不舒服的一点是,多个不同函数可以实现同一个功能。但是基于TF的keras 用起来倒挺方便。
今天试着用了用pytorch,发现定义网络时连接网络的forward函数在代码中都不需要被调用,很是困惑。搜了一下才知道,自己编写的网络的类继承nn.Module
类,nn.Module
里面有__call__
方法,forward函数是在传入数据时由__call__
方法调用的。
调用过程
pytorch通常的用法是:
1 | class Module(nn.Module): |
实际上
1 | module(data) |
就等价于:
1 | module.forward(data) |
我们在pycharm中把光标放在nn.Module上,按Ctrl+B,可以定位到nn.Module类的源码,在里面找__call__
,可以找到这一行代码:
1 | __call__ : Callable[..., Any] = _call_impl |
看不懂这行代码的可以看这篇文章,
它其实相当于:
1 | __call__ = _call_impl |
也就将给_call_impl
函数赋给了__call__
方法。
找到 _call_impl
函数:
1 | def _call_impl(self, *input, **kwargs): |
可以看到里面调用的forward函数。我们知道__call__
方法会在对象被调用时自动执行,所以说在你调用你定义的模型类创建的对象时,就像这样:
1 | net = Net() |
它就会自动调用forward函数,并把 x 参数传给forward来执行。
参考链接: