(一)NiN (network in network) 网络中的网络
(1)NiN简介
LeNet、AlexNet和VGG都有一个共同的设计模式:通过一系列的卷积层与汇聚层来提取空间结构特征;然后通过全连接层对特征的表征进行处理。
AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。
网络中的网络(NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。
NiN网络虽然并不出名,但是他的思想和后来的很多网络的设计思想类似。我们可以了解了解。
(2)NiN块
和VGG一样都提出了快的概念,但是NiN块中第一层是卷积层,但是后面两层使用的是11的卷积核。这两个卷积层充当带有ReLU激活函数的逐像素全连接层。从另一个角度说,11的卷积核是将不同通道的信息进行了融合。
(3)NiN模型
前面有多个NiN块组成,每个NiN块后面都增加了一个最大池化层,stride=2,将图片进行压缩。最后通过一个全局平均池化层(AdaptiveAvgPool),将输出的维度控制为(-1,10).一个通道就想到于一个类别,这就等同于全连接层的输出层。
(二)代码实现
import torch
from torch import nn
from torchvision import transforms
import torchvision
from torch.utils import data
from d2l import torch as d2l
import numpy as np
import matplotlib.pyplot as plt
def nin_block(in_channels,out_channels,kernel_size,stride,padding):
return nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size= kernel_size, stride= stride,padding = padding),nn.ReLU(),
nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),
nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),
)
net = nn.Sequential(
nin_block(1,96,kernel_size= 11,stride=4,padding= 0),
nn.MaxPool2d(kernel_size= 3,stride=2),
nin_block(96,256,kernel_size= 5,stride=1,padding= 2),
nn.MaxPool2d(kernel_size= 3,stride=2),
nin_block(256,384,kernel_size= 3,stride=1,padding= 1),
nn.MaxPool2d(kernel_size= 3,stride=2),
nn.Dropout(0.5),
nin_block(384,10,kernel_size= 3,stride=1,padding= 1),
nn.AdaptiveAvgPool2d(output_size=(1,1)),
nn.Flatten()
)
x = torch.rand(size=(1,1,224,224))
for layer in net:
x = layer(x)
print(layer.__class__.__name__,"\t\toutput shape:",x.shape)
# 现在使用mnist数据集测试一下结果
def load_data_fashion_mnist(batch_size, resize=None):
"""下载或者加载Fashion-MNIST数据集"""
trans = [transforms.ToTensor()]
if resize:
# 需要把图片拉长,正常时不会这么做的
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans) # 这是一步可以去掉的操作,这个就是把多个图像处理的步骤整合到一起
mnist_train = torchvision.datasets.FashionMNIST(
root="../data/",
train=True,
transform=trans,
download=False # 要是没下载过就选择true
)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data/",
train=False,
transform=trans,
download=False # 要是没下载过就选择true
)
return (data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=0),
data.DataLoader(mnist_test,batch_size=batch_size,shuffle=True,num_workers=0))
batch_size = 32
learning_rate = 0.05
epochs = 3
train_iter,test_iter = load_data_fashion_mnist(batch_size,resize=(224))
d2l.train_ch6(net,train_iter,test_iter,epochs,lr=learning_rate,device=d2l.try_gpu())
网友评论