美文网首页
多分类下的ROC曲线及2019/12/19备忘录

多分类下的ROC曲线及2019/12/19备忘录

作者: yumiii_ | 来源:发表于2019-12-29 21:23 被阅读0次

    如何在Pytorch框架下使用tensorboardX:
    [Pytorch]PyTorch使用tensorboardX(转

    textCNN论文:
    https://arxiv.org/abs/1408.5882
    RCNN论文路径:
    file:///Users/姓名/Downloads/9745-44425-1-PB.pdf
    文本分类综述论文路径:
    file:///Users/姓名/Downloads/information-10-00150-v2.pdf

    多分类下的ROC曲线
    指标介绍:AUC,ROC我看到的最透彻的讲解
    ROC及AUC计算方法及原理
    ROC原理介绍及利用python实现二分类和多分类的ROC曲线
    https://stackoverflow.com/questions/39685740/calculate-sklearn-roc-auc-score-for-multi-class
    scikit-learn官方文档-如何画多分类ROC曲线 一定要看这个!这个才是最好的!!!
    https://www.jianshu.com/p/00ef5b63dfc8

    碎碎念:
    我觉得深度学习的本质就是维度。。。我又调了一晚上的维度,就画个ROC而已。。说明了我不适合写代码
    在大佬的代码上补补改改,今日份的hello world:

        
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve, auc
    from sklearn.preprocessing import label_binarize
    from itertools import cycle
    def evaluate(config, model, data_iter, test=False):
        model.eval()
        loss_total = 0
        predict_all = np.array([], dtype=int)
        labels_all = np.array([], dtype=int)
    
        #######for roc#######
        predict_all_roc = np.empty(shape=[0,10])
        labels_all_roc = np.empty(shape=[0,10])
        y_score_roc = np.empty(shape=[0,10])
    
        with torch.no_grad():
            for texts, labels in data_iter:
    
                outputs = model(texts)
                outputs_softmax = F.softmax(outputs,dim=1)
                loss = F.cross_entropy(outputs, labels)
                loss_total += loss
    
                labels = labels.data.cpu().numpy()
    
                predic = torch.max(outputs.data, 1)[1].cpu().numpy()
    
                labels_all = np.append(labels_all, labels)
                predict_all = np.append(predict_all, predic)
    
    
                ########for roc###########
                labels_roc = label_binarize(labels, classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
                labels_all_roc = np.concatenate((labels_all_roc, labels_roc), axis=0)
                predic_roc = label_binarize(predic, classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
                predict_all_roc = np.concatenate((predict_all_roc,predic_roc),axis=0)
                y_score_roc = np.concatenate((y_score_roc,outputs_softmax),axis=0)
    
    
        acc = metrics.accuracy_score(labels_all, predict_all)
        if test:
            # labels_all = label_binarize(labels_all,classes=[0,1,2,3,4,5,6,7,8,9])
            fpr = {}
            tpr = {}
            roc_auc = {}
            for i in range(10):
                fpr[i], tpr[i], _ = roc_curve(labels_all_roc[:, i], y_score_roc[:,i])
                roc_auc[i] = auc(fpr[i], tpr[i])
            fpr['micro'],tpr['micro'],_ = roc_curve(labels_all_roc.ravel(),y_score_roc.ravel())
            roc_auc['micro'] = auc(fpr['micro'],tpr['micro'])
            lw = 2
            plt.figure()
            plt.plot(fpr[2], tpr[2],
                     label='micro-average ROC curve (area = {0:0.2f})'
                           ''.format(roc_auc[2]),
                     color='deeppink', linestyle='--', linewidth=4)
            colors = cycle(['aqua', 'darkorange', 'cornflowerblue','azure','seashell','peru','mintcream','linen','indigo','gold'])
    
            for i, color in zip(range(10), colors):
                plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                         label='ROC curve of class {0} (area = {1:0.2f})'
                               ''.format(i, roc_auc[i]))
            plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver operating characteristic example')
            plt.legend(loc="lower right")
            plt.show()
            report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
            confusion = metrics.confusion_matrix(labels_all, predict_all)
            return acc, loss_total / len(data_iter), report, confusion
    
        return acc, loss_total / len(data_iter)
    # test(config=)
    from importlib import import_module
    from utils import build_dataset, build_iterator, get_time_dif
    x = import_module('models.TextRNN_Att')
    dataset = 'THUCNews'
    embedding= 'random'
    
    config = x.Config(dataset, embedding)
    # print(vars(config))
    vocab, train_data, dev_data, test_data = build_dataset(config,False)
    #
    test_iter = build_iterator(test_data, config)
    model = x.Model(config).to(config.device)
    # print(vars(model))
    test(config,model,test_iter)
    

    另外我真是黑人问号脸了,就这评价指标AUC能达到0.99,为了避免哪里问题我还把每一类都打印出来了,什么鬼啊,,这科学吗????有没有大佬知道是什么情况啊。。科学不科学。。疯了


    P-R-F1
    ROC曲线

    相关文章

      网友评论

          本文标题:多分类下的ROC曲线及2019/12/19备忘录

          本文链接:https://www.haomeiwen.com/subject/svcuoctx.html