1.源码实现
import numpy as np
from sklearn import tree
x = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]])
y = [0, 1, 1, 1, 2, 3, 3, 4]
# 创建决策树分类器
clf = tree.DecisionTreeClassifier()
# 拟合
clf.fit(x, y)
# 分类
print(clf.predict([[1, 0, 0]]))
# 导出决策树
data = tree.export_graphviz(clf, out_file=None)
print(data)
2.运行及其结果
$ python3 example.py
[2]
digraph Tree {
node [shape=box] ;
0 [label="X[0] <= 0.5\ngini = 0.75\nsamples = 8\nvalue = [1, 3, 1, 2, 1]"] ;
1 [label="X[2] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [1, 3, 0, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[1] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1, 0, 0, 0]"] ;
1 -> 2 ;
3 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0, 0, 0, 0]"] ;
2 -> 3 ;
4 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0, 0, 0]"] ;
2 -> 4 ;
5 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0, 0, 0]"] ;
1 -> 5 ;
6 [label="X[1] <= 0.5\ngini = 0.625\nsamples = 4\nvalue = [0, 0, 1, 2, 1]"] ;
0 -> 6 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
7 [label="X[2] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [0, 0, 1, 1, 0]"] ;
6 -> 7 ;
8 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1, 0, 0]"] ;
7 -> 8 ;
9 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 0, 1, 0]"] ;
7 -> 9 ;
10 [label="X[2] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [0, 0, 0, 1, 1]"] ;
6 -> 10 ;
11 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 0, 1, 0]"] ;
10 -> 11 ;
12 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 0, 0, 1]"] ;
10 -> 12 ;
}
网友评论