美文网首页人工智能技术圈
机器学习(七):决策树

机器学习(七):决策树

作者: fromeast | 来源:发表于2019-08-04 20:59 被阅读13次

一、基本原理

1.1、思想与流程

决策树(decision tree)是机器学习中常见的分类与回归方法,是一种呈树形结构的判别模型。决策树可以看做一个互斥且完备的if-then规则集合。同时决策树还表示定义在特征条件下类的条件概率分布,条件概率分布将特征空间划分为互不相交的单元(cell)或区域(region),每个单元定义一个类的概率分布就构成了一个条件概率分布。该条件概率分布可表示为P(Y | X),其中X表示特征的随机变量,Y表示类的随机变量。

决策树与概率分布对应关系

决策树学习本质上是从训练集中归纳出一组分类规则,是训练数据矛盾较小,同时具有很好的泛化性能。决策树的损失函数通常是正则化的极大似然函数,学习的目标是以损失函数为目标函数的最小化。决策树学习的算法通常是一个递归地选择最优特征,并根据该特征对训练数据进行分割,使得对各个子数据集有一个最好的分类的过程。其基本过程如下:

决策树基本过程

1.2、特征选择

特征选择在于选取对训练数据具有分类能力的特征,这是决策树学习的关键。常见的特征选择的准则是信息增益和信息增益率。

信息增益(information gain)的表达式为g(D, A)=H(D)-H(D | A) 其中A为特征,D为训练数据集,H(D)为信息熵,H(D | A)为条件熵。
随机变量X的信息熵为H(X)=-\sum_{i=1}^{n} p_{i} \log p_{i},条件熵为X给定下Y的条件概率分布的熵对X的数学期望H(Y | X)=\sum_{i=1}^{n} p_{i} H\left(Y | X=x_{i}\right)
信息增益率(information gain rate)即其信息增益与信息熵之比,即g_{R}(D, A)=\frac{g(D, A)}{H(D)}

1.3、剪枝

剪枝(pruning)是决策树处理过拟合的主要手段。具体地,剪枝从已生成的树上裁掉一些子树或叶节点,并将其根节点或父节点作为新的叶节点,从而简化树模型。剪枝可分为“预剪枝”和“后剪枝”,预剪枝是指在决策树生成过程中,对每个节点在划分前进行估计,若当前节点的划分不能带来决策树泛化性能的提升,则停止划分并将当前节点标记为叶节点;后剪枝则是先从训练集生成一棵完整的决策树,然后自底向上对非叶节点进行考察,若将该节点对应子树替换为叶节点能提高决策树泛化性能,则将该子树替换为叶节点。

二、算法实现

2.1、手动实现

1、模块导入与数据生成

import pickle
import operator
import matplotlib.pyplot as plt
from math import log

def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],                      #数据集
            [0, 0, 0, 1, 'no'],
            [0, 1, 0, 1, 'yes'],
            [0, 1, 1, 0, 'yes'],
            [0, 0, 0, 0, 'no'],
            [1, 0, 0, 0, 'no'],
            [1, 0, 0, 1, 'no'],
            [1, 1, 1, 1, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [2, 0, 1, 2, 'yes'],
            [2, 0, 1, 1, 'yes'],
            [2, 1, 0, 1, 'yes'],
            [2, 1, 0, 2, 'yes'],
            [2, 0, 0, 0, 'no']]
    labels = ['age', 'work', 'house', 'credit']     #特征标签
    return dataSet, labels  

def splitData(data,axis,value):
    newData = []
    for vec in data:
        if vec[axis] == value:
            newData.append((vec[:axis]+vec[axis+1:]))
    return newData

2、计算交叉熵

def calcEntropy(data):
    row = len(data)
    label = {}
    for vec in data:
        current_label = vec[-1]
        if current_label not in label.keys():
            label[current_label] = 0
        label[current_label] += 1
    entropy = 0
    for key in label:
        prob = float(label[key])/row
        entropy -= prob*log(prob,2)
    return entropy

3、选择最优特征

def chooseFeature(data):
    features = len(data[0])-1
    entropy = calcEntropy(data)
    best_info_gain = 0.0
    best_feature = -1
    for i in range(features):
        feature_list = set([example[i] for example in data])
        temp = 0.0
        for value in feature_list:
            subdata = splitData(data,i,value)
            prob = len(subdata)/len(data)
            temp += prob*calcEntropy(data)
        info_gain = entropy - temp
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature

