美文网首页
KNN(K近邻)二分类实践代码

KNN(K近邻)二分类实践代码

作者: 万州客 | 来源:发表于2022-04-25 15:12 被阅读0次

有好书就跟着走起来~

《深入浅出Python机器学习》

一,代码

import numpy as np
from sklearn.datasets import make_blobs
from sklearn.datasets import make_regression
from sklearn.datasets import load_wine
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import KNeighborsRegressor
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
'''

data = make_blobs(n_samples=200, centers=2, random_state=8)
X, y = data
clf = KNeighborsClassifier()
clf.fit(X,y)
x_min, x_max = X[:, 0].min() -1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() -1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, .02),
                     np.arange(y_min, y_max, .02))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Pastel1)
plt.scatter(X[:, 0], X[:,1], c=y, cmap=plt.cm.spring, edgecolors='k')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("Classifier:KNN")
plt.scatter(6.75, 4.82, marker='*', c='red', s=200)
plt.show()

#对新数据点分类进行判断
print('新数据点的分类是:', clf.predict([[6.75, 4.82]]))

data2 = make_blobs(n_samples=500, centers=5, random_state=8)
X2, y2 = data2
#plt.scatter(X2[:, 0], X2[:,1], c=y2, cmap=plt.cm.spring, edgecolors='k')
#plt.show()
clf = KNeighborsClassifier()
clf.fit(X2,y2)
x_min, x_max = X2[:, 0].min() -1, X2[:, 0].max() + 1
y_min, y_max = X2[:, 1].min() -1, X2[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, .02),
                     np.arange(y_min, y_max, .02))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
print('模型正确率:{:.2f}'.format(clf.score(X2, y2)))
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Pastel1)
plt.scatter(X2[:, 0], X2[:,1], c=y2, cmap=plt.cm.spring, edgecolors='k')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("Classifier:KNN")
#plt.scatter(6.75, 4.82, marker='*', c='red', s=200)
plt.show()

X, y = make_regression(n_features=1, n_informative=1, noise=50, random_state=8)
reg = KNeighborsRegressor(n_neighbors=2)
reg.fit(X, y)
z = np.linspace(-3, 3, 200).reshape(-1, 1)
plt.scatter(X, y, c='orange', edgecolors='k')
plt.scatter(z, reg.predict(z), c='k', linewidth=3)
plt.title('KNN Regressor')
print('模型评分:{:.2f}'.format(reg.score(X, y)))
plt.show()
'''
wine_dataset = load_wine()
print('红酒数据集中的键:{}\n'.format(wine_dataset.keys()))
print('红酒数据集概况:{}\n'.format(wine_dataset['data'].shape))
# print(wine_dataset['DESCR'])
X_train, X_test, y_train, y_test = train_test_split(
    wine_dataset['data'], wine_dataset['target'], random_state=0
)
print('X_train shape:{}'.format(X_train.shape))
print('X_test shape:{}'.format(X_test.shape))
print('y_train shape:{}'.format(y_train.shape))
print('y_test shape:{}'.format(y_test.shape))

knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train)
print('测试数据得分:{:.2f}'.format(knn.score(X_test, y_test)))
X_new = np.array([[13.2, 2.77, 2.51, 18.5, 96.6, 1.04, 2.55, 0.57,
                   1.47, 6.2, 1.05, 3.33, 820]])
prediction = knn.predict(X_new)
print('预测新红酒的分类: {}'.format(wine_dataset['target_names'][prediction]))

二,效果

2022-04-25 15_09_58-MessageCenterUI.png
2022-04-25 17_53_49-MessageCenterUI.png 2022-04-25 15_49_57-.png

相关文章

网友评论

      本文标题:KNN(K近邻)二分类实践代码

      本文链接:https://www.haomeiwen.com/subject/cpwfyrtx.html