Time: 2019-08-04
循环一个epoch
# 循环一个batch
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)
total_loss = 0.0
total_correct = 0
for batch in train_loader:
images, labels = batch
preds = network(images) # 传入一个batch
loss = F.cross_entropy(preds, labels) # 计算损失函数
# 计算前需要先使得梯度为0
optimizer.zero_grad()
loss.backward() # 计算梯度
optimizer.step() # 一个批次更新参数
total_loss += loss.item()
total_correct += get_num_correct(preds, labels)
print("epoch: ", 0, "total_correct: ", total_correct, "loss: ", total_loss)
训练多个epoch
for epoch in range(6):
# 循环一个batch
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)
total_loss = 0.0
total_correct = 0
for batch in train_loader:
images, labels = batch
preds = network(images) # 传入一个batch
loss = F.cross_entropy(preds, labels) # 计算损失函数
# 计算前需要先使得梯度为0
optimizer.zero_grad()
loss.backward() # 计算梯度
optimizer.step() # 一个批次更新参数
total_loss += loss.item()
total_correct += get_num_correct(preds, labels)
print("epoch: ", epoch, "total_correct: ", total_correct, "loss: ", total_loss)
END.
网友评论