传染源:requires_grad
pytorch中最重要的数据结构:tensor。每个tensor有一个很重要的属性requires_grad,该属性决定了在神经网络的迭代过程中,当前tensor是否需要计算梯度,而且该属性是会传染的,只要某个tensor的requires_grad为True,则任何由该tesnor派生出来的任何tensor的requires_grad属性都会变成True。(除非有特殊处理)
import torch
x = torch.ones(2,2,requires_grad=True)
print(x,x.requires_grad)
y = torch.randn_like(x)
print(y,y.requires_grad)
z = x + y
print(z,z.requires_grad)
------------------out:
tensor([[1., 1.],
[1., 1.]], requires_grad=True) True
tensor([[ 0.9801, 0.6068],
[-0.8965, 1.4186]]) False
tensor([[1.9801, 1.6068],
[0.1035, 2.4186]], grad_fn=<ThAddBackward>) True
身份标识(轨迹图):grad_fn
既然前面吧requires_grad称为传染源,那么在pytorch的tensor中还有一个标识:grad_fn,该标识用于标记在requires_grad为True的情况下,当前tensor是通过什么操作得到的。最终将所有的身份标识组合起来就可以得到当前tensor的轨迹图。
import torch
import torch.nn as nn
x = torch.ones(2,2,requires_grad=True)
print(x,x.requires_grad)
y = torch.randn_like(x)
print(y,y.requires_grad)
z = x + y
print(z,z.requires_grad)
m = z*z
print(m)
m = m.view([2,2,1,1])
print(m)
conv = nn.Conv2d(2,2,1)
n = conv(m)
print(n)
--------------------out:
tensor([[1., 1.],
[1., 1.]], requires_grad=True) True
tensor([[-0.6495, 0.6366],
[-0.4593, -1.5528]]) False
tensor(..., grad_fn=<ThAddBackward>) True
tensor(..., grad_fn=<ThMulBackward>)
tensor(..., grad_fn=<ViewBackward>)
tensor(..., grad_fn=<ThnnConv2DBackward>)
毁“尸”灭迹:detach
若对某个tensor使用detach方法,则该tensor的requires_gard,将会变为false,从而该变量的身份标识(grad_fn)也就被抹掉了。
import torch
import torch.nn as nn
x = torch.ones(2,2,requires_grad=True)
print(x,x.requires_grad)
y = torch.randn_like(x)
print(y,y.requires_grad)
z = x + y
print(z,z.requires_grad)
m = z*z
print(m)
m = m.detach()
p = m.view([2,2,1,1])
print(p)
conv = nn.Conv2d(2,2,1)
n = conv(p)
print(n)
----------------------out:
tensor([[1., 1.],
[1., 1.]], requires_grad=True) True
tensor([[-1.2428, 0.2750],
[ 1.0678, -0.7844]]) False
tensor([[-0.2428, 1.2750],
[ 2.0678, 0.2156]], grad_fn=<ThAddBackward>) True
tensor([[0.0590, 1.6255],
[4.2759, 0.0465]], grad_fn=<ThMulBackward>)
tensor([[[[0.0590]],
[[1.6255]]],
[[[4.2759]],
[[0.0465]]]]) False
tensor(..., grad_fn=<ThnnConv2DBackward>)
网友评论