美文网首页计算机科学家的自我修养
ResNet 在mnist数据集的实验

ResNet 在mnist数据集的实验

作者: 火星小熊熊 | 来源:发表于2018-02-27 23:38 被阅读0次
    模型:
        1,使用30个Residual Block(所有Residual Block中均使用batch normalization)
        2,使用Adam优化器,学习率按指数动态衰减
        3,所有层的channel个数均为256
        4,不使用pooling
    
    结果:
        训练比较快,8块卡训练2分钟即可到98.3%的测试准确率
    
    结论:
        nn中显存占用率与batch size是成正比的线性关系
    
    from __future__ import print_function
    import numpy as np
    import argparse
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.autograd import Variable
    
    
    parser = argparse.ArgumentParser(description = 'pyTorch MNIST Example')
    parser.add_argument('--batch-size', type = int, default = 512, help = 'batch size')
    parser.add_argument('--lr', type = float, default = 1e-3, help = 'learning rate')
    
    args = parser.parse_args()
    
    print(args)
    
    kwargs = {'num_workers': 1, 'pin_memory': True}
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            './data', train = True, download = True, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))])
            ),
        batch_size = args.batch_size,
        shuffle = True
    )
    
    
    test_data = datasets.MNIST(
            './data', train=False, download = True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
            )
    
    test_x = (Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255).cuda()
    
    test_y = (test_data.test_labels[:2000]).cuda()
       
    class Residual(nn.Module):
        def __init__(self, use_bn = True, input_channels = 256, out_channels = 256):
            super(Residual, self).__init__()
            self.use_bn = use_bn
            self.out_channels   = out_channels
            self.input_channels = input_channels
            self.mid_channels   = input_channels // 2
    
            self.down_channel = nn.Conv2d(input_channels, self.mid_channels, kernel_size = 1)
            self.AcFunc       = nn.ReLU()
            if use_bn:
                self.bn_0 = nn.BatchNorm2d(num_features = self.mid_channels)
                self.bn_1 = nn.BatchNorm2d(num_features = self.mid_channels)
                self.bn_2 = nn.BatchNorm2d(num_features = self.out_channels)
    
            self.conv = nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size = 3, padding = 1)
    
            self.up_channel = nn.Conv2d(self.mid_channels, out_channels, kernel_size= 1)
    
            if input_channels != out_channels:
                self.trans = nn.Conv2d(input_channels, out_channels, kernel_size = 1)
        
        def forward(self, inputs):
            x = self.down_channel(inputs)
            if self.use_bn:
                x = self.bn_0(x)
            x = self.AcFunc(x)
    
            x = self.conv(x)
            if self.use_bn:
                x = self.bn_1(x)
            x = self.AcFunc(x)
    
            x = self.up_channel(x)
    
            if self.input_channels != self.out_channels:
                x += self.trans(inputs)
            else:
                x += inputs
    
            if self.use_bn:
                x = self.bn_2(x)
            
            return self.AcFunc(x)
    
    class Net(nn.Module):
        def __init__(self, residual_stack = 30, use_bn = True, pre_layer = 5):
            super(Net, self).__init__()
            self.preprecess = nn.Sequential(
                nn.Conv2d(1, 256, kernel_size = 1),
                nn.BatchNorm2d(num_features = 256),
                nn.ReLU()
            )
    
            for _ in range(pre_layer):
                self.preprecess.add_module(
                    name = 'pre_layer' + str(_), module = nn.Conv2d(256, 256, kernel_size = 3, padding = 1)
                )
    
                if use_bn:
                    self.preprecess.add_module(
                        name = 'bn_pre_layer' + str(_), module = nn.BatchNorm2d(num_features = 256)
                    )
    
                self.preprecess.add_module(
                    name = 'relu' + str(_), module = nn.ReLU()
                )
            
            self.residual_blocks = nn.Sequential()
    
            for _ in range(residual_stack):
                self.residual_blocks.add_module(
                    name = 'residual' + str(_), module = Residual(use_bn = use_bn)
                )
            
            self.out_layer = nn.Sequential(
                nn.Linear(in_features = 28 * 28 * 256, out_features = 10)
            )
        
        def forward(self, inputs):
            x = self.preprecess(inputs)
            x = self.residual_blocks(x)
            x = self.out_layer(x.view(-1,200704))
            return x
    
    net = Net(residual_stack = 30)
    net = torch.nn.DataParallel(net)
    net.cuda()
    
    optimizer = optim.Adam(net.parameters(), lr = args.lr)
    loss_function = nn.CrossEntropyLoss()
    
    def adjust_learning_rate(optimizer, epoch, lr):
        LR = lr * (0.9 ** (epoch / 2))
        for param_group in optimizer.param_groups:
            param_group['lr'] = LR
    
    def train(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = Variable(data.cuda()), Variable(target.cuda())
            optimizer.zero_grad()
            output = net(data)
            loss = loss_function(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx > 0 and batch_idx % 10 == 0:
                predict_output = net(test_x)
                pred_y = torch.max(predict_output, 1)[1].data.squeeze()  
                accuracy = sum(pred_y == test_y) / test_y.size(0)
                tips = 'epoch:{}, batch_idx:{}, percetange:{:.4f}, accuracy:{}'.format(epoch, batch_idx, batch_idx / len(train_loader),accuracy)
                print(tips)
    
        adjust_learning_rate(optimizer, epoch, args.lr)
    
    if __name__ == '__main__':
        for _ in range(100):
            train(_)
    
    

    相关文章

      网友评论

        本文标题:ResNet 在mnist数据集的实验

        本文链接:https://www.haomeiwen.com/subject/btuvxftx.html