美文网首页R炒面
79-预测分析-R语言实现-逻辑回归

79-预测分析-R语言实现-逻辑回归

作者: wonphen | 来源:发表于2020-10-04 17:01 被阅读0次
> library(pacman)
> p_load(dplyr, caret, readr, DataExplorer)

1、读入数据

> heart <- read_table2("data_set/heart.dat", col_names = F)
> 
> profile_missing(heart)
## # A tibble: 14 x 3
##    feature num_missing pct_missing
##    <fct>         <int>       <dbl>
##  1 X1                0           0
##  2 X2                0           0
##  3 X3                0           0
##  4 X4                0           0
##  5 X5                0           0
##  6 X6                0           0
##  7 X7                0           0
##  8 X8                0           0
##  9 X9                0           0
## 10 X10               0           0
## 11 X11               0           0
## 12 X12               0           0
## 13 X13               0           0
## 14 X14               0           0
> summary(heart)
##        X1              X2               X3              X4              X5              X6        
##  Min.   :29.00   Min.   :0.0000   Min.   :1.000   Min.   : 94.0   Min.   :126.0   Min.   :0.0000  
##  1st Qu.:48.00   1st Qu.:0.0000   1st Qu.:3.000   1st Qu.:120.0   1st Qu.:213.0   1st Qu.:0.0000  
##  Median :55.00   Median :1.0000   Median :3.000   Median :130.0   Median :245.0   Median :0.0000  
##  Mean   :54.43   Mean   :0.6778   Mean   :3.174   Mean   :131.3   Mean   :249.7   Mean   :0.1481  
##  3rd Qu.:61.00   3rd Qu.:1.0000   3rd Qu.:4.000   3rd Qu.:140.0   3rd Qu.:280.0   3rd Qu.:0.0000  
##  Max.   :77.00   Max.   :1.0000   Max.   :4.000   Max.   :200.0   Max.   :564.0   Max.   :1.0000  
##        X7              X8              X9              X10            X11             X12        
##  Min.   :0.000   Min.   : 71.0   Min.   :0.0000   Min.   :0.00   Min.   :1.000   Min.   :0.0000  
##  1st Qu.:0.000   1st Qu.:133.0   1st Qu.:0.0000   1st Qu.:0.00   1st Qu.:1.000   1st Qu.:0.0000  
##  Median :2.000   Median :153.5   Median :0.0000   Median :0.80   Median :2.000   Median :0.0000  
##  Mean   :1.022   Mean   :149.7   Mean   :0.3296   Mean   :1.05   Mean   :1.585   Mean   :0.6704  
##  3rd Qu.:2.000   3rd Qu.:166.0   3rd Qu.:1.0000   3rd Qu.:1.60   3rd Qu.:2.000   3rd Qu.:1.0000  
##  Max.   :2.000   Max.   :202.0   Max.   :1.0000   Max.   :6.20   Max.   :3.000   Max.   :3.0000  
##       X13             X14       
##  Min.   :3.000   Min.   :1.000  
##  1st Qu.:3.000   1st Qu.:1.000  
##  Median :3.000   Median :1.000  
##  Mean   :4.696   Mean   :1.444  
##  3rd Qu.:7.000   3rd Qu.:2.000  
##  Max.   :7.000   Max.   :2.000

最后一列为因变量,表示“是否患有心脏病”,其中1代表“否”,2代表“是”。

> names(heart) <- c("age", "sex", "chestpain", "restbp", "chol", "sugar", "ecg", "maxhr", 
+                  "angina", "dep", "exercise", "pluor", "thal", "output")
> 
> heart$output <- heart$output - 1

将分类变量转换为因子类型:

