美文网首页
Pytorch笔记7-训练、验证、测试模型

Pytorch笔记7-训练、验证、测试模型

作者: 江湾青年 | 来源:发表于2024-07-18 21:20 被阅读0次

    训练

    • 在训练之前,可以先定一个train_one_epoch()函数用于进行一个epoch的训练。这个函数包括使用train_loader中的每一个batch进行训练的训练部分;
    def train_one_epoch(model, train_loader, criterion, optimizer, device):
        model.train()  # 切换模型到训练模式
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # 后向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 计算batch内损失
            running_loss += loss.item() * inputs.size(0)
        # 计算epoch内损失
        epoch_loss = running_loss / len(train_loader.dataset)
        return epoch_loss
    

    enumerate()

    • 注:有时候在dataloader外面经常会套一个enumerate()函数,enumerate()函数用于在遍历可迭代对象时,同时获得元素的索引和值。它的使用并不是强制性的,取决于是否需要跟踪当前批次的索引。如果不需要索引,仅仅需要遍历数据,那么可以直接迭代DataLoader而不使用enumerate()

    • 举例:

    for batch_idx, batch_data in enumerate(train_loader):
        # 将数据移动到GPU
        inputs, labels = batch_data
        inputs, labels = inputs.to(device), labels.to(device)
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 后向传播
        optimizer.zero_grad()  # 清零所有参数的梯度
        loss.backward()        # 计算梯度
        optimizer.step()       # 更新参数
        # 使用batch_idx
        if batch_idx % 10 == 0:  # 每10个批次打印一次损失
            print(f'Batch [{batch_idx}], Loss: {loss.item():.4f}')
    

    验证

    • 如果有验证集,可以编写validate_one_epoch()函数用于实现对验证集中的每个批次进行验证的验证部分
    # 定义验证函数
    def validate_one_epoch(model, valid_loader, criterion, device):
        model.eval()  # 切换到评估模式
        running_loss = 0.0
        # 在验证过程中不需要计算梯度
        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)    # 计算平均损失
        epoch_loss = running_loss / len(valid_loader.dataset)
        return epoch_loss
    

    在每个epoch中进行训练+验证

    num_epochs = 10
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        valid_loss = validate_one_epoch(model, valid_loader, criterion, device)
        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}')
    

    测试(推理)

    使用训练好的模型进行推理,其实validation部分就是推理,因此代码和validate_one_epoch比较类似

    # 设置模型为评估模式
    model.eval()
    # 进行推理
    with torch.no_grad():  # 在推理过程中不需要计算梯度
        outputs = model(new_inputs)
    # 输出结果
    print(outputs)
    

    相关文章

      网友评论

          本文标题:Pytorch笔记7-训练、验证、测试模型

          本文链接:https://www.haomeiwen.com/subject/gkvkhjtx.html