pytorch学习(十八)—预训练模型微调

作者: 侠之大者_7d3f | 来源:发表于2019-01-08 10:18 被阅读2次

    训练结果

    image.png image.png image.png image.png image.png image.png image.png

    完整工程

    • 工程目录结构


      image.png
    • 代码

    import torch
    import torch.optim as optim
    import torch.nn as nn
    from torch.utils.data import DataLoader
    from torchvision.datasets import ImageFolder
    import torchvision.models as models
    import torchvision.transforms as transforms
    import numpy as np
    import copy
    
    
    # ---------------------------------------------------------
    # 载入预训练的AlexNet模型
    model = models.alexnet(pretrained=True)
    # 修改输出层,2分类
    model.classifier[6] = nn.Linear(in_features=4096, out_features=2)
    
    
    # -------------------------数据集----------------------------------------------------
    
    transform = transforms.Compose([transforms.Resize((227,227)),
                                    transforms.ToTensor()])
    
    train_dataset = ImageFolder(root='./data/train', transform=transform)
    val_dataset = ImageFolder(root='./data/val', transform=transform)
    
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
    
    
    # ------------------优化方法,损失函数--------------------------------------------------
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    loss_fc = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)
    
    
    # --------------------判断是否支持GPU--------------------------------------------------
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # -------------------训练-------------------------------------------------------------
    
    epoch_nums = 50
    best_model_wts = model.state_dict()
    best_acc = 0
    for epoch in range(epoch_nums):
        scheduler.step()
        running_loss = 0.0
        epoch_loss = 0.0
        correct = 0
        total = 0
    
        for i, sample_batch in enumerate(train_dataloader):
            inputs = sample_batch[0]
            labels = sample_batch[1]
    
            inputs.to(device)
            labels.to(device)
    
            model.train()
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            # loss
            loss = loss_fc(outputs, labels)
    
            loss.backward()
            optimizer.step()
    
            #
            running_loss += loss.item()
            if i % 10 == 9:
                correct = 0
                total = 0
                for images_test, labels_test in val_dataloader:
                    model.eval()
                    images_test = images_test.to(device)
                    labels_test = labels_test.to(device)
                    outputs_test = model(images_test)
                    _, prediction = torch.max(outputs_test, 1)
                    correct += ((prediction == labels_test).sum()).item()
                    total += labels_test.size(0)
                accuracy = correct/total
                print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
                running_loss = 0.0
                if accuracy > best_acc:
                    best_acc = accuracy
                    best_model_wts = copy.deepcopy(model.state_dict())
    
    
    print('Train finish')
    torch.save(best_model_wts, './models/model_50.pth')
    

    相关文章

      网友评论

        本文标题:pytorch学习(十八)—预训练模型微调

        本文链接:https://www.haomeiwen.com/subject/wuqvrqtx.html