美文网首页
实现一元线性回归

实现一元线性回归

作者: 大家都写错了字 | 来源:发表于2018-10-22 02:58 被阅读0次

机器学习最开始最入门的:一元线性回归

直线方程 f(x)= wx+b 对应到数组到点,则有有这样到关系:f(xi)=wxi+b

设输入为(x1,x2,x3,...xi),(y1,y2,y3,...yi)

损失函数为:△ =∑ √ (f(xi)-yi)**2  (所有计算所得的y_i与输入的yi的差值的平方再求平方根) 

△ =∑ √ (wxi+b-yi)**2 

为了方便计算,我们一般不再做开平方的操作,只做了平方保证其为正数即可

△ =∑  (wxi+b-yi)**2   (xi∈x)

梯度下降算法主要目的是使得损失函数变小,需要捕捉数损失函数到变化趋势,而导数则是变化趋势的表达形式,

于是,我们只要找到损失函数的导数为0的时候,对应的损失函数就是达到了极大或者极小值

关于w的偏导数 d△/dw 为: 2∑(wxi+b-yi)xi 

关于b的偏导数为: d△/db = 2∑(wxi+b-yi)

所以关键在于:如何求出损失函数的偏导数。

当关于w的偏导数为正数时,表示损失函数在W增加时,损失函数在增加;所以我们将w减去一个小值,会让损失函数变小,

同理为:关于w偏导数为负时,我们应该增加w

关于b的偏导数也是一样。

于是w的变化将依据关于w的偏导数来变化。

w = w- 每次变化的步长*偏导数/i(偏导数越大,w变化的越快)

简答的求导过过程,注意复合函数求导的关系还有导数求导的基本运算规则(自己也要从头复习了)

以下列出全部的代码

用python

数据样本为:

3.0000 10.5000 3.1000 10.3000 3.2000 11.6000 3.3000 12.7000 3.4000 11.1000 3.5000 11.8000 3.6000 11.8000 3.7000 12.800 3.8000 12.4000 3.9000 12.7000 4.0000 13.0000 4.1000 13.3000 4.2000 13.6000 4.3000 13.7000 4.4000 14.4000 4.5000 14.6000 4.6000 13.6000 4.7000 14.2000 4.8000 15.5000 4.9000 15.4000 5.0000 16.9000



import numpy as np

import matplotlib.pyplot as plt

rate = 0.01 # 每次更新的步幅

steps = 20000 #训练次数

data = np.loadtxt('data.txt')#用np的loadtxt方法读区文件,默认分隔符为空格,其他用‘;’这样的格式设置

print(data,type(data),data.shape)

x= []

y=[] #设置两个数组,用来承载读区的数据,即输入和输出样本

w=1 #默认一个w

b=8 #默认一个截距 b

for i in range(len(data)): #按照单双的交叉的方式拆分输入输出样本

    if i % 2 ==0: # 第0 2 4 6 8 双数为x

        x.append(data[i])

    if i%2 == 1: #其他为y

        y.append(data[i])

print(x,y)

plt.plot(x,y,'o')

delta = 0 #初始化一个误差

yy=[] #设置一个数组,为训练前计算出来的预测输出

y_=[] #预设数组为训练后预测输出

for i in range(len(x)): #计算损失函数

    delta += (( w*x[i] + b) - y[i])**2 # 误差距离平方

    yy.append(w*x[i] + b) #存下全部的初始化的时候的输出

plt.plot(x,y,'o')

plt.plot(x,yy)#画图

dwa=[] #预设数组,用来记录w的导数的变化,只是为了可视化而做的

dba=[] #预设数组,用来记录b的导数的变化 只是为了可视化而做的

D = []

