美文网首页
Pytorch Custom Function用法

Pytorch Custom Function用法

作者: 昵称己存在 | 来源:发表于2020-04-10 22:35 被阅读0次

torch.autograd.Function

https://pytorch.org/docs/master/notes/extending.html

>>> class Exp(Function):
>>>
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
  • 从pyotrch1.3开始,forward() backward() 都必须是@staticmethod.

backward(ctx, *grad_outputs)

  • 默认第一个参数是ctx
  • ctx后面跟的参数个数和forward() 的return个数相同(??)
  • Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input.(??)
  • gradient w.r.t (???)

Example ReLU pytorch defining new autograd functions

当需要对方程传入参数 non-Tensor arguments

class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.constant = constant
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

相关文章

网友评论

      本文标题:Pytorch Custom Function用法

      本文链接:https://www.haomeiwen.com/subject/qtbqmhtx.html