美文网首页
Datawhale 零基础入门CV赛事-Task4 模型训练与验

Datawhale 零基础入门CV赛事-Task4 模型训练与验

作者: 顾子豪 | 来源:发表于2020-05-30 23:53 被阅读0次

    学习目标

    • 理解验证集的作用,并使用训练集和验证集完成训练

    • 学会使用Pytorch环境下的模型读取和加载,并了解调参流程
      数据集划分

    • 训练集

    • 用来训练模型内参数的数据集,Classfier直接根据训练集来调整自身获得更好的分类效果

    • 验证集

    • 用于在训练过程中检验模型的状态,收敛情况。验证集通常用于调整超参数,根据几组模型验证集上的表现决定哪组超参数拥有最好的性能。
      同时验证集在训练过程中还可以用来监控模型是否发生过拟合,一般来说验证集表现稳定后,若继续训练,训练集表现还会继续上升,但是验证集会出现不升反降的情况,这样一般就发生了过拟合。所以验证集也用来判断何时停止训练
      测试集
      测试集用来评价模型泛化能力,即之前模型使用验证集确定了超参数,使用训练集调整了参数,最后使用一个从没有见过的数据集来判断这个模型是否Work。

    模型训练与验证

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=10, 
        shuffle=True, 
        num_workers=10, 
    )
        
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=10, 
        shuffle=False, 
        num_workers=10, 
    )
    
    model = SVHN_Model1()
    criterion = nn.CrossEntropyLoss (size_average=False)
    optimizer = torch.optim.Adam(model.parameters(), 0.001)
    best_loss = 1000.0
    for epoch in range(20):
        print('Epoch: ', epoch)
    
        train(train_loader, model, criterion, optimizer, epoch)
        val_loss = validate(val_loader, model, criterion)
        
        # 记录下验证集精度
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), './model.pt')
    
    • 其中每个Epoch的训练代码如下:
    def train(train_loader, model, criterion, optimizer, epoch):
        # 切换模型为训练模式
        model.train()
    
        for i, (input, target) in enumerate(train_loader):
            c0, c1, c2, c3, c4, c5 = model(data[0])
            loss = criterion(c0, data[1][:, 0]) + \
                    criterion(c1, data[1][:, 1]) + \
                    criterion(c2, data[1][:, 2]) + \
                    criterion(c3, data[1][:, 3]) + \
                    criterion(c4, data[1][:, 4]) + \
                    criterion(c5, data[1][:, 5])
            loss /= 6
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    • 其中每个Epoch的验证代码如下:
    def validate(val_loader, model, criterion):
        # 切换模型为预测模型
        model.eval()
        val_loss = []
    
        # 不记录模型梯度信息
        with torch.no_grad():
            for i, (input, target) in enumerate(val_loader):
                c0, c1, c2, c3, c4, c5 = model(data[0])
                loss = criterion(c0, data[1][:, 0]) + \
                        criterion(c1, data[1][:, 1]) + \
                        criterion(c2, data[1][:, 2]) + \
                        criterion(c3, data[1][:, 3]) + \
                        criterion(c4, data[1][:, 4]) + \
                        criterion(c5, data[1][:, 5])
                loss /= 6
                val_loss.append(loss.item())
        return np.mean(val_loss)
    

    模型保存与加载

    • 在Pytorch中模型的保存和加载非常简单,比较常见的做法是保存和加载模型参数:
    torch.save(model_object.state_dict(), 'model.pt')
    model.load_state_dict(torch.load(' model.pt'))
    

    相关文章

      网友评论

          本文标题:Datawhale 零基础入门CV赛事-Task4 模型训练与验

          本文链接:https://www.haomeiwen.com/subject/btbkzhtx.html