函数createTree()的伪代码:
找到最佳的带切分特征:
如果该节点不能再分,将节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
如何实现数据切分,如何才能知道是否已经充分切分呢?取决于叶节点的建模方式。回归树假设叶节点是常数值,认为数据中复杂关系可以用树结构来概括。
如何计算连续型数值的混乱度?
先计算所有数据的均值,然后计算每条数据的值到均值的差值。为了对正负差值同等看待,一般使用绝对值或平方值来代替上述差值。这里需要的是平方误差的总值(总方差),可以通过均方差乘数据个数得到。
from numpy import *
import matplotlib.pyplot as plt
def loadDataSet(fileName):
dataMat=[]
fr=open(fileName)
for line in fr.readlines():
curLine=line.strip().split('\t')
fltLine=list(map(float,curLine)) #将每行映射成浮点数
dataMat.append(fltLine)
return dataMat
#参数 : 数据集、待切分特征、该特征某个值
#通过数组过滤方式将数据集切分得到两个子集返回
def binSplitDataSet(dataSet,feature,value):
#nonzero(dataSet[:,feature]>value)[0] 取得第feature特征列中大于value且不为零的元素行坐标
mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:]
mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:]
return mat0,mat1
#数据集、建立叶节点函数、误差计算函数、ops是饱含树构建所需其他参数的元组
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
feat,val=chooseBestSplit(dataSet,leafType,errType,ops)
if feat==None: return val
retTree={}
retTree['spInd']=feat
retTree['spVal']=val
lSet,rSet=binSplitDataSet(dataSet,feat,val)
retTree['left']=createTree(lSet,leafType,errType,ops)
retTree['right']=createTree(rSet,leafType,errType,ops)
return retTree
chooseBestSplit() 完成两件事:用最佳方式切分数据集和生成相应的节点。该函数目标是找到数据集切分的最佳位置。它遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。
伪代码:
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
regLeaf()生成叶节点
#负责生成叶节点。在回归树中,该模型其实就是目标变量的均值
#mean() 求均值
def regLeaf(dataSet):
return mean(dataSet[:,-1])
#var():表示方差 即各项-均值的平方求和之后再除以N 此处返回总方差,所以乘个数
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0]
#找到数据的最佳二元切分方式,返回特征编号和切分特征值
#停止条件tolS对误差的数量级十分敏感
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
#tolS,tolN用来控制函数停止的时机
tolS=ops[0] #容许的误差下降值
tolN=ops[1] #切分的最小样本数
#如果所有特征值相同则退出
if len(set(dataSet[:,-1].T.tolist()[0]))==1:
return None,leafType(dataSet)
m,n=shape(dataSet)
#默认最后一个特征为最佳切分特征,计算误差估计
S=errType(dataSet)
#最佳误差,最佳特征切分索引值,最佳特征值
bestS=inf;bestIndex=0;bestValue=0
#遍历特征列
for featIndex in range(n-1):
#遍历特征值
for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
mat0,mat1=binSplitDataSet(dataSet,featIndex,splitVal)
#如果数据少于tolN,则退出
if(shape(mat0)[0]<tolN)or (shape(mat1)[0]<tolN):continue
#计算误差估计
newS=errType(mat0)+ errType(mat1)
if newS<bestS:
bestIndex=featIndex
bestValue=splitVal
bestS=newS
#如果切分后误差减少不够大,则不应进行切分而直接创建叶节点
if(S-bestS)<tolS:
return None,leafType(dataSet)
mat0,mat1=binSplitDataSet(dataSet,bestIndex,bestValue)
#如果切分出的数据集很小则退出
if(shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
return None,leafType(dataSet)
return bestIndex,bestValue
下面测试一下代码实际效果:
ex00.txt数据如下两列:
0.036098 0.155096
0.993349 1.077553
0.530897 0.893462
0.712386 0.564858
0.343554 -0.371700
0.098016 -0.332760
0.691115 0.834391
0.091358 0.099935
0.727098 1.000567
0.951949 0.945255
0.768596 0.760219
...
from regTrees import *
from numpy import *
myDat=loadDataSet('ex00.txt')
myMat=mat(myDat)
feat,val=chooseBestSplit(myMat,regLeaf,regErr,(1,4))
print(feat)
print(val)
输出结果
0
0.48813
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
再看一个多次切分的例子:
myDat=loadDataSet('ex0.txt')
myMat=mat(myDat)
print(createTree(myMat))
ex0.txt数据集如下
1.000000 0.409175 1.883180
1.000000 0.182603 0.063908
1.000000 0.663687 3.042257
1.000000 0.517395 2.305004
1.000000 0.013643 -0.067698
1.000000 0.469643 1.662809
1.000000 0.725426 3.275749
1.000000 0.394350 1.118077
1.000000 0.507760 2.095059
1.000000 0.237395 1.181912
1.000000 0.057534 0.221663
1.000000 0.369820 0.938453
1.000000 0.976819 4.149409
1.000000 0.616051 3.105444
1.000000 0.413700 1.896278
...
数据如图
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1,
'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right':
1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left':
1.0289583666666666, 'right': -0.023838155555555553}}
树剪枝
目前完成了回归树构建,但仍需要检查构建是否得当。通过树剪枝(pruning)技术。如果一个树节点过多,可能过拟合。
通过降低决策树复杂度避免过拟合的过程叫做剪枝。前面函数chooseBestSplit(prepruning)中提前终止条件实际是预剪枝操作。后剪枝(postpruning)需要使用测试机和训练集。
后剪枝:
#回归树剪枝函数
def isTree(obj):
return (type(obj).__name__=='dict')
def getMean(tree):
if isTree(tree['right']):tree['right']=getMean(tree['right'])
if isTree(tree['left']):tree['left']=getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
def prune(tree, testData):
"""
函数说明:后剪枝
Parameters:
tree - 树
test - 测试集
Returns:
树的平均值
"""
#如果测试集为空,则对树进行塌陷处理
if shape(testData)[0] == 0: return getMean(tree)
#如果有左子树或者右子树,则切分数据集
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#处理左子树(剪枝)
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
#处理右子树(剪枝)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#如果当前结点的左右结点为叶结点
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#计算没有合并的误差
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + sum(power(rSet[:,-1] - tree['right'],2))
#计算合并的均值
treeMean = (tree['left'] + tree['right']) / 2.0
#计算合并的误差
errorMerge = sum(power(testData[:,-1] - treeMean, 2))
#如果合并的误差小于没有合并的误差,则合并
if errorMerge < errorNoMerge:
# print("merging")
return treeMean
else: return tree
else: return tree
网友评论