美文网首页
torch.optim优化器

torch.optim优化器

作者: 星光下的胖子 | 来源:发表于2020-05-21 19:57 被阅读0次

    使用流程:

    # torch.optim优化器的使用
    '''
    步骤:
    1、自定义神经网络模型
        1)初始化模型参数
        2)重载前向传播forward()
    2、定义优化器,并使用优化器实现参数优化
        1)创建优化器实例(有多种优化器类型,常用SGD和Adam)
        2)优化参数:
            1.清空梯度
            2.前向传播
            3.计算loss
            4.反向传播(根据loss来计算梯度)
            5.参数更新(根据梯度来更新参数)
    '''
    import torch
    from torch import optim
    import matplotlib.pyplot as plt
    
    # 1、自定义模型
    class Line(torch.nn.Module):
        def __init__(self):
            super(Line, self).__init__()
            # 初始化模型参数,并设置为Parameter类型
            self.w = torch.nn.Parameter(torch.rand(1))
            self.b = torch.nn.Parameter(torch.rand(1))
    
        # 重载前向传播:数据输入到模型中,返回结果
        def forward(self, x):
            return self.w * x + self.b
    
    
    if __name__ == '__main__':
        # 构建数据
        xs = torch.arange(0, 1, 0.01)
        ys = 3 * xs + 4 + 0.01 * torch.rand(100)
    
        # 模型实例化
        line = Line()
    
        # 2、定义优化器
        # SGD继承Optimizer
        opt = optim.SGD(line.parameters(), lr=0.1)
        # opt = optim.Adam(line.parameters(), lr=0.1)
    
        # 模型训练
        plt.ion()
        for epoch in range(30):
            for x, y in zip(xs, ys):
                # 梯度清空
                # 优化器optim的梯度是被累积的而不是被替换掉,每次先调用zero_grad()清空梯度
                opt.zero_grad()
                # 前向传播(父类Module的__call__方法调用forward()方法)
                h = line(x)
                # 计算loss
                loss = (h - y) ** 2
                # 反向传播:计算当前张量loss的梯度,并保存到buffer中
                # loss必须是torch.tensor(张量)数据类型
                loss.backward()
                # 参数更新:根据梯度来更新参数
                # step()函数从buffer中获取梯度
                opt.step()
    
            # plt绘图
            print(line.w.item(), line.b.item(), loss.item())
            # plt.cla()
            plt.plot(xs, ys, ".")
            plt.plot(xs, [line.w * x + line.b for x in xs])
            plt.pause(0.01)
        plt.ioff()
        plt.show()
    

    拟合过程可视化:


    拟合过程

    相关文章

      网友评论

          本文标题:torch.optim优化器

          本文链接:https://www.haomeiwen.com/subject/dfgtahtx.html