MNIST手写数字集是包含十类70000张28x28的手写数字图片,这个数据集有助用于神经网络来进行多分类问题,官方提供了数据下载接口,其中分割为训练集60000张图片和测试集10000张图片,本人搭建的网络隐藏层采用的激活函数为LeakyReLU主要是因为将数据标准化为均值为0方差为1之后其数据有可能为负,如果使用ReLU负数的时候会造成梯度为零,采用LeakyReLU即使为负其导数也不为零,这样有助于梯度下降,输出层的激活函数使用softmax,损失函数使用交叉熵,这个在pytorch里面把softmax和交叉熵封装成了CrossEntropyLoss写的第一个神经网络,效果还可以
这也是本人
# 加载数据集
batch_size = 200
learning_rate = 0.01
epochs = 10
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('Mnist_data',train = True,download = True,
transform = transforms.Compose([
# 将数据变换为张量形式
transforms.ToTensor(),
# 根据官方提供的均值和标准差,标准化数据方便减少梯度下降次数
transforms.Normalize((0.1307,), (0.3081,))
])),batch_size = batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('Mnist_data',train = False,download = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),batch_size = batch_size,shuffle = True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to Mnist_data\MNIST\raw\train-images-idx3-ubyte.gz
100.1%
Extracting Mnist_data\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to Mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz
113.5%
Extracting Mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to Mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz
100.4%
Extracting Mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to Mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz
180.4%
Extracting Mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!
print(type(train_loader))
print(train_loader.dataset)
#查看数据结构
for i in enumerate(train_loader):
print(i)
break
<class 'torch.utils.data.dataloader.DataLoader'>
Dataset MNIST
Number of datapoints: 60000
Root location: Mnist_data
Split: Train
(0, [tensor([[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]],
[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]],
[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]],
...,
[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]],
[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]],
[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]]]), tensor([2, 1, 4, 4, 3, 7, 4, 1, 0, 0, 1, 9, 7, 9, 7, 9, 7, 4, 8, 9, 7, 3, 6, 7,
9, 9, 7, 1, 9, 0, 5, 3, 4, 7, 1, 5, 5, 8, 8, 9, 2, 4, 7, 8, 1, 9, 3, 4,
8, 9, 3, 5, 1, 2, 9, 3, 7, 5, 7, 8, 1, 3, 4, 2, 0, 0, 4, 9, 7, 3, 8, 6,
1, 5, 2, 1, 1, 9, 7, 4, 7, 5, 3, 6, 5, 8, 6, 4, 8, 5, 2, 9, 6, 4, 6, 8,
4, 2, 5, 1, 6, 3, 7, 7, 8, 7, 7, 0, 5, 7, 0, 9, 2, 5, 0, 0, 0, 8, 5, 3,
8, 7, 8, 7, 2, 9, 1, 7, 9, 9, 4, 7, 6, 6, 7, 0, 6, 3, 6, 0, 6, 6, 9, 2,
2, 8, 8, 0, 9, 7, 1, 6, 1, 1, 8, 3, 2, 7, 4, 7, 9, 8, 3, 0, 3, 9, 1, 5,
3, 2, 8, 2, 5, 9, 7, 1, 5, 0, 8, 5, 6, 2, 3, 8, 6, 8, 7, 8, 2, 0, 3, 3,
3, 4, 7, 7, 2, 4, 3, 8])])
class First_nn(nn.Module):
def __init__(self):
super(First_nn,self).__init__()
self.model = nn.Sequential(
nn.Linear(28*28,200),
# 因为数据被标准化了可能为负值为防止梯度为零用LeakyReLU
nn.LeakyReLU(inplace = True),
nn.Linear(200,200),
nn.LeakyReLU(inplace = True),
nn.Linear(200,10),
nn.LeakyReLU(inplace = True)
)
def forward(self,x):
return self.model(x)
# 将模型搬到GPU上
device = torch.device('cuda:0')
#net = First_nn().cuda()
net = First_nn().to(device)
#随机梯度下降
#Error:optimizer = optim.SGD(net.parameters(),lr = learning_rate).to(device)
optimizer = optim.SGD(net.parameters(),lr = learning_rate)
# softmax激活 + CrossEntropy损失
loss_func = nn.CrossEntropyLoss().to(device)
for epoch in range(epochs):
for batch_idx,(data,target) in enumerate(train_loader):
data = data.view(data.size(0),1*28*28)
data = data.to(device)
target = target.cuda()
logits = net(data)
loss = loss_func(logits,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx+1)%100 ==0:
print('Train epoch {} [{}/{}({:.2f}%)]\t loss:{}'.format(
epoch,(batch_idx+1)*len(data),len(train_loader.dataset),
100.*(batch_idx+1)*len(data)/len(train_loader.dataset),loss.item()))
if batch_idx == 0:
print('Train epoch {} [{}/{}({}%)]\t loss:{}'.format(
epoch,batch_idx*len(data),len(train_loader.dataset),
batch_idx/len(train_loader.dataset),loss.item()))
# test module
test_loss = 0
correct = 0
for batch_idx,(data,target) in enumerate(test_loader):
data = data.view(data.size(0),28*28)
data = data.to(device)
target = target.cuda()
logits = net(data)
test_loss += loss_func(logits,target).item()
predict = logits.argmax(dim=1)
correct += predict.eq(target).float().sum().item()
test_loss /= len(test_loader.dataset)
print('epoch:{} test_data Average_loss : {}\t Accuracy: {}/{}({:2f})%'.format(
epoch,test_loss,correct,len(test_loader.dataset),100.*correct/len(test_loader.dataset)))
Train epoch 0 [0/60000(0.0%)] loss:2.306537389755249
Train epoch 0 [20000/60000(33.33%)] loss:2.035033702850342
Train epoch 0 [40000/60000(66.67%)] loss:1.2957611083984375
Train epoch 0 [60000/60000(100.00%)] loss:0.7836695909500122
epoch:0 test_data Average_loss : 0.003627978050708771 Accuracy: 8302.0/10000(83.020000)%
Train epoch 1 [0/60000(0.0%)] loss:0.7361741065979004
Train epoch 1 [20000/60000(33.33%)] loss:0.49886584281921387
Train epoch 1 [40000/60000(66.67%)] loss:0.49449071288108826
Train epoch 1 [60000/60000(100.00%)] loss:0.3523230254650116
epoch:1 test_data Average_loss : 0.002026969975233078 Accuracy: 8896.0/10000(88.960000)%
Train epoch 2 [0/60000(0.0%)] loss:0.4014418125152588
Train epoch 2 [20000/60000(33.33%)] loss:0.3810940682888031
Train epoch 2 [40000/60000(66.67%)] loss:0.44071638584136963
Train epoch 2 [60000/60000(100.00%)] loss:0.3752475082874298
epoch:2 test_data Average_loss : 0.0016781002402305603 Accuracy: 9025.0/10000(90.250000)%
Train epoch 3 [0/60000(0.0%)] loss:0.35575544834136963
Train epoch 3 [20000/60000(33.33%)] loss:0.35479432344436646
Train epoch 3 [40000/60000(66.67%)] loss:0.37836602330207825
Train epoch 3 [60000/60000(100.00%)] loss:0.34169018268585205
epoch:3 test_data Average_loss : 0.0015032748386263847 Accuracy: 9110.0/10000(91.100000)%
Train epoch 4 [0/60000(0.0%)] loss:0.2739598751068115
Train epoch 4 [20000/60000(33.33%)] loss:0.38402724266052246
Train epoch 4 [40000/60000(66.67%)] loss:0.30347561836242676
Train epoch 4 [60000/60000(100.00%)] loss:0.2502043843269348
epoch:4 test_data Average_loss : 0.0014004800230264663 Accuracy: 9194.0/10000(91.940000)%
Train epoch 5 [0/60000(0.0%)] loss:0.24062280356884003
Train epoch 5 [20000/60000(33.33%)] loss:0.31439653038978577
Train epoch 5 [40000/60000(66.67%)] loss:0.2230578362941742
Train epoch 5 [60000/60000(100.00%)] loss:0.21400272846221924
epoch:5 test_data Average_loss : 0.001314392538368702 Accuracy: 9229.0/10000(92.290000)%
Train epoch 6 [0/60000(0.0%)] loss:0.25108906626701355
Train epoch 6 [20000/60000(33.33%)] loss:0.3200721740722656
Train epoch 6 [40000/60000(66.67%)] loss:0.2467733919620514
Train epoch 6 [60000/60000(100.00%)] loss:0.26089826226234436
epoch:6 test_data Average_loss : 0.0012480686396360396 Accuracy: 9290.0/10000(92.900000)%
Train epoch 7 [0/60000(0.0%)] loss:0.2835906147956848
Train epoch 7 [20000/60000(33.33%)] loss:0.2755715250968933
Train epoch 7 [40000/60000(66.67%)] loss:0.2002713531255722
Train epoch 7 [60000/60000(100.00%)] loss:0.20591193437576294
epoch:7 test_data Average_loss : 0.0011820814564824104 Accuracy: 9322.0/10000(93.220000)%
Train epoch 8 [0/60000(0.0%)] loss:0.31660526990890503
Train epoch 8 [20000/60000(33.33%)] loss:0.202341690659523
Train epoch 8 [40000/60000(66.67%)] loss:0.2645059823989868
Train epoch 8 [60000/60000(100.00%)] loss:0.21705631911754608
epoch:8 test_data Average_loss : 0.00112670366615057 Accuracy: 9345.0/10000(93.450000)%
Train epoch 9 [0/60000(0.0%)] loss:0.2496471405029297
Train epoch 9 [20000/60000(33.33%)] loss:0.18550094962120056
Train epoch 9 [40000/60000(66.67%)] loss:0.2601899802684784
Train epoch 9 [60000/60000(100.00%)] loss:0.2624129354953766
epoch:9 test_data Average_loss : 0.0010824317336082458 Accuracy: 9382.0/10000(93.820000)%
网友评论