美文网首页
PyTorch系列二:线性回归

PyTorch系列二:线性回归

作者: 八宝粥BBZ | 来源:发表于2019-02-01 15:59 被阅读2次

    1.介绍

    线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,运用十分广泛。其表达形式为y = w'x+e,e为误差服从均值为0的正态分布。
    回归分析中,只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,这种回归分析称为一元线性回归分析。如果回归分析中包括两个或两个以上的自变量,且因变量和自变量之间是线性关系,则称为多元线性回归分析。

    2.模型训练

    # -*- coding: utf-8 -*-
    
    import torch
    from torch import nn, optim
    from torch.autograd import Variable
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    num_epoches   = 10000
    learning_rate = 1e-3 
    
    class LinearRegression(nn.Module):
        """线性回归模型定义"""
    
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)
    
        def forward(self, x):
            # 前向传播
            output = self.linear(x)
            return output
    
    # 模型初始化
    model = LinearRegression()
    
    # 定义loss和优化函数
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # 输入数据
    x_train = np.array([1, 2, 3, 4, 5], dtype=np.float32).reshape(-1, 1)
    y_train = np.array([2, 4, 6, 8, 10], dtype=np.float32).reshape(-1, 1)
    
    # 将np.array转换成Tensor
    x_train = torch.from_numpy(x_train)
    y_train = torch.from_numpy(y_train)
    
    # 模型训练
    for epoch in range(num_epoches):
        inputs = Variable(x_train)
        target = Variable(y_train)
        # forward
        out  = model(inputs)
        loss = criterion(out, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print loss
        if (epoch+1) % 100 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1, num_epoches, loss.item()))
    
    # 模型保存
    torch.save(model.state_dict(), './Linear_Regression.model')
    
    # 模型评估
    model.eval()
    predict = model(Variable(x_train))
    predict = predict.data.numpy()
    
    # 画图
    plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
    plt.plot(x_train.numpy(), predict, label='Fitting Line')
    plt.legend()
    plt.show()
    

    3.拟合图

    image.png

    相关文章

      网友评论

          本文标题:PyTorch系列二:线性回归

          本文链接:https://www.haomeiwen.com/subject/udjdsqtx.html