美文网首页
pytorch 计算F1 score 代码

pytorch 计算F1 score 代码

作者: 默写年华Antifragile | 来源:发表于2019-09-26 16:30 被阅读0次
nb_classes = 9

confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
    for i, (inputs, classes) in enumerate(dataloaders['val']):
        inputs = inputs.to(device)
        classes = classes.to(device)
        outputs = model_ft(inputs)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(classes.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

print(confusion_matrix)

# To get the per-class accuracy: precision
precision = confusion_matrix.diag()/confusion_matrix.sum(1)
print(confusion_matrix.diag()/confusion_matrix.sum(1))

recall = confusion_matrix.diag()/confusion_matrix.sum(1)
print(confusion_matrix.diag()/confusion_matrix.sum(0))

f1 = 2*precision*recall/(precision+recall)

mean = f1.diagonal().mean()

相关文章

网友评论

      本文标题:pytorch 计算F1 score 代码

      本文链接:https://www.haomeiwen.com/subject/zdsiyctx.html