高斯过程回归
一、高斯分布
高斯分布(正态分布)是一种非常常见的连续概率分布。其在统计学上十分重要,经常用在自然和社会科学来代表一个不明的随机变量。
高斯分布:若随机变量X服从一个位置参数为、尺度参数为的高斯分布,记作:
其概率密度函数为
高斯分布的数学期望为,决定了分布的位置;其方差为,决定了分布的幅度。
二、多维高斯分布
2.1 一般形式
N为随机向量,
2.2 二维高斯分布
其中是与之间的相关系数,且。
image-20200618162846149.png三、高斯过程
在概率论和统计学中,高斯过程(Gaussian process)是观测值出现在一个连续域(例如时间和空间)的随机过程。在高斯过程中,连续输入空间中每一点都是与一个正态分布的随机变量相关联。此外,这些随机变量的每个有限集合都有一个多元正态分布,换句话说他们的任意有限线性组合是一个正态分布。高斯过程的分布是所有哪些(无限多个)随机变量的联合分布,正因如此,它是连续域(例如时间和空间)上的函数分布。
B站上有一个视频介绍的比较形象,假设存在一个高斯过程表示人的一生,在出生的那一刻已经被固定了。如下图所示,在人生的某一时刻他的成就为其符合一个正态分布,平均成就值为,方差为。其中直观理解为表示的是该时刻的一个平均值,表示的是其他时刻的表现对该时刻的影响。
image-20200618164751802.png四、高斯过程回归
已知训练数据集,其中为输入向量,为输出向量。现有新的输入,预测对应的目标数据(真实输出为)。
首先我们做出先验假设(和的联合概率分布):
其中
这是个先验,就像最小二乘拟合中假设数据服从线性分布类似。
根据多维高斯分布的条件分布性质可得
有高斯分布的概率分布可得,的概率最大。
是核函数(或协方差函数),用来捕捉不同时刻随机变量之间的关系,有如下核函数可供选择。
五、代码
import numpy as np
import matplotlib.pyplot as plt
class GPR:
def __init__(self,optimize=True):
self.is_fit = False
self.train_X,self.train_y = None,None
self.params = {"l":0.5,"sigma_f":0.2}
self.optimize = optimize
def fit(self,X,y):
self.train_X = np.asarray(X)
self.train_y = np.asarray(y)
self.is_fit = True
def predict(self,X):
if not self.is_fit:
print("GPR Model not fit yet")
return
X = np.asarray(X)
kff = self.kernel(self.train_X,self.train_X)
kyy = self.kernel(X,X)
kfy = self.kernel(self.train_X,X)
kff_inv = np.linalg.inv(kff+1e-8*np.eye(len(self.train_X)))
mu = kfy.T.dot(kff_inv).dot(self.train_y)
conv = kyy - kfy.T.dot(kff_inv).dot(kfy)
return mu,conv
def kernel(self,x1,x2):
dist_matrix = np.sum(x1**2, 1).reshape(-1, 1) + np.sum(x2**2, 1) - 2 * np.dot(x1, x2.T)
return self.params["sigma_f"] ** 2 * np.exp(-0.5 / self.params["l"] ** 2 * dist_matrix)
def y(x,noise_sigma=0.0):
x = np.asarray(x)
y = np.cos(x)+np.random.normal(0,noise_sigma,size=x.shape)
return y.tolist()
train_X = np.array([3, 1, 4, 5, 9]).reshape(-1, 1)
train_y = y(train_X, noise_sigma=1e-4)
test_X = np.arange(0, 10, 0.1).reshape(-1, 1)
gpr = GPR()
gpr.fit(train_X, train_y)
mu, cov = gpr.predict(test_X)
test_y = mu.ravel()
uncertainty = 1.96 * np.sqrt(np.diag(cov)) #1.96表示95%的置信度
plt.figure()
plt.title("l=%.2f sigma_f=%.2f" % (gpr.params["l"], gpr.params["sigma_f"]))
plt.fill_between(test_X.ravel(), test_y + uncertainty, test_y - uncertainty, alpha=0.1)
plt.plot(test_X, test_y, label="predict")
plt.scatter(train_X, train_y, label="train", c="red", marker="x")
plt.legend()
plt.show()
运行结果:
image.png
网友评论