3.2.2 kNN在多元分类的应用
后续测试了复杂一点的5类点的例子(用下面的代码, 去掉random_state, 可以产生各色各样的点分类图), 下面选的这个分类有点麻烦, 因为有两个类粘在一起了. 如果是分得比较开的类, 后续的测试结果会好很多.
# 样品量500, 5类, 2特征, 还是设定种子8
data2 = make_blobs(n_samples=500,n_features=2, centers=5, random_state=8)
X2,y2 = data2
# 散点图可视化
plt.scatter(X2[:,0],X2[:,1],c=y2, cmap=plt.cm.spring,edgecolor='k')
plt.show()
分类数量为5
类似地, 进行拟合和模型显示. 可以看出中间两个类太接近了.
# kNN算法建立模型
clf = KNeighborsClassifier()
clf.fit(X2,y2)
#下面的代码用于画图
x_min, x_max = X2[:, 0].min() - 1, X2[:, 0].max() + 1
y_min, y_max = X2[:, 1].min() - 1, X2[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, .02),
np.arange(y_min, y_max, .02))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Pastel1)
plt.scatter(X2[:, 0], X2[:, 1], c=y2, cmap=plt.cm.spring, edgecolor='k')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("Classifier:KNN")
plt.show()
对5类数据的kNN分类
print('模型正确率:{:.2f}'.format(clf.score(X2,y2)))
# 模型正确率:0.96
这里的正确率(score)实际是正确样本数/总样本数
, 500点的话, 就是错了20个点. 如果样本分类分离较好, 则正确率会上升, 例如:
TODO: 怎么知道此时的随机种子是多少??
整体而言, kNN在分类时的效果还是不错的. 不过这里没有做训练集和测试集分离的测试.
如果想分割成训练集和测试集, 可以:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X2,y2,test_size=0.2)
clf2 = KNeighborsClassifier()
clf2.fit(X_train, y_train)
print('模型正确率:{:.2f}'.format(clf2.score(X_test, y_test)))
# 模型正确率:0.92
### 在不设随机值时, 正确率在0.91~0.97间波动
网友评论