美文网首页
线性回归与梯度下降实现

线性回归与梯度下降实现

作者: 罗泽坤 | 来源:发表于2020-04-12 19:22 被阅读0次

    先创建一个数据样本集合

    Y = 2X+E+3,(E为标准正态分布误差)

    import numpy as np
    import pandas as pd
    from pandas import DataFrame,Series
    import torch
    e = np.random.randn(100)
    #print(data)
    x = np.random.uniform(1,10,size=(100,1)) #从均匀分布中100个取值
    

    数据清洗连接操作

    X = DataFrame(data=x) 
    print(X)
    E = DataFrame(data=e)
    print(E)
    Y = 2*X+E+3
    #Y = Y.rename(columns={0:'1'})
    Y.columns  = ['1']
    print(Y)
    Dt = pd.concat([X,Y.reindex(X.index)],axis=1)#将样本数据
    print(Dt)
    Dt.to_csv('LinR.csv',index = False,header = False) #去掉头索引和列索引写入LinR.csv文件
    
               0
    0   7.959296
    1   2.914118
    2   6.531930
    3   6.033350
    4   3.329694
    ..       ...
    95  6.689540
    96  9.719845
    97  1.911316
    98  6.540799
    99  2.833886
    
    [100 rows x 1 columns]
               0
    0  -0.988615
    1  -0.298217
    2   1.846047
    3   1.982957
    4   0.644001
    ..       ...
    95 -1.317616
    96  0.844359
    97  0.475709
    98 -0.182511
    99 -0.075437
    
    [100 rows x 1 columns]
                1
    0   17.929977
    1    8.530018
    2   17.909907
    3   17.049656
    4   10.303389
    ..        ...
    95  15.061463
    96  23.284048
    97   7.298340
    98  15.899087
    99   8.592336
    
    [100 rows x 1 columns]
               0          1
    0   7.959296  17.929977
    1   2.914118   8.530018
    2   6.531930  17.909907
    3   6.033350  17.049656
    4   3.329694  10.303389
    ..       ...        ...
    95  6.689540  15.061463
    96  9.719845  23.284048
    97  1.911316   7.298340
    98  6.540799  15.899087
    99  2.833886   8.592336
    
    [100 rows x 2 columns]
    
    import matplotlib.pyplot as plt
    sample = np.genfromtxt('LinR.csv',delimiter=',')
    print(sample)
    x = sample[:,0]  #提取出自变量
    y = sample[:,1]  #提取出函数值
    
    [[ 7.95929583 17.92997689]
     [ 2.91411788  8.5300183 ]
     [ 6.53192986 17.90990674]
     [ 6.0333498  17.04965623]
     [ 3.32969394 10.30338909]
     [ 9.78091003 24.11104012]
     [ 5.50414492 15.01162983]
     [ 2.3905083   7.98267035]
     [ 9.10073718 20.63593365]
     [ 8.6771435  20.00840521]
     [ 8.59297259 20.47870044]
     [ 3.71421162 13.64842939]
     [ 9.53717825 20.65189706]
     [ 1.65102104  5.72905434]
     [ 2.25873778  6.97617708]
     [ 4.22231735 11.55259733]
     [ 1.53631731  4.94287203]
     [ 5.34000791 13.16707094]
     [ 6.17545718 15.06864934]
     [ 4.90415263 12.41321868]
     [ 7.75327441 18.6794576 ]
     [ 5.69725718 12.84726204]
     [ 3.56909863 10.5966961 ]
     [ 9.94681425 21.33072322]
     [ 3.59799413  9.56658453]
     [ 3.87835658 11.11208729]
     [ 5.03228019 13.68579256]
     [ 9.48569476 22.77879716]
     [ 4.66228634 12.75729249]
     [ 2.4908625   7.26331882]
     [ 4.81153515 12.35458447]
     [ 9.58406536 22.38109883]
     [ 4.28121897 10.74117052]
     [ 2.45652558  8.71965016]
     [ 1.9736664   6.4809288 ]
     [ 9.51051954 22.72786535]
     [ 6.75959417 17.81727331]
     [ 5.61578003 14.22442185]
     [ 3.85178666  9.75542375]
     [ 5.68632445 13.52993683]
     [ 5.4966711  15.11321743]
     [ 1.11686002  6.05160856]
     [ 8.82731607 19.98729226]
     [ 3.23499699  8.89761767]
     [ 3.22340228  9.82438038]
     [ 3.79773535  9.75677403]
     [ 7.91327594 19.74140776]
     [ 6.96384595 16.55981935]
     [ 8.04811251 18.00044451]
     [ 5.95047664 15.91780519]
     [ 6.86179131 14.57196115]
     [ 9.92403963 24.85022378]
     [ 2.28427921  5.8765021 ]
     [ 3.65702265  9.91875919]
     [ 2.94110273 10.2254981 ]
     [ 1.44233169  6.68000409]
     [ 8.11509328 19.30801919]
     [ 9.61113455 22.19692088]
     [ 5.31715886 10.11119983]
     [ 4.18841633 10.38081049]
     [ 3.38696086 10.89271924]
     [ 6.79149049 16.38415715]
     [ 9.43502226 22.16533282]
     [ 2.39967181  8.2302037 ]
     [ 7.36239661 17.6441651 ]
     [ 2.94883598  9.58722116]
     [ 8.07898524 18.88577106]
     [ 7.27030879 17.48551352]
     [ 3.82638905 10.42645614]
     [ 2.21195155  6.25717108]
     [ 2.62373052  7.75337578]
     [ 4.2888883  13.05445549]
     [ 4.79372014 13.47932057]
     [ 7.27504625 18.03400179]
     [ 7.6804564  18.91495491]
     [ 4.96708467 11.14166529]
     [ 1.3403768   3.22305645]
     [ 3.31317609 10.54117355]
     [ 8.83562974 21.6072844 ]
     [ 5.2485316  15.68878013]
     [ 3.09095248  8.00604601]
     [ 1.40126741  5.61514191]
     [ 5.97243366 16.6769609 ]
     [ 9.27273317 21.88579855]
     [ 1.10719167  7.34636051]
     [ 7.64405604 16.02157126]
     [ 1.46871259  7.17774005]
     [ 8.04512646 19.50343429]
     [ 6.44703966 14.47562085]
     [ 2.53241481  6.62603648]
     [ 7.88971862 19.72029653]
     [ 5.25505582 13.99824232]
     [ 6.42160172 17.67089747]
     [ 4.71902614 10.58448458]
     [ 3.69982573 11.40370471]
     [ 6.68953965 15.06146343]
     [ 9.7198446  23.28404806]
     [ 1.91131556  7.29834036]
     [ 6.54079876 15.89908666]
     [ 2.83388631  8.59233584]]
    
    # 做出y = x*2+e+3的散点图
    plt.scatter(x, y, marker = 'o',color = 'red', s = 40 )
    plt.show()
    
    output_5_0.png
    #损失函数
    def loss_function(data,b,w):
        Total_Error = 0
        for i in range(len(data)):
            x = data[i][0]
            y = data[i][1]
            Total_Error += ((w*x+b)-y)**2
        return Total_Error/float(len(data))
    
    # 梯度下降步骤
    def gradient_step(b_current,w_current,data,learning_rate):
        #初始化梯度
        b_grad = 0
        w_grad = 0
        N = float(len(data))
        for i in range(len(data)):
            x = data[i][0]
            y = data[i][1]
            b_grad +=  (2/N)*(w_current*x+b_current-y)
            w_grad +=  x*(2/N)*(w_current*x+b_current-y)
        new_b  = b_current - learning_rate*b_grad
        #print(new_b)
        #w_current -= learning_rate*w_grad
        new_w = b_current - learning_rate*w_grad
        #print(new_w)
        return [new_b,new_w]
    
    #梯度下降迭代
    def gradient_descent_iter(initial_b,initial_w,iteration_num,data,learning_rate):
        b_current = initial_b
        w_current = initial_w
        for i in range(iteration_num):
            b_current,w_current = gradient_step(b_current,w_current,np.array(data),learning_rate)
        return [b_current,w_current]
    
    #运行函数
    def run_fun():
        initial_b = 0
        initial_w = 0
        iteration_num = 20000
        learning_rate = 0.0001
        data = np.genfromtxt('LinR.csv',delimiter=',')
        print('Starting gradient decent at b = {0},w={1},loss_value = {2}'.format(initial_b,initial_w,loss_function(data,initial_b,initial_w)))
        print('Running...')
        [b,w] = gradient_descent_iter(initial_b,initial_w,iteration_num,data,learning_rate)
        print('After {0} iterations b = {1},w = {2},loss_value = {3}'.format(iteration_num,b,w,loss_function(data,b,w)))
        
    #if __name__ == '__main__'():
    run_fun()
    
        
    
    Starting gradient decent at b = 0,w=0,loss_value = 217.6852494453724
    Running...
    After 20000 iterations b = 2.1616063195346547,w = 2.1614204809921613,loss_value = 1.4195063550286242
    

    可以看到迭代出的b值与w值与预先设定的2和3很接近损失函数值也很小了

    import matplotlib.pyplot as plt
    sample = np.genfromtxt('LinR.csv',delimiter=',')
    print(sample)
    x_1 = sample[:,0]
    y_1 = sample[:,1]
    
    [[ 7.95929583 17.92997689]
     [ 2.91411788  8.5300183 ]
     [ 6.53192986 17.90990674]
     [ 6.0333498  17.04965623]
     [ 3.32969394 10.30338909]
     [ 9.78091003 24.11104012]
     [ 5.50414492 15.01162983]
     [ 2.3905083   7.98267035]
     [ 9.10073718 20.63593365]
     [ 8.6771435  20.00840521]
     [ 8.59297259 20.47870044]
     [ 3.71421162 13.64842939]
     [ 9.53717825 20.65189706]
     [ 1.65102104  5.72905434]
     [ 2.25873778  6.97617708]
     [ 4.22231735 11.55259733]
     [ 1.53631731  4.94287203]
     [ 5.34000791 13.16707094]
     [ 6.17545718 15.06864934]
     [ 4.90415263 12.41321868]
     [ 7.75327441 18.6794576 ]
     [ 5.69725718 12.84726204]
     [ 3.56909863 10.5966961 ]
     [ 9.94681425 21.33072322]
     [ 3.59799413  9.56658453]
     [ 3.87835658 11.11208729]
     [ 5.03228019 13.68579256]
     [ 9.48569476 22.77879716]
     [ 4.66228634 12.75729249]
     [ 2.4908625   7.26331882]
     [ 4.81153515 12.35458447]
     [ 9.58406536 22.38109883]
     [ 4.28121897 10.74117052]
     [ 2.45652558  8.71965016]
     [ 1.9736664   6.4809288 ]
     [ 9.51051954 22.72786535]
     [ 6.75959417 17.81727331]
     [ 5.61578003 14.22442185]
     [ 3.85178666  9.75542375]
     [ 5.68632445 13.52993683]
     [ 5.4966711  15.11321743]
     [ 1.11686002  6.05160856]
     [ 8.82731607 19.98729226]
     [ 3.23499699  8.89761767]
     [ 3.22340228  9.82438038]
     [ 3.79773535  9.75677403]
     [ 7.91327594 19.74140776]
     [ 6.96384595 16.55981935]
     [ 8.04811251 18.00044451]
     [ 5.95047664 15.91780519]
     [ 6.86179131 14.57196115]
     [ 9.92403963 24.85022378]
     [ 2.28427921  5.8765021 ]
     [ 3.65702265  9.91875919]
     [ 2.94110273 10.2254981 ]
     [ 1.44233169  6.68000409]
     [ 8.11509328 19.30801919]
     [ 9.61113455 22.19692088]
     [ 5.31715886 10.11119983]
     [ 4.18841633 10.38081049]
     [ 3.38696086 10.89271924]
     [ 6.79149049 16.38415715]
     [ 9.43502226 22.16533282]
     [ 2.39967181  8.2302037 ]
     [ 7.36239661 17.6441651 ]
     [ 2.94883598  9.58722116]
     [ 8.07898524 18.88577106]
     [ 7.27030879 17.48551352]
     [ 3.82638905 10.42645614]
     [ 2.21195155  6.25717108]
     [ 2.62373052  7.75337578]
     [ 4.2888883  13.05445549]
     [ 4.79372014 13.47932057]
     [ 7.27504625 18.03400179]
     [ 7.6804564  18.91495491]
     [ 4.96708467 11.14166529]
     [ 1.3403768   3.22305645]
     [ 3.31317609 10.54117355]
     [ 8.83562974 21.6072844 ]
     [ 5.2485316  15.68878013]
     [ 3.09095248  8.00604601]
     [ 1.40126741  5.61514191]
     [ 5.97243366 16.6769609 ]
     [ 9.27273317 21.88579855]
     [ 1.10719167  7.34636051]
     [ 7.64405604 16.02157126]
     [ 1.46871259  7.17774005]
     [ 8.04512646 19.50343429]
     [ 6.44703966 14.47562085]
     [ 2.53241481  6.62603648]
     [ 7.88971862 19.72029653]
     [ 5.25505582 13.99824232]
     [ 6.42160172 17.67089747]
     [ 4.71902614 10.58448458]
     [ 3.69982573 11.40370471]
     [ 6.68953965 15.06146343]
     [ 9.7198446  23.28404806]
     [ 1.91131556  7.29834036]
     [ 6.54079876 15.89908666]
     [ 2.83388631  8.59233584]]
    
    #将拟合曲线与样本做比较
    import matplotlib.pyplot as plt
    import numpy as np
    plt.scatter(x_1, y_1, marker = 'o',color = 'red', s = 40,label = 'sampelSC' )
    #plt.show()
    x_2 = np.arange(0,10)
    w = 2.1614204809921613
    b = u2.1616063195346547
    y_2 = w*x+b
    plt.plot(x_2,y_2,label='Fitting Curve')
    plt.show()
    
    output_11_0.png

    可以看到曲线拟合的效果还是很不错的

    
    

    相关文章

      网友评论

          本文标题:线性回归与梯度下降实现

          本文链接:https://www.haomeiwen.com/subject/vpvbmhtx.html