美文网首页
决策树(下)

决策树(下)

作者: RossH | 来源:发表于2019-10-26 11:59 被阅读0次

    用Matplotlib绘制树形图

    Matplotlib annotations

    Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图形上添加文本注解。

    import matplotlib.pyplot as plt
    
    # 中文显示配置
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    
    # 定义文本框和箭头格式
    decisionNode = dict(boxstyle='sawtooth', fc='0.8')
    leafNode = dict(boxstyle='round4', fc='0.8')
    arrow_args = dict(arrowstyle = '<-')
    
    def createPlot():
        fig = plt.figure(1, facecolor='black')
        fig.clf()
        createPlot.ax1 = plt.subplot(111, frameon=False)
        plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
        plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
        plt.show()
    
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords='axes fraction',\
                               xytext = centerPt, textcoords='axes fraction',\
                               va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
    

    执行createPlot()得到如下图形。

    注解

    构造注解树

    绘制树需要知道有多少个叶节点,以便确定x轴的长度,还需要知道树有多少层,以便确定y轴的高度。

    # 获取叶节点的个数
    def getNumLeafs(myTree):
        numLeafs = 0
        firstStr = list(myTree.keys())[0]   # 根节点
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':    # 测试节点的类型是否为dict
                numLeafs += getNumLeafs(secondDict[key])    # 递归遍历
            else:
                numLeafs += 1
        return numLeafs
    
    # 获取树的层树
    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
    
    # 预存树信息。
    def retrieveTree(i):
        listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                      {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                      ]
        return listOfTrees[I]
    
    # 在父子节点间填充文本信息
    def plotMidText(cntrPt, parentPt, txtString):
        xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
        createPlot.ax1.text(xMid, yMid, txtString)
        
    def plotTree(myTree, parentPt, nodeTxt):
        # 计算宽高,决定x轴和y轴的长度
        numLeafs = getNumLeafs(myTree)  
        depth = getTreeDepth(myTree)
        
        firstStr = list(myTree.keys())[0]     # 节点文本
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(firstStr, cntrPt, parentPt, decisionNode)
        secondDict = myTree[firstStr]
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
        for key in secondDict.keys():
            if type(secondDict[key]).__name__=='dict': # 看节点类型是否为字典,不是则为叶节点
                plotTree(secondDict[key],cntrPt,str(key))        # 递归
            else:   
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    def createPlot(inTree):
        fig = plt.figure(1, facecolor='black')
        fig.clf()
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
        plotTree.totalW = float(getNumLeafs(inTree))
        plotTree.totalD = float(getTreeDepth(inTree))
        plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
        plotTree(inTree, (0.5,1.0), '')
        plt.show()
    

    执行结果


    决策树

    测试和存储分类器

    测试算法:使用决策树执行分类

    # 使用决策树的分类器
    def classify(inputTree, featLabels, testVec):
        firstStr = list(inputTree.keys())[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
    

    用之前预存树信息的retrieveTree函数获得一棵树,然后进行分类。

    myTree = retrieveTree(0)
    classify(myTree, labels, [1,1])
    

    结果为yes。分类正确。

    决策树的存储

    构建决策树是很耗时的任务,为了节省时间,最好是直接用已构建好的决策树。为了解决这个问题,需要用模块pickle存储决策树。

    # 使用pickle存储决策树
    import pickle
    def storeTree(inputTree, filename):
        fw = open(filename, 'w')
        pickle.dump(inputTree, fw)
        fw.close()
        
    def grabTree(filename):
        fr = open(filename)
        return pickle.load(fr)
    

    小结

    开始处理数据集时,首先需要计算熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。

    ID3算法可以用于划分标称型数据集。

    示例中的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配的问题。可以通过裁剪树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题。

    还有其他决策树的构造算法,例如,C4.5CART

    相关文章

      网友评论

          本文标题:决策树(下)

          本文链接:https://www.haomeiwen.com/subject/oevyvctx.html