美文网首页
pytorch关于y=ax+b线性回归demo

pytorch关于y=ax+b线性回归demo

作者: 郭百度 | 来源:发表于2022-12-30 15:11 被阅读0次
    image.png
    import random, torch
    import numpy as np
    import matplotlib.pyplot as plt
    
    class LinearRegression2(torch.nn.Module):
        def __init__(self):
            # y = a*x + b
            super(LinearRegression2, self).__init__()
            self.a = torch.rand(1, requires_grad=True)  # 参数a
            self.b = torch.rand(1, requires_grad=True)  # 参数b
            self.__parameters = dict(a=self.a, b=self.b)  # 参数字典
    
        def forward(self, inputs):
            # 这里根据公式y[...]=ax[...]+b forward拟合
            return self.a * inputs + self.b
    
        def parameters(self):
            # 全部输出参数字典
            for name, param in self.__parameters.items():
                yield param
    
    def pltDraw(x, a, b):
        # plt.clf()  # 清除图像
        plt.plot(x.data.numpy(), a * x.data.numpy() + b, "r",alpha=0.1)  # f(x)的图像
        plt.title('y=%s*x+%s' % (a, b))
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.pause(0.5)
    
    # 生成 x 与 y模拟数据,其中加入1-25随机噪声
    x = np.arange(50)
    y = np.array([2 * x[i] + 30 + random.randint(1, 25) for i in range(len(x))])
    x = torch.from_numpy(x.astype(np.float32))
    y = torch.from_numpy(y.astype(np.float32))
    plt.scatter(x, y) # 数据散点图
    
    plt.ion() #打开plt交互模式
    
    #线性回归正式开始
    net = LinearRegression2() #初始化网络
    learningRate = 0.001 #设定学习率
    optimizer = torch.optim.Adam(net.parameters(), lr=learningRate, weight_decay=0.005)
    loss_op = torch.nn.MSELoss(reduction='sum')
    
    for i in range(1, 200001):  # 20万次拟合
        out = net.forward(x) # 向前传播
        loss = loss_op(y, out) # 计算out矩阵与y矩阵 损失
        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 向后传播,计算梯度∇f
        optimizer.step()  # 更新参数
        loss_numpy = loss.cpu().detach().numpy()  # 得到损失的numpy值
    
        if i % 5000 == 0:  # 每5000次打印一下损失
            print(i, loss_numpy)
            a = net.a.cpu().detach().numpy()
            b = net.b.cpu().detach().numpy()
            print(a, b)
            pltDraw(x, a, b)
    
    plt.ioff()
    print(a,b)
    plt.show()
    
    

    下面是每500次画一条alpha=0.05 线的拟合过程


    image.png

    相关文章

      网友评论

          本文标题:pytorch关于y=ax+b线性回归demo

          本文链接:https://www.haomeiwen.com/subject/xbytcdtx.html