美文网首页
Pytorch笔记7-训练、验证、测试模型

Pytorch笔记7-训练、验证、测试模型

作者: 江湾青年 | 来源:发表于2024-07-18 21:20 被阅读0次

训练

  • 在训练之前,可以先定一个train_one_epoch()函数用于进行一个epoch的训练。这个函数包括使用train_loader中的每一个batch进行训练的训练部分;
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()  # 切换模型到训练模式
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 后向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算batch内损失
        running_loss += loss.item() * inputs.size(0)
    # 计算epoch内损失
    epoch_loss = running_loss / len(train_loader.dataset)
    return epoch_loss

enumerate()

  • 注:有时候在dataloader外面经常会套一个enumerate()函数,enumerate()函数用于在遍历可迭代对象时,同时获得元素的索引和值。它的使用并不是强制性的,取决于是否需要跟踪当前批次的索引。如果不需要索引,仅仅需要遍历数据,那么可以直接迭代DataLoader而不使用enumerate()

  • 举例:

for batch_idx, batch_data in enumerate(train_loader):
    # 将数据移动到GPU
    inputs, labels = batch_data
    inputs, labels = inputs.to(device), labels.to(device)
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    # 后向传播
    optimizer.zero_grad()  # 清零所有参数的梯度
    loss.backward()        # 计算梯度
    optimizer.step()       # 更新参数
    # 使用batch_idx
    if batch_idx % 10 == 0:  # 每10个批次打印一次损失
        print(f'Batch [{batch_idx}], Loss: {loss.item():.4f}')

验证

  • 如果有验证集,可以编写validate_one_epoch()函数用于实现对验证集中的每个批次进行验证的验证部分
# 定义验证函数
def validate_one_epoch(model, valid_loader, criterion, device):
    model.eval()  # 切换到评估模式
    running_loss = 0.0
    # 在验证过程中不需要计算梯度
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)    # 计算平均损失
    epoch_loss = running_loss / len(valid_loader.dataset)
    return epoch_loss

在每个epoch中进行训练+验证

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    valid_loss = validate_one_epoch(model, valid_loader, criterion, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}')

测试(推理)

使用训练好的模型进行推理,其实validation部分就是推理,因此代码和validate_one_epoch比较类似

# 设置模型为评估模式
model.eval()
# 进行推理
with torch.no_grad():  # 在推理过程中不需要计算梯度
    outputs = model(new_inputs)
# 输出结果
print(outputs)

相关文章

网友评论

      本文标题:Pytorch笔记7-训练、验证、测试模型

      本文链接:https://www.haomeiwen.com/subject/gkvkhjtx.html