pcolormesh()的作用就是绘制背景图
例一:以鸢尾花为例
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets # 机器学习库
scikit_iris = datasets.load_iris()
iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度', u'类别'
path = 'iris.data' # 数据文件路径
data = pd.DataFrame(data=np.c_[scikit_iris['data'], scikit_iris['target']],
columns=np.append(scikit_iris.feature_names, ['y']))
data.columns = iris_feature
data['类别'] = pd.Categorical(data['类别']).codes
x_train = data[['花萼长度', '花瓣长度']]
y_train = data['类别']
model = DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3)
model.fit(x_train, y_train)
N, M = 500, 500 # 横纵各采样多少个值
x1_min, x2_min = x_train.min(axis=0)
x1_max, x2_max = x_train.max(axis=0)
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点
y_predict = model.predict(x_show)
print(y_predict.shape)
print(x1.shape)
mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False
cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
plt.xlim(x1_min, x1_max)
plt.ylim(x2_min, x2_max)
print((y_predict.reshape(x1.shape)).shape)
plt.pcolormesh(x1, x2, y_predict.reshape(x1.shape), cmap=cm_light)
plt.scatter(x_train['花萼长度'], x_train['花瓣长度'], c=y_train, cmap=cm_dark, marker='o', edgecolors='k')
plt.xlabel('花萼长度')
plt.ylabel('花瓣长度')
plt.title('鸢尾花分类')
plt.grid(True, ls=':')
plt.show()
绘制结果:

例二:以KNN模型为例
# 导入画图工具
import matplotlib.pyplot as plt
# 导入数组工具
import numpy as np
# 导入数据集生成器
from sklearn.datasets import make_blobs
# 导入KNN 分类器
from sklearn.neighbors import KNeighborsClassifier
# 导入数据集拆分工具
from sklearn.model_selection import train_test_split
# 生成样本数为500,分类数为5的数据集
data = make_blobs(n_samples=500, n_features=2, centers=5, cluster_std=1.0, random_state=8)
X, Y = data
#print(data)
print('=============X')
print(X.shape)
print('=============y')
print(Y[0])
# 将生成的数据集进行可视化
plt.scatter(X[:,0], X[:,1],s=80, c=Y, cmap=plt.cm.spring, edgecolors='k')
plt.show()
clf = KNeighborsClassifier()
clf.fit(X, Y)
# 绘制图形
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
# print('=============X[:, 0].min()')
# print(X[:, 0].min())
#此处的 yy 并不是输出,而是 X 的另一列,即另一个属性值
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
#meshgrid 从坐标向量中返回坐标矩阵
xx, yy = np.meshgrid(np.arange(x_min, x_max, .02), np.arange(y_min, y_max, .02))
# print('=============xx')
# print(xx.shape)
# print('=============yy')
# print(yy.shape)
# print('=============xx.ravel')
# print(xx.ravel())
#np.c_ 按行相加
#ravel() 降维,化为一维
z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
#z1 = clf.predict(np.c_[xx.ravel(), yy.ravel()])
#ravel() 扁平化
print('=============z')
print(z.shape)
z = z.reshape(xx.shape)
print('=============z reshape')
print(z.shape)
plt.pcolormesh(xx, yy, z, cmap=plt.cm.Pastel1)
#散点图
#S 大小,c颜色 cmap散点颜色方案 edgecolors 散点的边缘线
plt.scatter(X[:, 0], X[:, 1], s=80, c=Y, cmap=plt.cm.spring, edgecolors='k')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title("Classifier:KNN")
# 把待分类的数据点用五星表示出来
plt.scatter(0, 5, marker='*', c='red', s=200)
# 对待分类的数据点的分类进行判断
res = clf.predict([[0, 5]])
plt.text(0.2, 4.6, 'Classification flag: ' + str(res))
plt.text(3.75, -13, 'Model accuracy: {:.2f}'.format(clf.score(X, Y)))
plt.show()
绘制结果

网友评论