美文网首页
【机器学习】决策树(Matplotlib可视化+项目实战)

【机器学习】决策树(Matplotlib可视化+项目实战)

作者: Geekero | 来源:发表于2020-09-06 14:17 被阅读0次

    接着上一篇文章【机器学习】决策树(构造篇)

    一、使用Matplotlib注释绘制树形图

    1.1 Matplotlib注解annotations

    import matplotlib.pyplot as plt
    
    #定义文本框和箭头格式
    decisionNode = dict(boxstyle='sawtooth', fc="0.8",)
    leafNode = dict(boxstyle='round4', fc='0.8')
    arrow_args = dict(arrowstyle="<-")
    
    #绘制带箭头的注解 实际的绘图功能
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        createPlot.ax1.annotate(nodeTxt, xy=parentPt, \
                                xycoords='axes fraction',
                                xytext=centerPt, textcoords='axes fraction', \
                                va="center", ha='center', bbox=nodeType, arrowprops=arrow_args)
    
    #主绘图函数
    def createPlot():
        #创建一个绘图区的全局变量
        fig = plt.figure(1, facecolor='white')
        fig.clf() #清空绘图区
        createPlot.ax1 = plt.subplot(111, frameon=False)
        plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)
        plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
        plt.show()
    

    运行:

    import treePlotter
    treePlotter.createPlot()
    

    1.2 构造注释树

    • 确定x轴长度:叶节点长度
    • 确定y轴高度:树的层数

    获得叶节点的数目和数的层数:

    def getNumLeafs(myTree):
        numLeafs = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                numLeafs += getNumLeafs(secondDict[key])
            else:
                numLeafs += 1
        return numLeafs
        
    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
                
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
        
    #输出预先存储的树结构
    def retrieveTree(i):
        listOfTrees = [{'no surfacing': {0:'no', 1:{'flippers':\
                                                   {0:'no', 1:'yes'}}}},
                       {'no surfacing': {0:'no', 1:{'flippers':\
                                                   {0:{'head':{0:'no', 1:'yes'}}, 1:'no'}}}}
                      ]
        return listOfTrees[i]
    

    运行:

    treePlotter.retrieveTree(0)
    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    
    treePlotter.retrieveTree(1)
    {'no surfacing': {0: 'no',
      1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    
    myTree = treePlotter.retrieveTree(0)
    treePlotter.getNumLeafs(myTree)
    3
    
    treePlotter.getTreeDepth(myTree)
    2
    

    plotTree函数:

    #在父节点间填充文本信息
    def plotMidText(cntrPt, parentPt, txtString):
        xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
        createPlot.ax1.text(xMid, yMid, txtString)
        
    def plotTree(myTree, parentPt, nodeTxt):
        numLeafs = getNumLeafs(myTree) #计算宽
        depth = getTreeDepth(myTree) #计算高
        firstStr = list(myTree.keys())[0]
        #计算已经绘制的节点的位置,以及放置下一个节点的恰当位置
        #通过计算树所包含的所有叶子节点数,划分图形的宽度,从而计算得到当前节点的中心位置
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
        # 标记子节点的属性值
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(firstStr, cntrPt, parentPt, decisionNode)
        secondDict = myTree[firstStr]
        # 按比例减少全局变量plotTree.yOff,并标注此处需要绘制子节点
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #依次递减y坐标
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                plotTree(secondDict[key], cntrPt, str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #在绘制完所有子节点以后,增加全局变量Y的偏移值
        
    def createPlot(inTree):
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
        plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
        plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
        #追踪已经绘制的节点的位置
        plotTree.xOff = -0.5/plotTree.totalW 
        plotTree.yOff = 1.0
        plotTree(inTree, (0.5, 1.0), '')
        plt.show()
    

    运行:

    treePlotter.createPlot(myTree)
    
    myTree['no surfacing'][3]='maybe'
    myTree
    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
    
    treePlotter.createPlot(myTree)
    

    二、测试和存储分类器

    2.1 测试算法:使用决策树执行分类

    生成特征标签列表帮助程序确定特征在数据集中的位置

    #使用决策树的分类函数
    def classify(inputTree, featLabels, testVec):
        firstStr = list(inputTree.keys())[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr) #寻找当前列表中第一个匹配firstStr变量的元素
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
    

    运行:

    >>> reload(trees)
    <module 'trees' from 'D:\\Data\\File\\ML\\trees\\trees.py'>
    >>> myDat, labels = trees.createDataSet()
    >>> labels
    ['no surfacing', 'flippers']
    >>> myTree=trees.retrieveTree(0)
    >>> myTree
    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    >>> trees.classify(myTree, labels, [1,0])
    'no'
    >>> trees.classify(myTree, labels, [1,1])
    'yes'
    

    2.2 使用算法:在硬盘中存储决策树分类器

    使用pickle模块存储决策树

    def storeTree(inputTree, filename):
        import pickle
        fw = open(filename, 'wb')
        pickle.dump(inputTree, fw)
        fw.close()
        
    def grapTree(filename):
        import pickle
        fr = open(filename, 'rb')
        return pickle.load(fr)
    

    运行:

    >>> reload(trees)
    <module 'trees' from 'D:\\Data\\File\\ML\\trees\\trees.py'>
    >>> trees.storeTree(myTree, 'classifierStorage.txt')
    >>> trees.grapTree('classifierStorage.txt')
    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    

    这也是决策树的优点之一,k-近邻法无法持续化分类器

    三、项目实战:使用决策树预测患者需要佩戴的隐形眼镜类型

    步骤:

    • 收集数据:提供的文本文件
    • 准备数据:解析tab键分隔的数据行
    • 分析数据:快熟检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图
    • 训练数据:使用createTree()函数
    • 测试数据:编写测试函数1验证决策树可以正确分类给定的数据实例
    • 使用算法:存储树的数据结构,以便于下次使用时无需重新造树
    >>> import trees
    >>> fr=open('lenses.txt')
    >>> lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    >>> lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    >>> lensesTree = trees.createTree(lenses, lensesLabels)
    >>> lensesTree
    {'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}, 'reduced': 'no lenses'}}
    >>> treePlotter.createPlot(lensesTree)
    
    由ID3算法得到的决策树

    问题与改进

    这里使用的是ID3算法的决策树,很好的算法但不完美:

    1. 匹配选项太多, 过度匹配(overfitting)
    2. 解决过度匹配的问题,可以裁剪决策树,去掉没有必要的叶子节点
    3. 如果叶子节点只能增加少许信息,则可以删除该节点,并将它并入其他叶子节点中。也就是,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配的问题。
    4. ID3算法无法直接处理数值型数据,虽然可以通过量化的方法将数值型转化为标称型数值,但如果存在太多的特征划分,ID3仍然会面临其他问题
    5. 可以采取C4.5,或CART算法
    总结
    • k-近邻法和决策树都是将数据实例明确划分到某个分类
    • 像朴素贝叶斯算法,将不能完全确定的数据实例划分带某个分类,或者只能给出数据实例属于给定分类的概率

    学习自:《机器学习实战》

    相关文章

      网友评论

          本文标题:【机器学习】决策树(Matplotlib可视化+项目实战)

          本文链接:https://www.haomeiwen.com/subject/qugfektx.html