Framework: Pytorch 0.3.0
关于pytorch 反传计算图一点小心得:
-
计算图记录了variable的所有操作。若需要此变量反向求导,则一开始参与计算时该以variable出现。
-
直接改动variable下的tensor(即variable.data),则其改动在反传中无效。
-
计算时最好进行矩阵整体运算,单独对矩阵中某些数值进行改动,计算图代价很大。
如,我们的目标是将Mat(cos_\theta)
在j=label
处替换成phi_\theta
如下代码,如果对phi_theta[i,j]
进行单独操作,由于公式1
中的cos_theta[i,j]
是variable。因此,在计算图中,矩阵中的每个元素都得单独反传。for i in range(cos_theta.shape[0]): j=target[i].data[0] if cos_theta[i,j].data[0] >= -self.cosm: phi_theta[i,j]=self.cosm * cos_theta[i,j] - \ self.sinm * torch.sqrt(1e-6+1-cos_theta[i,j]*cos_theta[i,j]) #公式1 else: phi_theta[i,j] = cos_theta[i,j]
将代码改成
for i in range(cos_theta.shape[0]):
j=target[i].data[0]
if cos_theta[i,j].data[0] >= -self.cosm:
flagMat[i,j]= 1
else:
flagMat[i,j] = 0
flagMat=Variable(flagMat)
phi_theta=(self.cosm * cos_theta - self.sinm * torch.sqrt(1e-6+1-cos_theta*cos_theta))*flagMat\
+(1-flagMat)*cos_theta
新构造指示矩阵flatMat
与公式1
进行运算,且flatMat
不参与反传,得到的phi_\theta.grad_fn
为公式1backward
,计算图恢复正常!life will be better!
网友评论