美文网首页
【机器学习】决策树(Matplotlib可视化+项目实战)

【机器学习】决策树(Matplotlib可视化+项目实战)

作者: Geekero | 来源:发表于2020-09-06 14:17 被阅读0次

接着上一篇文章【机器学习】决策树(构造篇)

一、使用Matplotlib注释绘制树形图

1.1 Matplotlib注解annotations

import matplotlib.pyplot as plt

#定义文本框和箭头格式
decisionNode = dict(boxstyle='sawtooth', fc="0.8",)
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle="<-")

#绘制带箭头的注解 实际的绘图功能
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, \
                            xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha='center', bbox=nodeType, arrowprops=arrow_args)

#主绘图函数
def createPlot():
    #创建一个绘图区的全局变量
    fig = plt.figure(1, facecolor='white')
    fig.clf() #清空绘图区
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

运行:

import treePlotter
treePlotter.createPlot()

1.2 构造注释树

  • 确定x轴长度:叶节点长度
  • 确定y轴高度:树的层数

获得叶节点的数目和数的层数:

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
    
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
            
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth
    
#输出预先存储的树结构
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0:'no', 1:{'flippers':\
                                               {0:'no', 1:'yes'}}}},
                   {'no surfacing': {0:'no', 1:{'flippers':\
                                               {0:{'head':{0:'no', 1:'yes'}}, 1:'no'}}}}
                  ]
    return listOfTrees[i]

运行:

treePlotter.retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

treePlotter.retrieveTree(1)
{'no surfacing': {0: 'no',
  1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}

myTree = treePlotter.retrieveTree(0)
treePlotter.getNumLeafs(myTree)
3

treePlotter.getTreeDepth(myTree)
2

plotTree函数:

#在父节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)
    
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree) #计算宽
    depth = getTreeDepth(myTree) #计算高
    firstStr = list(myTree.keys())[0]
    #计算已经绘制的节点的位置,以及放置下一个节点的恰当位置
    #通过计算树所包含的所有叶子节点数,划分图形的宽度,从而计算得到当前节点的中心位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    # 标记子节点的属性值
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # 按比例减少全局变量plotTree.yOff,并标注此处需要绘制子节点
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #依次递减y坐标
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #在绘制完所有子节点以后,增加全局变量Y的偏移值
    
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
    plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
    #追踪已经绘制的节点的位置
    plotTree.xOff = -0.5/plotTree.totalW 
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

运行:

treePlotter.createPlot(myTree)
myTree['no surfacing'][3]='maybe'
myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}

treePlotter.createPlot(myTree)

二、测试和存储分类器

2.1 测试算法:使用决策树执行分类

生成特征标签列表帮助程序确定特征在数据集中的位置

#使用决策树的分类函数
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr) #寻找当前列表中第一个匹配firstStr变量的元素
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

运行:

>>> reload(trees)
<module 'trees' from 'D:\\Data\\File\\ML\\trees\\trees.py'>
>>> myDat, labels = trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=trees.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree, labels, [1,0])
'no'
>>> trees.classify(myTree, labels, [1,1])
'yes'

2.2 使用算法:在硬盘中存储决策树分类器

使用pickle模块存储决策树

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()
    
def grapTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)

运行:

>>> reload(trees)
<module 'trees' from 'D:\\Data\\File\\ML\\trees\\trees.py'>
>>> trees.storeTree(myTree, 'classifierStorage.txt')
>>> trees.grapTree('classifierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

这也是决策树的优点之一,k-近邻法无法持续化分类器

三、项目实战:使用决策树预测患者需要佩戴的隐形眼镜类型

步骤:

  • 收集数据:提供的文本文件
  • 准备数据:解析tab键分隔的数据行
  • 分析数据:快熟检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图
  • 训练数据:使用createTree()函数
  • 测试数据:编写测试函数1验证决策树可以正确分类给定的数据实例
  • 使用算法:存储树的数据结构,以便于下次使用时无需重新造树
>>> import trees
>>> fr=open('lenses.txt')
>>> lenses = [inst.strip().split('\t') for inst in fr.readlines()]
>>> lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
>>> lensesTree = trees.createTree(lenses, lensesLabels)
>>> lensesTree
{'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}, 'reduced': 'no lenses'}}
>>> treePlotter.createPlot(lensesTree)
由ID3算法得到的决策树

问题与改进

这里使用的是ID3算法的决策树,很好的算法但不完美:

  1. 匹配选项太多, 过度匹配(overfitting)
  2. 解决过度匹配的问题,可以裁剪决策树,去掉没有必要的叶子节点
  3. 如果叶子节点只能增加少许信息,则可以删除该节点,并将它并入其他叶子节点中。也就是,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配的问题。
  4. ID3算法无法直接处理数值型数据,虽然可以通过量化的方法将数值型转化为标称型数值,但如果存在太多的特征划分,ID3仍然会面临其他问题
  5. 可以采取C4.5,或CART算法
总结
  • k-近邻法和决策树都是将数据实例明确划分到某个分类
  • 像朴素贝叶斯算法,将不能完全确定的数据实例划分带某个分类,或者只能给出数据实例属于给定分类的概率

学习自:《机器学习实战》

相关文章

网友评论

      本文标题:【机器学习】决策树(Matplotlib可视化+项目实战)

      本文链接:https://www.haomeiwen.com/subject/qugfektx.html