训练
- 在训练之前,可以先定一个train_one_epoch()函数用于进行一个epoch的训练。这个函数包括使用train_loader中的每一个batch进行训练的训练部分;
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train() # 切换模型到训练模式
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 后向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算batch内损失
running_loss += loss.item() * inputs.size(0)
# 计算epoch内损失
epoch_loss = running_loss / len(train_loader.dataset)
return epoch_loss
enumerate()
-
注:有时候在dataloader外面经常会套一个enumerate()函数,enumerate()函数用于在遍历可迭代对象时,同时获得元素的索引和值。它的使用并不是强制性的,取决于是否需要跟踪当前批次的索引。如果不需要索引,仅仅需要遍历数据,那么可以直接迭代DataLoader而不使用enumerate()
-
举例:
for batch_idx, batch_data in enumerate(train_loader):
# 将数据移动到GPU
inputs, labels = batch_data
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 后向传播
optimizer.zero_grad() # 清零所有参数的梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
# 使用batch_idx
if batch_idx % 10 == 0: # 每10个批次打印一次损失
print(f'Batch [{batch_idx}], Loss: {loss.item():.4f}')
验证
- 如果有验证集,可以编写validate_one_epoch()函数用于实现对验证集中的每个批次进行验证的验证部分
# 定义验证函数
def validate_one_epoch(model, valid_loader, criterion, device):
model.eval() # 切换到评估模式
running_loss = 0.0
# 在验证过程中不需要计算梯度
with torch.no_grad():
for inputs, labels in valid_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0) # 计算平均损失
epoch_loss = running_loss / len(valid_loader.dataset)
return epoch_loss
在每个epoch中进行训练+验证
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
valid_loss = validate_one_epoch(model, valid_loader, criterion, device)
print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}')
测试(推理)
使用训练好的模型进行推理,其实validation部分就是推理,因此代码和validate_one_epoch比较类似
# 设置模型为评估模式
model.eval()
# 进行推理
with torch.no_grad(): # 在推理过程中不需要计算梯度
outputs = model(new_inputs)
# 输出结果
print(outputs)
网友评论