Pytorch中神经网络的核心就是autograd包。在这里将会简要介绍autograd,并在后续的章节介绍怎么训练神经网络。
Tensor
当设置Tensor的属性.requires_grad为true,pytorch会跟踪tensor上所有的operations。当完成计算,可以调用.backward()自动计算梯度。Tensor的梯度将会累加到grad属性。
为了停止一个tensor被跟踪,可以调用.detach()将tensor和计算历史分离,这样在接下来的计算就不会被跟踪。
为了阻止跟踪和使用内存,你可以用with torch.no_grad():将相关代码放进block中。当你评估一个模型的时候会特别有用,因为模型可能会有带required_grad=True的可训练参数,但事实上我们并不需要梯度。
还有一个非常重要的类:Function。Tensor和Function相互连接,并建立起非循环图,通过编码实现一个完整的计算过程。每一个通过Function函数计算得到的tensor都有一个.grad_fn属性(用户自己创建的tensor的.grad_fn是None)。
创建一个tensor并设置requires_grad=True来跟踪计算:
x = torch.ones(2, 2, requires_grad=True)
print(x)
输出结果:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
做一个operation:
y = x + 2
print(y)
输出结果:
tensor([[3., 3.],
[3., 3.]], grad_fn=<AddBackward0>)
因为y是加法操作返回的结果,所以它有grad_fn。
.requires_grad_()可以改变Tensor的requires_grad标志,输入参数如果没有设置,则会使用默认参数False。
a = torch.randn(2, 2)
print(a.grad_fn)
a = ((a * 3) / (a - 1))
print(a.grad_fn)
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
a = ((a * 3) / (a - 1))
print(a.grad_fn)
上面的例子可以仔细分析分析,一个tensor默认requires_grad属性是False,设置成True后,才会有grad_fn属性:
None
None
False
True
<DivBackward0 object at 0x000001C0CFE76580>
Gradients
现在可以反向传播计算了,首先举一个简单的例子,网络输出一个标量:
x = torch.randn(3, requires_grad=True)
y = x * 2
z = y.sum()
z.backward()
print(x.grad)
输出结果:
tensor([2., 2., 2.])
至于为什么输出这个结果,可以把pytorch代码转换成数学公式:
dz/dx1 = d(y1+y2+y3)/dx1
= d(2*x1+2*x2+2*x3)/dx1
= 2
dz/dx2 = d(y1+y2+y3)/dx2
= d(2*x1+2*x2+2*x3)/dx2
= 2
dz/dx3 = d(y1+y2+y3)/dx3
= d(2*x1+2*x2+2*x3)/dx3
= 2
现在把问题考虑的更复杂一些,网络最终输出的结果是一个向量:
x = torch.randn(3, requires_grad=True)
y = x * 2
y.backward()
print(x.grad)
输出结果:
RuntimeError: grad can be implicitly created only for scalar outputs
注意这个时候backward()函数需要传入参数,参数大小与x相同:
x = torch.randn(3, requires_grad=True)
y = x * 2
v = torch.ones(3, dtype=torch.float)
y.backward(v)
print(x.grad)
输出结果:
tensor([2., 2., 2.])
其实这里关于backward函数输入参数v我并不是很明白,这个参数可以看做是每个输出对某个输入变量偏导的权重,通过输入参数v实现偏导的加权求和:
v*dy/dx1 = v*[dy1/dx1 dy2/dx1 dy3/dx1]
= v1*dy1/dx1 + v2*dy2/dx1 + v3*dy3/dx1
但是如何设置v的值我还没搞清楚,还有就是数学部分的推导不太好编写,这里只能实现大致的想法,有一些表述上的错误,大家可以提出来,我有时间会尽量更正。
结语
本节代码可以从github上下载,github上提供了cpp和python两个版本的入门程序,仅供大家参考。
网友评论