美文网首页Pytorch
Pytorch实现Resnet50用来解决cifar10图像分类

Pytorch实现Resnet50用来解决cifar10图像分类

作者: Lornatang | 来源:发表于2018-08-17 16:24 被阅读1623次

    数据集介绍点下面的链接:

    数据介绍点我

    数据下载点这里

    使用的是Pytorch框架

    代码地址

    train.py

    """
    # author: shiyipaisizuo
    # contact: shiyipaisizuo@gmail.com
    # file: train.py
    # time: 2018/8/14 09:43
    # license: MIT
    """
    
    import argparse
    import os
    
    import time
    import torch
    import torchvision
    from torch import nn, optim
    from torchvision import transforms
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    parser = argparse.ArgumentParser("""Image classifical!""")
    parser.add_argument('--path', type=str, default='../data/cifar10/',
                        help="""image dir path default: '../data/cifar10/'.""")
    parser.add_argument('--epochs', type=int, default=50,
                        help="""Epoch default:50.""")
    parser.add_argument('--batch_size', type=int, default=256,
                        help="""Batch_size default:256.""")
    parser.add_argument('--lr', type=float, default=0.0001,
                        help="""learing_rate. Default=0.0001""")
    parser.add_argument('--num_classes', type=int, default=10,
                        help="""num classes""")
    parser.add_argument('--model_path', type=str, default='../../model/pytorch/',
                        help="""Save model path""")
    parser.add_argument('--model_name', type=str, default='cifar10.pth',
                        help="""Model name.""")
    parser.add_argument('--display_epoch', type=int, default=5)
    
    args = parser.parse_args()
    
    # Create model
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    
    transform = transforms.Compose([
        transforms.Resize(32),  # 将图像转化为32 * 32
        transforms.RandomHorizontalFlip(p=0.75),  # 有0.75的几率随机旋转
        transforms.RandomCrop(24),  # 从图像中裁剪一个24 * 24的
        transforms.ColorJitter(brightness=1, contrast=2, saturation=3, hue=0),  # 给图像增加一些随机的光照
        transforms.ToTensor(),  # 将numpy数据类型转化为Tensor
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 归一化
    ])
    
    
    # Load data
    train_datasets = torchvision.datasets.CIFAR10(root=args.path,
                                                  transform=transform,
                                                  download=True,
                                                  train=True)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    
    test_datasets = torchvision.datasets.CIFAR10(root=args.path,
                                                 transform=transform,
                                                 download=True,
                                                 train=False)
    
    test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    
    
    def train():
        print(f"Train numbers:{len(train_datasets)}")
    
        # Load model
        # if torch.cuda.is_available():
        #     model = torch.load(args.model_path + args.model_name).to(device)
        # else:
        #     model = torch.load(args.model_path + args.model_name, map_location='cpu')
        model = torchvision.models.resnet18(predicted=True).to(device)
        model.avgpool = nn.AvgPool2d(1, 1)
        model.fc = nn.Linear(512, args.num_classes)
        print(model)
        # cast
        cast = nn.CrossEntropyLoss().to(device)
        # Optimization
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-8)
    
        for epoch in range(1, args.epochs + 1):
            model.train()
            # start time
            start = time.time()
            for images, labels in train_loader:
                images = images.to(device)
                labels = labels.to(device)
    
                # Forward pass
                outputs = model(images)
                loss = cast(outputs, labels)
    
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            if epoch % args.display_epoch == 0:
                end = time.time()
                print(f"Epoch [{epoch}/{args.epochs}], "
                      f"Loss: {loss.item():.8f}, "
                      f"Time: {(end-start) * args.display_epoch:.1f}sec!")
    
                model.eval()
    
                correct_prediction = 0.
                total = 0
                for images, labels in test_loader:
                    # to GPU
                    images = images.to(device)
                    labels = labels.to(device)
                    # print prediction
                    outputs = model(images)
                    # equal prediction and acc
                    _, predicted = torch.max(outputs.data, 1)
                    # val_loader total
                    total += labels.size(0)
                    # add correct
                    correct_prediction += (predicted == labels).sum().item()
    
                print(f"Acc: {(correct_prediction / total):4f}")
    
        # Save the model checkpoint
        torch.save(model, args.model_path + args.model_name)
        print(f"Model save to {args.model_path + args.model_name}.")
    
    
    if __name__ == '__main__':
        train()
    
    

    prediction.py

    """
    # author: shiyipaisizuo
    # contact: shiyipaisizuo@gmail.com
    # file: prediction.py
    # time: 2018/8/14 09:35
    # license: MIT
    """
    
    import argparse
    import os
    
    import torch
    import torchvision
    from torchvision import transforms
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    parser = argparse.ArgumentParser("""Image classifical!""")
    parser.add_argument('--path', type=str, default='../data/cifar10/',
                        help="""image dir path default: '../data/cifar10/'.""")
    parser.add_argument('--batch_size', type=int, default=256,
                        help="""Batch_size default:100.""")
    parser.add_argument('--num_classes', type=int, default=10,
                        help="""num classes""")
    parser.add_argument('--model_path', type=str, default='../../model/pytorch/',
                        help="""Save model path""")
    parser.add_argument('--model_name', type=str, default='cifar10.pth',
                        help="""Model name.""")
    
    args = parser.parse_args()
    
    # Create model
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    
    transform = transforms.Compose([
        transforms.Resize(32),  # 将图像转化为128 * 128
        transforms.RandomCrop(24),  # 从图像中裁剪一个114 * 114的
        transforms.ToTensor(),  # 将numpy数据类型转化为Tensor
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # 归一化
    ])
    
    # Load data
    test_datasets = torchvision.datasets.CIFAR10(root=args.path,
                                                 download=True,
                                                 transform=transform,
                                                 train=False)
    
    
    test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    
    
    def test():
        print(f"test numbers: {len(test_datasets)}.")
        # Load model
        if torch.cuda.is_available():
            model = torch.load(args.model_path + args.model_name).to(device)
        else:
            model = torch.load(args.model_path + args.model_name, map_location='cpu')
        model.eval()
    
        correct_prediction = 0.
        total = 0
        for images, labels in test_loader:
            # to GPU
            images = images.to(device)
            labels = labels.to(device)
            # print prediction
            outputs = model(images)
            # equal prediction and acc
            _, predicted = torch.max(outputs.data, 1)
            # val_loader total
            total += labels.size(0)
            # add correct
            correct_prediction += (predicted == labels).sum().item()
    
        print(f"Acc: {(correct_prediction / total):4f}")
    
    
    if __name__ == '__main__':
        test()
    
    

    validation.py

    """
    # author: shiyipaisizuo
    # contact: shiyipaisizuo@gmail.com
    # file: validation.py
    # time: 2018/8/14 09:43
    # license: MIT
    """
    
    import argparse
    import os
    
    import torch
    import torchvision
    from torchvision import transforms
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    parser = argparse.ArgumentParser("""Image classifical!""")
    parser.add_argument('--path', type=str, default='../data/cifar10/',
                        help="""image dir path default: '../data/cifar10/'.""")
    parser.add_argument('--batch_size', type=int, default=1,
                        help="""Batch_size default:1.""")
    parser.add_argument('--model_path', type=str, default='../../model/pytorch/',
                        help="""Save model path""")
    parser.add_argument('--model_name', type=str, default='cifar10.pth',
                        help="""Model name.""")
    
    args = parser.parse_args()
    
    # Create model
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    
    transform = transforms.Compose([
        transforms.Resize(32),  # 将图像转化为32 * 32
        transforms.RandomCrop(24),  # 从图像中裁剪一个114 * 114的
        transforms.ToTensor(),  # 将numpy数据类型转化为Tensor
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 归一化
    ])
    # Load data
    val_datasets = torchvision.datasets.ImageFolder(root=args.path + 'val/',
                                                    transform=transform)
    
    
    val_loader = torch.utils.data.DataLoader(dataset=val_datasets,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    # train_datasets dict
    item = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
    
    
    def val():
        # Load model
        if torch.cuda.is_available():
            model = torch.load(args.model_path + args.model_name).to(device)
        else:
            model = torch.load(args.model_path + args.model_name, map_location='cpu')
        model.eval()
        for i, (images, _) in enumerate(val_loader):
            # to GPU
            images = images.to(device)
            # print prediction
            outputs = model(images)
            # equal prediction and acc
            _, predicted = torch.max(outputs.data, 1)
            di = {v: k for k, v in item.items()}
    
            pred = di[int(predicted[0])]
    
            file = str(val_datasets.imgs[i])[2:-5]
    
            print(f"{i+1}.({file}) is {pred}!")
    
    
    if __name__ == '__main__':
        val()
    
    

    验证情况

    1.(../data/cifar10/val/bird/bird.jpg) is bird!
    2.(../data/cifar10/val/bird/bird2.jpg) is bird!
    3.(../data/cifar10/val/bird/bird2的副本 2.jpg) is bird!
    4.(../data/cifar10/val/bird/bird2的副本.jpg) is bird!
    5.(../data/cifar10/val/bird/bird的副本 2.jpg) is bird!
    6.(../data/cifar10/val/bird/bird的副本.jpg) is plane!
    7.(../data/cifar10/val/plane/plane.jpg) is plane!
    8.(../data/cifar10/val/plane/plane2.jpg) is ship!
    9.(../data/cifar10/val/plane/plane2的副本 2.jpg) is plane!
    10.(../data/cifar10/val/plane/plane2的副本.jpg) is plane!
    11.(../data/cifar10/val/plane/plane的副本 2.jpg) is bird!
    12.(../data/cifar10/val/plane/plane的副本.jpg) is plane!
    

    使用说明

    • train:

    python train.py

    • test:

    python pediction.py

    • val:

    python validation.py

    Acc: 0.984.

    LINCENSE: MIT

    相关文章

      网友评论

        本文标题:Pytorch实现Resnet50用来解决cifar10图像分类

        本文链接:https://www.haomeiwen.com/subject/gcltiftx.html