> library(pacman)
> p_load(data.table, mlr3, mlr3learners, mlr3viz, ggplot2, rchallenge, 
+        DataExplorer, dplyr)


> data("german", package = "rchallenge")
> str(german)
## 'data.frame':    1000 obs. of  21 variables:
##  $ status                 : Factor w/ 4 levels "no checking account",..: 1 2 4 1 1 4 4 2 4 2 ...
##  $ duration               : int  6 48 12 42 24 36 24 36 12 30 ...
##  $ credit_history         : Factor w/ 5 levels "delay in paying off in the past",..: 5 3 5 3 4 3 3 3 3 5 ...
##  $ purpose                : Factor w/ 11 levels "others","car (new)",..: 4 4 7 3 1 7 3 2 4 1 ...
##  $ amount                 : int  1169 5951 2096 7882 4870 9055 2835 6948 3059 5234 ...
##  $ savings                : Factor w/ 5 levels "unknown/no savings account",..: 5 1 1 1 1 5 3 1 4 1 ...
##  $ employment_duration    : Factor w/ 5 levels "unemployed","< 1 yr",..: 5 3 4 4 3 3 5 3 4 1 ...
##  $ installment_rate       : Ord.factor w/ 4 levels ">= 35"<"25 <= ... < 35"<..: 4 2 2 2 3 2 3 2 2 4 ...
##  $ personal_status_sex    : Factor w/ 4 levels "male : divorced/separated",..: 3 2 3 3 3 3 3 3 1 4 ...
##  $ other_debtors          : Factor w/ 3 levels "none","co-applicant",..: 1 1 1 3 1 1 1 1 1 1 ...
##  $ present_residence      : Ord.factor w/ 4 levels "< 1 yr"<"1 <= ... < 4 yrs"<..: 4 2 3 4 4 4 4 2 4 2 ...
##  $ property               : Factor w/ 4 levels "unknown / no property",..: 1 1 1 2 4 4 2 3 1 3 ...
##  $ age                    : int  67 22 49 45 53 35 53 35 61 28 ...
##  $ other_installment_plans: Factor w/ 3 levels "bank","stores",..: 3 3 3 3 3 3 3 3 3 3 ...
##  $ housing                : Factor w/ 3 levels "for free","rent",..: 2 2 2 3 3 3 2 1 2 2 ...
##  $ number_credits         : Ord.factor w/ 4 levels "1"<"2-3"<"4-5"<..: 2 1 1 1 2 1 1 1 1 2 ...
##  $ job                    : Factor w/ 4 levels "unemployed/unskilled - non-resident",..: 3 3 2 3 3 2 3 4 2 4 ...
##  $ people_liable          : Factor w/ 2 levels "0 to 2","3 or more": 1 1 2 2 2 2 1 1 1 1 ...
##  $ telephone              : Factor w/ 2 levels "no","yes (under customer name)": 2 1 1 1 1 2 1 2 1 1 ...
##  $ foreign_worker         : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
##  $ credit_risk            : Factor w/ 2 levels "good","bad": 1 2 1 1 2 1 1 1 1 2 ...


> skimr::skim(german)

Table: Data summary

Name german
Number of rows 1000
Number of columns 21
Column type frequency:
factor 18
numeric 3
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
status 0 1 FALSE 4 ...: 394, no : 274, ...: 269, 0<=: 63
credit_history 0 1 FALSE 5 no : 530, all: 293, exi: 88, cri: 49
purpose 0 1 FALSE 10 fur: 280, oth: 234, car: 181, car: 103
savings 0 1 FALSE 5 unk: 603, ...: 183, ...: 103, 100: 63
employment_duration 0 1 FALSE 5 1 <: 339, >= : 253, 4 <: 174, < 1: 172
installment_rate 0 1 TRUE 4 < 2: 476, 25 : 231, 20 : 157, >= : 136
personal_status_sex 0 1 FALSE 4 mal: 548, fem: 310, fem: 92, mal: 50
other_debtors 0 1 FALSE 3 non: 907, gua: 52, co-: 41
present_residence 0 1 TRUE 4 >= : 413, 1 <: 308, 4 <: 149, < 1: 130
property 0 1 FALSE 4 bui: 332, unk: 282, car: 232, rea: 154
other_installment_plans 0 1 FALSE 3 non: 814, ban: 139, sto: 47
housing 0 1 FALSE 3 ren: 713, for: 179, own: 108
number_credits 0 1 TRUE 4 1: 633, 2-3: 333, 4-5: 28, >= : 6
job 0 1 FALSE 4 ski: 630, uns: 200, man: 148, une: 22
people_liable 0 1 FALSE 2 0 t: 845, 3 o: 155
telephone 0 1 FALSE 2 no: 596, yes: 404
foreign_worker 0 1 FALSE 2 no: 963, yes: 37
credit_risk 0 1 FALSE 2 goo: 700, bad: 300

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
duration 0 1 20.90 12.06 4 12.0 18.0 24.00 72 ▇▇▂▁▁
amount 0 1 3271.26 2822.74 250 1365.5 2319.5 3972.25 18424 ▇▂▁▁▁
age 0 1 35.55 11.38 19 27.0 33.0 42.00 75 ▇▆▃▁▁




