KNN是一种很简单的KNN有监督的机器学习算法;既可用于分类,也可用于回归任务。
1、KNN的简单理解
1.1 算法步骤
- (1)计算输入数据与训练数据的距离(一般欧几里得距离);
- (2)从训练集中,选取距离输入数据点最近的k个数据;
- (3)对于分类任务【常见】,取这k个训练数据类别的众数;对于回归任务,取这k个训练数据值的平均数。
1.2 特点
- (1)如上步骤,KNN没有模型训练的过程。需要预测数据时,直接与训练数据集进行计算即可。
- (2)KNN算法中最重要的一个超参数就是K的选择,会在下面具体操作中介绍。
- (3)因为需要计算距离,所以需要进行数值变量标准化,以及类别变量转化。
- (4)KNN在数据量小或者维度较小的情况下效果很好,但不适用于大规模的数据(计算量大)。
2、R代码示例
示例数据--预测员工是否离职
library(modeldata)
data(attrition)
# initial dimension
dim(attrition)
# [1] 1470 31
#去因子化
attrit <- attrition %>% mutate_if(is.ordered, factor, ordered = FALSE)
churn_split <- rsample::initial_split(attrit, prop = .7,
strata = "Attrition")
churn_train <- rsample::training(churn_split)
churn_test <- rsample::testing(churn_split)
step1:数据预处理
注意两个方面(1)数值变量标准化;(2)类别变量转化
library(recipes)
blueprint <- recipe(Attrition ~ ., data = churn_train) %>%
step_nzv(all_nominal()) %>% #去除低变异的变量
step_integer(contains("Satisfaction")) %>% #类别变量转换
step_integer(WorkLifeBalance) %>% #类别变量转换
step_integer(JobInvolvement) %>% #类别变量转换
step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE) %>% #类别变量转换
step_center(all_numeric(), -all_outcomes()) %>% #中心化
step_scale(all_numeric(), -all_outcomes()) #归一化
step2:寻找最佳k值
(1)k值的grid search遍历比较
一般来说K值取奇数:当二分类任务时,不会出现投票数相同的情况
# Create a hyperparameter grid search
str(floor(seq(1, nrow(churn_train)/3, length.out = 20)))
# num [1:20] 1 18 36 54 72 90 108 126 144 162 ...
hyper_grid <- expand.grid(
k = floor(seq(1, nrow(churn_train)/3, length.out = 20)))
(2) 交叉验证设置
repeatedcv
方法相较于之前遇到的cv
,可以理解为做多(n)次k折交叉验证,然后取n次的均值作为模型性能评价。
如下设置表示做5次k折交叉验证;每次k折分为10份,采用留一法,做10次。
# Create a resampling method
cv <- trainControl(
method = "repeatedcv",
number = 10,
repeats = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary)
ggplot(knn_grid)
(3) 确定最佳k值
# Fit knn model and perform grid search
knn_grid <- train(
blueprint,
data = churn_train,
method = "knn",
trControl = cv,
tuneGrid = hyper_grid,
metric = "ROC")
knn_grid$bestTune
# 198
knn_grid$results[knn_grid$results$k==198,]
# k ROC Sens Spec ROCSD SensSD SpecSD
# 12 198 0.8041737 1 0 0.05500748 0 0
ggplot(knn_grid)
step3:预测测试集
pred = predict(knn_grid, newdata = churn_test)
confusionMatrix(pred, churn_test$Attrition)
# Confusion Matrix and Statistics
#
# Reference
# Prediction No Yes
# No 370 72
# Yes 0 0
#
# Accuracy : 0.8371
# 95% CI : (0.7993, 0.8703)
# No Information Rate : 0.8371
# P-Value [Acc > NIR] : 0.5314
#
# Kappa : 0
#
# Mcnemar's Test P-Value : <2e-16
#
# Sensitivity : 1.0000
# Specificity : 0.0000
# Pos Pred Value : 0.8371
# Neg Pred Value : NaN
# Prevalence : 0.8371
# Detection Rate : 0.8371
# Detection Prevalence : 1.0000
# Balanced Accuracy : 0.5000
#
# 'Positive' Class : No
- 如上结果:将所有员工都预测为不离职,显然是有问题的。暂时不知道是这份数据不适合这个算法,还是我在执行过程中出的问题。
网友评论