美文网首页
多分类下的ROC曲线及2019/12/19备忘录

多分类下的ROC曲线及2019/12/19备忘录

作者: yumiii_ | 来源:发表于2019-12-29 21:23 被阅读0次

如何在Pytorch框架下使用tensorboardX:
[Pytorch]PyTorch使用tensorboardX(转

textCNN论文:
https://arxiv.org/abs/1408.5882
RCNN论文路径:
file:///Users/姓名/Downloads/9745-44425-1-PB.pdf
文本分类综述论文路径:
file:///Users/姓名/Downloads/information-10-00150-v2.pdf

多分类下的ROC曲线
指标介绍:AUC,ROC我看到的最透彻的讲解
ROC及AUC计算方法及原理
ROC原理介绍及利用python实现二分类和多分类的ROC曲线
https://stackoverflow.com/questions/39685740/calculate-sklearn-roc-auc-score-for-multi-class
scikit-learn官方文档-如何画多分类ROC曲线 一定要看这个!这个才是最好的!!!
https://www.jianshu.com/p/00ef5b63dfc8

碎碎念:
我觉得深度学习的本质就是维度。。。我又调了一晚上的维度,就画个ROC而已。。说明了我不适合写代码
在大佬的代码上补补改改,今日份的hello world:

    
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from itertools import cycle
def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)

    #######for roc#######
    predict_all_roc = np.empty(shape=[0,10])
    labels_all_roc = np.empty(shape=[0,10])
    y_score_roc = np.empty(shape=[0,10])

    with torch.no_grad():
        for texts, labels in data_iter:

            outputs = model(texts)
            outputs_softmax = F.softmax(outputs,dim=1)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss

            labels = labels.data.cpu().numpy()

            predic = torch.max(outputs.data, 1)[1].cpu().numpy()

            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)


            ########for roc###########
            labels_roc = label_binarize(labels, classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
            labels_all_roc = np.concatenate((labels_all_roc, labels_roc), axis=0)
            predic_roc = label_binarize(predic, classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
            predict_all_roc = np.concatenate((predict_all_roc,predic_roc),axis=0)
            y_score_roc = np.concatenate((y_score_roc,outputs_softmax),axis=0)


    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        # labels_all = label_binarize(labels_all,classes=[0,1,2,3,4,5,6,7,8,9])
        fpr = {}
        tpr = {}
        roc_auc = {}
        for i in range(10):
            fpr[i], tpr[i], _ = roc_curve(labels_all_roc[:, i], y_score_roc[:,i])
            roc_auc[i] = auc(fpr[i], tpr[i])
        fpr['micro'],tpr['micro'],_ = roc_curve(labels_all_roc.ravel(),y_score_roc.ravel())
        roc_auc['micro'] = auc(fpr['micro'],tpr['micro'])
        lw = 2
        plt.figure()
        plt.plot(fpr[2], tpr[2],
                 label='micro-average ROC curve (area = {0:0.2f})'
                       ''.format(roc_auc[2]),
                 color='deeppink', linestyle='--', linewidth=4)
        colors = cycle(['aqua', 'darkorange', 'cornflowerblue','azure','seashell','peru','mintcream','linen','indigo','gold'])

        for i, color in zip(range(10), colors):
            plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                     label='ROC curve of class {0} (area = {1:0.2f})'
                           ''.format(i, roc_auc[i]))
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        plt.legend(loc="lower right")
        plt.show()
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion

    return acc, loss_total / len(data_iter)
# test(config=)
from importlib import import_module
from utils import build_dataset, build_iterator, get_time_dif
x = import_module('models.TextRNN_Att')
dataset = 'THUCNews'
embedding= 'random'

config = x.Config(dataset, embedding)
# print(vars(config))
vocab, train_data, dev_data, test_data = build_dataset(config,False)
#
test_iter = build_iterator(test_data, config)
model = x.Model(config).to(config.device)
# print(vars(model))
test(config,model,test_iter)

另外我真是黑人问号脸了,就这评价指标AUC能达到0.99,为了避免哪里问题我还把每一类都打印出来了,什么鬼啊,,这科学吗????有没有大佬知道是什么情况啊。。科学不科学。。疯了


P-R-F1
ROC曲线

相关文章

网友评论

      本文标题:多分类下的ROC曲线及2019/12/19备忘录

      本文链接:https://www.haomeiwen.com/subject/svcuoctx.html