决策树(附代码实现)

作者: 老居搞机 | 来源:发表于2020-03-22 10:30 被阅读0次

    决策树

    决策树其实跟人做决策过程的道理是比较相似的.

    比如小明在想今天要不要出去玩这件事情,首先会看看今天是不是工作日,如果是工作日,好了还是安心去搬砖吧! 如果不是工作日那再看看当前新冠肺炎疫情情况,如果没好,那还是窝在家里比较安全,如果疫情好了那小明再去看看今天天气怎么样,如果刮风下雨温度太低什么的那也窝在家里吧,如果不下雨, ok 那小明就开开心心的出去浪了, 这个就是整个决策树的流程了.

    如图:

    这里有个关键问题是,小明要决策的条件很多(是不是刮风、是不是下雨、温度怎么样), 怎么判断先用哪个条件做决策再用哪个条件呢?换句话说这些条件的先后顺序怎么定?

    image

    在决策树中定义这些分裂条件优先级比较有名的算法有:信息增益(ID3)、信息增益率(ID4.5)、基尼系数(CART),下面分别介绍一下

    信息增益(ID3)

    ID3算法使用的是信息增益作为划分标准

    • 总样本的信息熵:
    • 某个属性的信息熵:
    • 那么信息增益就是:
    • 举个例子🌰,这里有一份小明历史上是否出去玩的数据:

    总共14条数据,其中出去玩(Yes)的次数9, 那出去玩的概率是 P(Y) = 9/14

    不出去玩(No)的次数5,那不出去玩的概率是 P(N) = 5/14

    • 那么该样本的信息熵 info(D)就是:

    样本中共有4个属性(条件): Outlook, Temperature, Humidity, Wind

    Outlook属性的信息熵

    Sunny: 总数:5, Yes:2, No:3

    • 它的信息熵就是:

    Overcast: 总数:4, Yes:4, No:0

    • 它的信息熵就是:

    Rain: 总数:5, Yes:3, No:2

    • 它的信息熵就是:
    • Outlook属性总的信息熵就是:

    • 最终算出Outlook属性信息增益是:

    使用同样方法我们可以计算出其它属性的信息增益是:

    所以Outlook的信息增益最高,所有我们切分决策树第一层的时候就选择使用使用天气属性(Outlook),下一层依然采用相同方法来寻找信息增益最大的属性

    信息增益率(ID4.5)

    ID3算法存在一个问题,就是偏向于多值属性(过拟合), 例如,如果存在唯一标识属性ID,则ID3会选择它作为分 裂属性,这样虽然使得划分充分纯净,但这种划分对分类几乎毫无用处

    ID3的后继改进算法C4.5使用信息增益率(gain ratio),试图克服这个偏倚

    • 信息增益:

    • 引入的分裂信息熵:

    • 那么信息增益率就是:

      需要注意的是,信息增益率准则对可取值数目较少的属性有所偏好,因此C4.5算法并不是直接选择增益率最大的候选划分属性,而是使用一个启发式:先从候选划分属性中找出信息增益高于平均值水平的属性,再从中选择增益率最高的属性

    基尼系数(CART)

    CART算法在分类树使用基尼指数增益, 回归树使用最小化均方差

    • 基尼指数:

    • 分裂后的基尼指数:

    • 基尼指数增益:

    我们在候选属性集合A中,选择那个使划分后基尼指数最小的属性作为最优划分属性,即 a = argmin Gini(A)(D, a)

    决策树剪枝

    为了防止决策树的过拟合, 在决策树的构造中通常会做剪枝操作,剪枝分为预剪枝和后剪枝:

    • 预剪枝:

    在构造过程中,当某个节点满足剪枝条件, 则直接停止 此分支的构造

    • 后剪枝:

    先构造完完整的决策树, 在通过某些条件遍历树进行剪枝

    附代码

    # -*- encoding: utf8 -*-
    from math import log
    ​
    ​
    class DecisionTree:
        """
        ID3算法实现的决策树
        """
    ​
        def __init__(self):
    ​
            self.__feature_labels = dict()
            self.__tree = dict()
    ​
        def cul_entropy(self, data_set):
            """
            计算信息熵
            :param data_set:
            :return: float ent
            """
            total_len = len(data_set)
            label_counts = {}
    ​
            for item in data_set:
                label_counts[item[-1]] = label_counts.get(item[-1], 0) + 1
    ​
            ent = 0
            for key in label_counts:
                label_prob = float(label_counts[key]) / total_len
                ent -= label_prob * log(label_prob, 2)
    ​
            return ent
    ​
        def cul_information_gain_label(self, before_ent, data_set, axis):
            """
            计算信息增益
            :param before_ent:
            :param data_set:
            :param axis:
            :return: float information gain
            """
            total_len = len(data_set)
            feature_values = [item[axis] for item in data_set]
            unique_feature_values = set(feature_values)
    ​
            label_ent = 0
            for feature_value in unique_feature_values:
                label_data_set = self.split_data_lable(data_set, axis, feature_value)
                label_ent += float(len(label_data_set)) / total_len * self.cul_entropy(label_data_set)
    ​
            return before_ent - label_ent
    ​
        def get_best_feature(self, data_set):
            """
            获得信息增益最大的一个特征
            :param data_set:
            :return: feature axis
            """
            features_len = len(data_set[0][:-1])
            before_ent = self.cul_entropy(data_set)
    ​
            label_gain = {}
            for axis in range(features_len):
                label_gain[axis] = self.cul_information_gain_label(before_ent, data_set, axis)
    ​
            return max(label_gain.iterkeys(), key=lambda k: label_gain[k])
    ​
        def split_data_lable(self, data_set, axis, feature_value):
            label_data_set = []
            for item in data_set:
                if item[axis] == feature_value:
                    item_vec = item[:axis]
                    item_vec.extend(item[axis + 1:])
                    label_data_set.append(item_vec)
    ​
            return label_data_set
    ​
        def create_tree_id3(self, data_set, labels):
            class_list = [item[-1] for item in data_set]
            # 预剪枝 - 所有数据属于同一类型停止分枝
            if class_list.count(class_list[0]) == len(class_list):
                return class_list[0]
    ​
            best_feature = self.get_best_feature(data_set)
            best_feature_label = labels[best_feature]
            del(labels[best_feature])
            my_tree = {best_feature_label: {}}
    ​
            feature_values = [item[best_feature] for item in data_set]
            unique_feature_values = set(feature_values)
            for feature_value in unique_feature_values:
                sub_labels = labels[:]
                my_tree[best_feature_label][feature_value] = self.create_tree_id3(self.split_data_lable(data_set, best_feature, feature_value), sub_labels)
    ​
            return my_tree
    ​
        def fit(self, data_set, labels):
            for axis in range(len(labels)):
                self.__feature_labels[labels[axis]] = axis
    ​
            self.__tree = self.create_tree_id3(data_set, labels)
            print "tree:", self.__tree
    ​
        def predict_tree(self, tree, product):
            for key in tree:
                feature_value = product[self.__feature_labels[key]]
                if isinstance(tree[key][feature_value], dict):
                    return self.predict_tree(tree[key][feature_value], product)
                else:
                    return tree[key][feature_value]
    ​
        def predict(self, products):
            return [self.predict_tree(self.__tree, product) for product in products]
    ​
    ​
    def load_data():
        labels = ['Outlook', 'Temperature', 'Humidity', 'Wind', 'Play']
        data_set = [
            ['Sunny', 'Hot', 'High', 'Weak', 'No'],
            ['Sunny', 'Hot', 'High', 'Strong', 'No'],
            ['Overcast', 'Hot', 'High', 'Weak', 'Yes'],
            ['Rain', 'Mild', 'High', 'Weak', 'Yes'],
            ['Rain', 'Cool', 'Normal', 'Weak', 'Yes'],
            ['Rain', 'Cool', 'Normal', 'Strong', 'No'],
            ['Overcast', 'Cool', 'Normal', 'Strong', 'Yes'],
            ['Sunny', 'Mild', 'High', 'Weak', 'No'],
            ['Sunny', 'Cool', 'Normal', 'Weak', 'Yes'],
            ['Rain', 'Mild', 'Normal', 'Weak', 'Yes'],
            ['Sunny', 'Mild', 'Normal', 'Strong', 'Yes'],
            ['Overcast', 'Mild', 'High', 'Strong', 'Yes'],
            ['Overcast', 'Hot', 'Normal', 'Weak', 'Yes'],
            ['Rain', 'Mild', 'High', 'Strong', 'No']
        ]
    ​
        return data_set, labels
    ​
    ​
    if __name__ == '__main__':
        data_set, labels = load_data()
        dt = DecisionTree()
        dt.fit(data_set, labels)
    ​
        products = [
                    ['Rain', 'Mild', 'High', 'Strong'],
                    ['Overcast', 'Hot', 'High', 'Strong'],
                    ]
    ​
        probs = dt.predict(products)
        print "probs:", probs
    

    参考

    • [1] 周志华.<<机器学习>>
    • [2] 李航.<<统计学习方法>>


      关注公众号

    相关文章

      网友评论

        本文标题:决策树(附代码实现)

        本文链接:https://www.haomeiwen.com/subject/wijyyhtx.html