数据集介绍点下面的链接:
使用的是Pytorch框架
train.py
"""
# author: shiyipaisizuo
# contact: shiyipaisizuo@gmail.com
# file: train.py
# time: 2018/8/14 09:43
# license: MIT
"""
import argparse
import os
import time
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser("""Image classifical!""")
parser.add_argument('--path', type=str, default='../data/cifar10/',
help="""image dir path default: '../data/cifar10/'.""")
parser.add_argument('--epochs', type=int, default=50,
help="""Epoch default:50.""")
parser.add_argument('--batch_size', type=int, default=256,
help="""Batch_size default:256.""")
parser.add_argument('--lr', type=float, default=0.0001,
help="""learing_rate. Default=0.0001""")
parser.add_argument('--num_classes', type=int, default=10,
help="""num classes""")
parser.add_argument('--model_path', type=str, default='../../model/pytorch/',
help="""Save model path""")
parser.add_argument('--model_name', type=str, default='cifar10.pth',
help="""Model name.""")
parser.add_argument('--display_epoch', type=int, default=5)
args = parser.parse_args()
# Create model
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
transform = transforms.Compose([
transforms.Resize(32), # 将图像转化为32 * 32
transforms.RandomHorizontalFlip(p=0.75), # 有0.75的几率随机旋转
transforms.RandomCrop(24), # 从图像中裁剪一个24 * 24的
transforms.ColorJitter(brightness=1, contrast=2, saturation=3, hue=0), # 给图像增加一些随机的光照
transforms.ToTensor(), # 将numpy数据类型转化为Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化
])
# Load data
train_datasets = torchvision.datasets.CIFAR10(root=args.path,
transform=transform,
download=True,
train=True)
train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
batch_size=args.batch_size,
shuffle=True)
test_datasets = torchvision.datasets.CIFAR10(root=args.path,
transform=transform,
download=True,
train=False)
test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
batch_size=args.batch_size,
shuffle=True)
def train():
print(f"Train numbers:{len(train_datasets)}")
# Load model
# if torch.cuda.is_available():
# model = torch.load(args.model_path + args.model_name).to(device)
# else:
# model = torch.load(args.model_path + args.model_name, map_location='cpu')
model = torchvision.models.resnet18(predicted=True).to(device)
model.avgpool = nn.AvgPool2d(1, 1)
model.fc = nn.Linear(512, args.num_classes)
print(model)
# cast
cast = nn.CrossEntropyLoss().to(device)
# Optimization
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-8)
for epoch in range(1, args.epochs + 1):
model.train()
# start time
start = time.time()
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = cast(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % args.display_epoch == 0:
end = time.time()
print(f"Epoch [{epoch}/{args.epochs}], "
f"Loss: {loss.item():.8f}, "
f"Time: {(end-start) * args.display_epoch:.1f}sec!")
model.eval()
correct_prediction = 0.
total = 0
for images, labels in test_loader:
# to GPU
images = images.to(device)
labels = labels.to(device)
# print prediction
outputs = model(images)
# equal prediction and acc
_, predicted = torch.max(outputs.data, 1)
# val_loader total
total += labels.size(0)
# add correct
correct_prediction += (predicted == labels).sum().item()
print(f"Acc: {(correct_prediction / total):4f}")
# Save the model checkpoint
torch.save(model, args.model_path + args.model_name)
print(f"Model save to {args.model_path + args.model_name}.")
if __name__ == '__main__':
train()
prediction.py
"""
# author: shiyipaisizuo
# contact: shiyipaisizuo@gmail.com
# file: prediction.py
# time: 2018/8/14 09:35
# license: MIT
"""
import argparse
import os
import torch
import torchvision
from torchvision import transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser("""Image classifical!""")
parser.add_argument('--path', type=str, default='../data/cifar10/',
help="""image dir path default: '../data/cifar10/'.""")
parser.add_argument('--batch_size', type=int, default=256,
help="""Batch_size default:100.""")
parser.add_argument('--num_classes', type=int, default=10,
help="""num classes""")
parser.add_argument('--model_path', type=str, default='../../model/pytorch/',
help="""Save model path""")
parser.add_argument('--model_name', type=str, default='cifar10.pth',
help="""Model name.""")
args = parser.parse_args()
# Create model
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
transform = transforms.Compose([
transforms.Resize(32), # 将图像转化为128 * 128
transforms.RandomCrop(24), # 从图像中裁剪一个114 * 114的
transforms.ToTensor(), # 将numpy数据类型转化为Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # 归一化
])
# Load data
test_datasets = torchvision.datasets.CIFAR10(root=args.path,
download=True,
transform=transform,
train=False)
test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
batch_size=args.batch_size,
shuffle=True)
def test():
print(f"test numbers: {len(test_datasets)}.")
# Load model
if torch.cuda.is_available():
model = torch.load(args.model_path + args.model_name).to(device)
else:
model = torch.load(args.model_path + args.model_name, map_location='cpu')
model.eval()
correct_prediction = 0.
total = 0
for images, labels in test_loader:
# to GPU
images = images.to(device)
labels = labels.to(device)
# print prediction
outputs = model(images)
# equal prediction and acc
_, predicted = torch.max(outputs.data, 1)
# val_loader total
total += labels.size(0)
# add correct
correct_prediction += (predicted == labels).sum().item()
print(f"Acc: {(correct_prediction / total):4f}")
if __name__ == '__main__':
test()
validation.py
"""
# author: shiyipaisizuo
# contact: shiyipaisizuo@gmail.com
# file: validation.py
# time: 2018/8/14 09:43
# license: MIT
"""
import argparse
import os
import torch
import torchvision
from torchvision import transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser("""Image classifical!""")
parser.add_argument('--path', type=str, default='../data/cifar10/',
help="""image dir path default: '../data/cifar10/'.""")
parser.add_argument('--batch_size', type=int, default=1,
help="""Batch_size default:1.""")
parser.add_argument('--model_path', type=str, default='../../model/pytorch/',
help="""Save model path""")
parser.add_argument('--model_name', type=str, default='cifar10.pth',
help="""Model name.""")
args = parser.parse_args()
# Create model
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
transform = transforms.Compose([
transforms.Resize(32), # 将图像转化为32 * 32
transforms.RandomCrop(24), # 从图像中裁剪一个114 * 114的
transforms.ToTensor(), # 将numpy数据类型转化为Tensor
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化
])
# Load data
val_datasets = torchvision.datasets.ImageFolder(root=args.path + 'val/',
transform=transform)
val_loader = torch.utils.data.DataLoader(dataset=val_datasets,
batch_size=args.batch_size,
shuffle=True)
# train_datasets dict
item = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
def val():
# Load model
if torch.cuda.is_available():
model = torch.load(args.model_path + args.model_name).to(device)
else:
model = torch.load(args.model_path + args.model_name, map_location='cpu')
model.eval()
for i, (images, _) in enumerate(val_loader):
# to GPU
images = images.to(device)
# print prediction
outputs = model(images)
# equal prediction and acc
_, predicted = torch.max(outputs.data, 1)
di = {v: k for k, v in item.items()}
pred = di[int(predicted[0])]
file = str(val_datasets.imgs[i])[2:-5]
print(f"{i+1}.({file}) is {pred}!")
if __name__ == '__main__':
val()
验证情况
1.(../data/cifar10/val/bird/bird.jpg) is bird!
2.(../data/cifar10/val/bird/bird2.jpg) is bird!
3.(../data/cifar10/val/bird/bird2的副本 2.jpg) is bird!
4.(../data/cifar10/val/bird/bird2的副本.jpg) is bird!
5.(../data/cifar10/val/bird/bird的副本 2.jpg) is bird!
6.(../data/cifar10/val/bird/bird的副本.jpg) is plane!
7.(../data/cifar10/val/plane/plane.jpg) is plane!
8.(../data/cifar10/val/plane/plane2.jpg) is ship!
9.(../data/cifar10/val/plane/plane2的副本 2.jpg) is plane!
10.(../data/cifar10/val/plane/plane2的副本.jpg) is plane!
11.(../data/cifar10/val/plane/plane的副本 2.jpg) is bird!
12.(../data/cifar10/val/plane/plane的副本.jpg) is plane!
使用说明
- train:
python train.py
- test:
python pediction.py
- val:
python validation.py
Acc: 0.984.
网友评论