十分钟了解决策树

作者: littlehei | 来源:发表于2017-03-05 19:32 被阅读750次
    鸢尾花数据集上描绘决策树的决策面图

    决策树是一种非参数的监督学习方法,可用于分类和回归的应用中。旨在通过数据学习出简单的决策规则来创建模型,进而预测和判定目标变量的结果。

    ** 决策树的工作原理 **

    顾名思义,决策树以「树」这样的数据结构来进行分类或预测决策,在分类应用中,树中的每个节点构成类标签,叶子节点就是最终的分类类标号;树中的分支由决策规则组成。

    一旦决策树构造好了,对目标变量进行分类时,从树的根节点开始,依次经过不同的规则分支,达到叶子节点,该叶子节点的类标号即为该目标变量的分类结果。

    ** 如何构造决策树 **

    构造决策树的步骤为:
    1、开始的时候,所有记录当作一个节点;
    2、选择一个属性测试条件用于分割节点,往往需要遍历每一种分割方式,以找到最好的分割点;
    3、将节点分割,作为其子节点,如分割成N1和N2;
    4、对子节点,如N1和N2,继续执行第2、3步,直到节点满足停止分割的条件

    从步骤中,我们可以看出,整个算法必须解决「如何停止分割」和「如何选择分割」两个关键问题。

    为此,需要引入不纯度来度量每个节点,即度量节点中的记录够不够纯净,是否都是同一类别的记录。不纯度的计算方式包括:
    1、熵, Entropy = - sum( p(i) * log(p(i)) )
    2、Gini不纯度,Gini = 1 - sum( p(i) * p(i) )
    3、错误率,Error = 1 - max(p(i))
    注:p(i)为节点中第i种类别记录的占比。

    一种停止分割的方法是,当节点的不纯度满足一定条件时,则不再对该节点继续分割。而属性测试条件的选择也可用不纯度来衡量,当父节点和子节点不纯度差别越大时,说明分割得越好。

    若想了解具体的实现,可参考决策树的scikit-learn源码

    ** 决策树的优缺点 **

    优点
    • 简单易懂,可解释性强,且构造的树能够可视化。
    • 只需要较少的数据准备,而其他一些技术常常需要做数据标准化、哑变量的创建等等数据准备工作。
    • 使用成本低,一旦创建了树,对目标变量的决策所需要消耗的时间很少。
    • 能够同时处理数值和分类变量,其他的一些技术往往只能处理特定数据类型的变量。
    • 能处理多输出问题。
    • 可使用统计检验来验证模型,从而可保证模型的可靠性。
    缺点
    • 容易出现过拟合,特别是在构造了过于复杂的树的情况下。
    • 不够稳定,哪怕是训练数据出现了一点小的变化,最后生成的树也可能千差万别。
    • 训练寻找一个最优的决策树是一个NP完全问题,因此决策树的构造算法也多使用贪心算法,得到的往往是一个局部最优结果。
    • 对于异或、多路复用等问题,决策树表现一般,因为决策树很难去表达它们。
    剪枝

    当构造的决策树节点过多,显得过于复杂、过于「茂盛」时,就容易出现过拟合的现象,此时需要对树进行剪枝。剪枝方法包括前置剪枝和后置剪枝,也叫先剪枝和后剪枝。

    前置剪枝,即在构造树的时候就进行剪枝,这样在产生完全拟合训练数据之前就停止决策树的生长了。为做到这一点,通常采用更具有限制性的分割结束条件,如控制节点数,设置较高的分割阈值等。

    后置剪枝,初始决策树按照最大规模生长,完全拟合训练数据,然后在进行剪枝步骤,如用新的叶子节点替换子树,或用子树中最常见的分支代替子树。

    ** 示例 **

    下面示例使用决策树对鸢尾花进行分类计算,scikit-learn自带的鸢尾花数据集中有4个属性,示例中将每两两属性为一组,使用决策树从训练数据中学习得到对应的分类阈值规则,然后组成分类的决策边界。

    示例绘制了决策树分类结果的决策面,该示例来源于scikit-learn,详细代码如下:

    # -*- coding: utf-8 -*-
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    from sklearn.datasets import load_iris
    from sklearn.tree import DecisionTreeClassifier
    
    # 参数
    n_classes = 3
    plot_colors = "bry"
    plot_step = 0.02
    
    # 鸢尾花数据加载
    iris = load_iris()
    
    # 鸢尾花数据有4个属性,取两两属性为一组,遍历
    for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
                                    [1, 2], [1, 3], [2, 3]]):
        # 每次取两个属性的数据
        X = iris.data[:, pair]
        y = iris.target
    
        # 训练分类器
        clf = DecisionTreeClassifier().fit(X, y)
    
        # 绘制决策边界,共2行3列6个子图,每组属性画一个图
        plt.subplot(2, 3, pairidx + 1)
    
        # 生成坐标矩阵数据
        x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
        y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
                             np.arange(y_min, y_max, plot_step))
    
        # 使用分类器计算每个坐标点的分类,并绘制分类结果矩阵的等高线
        Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)
        cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
    
        # 使用属性特征名作为横纵坐标名
        plt.xlabel(iris.feature_names[pair[0]])
        plt.ylabel(iris.feature_names[pair[1]])
        plt.axis("tight")
    
        # 绘制训练数据点,并用不同颜色区分
        for i, color in zip(range(n_classes), plot_colors):
            idx = np.where(y == i)
            plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
                        cmap=plt.cm.Paired)
    
        plt.axis("tight")
    
    plt.suptitle("Decision surface of a decision tree using paired features")
    plt.legend()
    plt.show()
    

    更多有关决策树的内容见这里

    (全文完)

    相关文章

      网友评论

        本文标题:十分钟了解决策树

        本文链接:https://www.haomeiwen.com/subject/hsxdgttx.html