用Matplotlib绘制树形图
Matplotlib annotations
Matplotlib提供了一个非常有用的注解工具annotations
,它可以在数据图形上添加文本注解。
import matplotlib.pyplot as plt
# 中文显示配置
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 定义文本框和箭头格式
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle = '<-')
def createPlot():
fig = plt.figure(1, facecolor='black')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords='axes fraction',\
xytext = centerPt, textcoords='axes fraction',\
va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
执行createPlot()
得到如下图形。
构造注解树
绘制树需要知道有多少个叶节点,以便确定x轴的长度,还需要知道树有多少层,以便确定y轴的高度。
# 获取叶节点的个数
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0] # 根节点
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的类型是否为dict
numLeafs += getNumLeafs(secondDict[key]) # 递归遍历
else:
numLeafs += 1
return numLeafs
# 获取树的层树
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 预存树信息。
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[I]
# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
# 计算宽高,决定x轴和y轴的长度
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # 节点文本
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': # 看节点类型是否为字典,不是则为叶节点
plotTree(secondDict[key],cntrPt,str(key)) # 递归
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='black')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
执行结果
决策树
测试和存储分类器
测试算法:使用决策树执行分类
# 使用决策树的分类器
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
用之前预存树信息的retrieveTree
函数获得一棵树,然后进行分类。
myTree = retrieveTree(0)
classify(myTree, labels, [1,1])
结果为yes
。分类正确。
决策树的存储
构建决策树是很耗时的任务,为了节省时间,最好是直接用已构建好的决策树。为了解决这个问题,需要用模块pickle
存储决策树。
# 使用pickle存储决策树
import pickle
def storeTree(inputTree, filename):
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
fr = open(filename)
return pickle.load(fr)
小结
开始处理数据集时,首先需要计算熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。
ID3算法可以用于划分标称型数据集。
示例中的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配的问题。可以通过裁剪树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题。
还有其他决策树的构造算法,例如,C4.5和CART。
网友评论