001-梯度下降算法

作者: 吉林天师 | 来源:发表于2019-07-20 00:06 被阅读0次

目标函数 func(x) = x * x = x^2的极小值,由于func 是一个凸函数,因此他唯一的极小值,也是他的最小值,其一阶导函数 dfunc(x) = 2 * x

实现代码如下:

1.引入: 计算(numpy) | 画图( matplotlib)

import numpy as np
import matplotlib.pyplot as plt

2.定义函数

# 目标函数:y=x^2
def func(x):
    return np.square(x)

# 目标函数的一阶导函数也就是偏导数:dy/dx = 2*x
def dfunc(x):
    return 2 * x

#  梯度下降算法函数
def gradirent_descent(x_start,df,epochs,learning_rate):
    """
    梯度下降算法。给定起始点和目标函数的一阶导函数,求在epochs 次迭代中x的更新值
    args:
        x_start:x的起始点
        func_deri:目标函数的一阶导函数
        epochs:迭代周期
        learning_rate:学习率
    return:
        xs 在每次迭代后的位置(包括起始点),长度为epochs+1
    """
    theta_x = np.zeros(epochs + 1)
    temp_x = x_start
    theta_x[0] = temp_x
    for i in range(epochs):
        deri_x = dfunc(temp_x)
        # v 表示 x 要改变的幅度
        delta = - deri_x * learning_rate
        temp_x = temp_x + delta
        theta_x[i+1] = temp_x
    return theta_x

# 定义画图函数
def mat_plot():
    # 利用matplotlib 绘制图像
    line_x = np.linspace(-5,5,100)
    line_y = func(line_x)
    
    x_start = -5
    epochs = 5
    lr = 0.3
    x = gradirent_descent(x_start,dfunc,epochs,0.3)
    
    color = 'r'
    # 实现绘制主功能
    plt.plot(line_x,line_y,c='b')
    plt.plot(x,func(x),c=color,label='lr={}'.format(lr))
    plt.plot(x,func(x),c=color,)
    # legend 函数显示图例
    plt.legend()
    # show 函数显示
    plt.show()

3.调用

mat_plot()

展示效果如下:

x^2 的一阶偏导数梯度下降算法

心得体会:原书中代码有些失误,幸好,编译器够好用,定位到出错行,最后根据错误解决。略微不足的是书中图例在右上,每个折点都有圈

纠错:最后一个plot改为

plt.scatter(x,func(x),c=color,)

就会变得好看


带有折点的图

相关文章

网友评论

    本文标题:001-梯度下降算法

    本文链接:https://www.haomeiwen.com/subject/hepelctx.html