梯度下降(代码示例)

作者: Jinglever | 来源:发表于2020-03-21 22:02 被阅读0次

梯度下降可以说是当代机器学习的基石。一切机器学习问题,归根到底都是求最优的问题,而求最优都会被转换成求最小值。梯度下降就是用于求最小值的强力武器。

要理解梯度下降,需要知道两个数学概念:导数偏导数。然后,网上讲解梯度下降的文章非常多,也基本都能把概念介绍清楚。

推荐阅读简书上的:《深入浅出--梯度下降法及其实现》

如果需要更深入的讲解,可以阅读知乎上的:《理解梯度下降法》

你以为看完这两篇文章后,就能写代码啦?通过下面两道题目来测验一下吧。要求用代码实现梯度下降算法来求解最小值。题目一是最基础的一元函数,题目二是三元函数。

题目一:求y = x^2 的最小值

Python代码如下:

def y(x):
    return x ** 2

def grad(x):
    """
    求梯度值
    """
    return 2 * x

lr = 0.1 # 步长
max_iter = 100 # 最大迭代次数
stop = 1e-6 # 梯度值小于等于此值后停止迭代
x = 1. # 指定一个初始值
for i in range(max_iter):
    g = grad(x)
    x = x - lr * g
    print('iter: {}: x={} y={} grad={}'.format(i, x, y(x), g))
    if abs(g) <= stop:
        break
print('min y: ', y(x))

执行后输出:

iter: 0: x=0.8 y=0.6400000000000001 grad=2.0
iter: 1: x=0.64 y=0.4096 grad=1.6
iter: 2: x=0.512 y=0.262144 grad=1.28
iter: 3: x=0.4096 y=0.16777216 grad=1.024
iter: 4: x=0.32768 y=0.10737418240000002 grad=0.8192
iter: 5: x=0.26214400000000004 y=0.06871947673600003 grad=0.65536
iter: 6: x=0.20971520000000005 y=0.04398046511104002 grad=0.5242880000000001
iter: 7: x=0.16777216000000003 y=0.02814749767106561 grad=0.4194304000000001
iter: 8: x=0.13421772800000004 y=0.018014398509481992 grad=0.33554432000000006
iter: 9: x=0.10737418240000003 y=0.011529215046068476 grad=0.26843545600000007
iter: 10: x=0.08589934592000002 y=0.0073786976294838245 grad=0.21474836480000006
iter: 11: x=0.06871947673600001 y=0.004722366482869647 grad=0.17179869184000005
iter: 12: x=0.05497558138880001 y=0.003022314549036574 grad=0.13743895347200002
iter: 13: x=0.04398046511104001 y=0.0019342813113834073 grad=0.10995116277760002
iter: 14: x=0.035184372088832 y=0.0012379400392853804 grad=0.08796093022208001
iter: 15: x=0.028147497671065603 y=0.0007922816251426435 grad=0.070368744177664
iter: 16: x=0.02251799813685248 y=0.0005070602400912918 grad=0.056294995342131206
iter: 17: x=0.018014398509481985 y=0.00032451855365842676 grad=0.04503599627370496
iter: 18: x=0.014411518807585589 y=0.00020769187434139315 grad=0.03602879701896397
iter: 19: x=0.01152921504606847 y=0.0001329227995784916 grad=0.028823037615171177
iter: 20: x=0.009223372036854777 y=8.507059173023463e-05 grad=0.02305843009213694
iter: 21: x=0.007378697629483821 y=5.444517870735016e-05 grad=0.018446744073709553
iter: 22: x=0.005902958103587057 y=3.4844914372704106e-05 grad=0.014757395258967642
iter: 23: x=0.004722366482869646 y=2.230074519853063e-05 grad=0.011805916207174114
iter: 24: x=0.0037778931862957168 y=1.4272476927059604e-05 grad=0.009444732965739291
iter: 25: x=0.0030223145490365735 y=9.134385233318147e-06 grad=0.0075557863725914335
iter: 26: x=0.002417851639229259 y=5.846006549323614e-06 grad=0.006044629098073147
iter: 27: x=0.001934281311383407 y=3.741444191567113e-06 grad=0.004835703278458518
iter: 28: x=0.0015474250491067257 y=2.3945242826029522e-06 grad=0.003868562622766814
iter: 29: x=0.0012379400392853806 y=1.5324955408658897e-06 grad=0.0030948500982134514
iter: 30: x=0.0009903520314283045 y=9.807971461541695e-07 grad=0.002475880078570761
iter: 31: x=0.0007922816251426436 y=6.277101735386685e-07 grad=0.001980704062856609
iter: 32: x=0.000633825300114115 y=4.017345110647479e-07 grad=0.0015845632502852873
iter: 33: x=0.0005070602400912919 y=2.571100870814386e-07 grad=0.00126765060022823
iter: 34: x=0.0004056481920730336 y=1.6455045573212074e-07 grad=0.0010141204801825839
iter: 35: x=0.00032451855365842687 y=1.0531229166855728e-07 grad=0.0008112963841460671
iter: 36: x=0.0002596148429267415 y=6.739986666787666e-08 grad=0.0006490371073168537
iter: 37: x=0.0002076918743413932 y=4.313591466744107e-08 grad=0.000519229685853483
iter: 38: x=0.00016615349947311455 y=2.7606985387162278e-08 grad=0.0004153837486827864
iter: 39: x=0.00013292279957849163 y=1.7668470647783856e-08 grad=0.0003323069989462291
iter: 40: x=0.00010633823966279331 y=1.130782121458167e-08 grad=0.00026584559915698326
iter: 41: x=8.507059173023465e-05 y=7.2370055773322674e-09 grad=0.00021267647932558662
iter: 42: x=6.805647338418771e-05 y=4.631683569492651e-09 grad=0.0001701411834604693
iter: 43: x=5.444517870735017e-05 y=2.9642774844752965e-09 grad=0.00013611294676837542
iter: 44: x=4.3556142965880136e-05 y=1.8971375900641898e-09 grad=0.00010889035741470034
iter: 45: x=3.4844914372704106e-05 y=1.2141680576410813e-09 grad=8.711228593176027e-05
iter: 46: x=2.7875931498163285e-05 y=7.77067556890292e-10 grad=6.968982874540821e-05
iter: 47: x=2.230074519853063e-05 y=4.973232364097869e-10 grad=5.575186299632657e-05
iter: 48: x=1.7840596158824504e-05 y=3.1828687130226364e-10 grad=4.460149039706126e-05
iter: 49: x=1.4272476927059604e-05 y=2.0370359763344877e-10 grad=3.568119231764901e-05
iter: 50: x=1.1417981541647683e-05 y=1.3037030248540718e-10 grad=2.8544953854119208e-05
iter: 51: x=9.134385233318147e-06 y=8.343699359066061e-11 grad=2.2835963083295365e-05
iter: 52: x=7.307508186654517e-06 y=5.3399675898022794e-11 grad=1.8268770466636293e-05
iter: 53: x=5.8460065493236135e-06 y=3.417579257473458e-11 grad=1.4615016373309035e-05
iter: 54: x=4.676805239458891e-06 y=2.1872507247830133e-11 grad=1.1692013098647227e-05
iter: 55: x=3.741444191567113e-06 y=1.3998404638611287e-11 grad=9.353610478917782e-06
iter: 56: x=2.9931553532536903e-06 y=8.958978968711224e-12 grad=7.482888383134226e-06
iter: 57: x=2.3945242826029522e-06 y=5.733746539975183e-12 grad=5.986310706507381e-06
iter: 58: x=1.915619426082362e-06 y=3.669597785584118e-12 grad=4.7890485652059045e-06
iter: 59: x=1.5324955408658897e-06 y=2.3485425827738356e-12 grad=3.831238852164724e-06
iter: 60: x=1.2259964326927118e-06 y=1.503067252975255e-12 grad=3.0649910817317793e-06
iter: 61: x=9.807971461541693e-07 y=9.61963041904163e-13 grad=2.4519928653854235e-06
iter: 62: x=7.846377169233355e-07 y=6.156563468186643e-13 grad=1.9615942923083387e-06
iter: 63: x=6.277101735386684e-07 y=3.9402006196394517e-13 grad=1.569275433846671e-06
iter: 64: x=5.021681388309347e-07 y=2.5217283965692494e-13 grad=1.2554203470773367e-06
iter: 65: x=4.0173451106474777e-07 y=1.6139061738043196e-13 grad=1.0043362776618695e-06
iter: 66: x=3.213876088517982e-07 y=1.0328999512347644e-13 grad=8.034690221294955e-07
min y:  1.0328999512347644e-13

