美文网首页
Cart Algorithm and R Programming

Cart Algorithm and R Programming

作者: RouyiDing | 来源:发表于2016-07-08 17:01 被阅读168次

    首先,要明白Cart生成算法。Cart生成算法的核心是以基尼系数(Gini Index)最小化为准则生成分类树。理解下Gini Index,它用来衡量Pure程度,即一个节点中包含y因变量值的差异程度。Gini Index越小,说明y的值越一致,分类效果好,选择这样的特征作为节点,树的效率才高。

    Cart算法的基本思路(递归过程):
    Step 1: 选定training data,遍历每一个特征A,对每个特征A可取的值a,根据A=a测试是否划分为两部分,并计算Gini Index。

    Step 2: 在step 1中计算得到的Gini Index中,选择最小的Gini Index对应的A=a作为最有特征与最优切分点,由此training data被分配到了两个子节点中。

    Step 3: 重复以上步骤,直到满足停止条件。

    R 中的rpart package能够实现Cart 算法。

    R code:

    # raw data has 4521 rows and 17 columns; the last column is y
    bank <- read.csv("C:/working/summer/机器学习/决策树/bank/bank.csv",header=TRUE,sep=';')
    
    # seprate as training set & valication set
    bank_train <- bank[1:4000,]
    bank_test <- bank[4001:4521,1:16]
    bank_test1 <- bank[4001:4521,]
    
    # build tree
    library(rpart)
    fit <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact  
                 +day+month+duration+campaign+pdays+previous+poutcome,method="class",  
                 data=bank_train)  # method=class represent build classification tree
    plot(fit, uniform = TRUE,main="Classification Tree for Bank")
    text(fit,use.n = TRUE,all=TRUE)
    
    #######################################################################################################
    
    #use validation data to test the accuracy
    result <- predict(fit, bank_test,type = "class")
    
    #use a function to calculate accuracy rate
    source("C:/working/summer/机器学习/决策树/accurate rate.r")
    count_result(result,bank_test1)
    
    #######################################################################################################
    
    # deal with missing value
    # na.action 默认保留自变量缺失的观测值,删除因变量缺失的观测值
    # 但是不明白怎么保留自变量缺失的观测值??这样保留了怎么建的树?
    summary(bank) #The 4th, 9th,16th column have unknown value
    n <- nrow(bank)
    for (i in 1:n){
      if (bank[i,4]=="unknown"){
        bank[i,4]=NA
      }
      if (bank[i,9]=="unknown"){
        bank[i,9]=NA
      }
      if (bank[i,16]=="unknown"){
        bank[i,16]=NA
      }
    }
    
    fit2 <- rpart(y~.,method = "class", data=bank_train,na.action=na.rpart)  
    plot(fit,,use.n=TRUE,all=TRUE)  
    text(fit,use.n = TRUE,all=TRUE)
    result2 <- predict(fit2,bank_test,type="class")
    count_result(result2,bank_test1)
    
    ########################################################################################################
    fit3 <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact+day+month+duration+campaign+
                    pdays+previous+poutcome,method="class",data=bank_train,na.action=na.rpart,
                  control=rpart.control(minsplit=40,cp=0.001))   # minsplit越大树越简单,它表示当分类小到这个值时就停止
    result3 <- predict(fit3,bank_test,type="class")  
    count_result(result3,bank_test1)
    plot(fit3,use.n=TRUE,all=TRUE)
    

    count_result function 用来计算分类的正确率

    count_result <- function(result,data_test){
      n <- length(result)
      count_right<-0
      i <-1
      for (i in 1:n){
        if (result[i]==data_test[i,17]){
          count_right=count_right+1
        }
      }
      print(count_right/n)
    

    剪枝:

    library(rpart)
    fit <- rpart(y~age+job+marital+education+default+balance+housing+loan+contact  
                 +day+month+duration+campaign+pdays+previous+poutcome,method="class",  
                 data=bank_train,control=rpart.control(minsplit=140,cp=0.001))  # method=class represent build classification tree
    plot(fit, uniform = TRUE,main="Classification Tree for Bank")
    text(fit,use.n = TRUE,all=TRUE)
    
    # more beautiful plot
    library(rpart.plot)
    rpart.plot(fit, branch=1, branch.type=2, type=1, extra=102,  
               shadow.col="gray", box.col="green",  
               border.col="blue", split.col="red",  
               split.cex=1.2, main="Kyphosis决策树");  
    
    # prune
    printcp(fit)
    fit$cptable
    fit2 <- prune(fit, cp= fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"]) 
    rpart.plot(fit2, branch=1, branch.type=2, type=1, extra=102,  
               shadow.col="gray", box.col="green",  
               border.col="blue", split.col="red",  
               split.cex=1.2, main="Kyphosis决策树");
    

    剪枝前:4层

    Paste_Image.png

    剪枝后:3层

    Paste_Image.png

    相关文章

      网友评论

          本文标题:Cart Algorithm and R Programming

          本文链接:https://www.haomeiwen.com/subject/pzkrjttx.html