美文网首页
pytorch代码常见问题及注意点

pytorch代码常见问题及注意点

作者: callme周小伦 | 来源:发表于2019-05-31 17:33 被阅读0次

主要记录零售商品分类遇到的问题。
train代码如下:

def train_model():

    model.train()
    print("start train")
    print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
    for epoch in range(args.num_epochs):
        train_loss =0
        correct_total=0
        max_acc =0
        last_loss =0
        
        for i, sample_batch in enumerate(train_dataloader):
            print("epoch %d:\t [%d/%d]" %(epoch, i, len(train_dataloader)))
            data_batch  = sample_batch["image"].to(device)
            label_batch = sample_batch["label"].to(device)

            output = model(data_batch)
            loss = criterion(output, label_batch.long())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predict= torch.argmax(output, 1)
            correct_total += torch.eq(predict, label_batch).sum().item()
            train_loss += loss.item()

        train_acc = correct_total / len(train_set)

        if train_acc > max_acc or train_loss < last_loss:
            print("save model at epoch {}".format(epoch))
            torch.save(model.state_dict(), os.path.join(args.model, "Classification_{}.pth").format(epoch))
            max_acc = train_acc
            last_loss = train_loss
            print("max_acc: %.5f\t last_loss: %.5f"%(max_acc, last_loss))

        # val_loss, val_acc = val()
        f = open("out.txt", "a")
        print("[%d/%d]\t train_loss:%.5f\t train_acc:%.5f\t "%
              (epoch, args.num_epochs, train_loss, train_acc), file=f)
        f.close()
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))

  1. train之前需要model.train()
  2. 对于model、data都需要进行to(device)
  3. 关于criterion = nn.CrossEntropyLoss()传入参数,第一个参数是模型输出output,是向量,而第二个参数label是groundtruth,是一个标量。
  4. correct_total += torch.eq(predict, label_batch).sum().item()
    train_loss += loss.item()
    当我们需要将模型中变量提取出来参与计算时,需要.item()
  5. 需要将print输出到文件时:print("",file=f)

相关文章

网友评论

      本文标题:pytorch代码常见问题及注意点

      本文链接:https://www.haomeiwen.com/subject/jppgtctx.html