概述
- 优点:可以对复杂和非线性的数据建模
- 缺点:结果不易理解
- 适用数据类型:数值型和标称型数据
CART是树构建算法,使用二元切分来处理连续型变量。基于ID3算法的决策树使用香农熵来度量集合的无序程度,如果用其他方法来代替香农熵,就可以用树构建算法来完成回归。
树的构建
在树的构建过程中,需要解决多种类型数据的存储问题。这里将使用字典来存储树的数据结构,该字典包含以下4个元素。
- 待切分的特征
- 待切分的特征值
- 右子树。当不再需要切分时,也可以是单个值
- 左子树。与右子树类似
后面将构建两种树:
- 回归树,每个叶节点包含单个值;
- 模型树,每个叶节点包含一个线性方程。
import numpy as np
# 加载数据集
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
# 将每行数据都映射成浮点数
fltLine = map(float,curLine)
dataMat.append(fltLine)
return dataMat
def splitDataSet(dataSet, feature, value):
mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
feat, val = chooseBestSplit(dataSet, leafType, regErr, ops)
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
left, right = splitDataSet(dataSet, feat, val)
retTree['left'] = createTree(left, leafType, errType, ops)
retTree['right'] = createTree(right, leafType, errType, ops)
return retTree
splitDataSet
函数在给定特征和特征值的情况下,通过数组过滤方式将数据集切分得到两个子集返回。
树构建函数createTree
有4个参数,数据集;leafType
给出建立叶节点的函数;errType
代表误差计算函数;ops
元组带有两个值,一个是容许的误差下降值,一个是切分的最少样本数。
树的构建最重要的是找到最佳的划分点。用决策树进行分类,会计算数据集的熵,用来度量数据的混乱度。那么连续型数值混乱度如何计算?
首先是计算所有数据的均值,然后计算每条数据到均值的差值的平方。这有点儿类似于方差计算。唯一不同的是,方差是平方误差的均值(均方差),而这里是平方误差的总值(总方差)。总方差可以通过均方差乘以数据集样本个数得到。
# 生成叶节点
def regLeaf(dataSet):
return np.mean(dataSet[:,-1])
# 误差估计
def regErr(dataSet):
return np.var(dataSet[:,-1]) * dataSet.shape[0]
def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
tolS = ops[0]
tolN = ops[1]
# 如果所有值相等,则退出
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = dataSet.shape
S = errType(dataSet)
bestS = np.inf
bestFeat = 0
bestVal = 0
for featIndex in range(n-1): # 遍历所有特征
for splitVal in set(dataSet[:, featIndex].T.tolist()[0]): # 遍历所有特征值
mat0, mat1 = splitDataSet(dataSet, featIndex, splitVal)
if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestFeat = featIndex
bestVal = splitVal
bestS = newS
# 如果误差减少不大则退出
if (S - bestS) < tolS:
return None, leafType(dataSet)
# 如果切分出的数据集很小则退出
if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):
return None, leafType(dataSet)
return bestFeat, bestVal
regLeaf
函数负责生成叶节点,在回归树中就是目标变量的均值。
chooseBestSplit
是回归树构建的核心函数,目的是找到最佳二元划分方式。其参数ops
设定了tolN
和tolS
两个值,用于控制函数的停止时机。
执行createTree
查看效果。
dataSet = loadDataSet('ex00.txt')
myMat = np.mat(dataSet)
createTree(myMat)
结果如下:
{'spInd': 0,
'spVal': 0.48813,
'left': 1.0180967672413792,
'right': -0.04465028571428572}
数据分布如下
再看一个多次切分的例子。数据分布如下。
构建回归树。
dataSet = loadDataSet('ex0.txt')
myMat = np.mat(dataSet)
createTree(myMat)
结果如下。
{'spInd': 1,
'spVal': 0.39435,
'left': {'spInd': 1,
'spVal': 0.582002,
'left': {'spInd': 1,
'spVal': 0.797583,
'left': 3.9871632,
'right': 2.9836209534883724},
'right': 1.980035071428571},
'right': {'spInd': 1,
'spVal': 0.197834,
'left': 1.0289583666666666,
'right': -0.023838155555555553}}
网友评论