美文网首页
用 Spark-Scala 训练 LightGBM 模型

用 Spark-Scala 训练 LightGBM 模型

作者: 程序员的隐秘角落 | 来源:发表于2021-12-29 10:27 被阅读0次

    Spark-scala 可以使用LightGBM模型,既可以进行分布式训练,也可以进行分布式预测,支持各种参数设置。
    支持模型保存,并且保存后的模型和Python等语言是可以相互调用的。
    需要注意的是,Spark-scala训练LightGBM模型时, 输入模型的训练数据集需要处理成一个DataFrame,用spark.ml.feature.VectorAssembler将多列特征转换成一个 features向量列,label作为另外一列。

    一,环境配置

    spark-scala要使用lightgbm模型,pom文件中要配置如下依赖。

    <dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_${scala.version}</artifactId>
    <version>${spark.version}</version>
    <!--spark-ml要去掉pmml-model依赖-->
    <exclusions>
        <exclusion>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-model</artifactId>
        </exclusion>
    </exclusions>
    </dependency>
    
    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>jpmml-sparkml</artifactId>
        <version>1.3.4</version>
    </dependency>
    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>jpmml-lightgbm</artifactId>
        <version>1.3.4</version>
    </dependency>
    

    二,范例代码

    下面我们以二分类问题为例,按照如下几个大家熟悉的步骤进行范例代码演示。

    • 1,准备数据
    • 2,定义模型
    • 3,训练模型
    • 4,评估模型
    • 5,使用模型
    • 6,保存模型
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.DataFrame
    import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType, IntegerType}
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.linalg.Vector
    import org.apache.spark.ml.feature.VectorAssembler
    import org.apache.spark.ml.attribute.Attribute
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
    import com.microsoft.ml.spark.{lightgbm=>lgb}
    import com.google.gson.{JsonObject, JsonParser}
    import scala.collection.JavaConverters._
    
    object LgbDemo extends Serializable {
        
        def printlog(info:String): Unit ={
            val dt = new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new java.util.Date)
            println("=========="*8+dt)
            println(info+"\n")
        }
        
        def main(args:Array[String]):Unit= {
    
    
        /*================================================================================*/
        //  一,加载数据
        /*================================================================================*/
        printlog("step1: preparing data ...")
    
        //加载数据
        val spark = SparkSession.builder().getOrCreate()
        val dfdata_raw = spark.read.option("header","true")
            .option("delimiter", "\t")
            .option("inferschema", "true")
            .option("nullValue","")
            .csv("data/breast_cancer.csv")
    
        dfdata_raw.sample(false,0.1,1).printSchema 
    
        //将特征组合成features向量
        val feature_cols = dfdata_raw.columns.filter(!Array("label").contains(_)) 
        val cate_cols = Array("mean_radius","mean_texture") 
    
    
        val vectorAssembler = new VectorAssembler().
          setInputCols(feature_cols).
          setOutputCol("features")
    
        val dfdata = vectorAssembler.transform(dfdata_raw).select("features", "label")
        val Array(dftrain,dfval)  = dfdata.randomSplit(Array(0.7, .3), 666)
    
        //各个特征的名字存储在了schema 的 metadata中了, 所以可以用特征名指定类别特征 
        println(dfdata.schema("features").metadata)
        dfdata.show(10) 
    
        /*================================================================================*/
        //  二,定义模型
        /*================================================================================*/
        printlog("step2: defining model ...")
    
        val lgbclassifier = new lgb.LightGBMClassifier()
          .setNumIterations(100)
          .setLearningRate(0.1)
          .setNumLeaves(31)
          .setMinSumHessianInLeaf(0.001)
          .setMaxDepth(-1)
          .setBoostFromAverage(false)
          .setFeatureFraction(1.0)
          .setMaxBin(255)
          .setLambdaL1(0.0)
          .setLambdaL2(0.0)
          .setBaggingFraction(1.0)
          .setBaggingFreq(0)
          .setBaggingSeed(1)
          .setBoostingType("gbdt") //rf、dart、goss
          .setCategoricalSlotNames(cate_cols)
          .setObjective("binary") //binary, multiclass
          .setFeaturesCol("features") 
          .setLabelCol("label")
    
        println(lgbclassifier.explainParams) 
    
    
        /*================================================================================*/
        //  三,训练模型
        /*================================================================================*/
        printlog("step3: training model ...")
    
        val lgbmodel = lgbclassifier.fit(dftrain)
    
        val feature_importances = lgbmodel.getFeatureImportances("gain")
        val arr = feature_cols.zip(feature_importances).sortBy[Double](t=> -t._2)
        val dfimportance = spark.createDataFrame(arr).toDF("feature_name","feature_importance(gain)")
    
        dfimportance.show(100)
    
    
        /*================================================================================*/
        //  四,评估模型
        /*================================================================================*/
        printlog("step4: evaluating model ...")
    
        val evaluator = new BinaryClassificationEvaluator()
          .setLabelCol("label")
          .setRawPredictionCol("rawPrediction")
          .setMetricName("areaUnderROC")
    
        val dftrain_result = lgbmodel.transform(dftrain)
        val dfval_result = lgbmodel.transform(dfval)
    
        val train_auc  = evaluator.evaluate(dftrain_result)
        val val_auc = evaluator.evaluate(dfval_result)
        println(s"train_auc = ${train_auc}")
        println(s"val_auc = ${val_auc}")
    
    
        /*================================================================================*/
        //  五,使用模型
        /*================================================================================*/
        printlog("step5: using model ...")
    
        //批量预测
        val dfpredict = lgbmodel.transform(dfval)
        dfpredict.sample(false,0.1,1).show(20)
    
        //对单个样本进行预测
        val features = dfval.head().getAs[Vector]("features")
        val single_result = lgbmodel.predict(features)
    
        println(single_result)
    
    
        /*================================================================================*/
        //  六,保存模型
        /*================================================================================*/
        printlog("step6: saving model ...")
    
        //保存到集群,多文件
        lgbmodel.write.overwrite().save("lgbmodel.model")
        //加载集群模型
        println("load model ...")
        val lgbmodel_loaded = lgb.LightGBMClassificationModel.load("lgbmodel.model")
        val dfresult = lgbmodel_loaded.transform(dfval)
        dfresult.show() 
    
        //保存到本地,单文件,和Python接口兼容
        //lgbmodel.saveNativeModel("lgb_model",true)
        //加载本地模型
        //val lgbmodel_loaded = LightGBMClassificationModel.loadNativeModelFromFile("lgb_model")
        
        }
        
    }
    

    三,输出参考

    运行如上代码之后,可以得到如下输出。
    注意 println(lgbclassifier.explainParams)可以获取LightGBM模型各个参数的含义以及默认值。

    ================================================================================2021-07-17 22:16:29
    step1: preparing data ...
    
    root
     |-- mean_radius: integer (nullable = true)
     |-- mean_texture: integer (nullable = true)
     |-- mean_perimeter: double (nullable = true)
     |-- mean_area: double (nullable = true)
     |-- mean_smoothness: double (nullable = true)
     |-- mean_compactness: double (nullable = true)
     |-- mean_concavity: double (nullable = true)
     |-- mean_concave_points: double (nullable = true)
     |-- mean_symmetry: double (nullable = true)
     |-- mean_fractal_dimension: double (nullable = true)
     |-- radius_error: double (nullable = true)
     |-- texture_error: double (nullable = true)
     |-- perimeter_error: double (nullable = true)
     |-- area_error: double (nullable = true)
     |-- smoothness_error: double (nullable = true)
     |-- compactness_error: double (nullable = true)
     |-- concavity_error: double (nullable = true)
     |-- concave_points_error: double (nullable = true)
     |-- symmetry_error: double (nullable = true)
     |-- fractal_dimension_error: double (nullable = true)
     |-- worst_radius: double (nullable = true)
     |-- worst_texture: double (nullable = true)
     |-- worst_perimeter: double (nullable = true)
     |-- worst_area: double (nullable = true)
     |-- worst_smoothness: double (nullable = true)
     |-- worst_compactness: double (nullable = true)
     |-- worst_concavity: double (nullable = true)
     |-- worst_concave_points: double (nullable = true)
     |-- worst_symmetry: double (nullable = true)
     |-- worst_fractal_dimension: double (nullable = true)
     |-- label: integer (nullable = true)
    
    {"ml_attr":{"attrs":{"numeric":[{"idx":0,"name":"mean_radius"},{"idx":1,"name":"mean_texture"},{"idx":2,"name":"mean_perimeter"},{"idx":3,"name":"mean_area"},{"idx":4,"name":"mean_smoothness"},{"idx":5,"name":"mean_compactness"},{"idx":6,"name":"mean_concavity"},{"idx":7,"name":"mean_concave_points"},{"idx":8,"name":"mean_symmetry"},{"idx":9,"name":"mean_fractal_dimension"},{"idx":10,"name":"radius_error"},{"idx":11,"name":"texture_error"},{"idx":12,"name":"perimeter_error"},{"idx":13,"name":"area_error"},{"idx":14,"name":"smoothness_error"},{"idx":15,"name":"compactness_error"},{"idx":16,"name":"concavity_error"},{"idx":17,"name":"concave_points_error"},{"idx":18,"name":"symmetry_error"},{"idx":19,"name":"fractal_dimension_error"},{"idx":20,"name":"worst_radius"},{"idx":21,"name":"worst_texture"},{"idx":22,"name":"worst_perimeter"},{"idx":23,"name":"worst_area"},{"idx":24,"name":"worst_smoothness"},{"idx":25,"name":"worst_compactness"},{"idx":26,"name":"worst_concavity"},{"idx":27,"name":"worst_concave_points"},{"idx":28,"name":"worst_symmetry"},{"idx":29,"name":"worst_fractal_dimension"}]},"num_attrs":30}}
    +--------------------+-----+
    |            features|label|
    +--------------------+-----+
    |[17.0,10.0,122.8,...|    0|
    |[20.0,17.0,132.9,...|    0|
    |[19.0,21.0,130.0,...|    0|
    |[11.0,20.0,77.58,...|    0|
    |[20.0,14.0,135.1,...|    0|
    |[12.0,15.0,82.57,...|    0|
    |[18.0,19.0,119.6,...|    0|
    |[13.0,20.0,90.2,5...|    0|
    |[13.0,21.0,87.5,5...|    0|
    |[12.0,24.0,83.97,...|    0|
    +--------------------+-----+
    only showing top 10 rows
    
    ================================================================================2021-07-17 22:16:29
    step2: defining model ...
    
    baggingFraction: Bagging fraction (default: 1.0, current: 1.0)
    baggingFreq: Bagging frequency (default: 0, current: 0)
    baggingSeed: Bagging seed (default: 3, current: 1)
    boostFromAverage: Adjusts initial score to the mean of labels for faster convergence (default: true, current: false)
    boostingType: Default gbdt = traditional Gradient Boosting Decision Tree. Options are: gbdt, gbrt, rf (Random Forest), random_forest, dart (Dropouts meet Multiple Additive Regression Trees), goss (Gradient-based One-Side Sampling).  (default: gbdt, current: gbdt)
    categoricalSlotIndexes: List of categorical column indexes, the slot index in the features column (undefined)
    categoricalSlotNames: List of categorical column slot names, the slot name in the features column (current: [Ljava.lang.String;@351fb3fc)
    defaultListenPort: The default listen port on executors, used for testing (default: 12400)
    earlyStoppingRound: Early stopping round (default: 0)
    featureFraction: Feature fraction (default: 1.0, current: 1.0)
    featuresCol: features column name (default: features, current: features)
    initScoreCol: The name of the initial score column, used for continued training (undefined)
    isProvideTrainingMetric: Whether output metric result over training dataset. (default: false)
    isUnbalance: Set to true if training data is unbalanced in binary classification scenario (default: false)
    labelCol: label column name (default: label, current: label)
    lambdaL1: L1 regularization (default: 0.0, current: 0.0)
    lambdaL2: L2 regularization (default: 0.0, current: 0.0)
    learningRate: Learning rate or shrinkage rate (default: 0.1, current: 0.1)
    maxBin: Max bin (default: 255, current: 255)
    maxDepth: Max depth (default: -1, current: -1)
    minSumHessianInLeaf: Minimal sum hessian in one leaf (default: 0.001, current: 0.001)
    modelString: LightGBM model to retrain (default: )
    numBatches: If greater than 0, splits data into separate batches during training (default: 0)
    numIterations: Number of iterations, LightGBM constructs num_class * num_iterations trees (default: 100, current: 100)
    numLeaves: Number of leaves (default: 31, current: 31)
    objective: The Objective. For regression applications, this can be: regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. For classification applications, this can be: binary, multiclass, or multiclassova.  (default: binary, current: binary)
    parallelism: Tree learner parallelism, can be set to data_parallel or voting_parallel (default: data_parallel)
    predictionCol: prediction column name (default: prediction)
    probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
    rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
    thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
    timeout: Timeout in seconds (default: 1200.0)
    useBarrierExecutionMode: Use new barrier execution mode in Beta testing, off by default. (default: false)
    validationIndicatorCol: Indicates whether the row is for training or validation (undefined)
    verbosity: Verbosity where lt 0 is Fatal, eq 0 is Error, eq 1 is Info, gt 1 is Debug (default: 1)
    weightCol: The name of the weight column (undefined)
    ================================================================================2021-07-17 22:16:29
    step3: training model ...
    
    +--------------------+------------------------+
    |        feature_name|feature_importance(gain)|
    +--------------------+------------------------+
    |          worst_area|       974.9349449056517|
    |     worst_perimeter|       885.3691593843923|
    |worst_concave_points|      255.67364284247745|
    | mean_concave_points|      250.21955942230738|
    |       worst_texture|      151.07745621304454|
    |          area_error|       65.75557372416814|
    |    worst_smoothness|       62.29973236144293|
    |     mean_smoothness|      19.902610011957194|
    |        worst_radius|        16.8275272153341|
    |           mean_area|       12.41261211467938|
    |      mean_perimeter|      12.127510878875537|
    |     worst_concavity|      11.414242858900646|
    |   compactness_error|      10.996194651604892|
    |        mean_texture|       9.274276675339683|
    |     concavity_error|       8.009578698471008|
    |      symmetry_error|        7.93458393366217|
    |        radius_error|       7.357747321194173|
    |      worst_symmetry|       5.951699663755868|
    |fractal_dimension...|       4.811246624133022|
    |concave_points_error|        4.73140145466917|
    |   worst_compactness|       4.469820723182832|
    |       texture_error|       4.356178728700959|
    |    mean_compactness|       3.123736411467967|
    |       mean_symmetry|      1.9968633063354835|
    |      mean_concavity|      1.9701941942285224|
    |    smoothness_error|       1.673042485476758|
    |worst_fractal_dim...|      1.3582115541525612|
    |mean_fractal_dime...|      0.6050912755332459|
    |     perimeter_error|      0.3889888676278275|
    |         mean_radius|    5.684356116234315...|
    +--------------------+------------------------+
    
    ================================================================================2021-07-17 22:16:30
    step4: evaluating model ...
    
    train_auc = 1.0
    val_auc = 0.9890340267698758
    ================================================================================2021-07-17 22:16:31
    step5: using model ...
    
    +--------------------+-----+--------------------+--------------------+----------+
    |            features|label|       rawPrediction|         probability|prediction|
    +--------------------+-----+--------------------+--------------------+----------+
    |[9.0,12.0,60.34,2...|    1|[-10.570726382467...|[-9.5707263824679...|       1.0|
    |[10.0,16.0,65.85,...|    1|[-10.120435089856...|[-9.1204350898567...|       1.0|
    |[10.0,21.0,68.51,...|    1|[-8.8020346337692...|[-7.8020346337692...|       1.0|
    |[11.0,14.0,73.53,...|    1|[-10.315758226759...|[-9.3157582267596...|       1.0|
    |[11.0,15.0,73.38,...|    1|[-10.086077130817...|[-9.0860771308174...|       1.0|
    |[11.0,16.0,74.72,...|    1|[-6.9649803118554...|[-5.9649803118554...|       1.0|
    |[11.0,17.0,71.25,...|    1|[-10.694667171248...|[-9.6946671712481...|       1.0|
    |[11.0,17.0,75.27,...|    1|[-9.0156792680894...|[-8.0156792680894...|       1.0|
    |[11.0,18.0,75.17,...|    1|[-5.7513546284621...|[-4.7513546284621...|       1.0|
    |[11.0,18.0,76.38,...|    1|[-4.3134421808792...|[-3.3134421808792...|       1.0|
    |[12.0,15.0,82.57,...|    0|[2.49310942805160...|[3.49310942805160...|       0.0|
    |[12.0,17.0,78.27,...|    1|[-10.516042459712...|[-9.5160424597122...|       1.0|
    |[12.0,18.0,83.19,...|    1|[-9.4899850168431...|[-8.4899850168431...|       1.0|
    |[12.0,22.0,78.75,...|    1|[-8.9917629958319...|[-7.9917629958319...|       1.0|
    |[14.0,15.0,92.68,...|    1|[-7.2724968676775...|[-6.2724968676775...|       1.0|
    |[14.0,15.0,95.77,...|    1|[-5.0143190624015...|[-4.0143190624015...|       1.0|
    |[14.0,16.0,96.22,...|    1|[-5.3849620427583...|[-4.3849620427583...|       1.0|
    |[14.0,19.0,97.83,...|    1|[-3.3292007261919...|[-2.3292007261919...|       1.0|
    |[16.0,14.0,104.3,...|    1|[4.66077729134426...|[5.66077729134426...|       0.0|
    |[19.0,24.0,122.0,...|    0|[10.1503565558166...|[11.1503565558166...|       0.0|
    +--------------------+-----+--------------------+--------------------+----------+
    
    1.0
    ================================================================================2021-07-17 22:16:31
    step6: saving model ...
    
    load model ...
    

    本文Spark-scala 使用 LightGBM 模型训练 二分类模型 代码和数据集,以及训练 多分类模型 和 回归模型 的范例代码和数据集。

    相关文章

      网友评论

          本文标题:用 Spark-Scala 训练 LightGBM 模型

          本文链接:https://www.haomeiwen.com/subject/zmupqrtx.html