美文网首页PyTorch
Pytorch中requires_grad_(), detach

Pytorch中requires_grad_(), detach

作者: SnailTyan | 来源:发表于2020-06-01 10:15 被阅读0次

    文章作者:Tyan
    博客:noahsnail.com  |  CSDN  |  简书

    0. 测试环境

    Python 3.6.9, Pytorch 1.5.0

    1. 基本概念

    Tensor是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32

    • 示例一
    >>> a = torch.tensor([1.0])
    >>> a.data
    tensor([1.])
    >>> a.grad
    >>> a.requires_grad
    False
    >>> a.dtype
    torch.float32
    >>> a.item()
    1.0
    >>> type(a.item())
    <class 'float'>
    

    Tensor中只有一个数字时,使用torch.Tensor.item()可以得到一个Python数字。requires_gradTrue时,表示需要计算Tensor的梯度。requires_grad=False可以用来冻结部分网络,只更新另一部分网络的参数。

    • 示例二
    >>> a = torch.tensor([1.0, 2.0])
    >>> b = a.data
    >>> id(b)
    139808984381768
    >>> id(a)
    139811772112328
    >>> b.grad
    >>> a.grad
    >>> b[0] = 5.0
    >>> b
    tensor([5., 2.])
    >>> a
    tensor([5., 2.])
    

    a.data返回的是一个新的Tensor对象ba, bid不同,说明二者不是同一个Tensor,但ba共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b的元素时,a的元素也对应修改。

    2. requires_grad_()与detach()

    >>> a = torch.tensor([1.0, 2.0])
    >>> a.data
    tensor([1., 2.])
    >>> a.grad
    >>> a.requires_grad
    False
    >>> a.requires_grad_()
    tensor([1., 2.], requires_grad=True)
    >>> c = a.pow(2).sum()
    >>> c.backward()
    >>> a.grad
    tensor([2., 4.])
    >>> b = a.detach()
    >>> b.grad
    >>> b.requires_grad
    False
    >>> b
    tensor([1., 2.])
    >>> b[0] = 6
    >>> b
    tensor([6., 2.])
    >>> a
    tensor([6., 2.], requires_grad=True)
    
    • requires_grad_()

    requires_grad_()函数会改变Tensorrequires_grad属性并返回Tensor,修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=Truerequires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。

    • detach()

    detach()函数会返回一个新的Tensor对象b,并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。ba共享数据的存储空间,二者指向同一块内存。

    :共享内存空间只是共享的数据部分,a.gradb.grad是不同的。

    3. torch.no_grad()

    torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。

    >>> a = torch.tensor([1.0, 2.0], requires_grad=True)
    >>> with torch.no_grad():
    ...     b = n.pow(2).sum()
    ...
    >>> b
    tensor(5.)
    >>> b.requires_grad
    False
    >>> c = a.pow(2).sum()
    >>> c.requires_grad
    True
    

    上面的例子中,当arequires_grad=True时,不使用torch.no_grad()c.requires_gradTrue,使用torch.no_grad()时,b.requires_gradFalse,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True会占用更多的计算资源及存储资源。

    4. 总结

    requires_grad_()会修改Tensorrequires_grad属性。

    detach()会返回一个与计算图分离的新Tensor,新Tensor不会在反向传播中计算梯度,会在特定场合使用。

    torch.no_grad()更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。

    References

    1. https://pytorch.org/docs/stable/tensors.html
    2. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.requires_grad_
    3. https://pytorch.org/docs/stable/autograd.html#torch.Tensor.detach
    4. https://pytorch.org/docs/master/generated/torch.no_grad.html

    相关文章

      网友评论

        本文标题:Pytorch中requires_grad_(), detach

        本文链接:https://www.haomeiwen.com/subject/yadjzhtx.html