主要记录零售商品分类遇到的问题。
train代码如下:
def train_model():
model.train()
print("start train")
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
for epoch in range(args.num_epochs):
train_loss =0
correct_total=0
max_acc =0
last_loss =0
for i, sample_batch in enumerate(train_dataloader):
print("epoch %d:\t [%d/%d]" %(epoch, i, len(train_dataloader)))
data_batch = sample_batch["image"].to(device)
label_batch = sample_batch["label"].to(device)
output = model(data_batch)
loss = criterion(output, label_batch.long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
predict= torch.argmax(output, 1)
correct_total += torch.eq(predict, label_batch).sum().item()
train_loss += loss.item()
train_acc = correct_total / len(train_set)
if train_acc > max_acc or train_loss < last_loss:
print("save model at epoch {}".format(epoch))
torch.save(model.state_dict(), os.path.join(args.model, "Classification_{}.pth").format(epoch))
max_acc = train_acc
last_loss = train_loss
print("max_acc: %.5f\t last_loss: %.5f"%(max_acc, last_loss))
# val_loss, val_acc = val()
f = open("out.txt", "a")
print("[%d/%d]\t train_loss:%.5f\t train_acc:%.5f\t "%
(epoch, args.num_epochs, train_loss, train_acc), file=f)
f.close()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
- train之前需要
model.train()
- 对于model、data都需要进行
to(device)
- 关于
criterion = nn.CrossEntropyLoss()
传入参数,第一个参数是模型输出output,是向量,而第二个参数label是groundtruth,是一个标量。 -
correct_total += torch.eq(predict, label_batch).sum().item()
train_loss += loss.item()
当我们需要将模型中变量提取出来参与计算时,需要.item()
- 需要将print输出到文件时:
print("",file=f)
网友评论