题目二:求y = (5x_1 + 3x_2 -1)^2 + (3x_1 + 4x_3 - 1)^2的最小值

方案一:推导y分别对x_1x_2x_3的偏导数公式

import numpy as np

def y(x):
    x1, x2, x3 = x[0], x[1], x[2]
    return (5 * x1 + 3 * x2 - 1) ** 2 + (3 * x1 + 4 * x3 - 1) ** 2
  
def grad(x):
    """
    手推偏导公式
    """
    x1, x2, x3 = x[0], x[1], x[2]
    return np.array([68 * x1 + 30 * x2 + 24 * x3 - 16, 
            30 * x1 + 18 * x2 - 6,
            24 * x1 + 32 * x3 - 8])

lr = 0.01  # 步长
max_iter = 100  # 最大迭代次数
stop = 1e-6 # 梯度小于等于该值后停止迭代
x = np.array([1., 1., 1.])  # 初始值
for i in range(max_iter):
    g = grad(x)
    x = x - lr * g
    print('iter {}: x={} y={} grad={} grad_norm={}'.format(i, x, y(x), g, np.linalg.norm(g)))
    if np.all(np.abs(g) <= stop):
        break
print('y_min: ', y([x[0], x[1], x[2]]))

执行得到输出:

iter 0: x=[-0.06  0.58  0.52] y=1.0035999999999998 grad=[106.  42.  48.] grad_norm=123.70933675353692
iter 1: x=[-0.158   0.5536  0.448 ] y=0.11781664 grad=[9.8  2.64 7.2 ] grad_norm=12.443857922686194
iter 2: x=[-0.16416   0.561352  0.42256 ] y=0.05780793913599991 grad=[ 0.616  -0.7752  2.544 ] grad_norm=2.729895060254149
iter 3: x=[-0.1623512   0.56955664  0.4067392 ] y=0.030199645260006416 grad=[-0.18088  -0.820464  1.58208 ] grad_norm=1.791327964415225
iter 4: x=[-0.16043678  0.5757418   0.39554694] y=0.015795032234661038 grad=[-0.1914416  -0.61851648  1.1192256 ] grad_norm=1.293011394357185
iter 5: x=[-0.15899358  0.58023932  0.38747675] y=0.008261296899959504 grad=[-0.14432051 -0.44975103  0.80701939] grad_norm=0.9350853979569257
iter 6: x=[-0.15794416  0.58349431  0.38164265] y=0.004320918811588607 grad=[-0.10494191 -0.32549969  0.58341011] grad_norm=0.6762619393258339
iter 7: x=[-0.15718466  0.58584858  0.3774236 ] y=0.0022599768204799413 grad=[-0.07594993 -0.23542718  0.42190493] grad_norm=0.489078847504133
iter 8: x=[-0.15663533  0.58755124  0.37437237] y=0.0011820391570467469 grad=[-0.05493301 -0.17026531  0.30512334] grad_norm=0.3537063196554449
iter 9: x=[-0.15623804  0.58878261  0.37216569] y=0.0006182437607913295 grad=[-0.03972857 -0.12313765  0.22066779] grad_norm=0.2558036627980692
iter 10: x=[-0.15595072  0.58967316  0.3705698 ] y=0.0003233609863757719 grad=[-0.02873212 -0.0890543   0.15958895] grad_norm=0.1849995611186115
iter 11: x=[-0.15574293  0.59031721  0.36941564] y=0.0001691279947185792 grad=[-0.02077934 -0.06440489  0.1154162 ] grad_norm=0.13379338372139668
iter 12: x=[-0.15559265  0.59078299  0.36858094] y=8.845927555493002e-05 grad=[-0.01502781 -0.04657821  0.08347006] grad_norm=0.09676060537324213
iter 13: x=[-0.15548397  0.59111985  0.36797727] y=4.626699113132352e-05 grad=[-0.01086825 -0.03368579  0.06036631] grad_norm=0.06997815954556025
iter 14: x=[-0.15540537  0.59136346  0.3675407 ] y=2.4199095628099958e-05 grad=[-0.00786002 -0.02436187  0.04365747] grad_norm=0.05060884845123345
iter 15: x=[-0.15534853  0.59153965  0.36722496] y=1.265689025585839e-05 grad=[-0.00568444 -0.01761873  0.03157348] grad_norm=0.036600784561823196
iter 16: x=[-0.15530741  0.59166707  0.36699662] y=6.61995280363997e-06 grad=[-0.00411104 -0.01274203  0.02283423] grad_norm=0.02647002394910958
iter 17: x=[-0.15527768  0.59175922  0.36683148] y=3.4624441104046973e-06 grad=[-0.00297314 -0.00921515  0.01651393] grad_norm=0.019143364718940906
iter 18: x=[-0.15525618  0.59182587  0.36671205] y=1.8109674756416336e-06 grad=[-0.0021502  -0.00666448  0.01194302] grad_norm=0.013844657393094468
iter 19: x=[-0.15524063  0.59187407  0.36662568] y=9.471931078902917e-07 grad=[-0.00155505 -0.00481981  0.00863731] grad_norm=0.010012583532012379
iter 20: x=[-0.15522938  0.59190892  0.36656321] y=4.954118700101854e-07 grad=[-0.00112462 -0.00348573  0.00624658] grad_norm=0.00724119247873167
iter 21: x=[-0.15522125  0.59193413  0.36651804] y=2.591160333647869e-07 grad=[-0.00081334 -0.00252092  0.00451758] grad_norm=0.005236896985319308
iter 22: x=[-0.15521537  0.59195236  0.36648537] y=1.3552585800024948e-07 grad=[-0.00058821 -0.00182315  0.00326716] grad_norm=0.003787372054451165
iter 23: x=[-0.15521112  0.59196555  0.36646174] y=7.088429823583231e-08 grad=[-0.0004254  -0.00131852  0.00236284] grad_norm=0.0027390622956763953
iter 24: x=[-0.15520804  0.59197508  0.36644465] y=3.7074723676619534e-08 grad=[-0.00030765 -0.00095356  0.00170883] grad_norm=0.0019809150386429397
iter 25: x=[-0.15520581  0.59198198  0.36643229] y=1.9391249824058642e-08 grad=[-0.0002225  -0.00068963  0.00123584] grad_norm=0.0014326159709912123
iter 26: x=[-0.1552042   0.59198697  0.36642335] y=1.0142235260318824e-08 grad=[-0.00016091 -0.00049874  0.00089377] grad_norm=0.001036081043507717
iter 27: x=[-0.15520304  0.59199058  0.36641689] y=5.304708928439957e-09 grad=[-0.00011637 -0.0003607   0.00064638] grad_norm=0.0007493033377066719
iter 28: x=[-0.1552022   0.59199318  0.36641221] y=2.774530080738553e-09 grad=[-8.41625077e-05 -2.60858998e-04  4.67469984e-04] grad_norm=0.0005419030638727381
iter 29: x=[-0.15520159  0.59199507  0.36640883] y=1.4511667412422119e-09 grad=[-6.08670994e-05 -1.88655626e-04  3.38078591e-04] grad_norm=0.00039190927873711517
iter 30: x=[-0.15520115  0.59199644  0.36640639] y=7.59005975631017e-10 grad=[-4.40196460e-05 -1.36437483e-04  2.44501546e-04] grad_norm=0.0002834323940940813
iter 31: x=[-0.15520083  0.59199742  0.36640462] y=3.9698406438671973e-10 grad=[-3.18354128e-05 -9.86728425e-05  1.76825766e-04] grad_norm=0.00020498091364388083
iter 32: x=[-0.1552006   0.59199814  0.36640334] y=2.0763518660267846e-10 grad=[-2.30236632e-05 -7.13611070e-05  1.27882020e-04] grad_norm=0.00014824408160084345
iter 33: x=[-0.15520044  0.59199865  0.36640242] y=1.0859975142801017e-10 grad=[-1.66509250e-05 -5.16090088e-05  9.24854529e-05] grad_norm=0.00010721148295828789
iter 34: x=[-0.15520031  0.59199902  0.36640175] y=5.680109524789638e-11 grad=[-1.20421020e-05 -3.73241097e-05  6.68863299e-05] grad_norm=7.753633031283164e-05
iter 35: x=[-0.15520023  0.59199929  0.36640126] y=2.970876432768251e-11 grad=[-8.70895893e-06 -2.69931393e-05  4.83728088e-05] grad_norm=5.607498705048494e-05
iter 36: x=[-0.15520016  0.59199949  0.36640091] y=1.5538620761544232e-11 grad=[-6.29839918e-06 -1.95216866e-05  3.49836602e-05] grad_norm=4.0553946258328485e-05
iter 37: x=[-0.15520012  0.59199963  0.36640066] y=8.12718874970174e-12 grad=[-4.55506020e-06 -1.41182632e-05  2.53005047e-05] grad_norm=2.9328986837998224e-05
iter 38: x=[-0.15520009  0.59199973  0.36640048] y=4.25077604802532e-12 grad=[-3.29426143e-06 -1.02104578e-05  1.82975577e-05] grad_norm=2.121099297121382e-05
iter 39: x=[-0.15520006  0.59199981  0.36640035] y=2.223289943460344e-12 grad=[-2.38244015e-06 -7.38429697e-06  1.32329619e-05] grad_norm=1.5339985156425394e-05
iter 40: x=[-0.15520005  0.59199986  0.36640025] y=1.162850763666433e-12 grad=[-1.72300263e-06 -5.34039147e-06  9.57019976e-06] grad_norm=1.109401832157149e-05
iter 41: x=[-0.15520003  0.5919999   0.36640018] y=6.082076260088465e-13 grad=[-1.24609135e-06 -3.86222022e-06  6.92125647e-06] grad_norm=8.02329606364359e-06
iter 42: x=[-0.15520002  0.59199993  0.36640013] y=3.181117709801441e-13 grad=[-9.01184718e-07 -2.79319318e-06  5.00551632e-06] grad_norm=5.80252149114771e-06
iter 43: x=[-0.15520002  0.59199995  0.36640009] y=1.663824896935061e-13 grad=[-6.51745072e-07 -2.02006299e-06  3.62003543e-06] grad_norm=4.196436895314472e-06
iter 44: x=[-0.15520001  0.59199996  0.36640007] y=8.702328981561944e-14 grad=[-4.71348033e-07 -1.46092813e-06  2.61804291e-06] grad_norm=3.0349017515271214e-06
iter 45: x=[-0.15520001  0.59199997  0.36640005] y=4.551592522249042e-14 grad=[-3.40883227e-07 -1.05655666e-06  1.89339271e-06] grad_norm=2.1948688531434195e-06
iter 46: x=[-0.15520001  0.59199998  0.36640004] y=2.3806264492661115e-14 grad=[-2.46529886e-07 -7.64111489e-07  1.36931901e-06] grad_norm=1.587349335666307e-06
iter 47: x=[-0.1552      0.59199999  0.36640003] y=1.2451427244674976e-14 grad=[-1.78292686e-07 -5.52612456e-07  9.90304102e-07] grad_norm=1.147985636932428e-06
y_min:  1.2451427244674976e-14

