美文网首页
Pytorch构建CNN模型用于图片分类

Pytorch构建CNN模型用于图片分类

作者: 吐舌小狗 | 来源:发表于2018-03-14 22:52 被阅读1614次

    1.源码

    1.1简单的例子:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class CNN(nn.Module):
      def __init__(self):
          super(imsimpleCNN, self).__init__()
          self.conv1 = nn.Conv2d(1, 20, 5)
          self.conv2 = nn.Conv2d(20, 20, 5)
    
      def forward(self, x):
          x = self.conv1(x)
          x = F.relu(x)
          x = self.conv2(x)
          x = F.relu(x)
          return x
    

    使用pytorch 构建神经网络模型比较简单,大体上需要定义两个函数,一个init,一个forward,分别用来声明一些模型用到的参数和模型的结构

    该网络结构由两层构成,(1, 20, 5)分别代表的是:input_channel, output_channel, 和kernel_size, 后面还有其他的参数,比如:stride = 1, padding = 0, dilation = 1, groups = 1, bias=True
    如果没有特别的设定,这些是默认的参数。

    1.2一个升级版的例子

    class imCNN(nn.Module):
        def __init__(self):
            super(imCNN, self).__init__()
            self.layer1 = nn.Sequential(
                nn.Conv2d(
                    in_channels=1,
                    out_channels=16,
                    kernel_size=5,
                    padding=2),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
            self.layer2 = nn.Sequential(
                nn.Conv2d(16, 32, kernel_size=5, padding=2),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
            self.fc = nn.Linear(7 * 7 * 32, 10)
    
        def forward(self, x):
            out = self.layer1(x)
            out = self.layer2(out)
            out = out.view(out.size(0), -1)
            out = self.fc(out)
            return out
    

    输入的图片是28*28的mnist,网络有两层的结构构成,每一层的结构都基本的相似,组成的模块有nn.Conv2d(), nn.BatchNorm2d, nn.Relu(), nn.MaxPool2d(), 经过两次的pooling, 最终特征的大小变为了 [ 7, 7, 32, 10], 再经过一个全连接层,将最终的输出变为10,从而实现分类。

    2.源码分析

    2.1 如何增加信的卷积层?

    2.2 model.train(), model.eval()的作用是什么?

    eval( )时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值

    2.3为什么加上dropout,在哪里加?

    3. 实验与分析

    相关文章

      网友评论

          本文标题:Pytorch构建CNN模型用于图片分类

          本文链接:https://www.haomeiwen.com/subject/zzzfqftx.html