损失函数(criterion)
- 通过实例化各种损失函数类进行定义,一般实例化名为criterion
import torch.nn as nn
# 均方误差,适用于回归任务
criterion = nn.MSELoss()
# 交叉熵损失,用于多分类任务
criterion = nn.CrossEntropyLoss()
# 二元交叉熵损失,用于二分类任务,输出层神经元个数为1
criterion = nn.BCELoss()
优化器(optimizer)
-
优化器在PyTorch中是用来管理模型参数和梯度的。包括最基本的SGD、SGD with Momentum、AdaGrad、RMSprop、Adam。
-
Adam是目前最常用的优化算法之一,结合了动量和RMSprop的优点。通过一下代码实例化一个基于Adam的optimizer
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=0.001) # model为实例化后的模型
在训练中使用损失函数和优化器
- 向前传播
# 输入数据经过模型,得到outputs
outputs = model(inputs)
# outputs和labels进行损失计算
loss = criterion(outputs, labels)
- 后向传播
# 梯度清零
# 每次反向传播时,我们希望计算的是当前批次数据所对应的梯度
# 如果不清零梯度,当前批次的梯度会被之前批次的梯度污染,导致梯度计算不准确
optimizer.zero_grad()
# 计算损失相对于每个参数的梯度
loss.backward()
# 根据当前的梯度更新模型的参数
optimizer.step()
网友评论