1、决策树
决策树学习通常包括3个步骤:
- 特征选择。
- 决策树生成。
- 决策树剪枝。
决策树的学习目标是:根据给定的训练数据集构造一个决策树模型,使得它能够对样本进行正确的分类。
决策树最优化的策略是:损失函数最小化。决策树的损失函数通常是正则化的极大似然函数。
决策树生成算法
决策树的生成算法:
- 构建根结点:将所有训练数据放在根结点。
- 选择一个最优特征,根据这个特征将训练数据分割成子集,使得各个子集有一个在当前条件下最好的分类。
- 若这些子集已能够被基本正确分类,则将该子集构成叶结点。
- 若某个子集不能够被基本正确分类,则对该子集选择新的最优的特征,继续对该子集进行分割,构建相应的结点。
- 如此递归下去,直至所有训练数据子集都被基本正确分类,或者没有合适的特征为止。
上述生成的决策树可能对训练数据有很好的分类能力,但是对于未知的测试数据却未必有很好要的分类能力,即可能发生过拟合的现象。
- 解决的方法是:对生成的决策树进行剪枝,从而使得决策树具有更好的泛化能力。
- 剪枝是去掉过于细分的叶结点,使得该叶结点中的子集回退到父结点或更高层次的结点并让其成为叶结点。
决策树的生成对应着模型的局部最优,决策树的剪枝则考虑全局最优。
如果学习任意大小的决策树,则可以将决策树算法视作非参数算法。但是实践中经常有大小限制(如限制树的高度、限制树的叶结点数量),从而使得决策树成为有参数模型。
2、特征选择
特征选择的关键是:选取对训练数据有较强分类能力的特征。若一个特征的分类结果与随机分类的结果没有什么差别,则称这个特征是没有分类能力的。
通常特征选择的指标是:信息增益或者信息增益比。这两个指标刻画了特征的分类能力。
以信息增益作为划分训练集的特征选取方案,存在偏向于选取值较多的特征的问题。可以通过定义信息增益比来解决该问题。
信息增益比本质上是对信息增益乘以一个加权系数:
- 当特征 A 的取值集合较大时,加权系数较小,表示抑制该特征。
- 当特征 A 的取值集合较小时,加权系数较大,表示鼓励该特征。
3、生成算法
决策树有两种常用的生成算法:
- ID3 生成算法。
- C4.5 生成算法。
ID3 生成算法和 C4.5 生成算法只有树的生成算法,生成的树容易产生过拟合:对训练集拟合得很好,但是预测测试集效果较差。
ID3 生成算法
ID3 生成算法核心是在决策树的每个结点上应用信息增益准则选择特征,递归地构建决策树:
- 从根结点开始,计算结点所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征划分出子结点。
- 再对子结点递归地调用以上方法,构建决策树。
- 直到所有特征的信息增益均很小或者没有特征可以选择为止,最后得到一个决策树 。
设置特征信息增益阈值,如果不设置特征信息增益的下限,则可能会使得每个叶子都只有一个样本点,从而划分得太细。
C4.5 生成算法
C4.5 生成算法与 ID3 算法相似,但是 C4.5 算法在生成过程中用信息增益比来选择特征。
4、剪枝算法
- 决策树生成算法生成的树往往对于训练数据拟合很准确,但是对于未知的测试数据分类却没有那么准确。即出现过拟合现象。过拟合产生得原因是决策树太复杂。解决的办法是:对决策树剪枝,即对生成的决策树进行简化。
- 决策树的剪枝是从已生成的树上裁掉一些子树或者叶结点,并将根结点或者其父结点作为新的叶结点。剪枝的依据是:极小化决策树的整体损失函数或者代价函数。
- 决策树生成算法是学习局部的模型,决策树剪枝是学习整体的模型。即:生成算法仅考虑局部最优,而剪枝算法考虑全局最优
原理

算法

4、CART 树
回归树

分类树

其他讨论

CART 剪枝

5、面试题
防止过拟合:
- 预剪枝,及早停止树的生长。限制条件可能包括:最大深度、限制叶结点的最大数目、限制一个结点中数据点的最小数目
- 后剪枝,先构造树,但随后删除或折叠信息量较少的结点。
决策树缺点
树不能在训练数据的范围之外生成“新的”响应。

绿色代表决策树的预测,数据集是 2000 年以前的,特征为年份,可以看出决策树预测 2000 年以后的是一条直线,明显是错误的。
优点、缺点和参数
控制模型复杂度的参数是预剪枝参数,选择一种预剪枝策略(max_depth、max_leaf_nodes、min_samples_leaf)足以防止过拟合。
决策树的优点:
- 得到的模型很容易可视化
- 算法完全不受数据缩放的影响
决策树的缺点:
- 容易过拟合,泛化性能较差
6、代码
决策树分类
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42)
tree=DecisionTreeClassifier(random_state=0)
tree.fit(X_train,y_train)
print('Accuracy on training set :{:.3f}'.format(tree.score(X_train,y_train)))
print('Accuracy on test set:{:.3f}'.format(tree.score(X_test,y_test)))
# 限制树的深度
tree=DecisionTreeClassifier(max_depth=4,random_state=0)
tree.fit(X_train,y_train)
print('Accuracy on training set :{:.3f}'.format(tree.score(X_train,y_train)))
print('Accuracy on test set:{:.3f}'.format(tree.score(X_test,y_test)))
print('Feature importances:\n{}'.format(tree.feature_importances_))
def plot_feature_importances_cancer(model):
n_features=cancer.data.shape[1]
plt.barh(range(n_features),model.feature_importances_,align='center')
plt.yticks(np.arange(n_features),cancer.feature_names)
plt.xlabel('Feature importance')
plt.ylabel('Feature')
plt.show()
plot_feature_importances_cancer(tree)
决策树回归
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
ram_prices=pd.read_csv('data/ram_price.csv')
plt.semilogy(ram_prices.date,ram_prices.price)
plt.xlabel('Year')
plt.ylabel('Price in $/Mbyte')
plt.show()
plt.clf()
data_train=ram_prices[ram_prices.date<2000]
data_test=ram_prices[ram_prices.date>=2000]
X_train=data_train.date[:,np.newaxis]
y_train=np.log(data_train.price)
tree=DecisionTreeRegressor(max_depth=4).fit(X_train,y_train)
linear_reg=LinearRegression().fit(X_train,y_train)
X_all=ram_prices.date[:,np.newaxis]
pred_tree=tree.predict(X_all)
pred_lr=linear_reg.predict(X_all)
price_tree=np.exp(pred_tree)
price_lr=np.exp(pred_lr)
plt.semilogy(data_train.date,data_train.price,label='Training data')
plt.semilogy(data_test.date,data_test.price,label='Test data')
plt.semilogy(ram_prices.date,price_tree,label='Tree prediction')
plt.semilogy(ram_prices.date,price_lr,label='Linear prediction')
plt.legend()
plt.show()
网友评论