炒个冷饭……以后不能再不会写了!!
import torch
import torch.nn as nn
import torch.tensor
import torch.nn.functional as F
import torch.optim as optim
import os
from torchvision import datasets, transforms
# hyper parameters
BATCH_SIZE = 5
HIDDEN_SIZE = 512
WIDTH = 28
HEIGHT = 28
PIC_SIZE = WIDTH * HEIGHT
LEARNING_RATE = 1e-5
USE_CUDA = True
SAVE_PATH = './model/dnn'
# data loader
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor()]) # 将图像转为张量
), batch_size=BATCH_SIZE
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor()])
), batch_size=BATCH_SIZE
)
class NN(torch.nn.Module):
# the class of your module
def __init__(self, models):
super(NN, self).__init__()
self.models = models
def forward(self, input):
x = input.reshape([-1, PIC_SIZE])
for model in self.models:
x = model(x)
return x
# ModuleList
models = nn.ModuleList([
nn.Linear(PIC_SIZE, HIDDEN_SIZE),
nn.ReLU(),
nn.Linear(HIDDEN_SIZE, 10),
])
model = NN(models)
print(model)
# If you have more than one gpu, use it
if USE_CUDA: model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
def train(epoch):
for i, (source, label) in enumerate(train_loader):
if USE_CUDA: source, label = source.cuda(), label.cuda()
out = model(source)
loss = F.cross_entropy(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0 and i != 0:
print('Epoch {}, step {} | loss: {}'.format(epoch, i, loss))
return
def test():
count = 0
acc = 0
for i, (source, label) in enumerate(test_loader):
if USE_CUDA: source, label = source.cuda(), label.cuda()
out = model(source)
_, out = out.max(dim=1)
acc += BATCH_SIZE - (out - label).nonzero().size()[0]
count += BATCH_SIZE
return float(acc) / float(count)
if __name__ == '__main__':
for i in range(1, 10):
train(i)
print('Epoch {} | Accuracy {}'.format(i, test()))
if not os.path.exists(SAVE_PATH): os.makedirs(SAVE_PATH)
torch.save(model, './model/dnn/checkpoint_{}.pt'.format(i))
网友评论