美文网首页预测Data scienceR for statistics
机器学习--有监督--支持向量机SVM

机器学习--有监督--支持向量机SVM

作者: 小贝学生信 | 来源:发表于2021-11-10 21:03 被阅读0次

支持向量机Support vector machine,SVM是一种有监督的机器学习算法,可用于分类或者回归。本次笔记以分类任务为例主要学习。

1、简单理解

  • 假设对于特征空间的样本坐标分,以样本标签为目的,存在一个超平面,将样本分为两大类,从而最大化区分两类样本。该超平面即为决策边界,而间隔就是指训练数据中最接近决策边界的样本点与决策边界之间的距离。

  • SVM过程:以间隔最大化为目的,让决策边界尽可能远离样本。然后根据这个决策边界进行样本分类预测。
    根据样本数据的复杂程度,SVM建模的过程也有所不同,具体可分为如下三类

1.1 线性可分--硬间隔

hard margin classifier,HML:此类样本数据是最简单的情况,即数据集样本可以明显地使用线性边界区分开。可参考下图,分为3个步骤

  • step1:为每个类别的分布情况绘制出外轮廓多边形;
  • step2:找出连接两个轮廓最近的两个数据点,连接;
  • step3:绘制出该连线的垂直平分线,即为最优的决策边界。而连线长度的一半即为间隔(M)

HML的特点是不允许有样本点位于间隔区内,即必须干净的划分;即不能有样本距决策边界的距离比间隔的长度还短,甚至在决策边界的另一端(误分类)

1.2 线性不可完全分--软间隔

  • 当两个类别的边界不是很明显,或者存在离群点时,如严格按照HMC,会使模型过拟合,泛化能力降低。


  • soft margin classifier,SML通过设置超参数 C表示允许间隔内存在样本的数目,优化决策边界,提高模型的泛化能力。如上所示,左图为C=0的硬间隔方法;右图为C取最大值的分类结果;
  • 针对当前数据集,可交叉验证不同C取值下的模型性能,从而确定最佳的指标。

1.3 线性不可分-核技巧

  • 上述例子数据集的决策平面都是线性的。当决策边界是一个曲线等复杂模式时,使用上述的方法无法得到满意的分类效果;

  • 针对此类情况,支持向量机可通过核方法(kernel method)寻找到合适的非线性决策边界。可分为两个步骤:



    (1)将线性不可分的数据(n个特征向量,n维)增加一个维度(enlarged kernel-induced feature space),从而成为线性可分数据(n+1)维;
    (2)在n+1 维的空间里确定合适的决策边界(线性);然后投射到原来n维空间中,即得到我们真正需要的决策平面(非线性)

  • 常见的核方法有:Radial basis function(径向基函数)、d-th degree polynomial、Hyperbolic tangent等,但通常建议使用径向基函数(extremely flexible)试一试。该方法的超参数,除了C值外,还有一个sigma,在下面代码实操中会介绍到。

2、代码实操

(1)示例数据:预测员工是否离职

library(modeldata)
data(attrition)
# initial dimension
dim(attrition)
## [1] 1470   31
library(dplyr)
df <- attrition %>%
  mutate_if(is.ordered, factor, ordered = FALSE)
# Create training (70%) and test (30%) sets
set.seed(123) # for reproducibility
library(rsample)
churn_split <- initial_split(df, prop = 0.7, strata = "Attrition")
churn_train <- training(churn_split)
churn_test <- testing(churn_split)

(2)caret包建模

  • 如上,我们使用径向基函数的核方法寻找非线性决策边界,从而建立支持向量机模型。
  • 对于超参数σ,会自动根据样本数据寻找最合适的值;对超参数C,可以通过交叉验证选择,一般备选方案为2的指数系列值(2e-2, 2e-1,2e0,2e1,2e2...)
library(caret)
set.seed(1111) # for reproducibility
# Control params for SVM
ctrl <- trainControl(
  method = "cv",
  number = 10,
  classProbs = TRUE,  #表示返回分类概率,而不是直接分类标签结果
  summaryFunction = twoClassSummary # also needed for AUC/ROC
)

churn_svm <- train(
  Attrition ~ .,
  data = churn_train,
  method = "svmRadial",
  preProcess = c("center", "scale"),
  trControl = ctrl,
  metric = "ROC", # area under ROC curve (AUC)
  tuneLength = 10) #遍历C的10次取值,即从2的-2次方到2的7次方

#如下 C取4时,模型最优
churn_svm$results %>% arrange(desc(ROC)) %>% head(1)
#     sigma C       ROC      Sens      Spec      ROCSD     SensSD     SpecSD
# 1 0.009522278 4 0.8234039 0.9791767 0.2738971 0.07462533 0.02019714 0.08679811

# Plot results
ggplot(churn_svm) 
(3)测试集验证
pred = predict(churn_svm, churn_test)
table(pred)
# pred
# No Yes 
# 415  27
table(churn_test$Attrition)
# No Yes 
# 370  72
caret::confusionMatrix(pred, churn_test$Attrition, positive="Yes")
# Accuracy : 0.871           
# 95% CI : (0.8362, 0.9008)
# No Information Rate : 0.8371          
# P-Value [Acc > NIR] : 0.02819         
# 
# Kappa : 0.3681          
# 
# Mcnemar's Test P-Value : 5.611e-09       
#                                           
#             Sensitivity : 0.29167         
#             Specificity : 0.98378         
#          Pos Pred Value : 0.77778         
#          Neg Pred Value : 0.87711         
#              Prevalence : 0.16290         
#          Detection Rate : 0.04751         
#    Detection Prevalence : 0.06109         
#       Balanced Accuracy : 0.63773         
#                                           
#        'Positive' Class : Yes

(4)衡量特征重要性

  • SVM算法本身不提供有衡量特征重要性的计算方法;
  • 可使用vip包提供的permutation test置换检验的方法,随机调整某一列(特征)值的顺序,观察预测准确率是否明显下降,从而判断特征变量的重要性。
library(vip)
prob_yes <- function(object, newdata) {
  predict(object, newdata = newdata, type = "prob")[, "Yes"]
}
# Variable importance plot
set.seed(2827) # for reproducibility
vip(churn_svm, method = "permute", nsim = 5, train = churn_train,
    target = "Attrition", metric = "auc", reference_class = "Yes",
    pred_wrapper = prob_yes)
  • pdp包观察具体某一个特征变量对于预测结果的影响
library(pdp)
features <- c("OverTime", "JobRole")
pdps <- lapply(features, function(x) {
  partial(churn_svm, pred.var = x, which.class = 2,
          prob = TRUE, plot = TRUE, plot.engine = "ggplot2") +
    coord_flip()
})
grid.arrange(grobs = pdps, nrow = 1)

image.png

相关文章

网友评论

    本文标题:机器学习--有监督--支持向量机SVM

    本文链接:https://www.haomeiwen.com/subject/hyhizltx.html