pytorch中forward方法是如何调用的

前言

TensorFlow 的API确实有点乱,1版本和2版本函数变化比较大,最不舒服的一点是,多个不同函数可以实现同一个功能。但是基于TF的keras 用起来倒挺方便。

今天试着用了用pytorch,发现定义网络时连接网络的forward函数在代码中都不需要被调用,很是困惑。搜了一下才知道,自己编写的网络的类继承nn.Module 类,nn.Module 里面有__call__ 方法,forward函数是在传入数据时由__call__ 方法调用的。

调用过程

pytorch通常的用法是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
# ......

def forward(self, x):
# ......
return x

data = ..... #输入数据
# 实例化一个对象
module = Module()
# 前向传播
module(data)
# 而不是使用下面的
# module.forward(data)

实际上

1
module(data)

就等价于:

1
module.forward(data)

我们在pycharm中把光标放在nn.Module上,按Ctrl+B,可以定位到nn.Module类的源码,在里面找__call__ ,可以找到这一行代码:

1
__call__ : Callable[..., Any] = _call_impl

看不懂这行代码的可以看这篇文章,

它其实相当于:

1
__call__ = _call_impl

也就将给_call_impl 函数赋给了__call__ 方法。

找到 _call_impl 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def _call_impl(self, *input, **kwargs):
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
....
....

可以看到里面调用的forward函数。我们知道__call__ 方法会在对象被调用时自动执行,所以说在你调用你定义的模型类创建的对象时,就像这样:

1
2
net = Net()
net(x) # 调用对象

它就会自动调用forward函数,并把 x 参数传给forward来执行。


参考链接:

https://blog.csdn.net/xu380393916/article/details/97280035