美文网首页
回归树总结

回归树总结

作者: Max_7 | 来源:发表于2018-10-28 19:58 被阅读0次

    前言

    最近刚开始看XGBoost,发现和回归树有关,这一块确实不太熟悉,于是在网上找了一些资料了解了一下。

    作用

    首先区分回归树与决策树。决策树的作用说白了是一个分类器,通过对特征的选择,划分,对数据进行分类。具体的算法这里也不再多说了,李航老师的《统计学习方法》里面讲的已经很清楚了。
    与决策树不同,回归树做的是回归,是对值的回归预测。比如可以通过回归树预测房价,或者预测人的年龄,等等。输出的是连续值,而不是离散的分类类别。

    算法

    通俗的讲一回归树的思路,找到一个最好的特征的最优的划分点,把整个数据集根据这个划分点分成大于和小于的两个子集。然后对于这两个划分后的子集再分别寻找最优的划分点。直到满足终止条件,那么回归树就构建完成了。
    首先,如何选择回归树的划分点。
    遍历数据空间的特征,和每个特征所对应的所有取值。假设将j特征的s取值处选为取值点,那么由这个切分点将得到两个区域。
    R_1(j,s)=\lbrace x|x^{(j)} \leq s \rbrace
    R_2(j,s)=\lbrace x|x^{(j)} > s \rbrace
    对于最优切分点的寻找是通过最小化目标函数。
    min[min_{c_1}\sum_{x_{i}\in R_1(j,s)}(y_i-c_1)^2+min_{c_2}\sum_{x_{i}\in R_2(j,s)}(y_i-c_2)^2]
    其中c_1c_2的计算是计算区间内的平均值。
    使用均值的原因如下:
    假设我们用L来表示区间上的损失,那么对于真实值y和区间的表示值y^{'}而言,L=-\frac{1}{2}\sum_{i=0}^{m}(y_{i}-y^{'})^{2}
    为了最小化这个损失,求导,将梯度设为0之后,可以求得结果。
    y^{'}=\frac{1}{m}\sum_{i=0}^{m}y_{i}
    所以,对于每个划分出来的区间,我们用均值来表示这个区间的值。
    接下来就是不停的重复以上的步骤,寻找特征,再寻找特征里的最优划分点,划分区域,把均值作为这个区域的输出。直到最后构建好回归树。
    下面具体看一下回归树算法的流程(图片来自《统计学习方法》),


    算法过程

    关于终止条件一直没有找到一个很确切的定义,个人理解可以人为的设定树的深度,比如当树的深度达到5层时就停止继续划分。另一种思路可以设置一个关于准确度的阈值,当整个回归树的预测准确度(误差)低于阈值时就停止进一步的划分。如果有其他的方法希望可以在下面留言回复。
    关于回归树的复杂度,假设当前的数据存在F个特征,每个特征里面有N个取值。如果生成的回归树最终有S个内部结点,那么整个的复杂度为O(F*N*S)

    代码分析

    这里使用sklearn分别构建了3棵回归树,对应的对应的深度分别为1,2,5。并将结果与线性回归做了简单的对比。


    结果分析

    从结果分析中可以看出,当树的深度为5时,很好的拟合了数据点。表现要比普通的线性回归好很多。
    对于树的深度的选择就涉及到了过拟合问题,包括了树的剪枝。后续如果遇到这些情况会再针对剪枝写篇文章总结一下。

    代码地址

    相关文章

      网友评论

          本文标题:回归树总结

          本文链接:https://www.haomeiwen.com/subject/btukzftx.html