-
grad
该属性值默认为None,在第一次调用backward()方法后,该值会附上一个数值,并且grad属性值会在之后的每次调用backward()方法时进行累计,这就是为什么我们在训练网络时,每次迭代计算backward()之前需要进行zero_grad的操作的原因。 -
requires_grad
若当前tensor需要被计算梯度,则该属性值需要为True,否则该属性值为False。但是需要注意的是tensor是否需要被计算梯度和tensor的grad属性是否有值并不是等价的,具体可以看如下is_leaf这个属性的介绍。 -
backward()
计算当前tensor关于图中叶子节点的梯度值,需要注意的是上述提到过的该梯度值是会累计到对应的叶子节点中的。 -
register_hook()
该方法在backward()方法之前调用,通过这个钩子可以在backward()方法的执行中,人为的改变grad属性的值,具体使用方法看如下例子。 -
detach()/detach_()
detach()方法把计算图中的某个值分离出来,并重新生成一个tensor。
detach_()方法将某个值从计算图中分离出来,并生成叶子节点,具体使用方法看如下代码。 -
is_leaf
叶子节点的特性:
在调用tensor的backward()方法时,只有叶子节点的grad属性会被赋值。如果要对非叶子节点的grad属性进行赋值,则需要在使用backward()方法之前调用retain_grad()方法。
哪些节点被称为叶子节点:
所有requires_grad属性值为false的tensor都被称为叶子节点。当某个tensor_grad属性为True时,当且仅当该tensor是被用户直接创建,没有被作用于其他任何操作作用于该tensor的情况下,该节点也被称为叶子节点。
import torch
def test_not_leaf_get_grad(t):
t.retain_grad()
def print_isleaf(t):
print('*'*5)
print('{} is leaf {}!'.format(t,t.is_leaf))
def print_grad(t):
print('{} grad is {}'.format(t,t.grad))
def double_grad(grad):
grad=grad*2
return grad
if __name__ == "__main__":
#init
input_ = torch.randn(1,requires_grad=True)
output = input_ * input_
output2 = output*2
#test retain_grad()
test_not_leaf_get_grad(output)
#test register_hook()
output2.register_hook(double_grad)
output2.backward()
#test leaf
print_isleaf(input_)
print_grad(input_)
#test no leaf
print_isleaf(output)
print_grad(output)
# output
*****
tensor([],requires_grad=True) is leaf True !
tensor([],requires_grad=True) grad is tensor([2.])
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is None
## result for retain_grad
*****
tensor([],requires_grad=True) is leaf True !
tensor([],requires_grad=True) grad is tensor([2.])
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is tensor([1.])
## result for register_hook
*****
tensor([],requires_grad=True) is leaf True !
tensor([1.1],requires_grad=True) grad is tensor([8.8.],grad_fn=<CloneBackward>)
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is tensor([4.])
## test detach()
*****
tensor([],requires_grad=True) is leaf True !
tensor([1.1],requires_grad=True) grad is tensor([8.8.],grad_fn=<CloneBackward>)
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is tensor([4.])
*****
tensor([]) is leaf True !
tensor([]) grad is tensor([4.])
网友评论