4、统计类标签中最多的元素

def majorityClass(class_list):
    class_count = {}
    for vote in class_list:
        if vote not in class_count.keys():
            class_count[vote] = 0
            class_count[vote] += 1
    sort_class = sorted(class_count.items(),key = operator.itemgetter(1),reverse = True)
    return sort_class[0][0]

5、创建决策树

def createTree(data,labels,features):
    class_list = [example[-1] for example in data]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list
    if len(data[0]) == 1 or len(labels) == 0:
        return majorityClass(class_list)
    best_feature = chooseFeature(data)
    best_label = labels[best_feature]
    features.append(best_label)
    
    tree = {best_label:{}}
    del(labels[best_feature])
    
    feature_list = set([example[best_feature] for example in data])
    for value in feature_list:
        sublabels = labels[:]
        tree[best_label][value] = createTree(splitData(data,best_feature,value),sublabels,features)
    return tree

6、获取决策树叶子节点数目及决策树层数

def numLeafs(tree):
    leafs = 0
    string = next(iter(tree))
    tree_dict = tree[string]
    for key in tree_dict.keys():
        if type(tree_dict[key]).__name__=='dict':
            leafs += numLeafs(tree_dict[key])
        else:
            leafs += 1
    return leafs
def treeDepth(tree):
    max_depth = 0
    string = next(iter(tree))
    tree_dict = tree[string]
    for key in tree_dict.keys():
        if type(tree_dict[key]).__name__=='dict':
            depth = 1 + treeDepth(tree_dict[key])
        else:
            depth = 1
        if depth>max_depth:
            max_depth = depth
    return max_depth

7、绘制图像,包括树、节点、边属性

def plotNode(node_txt,centerPt,parent,node_type):
    arrow = dict(arrowstyle = '<-')
    createPlot.ax1.annotate(node_txt,xy=parent,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=node_type,arrowprops=arrow)
         
def plotText(cntr,parent,txt):
    x_mid = (parent[0]-cntr[0])/2.0 + cntr[0]
    y_mid = (parent[1]-cntr[1])/2.0 + cntr[1]
    createPlot.ax1.text(x_mid,y_mid,txt,va='center',ha='center',rotation=30)
    
