
多项式回归是一种多元线性回归的特殊形式,用于对响应变量和多项式特征项之间的关系进行建模。我们通过代码演示,这里推测方程为一个特征,这点不同于上面盖浇饭例子(具有两个特征)。
如果是两个参数和一个截断,也就是三个特征值,我们在等式两边除以矩阵是行不通,代替除以矩阵我们可以通过乘以逆矩阵来避免矩阵除法。值得注意的是只有方阵可逆。
def equation_1():
yy = regressor.predict(xx.reshape(xx.shape[0],1))
plt.plot(xx,yy,label='degree=1')
plt.axis([0, 28, 0, 28])
plt.show()

在一元回归简单直线,由于模型过于简单无法表达曲线特征,这就是我们所说的欠拟合,可以同增加多项式来增强模型表达能力
def equation_2():
yy = regressor.predict(xx.reshape(xx.shape[0],1))
plt.plot(xx,yy,label='degree=1')
# 2 次项生成器
quadratic_featurizer = PolynomialFeatures(degree=2)
X_train_quadratic = quadratic_featurizer.fit_transform(X_train)
X_test_quadratic = quadratic_featurizer.transform(X_test)
regressor_quadratic = LinearRegression()
regressor_quadratic.fit(X_train_quadratic,y_train)
plt.scatter(X_train, y_train)
xx_quadratic = quadratic_featurizer.transform(xx.reshape(xx.shape[0],1))
yy_quadratic = regressor_quadratic.predict(xx_quadratic)
plt1, = plt.plot(xx, yy, label="Degree1")
plt2, = plt.plot(xx, yy_quadratic, label="Degree2")
plt.axis([0, 28, 0, 28])
# 0.8675443656345054
print('Quadratic regression r-squared',regressor_quadratic.score(X_test_quadratic,y_test))
plt.show()

显然 2 阶函数更好的拟合这些点相比简单线性回归,2 次线性回归提升从 0.81 提升到 0.87
def equation_3():
yy = regressor.predict(xx.reshape(xx.shape[0],1))
plt.plot(xx,yy,label='degree=1')
# 2 次项生成器
biquadrate_featurizer = PolynomialFeatures(degree=4)
X_train_biquadrate = biquadrate_featurizer.fit_transform(X_train)
X_test_biquadrate = biquadrate_featurizer.transform(X_test)
regressor_biquadrate = LinearRegression()
regressor_biquadrate.fit(X_train_biquadrate,y_train)
plt.scatter(X_train, y_train)
xx_biquadrate = biquadrate_featurizer.transform(xx.reshape(xx.shape[0],1))
yy_biquadrate = regressor_biquadrate.predict(xx_biquadrate)
plt1, = plt.plot(xx, yy, label="Degree1")
plt2, = plt.plot(xx, yy_biquadrate, label="Degree2")
plt.axis([0, 28, 0, 28])
# 0.8095880795782215
print('Biquadrate regression r-squared',regressor_biquadrate.score(X_test_biquadrate,y_test))
plt.show()

随着模型容量增加,当 4阶多项式的回归曲线会经过所有训练集点,但是我们通过观察会发现这个曲线虽然在训练集上表现优异,但是在测试集表现一眼看出他存在问题。这一次 4 阶方程又降低回了 0.80
这样就是我们所说的过拟合,而之前的一元线性回归,由于模型过于简单无法表达这些点,这种情况就是欠拟合
网友评论