美文网首页
Spark.GBDT学习-GBTClassifier

Spark.GBDT学习-GBTClassifier

作者: 松鼠胃口好 | 来源:发表于2018-07-03 17:58 被阅读0次

    用于分类的GBT(Gradient-Boosted Trees)算法,基于J.H. Friedman. "Stochastic Gradient Boosting"实现,目前不支持多分类任务。Gradient Boosting vs. TreeBoost:

    • 本实现基于Stochastic Gradient Boosting(随机梯度提升),而不是TreeBoost
    • 两种方法都是通过最小化损失函数,学习树的集成
    • TreeBoost方法相对于原始方法,基于损失函数对叶节点的输出进行了额外的修改
    • Spark考虑未来实现TreeBoost

    GBTClassifier

    定义

    一个唯一标识uid,继承了Predictor类,继承了GBTClassifierParamsDefaultParamsWritableLogging特质。其中Predictor中的三个元素分别代表: 特征类型、学习器、学习到用于预测的模型

    class GBTClassifier(override val uid: String) 
    extends Predictor[Vector, GBTClassifier, GBTClassificationModel] 
    with GBTClassifierParams with DefaultParamsWritable with Logging 
    {
        def this() = this(Identifiable.randomUID("gbtc"))
        ...
    }
    

    参数

    为了兼容JAVA API,覆盖了继承自特质(with trait)的参数setter方法。

    1. TreeClassifierParams参数
    • maxDepth
      树的最大深度,0意味着只有一个叶节点,1意味着有一个内部节点+两个叶节点。
      支持:>=0
      默认:5
    • maxBins
      用于离散连续特征的最大分桶数,用于每个节点特征分裂时分裂点的选择,分桶数越大意味着粒度越高。
      支持:>=2并且>=任一类别特征的分类数
      默认:32
    • minInstancesPerNode
      分裂后每个子节点含有的最小样本数,如果分裂后左孩子或右孩子含有的样本数少于该值,则该分裂无效。
      支持:>=1
      默认:1
    • minInfoGain
      树节点分裂时的最小信息增益。
      支持:>=0.0
      默认:0.0
    • maxMemoryInMB
      每次会对一组节点进行切分,分组是按照树的层次逐步进行。每组需要切分的节点个数视内存大小而定,如果内存太小,每次只能切分一个节点。单位MB
      默认:256MB
    • cacheNodeIds
      如果为true,算法会为每个实例缓存树节点ID;如果为false,算法会将树传递给执行器用于匹配实例和树节点。缓存有利于加速训练深度较大的树,用户可以通过参数checkpointInterval设置缓存被检查的频率或者不检查。
      默认:false
    • checkpointInterval
      表示缓存的树节点ID的检查频率,当cacheNodeIds为true并且检查目录(checkpoint directory)通过sparkContext设置过才有效。
      支持:>=1或者-1代表不检查,10意味着每10次迭代检查一次。
      默认:10
    • impurity
      用于计算信息增益的准则。不支持通过GBTClassifier.setImpurity方法设置该值。
      支持:entropy、gini
      默认:gini
    1. TreeEnsembleParams参数
    • subsamplingRate
      每一次迭代训练基学习器(决策树)时所使用的训练数据集的百分比。
      支持:(0, 1]
      默认:1.0
    • seed
      随机数种子
      默认:this.getClass.getName.hashCode.toLong
    1. GBTParams参数
    • maxIter
      最大迭代次数
      支持:>=0
      默认:20
    • stepSize
      学习率(learning rate/step size)参数,用于缩小(shrinking)每个基学习器的贡献。
      支持:(0, 1]
      默认:0.1
    1. GBTClassifierParams参数
    • lossType
      GBT最小化的损失函数,不区分大小写。
      支持:logistic
      默认:logistic

    方法

    1. copy方法
      GBTClassifier的拷贝函数。
    2. train方法
      GBTClassifier类的主要方法,用于训练得到学习好的用于预测的模型。
    // @input: 训练数据, DataSet
    // @output: 学习到的模型, GBTClassificationModel
    override protected def train(dataset: Dataset[_]):
    GBTClassificationModel = {
        // 得到类别特征
        val categoricalFeatures: Map[Int, Int] =
        MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
        // 转换训练数据并进行验证
        // 将DataSet转换成RDD[LabeledPoint]
        // 只支持二分类,要求label为0或1
        val oldDataset: RDD[LabeledPoint] =
            dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
                case Row(label: Double, features: Vector) =>
                    require(label == 0 || label == 1, s"GBTClassifier was given dataset with invalid label $label.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.")
                LabeledPoint(label, features)
            }
        // 获得特征个数及boosting策略
        val numFeatures = oldDataset.first().features.size
        val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
        // 用于记录日志
        val instr = Instrumentation.create(this, oldDataset)
        instr.logParams(params: _*)
        instr.logNumFeatures(numFeatures)
        instr.logNumClasses(2)
        // 调用GradientBoostedTrees训练得到一组学习器及其权重
        val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))
        // 将学到的模型封装成GBTClassificationModel并返回
        val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
        instr.logSuccess(m)
        m
    }
    

    GBTClassifier对象

    object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
        // final变量,访问支持的损失函数类型
        final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
        // 从目录中加载GBTClassifier
        override def load(path: String): GBTClassifier = super.load(path)
    }
    

    GBTClassificationModel

    用于分类的GBT模型,仅支持二分类,支持连续特征和类别特征。

    定义

    继承了PredictionModel类以及多个特质,其中PredictionModel的两个元素分别代表特征类型、学习到用于预测的模型

    class GBTClassificationModel private[ml](
        override val uid: String,
        private val _trees: Array[DecisionTreeRegressionModel],
        private val _treeWeights: Array[Double],
        override val numFeatures: Int)
    extends PredictionModel[Vector, GBTClassificationModel]
    with GBTClassifierParams 
    with TreeEnsembleModel[DecisionTreeRegressionModel]
    with MLWritable with Serializable 
    {
        // 检查_trees.nonEmpty
        // 检查_trees.length == _treeWeights.length
        val numTrees: Int = _trees.length
        ...
    }
    

    方法

    1. transformImpl方法
      首先将GBTClassificationModel进行广播,然后通过udf进行预测数据,udf中调用predict方法实现。
    override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
        // 广播本类
        val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
        val predictUDF = udf { (features: Any) =>
            // udf通过本类的predict方法实现
            bcastModel.value.predict(features.asInstanceOf[Vector])
        }
        // 使用udf将特征数据转换成预测数据
        dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
      }
    
    1. predict方法
      关键的预测方法,先得到每个基学习器的预测值,然后进行融合得到最终的预测结果,最后得到类别结果。可以看到这里得到的预测值不是概率而是类别0/1,因为label被转换成了-1/+1,所以这里通过prediction>0.0得到预测lebel。
    override protected def predict(features: Vector): Double = {
        // 得到每棵树的预测结果
        val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
        // 乘以权重之后求和得到融合结果
        val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
        // 得到预测lebel
        if (prediction > 0.0) 1.0 else 0.0
      }
    
    1. copy方法
      GBTClassificationModel的拷贝方法。
    2. toOld方法
      将ml的模型转换成mllib中老的API,ml域的私有方法。
    private[ml] def toOld: OldGBTModel = {
        new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
    }
    
    1. write方法
      调用GBTClassificationModel对象的方法保存本模型。
    override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
    

    GBTClassificationModel对象

    1. fromOld方法
      从老的API中转换出当前模型
    2. GBTClassificationModelReader
      私有类,其中的load方法用于从目录中读取模型
    3. GBTClassificationModelWriter
      私有类,其中的saveImpl方法用于将本模型保存
    4. read方法
      新建GBTClassificationModelReader
    5. load方法

    相关文章

      网友评论

          本文标题:Spark.GBDT学习-GBTClassifier

          本文链接:https://www.haomeiwen.com/subject/rqujuftx.html