接着上一篇文章【机器学习】决策树(构造篇)
一、使用Matplotlib注释绘制树形图
1.1 Matplotlib注解annotations
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle='sawtooth', fc="0.8",)
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle="<-")
#绘制带箭头的注解 实际的绘图功能
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, \
xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction', \
va="center", ha='center', bbox=nodeType, arrowprops=arrow_args)
#主绘图函数
def createPlot():
#创建一个绘图区的全局变量
fig = plt.figure(1, facecolor='white')
fig.clf() #清空绘图区
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
运行:
import treePlotter
treePlotter.createPlot()
1.2 构造注释树
- 确定x轴长度:叶节点长度
- 确定y轴高度:树的层数
获得叶节点的数目和数的层数:
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
#输出预先存储的树结构
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0:'no', 1:{'flippers':\
{0:'no', 1:'yes'}}}},
{'no surfacing': {0:'no', 1:{'flippers':\
{0:{'head':{0:'no', 1:'yes'}}, 1:'no'}}}}
]
return listOfTrees[i]
运行:
treePlotter.retrieveTree(0)
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
treePlotter.retrieveTree(1)
{'no surfacing': {0: 'no',
1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
myTree = treePlotter.retrieveTree(0)
treePlotter.getNumLeafs(myTree)
3
treePlotter.getTreeDepth(myTree)
2
plotTree函数:
#在父节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #计算宽
depth = getTreeDepth(myTree) #计算高
firstStr = list(myTree.keys())[0]
#计算已经绘制的节点的位置,以及放置下一个节点的恰当位置
#通过计算树所包含的所有叶子节点数,划分图形的宽度,从而计算得到当前节点的中心位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 标记子节点的属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# 按比例减少全局变量plotTree.yOff,并标注此处需要绘制子节点
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #依次递减y坐标
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #在绘制完所有子节点以后,增加全局变量Y的偏移值
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
#追踪已经绘制的节点的位置
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
运行:
treePlotter.createPlot(myTree)
myTree['no surfacing'][3]='maybe'
myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
treePlotter.createPlot(myTree)
二、测试和存储分类器
2.1 测试算法:使用决策树执行分类
生成特征标签列表帮助程序确定特征在数据集中的位置
#使用决策树的分类函数
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr) #寻找当前列表中第一个匹配firstStr变量的元素
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
运行:
>>> reload(trees)
<module 'trees' from 'D:\\Data\\File\\ML\\trees\\trees.py'>
>>> myDat, labels = trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=trees.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree, labels, [1,0])
'no'
>>> trees.classify(myTree, labels, [1,1])
'yes'
2.2 使用算法:在硬盘中存储决策树分类器
使用pickle模块存储决策树
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grapTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
运行:
>>> reload(trees)
<module 'trees' from 'D:\\Data\\File\\ML\\trees\\trees.py'>
>>> trees.storeTree(myTree, 'classifierStorage.txt')
>>> trees.grapTree('classifierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
这也是决策树的优点之一,k-近邻法无法持续化分类器
三、项目实战:使用决策树预测患者需要佩戴的隐形眼镜类型
步骤:
- 收集数据:提供的文本文件
- 准备数据:解析tab键分隔的数据行
- 分析数据:快熟检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图
- 训练数据:使用createTree()函数
- 测试数据:编写测试函数1验证决策树可以正确分类给定的数据实例
- 使用算法:存储树的数据结构,以便于下次使用时无需重新造树
>>> import trees
>>> fr=open('lenses.txt')
>>> lenses = [inst.strip().split('\t') for inst in fr.readlines()]
>>> lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
>>> lensesTree = trees.createTree(lenses, lensesLabels)
>>> lensesTree
{'tearRate': {'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}, 'reduced': 'no lenses'}}
>>> treePlotter.createPlot(lensesTree)
由ID3算法得到的决策树
问题与改进
这里使用的是ID3算法的决策树,很好的算法但不完美:
- 匹配选项太多, 过度匹配(overfitting)
- 解决过度匹配的问题,可以裁剪决策树,去掉没有必要的叶子节点
- 如果叶子节点只能增加少许信息,则可以删除该节点,并将它并入其他叶子节点中。也就是,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配的问题。
- ID3算法无法直接处理数值型数据,虽然可以通过量化的方法将数值型转化为标称型数值,但如果存在太多的特征划分,ID3仍然会面临其他问题
- 可以采取C4.5,或CART算法
总结
- k-近邻法和决策树都是将数据实例明确划分到某个分类
- 像朴素贝叶斯算法,将不能完全确定的数据实例划分带某个分类,或者只能给出数据实例属于给定分类的概率
学习自:《机器学习实战》
网友评论