美文网首页
(三)Pytorch神经网络

(三)Pytorch神经网络

作者: 计算机视觉__掉队选手 | 来源:发表于2019-04-06 17:28 被阅读0次

一个完整的神经网络训练总体流程:
1.定义神经网络
2.输入数据进行迭代
3.损失函数计算损失
4.梯度反向传播
5.更新网络权重参数

定义神经网络

import torch
import torch.nn.functional as F
#第一种定义方式
class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden = torch.nn.Linear(n_feature,n_hidden)
        self.out = torch.nn.Linear(n_hidden,n_output)
   def forward(self,x):
        x = F.relu(self.hidden(x))
        x = self.out(x)
        return x
net1 = Net(n_feature=2,n_hidden=10,n_output=2)
print(net)
#第二种定义方式
net2 = torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,2)
    )
print(net2)

输入数据进行迭代

n_data = torch.ones(100,2)
x0 = torch.normal(2*n_data,1)
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data,1)
y1 = torch.ones(100)
x = torch.cat((x0,x1),0).type(torch.FloatTensor)
y = torch.cat((y0,y1),).type(torch.LongTensor)
x,y = Variable(x),Variable(y)
out=net(x)

损失函数计算损失

损失函数包括L1损失函数、MSE损失函数、交叉熵损失函数等

loss = torch.nn.CrossEntropyLoss()
loss(out,y)

梯度下降反向传播

optim = torch.nn.optim.SGD(net.parameters(),lr=0.02)
optim.zero_grad()
loss.backward()
#更新参数
optim.step()

参考链接:https://github.com/MorvanZhou/PyTorch-Tutorial/tree/master/tutorial-contents

相关文章

  • Pytorch教程

    Pytorch 神经网络基础 1.1 Pytorch & Numpy 1.1.1 用Torch还是Numpy To...

  • 【Note】MV-机器学习系列 之 神经网络 PyTorch

    一、PyTorch 简介 1、Why PyTorch? PyTorch 的优势是建立的神经网络是动态的,比如 RN...

  • Pytorch 任务六

    PyTorch理解更多神经网络优化方法

  • pytorch神经网络拟合y = x^2

    pytorch神经网络拟合 在windows的环境下,使用pytorch拟合。实验中仅使用了三个神经元去拟合该函数...

  • 深度学习之PyTorch

    PyTorch常用导入的包 Dataset(数据集) nn.Module(模组) 在Pytorch里编写神经网络,...

  • PyTorch 训练

     PyTorch 训练与加速神经网络训练. 更多可以查看官网 :* PyTorch 官网 批训练 Torch 中提...

  • Pytorch-nlp开源工具(一)

    摘要:本分主要分享Pytorch NLP开源工具, PyTorch-NLP或torchnlp简称为神经网络层,文本...

  • (三)Pytorch神经网络

    一个完整的神经网络训练总体流程:1.定义神经网络2.输入数据进行迭代3.损失函数计算损失4.梯度反向传播5.更新网...

  • PyTorch极简教程

    PyTorch:深度学习框架,神经网络界的Numpy。所以PyTorch的使用方式基本与Numpy一致。 Tens...

  • PyTorch Classification

     PyTorch 通过简单的途径来使用神经网络进行事物的分类. 更多可以查看官网 :* PyTorch 官网 建立...

网友评论

      本文标题:(三)Pytorch神经网络

      本文链接:https://www.haomeiwen.com/subject/lcwriqtx.html