美文网首页
pytorch下利用RNN实现mnist数据集的分类 简易代码

pytorch下利用RNN实现mnist数据集的分类 简易代码

作者: melo4 | 来源:发表于2018-03-08 21:37 被阅读0次

    使用的模型为LSTM

    parameters

    EPOCH = 1
    BATCH_SIZE = 64
    TIME_STEP = 28 # rnn time step / image height
    INPUT_SIZE = 28 # rnn input size / image width
    LR = 0.01
    DOWNLOAD_MNIST = False

    RNN 模型搭建:

    class RNN(nn.Module):
        def __init__(self):
            super(RNN,self).__init__()
            self.rnn = nn.LSTM(
                input_size=INPUT_SIZE,
                hidden_size=64,
                num_layers=1,
                batch_first=True,
            )
            self.out = nn.Linear(64,10)
    
        def forward(self, x):
            r_out,(h_n, h_c) = self.rnn(x,None)
            out = self.out(r_out[:,-1,:])
            return out
    rnn =RNN()
    
    optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
    loss_func = nn.CrossEntropyLoss()
    
    

    完整代码:

    import torch
    from torch import nn
    from torch.autograd import Variable
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    import torch.utils.data as Data
    import matplotlib.pyplot as plt
    
    # Hyper Parameters
    EPOCH = 1
    BATCH_SIZE = 64
    TIME_STEP = 28  # rnn time step / image height
    INPUT_SIZE = 28 # rnn input size / image width
    LR = 0.01
    DOWNLOAD_MNIST = False
    
    train_data = dsets.MNIST(
        root='./mnist',
        train=True,
        transform=transforms.ToTensor(),
        download=DOWNLOAD_MNIST
    )
    
    train_loader = Data.DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2
    )
    
    test_data = dsets.MNIST(
        root='./mnist',
        train=False,
        transform=transforms.ToTensor()
    )
    test_x = Variable(test_data.test_data,volatile=True).type(torch.FloatTensor)[:2000]/255.
    test_y = test_data.test_labels.numpy().squeeze()[:2000]
    
    class RNN(nn.Module):
        def __init__(self):
            super(RNN,self).__init__()
    
            self.rnn = nn.LSTM(
                input_size=INPUT_SIZE,
                hidden_size=64,
                num_layers=1,
                batch_first=True,
            )
            self.out = nn.Linear(64,10)
    
        def forward(self, x):
            r_out,(h_n, h_c) = self.rnn(x,None)
            out = self.out(r_out[:,-1,:])
            return out
    rnn =RNN()
    
    optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
    loss_func = nn.CrossEntropyLoss()
    
    for epoch in range(EPOCH):
        for step,(x,y) in enumerate(train_loader):
            b_x = Variable(x.view(-1,28,28))     # reshape x to (batch,time_step,input_size)
            b_y = Variable(y)
    
            output = rnn(b_x)
            loss = loss_func(output,b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if step % 50 == 0:
                test_output = rnn(test_x)
                pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
                accuracy = sum(pred_y == test_y) / test_y.size
                print('Epoch: ',epoch,'| train loss: %4.f' %loss.data[0],'| test accuracy: ',accuracy)
    # print 10 predictions from test data
    test_output = rnn(test_x[:10].view(-1,28,28))
    pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
    print(pred_y,'prediction number')
    print(test_y[:10],'real number')
    
    

    相关文章

      网友评论

          本文标题:pytorch下利用RNN实现mnist数据集的分类 简易代码

          本文链接:https://www.haomeiwen.com/subject/xbhufftx.html