目标函数 func(x) = x * x = x^2的极小值,由于func 是一个凸函数,因此他唯一的极小值,也是他的最小值,其一阶导函数 dfunc(x) = 2 * x
实现代码如下:
1.引入: 计算(numpy) | 画图( matplotlib)
import numpy as np
import matplotlib.pyplot as plt
2.定义函数
# 目标函数:y=x^2
def func(x):
return np.square(x)
# 目标函数的一阶导函数也就是偏导数:dy/dx = 2*x
def dfunc(x):
return 2 * x
# 梯度下降算法函数
def gradirent_descent(x_start,df,epochs,learning_rate):
"""
梯度下降算法。给定起始点和目标函数的一阶导函数,求在epochs 次迭代中x的更新值
args:
x_start:x的起始点
func_deri:目标函数的一阶导函数
epochs:迭代周期
learning_rate:学习率
return:
xs 在每次迭代后的位置(包括起始点),长度为epochs+1
"""
theta_x = np.zeros(epochs + 1)
temp_x = x_start
theta_x[0] = temp_x
for i in range(epochs):
deri_x = dfunc(temp_x)
# v 表示 x 要改变的幅度
delta = - deri_x * learning_rate
temp_x = temp_x + delta
theta_x[i+1] = temp_x
return theta_x
# 定义画图函数
def mat_plot():
# 利用matplotlib 绘制图像
line_x = np.linspace(-5,5,100)
line_y = func(line_x)
x_start = -5
epochs = 5
lr = 0.3
x = gradirent_descent(x_start,dfunc,epochs,0.3)
color = 'r'
# 实现绘制主功能
plt.plot(line_x,line_y,c='b')
plt.plot(x,func(x),c=color,label='lr={}'.format(lr))
plt.plot(x,func(x),c=color,)
# legend 函数显示图例
plt.legend()
# show 函数显示
plt.show()
3.调用
mat_plot()
展示效果如下:

心得体会:原书中代码有些失误,幸好,编译器够好用,定位到出错行,最后根据错误解决。略微不足的是书中图例在右上,每个折点都有圈
纠错:最后一个plot改为
plt.scatter(x,func(x),c=color,)
就会变得好看

网友评论