用于分类的GBT(Gradient-Boosted Trees)算法,基于J.H. Friedman. "Stochastic Gradient Boosting"实现,目前不支持多分类任务。Gradient Boosting vs. TreeBoost:
- 本实现基于Stochastic Gradient Boosting(随机梯度提升),而不是TreeBoost
- 两种方法都是通过最小化损失函数,学习树的集成
- TreeBoost方法相对于原始方法,基于损失函数对叶节点的输出进行了额外的修改
- Spark考虑未来实现TreeBoost
GBTClassifier
类
定义
一个唯一标识uid,继承了Predictor
类,继承了GBTClassifierParams
、DefaultParamsWritable
、Logging
特质。其中Predictor
中的三个元素分别代表: 特征类型、学习器、学习到用于预测的模型。
class GBTClassifier(override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTClassifierParams with DefaultParamsWritable with Logging
{
def this() = this(Identifiable.randomUID("gbtc"))
...
}
参数
为了兼容JAVA API,覆盖了继承自特质(with trait)的参数setter方法。
-
TreeClassifierParams
参数
-
maxDepth
树的最大深度,0意味着只有一个叶节点,1意味着有一个内部节点+两个叶节点。
支持:>=0
默认:5 -
maxBins
用于离散连续特征的最大分桶数,用于每个节点特征分裂时分裂点的选择,分桶数越大意味着粒度越高。
支持:>=2并且>=任一类别特征的分类数
默认:32 -
minInstancesPerNode
分裂后每个子节点含有的最小样本数,如果分裂后左孩子或右孩子含有的样本数少于该值,则该分裂无效。
支持:>=1
默认:1 -
minInfoGain
树节点分裂时的最小信息增益。
支持:>=0.0
默认:0.0 -
maxMemoryInMB
每次会对一组节点进行切分,分组是按照树的层次逐步进行。每组需要切分的节点个数视内存大小而定,如果内存太小,每次只能切分一个节点。单位MB
默认:256MB -
cacheNodeIds
如果为true,算法会为每个实例缓存树节点ID;如果为false,算法会将树传递给执行器用于匹配实例和树节点。缓存有利于加速训练深度较大的树,用户可以通过参数checkpointInterval设置缓存被检查的频率或者不检查。
默认:false -
checkpointInterval
表示缓存的树节点ID的检查频率,当cacheNodeIds为true并且检查目录(checkpoint directory)通过sparkContext设置过才有效。
支持:>=1或者-1代表不检查,10意味着每10次迭代检查一次。
默认:10 -
impurity
用于计算信息增益的准则。不支持通过GBTClassifier.setImpurity方法设置该值。
支持:entropy、gini
默认:gini
-
TreeEnsembleParams
参数
-
subsamplingRate
每一次迭代训练基学习器(决策树)时所使用的训练数据集的百分比。
支持:(0, 1]
默认:1.0 -
seed
随机数种子
默认:this.getClass.getName.hashCode.toLong
-
GBTParams
参数
-
maxIter
最大迭代次数
支持:>=0
默认:20 -
stepSize
学习率(learning rate/step size)参数,用于缩小(shrinking)每个基学习器的贡献。
支持:(0, 1]
默认:0.1
-
GBTClassifierParams
参数
-
lossType
GBT最小化的损失函数,不区分大小写。
支持:logistic
默认:logistic
方法
-
copy
方法
GBTClassifier的拷贝函数。 -
train
方法
GBTClassifier
类的主要方法,用于训练得到学习好的用于预测的模型。
// @input: 训练数据, DataSet
// @output: 学习到的模型, GBTClassificationModel
override protected def train(dataset: Dataset[_]):
GBTClassificationModel = {
// 得到类别特征
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
// 转换训练数据并进行验证
// 将DataSet转换成RDD[LabeledPoint]
// 只支持二分类,要求label为0或1
val oldDataset: RDD[LabeledPoint] =
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given dataset with invalid label $label. Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.")
LabeledPoint(label, features)
}
// 获得特征个数及boosting策略
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
// 用于记录日志
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(2)
// 调用GradientBoostedTrees训练得到一组学习器及其权重
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))
// 将学到的模型封装成GBTClassificationModel并返回
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
}
GBTClassifier
对象
object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
// final变量,访问支持的损失函数类型
final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
// 从目录中加载GBTClassifier
override def load(path: String): GBTClassifier = super.load(path)
}
GBTClassificationModel
类
用于分类的GBT模型,仅支持二分类,支持连续特征和类别特征。
定义
继承了PredictionModel
类以及多个特质,其中PredictionModel
的两个元素分别代表特征类型、学习到用于预测的模型。
class GBTClassificationModel private[ml](
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with GBTClassifierParams
with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable
{
// 检查_trees.nonEmpty
// 检查_trees.length == _treeWeights.length
val numTrees: Int = _trees.length
...
}
方法
-
transformImpl
方法
首先将GBTClassificationModel
进行广播,然后通过udf进行预测数据,udf中调用predict
方法实现。
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
// 广播本类
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
// udf通过本类的predict方法实现
bcastModel.value.predict(features.asInstanceOf[Vector])
}
// 使用udf将特征数据转换成预测数据
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
-
predict
方法
关键的预测方法,先得到每个基学习器的预测值,然后进行融合得到最终的预测结果,最后得到类别结果。可以看到这里得到的预测值不是概率而是类别0/1,因为label被转换成了-1/+1,所以这里通过prediction>0.0得到预测lebel。
override protected def predict(features: Vector): Double = {
// 得到每棵树的预测结果
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
// 乘以权重之后求和得到融合结果
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
// 得到预测lebel
if (prediction > 0.0) 1.0 else 0.0
}
-
copy
方法
GBTClassificationModel
的拷贝方法。 -
toOld
方法
将ml的模型转换成mllib中老的API,ml域的私有方法。
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
-
write
方法
调用GBTClassificationModel
对象的方法保存本模型。
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
GBTClassificationModel
对象
-
fromOld
方法
从老的API中转换出当前模型 -
GBTClassificationModelReader
类
私有类,其中的load
方法用于从目录中读取模型 -
GBTClassificationModelWriter
类
私有类,其中的saveImpl
方法用于将本模型保存 -
read
方法
新建GBTClassificationModelReader
类 -
load
方法
网友评论