美文网首页
预测三分类变量模型的ROC介绍

预测三分类变量模型的ROC介绍

作者: jamesjin63 | 来源:发表于2022-03-11 17:46 被阅读0次

    我们对Logistics回归很熟悉,预测变量y为二分类变量,然后对预测结果进行评估,会用到2*2 Matrix,计算灵敏度、特异度等及ROC曲线,判断模型预测准确性。

    但是如果遇到y为三分类变量,那么会得到3*3 Matrix 那该选用什么指标进行评估呢?

    答案:macro-average and micro-average

    接下来,我们将介绍如何建立模型预测三分类变量,及对模型准确性进行评估。

    1.模型构建

    我们根据 iris数据集中的 Species三分类变量,建立多元回归模型,根据花的特征预测Species种类,其中我们添加xv新变量;
    首先我们对 iris数据集进行拆分成 Training与Testing两个数据集,Training用于模型构建。

    # https://stackoverflow.com/questions/59205776/random-forest-svm-and-multinomial-logistic-regression-with-r
    
    library(tidyverse)
    library(randomForest)
    set.seed(123)
    head(iris)
    df=iris %>% mutate(xv=as.factor(ifelse(rnorm(150,3,4)<3,"Yes","No"))) # new predictor
    ## split da
    split1= sample(c(rep(0, 0.7 * nrow(df)), rep(1, 0.3 * nrow(df)))) 
    train <- df[split1 == 0, ]   
    test <- df[split1 == 1, ]  
    
    ## Model LM
    library(nnet)
    fit1 = multinom(Species~.,data=train)
    summary(fit1)
    
    

    fit1结果解读比二分类多一个分类。参照OR的解释。

    2.观测值VS预测值-Matrix

    构建完模型fit1后,需要对testing 数据进行预测,然后我们创建一个真实值与预测值的矩阵。

    ## Model Prediction
    pre=predict(fit1,test)
    dfpre=tibble(actual=test$Species,predicted=pre)
    table(dfpre)
    
                predicted
    actual       setosa versicolor virginica
      setosa         13          0         0
      versicolor      0         13         0
      virginica       0          1        18
    

    3.Performance Measures

    接下来对该矩阵进行分析,需要预先对矩阵的一些参数进行计算;为后续的
    Accuracy, precision, F1等。
    Source: https://www.r-bloggers.com/2016/03/computing-classification-evaluation-metrics-in-r/

    ## basic variables
    n = sum(cm) # number of instances
    nc = nrow(cm) # number of classes
    diag = diag(cm) # number of correctly classified instances per class 
    rowsums = apply(cm, 1, sum) # number of instances per class
    colsums = apply(cm, 2, sum) # number of predictions per class
    p = rowsums / n # distribution of instances over the actual classes
    q = colsums / n # distribution of instances over the predicted classes
    
    ## Accuracy
    accuracy = sum(diag) / n 
    accuracy 
    precision = diag / colsums 
    recall = diag / rowsums 
    f1 = 2 * precision * recall / (precision + recall) 
    data.frame(precision, recall, f1) 
    
    ## Macro
    macroPrecision = mean(precision)
    macroRecall = mean(recall)
    macroF1 = mean(f1)
    data.frame(macroPrecision, macroRecall, macroF1)
    
    

    上述计算过程比较繁琐,有没有一键输出的,有!接下来是一键输出

    3.1 Performance Measures 一键输出

    这里使用 Evaluate 函数进行输出,其中Evaluate代码见连接或后台私信。 Source:https://github.com/saidbleik/Evaluation/blob/master/eval.R

    results = Evaluate(actual=df3$ya, predicted=xa)
    results
    ## output
    $ConfusionMatrix
                Predicted
    Actual       setosa versicolor virginica
      setosa         13          0         0
      versicolor      0         13         0
      virginica       0          1        18
    
    $Metrics
                                    setosa versicolor virginica
    Accuracy                     0.9777778  0.9777778 0.9777778
    Precision                    1.0000000  0.9285714 1.0000000
    Recall                       1.0000000  1.0000000 0.9473684
    F1                           1.0000000  0.9629630 0.9729730
    MacroAvgPrecision            0.9761905  0.9761905 0.9761905
    MacroAvgRecall               0.9824561  0.9824561 0.9824561
    MacroAvgF1                   0.9786453  0.9786453 0.9786453
    AvgAccuracy                  0.9851852  0.9851852 0.9851852
    MicroAvgPrecision            0.9777778  0.9777778 0.9777778
    MicroAvgRecall               0.9777778  0.9777778 0.9777778
    MicroAvgF1                   0.9777778  0.9777778 0.9777778
    MajorityClassAccuracy        0.4222222  0.4222222 0.4222222
    MajorityClassPrecision       0.0000000  0.0000000 0.4222222
    MajorityClassRecall          0.0000000  0.0000000 1.0000000
    MajorityClassF1              0.0000000  0.0000000 0.5937500
    Kappa                        0.9662162  0.9662162 0.9662162
    RandomGuessAccuracy          0.3333333  0.3333333 0.3333333
    RandomGuessPrecision         0.2888889  0.2888889 0.4222222
    RandomGuessRecall            0.3333333  0.3333333 0.3333333
    RandomGuessF1                0.3095238  0.3095238 0.3725490
    RandomWeightedGuessAccuracy  0.3451852  0.3451852 0.3451852
    RandomWeightedGuessPrecision 0.2888889  0.2888889 0.4222222
    RandomWeightedGuessRecall    0.2888889  0.2888889 0.4222222
    RandomWeightedGuessF1        0.2888889  0.2888889 0.4222222
    

    4.ROC Curves Across Multi-Class Classifications

    当然我们也可以绘制 The ROC curves of micro-average and macro-average, indicating the overall distinguishing ability of the three-class classification. 但是需要分几个步骤进行:

    1. 我们原来的预测值输出是Species的分类结果,这部分我们需要输出对各种类别的概率值。
    2. 哑变量设置,将我们的 testing数据集中Species分类改成哑变量
    3. 计算 macro/micro。并绘制ROC曲线
      Source:https://mran.microsoft.com/snapshot/2018-02-12/web/packages/multiROC/vignettes/my-vignette.html

    当然这里我们需要提到一个概念:One-vs-all confusion matrices
    即针对三个变量转换成,setosa与非setosa;这样就可以得到setosa的ROC

    library(multiROC)
    actual=dummies::dummy.data.frame(test %>% select(Species),               
                                     sep = "_",            
                                     dummy.classes = "factor"  )
    
    predicted=predict(fit1,test,type = "prob")# with probability
    
    
    test_data=cbind(actual,predicted)
    colnames(test_data)=c("setosa_true","versicolor_true" ,"virginica_true",
                          "setosa_pred_m1","versicolor_pred_m1","virginica_pred_m1")
    res <- multi_roc(test_data, force_diag=T)
    res
    
    

    res里面存储了我们想要的信息,接下来对res进行提取各组的Specificity 与Sensitivity,绘制ROC曲线。

    #### ggplot ROC
    n_method <- length(unique(res$Methods))
    n_group <- length(unique(res$Groups))
    res_df <- data.frame(Specificity= numeric(0), Sensitivity= numeric(0), Group = character(0), AUC = numeric(0), Method = character(0))
    for (i in 1:n_method) {
      for (j in 1:n_group) {
        temp_data_1 <- data.frame(Specificity=res$Specificity[[i]][j],
                                  Sensitivity=res$Sensitivity[[i]][j],
                                  Group=unique(res$Groups)[j],
                                  AUC=res$AUC[[i]][j],
                                  Method = unique(res$Methods)[i])
        colnames(temp_data_1) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
        res_df <- rbind(res_df, temp_data_1)
        
      }
      temp_data_2 <- data.frame(Specificity=res$Specificity[[i]][n_group+1],
                                Sensitivity=res$Sensitivity[[i]][n_group+1],
                                Group= "Macro",
                                AUC=res$AUC[[i]][n_group+1],
                                Method = unique(res$Methods)[i])
      temp_data_3 <- data.frame(Specificity=res$Specificity[[i]][n_group+2],
                                Sensitivity=res$Sensitivity[[i]][n_group+2],
                                Group= "Micro",
                                AUC=res$AUC[[i]][n_group+2],
                                Method = unique(res$Methods)[i])
      colnames(temp_data_2) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
      colnames(temp_data_3) <- c("Specificity", "Sensitivity", "Group", "AUC", "Method")
      res_df <- rbind(res_df, temp_data_2)
      res_df <- rbind(res_df, temp_data_3)
    }
    
    ggplot(res_df, aes(x = 1-Specificity, y=Sensitivity)) + 
      geom_path(aes(color = Group, linetype=Method)) + 
      geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), colour='grey', linetype = 'dotdash') + 
      theme_bw() + 
      theme(plot.title = element_text(hjust = 0.5), 
            legend.justification=c(1, 0), 
            legend.position=c(.95, .05), 
            legend.title=element_blank(), 
            legend.background = element_rect(fill=NULL, size=0.5, linetype="solid", colour ="black"))
    
    ggsave("ROC-SVM.pdf",width = 16,height = 12,dpi=500)
    
    image.png

    最后,附上RF,SVM的模型

    #### 2.SVM
    library(e1071)
    fitsvm = svm(ya~ ., data = df2,probability=TRUE)
    summary(fitsvm)
    
    
    #### 3.RF
    library(randomForest)
    fitrf = randomForest(ya~ ., 
                       data = df2,
                       ntree = 300, # parameter setting
                       mtry = 8,
                       importance = TRUE,
                       proximity = TRUE)
    
    

    参考:
    Performance Measures for Multi-Class Problems--
    https://www.datascienceblog.net/post/machine-learning/performance-measures-multi-class-problems/

    相关文章

      网友评论

          本文标题:预测三分类变量模型的ROC介绍

          本文链接:https://www.haomeiwen.com/subject/jxhhdrtx.html