for s in range(steps):

    delta = 0 #对多个参数做初始化

    dw = 0

    db = 0

    y_temp=[] #用来临时记录输出值,用来在设置的频率下,观察输出的直线变化,

    for i in range(len(x)):#进行批量的训练,一次训练全部的数据样本:样本为多组 一元参数

        dw += (w*x[i]+b-y[i])*x[i] # 这是关于w的偏导数

        db += (w*x[i]+b-y[i])  #这是关于b的偏导数

    w = w - rate/len(x) * dw #这里进行w的更新,按照样本数量平分步长然后乘以w的偏导数,这样等于根据损失函数的偏导数大小来变化更新的步长,偏导数绝对值越大,表示误差变化越大,则表示误差还有一个很大的范围,所以修改的值也越大

    b = b - rate/len(x) * db #同上

    for q in range(len(x)): #计算本次之后的损失函数,主要为了显示观察,可以注释掉

        delta += (( w*x[q] + b) - y[q])**2 #循环一开始就讲delta初始化为0,然后每次都从新计算

#    print(delta)

    D.append(delta)

    for temp_i in range(len(x)):

        y_temp.append( w*x[temp_i] + b)

    if s%500 == 0: #每五次显示一次误差,同时记录一次图像,

        plt.plot(x,y_temp,)

        #print(s,delta)

    dwa.append(dw) #记录损失函数关于w的偏导数的数值,为了记录和观察变化,当将步长变得比较大的时候就会发现,偏导数会剧烈的跳到,使得误差收敛误差

    dba.append(db)

for i in range(len(x)):

    delta += (( w*x[i] + b) - y[i])*(( w*x[i] + b) - y[i])

    y_.append(w*x[i] + b)#记录最终状态下的输出 ,主要是为了画图

print("所以函数是:y=" ,w,"x+",b) #当然还是要写出这个公式来

#以下是多个图形显示变化

plt.plot(x,y,'o')

plt.title('训练过程中的输出变化',fontproperties='SimHei')

plt.show()

plt.subplot(121)

plt.title('损失函数变化',fontproperties='SimHei')

plt.plot(range(steps),D)

plt.subplot(122)

plt.title('w偏导数',fontproperties='SimHei')

plt.plot(range(steps),dwa)

plt.show()

plt.subplot(121)

plt.title('b偏导数',fontproperties='SimHei')

plt.plot(range(steps),dba)

plt.subplot(122)

plt.title('样本+直线',fontproperties='SimHei')

plt.plot(x,y_)

plt.plot(x,y,'o')

plt.show()


    生成数据图为:

相关文章

  • 2020-08-13--线性回归01

    线性回归算法简介 解决回归问题 思想简单,容易实现 许多强大的非线性模型的基础 结果具有很好的 线性回归分为一元线...

  • 2020-02-14

    线性回归:线性回归分为一元线性回归和多元线性回归,一元线性回归用一条直线描述数据之间的关系,多元回归是用一条曲线描...

  • Matlab一元/多元回归(后续会有更新)

    一元线性回归&一元非线性回归 多元回归/逐步回归 多元回归 先画散点图,看有没有线性性质,再决定能不能用多元线性回...

  • 使用TF实现一元线性回归

    前言 要实现一元线性回归,总体思路按照Tensorflow的指导手册进行改装

  • 一元线性回归方程

    目标:写清楚一元线性回归分析的全部过程。 一元线性回归分析步骤: 确定变量variable:independent...

  • 数学建模系列笔记2:回归和时间序列

    数学建模 @[toc] 3-1-1 一元线性回归 一般,假设 若 称为一元正态线性回归模型 回归分析要解决的主要问...

  • 实现一元线性回归

    机器学习最开始最入门的:一元线性回归 直线方程f(x)= wx+b 对应到数组到点,则有有这样到关系:f(xi)=...

  • 机器学习

    1.线性回归 1.1一元线性回归 y=a+bx 1.2多元线性回归 y=a+b1x1+b2x2+...+bnxn ...

  • 机器学习第4天:线性回归及梯度下降

    联系我:ke.zb@qq.com我的技术博客:明天依旧可好-CSDN 一、简单线性回归(即一元线性回归) 线性回归...

  • 线性回归--sklearn框架实现

    线性回归--原理 线性回归--python实现(不使用框架) 线性回归--sklearn框架实现 这里使用skle...

网友评论

      本文标题:实现一元线性回归

      本文链接:https://www.haomeiwen.com/subject/hacwzftx.html