1、算法简介
1-1、算法思路
上一篇,简单线性回归算法的缺点之一是对于标签值是曲线结构的走势,很难拟合。那多项式回归算法出现,就是使得线性回归算法可以对非线性的数据进行回归分析,一个方面的优化、改进。
对于拟合出非线性关系,一般可以想到曲线;提起曲线,就得说在初中时学习的一元二次方程--y = ax^2 + bx + c。
简单线性回归算法的特征值是一次幂,如果想要形成一元二次方程的效果,就得需要在特征值中添加二次幂;这样的话,对于最后的解就变为了a、b、c。
1-2、图示
多项式回归如图,样本点中间有一条曲线,样本之间的关系试图要用一条曲线来拟合。
1-3、算法流程
同简单线性回归算法
1-4、优缺点
1-4-1、优点
a、拟合非线性的数据
b、理解与解释都十分直观
c、可以通过正则化来降低过拟合的风险
d、容易使用随机梯度下降和新数据更新模型权重
1-4-2、缺点
a、需要处理异常值
b、较简单回归算法复杂、困难
c、训练时间会增加
2、实践
2-1、采用bobo老师创建简单测试用例
import numpy as np
import matplotlib.pyplot as plt
# 创建测试数据
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
plt.scatter(x, y)
plt.show() #见plt.show0
plt.show0
from sklearn.linear_model import LinearRegression
# 使用简单线性回归训练
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_predict = lin_reg.predict(X)
plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show() # 见plt.show1
plt.show1
X2 = np.hstack([X, X**2]) # 添加一个二次幂特征
X2.shape
# (100, 2)
# 多项式回归
lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_predict2 = lin_reg2.predict(X2)
plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
#要对x排序,否则是混乱的折线图
plt.show() # 见plt.show2
plt.show2
lin_reg2.coef_ # 特征和二次幂特征的系数
# array([0.85348244, 0.481137 ])
lin_reg2.intercept_ # 截距-b
# 2.032352537360585
网友评论