方案二:所谓偏导,实际上就是在各个变量方向上计算斜率,那么可以用更普适的写法来规避推导公式

import numpy as np

def y(x):
    x1, x2, x3 = x[0], x[1], x[2]
    return (5 * x1 + 3 * x2 - 1) ** 2 + (3 * x1 + 4 * x3 - 1) ** 2
  
def grad(x):
    """
    更加普适的方法
    """
    delta = 1e-8
    d_x0 = (y([(x[0] + delta), x[1], x[2]]) - y(x)) / delta
    d_x1 = (y([x[0], (x[1] + delta), x[2]]) - y(x)) / delta
    d_x2 = (y([x[0], x[1], (x[2] + delta)]) - y(x)) / delta
    return np.array([d_x0, d_x1, d_x2])

lr = 0.01  # 步长
max_iter = 100  # 最大迭代次数
stop = 1e-6 # 梯度小于等于该值后停止迭代
x = np.array([1., 1., 1.])  # 初始值
for i in range(max_iter):
    g = grad(x)
    x = x - lr * g
    print('iter {}: x={} y={} grad={} grad_norm={}'.format(i, x, y(x), g, np.linalg.norm(g)))
    if np.all(np.abs(g) <= stop):
        break
print('y_min: ', y([x[0], x[1], x[2]]))

