美文网首页
一个更真实的Decision Tree

一个更真实的Decision Tree

作者: 醉看红尘这场梦 | 来源:发表于2020-03-11 18:06 被阅读0次

我们来看一个更真实的机器学习的例子。当然,它仍旧是基于decision tree算法的,相比你可能更多听到的“神经网络”、“支持向量机”等算法,decision tree最大的优点,就是我们几乎不需要任何数学基础,就可以了解这种算法的分类过程。

一个更真实的traing data set - Iris

首先,来看我们使用的training data:Iris

dt1

它是一套标准数据集合,通过萼片(Sepal)和花瓣(Petal)各自的宽度和长度,识别了三种不同的鸢尾花(Iris):Setosa / Versicolor / Virginica。其中,每一类花,都有50个不重复的样本记录(Examples)。

结合上面这张表,以及我们已经学过的training data中的术语,就可以发现,这份数据集合中包含了以下内容:

  • 4个Features,也就是识别花的四个不同属性:Sepal length / Sepal width / Petal length / Petal width;
  • 3个Label,也就是三种不同的鸢尾花:Setosa / Versicolor / Virginica。

接下来,要做的第一个事情,就是把这些记录先倒入到Scikit。

加载Iris测试数据集

Scikit数据集导入页面可以看到,Scikit已经提供了直接导入Iris的API方便我们学习,无需加载任何第三方文件。

dt1

新建一个叫做iris.py的文件,然后添加下面的代码:

from sklearn.datasets import load_iris

iris = load_iris()

print(iris.feature_names)
print(iris.target_names)

首先,我们从sklearn.datasets中,引入了load_iris方法,并直接调用它加载了全部的Iris测试数据集合;

其次,我们读取了feature_names属性,它是一个数组,包含了所有Features的名字;

第三,我们读取了target_names属性,它包含的是所有Labels的名字;

最后,我们读取了iris的中第一个Example;

保存退出后,执行python3s iris.py,就可以在控制台看到下面的结果了:

['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
['setosa' 'versicolor' 'virginica']

很简单对不对?接下来,我们继续读取一些样本数据:

print(iris.data[0])
print(iris.target[0])

其中,data表示样本中的Features,target表示和每一组Feature对应的Label。保存后重新执行一下,就能看到结果了:

[ 5.1  3.5  1.4  0.2]
0

把这个结果和之前feature_namestarget_names的值对应起来,你就可以理解它的含义了。没错,第一个样本数据表示一个setosa。如果你要确认Scikit已经加载了所有的Iris数据,可以这样:

for i in range(len(iris.target)):
    print("Example %d: features: %s, label: %s" % (i, iris.data[i], iris.target[I]))

区分学习和测试数据

接下来,我们就要用这组数据集训练Classifier了,为了方便稍后检查学习结果,我们得从每一类花的Examples中抽掉一个记录,用于测试。通过之前对测试数据的了解我们知道,iris.datairis.target中的第0,50和150条记录,分别对应着一种新花类型的开始,于是,我们可以用下面的代码,把这三条记录从datatarget中取出来,稍后用于检验:

import numpy as np

test_index = [0, 50 , 100]
training_data = np.delete(iris.data, test_index, axis = 0)
training_target = np.delete(iris.target, test_index)

这里,简单介绍下numpy中的delete方法:

首先,对于iris.target来说,它是一个像这样的一维数组:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

为了删掉这个数组中的第0/50/100个元素,我们直接传递给它对应位置的数组就好了。

其次,iris.data是一个二维数组:

[[ 5.1  3.5  1.4  0.2]
 [ 4.9  3.   1.4  0.2]
 [ 4.7  3.2  1.3  0.2]
 [ 4.6  3.1  1.5  0.2]
 ...
]

为了在这个二维数组中删掉第0/50/100行,我们就要给delete传递第三个参数axis,对于一个二维数组来说,0表示索引位置所在的行,1表示索引位置所在的列,因此,我们传递0。然后,delete会返回删除后的值,我们分别保存起来稍后用于训练。

最后,我们还要专门把第0/50/100位置的Feature和Label也单独保存出来,稍后用于检验学习结果:

testing_data = iris.data[test_index]
testing_target = iris.target[test_index]

对于机器学习来说,在开始训练之前搞清楚哪些数据用于训练,哪些数据用于验证训练效果,是一件非常重要的事情,搞不清楚它们,我们将无法了解学习的效果。

训练并检验学习结果

接下来的事情,就很简单了,过程和上一节一样。首先,创建决策树并填充features和labels进行训练:

from sklearn import tree

clf = tree.DecisionTreeClassifier()
clf.fit(training_data, training_target)

其次,用下面的代码检查学习结果:

print(testing_target)
print(clf.predict(testing_data))

按照之前的推断,预测的结果,应该和testing_target中的值,是完全一样的。重新执行一下,就能看到下面的结果了:

[0 1 2]
[0 1 2]

可视化decision tree的学习过程

在这一节最后,我们通过可视化的方式来看下机器是如何根据决策树进行判断的,在Scikit的官网上,可以找到输出PDF和生成png的例子。但有趣的是,我们要把这两部分的代码合并起来,生成的PDF才更加易懂。在iris.py里,添加下面的代码:

from IPython.display import Image
import pydotplus

dot_data = tree.export_graphviz(clf, out_file=None,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True, rounded=True,
    special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("iris.pdf")

如果你还没装过IPython,执行conda install IPython安装就好。但通过conda安装pydotplus会报错。这时,直接执行~/Miniconda3/bin/pip install pydotplus来安装就好了。

安装完成之后,重新执行iris.py,就可以在当前目录看到生成的iris.pdf文件了。打开它之后,看上去是这样的:

dt1

如何理解它呢?我们用testing_data中的结果来举例:

print(testing_data[0])
print(testing_target[0])
# [ 5.1  3.5  1.4  0.2]
# 0

从之前的结果中我们知道,0对应的是Setosa,结合生成的PDF,从上向下看:

首先比较petal width,它是features中的最后一个值,0.2 <= 0.8成立,于是走到左边节点,由于这已经是一个叶子节点,可以从图中看到class = setosa

其次,我们读取testing_data中的第2个记录,我们知道,它是versicolor:

print(testing_data[1])
print(testing_target[1])
# [ 7\.   3.2  4.7  1.4]
# 1

这次,仍旧从树根开始比较petal width:1.4 > 0.8,走到决策树的右节点。继续比较:1.4 < 1.75,走到左节点,这次,比较petal length,这是features中的第三个属性,4.7 < 4.95,继续走左节点,重新比较petal width:1.4 < 1.65,最终,走左节点后,来到一个新的叶子节点,而这个节点的值,就是versicolor,和我们的预期是完全一样的。

相关文章

网友评论

      本文标题:一个更真实的Decision Tree

      本文链接:https://www.haomeiwen.com/subject/uwjhjhtx.html