> german.s <- german
> # 筛选所有因子型变量
> fac <- sapply(german, is.factor)
> # 将因子变量的leverl值缩减到12字符内,中间用...代替
> # 否则画出的图会非常难看
> short_level <- function(x) {
+     levels(x) <- abbreviate(mlr3misc::str_trunc(levels(x), 16, "..."), 12)
+     return(x)
+   }
> german.s[fac] <- lapply(german[fac], short_level)
> plot_bar(german.s, nrow = 2, ncol = 3, ggtheme = theme_bw())
> plot_histogram(german.s, nrow = 1, ggtheme = theme_bw())
> plot_boxplot(german.s, by = "credit_risk", nrow = 1, ggtheme = theme_bw())


3.1 创建分类任务


> task <- TaskClassif$new("GermanCredit", german, target = "credit_risk")

3.2 创建学习器


> mlr_learners
## <DictionaryLearner> with 28 stored values
## Keys: classif.cv_glmnet, classif.debug, classif.featureless, classif.glmnet, classif.kknn,
##   classif.lda, classif.log_reg, classif.multinom, classif.naive_bayes, classif.qda,
##   classif.ranger, classif.rpart, classif.svm, classif.xgboost, regr.cv_glmnet, regr.featureless,
##   regr.glmnet, regr.kknn, regr.km, regr.lm, regr.ranger, regr.rpart, regr.svm, regr.xgboost,
##   surv.cv_glmnet, surv.glmnet, surv.ranger, surv.xgboost


> learner.glm <- lrn("classif.log_reg")

3.3 拆分训练集和测试集

> dtrain <- sample(task$row_ids, 0.8 * task$nrow)
> dtest <- setdiff(task$row_ids, dtrain)

3.4 训练模型

> learner.glm$train(task, row_ids = dtrain)

3.5 查看模型信息

