美文网首页
PyTorch 基础(4) 线性回归

PyTorch 基础(4) 线性回归

作者: sixfold_yuan | 来源:发表于2017-05-31 09:56 被阅读0次

给定一个数据点集合X和对应的目标值y,线性模型的目标就是找到一条使用向量w和位移b描述的线,来尽可能地近似每个样本X[i]y[i]。用数学符号来表示就是:


并最小化所有数据点上的平方误差

实际上线性模型是最简单但也可能是最有用的神经网络。一个神经网络就是一个由节点(神经元)和有向边组成的集合。我们一般把一些节点组成层,每一层使用下一层的节点作为输入,并输出给上面层使用。为了计算一个节点值,我们将输入节点值做加权和,然后再加上一个激活函数。对于线性回归而言,它是一个两层神经网络,其中第一层是(下图橙色点)输入,每个节点对应输入数据点的一个维度,第二层是单输出节点(下图绿色点),它使用身份函数(f(x)=x)作为激活函数。
线性回归

创建数据集

使用如下方法来生成数据,这里噪音服从均值0和标准差为0.01的正态分布
y[i] = 2 * X[i][0] - 3.4 * X[i][1] + 4.2 + noise

from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np 
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader


num_inputs = 2
num_examples = 1000

true_w = [2, -3.4]
true_b = 4.2

x = torch.randn(num_examples, num_inputs)
y = true_w[0] * x[:, 0] + true_w[1] * x[:, 1] + true_b

y = y + torch.randn(y.size()) * 0.01

注意到X的每一行是一个长度为2的向量,而y的每一行是一个长度为1的向量(标量)。

数据读取

使用torch.utils.data模块来读取数据

dataset = TensorDataset(x, y)
trainloader = DataLoader(dataset, batch_size=256, shuffle=True)

for data, label in trainloader:
    print(data, label)
    break

定义模型

pytorch有大量预定义的层,我们只需要关注使用哪些层来构建模型。例如线性模型就是使用对应的Linear层

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(2,1)
        print(self.fc.weight)

    def forward(self, x):
        x = self.fc(x)
        return x

net = Net() 

损失函数

平方误差函数

criterion = nn.MSELoss()

优化

使用SGD优化算法,学习率设为0.1

optimizer = optim.SGD(net.parameters(), lr=0.1)

训练

epochs = 100
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        data = Variable(inputs)
        label = Variable(labels).float()

        optimizer.zero_grad()

        out = net(data)

        loss = criterion(out, label)

        
        loss.backward()
        optimizer.step()

        total_loss = total_loss + loss.data[0]
    print("Epoch %d, average loss: %f" % (epoch, total_loss/num_examples))


params = list(net.parameters())
print(params[0])
print(params[1])

参考资料

pytorch官网
动手学深度学习

相关文章

  • PyTorch 基础(4) 线性回归

    给定一个数据点集合X和对应的目标值y,线性模型的目标就是找到一条使用向量w和位移b描述的线,来尽可能地近似每个样本...

  • 第一次打卡

    线性回归主要内容包括: 线性回归的基本要素线性回归模型从零开始的实现线性回归模型使用pytorch的简洁实现线性回...

  • 动手学深度学习(一) 线性回归

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现...

  • 线性回归

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现...

  • 第一天-线性回归,Softmax与分类模型,多层感知机

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现...

  • Pytorch实现共享单车数量预测

    之前分享过Pytorch实现简单线性回归算法的内容:Pytorch实现简单的线性回归算法,这次分享一下用pytor...

  • 【机器学习快速入门教程4】线性回归

    章节4:线性回归 本章节,我们将介绍线性回归问题,机器学习中最基础的问题。 线性回归 线性回归是指在一组数据中拟合...

  • 动手学深度学习-01打卡

    线性回归 主要内容包括:1.线性回归的基本要素2.线性回归模型从零开始的实现3.线性回归模型使用pytorch的简...

  • 从回归到临床模型(一)

    一.回归基础知识 二.线性回归 1.拟合线性模型 2.简单线性模型 3.多项式回归 4.多元线性回归 5.回归诊断...

  • 「动手学深度学习」线性回归

    1. 主要内容 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用PyTorch的简洁实现 2. 线...

网友评论

      本文标题:PyTorch 基础(4) 线性回归

      本文链接:https://www.haomeiwen.com/subject/krvdfxtx.html