美文网首页
多项式回归算法

多项式回归算法

作者: 元宝的技术日常 | 来源:发表于2020-04-18 20:52 被阅读0次

    1、算法简介

    1-1、算法思路

    上一篇,简单线性回归算法的缺点之一是对于标签值是曲线结构的走势,很难拟合。那多项式回归算法出现,就是使得线性回归算法可以对非线性的数据进行回归分析,一个方面的优化、改进。

    对于拟合出非线性关系,一般可以想到曲线;提起曲线,就得说在初中时学习的一元二次方程--y = ax^2 + bx + c。

    简单线性回归算法的特征值是一次幂,如果想要形成一元二次方程的效果,就得需要在特征值中添加二次幂;这样的话,对于最后的解就变为了a、b、c。


    1-2、图示

    多项式回归

    如图,样本点中间有一条曲线,样本之间的关系试图要用一条曲线来拟合。


    1-3、算法流程
    简单线性回归算法


    1-4、优缺点

    1-4-1、优点

    a、拟合非线性的数据
    b、理解与解释都十分直观
    c、可以通过正则化来降低过拟合的风险
    d、容易使用随机梯度下降和新数据更新模型权重

    1-4-2、缺点

    a、需要处理异常值
    b、较简单回归算法复杂、困难
    c、训练时间会增加


    2、实践

    2-1、采用bobo老师创建简单测试用例

    import numpy as np 
    import matplotlib.pyplot as plt
    
    # 创建测试数据
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
    
    plt.scatter(x, y)
    plt.show() #见plt.show0
    
    plt.show0
    from sklearn.linear_model import LinearRegression
    
    # 使用简单线性回归训练
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    y_predict = lin_reg.predict(X)
    
    plt.scatter(x, y)
    plt.plot(x, y_predict, color='r')
    plt.show() # 见plt.show1
    
    plt.show1
    X2 = np.hstack([X, X**2]) # 添加一个二次幂特征
    X2.shape
    # (100, 2)
    
    # 多项式回归
    lin_reg2 = LinearRegression()
    lin_reg2.fit(X2, y)
    y_predict2 = lin_reg2.predict(X2)
    
    plt.scatter(x, y)
    plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r') 
    #要对x排序,否则是混乱的折线图
    plt.show() # 见plt.show2
    
    plt.show2
    lin_reg2.coef_ # 特征和二次幂特征的系数
    # array([0.85348244, 0.481137  ])
    
    lin_reg2.intercept_ # 截距-b
    # 2.032352537360585
    

    相关文章

      网友评论

          本文标题:多项式回归算法

          本文链接:https://www.haomeiwen.com/subject/ozayvhtx.html