> summary(learner.glm$model)
## Call:
## stats::glm(formula = task$formula(), family = "binomial", data = task$data(), 
##     model = FALSE)
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -2.1291  -0.6626  -0.3606   0.6362   2.9196  
## Coefficients:
##                                                             Estimate Std. Error z value Pr(>|z|)    
## (Intercept)                                                1.459e+00  1.167e+00   1.250 0.211187    
## age                                                       -7.533e-03  1.054e-02  -0.715 0.474839    
## amount                                                     1.612e-04  5.063e-05   3.184 0.001454 ** 
## credit_historycritical account/other credits elsewhere     4.302e-01  6.555e-01   0.656 0.511645    
## credit_historyno credits taken/all credits paid back duly -3.057e-01  5.156e-01  -0.593 0.553257    
## credit_historyexisting credits paid back duly till now    -7.092e-01  5.416e-01  -1.310 0.190329    
## credit_historyall credits at this bank paid back duly     -1.007e+00  5.005e-01  -2.013 0.044099 *  
## duration                                                   2.253e-02  1.070e-02   2.105 0.035260 *  
## employment_duration< 1 yr                                  4.608e-01  5.113e-01   0.901 0.367468    
## employment_duration1 <= ... < 4 yrs                       -1.508e-04  4.859e-01   0.000 0.999752    
## employment_duration4 <= ... < 7 yrs                       -4.023e-01  5.248e-01  -0.767 0.443343    
## employment_duration>= 7 yrs                                4.474e-02  4.851e-01   0.092 0.926509    
## foreign_workeryes                                         -1.380e+00  6.794e-01  -2.031 0.042206 *  
## housingrent                                               -5.602e-01  2.728e-01  -2.054 0.040020 *  
## housingown                                                -7.413e-01  5.572e-01  -1.330 0.183366    
## installment_rate.L                                         7.886e-01  2.492e-01   3.164 0.001554 ** 
## installment_rate.Q                                         6.319e-02  2.265e-01   0.279 0.780229    
## installment_rate.C                                         1.790e-02  2.312e-01   0.077 0.938308    
## jobunskilled - resident                                    5.529e-02  7.848e-01   0.070 0.943837    
## jobskilled employee/official                               2.578e-01  7.585e-01   0.340 0.733955    
## jobmanager/self-empl/highly qualif. employee               6.324e-02  7.740e-01   0.082 0.934883    
## number_credits.L                                           1.444e-01  7.801e-01   0.185 0.853098    
## number_credits.Q                                           4.038e-02  6.684e-01   0.060 0.951826    
## number_credits.C                                           2.060e-01  5.314e-01   0.388 0.698229    
## other_debtorsco-applicant                                  1.748e-01  4.884e-01   0.358 0.720509    
## other_debtorsguarantor                                    -9.562e-01  4.758e-01  -2.009 0.044489 *  
## other_installment_plansstores                              1.843e-01  4.897e-01   0.376 0.706703    
## other_installment_plansnone                               -7.498e-01  2.747e-01  -2.729 0.006350 ** 
## people_liable3 or more                                     3.275e-01  2.819e-01   1.162 0.245299    
## personal_status_sexfemale : non-single or male : single   -2.046e-01  4.727e-01  -0.433 0.665179    
## personal_status_sexmale : married/widowed                 -8.314e-01  4.675e-01  -1.778 0.075325 .  
## personal_status_sexfemale : single                        -2.847e-01  5.464e-01  -0.521 0.602363    
## present_residence.L                                        1.002e-01  2.524e-01   0.397 0.691418    
## present_residence.Q                                       -5.788e-01  2.334e-01  -2.479 0.013161 *  
## present_residence.C                                        3.899e-01  2.298e-01   1.697 0.089786 .  
## propertycar or other                                       5.373e-01  2.953e-01   1.819 0.068884 .  
## propertybuilding soc. savings agr. / life insurance        4.601e-01  2.750e-01   1.673 0.094288 .  
## propertyreal estate                                        1.031e+00  4.942e-01   2.086 0.037015 *  
## purposecar (new)                                          -1.997e+00  4.505e-01  -4.433 9.28e-06 ***
## purposecar (used)                                         -7.719e-01  2.948e-01  -2.618 0.008838 ** 
## purposefurniture/equipment                                -9.267e-01  2.824e-01  -3.282 0.001032 ** 
## purposeradio/television                                    4.374e-01  8.456e-01   0.517 0.604946    
## purposedomestic appliances                                -2.571e-01  6.137e-01  -0.419 0.675191    
## purposerepairs                                            -3.256e-01  4.619e-01  -0.705 0.480833    
## purposevacation                                           -8.983e-01  1.324e+00  -0.678 0.497555    
## purposeretraining                                         -5.016e-01  3.885e-01  -1.291 0.196597    
## purposebusiness                                           -1.330e+00  8.233e-01  -1.615 0.106332    
## savings... < 100 DM                                       -4.541e-01  3.539e-01  -1.283 0.199398    
## savings100 <= ... < 500 DM                                -2.827e-01  4.183e-01  -0.676 0.499153    
## savings500 <= ... < 1000 DM                               -1.635e+00  6.632e-01  -2.466 0.013677 *  
## savings... >= 1000 DM                                     -1.038e+00  3.056e-01  -3.397 0.000682 ***
## status... < 0 DM                                          -4.535e-01  2.477e-01  -1.831 0.067076 .  
## status0<= ... < 200 DM                                    -1.313e+00  4.367e-01  -3.007 0.002641 ** 
## status... >= 200 DM / salary for at least 1 year          -1.838e+00  2.669e-01  -6.886 5.73e-12 ***
## telephoneyes (under customer name)                        -4.866e-01  2.333e-01  -2.086 0.036975 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## (Dispersion parameter for binomial family taken to be 1)
##     Null deviance: 972.25  on 799  degrees of freedom
## Residual deviance: 694.72  on 745  degrees of freedom
## AIC: 804.72
## Number of Fisher Scoring iterations: 5

参数importance = "permutation"表示对特征做重要性排序。

> learner.rf <- lrn("classif.ranger", importance = "permutation")
> learner.rf$train(task, row_ids = dtrain)
> # 查看变量重要性
> imp <- learner.rf$importance()
> imp
##                  status                duration                  amount          credit_history 
##            3.693883e-02            1.648268e-02            1.562386e-02            1.163516e-02 
##                 savings                property        installment_rate                     age 
##            7.036720e-03            6.318439e-03            5.071133e-03            5.020740e-03 
##     employment_duration other_installment_plans                 purpose           other_debtors 
##            4.320076e-03            4.002393e-03            3.483755e-03            3.380466e-03 
##          number_credits                 housing               telephone     personal_status_sex 
##            2.506178e-03            1.951486e-03            1.670096e-03            1.243216e-03 
##                     job           people_liable       present_residence          foreign_worker 
##            9.502581e-04            6.896296e-04            5.660758e-04            9.835257e-05


