梯度下降法
梯度定义
梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模)。
<p align="right">--------百度百科</p>
对于来说,其梯度为:
对于来说,其梯度为:
梯度下降法思路
因为梯度是函数上升最快的方向,所以如果我们要寻找函数的最小值,只需沿着梯度的反方向寻找即可。这里以为例,简述梯度下降法实现的大体步骤:
- 确定变量的初始点,从初始点开始一步步向函数最小值逼近。
- 求函数梯度,然后求梯度的反向,将变量的初始点代入,确定变量变化的方向:;用求得的梯度向量(变量变化的方向)乘以学习率 (变量变化的步长)得到一个新的向量;变量的初始点加上求得的新向量,到达下一个点。
- 判断此时函数值的变化量是否满足精度要求。定义一个我们认为满足要求的精度;用上一个点的函数值减去当前点的函数值,得到此时函数值变化量的精度值(可以近似认为p为损失函数);判断是否成立。不成立则反复执行步骤2、3。
但是梯度下降法对初始点的选取要求比较高,选取不当容易陷入极小值(局部最优解)。
梯度下降法的简单应用
梯度下降法求二维曲线的最小值
下图为梯度下降法求曲线最小值的结果图,左图红色的点为求解过程中的过程点,右图为求解过程中精度的变化(损失函数值的变化),代码见附录。
梯度下降法求三维曲面的最小值
下图为梯度下降法求曲面最小值的结果图,图中红色的点为求解过程中的过程点,代码见附录。
代码附录
# -*- encoding=utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as aplt
from mpl_toolkits.mplot3d.axes3d import Axes3D
import sympy
class gradientDescent(object):
def init2D(self,vector:float,precision:float,startPoint:float):
"""
vector:学习率
precision:精度
startPoint:起始点
"""
self.vector = vector
self.precision = precision
self.startPoint = startPoint
self.startPrecision = precision + 1
def init3D(self,vector:float,precision:float,startVar1Point:float,startVar2Point:float):
"""
vector:学习率
precision:精度
startVar1Point:变量1的起始位置
startVar2Point:变量2的起始位置
"""
self.vector = vector
self.precision = precision
self.startVar1Point = startVar1Point
self.startVar2Point = startVar2Point
self.startPrecision = precision + 1
def singleVar2D(self, func:str, var:str):
grad = sympy.diff(func, var)
grad = str(grad)
xpoint = []
ypoint = []
errors = []
x = self.startPoint
while self.startPrecision > self.precision:
y = eval(func)
xpoint.append(x)
ypoint.append(y)
x1 = x - self.vector*eval(grad)
x = x1
y1 = eval(func)
self.startPrecision = y - y1
errors.append(self.startPrecision)
xpoint.append(x)
ypoint.append(y)
xlen = len(xpoint)
return [xpoint,ypoint,errors,xlen]
def doubleVar3D(self, func:str, var1:str, var2:str):
var1Grad = sympy.diff(func, var1)
var1Grad = str(var1Grad)
var1Grad = var1Grad.replace("sqrt","np.sqrt")
var2Grad = sympy.diff(func, var2)
var2Grad = str(var2Grad)
var2Grad = var2Grad.replace("sqrt","np.sqrt")
func = func.replace("sqrt","np.sqrt")
xpoint = []
ypoint = []
zpoint = []
errors = []
x = self.startVar1Point
y = self.startVar2Point
while self.startPrecision > self.precision:
z = eval(func)
xpoint.append(x)
ypoint.append(y)
zpoint.append(z)
x1 = x - self.vector*eval(var1Grad)
y1 = y - self.vector*eval(var2Grad)
x = x1
y = y1
z1 = eval(func)
self.startPrecision = z - z1
errors.append(self.startPrecision)
xpoint.append(x)
ypoint.append(y)
zpoint.append(z)
xlen = len(xpoint)
return [xpoint,ypoint,zpoint,errors,xlen]
if __name__ == '__main__':
xData = np.arange(-100,100,0.1)
yData = xData**2 + 2*xData + 5
vector=0.2
precision=10e-6
startPoint=-100
x = sympy.symbols("x")
func = "x**2+2*x+5"
gradient_descent = gradientDescent()
gradient_descent.init2D(vector,precision,startPoint)
[xpoint,ypoint,errors,xlen] = gradient_descent.singleVar2D(func,x)
fig,ax = plt.subplots(figsize=(12,8),ncols=2,nrows=1)
for i in range(xlen):
ax[0].cla()
ax[0].plot(xData,yData,color="green",label="$y=x^2+2x+5$")
ax[0].scatter(xpoint[i],ypoint[i],color="red",label="process point")
plt.pause(0.1)
ax[0].legend(loc = "best")
ax[1].plot(errors,label="Loss curve")
ax[1].legend(loc = "best")
plt.pause(0.1)
plt.show()
# =======================================================================
xData = np.arange(-100,100,0.1)
yData = np.arange(-100,100,0.1)
X,Y = np.meshgrid(xData,yData)
# z = sqrt(x^2+y^2)
Z = np.sqrt(X**2+Y**2)
x = sympy.symbols("x")
y = sympy.symbols("y")
func = "sqrt(x**2+y**2)"
vector=0.2
precision=10e-6
startVar1Point=100
startVar2Point=-100
gradient_descent = gradientDescent()
gradient_descent.init3D(vector, precision, startVar1Point, startVar2Point)
[xpoint,ypoint,zpoint,errors,xlen] = gradient_descent.doubleVar3D(func,x,y)
fig = plt.figure()
ax = Axes3D(fig)
surf = ax.plot_surface(X,Y,Z,label="$z=\sqrt{x^2+y^2}$")
ax.scatter(xpoint,ypoint,zpoint,color="red",label="process point")
# 解决标签报错,不显示问题
surf._facecolors2d=surf._facecolors3d
surf._edgecolors2d=surf._edgecolors3d
ax.legend()
plt.show()
网友评论