本博文主要介绍了VGGNET的网络结构,并在cifar10数据集上实现了
VGGNET详解
VGG Net由牛津大学的视觉几何组(Visual Geometry Group)和 Google DeepMind公司的研究员一起研发的的深度卷积神经网络,在 ILSVRC 2014 上取得了第二名的成绩,将 Top-5错误率降到7.3%。它主要的贡献是展示出网络的深度(depth)是算法优良性能的关键部分。
VGGNET的网络结构如下图所示,VGGNET包含多层网络,深度从11层到19层不等,较为常用的是VGG16和VGG19,接下来我们以VGG16为例,即下图中的D,介绍VGGNET。

- 输入尺寸为
的图片,用64个
的卷积核作两次卷积和ReLU,卷积后的尺寸变为
。
- 池化层,使用
,池化单元大小为
,池化后尺寸变为
。
- 输入尺寸为
,使用128个
的卷积核作两次卷积和ReLU,尺寸改变为
。
- 池化层,使用
,池化单元大小为
,池化后尺寸变为
。
- 输入尺寸为
,使用256个
的卷积核作三次卷积和ReLU,尺寸改变为
。
- 池化层,使用
,池化单元大小为
,池化后尺寸变为
。
- 输入尺寸为
,使用512个
的卷积核作三次卷积和ReLU,尺寸改变为
。
- 池化层,使用
,池化单元大小为
,池化后尺寸变为
。
- 输入尺寸为
,使用512个
的卷积核作三次卷积和ReLU,尺寸改变为
。
- 池化层,使用
,池化单元大小为
,池化后尺寸变为
。
- 与两层1x1x4096,一层1x1x1000进行全连接+ReLU(共三层)。
- 通过softmax输出1000个预测结果。
VGGNET的特点
-
VGGNET全部使用
的卷积核和
的池化核,通过不断加深网络深度来提升性能。作者认为,两个
卷积层的串联相当于1个
的卷积层,3个
的卷积层串联相当于1个7*7的卷积层,即3个
卷积层的感受野大小相当于1个
的卷积层。但是3个
的卷积层参数量只有
的一半左右,同时前者可以有3个非线性操作,而后者只有1个非线性操作,这样使得前者对于特征的学习能力更强。
-
VGGNet的卷积层有一个显著的特点:特征图的空间分辨率单调递减,特征图的通道数单调递增。
代码实现
import torch.nn as nn
cfg = {
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
class VGG(nn.Module):
def __init__(self, vgg_name):
super(VGG, self).__init__()
self.features = self._make_layers(cfg[vgg_name])
self.classifier = nn.Linear(512, 10)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)]
in_channels = x
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)
def VGG11():
return VGG('VGG11')
def VGG13():
return VGG('VGG13')
def VGG16():
return VGG('VGG16')
def VGG19():
return VGG('VGG19')
网友评论