传送门:分类树
1、原理
分类与回归树(classification and regression tree,CART)模型由Breiman等人在1984年提出。CART同样由特征选择、树的生成及剪枝组成。
既然是决策树,那么必然会存在以下两个核心问题:如何选择划分点?如何决定叶节点的输出值?
一个回归树对应着输入空间(即特征空间)的一个划分以及在划分单元上的输出值。
- 分类树中,我们采用信息论中的方法,通过计算选择最佳划分点。
- 而在回归树中,采用的是启发式的方法。
假如我们有n个特征,每个特征有s_i(i∈(1,n))个取值,那我们遍历所有特征,尝试该特征所有取值,对空间进行划分,直到取到特征j的取值s,使得损失函数最小,这样就得到了一个划分点。描述该过程的公式如下:
2、算法描述
一个简单实例:训练数据见下表,目标是得到一棵最小二乘回归树。![](https://img.haomeiwen.com/i2434365/7e958405e7e74498.png)
2.1 选择最优切分变量j与最优切分点s
在本数据集中,只有一个变量,因此最优切分变量自然是x。
接下来我们考虑9个切分点[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5].
你可能会问,为什么会带小数点呢?类比于篮球比赛的博彩,倘若两队比分是96:95,而盘口是“让1分 A队胜B队”,那A队让1分之后,到底是A队赢还是B队赢了?所以我们经常可以看到“让0.5分 A队胜B队”这样的盘口。在这个实例中,也是这个道理。
![](https://img.haomeiwen.com/i2434365/467b573f54a8815d.jpg)
![](https://img.haomeiwen.com/i2434365/69a245efcd01f5b7.png)
![](https://img.haomeiwen.com/i2434365/424588aaf022ee22.png)
2.2 对两个子区域继续调用上述步骤
![](https://img.haomeiwen.com/i2434365/685cbbcadbe6d089.png)
2.3 生成回归树
假设在生成3个区域之后停止划分,那么最终生成的回归树形式如下:![](https://img.haomeiwen.com/i2434365/4d3333225887be88.png)
3、代码
# coding=utf-8
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn import linear_model
# 画图支持中文显示
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
# Data set
x = np.array(list(range(1, 11))).reshape(-1, 1)
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05]).ravel() # 多维数组转换成一维数组[5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05]
print('x:\n' + str(x) + '\ny:\n' + str(y))
# Fit regression model
model1 = DecisionTreeRegressor(max_depth=1)
model2 = DecisionTreeRegressor(max_depth=3)
model3 = linear_model.LinearRegression() # 线性回归模型
model1.fit(x, y)
model2.fit(x, y)
model3.fit(x, y)
# Predict
X_test = np.arange(0.0, 10.0, 0.01)[:, np.newaxis] # 1000行1列
print(X_test.size)
# test = np.zeros(shape=(2, 3, 4), dtype=int)
# print(test.size)
# print(test.ndim)
y_1 = model1.predict(X_test)
y_2 = model2.predict(X_test)
y_3 = model3.predict(X_test)
# Plot the results
plt.figure()
plt.scatter(x, y, s=20, color="darkorange", edgecolors='blue', label="src_data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=1", linewidth=2, linestyle='--')
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=3", linewidth=2, linestyle='--')
plt.plot(X_test, y_3, color='red', label='liner regression', linewidth=2, linestyle='--')
plt.xlim(0, 12) # 设置坐标轴
plt.ylim(4, 11)
plt.xlabel("x_data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()
![](https://img.haomeiwen.com/i2434365/ddecf3ff4ef48a85.png)
拓展阅读:https://blog.csdn.net/hy592070616/article/details/81628956
网友评论