> library(pacman)
> p_load(dplyr, caret, readr, DataExplorer)
1、读入数据
> heart <- read_table2("data_set/heart.dat", col_names = F)
>
> profile_missing(heart)
## # A tibble: 14 x 3
## feature num_missing pct_missing
## <fct> <int> <dbl>
## 1 X1 0 0
## 2 X2 0 0
## 3 X3 0 0
## 4 X4 0 0
## 5 X5 0 0
## 6 X6 0 0
## 7 X7 0 0
## 8 X8 0 0
## 9 X9 0 0
## 10 X10 0 0
## 11 X11 0 0
## 12 X12 0 0
## 13 X13 0 0
## 14 X14 0 0
> summary(heart)
## X1 X2 X3 X4 X5 X6
## Min. :29.00 Min. :0.0000 Min. :1.000 Min. : 94.0 Min. :126.0 Min. :0.0000
## 1st Qu.:48.00 1st Qu.:0.0000 1st Qu.:3.000 1st Qu.:120.0 1st Qu.:213.0 1st Qu.:0.0000
## Median :55.00 Median :1.0000 Median :3.000 Median :130.0 Median :245.0 Median :0.0000
## Mean :54.43 Mean :0.6778 Mean :3.174 Mean :131.3 Mean :249.7 Mean :0.1481
## 3rd Qu.:61.00 3rd Qu.:1.0000 3rd Qu.:4.000 3rd Qu.:140.0 3rd Qu.:280.0 3rd Qu.:0.0000
## Max. :77.00 Max. :1.0000 Max. :4.000 Max. :200.0 Max. :564.0 Max. :1.0000
## X7 X8 X9 X10 X11 X12
## Min. :0.000 Min. : 71.0 Min. :0.0000 Min. :0.00 Min. :1.000 Min. :0.0000
## 1st Qu.:0.000 1st Qu.:133.0 1st Qu.:0.0000 1st Qu.:0.00 1st Qu.:1.000 1st Qu.:0.0000
## Median :2.000 Median :153.5 Median :0.0000 Median :0.80 Median :2.000 Median :0.0000
## Mean :1.022 Mean :149.7 Mean :0.3296 Mean :1.05 Mean :1.585 Mean :0.6704
## 3rd Qu.:2.000 3rd Qu.:166.0 3rd Qu.:1.0000 3rd Qu.:1.60 3rd Qu.:2.000 3rd Qu.:1.0000
## Max. :2.000 Max. :202.0 Max. :1.0000 Max. :6.20 Max. :3.000 Max. :3.0000
## X13 X14
## Min. :3.000 Min. :1.000
## 1st Qu.:3.000 1st Qu.:1.000
## Median :3.000 Median :1.000
## Mean :4.696 Mean :1.444
## 3rd Qu.:7.000 3rd Qu.:2.000
## Max. :7.000 Max. :2.000
最后一列为因变量,表示“是否患有心脏病”,其中1代表“否”,2代表“是”。
> names(heart) <- c("age", "sex", "chestpain", "restbp", "chol", "sugar", "ecg", "maxhr",
+ "angina", "dep", "exercise", "pluor", "thal", "output")
>
> heart$output <- heart$output - 1
将分类变量转换为因子类型:
> heart <- heart %>%
+ mutate(across(c("chestpain", "ecg", "thal", "exercise", "output"), .fns = as.factor))
> # 去掉自带的属性值
> attr(heart, which = "spec") <- NULL
> str(heart)
## tibble [270 × 14] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
## $ age : num [1:270] 70 67 57 64 74 65 56 59 60 63 ...
## $ sex : num [1:270] 1 0 1 1 0 1 1 1 1 0 ...
## $ chestpain: Factor w/ 4 levels "1","2","3","4": 4 3 2 4 2 4 3 4 4 4 ...
## $ restbp : num [1:270] 130 115 124 128 120 120 130 110 140 150 ...
## $ chol : num [1:270] 322 564 261 263 269 177 256 239 293 407 ...
## $ sugar : num [1:270] 0 0 0 0 0 0 1 0 0 0 ...
## $ ecg : Factor w/ 3 levels "0","1","2": 3 3 1 1 3 1 3 3 3 3 ...
## $ maxhr : num [1:270] 109 160 141 105 121 140 142 142 170 154 ...
## $ angina : num [1:270] 0 0 0 1 1 0 1 1 0 0 ...
## $ dep : num [1:270] 2.4 1.6 0.3 0.2 0.2 0.4 0.6 1.2 1.2 4 ...
## $ exercise : Factor w/ 3 levels "1","2","3": 2 2 1 2 1 1 2 2 2 2 ...
## $ pluor : num [1:270] 3 0 0 1 1 0 1 1 2 3 ...
## $ thal : Factor w/ 3 levels "3","6","7": 1 3 3 3 1 3 2 3 3 3 ...
## $ output : Factor w/ 2 levels "0","1": 2 1 2 1 1 1 2 2 2 2 ...
2、数据预处理
2.1 数据集拆分为训练集和测试集
> set.seed(123)
> ind <- createDataPartition(heart$output, p = 0.85, list = F)
> dtrain <- heart[ind, ]
> dtest <- heart[-ind, ]
>
> dim(dtrain)
## [1] 230 14
> dim(dtest)
## [1] 40 14
2.2 使用逻辑回归建模
> set.seed(123)
> fit.logit <- train(output ~ ., data = dtrain, method = "glm")
> summary(fit.logit$finalModel)
##
## Call:
## NULL
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.36808 -0.42385 -0.08916 0.26064 2.89012
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -8.798e+00 3.756e+00 -2.342 0.019167 *
## age -1.739e-02 3.088e-02 -0.563 0.573473
## sex 2.581e+00 7.263e-01 3.554 0.000379 ***
## chestpain2 1.447e+00 1.034e+00 1.400 0.161620
## chestpain3 5.875e-01 8.565e-01 0.686 0.492806
## chestpain4 2.566e+00 8.675e-01 2.958 0.003094 **
## restbp 3.294e-02 1.403e-02 2.348 0.018881 *
## chol 1.088e-02 5.821e-03 1.870 0.061525 .
## sugar -1.132e+00 7.868e-01 -1.439 0.150087
## ecg1 -1.174e+01 1.455e+03 -0.008 0.993564
## ecg2 4.843e-01 4.769e-01 1.015 0.309935
## maxhr -2.576e-02 1.365e-02 -1.888 0.059067 .
## angina 6.441e-01 5.407e-01 1.191 0.233572
## dep 3.533e-01 2.688e-01 1.314 0.188869
## exercise2 1.550e+00 6.072e-01 2.553 0.010684 *
## exercise3 4.731e-01 1.153e+00 0.410 0.681527
## pluor 1.500e+00 3.441e-01 4.361 1.3e-05 ***
## thal6 -2.093e+00 1.249e+00 -1.676 0.093769 .
## thal7 1.586e+00 5.134e-01 3.090 0.002001 **
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 315.9 on 229 degrees of freedom
## Residual deviance: 126.6 on 211 degrees of freedom
## AIC: 164.6
##
## Number of Fisher Scoring iterations: 14
预测结果表明:在有其他输入特征存在的情况下,age等高P值变量并没有给模型带来什么贡献。
> # 训练集正确率
> train.hat <- ifelse(fit.logit$finalModel$fitted.values > 0.5, 1, 0)
> confusionMatrix(as.factor(train.hat), dtrain$output)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 120 13
## 1 8 89
##
## Accuracy : 0.9087
## 95% CI : (0.8638, 0.9426)
## No Information Rate : 0.5565
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.8141
##
## Mcnemar's Test P-Value : 0.3827
##
## Sensitivity : 0.9375
## Specificity : 0.8725
## Pos Pred Value : 0.9023
## Neg Pred Value : 0.9175
## Prevalence : 0.5565
## Detection Rate : 0.5217
## Detection Prevalence : 0.5783
## Balanced Accuracy : 0.9050
##
## 'Positive' Class : 0
##
2.3 测试集表现
> test.hat <- predict(fit.logit, newdata = dtest, type = "raw")
> confusionMatrix(test.hat, dtest$output)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 18 7
## 1 4 11
##
## Accuracy : 0.725
## 95% CI : (0.5611, 0.854)
## No Information Rate : 0.55
## P-Value [Acc > NIR] : 0.01789
##
## Kappa : 0.4359
##
## Mcnemar's Test P-Value : 0.54649
##
## Sensitivity : 0.8182
## Specificity : 0.6111
## Pos Pred Value : 0.7200
## Neg Pred Value : 0.7333
## Prevalence : 0.5500
## Detection Rate : 0.4500
## Detection Prevalence : 0.6250
## Balanced Accuracy : 0.7146
##
## 'Positive' Class : 0
##
模型在测试集上的准确率远低于训练集上的准确率,说明存在严重的过拟合问题。同时,根据常识,年龄(age)越大患心脏病的概率应该越大,但是模型中age并不显著,所以说明可能存在共线性问题。
3、正则化
> set.seed(123)
> fit.glmnet <- train(output ~ ., data = dtrain, method = "glmnet")
>
> fit.glmnet$bestTune
## alpha lambda
## 6 0.55 0.05095497
> # 训练集表现
> confusionMatrix(predict(fit.glmnet, newdata = dtrain, type = "raw"), dtrain$output)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 119 18
## 1 9 84
##
## Accuracy : 0.8826
## 95% CI : (0.8338, 0.9212)
## No Information Rate : 0.5565
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.76
##
## Mcnemar's Test P-Value : 0.1237
##
## Sensitivity : 0.9297
## Specificity : 0.8235
## Pos Pred Value : 0.8686
## Neg Pred Value : 0.9032
## Prevalence : 0.5565
## Detection Rate : 0.5174
## Detection Prevalence : 0.5957
## Balanced Accuracy : 0.8766
##
## 'Positive' Class : 0
##
> # 测试集表现
> confusionMatrix(predict(fit.glmnet, newdata = dtest, type = "raw"), dtest$output)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 19 5
## 1 3 13
##
## Accuracy : 0.8
## 95% CI : (0.6435, 0.9095)
## No Information Rate : 0.55
## P-Value [Acc > NIR] : 0.0008833
##
## Kappa : 0.5918
##
## Mcnemar's Test P-Value : 0.7236736
##
## Sensitivity : 0.8636
## Specificity : 0.7222
## Pos Pred Value : 0.7917
## Neg Pred Value : 0.8125
## Prevalence : 0.5500
## Detection Rate : 0.4750
## Detection Prevalence : 0.6000
## Balanced Accuracy : 0.7929
##
## 'Positive' Class : 0
##
训练集和测试集的准确率都约为80%,好于逻辑回归。
> # 查看回归系数
> coef(fit.glmnet$finalModel)[, 6]
## (Intercept) age sex chestpain2 chestpain3 chestpain4 restbp chol
## -0.202328200 0.000000000 0.000000000 0.000000000 0.000000000 0.293731204 0.000000000 0.000000000
## sugar ecg1 ecg2 maxhr angina dep exercise2 exercise3
## 0.000000000 0.000000000 0.000000000 -0.003020824 0.065690202 0.041143732 0.000000000 0.000000000
## pluor thal6 thal7
## 0.100495948 0.000000000 0.385824282
可以看到,有的系数为0,说明实际上这些特征已经从模型中去除了。
4、ROC曲线
> ROSE::roc.curve(predict(fit.glmnet, newdata = dtest, type = "raw"), dtest$output)
## Area under the curve (AUC): 0.802

网友评论