> heart <- heart %>% 
+   mutate(across(c("chestpain", "ecg", "thal", "exercise", "output"), .fns = as.factor))
> # 去掉自带的属性值
> attr(heart, which = "spec") <- NULL
> str(heart)
## tibble [270 × 14] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ age      : num [1:270] 70 67 57 64 74 65 56 59 60 63 ...
##  $ sex      : num [1:270] 1 0 1 1 0 1 1 1 1 0 ...
##  $ chestpain: Factor w/ 4 levels "1","2","3","4": 4 3 2 4 2 4 3 4 4 4 ...
##  $ restbp   : num [1:270] 130 115 124 128 120 120 130 110 140 150 ...
##  $ chol     : num [1:270] 322 564 261 263 269 177 256 239 293 407 ...
##  $ sugar    : num [1:270] 0 0 0 0 0 0 1 0 0 0 ...
##  $ ecg      : Factor w/ 3 levels "0","1","2": 3 3 1 1 3 1 3 3 3 3 ...
##  $ maxhr    : num [1:270] 109 160 141 105 121 140 142 142 170 154 ...
##  $ angina   : num [1:270] 0 0 0 1 1 0 1 1 0 0 ...
##  $ dep      : num [1:270] 2.4 1.6 0.3 0.2 0.2 0.4 0.6 1.2 1.2 4 ...
##  $ exercise : Factor w/ 3 levels "1","2","3": 2 2 1 2 1 1 2 2 2 2 ...
##  $ pluor    : num [1:270] 3 0 0 1 1 0 1 1 2 3 ...
##  $ thal     : Factor w/ 3 levels "3","6","7": 1 3 3 3 1 3 2 3 3 3 ...
##  $ output   : Factor w/ 2 levels "0","1": 2 1 2 1 1 1 2 2 2 2 ...

2、数据预处理

2.1 数据集拆分为训练集和测试集

> set.seed(123)
> ind <- createDataPartition(heart$output, p = 0.85, list = F)
> dtrain <- heart[ind, ]
> dtest <- heart[-ind, ]
> 
> dim(dtrain)
## [1] 230  14
> dim(dtest)
## [1] 40 14

2.2 使用逻辑回归建模

> set.seed(123)
> fit.logit <- train(output ~ ., data = dtrain, method = "glm")
> summary(fit.logit$finalModel)
## 
## Call:
## NULL
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -2.36808  -0.42385  -0.08916   0.26064   2.89012  
## 
## Coefficients:
##               Estimate Std. Error z value Pr(>|z|)    
## (Intercept) -8.798e+00  3.756e+00  -2.342 0.019167 *  
## age         -1.739e-02  3.088e-02  -0.563 0.573473    
## sex          2.581e+00  7.263e-01   3.554 0.000379 ***
## chestpain2   1.447e+00  1.034e+00   1.400 0.161620    
## chestpain3   5.875e-01  8.565e-01   0.686 0.492806    
## chestpain4   2.566e+00  8.675e-01   2.958 0.003094 ** 
## restbp       3.294e-02  1.403e-02   2.348 0.018881 *  
## chol         1.088e-02  5.821e-03   1.870 0.061525 .  
## sugar       -1.132e+00  7.868e-01  -1.439 0.150087    
## ecg1        -1.174e+01  1.455e+03  -0.008 0.993564    
## ecg2         4.843e-01  4.769e-01   1.015 0.309935    
## maxhr       -2.576e-02  1.365e-02  -1.888 0.059067 .  
## angina       6.441e-01  5.407e-01   1.191 0.233572    
## dep          3.533e-01  2.688e-01   1.314 0.188869    
## exercise2    1.550e+00  6.072e-01   2.553 0.010684 *  
## exercise3    4.731e-01  1.153e+00   0.410 0.681527    
## pluor        1.500e+00  3.441e-01   4.361  1.3e-05 ***
## thal6       -2.093e+00  1.249e+00  -1.676 0.093769 .  
## thal7        1.586e+00  5.134e-01   3.090 0.002001 ** 
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 315.9  on 229  degrees of freedom
## Residual deviance: 126.6  on 211  degrees of freedom
## AIC: 164.6
## 
## Number of Fisher Scoring iterations: 14

预测结果表明:在有其他输入特征存在的情况下,age等高P值变量并没有给模型带来什么贡献。

> # 训练集正确率
> train.hat <- ifelse(fit.logit$finalModel$fitted.values > 0.5, 1, 0)
> confusionMatrix(as.factor(train.hat), dtrain$output)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 120  13
##          1   8  89
##                                           
##                Accuracy : 0.9087          
##                  95% CI : (0.8638, 0.9426)
##     No Information Rate : 0.5565          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8141          
##                                           
##  Mcnemar's Test P-Value : 0.3827          
##                                           
##             Sensitivity : 0.9375          
##             Specificity : 0.8725          
##          Pos Pred Value : 0.9023          
##          Neg Pred Value : 0.9175          
##              Prevalence : 0.5565          
##          Detection Rate : 0.5217          
##    Detection Prevalence : 0.5783          
##       Balanced Accuracy : 0.9050          
##                                           
##        'Positive' Class : 0               
## 

