美文网首页
pytorch基础五(定义自动求导函数)

pytorch基础五(定义自动求导函数)

作者: 永远学习中 | 来源:发表于2018-12-02 12:23 被阅读0次

    本人学习pytorch主要参考官方文档莫烦Python中的pytorch视频教程。
    后文主要是对pytorch官网的文档的总结。
    代码来自pytorch官网

    import torch
    # 通过继承torch.autograd.Function类,并实现forward 和 backward函数
    class MyReLU(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            """
            在forward函数中,接收包含输入的Tensor并返回包含输出的Tensor。
            ctx是环境变量,用于提供反向传播是需要的信息。可通过ctx.save_for_backward方法缓存数据。
            """
            ctx.save_for_backward(input)
            return input.clamp(min=0)
    
        @staticmethod
        def backward(ctx, grad_output):
            """
            在backward函数中,接收包含了损失梯度的Tensor,
            我们需要根据输入计算损失的梯度。
            """
            input, = ctx.saved_tensors
            grad_input = grad_output.clone()
            grad_input[input < 0] = 0
            return grad_input
    
    dtype = torch.float
    device = torch.device("cpu")
    N, D_in, H, D_out = 64, 1000, 100, 10
    x = torch.randn(N, D_in, device=device, dtype=dtype)
    y = torch.randn(N, D_out, device=device, dtype=dtype)
    w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
    w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
    learning_rate = 1e-6
    for t in range(500):
        relu = MyReLU.apply
        y_pred = relu(x.mm(w1)).mm(w2)
        loss = (y_pred - y).pow(2).sum()
        print(t, loss.item())
        loss.backward()
        with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad
            w1.grad.zero_()
            w2.grad.zero_()
    

    相关文章

      网友评论

          本文标题:pytorch基础五(定义自动求导函数)

          本文链接:https://www.haomeiwen.com/subject/ctrjcqtx.html