美文网首页
pytorch api:torch.nn.utils.clip_

pytorch api:torch.nn.utils.clip_

作者: 魏鹏飞 | 来源:发表于2020-04-26 11:52 被阅读0次

    1. torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)

    Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

    Parameters
    • parameters (Iterable[Tensor] or Tensor) – an iterable of Tensors or a single Tensor that will have gradients normalized

    • max_norm (float or int) – max norm of the gradients

    • norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.

    Returns

    Total norm of the parameters (viewed as a single vector).

    2. torch.nn.utils.clip_grad_value_(parameters, clip_value)

    Clips gradient of an iterable of parameters at specified value.

    Gradients are modified in-place.

    Parameters
    • parameters (Iterable[Tensor] or Tensor) – an iterable of Tensors or a single Tensor that will have gradients normalized

    • clip_value (float or int) – maximum allowed value of the gradients. The gradients are clipped in the range [-clip_value,clip_value]

    SOURCE CODE

    import warnings
    import torch
    from torch._six import inf
    
    [[docs]](https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_)def clip_grad_norm_(parameters, max_norm, norm_type=2):
        r"""Clips gradient norm of an iterable of parameters.
    
     The norm is computed over all gradients together, as if they were
     concatenated into a single vector. Gradients are modified in-place.
    
     Arguments:
     parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
     single Tensor that will have gradients normalized
     max_norm (float or int): max norm of the gradients
     norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
     infinity norm.
    
     Returns:
     Total norm of the parameters (viewed as a single vector).
     """
        if isinstance(parameters, torch.Tensor):
            parameters = [parameters]
        parameters = list(filter(lambda p: p.grad is not None, parameters))
        max_norm = float(max_norm)
        norm_type = float(norm_type)
        if norm_type == inf:
            total_norm = max(p.grad.detach().abs().max() for p in parameters)
        else:
            total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
        clip_coef = max_norm / (total_norm + 1e-6)
        if clip_coef < 1:
            for p in parameters:
                p.grad.detach().mul_(clip_coef)
        return total_norm
    
    def clip_grad_norm(parameters, max_norm, norm_type=2):
        r"""Clips gradient norm of an iterable of parameters.
    
     .. warning::
     This method is now deprecated in favor of
     :func:`torch.nn.utils.clip_grad_norm_`.
     """
        warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
                      "of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
        return clip_grad_norm_(parameters, max_norm, norm_type)
    
    [[docs]](https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_)def clip_grad_value_(parameters, clip_value):
        r"""Clips gradient of an iterable of parameters at specified value.
    
     Gradients are modified in-place.
    
     Arguments:
     parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
     single Tensor that will have gradients normalized
     clip_value (float or int): maximum allowed value of the gradients.
     The gradients are clipped in the range
     :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
     """
        if isinstance(parameters, torch.Tensor):
            parameters = [parameters]
        clip_value = float(clip_value)
        for p in filter(lambda p: p.grad is not None, parameters):
            p.grad.data.clamp_(min=-clip_value, max=clip_value)
    
    

    Usage

    参考链接:
    https://pytorch.org/docs/stable/modules/torch/nn/utils/clip_grad.html#clip_grad_norm

    相关文章

      网友评论

          本文标题:pytorch api:torch.nn.utils.clip_

          本文链接:https://www.haomeiwen.com/subject/dpjdwhtx.html