美文网首页R绘图R语言
R机器学习的Tidymodel流水线编程

R机器学习的Tidymodel流水线编程

作者: jamesjin63 | 来源:发表于2020-07-06 22:45 被阅读0次

    Tidymodels: tidy machine learning in R

    在处理数据时,有简洁的工具包,tidyverse应运而生,极大地简化数据处理流程,让数据处理变得简洁,清晰。
    但是在处理完数据后,需要对数据进行建模分析,预测与拟合,这个过程随着模型的不同而变的多元化,尤其是机器学习应用。加速了模型构建的流程化与简洁化。
    Caret的出现,让此项工作变得简洁明了。但是还是有些缺点。


    image.png

    上图基于Wickham和Grolemund撰写的《 R for Data Science》一书。
    本文中的版本详细解释了tidymodels每个程序包涵盖的步骤。在模型构建及预测过程中,tidymodels的流畅与简洁,让你体验纵享丝滑般的感受。

    在模型构建过程中,需要涉及的数据预处理及模型参数调整,这些步骤都含括在以下程序包中:

    • rsample - 数据分离重采样
    • recipes - 数据转换处理
    • parnip - 模型构建框架
    • yardstick - 模型效果评估

    下图说明了tidymodels建模步骤:


    image.png

    数据iris

    下面我们将通过iris数据来举例说明。
    首先,我们将iris数据分成训练和测试集,通过initial_split()函数实现数据拆分,可以根据prop参数,指定分离比例。分离数据后,我们可以通过training() 与testing() 函数,获取训练集和测试集的数据。

    library(tidymodels)
    
    # split
    iris_split <- initial_split(iris, prop = 0.6)
    iris_split
    
    # get training data
    iris_split %>%
      training() %>%
      glimpse()
    
    ## Observations: 90
    ## Variables: 5
    ## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.9, 5.4, 4…
    ## $ Sepal.Width  <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 3.1, 3.7, 3…
    ## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.5, 1.5, 1…
    ## $ Petal.Width  <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.1, 0.2, 0…
    ## $ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
    

    数据预处理

    recipes 包提供了多种函数,可以对数据进行预处理。包括数据的标准化,数据的相关性重复,变成亚分类变量等。

    • step_corr() - 消除相关性较高的影响
    • step_center() - 以0为中心标准化
    • step_scale() - 以1为中心标准化

    recipe还有一个好处就是,在指定数据处理时,可以用all_predictors()来指定对所有协变量进行归一化。然后all_outcomes()可以指定y。
    可以打印recipe的详细信息。里面记录了骤删除了Petal.Length变量。

    在处理完train数据后,test数据可以用bake函数进行相似的处理。然后输出为dataframe。train数据从iris_recipe输出为dataframe,可以用juice()

    # train data
    iris_recipe <- training(iris_split) %>%
      recipe(Species ~.) %>%
      step_corr(all_predictors()) %>%
      step_center(all_predictors(), -all_outcomes()) %>%
      step_scale(all_predictors(), -all_outcomes()) %>%
      prep()
      
    iris_recipe
    ## Data Recipe
    ## 
    ## Inputs:
    ## 
    ##       role #variables
    ##    outcome          1
    ##  predictor          4
    ## 
    ## Training data contained 90 data points and no missing data.
    ## 
    ## Operations:
    ## 
    ## Correlation filter removed Petal.Length [trained]
    ## Centering for Sepal.Length, Sepal.Width, Petal.Width [trained]
    ## Scaling for Sepal.Length, Sepal.Width, Petal.Width [trained]
    
    # test data
    iris_testing <- iris_recipe %>%
      bake(testing(iris_split)) 
    
    glimpse(iris_testing)
    ## Observations: 60
    ## Variables: 4
    ## $ Sepal.Length <dbl> -1.597601746, -1.138960096, 0.007644027, -0.7949788…
    ## $ Sepal.Width  <dbl> -0.41010139, 0.71517681, 2.06551064, 1.61539936, 0.…
    ## $ Petal.Width  <dbl> -1.2085003, -1.2085003, -1.2085003, -1.0796318, -1.…
    ## $ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
    

    数据建模

    在R里面,有很多关于机器学习的包,rangerrandomForest都有针对各自包的定义的参数及说明,很不方便,没有统一标准。
    tidymodels的出现,将这些机器学习的包整合到一在接口,而不是重新开发机器学习的包。更准确的说,tidymodels提供了一组用于定义模型的函数和参数。然后根据请求的建模包对模型进行拟合。
    现在我们准备根据我们的数据,建一个随机森林模型。rand_forest()函数来定义,我们的模型然后mode参数定义分类还是回归问题。mode = "classification"因为本研究是分类问题。trees可以设定节点的数。然后set_engine()很重要,可以指定我们运行的模型的引擎,可以是glm、rf等。然后用fit()函数,加载我们要拟合的数据。

    # ranger
    iris_ranger <- rand_forest(trees = 100, mode = "classification") %>%
      set_engine("ranger") %>%
      fit(Species ~ ., data = iris_training)
    
    # randomForest
    iris_rf <-  rand_forest(trees = 100, mode = "classification") %>%
      set_engine("randomForest") %>%
      fit(Species ~ ., data = iris_training)
    

    总的来说,模型构建的步骤分为三部,选定模型, set_engine 然后 fit数据。流水线式操作。

    预测

    针对arsnip的predict()函数,可以返回tibble数据格式。默认情况下,预测变量称为.pred_class。在示例中,test的数据是bake以后的--数据预处理后的testing data。然后我们将其合并入test数据集中。

    predict(iris_ranger, iris_testing)
    
    iris_ranger %>%
      predict(iris_testing) %>%
      bind_cols(iris_testing)
     
    iris_ranger
    
    ## Observations: 60
    ## Variables: 5
    ## $ .pred_class  <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
    ## $ Sepal.Length <dbl> -1.597601746, -1.138960096, 0.007644027, -0.7949788…
    ## $ Sepal.Width  <dbl> -0.41010139, 0.71517681, 2.06551064, 1.61539936, 0.…
    ## $ Petal.Width  <dbl> -1.2085003, -1.2085003, -1.2085003, -1.0796318, -1.…
    ## $ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, set…
    
    
    iris_ranger %>%
      predict(iris_testing, type = "prob") %>%
      glimpse()
      
    ## Observations: 60
    ## Variables: 3
    ## $ .pred_setosa     <dbl> 0.677480159, 0.978293651, 0.783250000, 0.983972…
    ## $ .pred_versicolor <dbl> 0.295507937, 0.011706349, 0.150833333, 0.001111…
    ## $ .pred_virginica  <dbl> 0.02701190, 0.01000000, 0.06591667, 0.01491667,…
    
    

    该模型预测的结果为分类变量,当然有时候会根据需要,预测每个类别的概率,所以可以通过predict函数中的 type参数来输出为概率。

    模型评估

    使用metrics()函数来衡量模型的性能。它将自动选择适合给定模型类型的指标。
    该函数需要一个包含实际结果(真相)和模型预测值(估计值)的tibble数据。

    iris_ranger %>%
      predict(iris_testing) %>%
      bind_cols(iris_testing) %>%
      metrics(truth = Species, estimate = .pred_class)
    
      
    ## # A tibble: 2 x 3
    ##   .metric  .estimator .estimate
    ##   <chr>    <chr>          <dbl>
    ## 1 accuracy multiclass     0.917
    ## 2 kap      multiclass     0.874
    
    iris_rf %>%
      predict(iris_testing) %>%
      bind_cols(iris_testing) %>%
      metrics(truth = Species, estimate = .pred_class)
      
    ## # A tibble: 2 x 3
    ##   .metric  .estimator .estimate
    ##   <chr>    <chr>          <dbl>
    ## 1 accuracy multiclass     0.883
    ## 2 kap      multiclass     0.824
    
    

    绘制分类结果的图

    iris_probs%>%
      gain_curve(Species, .pred_setosa:.pred_virginica) %>%
      autoplot()
    
    iris_probs%>%
      roc_curve(Species, .pred_setosa:.pred_virginica) %>%
      autoplot()
    

    参考

    相关文章

      网友评论

        本文标题:R机器学习的Tidymodel流水线编程

        本文链接:https://www.haomeiwen.com/subject/vejnqktx.html