执行得到的输出:

iter 0: x=[-0.06000002  0.58000001  0.52      ] y=1.0035998786360978 grad=[106.00000167  41.99999921  47.99999971] grad_norm=123.70933779945847
iter 1: x=[-0.15800001  0.55360001  0.448     ] y=0.11781663499253381 grad=[9.79999952 2.63999975 7.19999989] grad_norm=12.443857428014446
iter 2: x=[-0.16416001  0.56135201  0.42256   ] y=0.057807938341557265 grad=[ 0.61599995 -0.77520009  2.544     ] grad_norm=2.729895076249535
iter 3: x=[-0.16235121  0.56955665  0.4067392 ] y=0.030199644492460674 grad=[-0.18087999 -0.82046405  1.58208002] grad_norm=1.791328002370116
iter 4: x=[-0.1604368   0.57574182  0.39554695] y=0.015795031442206904 grad=[-0.19144159 -0.61851653  1.11922561] grad_norm=1.2930114228260368
iter 5: x=[-0.15899359  0.58023933  0.38747675] y=0.008261296195812924 grad=[-0.1443205  -0.44975107  0.8070194 ] grad_norm=0.9350854177457265
iter 6: x=[-0.15794417  0.58349433  0.38164265] y=0.004320918234695532 grad=[-0.10494189 -0.32549973  0.58341011] grad_norm=0.6762619532127319
iter 7: x=[-0.15718467  0.5858486   0.3774236 ] y=0.002259976368159793 grad=[-0.07594991 -0.23542721  0.42190493] grad_norm=0.48907885743671226
iter 8: x=[-0.15663534  0.58755125  0.37437237] y=0.0011820388121374136 grad=[-0.05493299 -0.17026534  0.30512333] grad_norm=0.35370632660999723
iter 9: x=[-0.15623806  0.58878263  0.37216569] y=0.0006182435017803573 grad=[-0.03972856 -0.12313768  0.22066778] grad_norm=0.2558036679709723
iter 10: x=[-0.15595074  0.58967317  0.3705698 ] y=0.0003233607945279679 grad=[-0.0287321  -0.08905433  0.15958895] grad_norm=0.18499956450364702
iter 11: x=[-0.15574294  0.59031722  0.36941564] y=0.0001691278538960136 grad=[-0.02077932 -0.06440492  0.11541619] grad_norm=0.13379338586763737
iter 12: x=[-0.15559266  0.590783    0.36858094] y=8.84591723127018e-05 grad=[-0.01502779 -0.04657824  0.08347004] grad_norm=0.09676060737142307
iter 13: x=[-0.15548398  0.59111986  0.36797728] y=4.626691594568733e-05 grad=[-0.01086823 -0.03368582  0.0603663 ] grad_norm=0.06997816057333908
iter 14: x=[-0.15540538  0.59136348  0.3675407 ] y=2.4199040946547724e-05 grad=[-0.00786    -0.0243619   0.04365746] grad_norm=0.0506088492895908
iter 15: x=[-0.15534854  0.59153967  0.36722497] y=1.2656850584237097e-05 grad=[-0.00568442 -0.01761876  0.03157347] grad_norm=0.03660078503652166
iter 16: x=[-0.15530743  0.59166709  0.36699663] y=6.619924005703637e-06 grad=[-0.00411102 -0.01274205  0.02283422] grad_norm=0.02647002450959095
iter 17: x=[-0.1552777   0.59175924  0.36683149] y=3.4624232532766325e-06 grad=[-0.00297312 -0.00921518  0.01651392] grad_norm=0.019143364938747714
iter 18: x=[-0.1552562   0.59182588  0.36671206] y=1.810952377305285e-06 grad=[-0.00215019 -0.00666451  0.01194301] grad_norm=0.01384465753797714
iter 19: x=[-0.15524064  0.59187408  0.36662568] y=9.4718217360823e-07 grad=[-0.00155503 -0.00481984  0.00863729] grad_norm=0.010012583742810544
iter 20: x=[-0.1552294   0.59190894  0.36656322] y=4.954039568822264e-07 grad=[-0.00112461 -0.00348576  0.00624657] grad_norm=0.00724119258550554
iter 21: x=[-0.15522127  0.59193415  0.36651804] y=2.591103071282859e-07 grad=[-0.00081332 -0.00252094  0.00451757] grad_norm=0.005236897081079024
iter 22: x=[-0.15521538  0.59195238  0.36648537] y=1.3552171547618412e-07 grad=[-0.0005882  -0.00182318  0.00326715] grad_norm=0.0037873721090173297
iter 23: x=[-0.15521113  0.59196557  0.36646174] y=7.08813020998077e-08 grad=[-0.00042539 -0.00131855  0.00236283] grad_norm=0.002739062318989531
iter 24: x=[-0.15520805  0.5919751   0.36644465] y=3.707255678388081e-08 grad=[-0.00030764 -0.00095359  0.00170881] grad_norm=0.0019809150593567753
iter 25: x=[-0.15520583  0.591982    0.3664323 ] y=1.9389682790598714e-08 grad=[-0.00022248 -0.00068965  0.00123583] grad_norm=0.0014326159858460016
iter 26: x=[-0.15520422  0.59198699  0.36642336] y=1.0141102134263334e-08 grad=[-0.0001609  -0.00049877  0.00089376] grad_norm=0.0010360810528812455
iter 27: x=[-0.15520306  0.59199059  0.36641689] y=5.3038896407188295e-09 grad=[-0.00011636 -0.00036072  0.00064637] grad_norm=0.0007493033446737724
iter 28: x=[-0.15520221  0.5919932   0.36641222] y=2.7739377788545485e-09 grad=[-8.41463763e-05 -2.60885881e-04  4.67457893e-04] grad_norm=0.0005419030696615658
iter 29: x=[-0.15520161  0.59199509  0.36640884] y=1.4507386086007354e-09 grad=[-6.08509715e-05 -1.88682507e-04  3.38066497e-04] grad_norm=0.0003919092828529942
iter 30: x=[-0.15520117  0.59199645  0.36640639] y=7.586965719153159e-10 grad=[-4.40035182e-05 -1.36464365e-04  2.44489453e-04] grad_norm=0.00028343239969230184
iter 31: x=[-0.15520085  0.59199744  0.36640463] y=3.9676053172038427e-10 grad=[-3.18192859e-05 -9.86997238e-05  1.76813671e-04] grad_norm=0.00020498091786609265
iter 32: x=[-0.15520062  0.59199816  0.36640335] y=2.0747375812448486e-10 grad=[-2.30075351e-05 -7.13879870e-05  1.27869925e-04] grad_norm=0.00014824408608517318
iter 33: x=[-0.15520045  0.59199867  0.36640242] y=1.0848323754807052e-10 grad=[-1.66347964e-05 -5.16358892e-05  9.24733574e-05] grad_norm=0.00010721148879348808
iter 34: x=[-0.15520033  0.59199905  0.36640175] y=5.671706412707642e-11 grad=[-1.20259734e-05 -3.73509901e-05  6.68742345e-05] grad_norm=7.753633819927338e-05
iter 35: x=[-0.15520024  0.59199932  0.36640127] y=2.9648225375971145e-11 grad=[-8.69283154e-06 -2.70200198e-05  4.83607125e-05] grad_norm=5.607499709062693e-05
iter 36: x=[-0.15520018  0.59199951  0.36640092] y=1.549507150064275e-11 grad=[-6.28227063e-06 -1.95485665e-05  3.49715645e-05] grad_norm=4.0553960380623134e-05
iter 37: x=[-0.15520014  0.59199965  0.36640067] y=8.095926604511534e-12 grad=[-4.53893240e-06 -1.41451435e-05  2.52884088e-05] grad_norm=2.9329006278240914e-05
iter 38: x=[-0.1552001   0.59199976  0.36640049] y=4.228400087587236e-12 grad=[-3.27813344e-06 -1.02373379e-05  1.82854617e-05] grad_norm=2.1211019675900683e-05
iter 39: x=[-0.15520008  0.59199983  0.36640035] y=2.207340577392433e-12 grad=[-2.36631222e-06 -7.41117695e-06  1.32208658e-05] grad_norm=1.534002184680652e-05
iter 40: x=[-0.15520006  0.59199988  0.36640026] y=1.1515491611298115e-12 grad=[-1.70687453e-06 -5.36727143e-06  9.55810376e-06] grad_norm=1.10940691735635e-05
iter 41: x=[-0.15520005  0.59199992  0.36640019] y=6.002673272742683e-13 grad=[-1.22996334e-06 -3.88910022e-06  6.90916045e-06] grad_norm=8.023366399927621e-06
iter 42: x=[-0.15520004  0.59199995  0.36640014] y=3.126023968333678e-13 grad=[-8.85056713e-07 -2.82007319e-06  4.99342034e-06] grad_norm=5.80261879145478e-06
iter 43: x=[-0.15520003  0.59199997  0.3664001 ] y=1.6263118319242774e-13 grad=[-6.35617061e-07 -2.04694300e-06  3.60793941e-06] grad_norm=4.196571396136328e-06
iter 44: x=[-0.15520003  0.59199999  0.36640008] y=8.454343398301524e-14 grad=[-4.55220043e-07 -1.48780812e-06  2.60594690e-06] grad_norm=3.0350877292318492e-06
iter 45: x=[-0.15520003  0.592       0.36640006] y=4.395559434501895e-14 grad=[-3.24755215e-07 -1.08343666e-06  1.88129670e-06] grad_norm=2.1951260110189777e-06
iter 46: x=[-0.15520002  0.592       0.36640004] y=2.2910942287870438e-14 grad=[-2.30401882e-07 -7.90991486e-07  1.35722301e-06] grad_norm=1.5877049006594887e-06
iter 47: x=[-0.15520002  0.59200001  0.36640003] y=1.2037045292228435e-14 grad=[-1.62164681e-07 -5.79492454e-07  9.78208098e-07] grad_norm=1.148477240177836e-06
y_min:  1.2037045292228435e-14

