散点图能够直观地看出预测值与真实值之间的关系,同时绘制完整散点图非常重要。一般散点图包含下列数据。
- 显示的数值,比如回归预测的和,还有样本大小等。
- 显示比例线,一般是和预测与真实之间拟合线以及对应的拟合方程。
- 标题,轴的含义,单位等重要的量。
注意:
python和库的版本,我的版本是
python 3.6
seaborn 0.10.0
pandas 1.0.1
numpy 1.19.1
matplotlib 2.2.5
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
sns.set(style="white", font_scale=1.5,color_codes=True)
### 注意这个是读取绘制的文件名
for file in file_name:
dataset = pd.read_csv(file)
### 1:1比例线
pred_min, pred_max = dataset['y_pred'].min(),dataset['y_pred'].max()
true_min, true_max = dataset['y_test'].min(),dataset['y_test'].max()
y_pred = dataset['y_test'].to_numpy()
y_test = dataset['y_pred'].to_numpy()
xy_mse = np.sum((y_pred-y_test)**2);
xy_mean = np.mean(y_pred);
xx_mean = np.sum((y_pred - xy_mean)**2);
R2 = 1 - xy_mse/xx_mean
RMSE = np.sqrt(np.mean((y_pred - y_test)**2))
p = np.polyfit(y_pred, y_test, 1)
formatSpec = 'y = %.4fx+ %.4f'%(p[0], p[1])
formatXy = 'y = x'
x1 = np.linspace(pred_min,pred_max);
y1 = np.polyval(p, x1 )
str_R2 = '$R^2$ = %.4f\nRMSE = %.2f \nsamples = %d'%(R2, RMSE, dataset.shape[0])
f, ax= plt.subplots(figsize = (14, 10))
#plt.title('%s'%file[:-4])
plt.title('The demo of scatter map')
#### set the colorbar font size
#### https://stackoverflow.com/questions/34706845/change-xticklabels-fontsize-of-seaborn-heatmap
#### set the x y labels font size
#### https://www.cnblogs.com/lemonbit/p/7419851.html
ax.tick_params(labelsize = 16) #
# # ax.set_ylabel('the Number of Models',fontsize=15, color='r')
## cmap='BrBG' 'RdBu'
scatter = sns.scatterplot(x= 'y_pred', y='y_test', data = dataset, alpha = 0.8, color = 'b')
ax.plot(np.arange(pred_min, pred_max,0.1), np.arange(pred_min, pred_max,0.1), color='r', linewidth=3, alpha=0.6, label = formatXy )
ax.plot(x1, y1, color='k', linewidth=3, alpha=0.6, label = formatSpec)
ax.legend(loc = 'lower right', fontsize = 20)
x_pos1 = int(pred_min)
y_pos1 = int(0.9 * true_max)
ax.text(x_pos1,y_pos1 ,str_R2, fontsize = 20)
file = 'temp'
f.savefig('%s.png'%file, dpi=300, bbox_inches='tight')
plt.xlabel( 'X axis')
plt.ylabel( 'Y axis')
plt.show()
temp.png
网友评论