美文网首页
简单线性回归算法

简单线性回归算法

作者: 元宝的技术日常 | 来源:发表于2020-04-17 21:31 被阅读0次

    1、算法简介

    1-1、算法思路
    简单线性回归(SimpleLinearRegression)解决的是回归问题,上一篇是分类,这两个概念的区别是标签值-label的差距。分类问题的label一般是类别型,像性别类别,品牌类别... 而回归问题的label一般为为连续数值型,像身高、体重... 监督学习的样本点是由特征-feature和标签-label组成。

    简单线性回归之所以有简单两个字,是因为确实比较简单,如果能把样本点映射到空间中,它的作用就是试图用一条线把所有点表示。比如,样本只有一个特征值-x和一个标签值-y,最终的表示则为一元方程:y=ax+b;a、b则为此算法要求出的未知量。


    1-2、图示

    简单线性回归

    如图,样本点中间有一条直线,样本之间的关系试图要用一条直线来模拟。


    1-3、算法流程
    1--- 假如样本只有一个特征值-x和一个标签值-y,最终的表示则为一元方程:y^=ax+b
    2--- 试图求出a、b,使得推理出的-y^和实际标签值-y无限接近;判断是否接近的评测标准一般用:均方差-MSE、均方根差-RMSE、平均绝对误差-MAE和R Square。
    3--- 采用最小二乘法,可以求出a、b(推导过程见这里
    4--- 不同于kNN算法,在训练数据集上训练好参数a和b的值之后,推理/预测的时候,只需要使用学习到的参数a和b对每一个待推理/预测的样本进行计算就好了。这就是一个典型的参数学习算法。


    1-4、优缺点

    1-4-1、优点

    a、思想简单,实现容易;
    b、结果有可解释性;
    c、有强大的非线性模型的基础。

    1-4-2、缺点

    a、在数学模型中表示一根直线,而现实环境中很多的数据,例如房价,销售涨跌都是曲线结构的,使得推理/预测率低;
    b、难以很好地表达高度复杂的数据。

    2、实践

    2-1、采用bobo老师创建简单测试用例

    import numpy as np
    import matplotlib.pyplot as plt
    
    # 创建测试数据
    x = np.array([1., 2., 3., 4., 5.])
    y = np.array([1., 3., 2., 3., 5.])
    
    plt.scatter(x, y)
    plt.axis([0, 6, 0, 6])
    plt.show() #见plt.show1
    
    plt.show1
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    
    # 最小化误差的平方,用最小二乘法求解:
    num = 0.0 
    d = 0.0 
    for x_i, y_i in zip(x, y):
        num += (x_i - x_mean) * (y_i - y_mean)
        d += (x_i - x_mean) ** 2
    
    a = num/d
    b = y_mean - a * x_mean
    y_hat = a * x + b
    
    plt.scatter(x, y)
    plt.plot(x, y_hat, color='r')
    plt.axis([0, 6, 0, 6])
    plt.show() #见plt.show2
    
    plt.show2

    相关文章

      网友评论

          本文标题:简单线性回归算法

          本文链接:https://www.haomeiwen.com/subject/zppsvhtx.html