可以采用两种方法,一种是直接调用像seaborn这样的库,另一种是在matplotlib的基础上根据置信带的原理自己完善。
一、调用seaborn
可以使用sns.regplot()这个函数,其中参数ci为置信水平,默认为95%,我们可以设置为99%或者其他值。
调用方式为:
sns.regplot(x=x, y=y, ci=95)
具体示例如下所示:
代码
# Import standard packages
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
# Set time seed
np.random.seed(8)
# Generate data
mean, cov = [4, 6], [(1.5, .7), (.7, 1)]
x, y = np.random.multivariate_normal(mean, cov, 80).T
# Plot figure
ax = sns.regplot(x=x, y=y, ci=95)
plt.show()
lr.png
二、自己实现
在理解置信区间(confidence interval)的基础上,尝试实现
例如下面的代码中,函数fit_plot_line()已写好,直接调用即可
调用方式为
fit_plot_line(x=x, y=y, ci=95)
具体示例如下所示:
作者:知乎用户
链接:https://www.zhihu.com/question/425566655/answer/1524420211
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
np.random.seed(8)
def fit_plot_line(x=[], y=[], ci=95):
alpha = 1 - ci / 100
n = len(x)
Sxx = np.sum(x**2) - np.sum(x)**2 / n
Sxy = np.sum(x * y) - np.sum(x)*np.sum(y) / n
mean_x = np.mean(x)
mean_y = np.mean(y)
# Linefit
b = Sxy / Sxx
a = mean_y - b * mean_x
# Residuals
def fit(xx):
return a + b * xx
residuals = y - fit(x)
var_res = np.sum(residuals**2) / (n - 2)
sd_res = np.sqrt(var_res)
# Confidence intervals
se_b = sd_res / np.sqrt(Sxx)
se_a = sd_res * np.sqrt(np.sum(x**2)/(n * Sxx))
df = n-2 # degrees of freedom
tval = stats.t.isf(alpha/2., df) # appropriate t value
ci_a = a + tval * se_a * np.array([-1, 1])
ci_b = b + tval * se_b * np.array([-1, 1])
# create series of new test x-values to predict for
npts = 100
px = np.linspace(np.min(x), np.max(x), num=npts)
def se_fit(x):
return sd_res * np.sqrt(1. / n + (x - mean_x)**2 / Sxx)
# Plot the data
plt.figure()
plt.plot(px, fit(px), 'k', label='Regression line')
plt.plot(x, y, 'k.')
x.sort()
limit = (1 - alpha) * 100
plt.plot(x, fit(x) + tval * se_fit(x), 'r--', lw=2,
label='Confidence limit ({0:.1f}%)'.format(limit))
plt.plot(x, fit(x) - tval * se_fit(x), 'r--', lw=2)
plt.xlabel('X values')
plt.ylabel('Y values')
plt.title('Linear regression and confidence limits')
plt.legend(loc='best')
plt.show()
# generate data
mean, cov = [4, 6], [(1.5, .7), (.7, 1)]
x, y = np.random.multivariate_normal(mean, cov, 80).T
# fit line and plot figure
fit_plot_line(x=x, y=y, ci=95)
lrl.png
网友评论