美文网首页
深入探究PyTorch中为什么model(image)会自动调用

深入探究PyTorch中为什么model(image)会自动调用

作者: CPinging | 来源:发表于2021-04-20 21:49 被阅读0次

主要是对题目的这个问题太好奇了,于是就看了一下源码,并有了如下的总结。

一、问题起因

经常写PyTorch模型的人会写:output = model(images)来进行前项传播,但是有没有仔细想过为啥这个image传入之后就能自动调用forward呢?

二、探究

于是我追踪了源码并阅读了一些资料,有了如下总结:

首先,model()是一个类,例如这里用alexnet为例子:

class AlexNet(nn.Module):

    def __init__(self, num_classes=200):
        # generate fater class init
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

在该代码中我们定义了forward函数,以及构建了AlexNet的class类。值得注意的是,该类继承了父类Module,这里插入一些class的知识:

代码中的def __init__(self, num_classes=200):为定义构造函数,并且super(AlexNet, self).__init__()是继承父类的构造函数__init__(),从而使得AlexNet中包含了父类的一些变量以及方法。

言归正传,我们进入Module:

在该父类中,我们能看到如下的一个func:

def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                        _global_backward_hooks.values(),
                        self._backward_hooks.values()):
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

    __call__ : Callable[..., Any] = _call_impl

这里__call__ : Callable[..., Any] = _call_impl代表call函数调用了_call_impl,而在_call_impl中,我们能得到如下的函数执行顺序:

image.png

其中有result = self.forward(*input, **kwargs),从而使得所有的model函数在调用的时候变会调用call,于是调用forward。

关于call的详细解释参考:https://blog.csdn.net/dss_dssssd/article/details/83750838

后面会在解析一些关于hook的知识,希望大家Follow关注。

相关文章

网友评论

      本文标题:深入探究PyTorch中为什么model(image)会自动调用

      本文链接:https://www.haomeiwen.com/subject/unvahltx.html