美文网首页进阶的iOSer数学教育自然科普
如何快速理解线性回归及模型训练?

如何快速理解线性回归及模型训练?

作者: 岁与禾 | 来源:发表于2020-06-17 11:50 被阅读0次

在机器学习领域内,经常会听到回归问题、线性回归等专业术语,那什么是回归问题,线性回归又是什么意思以及怎么训练一个模型、怎么调整评估模型的预测性能呢,本文将结合网上资料及TensorFlow来训练一个最简单的线性模型并评估其性能优劣。

01

线性回归

什么是回归问题

回归分析是确定多个相互依赖的变量之间的定量关系的方法。回归分析通常用于预测,发现变量之间的因果关系。

线性回归

线性回归是回归分析的一种,它研究的是多个变量之间的线性组合,线性回归假设的是预测的目标和特征之间是线性相关,通常可以使用以下公式来表达:

image

其中w表示的是特征x的权重,b表示的是偏差。在几何中,w又表示直线的斜率,w表示截距。

线性回归的目的在于,通过大量的训练数据来找到最合适的w,b的值。机器学习中常见的专业术语“模型”,就是上述类似的一组描述特征和目标之间关系的公式。

机器学习训练出来的模型,最终需要应用的到验证数据上进行预测,而模型的优劣好坏如何度量,如何才能找到合适的w,b的值呢?这就需要定量分析,需要引入“损失”、“损失函数”等概念了。

02

损失函数

我们先看下面的一个例子,红色的直线表示当前的学习模型,蓝色的点表示测试数据,通过下图我们可以看到只有一个点跟模型拟合,其他的都欠拟合。从直觉上我们知道,当前的模型可能没那么好,那么如何定量的去衡量这个模型的好坏呢?

image

我们仍然以上图为例,取x = 12这个样本值,那么相应蓝色点的y值即为真实值,而相应红线所在的prediction(x)即表示预测值。那么y - prediction(x)即为损失。那么衡量这种损失的函数,我们称之为损失函数。

目前有这么几种损失函数:

  • 平方损失,又称L2损失

  • 残差平方和(RSS, Residual of Sum of Square)

  • 均方误差(MSE, Mean Square Error)

  • 均方根误差(RMSE, Root Mean Square Error)

其实都很好理解,我们一个一个来看。

平方损失

平方损失最好理解,单个样本的平方损失,可以有如下公式给出:

image

其实即为真实值与预测值之间差值的平方。

******残差平方和******

残差平方和指的是将单个样本的真实值与预测值之差的平方相加求和:

image

也即所有样本的平方损失之和。

****均方误差****

均方误差指的是平均平方损失,也即需要计算所有样本的平方损失之和,然后再除以样本之和。可以由如下公式给出:

image

其中(x,y)表示样本,x指的是样本的特征值,y指的是样本的标签(或者说目标值)。

D表示数据集,prediction(x)指的是模型预测值,N表示样本个数。

******均方根误差******

均方根误差与均方误差,从字面意思上来讲其实有很大的联系,均方根误差即为均方误差的开平方根。

image

在实际编程应用中,经常会先求出MSE,然后对其开平方即可得到RMSE。

了解完所有的损失函数之后,我们不禁有个疑问,这些误差到底在机器学习中起到了什么作用?

文章的开头我们指出,我们需要利用损失函数来评估模型的优劣,或者说我们通过损失函数的取值来调整模型的参数,以期获得最合适的模型。既然损失函数描述的是预测值与真实值的偏离情况,那么在实际训练过程中,我们应该找到使损失函数取最小值时取对应模型的参数。

我们假设目标和特征之间存在线性相关性,于是就有公式 y' = wx + b (假设只有一个特征)。以MSE(均方误差)及线性回归为例,我们将公式带入MSE中替换预测值,即可得如下公式:

image

该公式是以w,b为参数的函数,我们的目标是求取w,b的值使得该函数取最小值。由于N是一个常数,我们可以简化成下面的公式:

image

机器学习的模型训练问题到此就被我们简化成了求解上面公式(2)的最小值,求解函数最小值的方法有很多,比较常见的有:

  • 最小二乘法

  • 梯度下降法

最小二乘法我们这里省略不讲解,本文只讲解梯度下降法。

03

迭代法与梯度下降

讲解梯度下降法之前,我们先了解一下机器学习训练过程中的迭代法。

