前面,我们一直在讲决策树的原理,今天呢小鱼就教大家决策树可视化展示的方法,我们来具体形象地看看决策树究竟长什么样子~
首先,导入工具包,并执行魔法指令:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
%matplotlib inline
这里,小鱼为了省略数据预处理的步骤,我们直接使用 sklearn 内置的一个数据集:关于加利福尼亚房屋价值中枢预测的数据集。
housing = fetch_california_housing()
housing
返回 housing
的结构如下:
{'data': array([[ 8.3252 , 41. , 6.98412698, ..., 2.55555556,
37.88 , -122.23 ],
[ 8.3014 , 21. , 6.23813708, ..., 2.10984183,
37.86 , -122.22 ],
[ 7.2574 , 52. , 8.28813559, ..., 2.80225989,
37.85 , -122.24 ],
...,
[ 1.7 , 17. , 5.20554273, ..., 2.3256351 ,
39.43 , -121.22 ],
[ 1.8672 , 18. , 5.32951289, ..., 2.12320917,
39.43 , -121.32 ],
[ 2.3886 , 16. , 5.25471698, ..., 2.61698113,
39.37 , -121.24 ]]),
'target': array([4.526, 3.585, 3.521, ..., 0.923, 0.847, 0.894]),
'frame': None,
'target_names': ['MedHouseVal'],
'feature_names': ['MedInc',
'HouseAge',
'AveRooms',
'AveBedrms',
'Population',
'AveOccup',
'Latitude',
'Longitude'],
'DESCR': '.. _california_housing_dataset:\n\nCalifornia Housing dataset\n--------------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 20640\n\n :Number of Attributes: 8 numeric, predictive attributes and the target\n\n :Attribute Information:\n - MedInc median income in block group\n - HouseAge median house age in block group\n - AveRooms average number of rooms per household\n - AveBedrms average number of bedrooms per household\n - Population block group population\n - AveOccup average number of household members\n - Latitude block group latitude\n - Longitude block group longitude\n\n :Missing Attribute Values: None\n\nThis dataset was obtained from the StatLib repository.\nhttps://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.html\n\nThe target variable is the median house value for California districts,\nexpressed in hundreds of thousands of dollars ($100,000).\n\nThis dataset was derived from the 1990 U.S. census, using one row per census\nblock group. A block group is the smallest geographical unit for which the U.S.\nCensus Bureau publishes sample data (a block group typically has a population\nof 600 to 3,000 people).\n\nAn household is a group of people residing within a home. Since the average\nnumber of rooms and bedrooms in this dataset are provided per household, these\ncolumns may take surpinsingly large values for block groups with few households\nand many empty houses, such as vacation resorts.\n\nIt can be downloaded/loaded using the\n:func:`sklearn.datasets.fetch_california_housing` function.\n\n.. topic:: References\n\n - Pace, R. Kelley and Ronald Barry, Sparse Spatial Autoregressions,\n Statistics and Probability Letters, 33 (1997) 291-297\n'}
其中 housing.data
为数据集的特征,housing.target
为数据集的标签,还包含标签的名称 housing.target_names
以及特征的名称 housing.feature_names
等信息。
关于数据集的更多介绍,可以打印
housing.DESCR
了解。
接下来,为便于可视化展示,我们仅使用前 3 个特征,训练一个最大深度为 3 的决策树:
>> from sklearn.tree import DecisionTreeRegressor
>> dtr = DecisionTreeRegressor(max_depth=3)
>> dtr.fit(housing.data[:,:3], housing.target)
DecisionTreeRegressor(max_depth=3)
上述变量 dtr
就是我们训练好的决策树了。在进行决策树的展示前,我们需要确保我们的系统已经安装好了 graphviz,就可以使用 sklearn 提供的方法 export_graphviz
来绘制决策树了:
>> from sklearn.tree import export_graphviz
>> dot_data = export_graphviz(
dtr,
out_file=None,
feature_names=housing.feature_names[:3],
filled=True,
impurity=False,
rounded=True
)
>> type(dot_data)
str
Graphviz 下载地址:http://www.graphviz.org/download/,根据使用的系统类型,下载相应版本即可。
上述我们得到的 dot_data
为字符串类型,这是因为 DOT 文件是一个文本文件,描述了图表的组成元素以及它们之间的关系,以便该工具可以生成这些组成元素和它们之间的关系的图形化表示。
最后借助如下的 Python 库将 DOT 文件绘制成 PNG 展示出来:
import pydotplus
from IPython.display import Image
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
绘制结果:

网友评论