方案三:利用tensorflow-2.0.0来实现

import tensorflow as tf
import numpy as np

# 使用cpu
tf.config.experimental.set_visible_devices([], device_type='GPU')

def f(x):
    return (5 * x[0] + 3 * x[1] - 1) ** 2 + (3 * x[0] + 4 * x[2] - 1) ** 2

def f(x):
    return (5 * x[0] + 3 * x[1] - 1) ** 2 + (3 * x[0] + 4 * x[2] - 1) ** 2

lr = 0.01  # 步长
max_iter = 100  # 最大迭代次数
stop = 1e-6 # 梯度小于等于该值后停止迭代
x = tf.constant([1., 1., 1.]) # 初始值
for i in range(max_iter):
    with tf.GradientTape() as tape:
        tape.watch([x])
        y = f(x)
    g = tape.gradient(y, [x])[0]
    x = x - lr * g
    print('iter {}: x={} y={} grad={} grad_norm={}'.format(i, x.numpy(), f(x.numpy()), g.numpy(), np.linalg.norm(g.numpy())))
    if np.all(np.abs(g.numpy()) <= stop):
        break
print('y_min: ', f(x.numpy()))

执行后输出:

iter 0: x=[-0.05999994  0.58000004  0.52      ] y=1.0036005367280865 grad=[106.  42.  48.] grad_norm=123.70933532714844
iter 1: x=[-0.15799999  0.5536      0.44799998] y=0.11781659378563702 grad=[9.800005  2.6400025 7.200001 ] grad_norm=12.443862915039062
iter 2: x=[-0.16416     0.561352    0.42255998] y=0.057807889841356985 grad=[ 0.61599994 -0.7751999   2.5439997 ] grad_norm=2.7298946380615234
iter 3: x=[-0.16235119  0.56955665  0.40673918] y=0.030199607572857445 grad=[-0.18088043 -0.8204638   1.5820789 ] grad_norm=1.7913269996643066
iter 4: x=[-0.16043676  0.5757418   0.39554694] y=0.015795018348294665 grad=[-0.19144213 -0.6185163   1.1192245 ] grad_norm=1.2930104732513428
iter 5: x=[-0.15899359  0.5802393   0.38747674] y=0.008261299425506241 grad=[-0.14431751 -0.4497496   0.8070202 ] grad_norm=0.9350849390029907
iter 6: x=[-0.15794416  0.5834943   0.38164264] y=0.004320916231244087 grad=[-0.10494292 -0.32550037  0.58341026] grad_norm=0.6762625575065613
iter 7: x=[-0.15718466  0.58584857  0.37742358] y=0.002259974556757527 grad=[-0.07595015 -0.23542714  0.42190456] grad_norm=0.489078551530838
iter 8: x=[-0.15663531  0.58755124  0.37437236] y=0.001182037898752597 grad=[-0.05493414 -0.17026556  0.30512238] grad_norm=0.3537057936191559
iter 9: x=[-0.15623803  0.5887826   0.37216568] y=0.0006182425492107235 grad=[-0.03972745 -0.123137    0.22066784] grad_norm=0.2558031976222992
iter 10: x=[-0.15595073  0.58967316  0.3705698 ] y=0.00032336028946389206 grad=[-0.02873111 -0.08905363  0.15958881] grad_norm=0.18499895930290222
iter 11: x=[-0.15574293  0.5903172   0.36941564] y=0.00016912902688170917 grad=[-0.02077973 -0.06440485  0.11541557] grad_norm=0.13379287719726562
iter 12: x=[-0.15559264  0.590783    0.36858094] y=8.84587340199694e-05 grad=[-0.01502872 -0.04657888  0.08347034] grad_norm=0.0967613235116005
iter 13: x=[-0.15548398  0.5911198   0.36797726] y=4.62670071059712e-05 grad=[-0.01086664 -0.03368497  0.06036663] grad_norm=0.0699777901172638
iter 14: x=[-0.15540537  0.59136343  0.3675407 ] y=2.419935945319196e-05 grad=[-0.0078603  -0.02436197  0.0436573 ] grad_norm=0.0506087951362133
iter 15: x=[-0.15534851  0.5915396   0.36722496] y=1.265716039178244e-05 grad=[-0.00568604 -0.01761961  0.0315733 ] grad_norm=0.036601293832063675
iter 16: x=[-0.1553074   0.59166706  0.36699662] y=6.619958903275602e-06 grad=[-0.00411105 -0.01274228  0.02283478] grad_norm=0.026470616459846497
iter 17: x=[-0.15527767  0.5917592   0.36683148] y=3.462529583941887e-06 grad=[-0.00297296 -0.009215    0.01651382] grad_norm=0.019143173471093178
iter 18: x=[-0.15525618  0.59182584  0.36671203] y=1.8109340462757473e-06 grad=[-0.00214946 -0.0066644   0.01194382] grad_norm=0.01384518388658762
iter 19: x=[-0.15524063  0.59187406  0.36662567] y=9.471220696610771e-07 grad=[-0.00155616 -0.00482011  0.00863647] grad_norm=0.010012180544435978
iter 20: x=[-0.15522937  0.59190893  0.3665632 ] y=4.953290013709477e-07 grad=[-0.00112534 -0.00348616  0.00624657] grad_norm=0.007241495884954929
iter 21: x=[-0.15522125  0.59193414  0.36651802] y=2.590420553616468e-07 grad=[-0.00081277 -0.00252056  0.00451756] grad_norm=0.005236614029854536
iter 22: x=[-0.15521537  0.5919524   0.36648536] y=1.35480251461928e-07 grad=[-0.00058889 -0.00182319  0.00326633] grad_norm=0.0037867859937250614
iter 23: x=[-0.15521112  0.59196556  0.36646172] y=7.085719921917644e-08 grad=[-0.00042403 -0.00131786  0.0023632 ] grad_norm=0.0027388480957597494
iter 24: x=[-0.15520804  0.5919751   0.36644465] y=3.706684070792221e-08 grad=[-0.00030804 -0.00095344  0.00170803] grad_norm=0.001980226021260023
iter 25: x=[-0.15520582  0.591982    0.36643228] y=1.9368842529843278e-08 grad=[-0.00022221 -0.00068951  0.00123596] grad_norm=0.0014326188247650862
iter 26: x=[-0.1552042   0.591987    0.36642334] y=1.0117107152041172e-08 grad=[-0.00016069 -0.00049853  0.00089359] grad_norm=0.0010357925202697515
iter 27: x=[-0.15520306  0.5919906   0.36641687] y=5.294389371357511e-09 grad=[-0.00011539 -0.00035977  0.00064564] grad_norm=0.0007480646600015461
iter 28: x=[-0.15520221  0.5919932   0.36641222] y=2.772573282072699e-09 grad=[-8.5353851e-05 -2.6106834e-04  4.6634674e-04] grad_norm=0.0005412219907157123
iter 29: x=[-0.1552016   0.5919951   0.36640885] y=1.4500369793779555e-09 grad=[-6.0915947e-05 -1.8846989e-04  3.3760071e-04] grad_norm=0.0003914152330253273
iter 30: x=[-0.15520117  0.5919965   0.3664064 ] y=7.580158722930719e-10 grad=[-4.2676926e-05 -1.3589859e-04  2.4509430e-04] grad_norm=0.0002834800980053842
iter 31: x=[-0.15520087  0.59199744  0.36640462] y=3.965698880392665e-10 grad=[-3.0279160e-05 -9.7990036e-05  1.7738342e-04] grad_norm=0.00020489937742240727
iter 32: x=[-0.15520063  0.59199816  0.36640334] y=2.0686208301867737e-10 grad=[-2.4080276e-05 -7.1525574e-05  1.2683868e-04] grad_norm=0.00014759342593606561
iter 33: x=[-0.15520045  0.5919987   0.36640242] y=1.0719425347360811e-10 grad=[-1.7166138e-05 -5.1498413e-05  9.1552734e-05] grad_norm=0.00010643620771588758
iter 34: x=[-0.15520033  0.59199905  0.36640176] y=5.6852300645005016e-11 grad=[-1.1920929e-05 -3.7193298e-05  6.6757202e-05] grad_norm=7.734322571195662e-05
iter 35: x=[-0.15520024  0.59199935  0.36640128] y=2.943423282886215e-11 grad=[-8.8214874e-06 -2.7179718e-05  4.8637390e-05] grad_norm=5.641056122840382e-05
iter 36: x=[-0.1552002   0.59199953  0.36640093] y=1.545474859199203e-11 grad=[-4.529953e-06 -1.859665e-05  3.528595e-05] grad_norm=4.014292062493041e-05
iter 37: x=[-0.15520014  0.59199965  0.3664007 ] y=8.51274606361585e-12 grad=[-5.2452087e-06 -1.4305115e-05  2.4795532e-05] grad_norm=2.9102697226335295e-05
iter 38: x=[-0.15520011  0.59199977  0.3664005 ] y=4.46620518346208e-12 grad=[-2.9802322e-06 -1.0371208e-05  1.9073486e-05] grad_norm=2.1914416720392182e-05
iter 39: x=[-0.15520008  0.5919998   0.3664004 ] y=2.5850432905372145e-12 grad=[-2.5033951e-06 -7.5101852e-06  1.3351440e-05] grad_norm=1.5521945897489786e-05
iter 40: x=[-0.15520008  0.5919999   0.36640027] y=1.2545520178264269e-12 grad=[-4.76837158e-07 -5.00679016e-06  1.04904175e-05] grad_norm=1.1633751455519814e-05
iter 41: x=[-0.15520006  0.59199995  0.3664002 ] y=6.572520305780927e-13 grad=[-2.1457672e-06 -4.2915344e-06  6.6757202e-06] grad_norm=8.221120879170485e-06
iter 42: x=[-0.15520005  0.592       0.36640015] y=2.6334490144108713e-13 grad=[-1.0728836e-06 -3.2186508e-06  5.7220459e-06] grad_norm=6.6522629822429735e-06
iter 43: x=[-0.15520005  0.592       0.36640012] y=1.674216321134736e-13 grad=[ 4.7683716e-07 -1.4305115e-06  3.8146973e-06] grad_norm=4.1019084164872766e-06
iter 44: x=[-0.15520005  0.592       0.3664001 ] y=9.992007221626409e-14 grad=[-2.3841858e-07 -1.4305115e-06  2.8610229e-06] grad_norm=3.207593863407965e-06
iter 45: x=[-0.15520003  0.592       0.36640006] y=4.440892098500626e-14 grad=[-9.5367432e-07 -1.4305115e-06  1.9073486e-06] grad_norm=2.5678466499812203e-06
iter 46: x=[-0.15520003  0.592       0.36640006] y=4.440892098500626e-14 grad=[-4.7683716e-07 -7.1525574e-07  9.5367432e-07] grad_norm=1.2839233249906101e-06
y_min:  4.440892098500626e-14