在机器学习过程中,我们的会选取一个初始化参数(w,b),然后通过损失函数来不断地调整参数值(w,b)直到训练处一个符合预期的模型。下图是一个简化版本的机器学习迭代过程:

image

图例来自Google机器学习文档

上图中,“计算参数更新”就是梯度下降法的本质。在回归问题中,权重值w的取值和损失之间的关系如下图,始终是一个凸函数(凸函数,用几何的方式来理解就是图形内任意两点构成的线段仍然被图形包围)。也就是说找到了一个极小值点,那么这个极小值点就是全局最小值点。

image

图例来自Google机器学习文档

上图中,任意选取一个起点。由于我们能够看到图形全貌,我们很快就能找到损失下降的方向,并最终定位到极小值点。但是在实际应用中,我们任意选取的起点后,是不能立刻确定哪个方向是损失下降的方向。举个例子,常识告诉我们地球是圆的,但站着地球上,我们始终感受不到弧度,也就无从找到某段圆弧的最低点。同样,我们就需要通过某种方式判断哪个方向是损失下降的方向。

在熟知的高数中,我们知道导数可以理解成斜率,或者曲线某个点的导数可以理解成改点切线的斜率。导数为正,则曲线递增;反之,曲线递减。梯度就可以理解成导数(偏导数)。

梯度下降就是按照负梯度的方向,逐渐的去找到最小值点。

04

模型训练

了解了上面的线性回归的基本原理以及模型训练的目的之后,我们总结一下线性回归模型训练的套路:

  • 分析训练数据,找到数据的合适的分布

  • 确定问题的类型,回归还是分类问题

  • 加载数据集,输入训练数据

  • 开始训练模型,迭代过程

  • 使用测试数据来分析当前模型的优劣

  • 其他(后面再讲)

可以参照google机器学习文档来详细了解如何使用tensorflow训练模型。接下来会写一篇TensorFlow模型训练初体验。

** 结语 **

线性回归是机器学习中常见的一种类型,通过损失函数来反复迭代机器学习模型。在迭代过程中采用了梯度下降法,寻找最优解,整个流程比较容易理解。

关注公众账号<岁与禾>,获取最新信息

相关文章

  • 如何快速理解线性回归及模型训练?

    在机器学习领域内,经常会听到回归问题、线性回归等专业术语,那什么是回归问题,线性回归又是什么意思以及怎么训练一个模...

  • 动手学深度学习pytorch(第一天、第二天)

    1.线性回归 线性回归输出是⼀个连续值,因此适⽤用于回归问题。 基本要素 ①模型定义 ②模型训练 (1)训练数据:...

  • 数据挖掘3

    建模调参 内容介绍 线性回归模型:线性回归对于特征的要求;处理长尾分布;理解线性回归模型; 模型性能验证:评价函数...

  • 零基础入门数据挖掘-Task4 建模调参

    内容介绍 线性回归模型:线性回归对于特征的要求;处理长尾分布;理解线性回归模型; 模型性能验证:评价函数与目标函数...

  • Tensorflow学习-No.1

    0 深度学习模型训练入门 如何通过tensorflow训练最简单的全连接网络,生成线性回归模型。 来自https:...

  • 2020-04-01

    线性回归模型:线性回归对于特征的要求;处理长尾分布;理解线性回归模型;模型性能验证:评价函数与目标函数;交叉验证方...

  • 线性回归代码实现

    线性回归是比较常用的模型。本文会简单介绍线性回归的原理,以及如何用代码实现线性回归模型。 什么是线性回归 简单举一...

  • 线性回归模型原理与推导

    什么是线性回归? 线性回归是在假设特征满足线性关系,根据给定的训练数据训练一个模型,并用此模型进行预测。其函数形式...

  • 2020机器学习线性模型(1)

    线性模型 今天我们来讨论一下线性模型,之前已经了解到线性模型来做回归问题,所谓回归问题就是根据给定样本数据训练出一...

  • 【机器学习实践】有监督学习:线性分类、回归模型

    线性模型 为线性模型 分类和回归的区别 分类:离散回归:连续本文主要关注线性回归模型 常用线性回归模型类型 OLS...

网友评论

    本文标题:如何快速理解线性回归及模型训练?

    本文链接:https://www.haomeiwen.com/subject/etojxktx.html