美文网首页
(二)线性回归正规方程

(二)线性回归正规方程

作者: 羽天驿 | 来源:发表于2020-04-06 16:11 被阅读0次

    一、什么是正规地方程

    (1)找到合适的预测函数
    (2)找到预测值与真实值之间的损失函数。

    • 正规方程--最小二乘法就是线性回归所用的损失函数。
    • 最小二乘法,实际上是想让拟合的直线方程与实际的误差最小。
    • 线性回归的使用的就是最小二乘法
    正规方程.png

    二、正规方程的详细解释

    • w = (X^TX)^{-1}X^Ty

    • 正规方程就是矩阵运算求解方式

    • \min\limits_w||Xw - y||_2^2

    • (x\theta - y)^2 = x^2\theta^2 + y^2 - 2x\theta y

    • (a\theta - b)^2 = a^2\theta^2 - 2ab\theta + b^2

    • X,w,y都是矩阵

    • 最小二乘法的方程大于等于0的

    • 开口向上的一个函数


      最小二乘法损失函数.png
    • 如果X,和y是数字,求导非常简单

    • f(x) = 3x^2 + 4x + 7

    • f'(x) = 6x + 4

    现在的方程不是数字,而是矩阵,矩阵求导法则 :


    矩阵求导公式.jpg
    • 矩阵常用求导公式:
      • \frac{dX^T}{dX} = I 求解出来是单位矩阵
      • \frac{dX}{dX^T} = I
      • \frac{dX^TA}{dX} = A
      • \frac{dAX}{dX} = A^T
      • \frac{dXA}{dX} = A^T
      • \frac{dAX}{dX^T} = A
        矩阵求导推导.png
    矩阵求导公式推导(二).png

    三 、正规方程矩阵的推导过程

    • f(w) = ||Xw - y||_2^2

    • <font color = red>展开之后并不是:</font>f(w) = (Xw - y)(Xw - y)

    • f(w) = (Xw - y)^T(Xw - y)

    • 为什么展开之后,带着T,进行了转置:

      • 向量2-范数表示:每个元素的平方和再开平方根
        ||X||_2 = \sqrt{\sum\limits_{i = 1}^nx_i^2}
      • 表示,自己和自己相乘
    • f(w) = (w^TX^T - y^T)(Xw - y)

    • f(w) = w^TX^TXw - w^TX^Ty - y^TXw + y^Ty 矩阵乘法形式的变换

    • f(w) = w^TX^TXw - 2y^TXw + y^Ty

    • 进行导数求解:

      • f'(w) = 2X^TXw - 2X^Ty

      • f'(w) = 0

      • 2X^TXw - 2X^Ty = 0

      • X^TXw = X^Ty

      • 矩阵运算,没有除法,逆矩阵

      • (X^TX)^{-1}(X^TX)w = (X^TX)^{-1}X^Ty

      • Iw = (X^TX)^{-1}X^Ty

      • w = (X^TX)^{-1}X^Ty

    四、代码实现

    (一.正规方程)

    import numpy as np
    
    X = np.random.randint(0,10,size = (5,5))
    X
    
    array([[5, 8, 3, 4, 9],
           [3, 1, 7, 0, 4],
           [0, 8, 8, 4, 0],
           [4, 6, 7, 1, 7],
           [9, 6, 8, 9, 7]])
    

    向量2-范数表示:每个元素的平方和再开平方根

    ||X||_2 = \sqrt{\sum\limits_{i = 1}^nx_i^2}

    使用矩阵,进行运算

    ||X||_2 = X^TX

    使用矩阵,展开,不是下面这种写法

    ||X||_2 = XX

    # 线性代数的方法,进行了计算,肯定正确
    np.linalg.norm(X,ord = 2)
    
    27.14866872518055
    
    X.dot(X).sum()
    
    364
    
    (((X).dot(X.T)).sum(axis = 1))**0.5
    
    array([27.40437921, 20.76053949, 23.83275058, 26.38181192, 31.27299154])
    
    X.T.dot(X)
    
    array([[131, 121, 136, 105, 148],
           [121, 201, 185, 124, 160],
           [136, 185, 235, 123, 160],
           [105, 124, 123, 114, 106],
           [148, 160, 160, 106, 195]])
    
    np.dot(X.T,X)
    
    array([[131, 121, 136, 105, 148],
           [121, 201, 185, 124, 160],
           [136, 185, 235, 123, 160],
           [105, 124, 123, 114, 106],
           [148, 160, 160, 106, 195]])
    

    (二.数据挖掘2020年天猫双十二预测)

    import numpy as np
    
    from sklearn.linear_model import LinearRegression
    
    # pip install matplotlib
    import matplotlib.pyplot as plt
    
    image
    X = np.arange(2009,2020)
    X
    
    array([2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019])
    
    # 销售额
    y = np.array([0.5,9.36,52,191,350,571,912,1207,1682,2135,2684])
    
    plt.plot(X,y)#线形图
    plt.scatter(X,y,color = 'red')#散点图
    
    <matplotlib.collections.PathCollection at 0x200ef1be198>
    

    [图片上传失败...(image-93c373-1586160670328)]

    销量随着年份增加,越来越大,速度越来越慢

    X年 -----> y销量之间存在一个函数关系
    画图显示,不是直线
    X ------> y之间的关系,多项式关系

    假设X和y之间的关系是一元三次幂关系

    f(x) = w_1x + w_2x^2 + w_3x^3 + b

    f(x) = w_0x^0 + w_1x + w_2x^2 + w_3x^3

    X 年份,数字太大,差别不明显,数据处理,优化

    # 2009 2010 相差的绝对值是1,相差的百分比,很小 1/2009 = 0.0004977
    # 1     2   相差绝对值是1,相差的百分比,1/1 = 100%
    # 差异明显了
    
    # 放大差异,放大镜观察数据
    X = X - 2008
    X
    
    array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
    
    plt.plot(X,y)#线形图
    plt.scatter(X,y,color = 'red')#散点图
    
    <matplotlib.collections.PathCollection at 0x20091fb6898>
    

    [图片上传失败...(image-57f556-1586160670328)]

    a = np.array([1,2,3])
    
    b = np.array([3,6,-4])
    
    # 级联,数据合并
    np.concatenate([a,a])
    
    array([1, 2, 3, 1, 2, 3])
    
    np.c_[a,a,a,b]
    
    array([[ 1,  1,  1,  3],
           [ 2,  2,  2,  6],
           [ 3,  3,  3, -4]])
    
    # 数据级联到一起
    # 0次幂,1次幂,2次幂,3次幂
    X_train = np.c_[X**0,X,X**2,X**3]
    X_train
    
    array([[   1,    1,    1,    1],
           [   1,    2,    4,    8],
           [   1,    3,    9,   27],
           [   1,    4,   16,   64],
           [   1,    5,   25,  125],
           [   1,    6,   36,  216],
           [   1,    7,   49,  343],
           [   1,    8,   64,  512],
           [   1,    9,   81,  729],
           [   1,   10,  100, 1000],
           [   1,   11,  121, 1331]])
    

    训练

    # fit_intercept=False 将截距设置为零
    linear = LinearRegression(fit_intercept=False)
    linear.fit(X_train,y)
    w_ = linear.coef_
    print(linear.coef_.round(2))
    print(linear.intercept_)
    
    [ 58.77 -84.06  27.95   0.13]
    0.0
    

    f(x) = 58.77 - 84.06*x + 27.95*x^2 + 0.13*x^3

    画图预测,验证,准确吗???

    # 测试数据
    X_test = np.linspace(0,12,256)
    # 0次幂,1次幂,2次幂,3次幂
    # 因为训练数据,是四维属性
    X_test = np.c_[X_test**0,X_test,X_test**2,X_test**3]
    
    # 使用模型,预测,销量(连续的方程了)
    # 返回的y_就是销量
    y_ = linear.predict(X_test)
    
    # 绘制图形
    plt.plot(np.linspace(0,12,256),y_,color = 'g')
    # 天猫双十一真实销量情况
    plt.scatter(np.arange(1,12),y,color = 'r')
    
    <matplotlib.collections.PathCollection at 0x200941543c8>
    

    [图片上传失败...(image-56ed9d-1586160670328)]

    fun = lambda x : w_[0] + w_[1]*x + w_[2]*x**2 + w_[-1]*x**3
    

    一元三次方程,基本吻合十一年的销量数据

    使用这个模型预测,2020年销量

    模型就是上面写好的fun方程

    2020 - 2008 = 12 数字12代表2020年

    1684,2135,2684,3294

    fun(12)

    正规方程,进行求解

    w = (X^TX)^{-1}X^Ty

    print("使用sklearn封装好的算法,计算的方程斜率",w_)
    
    使用sklearn封装好的算法,计算的方程斜率 [ 58.76727273 -84.0594561   27.94659674   0.12726496]
    
    w = np.linalg.inv(X_train.T.dot(X_train)).dot(X_train.T).dot(y)
    w
    
    array([ 58.76727273, -84.0594561 ,  27.94659674,   0.12726496])
    

    相关文章

      网友评论

          本文标题:(二)线性回归正规方程

          本文链接:https://www.haomeiwen.com/subject/hrnruhtx.html