美文网首页
pytorch 计算F1 score 代码

pytorch 计算F1 score 代码

作者: 默写年华Antifragile | 来源:发表于2019-09-26 16:30 被阅读0次
    nb_classes = 9
    
    confusion_matrix = torch.zeros(nb_classes, nb_classes)
    with torch.no_grad():
        for i, (inputs, classes) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model_ft(inputs)
            _, preds = torch.max(outputs, 1)
            for t, p in zip(classes.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
    
    print(confusion_matrix)
    
    # To get the per-class accuracy: precision
    precision = confusion_matrix.diag()/confusion_matrix.sum(1)
    print(confusion_matrix.diag()/confusion_matrix.sum(1))
    
    recall = confusion_matrix.diag()/confusion_matrix.sum(1)
    print(confusion_matrix.diag()/confusion_matrix.sum(0))
    
    f1 = 2*precision*recall/(precision+recall)
    
    mean = f1.diagonal().mean()
    

    相关文章

      网友评论

          本文标题:pytorch 计算F1 score 代码

          本文链接:https://www.haomeiwen.com/subject/zdsiyctx.html