run
方法
根据任务类型训练得到一组弱学习器及对应的权重。分类任务(目前只能处理二分类)和回归任务调用的是相同的方法进行训练,分类任务可以看作是取值范围为[-1, +1]的回归任务。基学习器是DecisionTreeRegressionModel
。
// Method to train a gradient boosting model
// @input: 训练数据集, RDD[LabelPoint]
// @input: boosting策略, boostingStrategy
// @input: 随机数种子, seed
// @output: (array of decision tree models, array of model weights)
def run(
input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
// 根据boosting策略选择回归还是分类
// 分类和回归调用的是同一个方法,唯一的区别就是需要将分类任务的label进行转换
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case OldAlgo.Classification =>
// 分类任务, 需要先将label映射为-1, +1. 这样二分类就可以看作是[-1, +1]的回归问题
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
}
runWithValidation
方法
基于验证集的训练方法。同run
方法唯一的区别就是增加了验证集。验证集要和训练集不同并且符合相同的分布(e.g.通过randomSplit
方法得到的两个数据集)。验证集的功能就是:
- 在训练过程中通过验证判断是否提前结束训练(误差减小幅度太小;误差增加过拟合)
- 选择验证误差最小的模型(第m轮在验证集上的误差最小)
- 验证集的具体使用在
boost
方法中可以看到
// 和`run`方法的区别是调用`boost`方法时,参数validate的默认值为true,并且对于分类任务训练集和验证集都要进行label转换. 当validate为false时,验证集是无效的
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
computeInitialPredictionAndError
方法
计算gradient boosting第一次迭代产生模型(学习到的第一棵树)预测值以及误差。
// @input: 训练数据集, RDD[LabeledPoint]
// @input: 第一棵树的学习率(权重), Double
// @input: 第一棵树, DecisionTreeRegressionModel
// @input: 评价标准(evaluation metric), Loss
// @output: RDD(Tuple2(prediction, error))
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
data.map { lp =>
// 调用updatePrediction得到预测结果, 已有的预测值为0.0
val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight)
// 调用Loss计算预测误差
val error = loss.computeError(pred, lp.label)
(pred, error)
}
}
updatePrediction
方法
将新一轮boosting迭代产生模型的预测值累加到之前的预测值上。
// @input: 特征,Vector
// @input: 已有的预测值(通过累加的方式集成多个弱学习的学习结果)
// @input: 新的决策树模型, DecisionTreeRegressionModel
// @input: 新模型的权重(学习率)
def updatePrediction(
features: Vector,
prediction: Double,
tree: DecisionTreeRegressionModel,
weight: Double): Double = {
// 调用决策树的预测方法得到预测值,乘以学习率,累加到已有的预测值上
prediction + tree.rootNode.predictImpl(features).prediction * weight
}
updatePredictionError
方法
根据新一轮boosting迭代产生的模型调用updatePrediction
方法得到新的预测值,并计算新的误差。
// @input: 训练数据, 新的决策树模型及权重, 评估方法
// @input: 上一轮的(预测值, 误差), 用于更新新的预测值
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions {
iter => iter.map {
// zip之后, 形成(key, value), 第一个RDD的元素是key
case (lp, (pred, error)) =>
val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
}
newPredError
}
computeError
方法
计算GBT的误差,该方法没有在算法中使用,但是对于Debug很有用。计算输入数据的平均误差。
// @input: 输入数据, 基学习数组, 权重数组, 评估方法
// @output: 平均误差
def computeError(
data: RDD[LabeledPoint],
trees: Array[DecisionTreeRegressionModel],
treeWeights: Array[Double],
loss: OldLoss): Double = {
data.map { lp =>
// 计算预测值foldLeft从左边遍历元素(model, weight), 初始值0, 得到累加的预测值
val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
updatePrediction(lp.features, acc, model, weight)
}
// 得到预测值之后计算误差
loss.computeError(predicted, lp.label)
}.mean()
}
evaluateEachIteration
方法
为gradient boosting的每一次迭代计算误差或损失,相当于evaluate方法。该方法好像没有被调用过。
// @input: 输入数据, 基学习数组, 权重数组, 评估方法, 算法类型(回归/分类)
def evaluateEachIteration(
data: RDD[LabeledPoint],
trees: Array[DecisionTreeRegressionModel],
treeWeights: Array[Double],
loss: OldLoss,
algo: OldAlgo.Value): Array[Double] = {
val sc = data.sparkContext
// 对于二分类任务需要将label映射到-1/+1
val remappedData = algo match {
case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
case _ => data
}
// 广播trees, Kryo序列化可以注册该类
val broadcastTrees = sc.broadcast(trees)
val localTreeWeights = treeWeights
val treesIndices = trees.indices
val dataCount = remappedData.count()
// 计算每一轮迭代的平均误差
val evaluation = remappedData.map { point =>
treesIndices.map { idx =>
// 计算每一个基学习器的预测值
val prediction = broadcastTrees.value(idx)
.rootNode
.predictImpl(point.features)
.prediction
prediction * localTreeWeights(idx)
}
// 累加得到每一轮的预测值
.scanLeft(0.0)(_ + _).drop(1)
// 计算得到每一轮的误差
.map(prediction => loss.computeError(prediction, point.label))
}
// 计算所有数据每一轮的平均误差
.aggregate(treesIndices.map(_ => 0.0))(
(aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
(a, b) => treesIndices.map(idx => a(idx) + b(idx)))
.map(_ / dataCount)
broadcastTrees.destroy()
evaluation.toArray
}
该方法中的scanLeft
和aggregate
方法需要额外说明。
// reduceLeft/Right没有初始值
// foldLeft/Right有初始值
// scanLeft/Right得到累积的中间结果的集合
val abc = List("A", "B", "C")
def add(res: String, x: String) = {
println(s"op: $res + $x = ${res + x}")
res + x
}
abc.reduceLeft(add)
// op: A + B = AB
// op: AB + C = ABC
// res: String = ABC
abc.foldLeft("z")(add)
// op: z + A = zA
// op: zA + B = zAB
// op: zAB + C = zABC
// res: String = zABC
abc.scanLeft("z")(add)
// op: z + A = zA
// op: zA + B = zAB
// op: zAB + C = zABC
// res: List[String] = List(z, zA, zAB, zABC)
def aggregate(zeroValue)(seqOp, combOp)
// zeroValue是初始值
// seqOp用于计算一个分区中的结果
// combOp用户合并不同分区的结果
data.aggregate(treesIndices.map(_ => 0.0))(
(aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
(a, b) => treesIndices.map(idx => a(idx) + b(idx))
)
// 初始值是0数组, aggregated是累积结果, row是data中的每一行数据。意思就是将data中每一行相同位置的数据进行累加。之后(a, b)是合并不同分区的结果
boost
方法
最关键的方法,用户训练得到模型。
def boost(
input: RDD[LabeledPoint],
validationInput: RDD[labeledPoint],
boostingStrategy: OldBoostingStrategy,
validate: Boolean,
seed: Long):
(Array[DecisionTreeRegressionModel], Array[Double]) = {
// 验证boosting策略是否有效(只支持二分类以及回归任务,要求学习率(0,1])
boostingStrategy.assertValid()
// 初始化gradient boosting参数(基学习器是回归树)
val numIterations = boostingStrategy.numIterations
val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// 初始化基学习器参数,基学习器是基于方差不纯度(variance impurity)的回归树
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = OldAlgo.Regression
treeStrategy.impurity = OldVariance
// 基学习器策略要求分类任务类别>=2, 不纯度测量方法为Gini或Entropy; 回归任务要求不纯度测量方法为variance。maxDepth>=0, maxBins>=2, minInstancesPerNode>=1, maxMemoryInMB<=10240, subsamplingRate满足(0,1]
treeStrategy.assertValid()
// 缓存训练数据(true/false标志用于判断是否需要unpersist)
val persistedInput = if (input.getStorageLevel == StorageLevel.NONE)
{
input.persist(StorageLevel.MEMORY_AND_DISK)
true
} else {
false
}
// 为训练数据和验证数据准备周期性的检查, 每进行一次迭代都更新对应的预测值和误差, 训练完成后删除所有的检查
// 学习第一棵树(第一棵树的权重为1.0)
val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
// 计算训练数据、验证数据的预测值与误差, 更新检查点
computeInitialPredictionAndError(...)
// 初始化最优验证误差以及最优的位置
var bestValidateError =
if (validate)
validatePredError.values.mean()
else 0.0
var bestM = 1
// 主循环, 循环迭代学习, 梯度提升
var m = 1 // 当前迭代
var doneLearning = false // 是否早停
while (m < numIterations && !doneLearning) {
// 基于伪残差(pseudo-residuals, 负梯度)更新数据
val data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeStrategy)
baseLearners(m) = model
// 权重设置为学习率,这种方法对于除了平方误差之外的损失函数是不正确的。权重应该针对每一种损失函数进行优化,但是这种方法尽管不是最优的,但是是合理的
baseLearnerWeights(m) = learningRate
// 计算预测和误差,更新检查
// 验证集,用于判断早停和寻找最优的模型
if (validate) {
// 计算验证集预测和误差,更新验证集的检查
validatePredError = updatePredictionError(...)
val currentValidateError = validatePredError.values.mean()
// 早停条件(开始的最优验证误差较大)
// 1. 减小的误差小于validationTol(基学习器的参数),或者
// 2. 验证误差增加(差小于0了), 模型可能过拟合
// 返回对应最优验证误差的模型
if (bestValidateError - currentValidateError < validationTol * Math.max(currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
}
}
m += 1
}
// 删除所有检查点
// unpersist数据
if (persistedInput) input.unpersist()
// 返回模型
if (validate) {
(baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
} else {
(baseLearners, baseLearnerWeights)
}
}
TODO
- 为什么学习率作为权重对于平方误差来说是正确的?对于其他损失函数不是最优的但是是合理的?
网友评论