美文网首页
树回归(二)

树回归(二)

作者: RossH | 来源:发表于2019-11-07 22:43 被阅读0次

树剪枝

一棵树如果节点过多,说明该模型存在过拟合问题。

通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。树回归(一)中的chooseBestSplit函数中的提前终止条件,实际上是一种预剪枝(prepruning)操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。

预剪枝的不足

树回归(一)中的树构建算法对参数tolStolN非常敏感,下面用树回归(一)中的第一个数据集,采用不同的ops参数,来观察结果。

dataSet = loadDataSet('ex00.txt')
myMat = np.mat(dataSet)
createTree(myMat, ops=(0,1))

结果如下:

{'spInd': 0,
 'spVal': 0.48813,
 'left': {'spInd': 0,
  'spVal': 0.620599,
  'left': {'spInd': 0,
 ......
  'right': {'spInd': 0,
   'spVal': 0.325412,
   'left': {'spInd': 0, 'spVal': 0.3371, 'left': 0.1910235, 'right': 0.118208},
   'right': -0.028594120689655174}}}

由于输出过长,这里省略部分内容。

与上文中只包含两个节点的树相比,这里构建的树过于臃肿。

下面用一个与ex00.txt数据集分布类似,但y轴数量级是其100倍的ex2.txt数据集来构建树。

dataSet = loadDataSet('ex2.txt')
myMat2 = np.mat(dataSet)
createTree(myMat2)

结果如下

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': 108.838789625,
 ......
 'right': {'spInd': 0,
  'spVal': 0.457563,
  'left': 7.969946125,
  'right': -3.6244789069767447}}

用默认参数构建的树显得比较臃肿。下面是其分布。


ex00.txtex2.txt两个数据集分布类似,但在都采用默认参数的情况下,ex00.txt构建的树只有两个叶节点,而ex2.txt却有很多。产生这种现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项上花费时间并对上述误差容忍度取平方值,也能得到两个叶节点的树:

createTree(myMat2, ops=(10000, 4))

output:
{'spInd': 0,
 'spVal': 0.499171,
 'left': 101.35815937735848,
 'right': -2.637719329787234}

然而,通过不断修改参数来得到合理结果并不是很好的办法。

下面将介绍后剪枝,利用测试集来对树进行剪枝,并不需要指定参数,是一种更理想化的剪枝方法。

后剪枝

剪枝函数prune()的伪代码如下:

基于已有的树切分测试数据:
    如果存在任意子集是一棵树,则在该子集递归剪枝
    计算将当前两个叶节点合并后的误差
    计算不合并的误差
    如果合并会降低误差的话,就将叶节点合并
def isTree(obj):
    return (type(obj).__name__ == 'dict')

def getMean(tree):
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    return (tree['left'] + tree['right'])/2

def prune(tree, testData):
    # 没有测试数据则对树进行塌陷处理
    if testData.shape[0] == 0: 
        return getMean(tree)
    
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): 
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): 
        tree['right'] =  prune(tree['right'], rSet)
    
    # 如果都是叶节点,看能不能合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) +\
            sum(np.power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge: 
            print("merging")
            return treeMean
        else: return tree
    else: return tree

isTree()判断是否为树。
getMean()从上往下遍历树直到叶节点为止。该函数对树进行塌陷处理,即返回树平均值。
接下来看看实际效果。

# 构建一个过拟合的树
myTree = createTree(myMat2, ops=(0,1))
# 加载测试集
testData = loadDataSet('ex2test.txt')
testMat = np.mat(testData)
# 剪枝
prune(myTree, testMat)

运行后观察两棵树,可以发现大量节点被剪枝掉,但没有预期那样剪枝成两部分。
一般地,为了寻求最佳模型可以同时使用两种剪枝技术。

相关文章

  • 树回归(二)

    树剪枝 一棵树如果节点过多,说明该模型存在过拟合问题。 通过降低决策树的复杂度来避免过拟合的过程称为剪枝(prun...

  • XGBoost详解

    Prerequisite: CART回归树 CART回归树是假设树为二叉树,通过不断将特征进行分裂。比如当前树结点...

  • 机器学习入门之 — 提升树

    对于回归提升树 提升树是以分类树或回归树为基分类器。它的idea在于,第一个回归树预测的效果可能一般,但是第二个回...

  • 机器学习系列(三十六)——回归决策树与决策树总结

    本篇主要内容:回归决策树原理、回归树学习曲线、决策树总结 回归决策树原理 回归决策树树是用于回归的决策树模型,回归...

  • xgboost slide(二)

    这部分主要是介绍回归树和组合: 回归树(cart)1、回归树(也可以叫做分类回归树),加入gbdt是回归树,但是根...

  • 机器学习实战之树回归

    “回归”与“树” 在讲解树回归之前,我们看看回归和树巧妙结合的原因。 线性回归的弊端 线性回归需要拟合所有样本点,...

  • 树回归

    CART算法(分类回归树) 可以用来分类,也可以用来回归 回归树 回归树与分类树的思路类似,但叶节点的数据类型不是...

  • 决策树-QA

    Q1:决策树是“二叉”树吗? 分类回归树(CART)是“二叉”分支的决策树,内部结点特征取值为“是”和“否”,左分...

  • 对于树模型的一些见解

    首先树模型根据根据基于决策树和回归树可分为分类模型和回归模型。 对于分类和回归问题来说,不同的地方在于,对回归...

  • 【机器学习实战】第9章 树回归

    第9章 树回归 树回归 概述 我们本章介绍 CART(Classification And Regression ...

网友评论

      本文标题:树回归(二)

      本文链接:https://www.haomeiwen.com/subject/ieubbctx.html