> data.frame(Feature = names(imp), Importance = as.vector(imp)) %>% 
+   ggplot(aes(Importance, reorder(Feature, Importance))) +
+   geom_bar(stat = "identity") +
+   theme_bw() +
+   labs(y = "", title = "特征重要性排序") +
+   theme(plot.title = element_text(hjust = 0.5))


> # 逻辑回归模型
> pred.glm <- learner.glm$predict(task, row_ids = dtest)
> # 随机森林模型
> pred.rf <- learner.rf$predict(task, row_ids = dtest)


> pred.glm$confusion
##         truth
## response good bad
##     good  117  30
##     bad    20  33
> pred.rf$confusion
##         truth
## response good bad
##     good  119  36
##     bad    18  27


> learner.glm$predict_type <- "prob"
> pred.glm2 <- learner.glm$predict(task, row_ids = dtest)
> head(pred.glm2$data$prob)
##           good         bad
## [1,] 0.2587434 0.741256592
## [2,] 0.9936570 0.006342955
## [3,] 0.8750556 0.124944439
## [4,] 0.8562178 0.143782236
## [5,] 0.3046338 0.695366174
## [6,] 0.4106285 0.589371548


> resampling <- rsmp("holdout", ratio = 2/3)
> res <- resample(task, learner = learner.glm, resampling = resampling)
> res$aggregate()
## classif.ce 
##  0.2402402


> resampling <- rsmp("cv", folds = 10)
> res2 <- resample(task, learner = learner.glm, resampling = resampling)
> res2$aggregate()
## classif.ce 
##      0.249


> mlr_resamplings
## <DictionaryResampling> with 8 stored values
## Keys: bootstrap, custom, cv, holdout, insample, loo, repeated_cv, subsampling
> mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier, classif.ce, classif.costs,
##   classif.dor, classif.fbeta, classif.fdr, classif.fn, classif.fnr, classif.fomr, classif.fp,
##   classif.fpr, classif.logloss, classif.mbrier, classif.mcc, classif.npv, classif.ppv,
##   classif.prauc, classif.precision, classif.recall, classif.sensitivity, classif.specificity,
##   classif.tn, classif.tnr, classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
##   regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse, regr.msle, regr.pbias,
##   regr.rae, regr.rmse, regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho,
##   regr.sse, selected_features, time_both, time_predict, time_train


> # 如果数据量或模型很多,考虑使用多核并行
> future::plan("multicore")
> learners <- lrns(c("classif.log_reg", "classif.ranger"), predict_type = "prob")
> bm.design <- benchmark_grid(
+   tasks = task,
+   learners = learners,
+   resamplings = resampling
+ )
> bmr <- benchmark(bm.design)


> measures <- msrs(c("classif.ce", "classif.auc"))
> performances <- bmr$aggregate(measures = measures)
> performances[, c("learner_id", "classif.ce", "classif.auc")]
##         learner_id classif.ce classif.auc
## 1: classif.log_reg      0.242   0.7870299
## 2:  classif.ranger      0.227   0.8122784




> learner.rf$param_set$ids()
##  [1] "num.trees"                    "mtry"                         "importance"                  
##  [4] "write.forest"                 "min.node.size"                "replace"                     
##  [7] "sample.fraction"              "class.weights"                "splitrule"                   
## [10] "num.random.splits"            "split.select.weights"         "always.split.variables"      
## [13] "respect.unordered.factors"    "scale.permutation.importance" "keep.inbag"                  
## [16] "holdout"                      "num.threads"                  "save.memory"                 
## [19] "verbose"                      "oob.error"                    "max.depth"                   
## [22] "alpha"                        "min.prop"                     "regularization.factor"       
## [25] "regularization.usedepth"      "seed"                         "minprop"                     
## [28] "predict.all"                  "se.method"


> rf.med <- lrn("classif.ranger", id = "med", predict_type = "prob")
> rf.high <- lrn("classif.ranger", id = "high", predict_type = "prob",
+                num.trees = 1000, mtry = 11)
> rf.low <- lrn("classif.ranger", id = "low", predict_type = "prob",
+               num.trees = 5, mtry = 2)
> learners <- list(rf.med, rf.high, rf.low)
> bm.design2 <- benchmark_grid(
+   tasks = task,
+   learners = learners,
+   resamplings = resampling
+ )
> bmr2 <- benchmark(bm.design2)
> autoplot(bmr2)




