美文网首页
CNN训练循环

CNN训练循环

作者: 钢笔先生 | 来源:发表于2019-08-04 16:34 被阅读0次

Time: 2019-08-04

循环一个epoch

# 循环一个batch
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

total_loss = 0.0
total_correct = 0

for batch in train_loader:
  images, labels = batch

  preds = network(images) # 传入一个batch
  loss = F.cross_entropy(preds, labels) # 计算损失函数
  
  # 计算前需要先使得梯度为0
  optimizer.zero_grad()

  loss.backward() # 计算梯度
  optimizer.step() # 一个批次更新参数

  total_loss += loss.item()
  total_correct += get_num_correct(preds, labels)
  
print("epoch: ", 0, "total_correct: ", total_correct,  "loss: ", total_loss)

训练多个epoch

for epoch in range(6):
  # 循环一个batch
  network = Network()

  train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
  optimizer = optim.Adam(network.parameters(), lr=0.01)

  total_loss = 0.0
  total_correct = 0

  for batch in train_loader:
    images, labels = batch

    preds = network(images) # 传入一个batch
    loss = F.cross_entropy(preds, labels) # 计算损失函数

    # 计算前需要先使得梯度为0
    optimizer.zero_grad()

    loss.backward() # 计算梯度
    optimizer.step() # 一个批次更新参数

    total_loss += loss.item()
    total_correct += get_num_correct(preds, labels)

  print("epoch: ", epoch, "total_correct: ", total_correct,  "loss: ", total_loss)

END.

相关文章

网友评论

      本文标题:CNN训练循环

      本文链接:https://www.haomeiwen.com/subject/ldjydctx.html