相关文章

  • 梯度下降(代码示例)

    梯度下降可以说是当代机器学习的基石。一切机器学习问题,归根到底都是求最优的问题,而求最优都会被转换成求最小值。梯度...

  • (三)线性回归--梯度下降

    一、梯度下降 二、代码的实现 (一.梯度下降) 导包 构建数据 梯度下降 使用梯度下降,可视化 (二。梯度下降矩阵...

  • ML-梯度下降代码-线性回归为例

    梯度下降代码线性回归为例 bgd 批量梯度下降 sbd 随机梯度下降 mbfd 小批量随机梯度下降

  • 关于梯度下降有关的理解及表述

    一、写在前面 二、梯度下降原理 三、梯度下降图形与代码表述 * coding:utf-8 * import num...

  • 梯度寻优

    参考:梯度下降算法的Python实现 批量梯度下降: 在上述代码中,nb_epochs为迭代次数;data是所有的...

  • 梯度色CPTGradient

    CPTGradientType(梯度类型) 该类型为枚举类型 代码示例1: 展示效果1: 代码示例2: 展示效果2...

  • 神经网络优化2

    梯度下降 梯度下降法 批梯度下降法(Batch Gradient Descent,BGD)是最常用的梯度下降形式,...

  • 深入浅出--梯度下降法及其实现

    梯度下降的场景假设梯度梯度下降算法的数学解释梯度下降算法的实例梯度下降算法的实现Further reading 本...

  • 机器学习-常用优化方法

    一阶方法:梯度下降、随机梯度下降、mini 随机梯度下降降法。 随机梯度下降不但速度上比原始梯度下降要快,局部最优...

  • 统计学习方法-感知机-python

    感知机算法 算法描述: 梯度下降 将损失函数梯度下降过程定义为函数 以下代码命名为perceptron.py 实例...

网友评论

    本文标题:梯度下降(代码示例)

    本文链接:https://www.haomeiwen.com/subject/wbhpyhtx.html