美文网首页数据挖掘与机器学习
十大数据挖掘算法之CART回归树

十大数据挖掘算法之CART回归树

作者: 清梦载星河 | 来源:发表于2020-01-28 17:38 被阅读0次

    一、CART回归树概述

    决策树算法的关键在于选择最佳划分特征及特征重最佳划分点位置,即划分算法。ID3决策树的划分算法基于信息增益,C4.5基于信息增益率,CART基于GINI系数。CART决策树算法既可用于分类,亦可用于回归。

    CART算法使用二元切分来处理连续型变量,因此用CART构建的回归树,其叶节点的数据类型不是分类树的离散型,而是连续型。

    二、 CART基尼系数算法的流程

    1. 迭代计算每个特征的每个二分切点GINI系数;
    2. 选取GINI最小的特征及对应切分点为最佳划分点;
    3. 进行一次样本划分;
    4. 对划分后的两部分样本重复以上迭代过程,逐步向下划分;
    5. 所有样本被分到叶结点中;
    6. 结束。

    二、连续型和离散型特征的树的构建

    使用一部字典来存储树的数据结构,该字典包含以下4个元素:

    • 待切分的特征;
    • 待切分的特征值;
    • 右子树。当不再需要切分的时候,也可以是单个值;
    • 左子树。与右子树类似。

    先上代码:

    import numpy as np
    
    def loadDataSet(fileName):
        '''
        此函数作用为读取一个以tab键为分隔符的文件,
        然后将每行的内容保存成浮点数
        '''
        datamat = []
        with open(fileName) as f:
            for line in f.readlines():
                curLine = line.strip().split('\t')
                # 将每行映射成浮点数
                fltLine = list(map(float,curLine))
                datamat.append(fltLine)
        return datamat
    
    def binSplitDataSet(dataSet,feature,value):
        '''
        参数分别为:数据集合、待切分的特征和该特征的某个值;
        在给定特征和特征值的情况下,该函数通过数组过滤的方式将数据集切分得到两个子集
        '''
        mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
        mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
        return mat0,mat1
    
    def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
        '''
        参数分别为:数据集和3个可选参数;
        leafType给出建立叶节点的函数;
        errType代表误差计算函数;
        ops是一个包含树构建所需其他参数的元组。
        '''
        # 满足停止条件时返回None和某类模型的值
        # 如果是回归树,则模型为一个常数;如果是模型树,则模型为一个线性方程。
        feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
        if feat == None:
            return val
        retTree = {}
        retTree['spInd'] = feat
        retTree['spVal'] = val
        lSet,rSet = binSplitDataSet(dataSet,feat,val)
        retTree['left'] = createTree(lSet,leafType,errType,ops)
        retTree['right'] = createTree(rSet,leafType,errType,ops)
        return retTree
    

    上面的代码中,createTree()函数用于构建树,createTree是一个递归函数,使用后面写的chooseBestSplit()函数将数据集分为两个部分。createTree()函数的伪代码如下:

    找到最佳的待切分特征:
        如果该节点不能再分,将该节点存为叶节点
        执行二元切分
        在右子树调用createTree()方法
        在左子树调用createTree()方法
    

    三、将CART算法用于回归

    回归树的切分函数:

    def regLeaf(dataSet):
        '''
        此函数用于生成叶节点
        '''
        return np.mean(dataSet[:,-1])
    
    def regErr(dataSet):
        '''
        此函数用于生成计算目标变量的平方误差
        '''
        return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
    
    def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
        '''
        该函数的目的是找到数据的最佳二元切分方式
        '''
        # tolS是容许的误差下降值
        # tolN是切分的最少样本数
        tolS = ops[0]
        tolN = ops[1]
        # 如果不同剩余特征值的数目为为1,那么就不需要再切分而直接返回
        if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
            return None,leafType(dataSet)
        m,n = np.shape(dataSet)
        # S用于与新切分误差进行对比,来检查新切分能否降低误差
        S = errType(dataSet)
        bestS = np.inf
        bestIndex = 0
        bestValue = 0
        for featIndex in range(n-1):
            for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
                mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
                if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
                    continue
                newS = errType(mat0) + errType(mat1)
                if newS < bestS:
                    bestIndex = featIndex
                    bestValue = splitVal
                    bestS = newS
        if (S - bestS) < tolS:
            return None,leafType(dataSet)
        mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
        if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
            return None,leafType(dataSet)
        return bestIndex,bestValue
    

    上述代码中,chooseBestSplit()函数的目标是找到数据集切分的最佳位置,其伪代码如下:

    对每个特征:
        对每个特征值:
            将数据集切分为两份
            计算切分的误差
            如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
    返回最佳切分的特征和阈值
    

    代码运行(Python提示符):

    >>>myDat = loadDataSet('FILENAME')
    >>>myMat = np.mat(myDat)
    >>>createTree(myMat)
    

    四、树剪枝

    一棵树如果节点过多,即每个结点划分为一个叶结点,表明该模型可能对数据进行了“过拟合”,为了避免过拟合,我们可以采用“预剪枝”和“后剪枝”两种方案。

    4.1 预剪枝

    预剪枝是指在决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶节点。(摘自《机器学习 (周志华著)》)

    在这里,可以通过指定参数 ops 来进行预剪枝,ops 参数分别对应 tolS和tolN ,值越高,代表条件越苛刻。

    4.2 后剪枝

    后剪枝是从训练集生成一棵完整的决策树,然后自底而上地对非叶节点进行考察,若将该节点对应地子树替换为叶节点能带来决策树泛化性能提升,则将该子树替换为叶节点。(摘自《机器学习 (周志华著)》)

    虽然《机器学习》里面说的是“自底而上”找非叶节点,但这里用的是自上而下找叶节点。使用后剪枝方法需要将数据集分为测试集和训练集。首先指定参数,使得构建出的树足够大,足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差。如果是就合并。

    def isTree(obj):
        '''
        判断当前处理的节点是否为叶节点
        '''
        return (type(obj).__name__=='dict')
    
    def getMean(tree):
        '''
        从上往下遍历树直到叶节点为止。
        如果找到两个叶节点则计算它们的平均值。
        '''
        if isTree(tree['right']):
            tree['right'] = getMean(tree['right'])
        if isTree(tree['left']):
            tree['left'] = getMean(tree['left'])
        return (tree['left'] + tree['right']) / 2.0
    
    def prune(tree,testData):
        '''
        参数tree为待剪枝的树
        参数testData为剪枝所需的测试数据
        '''
        # 确认测试集是否为空
        if np.shape(testData)[0] == 0:
            return getMean(tree)
        # 如果测试集非空,则反复递归调用函数prune()对测试数据进行切分
        if (isTree(tree['right']) or isTree(tree['left'])):
            lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        if isTree(tree['left']):
            tree['left'] = prune(tree['left'], lSet)
        if isTree(tree['right']):
            tree['right'] =  prune(tree['right'], rSet)
        # 如果两个分支已经不再是子树,那么就可以进行合并
        # 如果合并后的误差比合并前的误差小就进行合并,反之则不合并
        if not isTree(tree['left']) and not isTree(tree['right']):
            lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
            errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'],2)) +\
                np.sum(np.power(rSet[:,-1] - tree['right'],2))
            treeMean = (tree['left']+tree['right'])/2.0
            errorMerge = np.sum(np.power(testData[:,-1] - treeMean,2))
            if errorMerge < errorNoMerge: 
                print("merging")
                return treeMean
            else: 
                return tree
        else:
            return tree
    

    其中prune()函数的伪代码如下:

    基于已有的树切分测试数据:
        如果存在任一子集是一棵树,则在该子集递归剪枝过程
        计算将当前两个叶节点合并后的误差
        计算不合并的误差
        如果合并会降低误差的话,就将叶节点合并
    

    (具体解释和代码运行请另参阅《机器学习实战》)

    五、使用sklearn中的决策树

    # 决策树分类器
    from sklearn.tree import DecisionTreeClassifier
    # 决策树回归器
    from sklearn.tree import DecisionTreeRegressor
    
    DecisionTreeClassifier(
        criterion='gini', # 划分算法,默认是gini,另一个是entropy。如果是回归模型,参数为"mse"或"mae"
        splitter='best', # 特征中选择最佳划分点。有两个选择,"best"或者"random",best是从所有样本中选择最优,random是从随机抽样的样本中选择最优,一般为best,数据规模特别大的时候选择random
        max_depth=None, # 控制树的深度,过拟合时减少,欠拟合时增大
        min_samples_split=2, # 任意节点样本量达到多少时就不再分裂,一般默认
        min_samples_leaf=1, # 每个叶节点上最少的样本数,提高这个值可以防止过拟合
        max_features=None, # 控制所搜的特征数量,一般为默认,即不限制
        max_leaf_nodes=None, # 最大叶节点数,提高参数值可防止过拟合
        min_impurity_decrease=0.0, # 切分点不纯度最小减少量,和后剪枝相关
        min_impurity_split=None, # 切分点最小不纯度,和预剪枝相关
        class_weight=None, # 指定样本各类别的权重
    )
    

    相关文章

      网友评论

        本文标题:十大数据挖掘算法之CART回归树

        本文链接:https://www.haomeiwen.com/subject/dikmthtx.html