美文网首页
使用pytorch建立全连接神经网络

使用pytorch建立全连接神经网络

作者: 万州客 | 来源:发表于2022-06-29 09:56 被阅读0次

全连接,卷积,循环是最基础的三种神经网络了,需要掌握。

一,代码

import torch
from torch.autograd import Variable
from torch import tensor, Tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import datasets, transforms

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt


#定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20


# 简单的三层全连接神经网络
class SimpleNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


# 增加了激活函数的三层全连接神经网络
class ActivationNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


# 增加了批标准化的三层全连接神经网络
class BatchNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


data_tf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])]
)

train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = SimpleNet(28 * 28, 300, 100, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

model.eval()
eval_loss = 0
eval_acc = 0

for data in test_loader:
    img, label = data
    img = img.view(img.size(0), -1)
    # volatile已弃用,要改为torch.no_grad()
    # img = Variable(img, volatile=True)
    # label = Variable(label, volatile=True)
    with torch.no_grad():
        img = Variable(img)
        label = Variable(label)
    out = model(img)
    loss = criterion(out, label)
    eval_loss += loss.item() * label.size(0)
    _, pred = torch.max(out, 1)
    num_correct = (pred == label).sum()
    eval_acc += num_correct.item()
    print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss/(len(test_dataset)), eval_acc/(len(test_dataset))))

相关文章

  • 使用pytorch建立全连接神经网络

    全连接,卷积,循环是最基础的三种神经网络了,需要掌握。 一,代码

  • 图像语义分割实践(三)网络搭建与实现

    众所周知,神经网络搭建常用基础模块有卷积,池化,归一,激活,全连接等等。如果使用Pytorch进行网络的搭建时,除...

  • PyTorch Classification

     PyTorch 通过简单的途径来使用神经网络进行事物的分类. 更多可以查看官网 :* PyTorch 官网 建立...

  • 【Note】MV-机器学习系列 之 神经网络 PyTorch

    一、PyTorch 简介 1、Why PyTorch? PyTorch 的优势是建立的神经网络是动态的,比如 RN...

  • CNN

    卷积神经网络(Convolution Neural Network) 基于全连接层和CNN的神经网络示意图 全连接...

  • 全连接卷积神经网络 FCN

    (一)全连接卷积神经网络(FCN) (1) 全连接卷积神经网络简介 FCN是深度神经网络用于语义分割的奠基性工作,...

  • Pytorch学习之全连接识别MNIST数字

    Pytorch之全连接识别MNIST数字 导入库 设置超参数 数据预处理方法 数据集下载及获取 模型建立 确定损失...

  • 机器学习:卷积神经网络

    和全连接神经网络的主要差别 全连接神经网络:  每个神经元的输入数据,都使用了上一层的所有神经元的输出数据,每个神...

  • pytorch模型转keras模型

    1. 概述 使用pytorch建立的模型,有时想把pytorch建立好的模型装换为keras,本人使用Tensor...

  • 卷积神经网络

    CNN 一、卷积神经网络结构 1.全连接神经网络 2.卷积神经网络 全连接层存在的问题:数据的形状被“忽视”了例如...

网友评论

      本文标题:使用pytorch建立全连接神经网络

      本文链接:https://www.haomeiwen.com/subject/ljstbrtx.html