2.3 测试集表现

> test.hat <- predict(fit.logit, newdata = dtest, type = "raw")
> confusionMatrix(test.hat, dtest$output)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 18  7
##          1  4 11
##                                          
##                Accuracy : 0.725          
##                  95% CI : (0.5611, 0.854)
##     No Information Rate : 0.55           
##     P-Value [Acc > NIR] : 0.01789        
##                                          
##                   Kappa : 0.4359         
##                                          
##  Mcnemar's Test P-Value : 0.54649        
##                                          
##             Sensitivity : 0.8182         
##             Specificity : 0.6111         
##          Pos Pred Value : 0.7200         
##          Neg Pred Value : 0.7333         
##              Prevalence : 0.5500         
##          Detection Rate : 0.4500         
##    Detection Prevalence : 0.6250         
##       Balanced Accuracy : 0.7146         
##                                          
##        'Positive' Class : 0              
## 

模型在测试集上的准确率远低于训练集上的准确率,说明存在严重的过拟合问题。同时,根据常识,年龄(age)越大患心脏病的概率应该越大,但是模型中age并不显著,所以说明可能存在共线性问题。

3、正则化

> set.seed(123)
> fit.glmnet <- train(output ~ ., data = dtrain, method = "glmnet")
> 
> fit.glmnet$bestTune
##   alpha     lambda
## 6  0.55 0.05095497
> # 训练集表现
> confusionMatrix(predict(fit.glmnet, newdata = dtrain, type = "raw"), dtrain$output)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 119  18
##          1   9  84
##                                           
##                Accuracy : 0.8826          
##                  95% CI : (0.8338, 0.9212)
##     No Information Rate : 0.5565          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.76            
##                                           
##  Mcnemar's Test P-Value : 0.1237          
##                                           
##             Sensitivity : 0.9297          
##             Specificity : 0.8235          
##          Pos Pred Value : 0.8686          
##          Neg Pred Value : 0.9032          
##              Prevalence : 0.5565          
##          Detection Rate : 0.5174          
##    Detection Prevalence : 0.5957          
##       Balanced Accuracy : 0.8766          
##                                           
##        'Positive' Class : 0               
## 
> # 测试集表现
> confusionMatrix(predict(fit.glmnet, newdata = dtest, type = "raw"), dtest$output)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  0  1
##          0 19  5
##          1  3 13
##                                           
##                Accuracy : 0.8             
##                  95% CI : (0.6435, 0.9095)
##     No Information Rate : 0.55            
##     P-Value [Acc > NIR] : 0.0008833       
##                                           
##                   Kappa : 0.5918          
##                                           
##  Mcnemar's Test P-Value : 0.7236736       
##                                           
##             Sensitivity : 0.8636          
##             Specificity : 0.7222          
##          Pos Pred Value : 0.7917          
##          Neg Pred Value : 0.8125          
##              Prevalence : 0.5500          
##          Detection Rate : 0.4750          
##    Detection Prevalence : 0.6000          
##       Balanced Accuracy : 0.7929          
##                                           
##        'Positive' Class : 0               
## 

训练集和测试集的准确率都约为80%,好于逻辑回归。

> # 查看回归系数
> coef(fit.glmnet$finalModel)[, 6]
##  (Intercept)          age          sex   chestpain2   chestpain3   chestpain4       restbp         chol 
## -0.202328200  0.000000000  0.000000000  0.000000000  0.000000000  0.293731204  0.000000000  0.000000000 
##        sugar         ecg1         ecg2        maxhr       angina          dep    exercise2    exercise3 
##  0.000000000  0.000000000  0.000000000 -0.003020824  0.065690202  0.041143732  0.000000000  0.000000000 
##        pluor        thal6        thal7 
##  0.100495948  0.000000000  0.385824282

可以看到,有的系数为0,说明实际上这些特征已经从模型中去除了。

4、ROC曲线

> ROSE::roc.curve(predict(fit.glmnet, newdata = dtest, type = "raw"), dtest$output)
## Area under the curve (AUC): 0.802
ROC曲线

相关文章

网友评论

    本文标题:79-预测分析-R语言实现-逻辑回归

    本文链接:https://www.haomeiwen.com/subject/chpnuktx.html