美文网首页AI新时空
Pytorch +CNN训练

Pytorch +CNN训练

作者: Yankee_13 | 来源:发表于2018-12-29 13:53 被阅读0次

    献给莹莹

    1.读入数据

    利用ImageFolder读入训练数据,可以参考之前的文章

    2.构建训练流程

    1).一些准备工作

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #设置gpu或是cpu
    print('==> Building model..')
    net = ResNet18()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
    

    2).构建损失函数

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.01,weight_decay=5e-4)
    scheduler=MultiStepLR(optimizer,milestones=[50,100],gamma=0.1)
    
    1.损失函数

    交叉熵损失函数
    nn.CrossEntropyLoss()


    x为一维向量,label为一个数值
    2.学习率衰减

    MultiStepLR能够控制学习率,milestones表示训练的epoch里程碑,gamma表示衰减因子
    以下例子表示,训练epoch到达50时,lr乘以0.1,训练epoch到达100时,lr再乘以0.1

    from torch.optim.lr_scheduler import MultiStepLR
    scheduler=MultiStepLR(optimizer,milestones=[50,100],gamma=0.1)
    

    运行时,需要加一句:

    scheduler.step()
    

    3).定义训练部分

    # Training
    def train(epoch):
        net.train()
        for data in tqdm(trainloader, leave=False, total=len(trainloader)):
            inputs=data[0]
            targets=data[1]
            inputs=inputs.to(device)
            targets=targets.to(device)
            optimizer.zero_grad()
            #每次喂入数据前,都需要将梯度清零
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            #计算loss
            loss.backward()
            #传回反向梯度
            optimizer.step()
            #梯度传回,利用优化器将参数更新
    
    1.tqdm的使用

    tqdm的载入:

    from tqdm import tqdm
    

    举例

    for data in tqdm(trainloader, leave=False, total=len(trainloader))
    

    一些默认参数

    def __init__(self, iterable=None, desc=None, total=None, leave=True,
                     file=sys.stderr, ncols=None, mininterval=0.1,
                     maxinterval=10.0, miniters=None, ascii=None, disable=False,
                     unit='it', unit_scale=False, dynamic_ncols=False,
                     smoothing=0.3, bar_format=None, initial=0, position=None,
                     gui=False, **kwargs):
    

    关键参数解释:

    • iterable:一个迭代器,可以是torch中的dataloader
    • leave:是否保留进度条的最终形态
    • total:预期迭代的数目,一般是迭代器的长度
      详细说明来自tqdm官方文档:https://tqdm.github.io/docs/tqdm/
      注:
      如果不要求对数据的读取可视化,可以直接用enumerate
    for batch_idx, (inputs, targets) in enumerate(trainloader):
    

    4).定义测试部分

    这一部分实际和训练部分代码相似,只不过多了一个检测的标准,用来判断是否要更新权重,保存下来。

    def test(epoch):
        global best_acc
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
        #因为是测试,因此禁止梯度
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
    
        acc = 100. * correct / total
        print("Now acc is {}".format(acc))
    

    保存权重

        if acc > best_acc:
            print('Saving..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.t7')
            best_acc = acc
    

    迭代

    if __name__ =="__main__":
        best_acc=0
        for epoch in range(0,200):
            scheduler.step()
            train(epoch)
            test(epoch)
    

    5).继续训练或是应用训练后的模型

    print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.t7')
        net.load_state_dict(checkpoint['net'])
        #best_acc = checkpoint['acc']
        #start_epoch = checkpoint['epoch']
    

    相关文章

      网友评论

        本文标题:Pytorch +CNN训练

        本文链接:https://www.haomeiwen.com/subject/fixdlqtx.html