def createPlot(tree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(numLeafs(tree))
    plotTree.totalD = float(treeDepth(tree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(tree,(0.5,1.0),'')
    plt.show()
    
def plotTree(tree,parent,txt):
    decision_node = dict(boxstyle='sawtooth',fc='0.8')
    leaf_node = dict(boxstyle='round4',fc='0.8')
    leafs = numLeafs(tree)
    depth = treeDepth(tree)
    string = next(iter(tree))
    cntr = (plotTree.xOff + (1.0+float(leafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotText(cntr,parent,txt)
    plotNode(string,cntr,parent,decision_node)
    tree_dict = tree[string]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    
    for key in tree_dict.keys():
        if type(tree_dict[key]).__name__=='dict':
            plotTree(tree_dict[key],cntr,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(tree_dict[key],(plotTree.xOff,plotTree.yOff),cntr,leaf_node)
            plotText((plotTree.xOff,plotTree.yOff),cntr,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

8、使用决策树进行分类及存储决策树

def classify(tree,labels,vec):
    string = next(iter(tree))
    tree_dict = tree[string]
    index = labels.index(string)
    
    for key in tree_dict.keys():
        if vec[index] == key:
            if type(tree_dict[key]).__name__=='dict':
                class_label = classify(tree_dict[key],labels,vec)
            else:
                class_label = tree_dict[key]
    return class_label

def storeTree(tree,filename):
    with open(filename,'wb') as f:
        pickle.dump(tree,f)

9、主函数

if __name__ == '__main__':
    data,labels = createDataSet()
    feature_labels = []
    tree = createTree(data,labels,feature_labels)
    createPlot(tree)
    
    test_vec = [0,1]
    result = classify(tree,feature_labels,test_vec)
    if result == 'yes':
        print('lending')
    if result == 'no':
        print('no lending')

下图为上述决策树产生的分类示意图,展现了分类树的作用过程。


分类示意图

2.2、使用sklearn库

可以调用mglearn库,展现动物分类的过程,如下:

import mglearn
mglearn.plots.plot_animal_tree()
动物分类举例

使用sklearn库中DecisionTreeClassifier对癌症数据集进行处理,结果如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import export_graphviz
import graphviz

cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42)

tree = DecisionTreeClassifier(max_depth=4,random_state=0)
tree.fit(X_train,y_train)

print('accuracy on training set:{:.3f}'.format(tree.score(X_train,y_train)))
print('accuracy on test set:{:.3f}'.format(tree.score(X_test,y_test)))

accuracy on training set:0.988
accuracy on test set:0.951

绘制其分类过程如下:

export_graphviz(tree,out_file='tree.dot',class_names=['malignant','benign'],feature_names=cancer.feature_names,impurity=False,filled=True)
with open('tree.dot') as f:
    dot_graph = f.read()
    
graphviz.Source(dot_graph)
癌症数据集分类过程

同样,可找出其特征重要性,如下:

n_features = cancer.data.shape[1]
plt.barh(range(n_features),tree.feature_importances_,align='center')
plt.yticks(np.arange(n_features),cancer.feature_names)
plt.xlabel('Feature importance')
plt.ylabel('Feature')
特征重要性分布

三、问题探讨

(信息)熵、联合熵、条件熵、相对熵、互信息

熵(Entropy)是随机变量的不确定性的度量。设X是离散随机变量,其概率密度函数p(x)=P(X=x),则其信息熵为:H(X)=-\sum_{x} p(x) \log p(x) 当对数底数是2时,单位是bit,当对数底数是e时,单位是nat(奈特)。
如果随机变量(X, Y) \sim p(x, y),则其联合熵(Joint entropy)为:H(X, Y)=-\sum_{x } \sum_{y} p(x, y) \log p(x, y)=-E \log p(X, Y)
类似的,其条件熵(Conditional entropy)为:\begin{aligned} H(Y | X)=\sum_{x} p(x) H(Y | X=x)=-& \sum_{x} p(x) \sum_{y} p(y | x) \log p(y | x)=-\sum_{x } \sum_{y} p(x, y) \log p(y | x) =-E \log p(Y | X) \end{aligned}
另外,有如下关系成立:H(X, Y)=H(X)+H(Y | X),证明如下:H(X, Y) =-\sum_{x} \sum_{y} p(x, y) \log p(x, y)=-\sum_{x } \sum_{y } p(x, y) \log p(x) p(y | x) \\ =-\sum_{x } \sum_{y } p(x, y) \log p(x)-\sum_{x} \sum_{y} p(x, y) \log p(y | x) \\ =-\sum_{x} p(x) \log p(x)+H(Y | X)=H(X)+H(Y | X)
交叉熵(Cross entropy)(又称相对熵,KL散度等),是两个随机分布之间距离的度量。当真实分布为p(x),而假定分布为q(x),其交叉熵为:D(p \| q)=\sum_{x} p(x) \log \frac{p(x)}{q(x)}
互信息(Mutual information)是一个随机变量包含另一个随机变量信息量的度量,也可以说是在给定一个随机变量的条件下,原随机变量的不确定性的减少量,即:I(X ; Y)=\sum_{x } \sum_{y} p(x, y) \log \frac{p(x, y)}{p(x) p(y)}=D(p(x, y) \| p(x) p(y))
熵与互信息的关系:I(X ; Y)=H(Y)-H(Y | X)=H(X)-H(X | Y) 证明如下:
\begin{array}{c}{I(X ; Y)=\sum_{x , y } p(x, y) \log \frac{p(x, y)}{p(x) p(y)}=\sum_{x , y } p(x, y) \log \frac{p(x | y)}{p(x)}} \\ {=-\sum_{x , y} p(x, y) \log p(x)-\left(-\sum_{x, y} p(x, y) \log p(x | y)\right) \\=H(X)-H(X | Y)}\end{array}

参考资料

[1] https://github.com/lawlite19/MachineLearning_Python
[2] 周志华 著. 机器学习. 北京:清华大学出版社,2016
[3] 李航 著. 统计学习方法. 北京:清华大学出版社,2012
[4] 史春奇等 著. 机器学习算法背后的理论与优化. 北京:清华大学出版社,2019
[5] Peter Harrington 著. 李锐等 译. 机器学习实战. 北京:人民邮电出版社,2013

劝君更尽一杯酒,西出阳关无故人。 ——王维《送元二使安西》

相关文章

网友评论

    本文标题:机器学习(七):决策树

    本文链接:https://www.haomeiwen.com/subject/aqjjkctx.html