用CIFAR.10的数据集进行训练和测试
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
# 定义参数
EPOCH = 10 # 训练epoch次数
BATCH_SIZE = 64 # 批训练的数量
LR = 0.001 # 学习率
DOWNLOAD_MNIST = False # 设置True 可以自动下载数据
# MNIST数据集下载
train_data = datasets.CIFAR10(root='/nas/cifar10/',
train=True, # 这里是训练集
transform=transforms.ToTensor(),
download=True
)
test_data = datasets.CIFAR10(root='/nas/cifar10/',
train=False, # 测试集
transform=transforms.ToTensor(),
download=True
)
用DataLoader对数据分批,分成训练集和测试集
# Cifar-10的标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 使用DataLoader进行分批
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
采用Torch中的ResNet进行数据训练
# ResNet Model
model = torchvision.models.resnet50(pretrained=False)
#model = torchvision.models.densenet161(pretrained=False)
#损失函数:这里用交叉熵
criterion = nn.CrossEntropyLoss()
#优化器 这里用SGD
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# 训练
for epoch in range(EPOCH):
start_time = time.time()
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
# 计算损失函数
loss = criterion(outputs, labels)
# 清空上一轮梯度
optimizer.zero_grad()
# 反向传播
loss.backward()
# 参数更新
optimizer.step()
print('epoch{} loss:{:.4f} time:{:.4f}'.format(epoch+1, loss.item(), time.time()-start_time))
用CPU跑出来的结果,一个epoch一个小时,惨不忍睹。。。。。
epoch1 loss:3.6198 time:3435.6966
epoch2 loss:1.5368 time:3536.9948
epoch3 loss:1.9480 time:3657.2960
epoch4 loss:1.7615 time:4115.1650
epoch5 loss:1.4777 time:4802.4395
epoch6 loss:1.1862 time:2313.2824
epoch7 loss:1.6558 time:2310.4057
epoch8 loss:0.8733 time:2334.7355
epoch9 loss:1.0813 time:2411.8487
epoch10 loss:0.4476 time:2400.0543
用测试数据进行测试
# 测试
model = torch.load('cifar10_resnet.pt')
model.eval()
print(model)
correct = 0
total = 0
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
# 前向传播
out = model(images)
_, predicted = torch.max(out.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
#输出识别准确率
print('10000测试图像 准确率:{:.4f}%'.format(100 * correct / total))
最终识别准确率:
10000测试图像 准确率:63.8800%
网友评论