美文网首页
多项线性回归

多项线性回归

作者: NextStepPeng | 来源:发表于2017-10-29 22:57 被阅读0次

    之前学习了单一线性回归和多元线性回归,这次来学习下多项线性回归,那什么事多项线性回归呢?先看下图Polynomial Linear Regression

    多项目线性回归表达式

    来,上代码

    import numpy as np

    import matplotlib.pyplot as plt

    import pandas as pd

    #import the dataset

    dataset = pd.read_csv('Position_Salaries.csv')

    X = dataset.iloc[:,1:2].values

    y = dataset.iloc[:,2].values

    吃进去的数据是这样(样例数据)

    样例数据(级别与薪资)

    来生成多项模型

    #多项线性模型

    #项 = 几次方 【1,2,3,4,5】 直线 【1,4,9,16,25】【1,8,27,84,125】 曲线

    from sklearn.preprocessing import PolynomialFeatures

    poly_reg = PolynomialFeatures(degree = 2) #转成平方

    X_poly = poly_reg.fit_transform(X)

    来查看先X_poly

    X_poly

    为什么前面多 了一列1,解释请看下面公式,进一步理解什么是多项线性回归

    Y = 2º + 2¹ + 2²  (用2距离)

    为了跟好了理解下多项详细回归,我们先用数据

    (1)、生成单一线性模型

    #单一 线性回归

    from sklearn.linear_model import LinearRegression

    lin_reg = LinearRegression()

    lin_reg.fit(X, y)

    (2)、可视化

    #把数据可视化

    plt.scatter(X,y,color = "red")

    #先用单一线性回归的模型

    plt.plot(X,lin_reg.predict(X), color = "blue")

    plt.title("Truth or Bluff(Linear Regression)")

    plt.xlabel("Position level")

    plt.ylabel("Salary")

    plt.show()

    通过单一线性回归得出来的可视化图

    (3)、通过这种图可以发现,这个模型的精度非常低,几乎无法预测、下面让我们再来看下,通过多项线性模型得出的来的图。

    代码:

    模型

    #多项线性模型

    #项 = 几次方 【1,2,3,4,5】 直线 【1,4,9,16,25】【1,8,27,84,125】 曲线

    from sklearn.preprocessing import PolynomialFeatures

    poly_reg = PolynomialFeatures(degree = 2) #转成平方  3、立方 4、 可以更换尝试

    X_poly = poly_reg.fit_transform(X)

    lin_reg2 = LinearRegression()

    lin_reg2.fit(X_poly, y)

    #把数据可视化

    plt.scatter(X,y,color = "red")

    #多项线性回归的模型

    plt.plot(X,lin_reg2.predict(poly_reg.fit_transform(X)), color = "blue")

    plt.title("Truth or Bluff(PolynomialFeatures)")

    plt.xlabel("Position level")

    plt.ylabel("Salary")

    plt.show()

    通过多项线性回归(平方)产生的图 通过多项线性回归(立方)产生的图 通过多项线性回归(4次方)产生的图

    可以看到4次方的时候,模型的精确度越来越高

    通过两个模型对6.5级别的薪资 进行下预测看下那个精度高些,明显是多项线性模型 预测的是158862

    #作出预测

    lin_reg.predict(6.5) #array([[ 330378.78787879]])

    lin_reg2.predict(poly_reg.fit_transform(6.5)) #array([[ 158862.4526516]])

    线性预测的结果是:array([[ 330378.78787879]])

    那么问题又来了,选择使用什么样的模型,或者如何提高模型的精确度,请看“SVR”

    相关文章

      网友评论

          本文标题:多项线性回归

          本文链接:https://www.haomeiwen.com/subject/vmhupxtx.html