# 导入相应模块和数据集
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
# 加载数据
wine = load_wine()
数据信息
wine.data.shape
#(178, 13)
wine.target
'''
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2])
'''
import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)

划分数据集
# 划分训练集和测试集
Xtrain,Xtest,Ytrain,Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
Xtrain.shape
Xtest.shape
#(54, 13)
决策树创建
# 创建决策树模型 random_state=30类似于确定随机数种子,保证生成的决策树不变
clf = tree.DecisionTreeClassifier(criterion='entropy',random_state=30)
clf.fit(Xtrain,Ytrain)
score = clf.score(Xtest,Ytest)
score
# 0.9629629629629629
可视化决策树
# 画出决策树
feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚'
,'类黄酮','非黄烷类酚类','花青素','颜色强度','色调'
,'od/280od315稀释葡萄酒','脯氨酸']
import graphviz
dot_data = tree.export_graphviz(clf
,feature_names=feature_name
,filled=True # 不同颜色
,rounded=True)# 圆角
graph = graphviz.Source(dot_data)
graph

分析重要信息
# 重要特征
clf.feature_importances_
[*zip(feature_name,clf.feature_importances_)]
[('酒精', 0.0),
('苹果酸', 0.0),
('灰', 0.0),
('灰的碱性', 0.0),
('镁', 0.0),
('总酚', 0.0),
('类黄酮', 0.3990915395984506),
('非黄烷类酚类', 0.0),
('花青素', 0.0),
('颜色强度', 0.4679552572083532),
('色调', 0.034102883482108784),
('od/280od315稀释葡萄酒', 0.0),
('脯氨酸', 0.0988503197110873)]
决策树的裁剪
# 决策树的裁剪 防止过拟合
# 常见使用前剪枝
# max_depth 限制树的深度
# min_samples_leaf 保证节点分支后至少有多少个样本
# min_samples_split 保证节点有多少样本时才允许分支
clf = tree.DecisionTreeClassifier(criterion='entropy'
,random_state=30
,max_depth=2
,min_samples_leaf=10
,min_samples_split=10)
clf = clf.fit(Xtrain,Ytrain)
score = clf.score(Xtest,Ytest)
score
# 0.7962962962962963
可视化结果
dot_data = tree.export_graphviz(clf
,feature_names=feature_name
,filled=True # 不同颜色
,rounded=True)# 圆角
graph = graphviz.Source(dot_data)
graph

其它一些接口
# 一些重要属性和接口
clf.apply(Xtest) # 返回野猪节点所在索引
'''
array([4, 1, 4, 1, 1, 4, 4, 3, 4, 4, 1, 4, 4, 3, 3, 4, 4, 1, 3, 3, 1, 1,
4, 3, 1, 3, 1, 3, 1, 1, 3, 4, 4, 3, 4, 3, 4, 1, 3, 4, 4, 4, 1, 4,
4, 4, 4, 4, 3, 3, 4, 4, 4, 4], dtype=int64)
'''
网友评论