美文网首页
2.nn.module的使用

2.nn.module的使用

作者: 三角绿毛怪 | 来源:发表于2020-10-14 13:41 被阅读0次
    import torch
    from torch import nn
    from torch import optim#优化器类
    # 1. torch.optim.SGD(参数,学习率)
    # 2. torch.optim.Adam(参数,学习率)
    import numpy as np
    from matplotlib import pyplot as plt
    
    # 1. 定义数据
    x = torch.rand([50, 1])
    y = x * 3 + 0.8
    
    
    # 2 .定义模型
    class Lr(nn.Module):
        def __init__(self):
            #__init__需要调用super方法,继承父类的属性和方法
            super(Lr, self).__init__()
            #nn.Linear()为全链接层,参数分别是输入的数量和输出的数量
            self.linear = nn.Linear(1, 1)
    
        def forward(self, x):
            #前向传播的过程
            #实际上是Lr的实例,将x传入,用里面的forward方法传入参数,得到输出
            out = self.linear(x)
            return out
    
    
    # 2. 实例化模型,loss,和优化器
    model = Lr()
    #定义损失函数
    criterion = nn.MSELoss()
    #定义优化的方法为随机梯度下降,学习率为1e-3
    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    # 3. 训练模型
    for i in range(20000):
        #这里的四个步骤是优化类的使用方法
        out = model(x)  # 3.1 获取预测值
        loss = criterion(y, out)  # 3.2 计算损失
        optimizer.zero_grad()  # 3.3 梯度归零
        loss.backward()  # 3.4 计算梯度
        optimizer.step()  # 3.5 更新梯度
        if (i + 1) % 2000 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))
    
    
    

    相关文章

      网友评论

          本文标题:2.nn.module的使用

          本文链接:https://www.haomeiwen.com/subject/yuxwpktx.html