美文网首页
K-近邻取样

K-近邻取样

作者: 北欧森林 | 来源:发表于2021-05-07 05:05 被阅读0次
    1. 简介


      image.png
    2. 模拟数据

    set.seed(888)
    df1 <- data.frame(x1 = runif(200,0,100),x2 = runif(200,0,100))
    df1 <- transform(df1,y = 1+ifelse(100-x1-x2+rnorm(200,sd = 10) < 0,0,
                                      ifelse(100-2*x2+rnorm(200,sd=10)<2,1,2)))
    df1$y <- as.factor(df1$y)
    df1$tag <- c(rep("train",150),rep("test",50))
    str(df1)
    
    # 'data.frame': 200 obs. of  4 variables:
    #   $ x1 : num  2.55 34.67 6.12 68.38 76.73 ...
    # $ x2 : num  76.73 88.59 19.9 8.54 11.36 ...
    # $ y  : Factor w/ 3 levels "1","2","3": 2 1 3 3 3 3 1 3 3 1 ...
    # $ tag: chr  "train" "train" "train" "train" ...
    
    1. 查看模拟的数据
    library(ggplot2)
    qplot(x1,x2,data = df1,colour =y,shape =tag)
    
    image.png
    1. 整理数据并训练模型
    library(class)
    train <- df1[1:150,1:2]
    train.label <- df1[1:150,3]
    test <- df1[151:200,1:2]
    test.label <- df1[151:200,3]
    
    pred <- knn(train = train,test = test,cl = train.label,k =6)
    pred #返回test数据集里面观察对象的预测
    
    # [1] 3 1 1 3 1 3 2 1 1 3 3 3 2 1 3 1 1 2 3 1 1 1 1 1 1 2 1 1 1 1 1 1 1 3 1 1 2 3 3
    # [40] 1 1 1 2 1 2 3 1 1 1 1
    # Levels: 1 2 3
    
    1. 对拟合结果的评估观察
    #install.packages("gmodels")
    library(gmodels)
    CrossTable(x = test.label,y = pred,prop.chisq = FALSE)
    
    image.png
    1. 拟合准确性的评估
    table <- CrossTable(x = test.label,y = pred,prop.chisq = TRUE)
    tp1 <- table$t[1,1]
    tp2 <- table$t[2,2]
    tp3 <- table$t[3,3]
    tn1 <- table$t[2,2]+table$t[2,3]+table$t[3,2]+table$t[3,3]
    tn2 <- table$t[1,1]+table$t[1,3]+table$t[3,1]+table$t[3,3]
    tn3 <- table$t[1,1]+table$t[1,2]+table$t[2,1]+table$t[2,2]
    fn1 <- table$t[1,2]+table$t[1,3]
    fn2 <- table$t[2,1]+table$t[2,3]
    fn3 <- table$t[3,1]+table$t[3,2]
    fp1 <- table$t[2,1]+table$t[3,1]
    fp2 <- table$t[1,2]+table$t[3,2]
    fp3 <- table$t[1,3]+table$t[2,3]
    
    accuracy <- (((tp1+tn1)/(tp1+fn1+fp1+tn1))+((tp2+tn2)/(tp2+fn2+fp2+tn2))+((tp3+tn3)/(tp3+fn3+fp3+tn3)))/3
    
    accuracy
    #[1] 0.9333333
    
    1. 敏感性和特异性评估
    
    sen1 <- tp1/(tp1+fn1)
    sp1 <- tn1/(tn1+fp1)
    sen1
    # [1] 1
    sp1
    #[1] 0.9047619
    
    1. Multiclass area under the curve (AUC)
    library(pROC)
    multiclass.roc(response = test.label,predictor = as.ordered(pred))
    # Call:
    #   multiclass.roc.default(response = test.label, predictor = as.ordered(pred))
    # 
    # Data: as.ordered(pred) with 3 levels of test.label: 1, 2, 3.
    # Multi-class area under the curve: 0.9212
    
    1. Kappa statistic
      手动计算
    table <- table(test.label,pred)
    table
    # pred
    # test.label  1  2  3
    #   1 29  0  0
    #   2  2  6  2
    #   3  0  1 10
    
    image.png

    自动计算kappa statitic

    #install.packages("psych")
    library(psych)
    cohen.kappa(x=cbind(test.label,pred)) # 取unweighted kappa
    # Call: cohen.kappa1(x = x, w = w, n.obs = n.obs, alpha = alpha, levels = levels)
    # 
    # Cohen Kappa and Weighted Kappa correlation coefficients and confidence boundaries 
    #                   lower estimate upper
    # unweighted kappa  0.68     0.82  0.96
    # weighted kappa    0.93     0.93  0.93
    # 
    # Number of subjects = 50 
    
    1. 调整k值对knn模型预测准确性的影响
    accuracyCal <- function(N){
      accuracy <- 1
      for (x in 1:N){
        pred <- knn(train = train,test = test,cl = train.label,k =x)
        table <- table(test.label,pred)
        tp1 <- table[1,1]
        tp2 <- table[2,2]
        tp3 <- table[3,3]
        tn1 <- table[2,2]+table[2,3]+table[3,2]+table[3,3]
        tn2 <- table[1,1]+table[1,3]+table[3,1]+table[3,3]
        tn3 <- table[1,1]+table[1,2]+table[2,1]+table[2,2]
        fn1 <- table[1,2]+table[1,3]
        fn2 <- table[2,1]+table[2,3]
        fn3 <- table[3,1]+table[3,2]
        fp1 <- table[2,1]+table[3,1]
        fp2 <- table[1,2]+table[3,2]
        fp3 <- table[1,3]+table[2,3]
        accuracy <- c(accuracy,(((tp1+tn1)/(tp1+fn1+fp1+tn1))+
                                  ((tp2+tn2)/(tp2+fn2+fp2+tn2))+
                                  ((tp3+tn3)/(tp3+fn3+fp3+tn3)))/3)
      }
      return(accuracy[-1])
    }
    
    # install.packages("TeachingDemos")
    library(TeachingDemos)
    qplot(seq(1:150),accuracyCal(150),xlab = "k values",
          ylab = "Average accuracy",geom = c("point","smooth"))
    
    subplot(plot(seq(1:30),accuracyCal(30),col=2,xlab = "",ylab = "",cex.axis = 0.8),
            x = grconvertX(c(0,0.75),from = "npc"),
            y = grconvertY(c(0,0.45),from = "npc"),
            type = "fig",pars = list(mar=c(0,0,1.5,1.5)+0.1))
    
    image.png

    参考资料
    章仲恒教授丁香园课程:K-近邻取样
    Zhang Zhongheng. Introduction to machine learning: k-nearest neighbors

    相关文章

      网友评论

          本文标题:K-近邻取样

          本文链接:https://www.haomeiwen.com/subject/uilirltx.html