梯度下降可以说是当代机器学习的基石。一切机器学习问题,归根到底都是求最优的问题,而求最优都会被转换成求最小值。梯度下降就是用于求最小值的强力武器。
要理解梯度下降,需要知道两个数学概念:导数,偏导数。然后,网上讲解梯度下降的文章非常多,也基本都能把概念介绍清楚。
推荐阅读简书上的:《深入浅出--梯度下降法及其实现》
如果需要更深入的讲解,可以阅读知乎上的:《理解梯度下降法》
你以为看完这两篇文章后,就能写代码啦?通过下面两道题目来测验一下吧。要求用代码实现梯度下降算法来求解最小值。题目一是最基础的一元函数,题目二是三元函数。
题目一:求
的最小值
Python代码如下:
def y(x):
return x ** 2
def grad(x):
"""
求梯度值
"""
return 2 * x
lr = 0.1 # 步长
max_iter = 100 # 最大迭代次数
stop = 1e-6 # 梯度值小于等于此值后停止迭代
x = 1. # 指定一个初始值
for i in range(max_iter):
g = grad(x)
x = x - lr * g
print('iter: {}: x={} y={} grad={}'.format(i, x, y(x), g))
if abs(g) <= stop:
break
print('min y: ', y(x))
执行后输出:
iter: 0: x=0.8 y=0.6400000000000001 grad=2.0
iter: 1: x=0.64 y=0.4096 grad=1.6
iter: 2: x=0.512 y=0.262144 grad=1.28
iter: 3: x=0.4096 y=0.16777216 grad=1.024
iter: 4: x=0.32768 y=0.10737418240000002 grad=0.8192
iter: 5: x=0.26214400000000004 y=0.06871947673600003 grad=0.65536
iter: 6: x=0.20971520000000005 y=0.04398046511104002 grad=0.5242880000000001
iter: 7: x=0.16777216000000003 y=0.02814749767106561 grad=0.4194304000000001
iter: 8: x=0.13421772800000004 y=0.018014398509481992 grad=0.33554432000000006
iter: 9: x=0.10737418240000003 y=0.011529215046068476 grad=0.26843545600000007
iter: 10: x=0.08589934592000002 y=0.0073786976294838245 grad=0.21474836480000006
iter: 11: x=0.06871947673600001 y=0.004722366482869647 grad=0.17179869184000005
iter: 12: x=0.05497558138880001 y=0.003022314549036574 grad=0.13743895347200002
iter: 13: x=0.04398046511104001 y=0.0019342813113834073 grad=0.10995116277760002
iter: 14: x=0.035184372088832 y=0.0012379400392853804 grad=0.08796093022208001
iter: 15: x=0.028147497671065603 y=0.0007922816251426435 grad=0.070368744177664
iter: 16: x=0.02251799813685248 y=0.0005070602400912918 grad=0.056294995342131206
iter: 17: x=0.018014398509481985 y=0.00032451855365842676 grad=0.04503599627370496
iter: 18: x=0.014411518807585589 y=0.00020769187434139315 grad=0.03602879701896397
iter: 19: x=0.01152921504606847 y=0.0001329227995784916 grad=0.028823037615171177
iter: 20: x=0.009223372036854777 y=8.507059173023463e-05 grad=0.02305843009213694
iter: 21: x=0.007378697629483821 y=5.444517870735016e-05 grad=0.018446744073709553
iter: 22: x=0.005902958103587057 y=3.4844914372704106e-05 grad=0.014757395258967642
iter: 23: x=0.004722366482869646 y=2.230074519853063e-05 grad=0.011805916207174114
iter: 24: x=0.0037778931862957168 y=1.4272476927059604e-05 grad=0.009444732965739291
iter: 25: x=0.0030223145490365735 y=9.134385233318147e-06 grad=0.0075557863725914335
iter: 26: x=0.002417851639229259 y=5.846006549323614e-06 grad=0.006044629098073147
iter: 27: x=0.001934281311383407 y=3.741444191567113e-06 grad=0.004835703278458518
iter: 28: x=0.0015474250491067257 y=2.3945242826029522e-06 grad=0.003868562622766814
iter: 29: x=0.0012379400392853806 y=1.5324955408658897e-06 grad=0.0030948500982134514
iter: 30: x=0.0009903520314283045 y=9.807971461541695e-07 grad=0.002475880078570761
iter: 31: x=0.0007922816251426436 y=6.277101735386685e-07 grad=0.001980704062856609
iter: 32: x=0.000633825300114115 y=4.017345110647479e-07 grad=0.0015845632502852873
iter: 33: x=0.0005070602400912919 y=2.571100870814386e-07 grad=0.00126765060022823
iter: 34: x=0.0004056481920730336 y=1.6455045573212074e-07 grad=0.0010141204801825839
iter: 35: x=0.00032451855365842687 y=1.0531229166855728e-07 grad=0.0008112963841460671
iter: 36: x=0.0002596148429267415 y=6.739986666787666e-08 grad=0.0006490371073168537
iter: 37: x=0.0002076918743413932 y=4.313591466744107e-08 grad=0.000519229685853483
iter: 38: x=0.00016615349947311455 y=2.7606985387162278e-08 grad=0.0004153837486827864
iter: 39: x=0.00013292279957849163 y=1.7668470647783856e-08 grad=0.0003323069989462291
iter: 40: x=0.00010633823966279331 y=1.130782121458167e-08 grad=0.00026584559915698326
iter: 41: x=8.507059173023465e-05 y=7.2370055773322674e-09 grad=0.00021267647932558662
iter: 42: x=6.805647338418771e-05 y=4.631683569492651e-09 grad=0.0001701411834604693
iter: 43: x=5.444517870735017e-05 y=2.9642774844752965e-09 grad=0.00013611294676837542
iter: 44: x=4.3556142965880136e-05 y=1.8971375900641898e-09 grad=0.00010889035741470034
iter: 45: x=3.4844914372704106e-05 y=1.2141680576410813e-09 grad=8.711228593176027e-05
iter: 46: x=2.7875931498163285e-05 y=7.77067556890292e-10 grad=6.968982874540821e-05
iter: 47: x=2.230074519853063e-05 y=4.973232364097869e-10 grad=5.575186299632657e-05
iter: 48: x=1.7840596158824504e-05 y=3.1828687130226364e-10 grad=4.460149039706126e-05
iter: 49: x=1.4272476927059604e-05 y=2.0370359763344877e-10 grad=3.568119231764901e-05
iter: 50: x=1.1417981541647683e-05 y=1.3037030248540718e-10 grad=2.8544953854119208e-05
iter: 51: x=9.134385233318147e-06 y=8.343699359066061e-11 grad=2.2835963083295365e-05
iter: 52: x=7.307508186654517e-06 y=5.3399675898022794e-11 grad=1.8268770466636293e-05
iter: 53: x=5.8460065493236135e-06 y=3.417579257473458e-11 grad=1.4615016373309035e-05
iter: 54: x=4.676805239458891e-06 y=2.1872507247830133e-11 grad=1.1692013098647227e-05
iter: 55: x=3.741444191567113e-06 y=1.3998404638611287e-11 grad=9.353610478917782e-06
iter: 56: x=2.9931553532536903e-06 y=8.958978968711224e-12 grad=7.482888383134226e-06
iter: 57: x=2.3945242826029522e-06 y=5.733746539975183e-12 grad=5.986310706507381e-06
iter: 58: x=1.915619426082362e-06 y=3.669597785584118e-12 grad=4.7890485652059045e-06
iter: 59: x=1.5324955408658897e-06 y=2.3485425827738356e-12 grad=3.831238852164724e-06
iter: 60: x=1.2259964326927118e-06 y=1.503067252975255e-12 grad=3.0649910817317793e-06
iter: 61: x=9.807971461541693e-07 y=9.61963041904163e-13 grad=2.4519928653854235e-06
iter: 62: x=7.846377169233355e-07 y=6.156563468186643e-13 grad=1.9615942923083387e-06
iter: 63: x=6.277101735386684e-07 y=3.9402006196394517e-13 grad=1.569275433846671e-06
iter: 64: x=5.021681388309347e-07 y=2.5217283965692494e-13 grad=1.2554203470773367e-06
iter: 65: x=4.0173451106474777e-07 y=1.6139061738043196e-13 grad=1.0043362776618695e-06
iter: 66: x=3.213876088517982e-07 y=1.0328999512347644e-13 grad=8.034690221294955e-07
min y: 1.0328999512347644e-13
题目二:求
的最小值
方案一:推导
分别对
、
、
的偏导数公式
import numpy as np
def y(x):
x1, x2, x3 = x[0], x[1], x[2]
return (5 * x1 + 3 * x2 - 1) ** 2 + (3 * x1 + 4 * x3 - 1) ** 2
def grad(x):
"""
手推偏导公式
"""
x1, x2, x3 = x[0], x[1], x[2]
return np.array([68 * x1 + 30 * x2 + 24 * x3 - 16,
30 * x1 + 18 * x2 - 6,
24 * x1 + 32 * x3 - 8])
lr = 0.01 # 步长
max_iter = 100 # 最大迭代次数
stop = 1e-6 # 梯度小于等于该值后停止迭代
x = np.array([1., 1., 1.]) # 初始值
for i in range(max_iter):
g = grad(x)
x = x - lr * g
print('iter {}: x={} y={} grad={} grad_norm={}'.format(i, x, y(x), g, np.linalg.norm(g)))
if np.all(np.abs(g) <= stop):
break
print('y_min: ', y([x[0], x[1], x[2]]))
执行得到输出:
iter 0: x=[-0.06 0.58 0.52] y=1.0035999999999998 grad=[106. 42. 48.] grad_norm=123.70933675353692
iter 1: x=[-0.158 0.5536 0.448 ] y=0.11781664 grad=[9.8 2.64 7.2 ] grad_norm=12.443857922686194
iter 2: x=[-0.16416 0.561352 0.42256 ] y=0.05780793913599991 grad=[ 0.616 -0.7752 2.544 ] grad_norm=2.729895060254149
iter 3: x=[-0.1623512 0.56955664 0.4067392 ] y=0.030199645260006416 grad=[-0.18088 -0.820464 1.58208 ] grad_norm=1.791327964415225
iter 4: x=[-0.16043678 0.5757418 0.39554694] y=0.015795032234661038 grad=[-0.1914416 -0.61851648 1.1192256 ] grad_norm=1.293011394357185
iter 5: x=[-0.15899358 0.58023932 0.38747675] y=0.008261296899959504 grad=[-0.14432051 -0.44975103 0.80701939] grad_norm=0.9350853979569257
iter 6: x=[-0.15794416 0.58349431 0.38164265] y=0.004320918811588607 grad=[-0.10494191 -0.32549969 0.58341011] grad_norm=0.6762619393258339
iter 7: x=[-0.15718466 0.58584858 0.3774236 ] y=0.0022599768204799413 grad=[-0.07594993 -0.23542718 0.42190493] grad_norm=0.489078847504133
iter 8: x=[-0.15663533 0.58755124 0.37437237] y=0.0011820391570467469 grad=[-0.05493301 -0.17026531 0.30512334] grad_norm=0.3537063196554449
iter 9: x=[-0.15623804 0.58878261 0.37216569] y=0.0006182437607913295 grad=[-0.03972857 -0.12313765 0.22066779] grad_norm=0.2558036627980692
iter 10: x=[-0.15595072 0.58967316 0.3705698 ] y=0.0003233609863757719 grad=[-0.02873212 -0.0890543 0.15958895] grad_norm=0.1849995611186115
iter 11: x=[-0.15574293 0.59031721 0.36941564] y=0.0001691279947185792 grad=[-0.02077934 -0.06440489 0.1154162 ] grad_norm=0.13379338372139668
iter 12: x=[-0.15559265 0.59078299 0.36858094] y=8.845927555493002e-05 grad=[-0.01502781 -0.04657821 0.08347006] grad_norm=0.09676060537324213
iter 13: x=[-0.15548397 0.59111985 0.36797727] y=4.626699113132352e-05 grad=[-0.01086825 -0.03368579 0.06036631] grad_norm=0.06997815954556025
iter 14: x=[-0.15540537 0.59136346 0.3675407 ] y=2.4199095628099958e-05 grad=[-0.00786002 -0.02436187 0.04365747] grad_norm=0.05060884845123345
iter 15: x=[-0.15534853 0.59153965 0.36722496] y=1.265689025585839e-05 grad=[-0.00568444 -0.01761873 0.03157348] grad_norm=0.036600784561823196
iter 16: x=[-0.15530741 0.59166707 0.36699662] y=6.61995280363997e-06 grad=[-0.00411104 -0.01274203 0.02283423] grad_norm=0.02647002394910958
iter 17: x=[-0.15527768 0.59175922 0.36683148] y=3.4624441104046973e-06 grad=[-0.00297314 -0.00921515 0.01651393] grad_norm=0.019143364718940906
iter 18: x=[-0.15525618 0.59182587 0.36671205] y=1.8109674756416336e-06 grad=[-0.0021502 -0.00666448 0.01194302] grad_norm=0.013844657393094468
iter 19: x=[-0.15524063 0.59187407 0.36662568] y=9.471931078902917e-07 grad=[-0.00155505 -0.00481981 0.00863731] grad_norm=0.010012583532012379
iter 20: x=[-0.15522938 0.59190892 0.36656321] y=4.954118700101854e-07 grad=[-0.00112462 -0.00348573 0.00624658] grad_norm=0.00724119247873167
iter 21: x=[-0.15522125 0.59193413 0.36651804] y=2.591160333647869e-07 grad=[-0.00081334 -0.00252092 0.00451758] grad_norm=0.005236896985319308
iter 22: x=[-0.15521537 0.59195236 0.36648537] y=1.3552585800024948e-07 grad=[-0.00058821 -0.00182315 0.00326716] grad_norm=0.003787372054451165
iter 23: x=[-0.15521112 0.59196555 0.36646174] y=7.088429823583231e-08 grad=[-0.0004254 -0.00131852 0.00236284] grad_norm=0.0027390622956763953
iter 24: x=[-0.15520804 0.59197508 0.36644465] y=3.7074723676619534e-08 grad=[-0.00030765 -0.00095356 0.00170883] grad_norm=0.0019809150386429397
iter 25: x=[-0.15520581 0.59198198 0.36643229] y=1.9391249824058642e-08 grad=[-0.0002225 -0.00068963 0.00123584] grad_norm=0.0014326159709912123
iter 26: x=[-0.1552042 0.59198697 0.36642335] y=1.0142235260318824e-08 grad=[-0.00016091 -0.00049874 0.00089377] grad_norm=0.001036081043507717
iter 27: x=[-0.15520304 0.59199058 0.36641689] y=5.304708928439957e-09 grad=[-0.00011637 -0.0003607 0.00064638] grad_norm=0.0007493033377066719
iter 28: x=[-0.1552022 0.59199318 0.36641221] y=2.774530080738553e-09 grad=[-8.41625077e-05 -2.60858998e-04 4.67469984e-04] grad_norm=0.0005419030638727381
iter 29: x=[-0.15520159 0.59199507 0.36640883] y=1.4511667412422119e-09 grad=[-6.08670994e-05 -1.88655626e-04 3.38078591e-04] grad_norm=0.00039190927873711517
iter 30: x=[-0.15520115 0.59199644 0.36640639] y=7.59005975631017e-10 grad=[-4.40196460e-05 -1.36437483e-04 2.44501546e-04] grad_norm=0.0002834323940940813
iter 31: x=[-0.15520083 0.59199742 0.36640462] y=3.9698406438671973e-10 grad=[-3.18354128e-05 -9.86728425e-05 1.76825766e-04] grad_norm=0.00020498091364388083
iter 32: x=[-0.1552006 0.59199814 0.36640334] y=2.0763518660267846e-10 grad=[-2.30236632e-05 -7.13611070e-05 1.27882020e-04] grad_norm=0.00014824408160084345
iter 33: x=[-0.15520044 0.59199865 0.36640242] y=1.0859975142801017e-10 grad=[-1.66509250e-05 -5.16090088e-05 9.24854529e-05] grad_norm=0.00010721148295828789
iter 34: x=[-0.15520031 0.59199902 0.36640175] y=5.680109524789638e-11 grad=[-1.20421020e-05 -3.73241097e-05 6.68863299e-05] grad_norm=7.753633031283164e-05
iter 35: x=[-0.15520023 0.59199929 0.36640126] y=2.970876432768251e-11 grad=[-8.70895893e-06 -2.69931393e-05 4.83728088e-05] grad_norm=5.607498705048494e-05
iter 36: x=[-0.15520016 0.59199949 0.36640091] y=1.5538620761544232e-11 grad=[-6.29839918e-06 -1.95216866e-05 3.49836602e-05] grad_norm=4.0553946258328485e-05
iter 37: x=[-0.15520012 0.59199963 0.36640066] y=8.12718874970174e-12 grad=[-4.55506020e-06 -1.41182632e-05 2.53005047e-05] grad_norm=2.9328986837998224e-05
iter 38: x=[-0.15520009 0.59199973 0.36640048] y=4.25077604802532e-12 grad=[-3.29426143e-06 -1.02104578e-05 1.82975577e-05] grad_norm=2.121099297121382e-05
iter 39: x=[-0.15520006 0.59199981 0.36640035] y=2.223289943460344e-12 grad=[-2.38244015e-06 -7.38429697e-06 1.32329619e-05] grad_norm=1.5339985156425394e-05
iter 40: x=[-0.15520005 0.59199986 0.36640025] y=1.162850763666433e-12 grad=[-1.72300263e-06 -5.34039147e-06 9.57019976e-06] grad_norm=1.109401832157149e-05
iter 41: x=[-0.15520003 0.5919999 0.36640018] y=6.082076260088465e-13 grad=[-1.24609135e-06 -3.86222022e-06 6.92125647e-06] grad_norm=8.02329606364359e-06
iter 42: x=[-0.15520002 0.59199993 0.36640013] y=3.181117709801441e-13 grad=[-9.01184718e-07 -2.79319318e-06 5.00551632e-06] grad_norm=5.80252149114771e-06
iter 43: x=[-0.15520002 0.59199995 0.36640009] y=1.663824896935061e-13 grad=[-6.51745072e-07 -2.02006299e-06 3.62003543e-06] grad_norm=4.196436895314472e-06
iter 44: x=[-0.15520001 0.59199996 0.36640007] y=8.702328981561944e-14 grad=[-4.71348033e-07 -1.46092813e-06 2.61804291e-06] grad_norm=3.0349017515271214e-06
iter 45: x=[-0.15520001 0.59199997 0.36640005] y=4.551592522249042e-14 grad=[-3.40883227e-07 -1.05655666e-06 1.89339271e-06] grad_norm=2.1948688531434195e-06
iter 46: x=[-0.15520001 0.59199998 0.36640004] y=2.3806264492661115e-14 grad=[-2.46529886e-07 -7.64111489e-07 1.36931901e-06] grad_norm=1.587349335666307e-06
iter 47: x=[-0.1552 0.59199999 0.36640003] y=1.2451427244674976e-14 grad=[-1.78292686e-07 -5.52612456e-07 9.90304102e-07] grad_norm=1.147985636932428e-06
y_min: 1.2451427244674976e-14
方案二:所谓偏导,实际上就是在各个变量方向上计算斜率,那么可以用更普适的写法来规避推导公式
import numpy as np
def y(x):
x1, x2, x3 = x[0], x[1], x[2]
return (5 * x1 + 3 * x2 - 1) ** 2 + (3 * x1 + 4 * x3 - 1) ** 2
def grad(x):
"""
更加普适的方法
"""
delta = 1e-8
d_x0 = (y([(x[0] + delta), x[1], x[2]]) - y(x)) / delta
d_x1 = (y([x[0], (x[1] + delta), x[2]]) - y(x)) / delta
d_x2 = (y([x[0], x[1], (x[2] + delta)]) - y(x)) / delta
return np.array([d_x0, d_x1, d_x2])
lr = 0.01 # 步长
max_iter = 100 # 最大迭代次数
stop = 1e-6 # 梯度小于等于该值后停止迭代
x = np.array([1., 1., 1.]) # 初始值
for i in range(max_iter):
g = grad(x)
x = x - lr * g
print('iter {}: x={} y={} grad={} grad_norm={}'.format(i, x, y(x), g, np.linalg.norm(g)))
if np.all(np.abs(g) <= stop):
break
print('y_min: ', y([x[0], x[1], x[2]]))
执行得到的输出:
iter 0: x=[-0.06000002 0.58000001 0.52 ] y=1.0035998786360978 grad=[106.00000167 41.99999921 47.99999971] grad_norm=123.70933779945847
iter 1: x=[-0.15800001 0.55360001 0.448 ] y=0.11781663499253381 grad=[9.79999952 2.63999975 7.19999989] grad_norm=12.443857428014446
iter 2: x=[-0.16416001 0.56135201 0.42256 ] y=0.057807938341557265 grad=[ 0.61599995 -0.77520009 2.544 ] grad_norm=2.729895076249535
iter 3: x=[-0.16235121 0.56955665 0.4067392 ] y=0.030199644492460674 grad=[-0.18087999 -0.82046405 1.58208002] grad_norm=1.791328002370116
iter 4: x=[-0.1604368 0.57574182 0.39554695] y=0.015795031442206904 grad=[-0.19144159 -0.61851653 1.11922561] grad_norm=1.2930114228260368
iter 5: x=[-0.15899359 0.58023933 0.38747675] y=0.008261296195812924 grad=[-0.1443205 -0.44975107 0.8070194 ] grad_norm=0.9350854177457265
iter 6: x=[-0.15794417 0.58349433 0.38164265] y=0.004320918234695532 grad=[-0.10494189 -0.32549973 0.58341011] grad_norm=0.6762619532127319
iter 7: x=[-0.15718467 0.5858486 0.3774236 ] y=0.002259976368159793 grad=[-0.07594991 -0.23542721 0.42190493] grad_norm=0.48907885743671226
iter 8: x=[-0.15663534 0.58755125 0.37437237] y=0.0011820388121374136 grad=[-0.05493299 -0.17026534 0.30512333] grad_norm=0.35370632660999723
iter 9: x=[-0.15623806 0.58878263 0.37216569] y=0.0006182435017803573 grad=[-0.03972856 -0.12313768 0.22066778] grad_norm=0.2558036679709723
iter 10: x=[-0.15595074 0.58967317 0.3705698 ] y=0.0003233607945279679 grad=[-0.0287321 -0.08905433 0.15958895] grad_norm=0.18499956450364702
iter 11: x=[-0.15574294 0.59031722 0.36941564] y=0.0001691278538960136 grad=[-0.02077932 -0.06440492 0.11541619] grad_norm=0.13379338586763737
iter 12: x=[-0.15559266 0.590783 0.36858094] y=8.84591723127018e-05 grad=[-0.01502779 -0.04657824 0.08347004] grad_norm=0.09676060737142307
iter 13: x=[-0.15548398 0.59111986 0.36797728] y=4.626691594568733e-05 grad=[-0.01086823 -0.03368582 0.0603663 ] grad_norm=0.06997816057333908
iter 14: x=[-0.15540538 0.59136348 0.3675407 ] y=2.4199040946547724e-05 grad=[-0.00786 -0.0243619 0.04365746] grad_norm=0.0506088492895908
iter 15: x=[-0.15534854 0.59153967 0.36722497] y=1.2656850584237097e-05 grad=[-0.00568442 -0.01761876 0.03157347] grad_norm=0.03660078503652166
iter 16: x=[-0.15530743 0.59166709 0.36699663] y=6.619924005703637e-06 grad=[-0.00411102 -0.01274205 0.02283422] grad_norm=0.02647002450959095
iter 17: x=[-0.1552777 0.59175924 0.36683149] y=3.4624232532766325e-06 grad=[-0.00297312 -0.00921518 0.01651392] grad_norm=0.019143364938747714
iter 18: x=[-0.1552562 0.59182588 0.36671206] y=1.810952377305285e-06 grad=[-0.00215019 -0.00666451 0.01194301] grad_norm=0.01384465753797714
iter 19: x=[-0.15524064 0.59187408 0.36662568] y=9.4718217360823e-07 grad=[-0.00155503 -0.00481984 0.00863729] grad_norm=0.010012583742810544
iter 20: x=[-0.1552294 0.59190894 0.36656322] y=4.954039568822264e-07 grad=[-0.00112461 -0.00348576 0.00624657] grad_norm=0.00724119258550554
iter 21: x=[-0.15522127 0.59193415 0.36651804] y=2.591103071282859e-07 grad=[-0.00081332 -0.00252094 0.00451757] grad_norm=0.005236897081079024
iter 22: x=[-0.15521538 0.59195238 0.36648537] y=1.3552171547618412e-07 grad=[-0.0005882 -0.00182318 0.00326715] grad_norm=0.0037873721090173297
iter 23: x=[-0.15521113 0.59196557 0.36646174] y=7.08813020998077e-08 grad=[-0.00042539 -0.00131855 0.00236283] grad_norm=0.002739062318989531
iter 24: x=[-0.15520805 0.5919751 0.36644465] y=3.707255678388081e-08 grad=[-0.00030764 -0.00095359 0.00170881] grad_norm=0.0019809150593567753
iter 25: x=[-0.15520583 0.591982 0.3664323 ] y=1.9389682790598714e-08 grad=[-0.00022248 -0.00068965 0.00123583] grad_norm=0.0014326159858460016
iter 26: x=[-0.15520422 0.59198699 0.36642336] y=1.0141102134263334e-08 grad=[-0.0001609 -0.00049877 0.00089376] grad_norm=0.0010360810528812455
iter 27: x=[-0.15520306 0.59199059 0.36641689] y=5.3038896407188295e-09 grad=[-0.00011636 -0.00036072 0.00064637] grad_norm=0.0007493033446737724
iter 28: x=[-0.15520221 0.5919932 0.36641222] y=2.7739377788545485e-09 grad=[-8.41463763e-05 -2.60885881e-04 4.67457893e-04] grad_norm=0.0005419030696615658
iter 29: x=[-0.15520161 0.59199509 0.36640884] y=1.4507386086007354e-09 grad=[-6.08509715e-05 -1.88682507e-04 3.38066497e-04] grad_norm=0.0003919092828529942
iter 30: x=[-0.15520117 0.59199645 0.36640639] y=7.586965719153159e-10 grad=[-4.40035182e-05 -1.36464365e-04 2.44489453e-04] grad_norm=0.00028343239969230184
iter 31: x=[-0.15520085 0.59199744 0.36640463] y=3.9676053172038427e-10 grad=[-3.18192859e-05 -9.86997238e-05 1.76813671e-04] grad_norm=0.00020498091786609265
iter 32: x=[-0.15520062 0.59199816 0.36640335] y=2.0747375812448486e-10 grad=[-2.30075351e-05 -7.13879870e-05 1.27869925e-04] grad_norm=0.00014824408608517318
iter 33: x=[-0.15520045 0.59199867 0.36640242] y=1.0848323754807052e-10 grad=[-1.66347964e-05 -5.16358892e-05 9.24733574e-05] grad_norm=0.00010721148879348808
iter 34: x=[-0.15520033 0.59199905 0.36640175] y=5.671706412707642e-11 grad=[-1.20259734e-05 -3.73509901e-05 6.68742345e-05] grad_norm=7.753633819927338e-05
iter 35: x=[-0.15520024 0.59199932 0.36640127] y=2.9648225375971145e-11 grad=[-8.69283154e-06 -2.70200198e-05 4.83607125e-05] grad_norm=5.607499709062693e-05
iter 36: x=[-0.15520018 0.59199951 0.36640092] y=1.549507150064275e-11 grad=[-6.28227063e-06 -1.95485665e-05 3.49715645e-05] grad_norm=4.0553960380623134e-05
iter 37: x=[-0.15520014 0.59199965 0.36640067] y=8.095926604511534e-12 grad=[-4.53893240e-06 -1.41451435e-05 2.52884088e-05] grad_norm=2.9329006278240914e-05
iter 38: x=[-0.1552001 0.59199976 0.36640049] y=4.228400087587236e-12 grad=[-3.27813344e-06 -1.02373379e-05 1.82854617e-05] grad_norm=2.1211019675900683e-05
iter 39: x=[-0.15520008 0.59199983 0.36640035] y=2.207340577392433e-12 grad=[-2.36631222e-06 -7.41117695e-06 1.32208658e-05] grad_norm=1.534002184680652e-05
iter 40: x=[-0.15520006 0.59199988 0.36640026] y=1.1515491611298115e-12 grad=[-1.70687453e-06 -5.36727143e-06 9.55810376e-06] grad_norm=1.10940691735635e-05
iter 41: x=[-0.15520005 0.59199992 0.36640019] y=6.002673272742683e-13 grad=[-1.22996334e-06 -3.88910022e-06 6.90916045e-06] grad_norm=8.023366399927621e-06
iter 42: x=[-0.15520004 0.59199995 0.36640014] y=3.126023968333678e-13 grad=[-8.85056713e-07 -2.82007319e-06 4.99342034e-06] grad_norm=5.80261879145478e-06
iter 43: x=[-0.15520003 0.59199997 0.3664001 ] y=1.6263118319242774e-13 grad=[-6.35617061e-07 -2.04694300e-06 3.60793941e-06] grad_norm=4.196571396136328e-06
iter 44: x=[-0.15520003 0.59199999 0.36640008] y=8.454343398301524e-14 grad=[-4.55220043e-07 -1.48780812e-06 2.60594690e-06] grad_norm=3.0350877292318492e-06
iter 45: x=[-0.15520003 0.592 0.36640006] y=4.395559434501895e-14 grad=[-3.24755215e-07 -1.08343666e-06 1.88129670e-06] grad_norm=2.1951260110189777e-06
iter 46: x=[-0.15520002 0.592 0.36640004] y=2.2910942287870438e-14 grad=[-2.30401882e-07 -7.90991486e-07 1.35722301e-06] grad_norm=1.5877049006594887e-06
iter 47: x=[-0.15520002 0.59200001 0.36640003] y=1.2037045292228435e-14 grad=[-1.62164681e-07 -5.79492454e-07 9.78208098e-07] grad_norm=1.148477240177836e-06
y_min: 1.2037045292228435e-14
方案三:利用tensorflow-2.0.0来实现
import tensorflow as tf
import numpy as np
# 使用cpu
tf.config.experimental.set_visible_devices([], device_type='GPU')
def f(x):
return (5 * x[0] + 3 * x[1] - 1) ** 2 + (3 * x[0] + 4 * x[2] - 1) ** 2
def f(x):
return (5 * x[0] + 3 * x[1] - 1) ** 2 + (3 * x[0] + 4 * x[2] - 1) ** 2
lr = 0.01 # 步长
max_iter = 100 # 最大迭代次数
stop = 1e-6 # 梯度小于等于该值后停止迭代
x = tf.constant([1., 1., 1.]) # 初始值
for i in range(max_iter):
with tf.GradientTape() as tape:
tape.watch([x])
y = f(x)
g = tape.gradient(y, [x])[0]
x = x - lr * g
print('iter {}: x={} y={} grad={} grad_norm={}'.format(i, x.numpy(), f(x.numpy()), g.numpy(), np.linalg.norm(g.numpy())))
if np.all(np.abs(g.numpy()) <= stop):
break
print('y_min: ', f(x.numpy()))
执行后输出:
iter 0: x=[-0.05999994 0.58000004 0.52 ] y=1.0036005367280865 grad=[106. 42. 48.] grad_norm=123.70933532714844
iter 1: x=[-0.15799999 0.5536 0.44799998] y=0.11781659378563702 grad=[9.800005 2.6400025 7.200001 ] grad_norm=12.443862915039062
iter 2: x=[-0.16416 0.561352 0.42255998] y=0.057807889841356985 grad=[ 0.61599994 -0.7751999 2.5439997 ] grad_norm=2.7298946380615234
iter 3: x=[-0.16235119 0.56955665 0.40673918] y=0.030199607572857445 grad=[-0.18088043 -0.8204638 1.5820789 ] grad_norm=1.7913269996643066
iter 4: x=[-0.16043676 0.5757418 0.39554694] y=0.015795018348294665 grad=[-0.19144213 -0.6185163 1.1192245 ] grad_norm=1.2930104732513428
iter 5: x=[-0.15899359 0.5802393 0.38747674] y=0.008261299425506241 grad=[-0.14431751 -0.4497496 0.8070202 ] grad_norm=0.9350849390029907
iter 6: x=[-0.15794416 0.5834943 0.38164264] y=0.004320916231244087 grad=[-0.10494292 -0.32550037 0.58341026] grad_norm=0.6762625575065613
iter 7: x=[-0.15718466 0.58584857 0.37742358] y=0.002259974556757527 grad=[-0.07595015 -0.23542714 0.42190456] grad_norm=0.489078551530838
iter 8: x=[-0.15663531 0.58755124 0.37437236] y=0.001182037898752597 grad=[-0.05493414 -0.17026556 0.30512238] grad_norm=0.3537057936191559
iter 9: x=[-0.15623803 0.5887826 0.37216568] y=0.0006182425492107235 grad=[-0.03972745 -0.123137 0.22066784] grad_norm=0.2558031976222992
iter 10: x=[-0.15595073 0.58967316 0.3705698 ] y=0.00032336028946389206 grad=[-0.02873111 -0.08905363 0.15958881] grad_norm=0.18499895930290222
iter 11: x=[-0.15574293 0.5903172 0.36941564] y=0.00016912902688170917 grad=[-0.02077973 -0.06440485 0.11541557] grad_norm=0.13379287719726562
iter 12: x=[-0.15559264 0.590783 0.36858094] y=8.84587340199694e-05 grad=[-0.01502872 -0.04657888 0.08347034] grad_norm=0.0967613235116005
iter 13: x=[-0.15548398 0.5911198 0.36797726] y=4.62670071059712e-05 grad=[-0.01086664 -0.03368497 0.06036663] grad_norm=0.0699777901172638
iter 14: x=[-0.15540537 0.59136343 0.3675407 ] y=2.419935945319196e-05 grad=[-0.0078603 -0.02436197 0.0436573 ] grad_norm=0.0506087951362133
iter 15: x=[-0.15534851 0.5915396 0.36722496] y=1.265716039178244e-05 grad=[-0.00568604 -0.01761961 0.0315733 ] grad_norm=0.036601293832063675
iter 16: x=[-0.1553074 0.59166706 0.36699662] y=6.619958903275602e-06 grad=[-0.00411105 -0.01274228 0.02283478] grad_norm=0.026470616459846497
iter 17: x=[-0.15527767 0.5917592 0.36683148] y=3.462529583941887e-06 grad=[-0.00297296 -0.009215 0.01651382] grad_norm=0.019143173471093178
iter 18: x=[-0.15525618 0.59182584 0.36671203] y=1.8109340462757473e-06 grad=[-0.00214946 -0.0066644 0.01194382] grad_norm=0.01384518388658762
iter 19: x=[-0.15524063 0.59187406 0.36662567] y=9.471220696610771e-07 grad=[-0.00155616 -0.00482011 0.00863647] grad_norm=0.010012180544435978
iter 20: x=[-0.15522937 0.59190893 0.3665632 ] y=4.953290013709477e-07 grad=[-0.00112534 -0.00348616 0.00624657] grad_norm=0.007241495884954929
iter 21: x=[-0.15522125 0.59193414 0.36651802] y=2.590420553616468e-07 grad=[-0.00081277 -0.00252056 0.00451756] grad_norm=0.005236614029854536
iter 22: x=[-0.15521537 0.5919524 0.36648536] y=1.35480251461928e-07 grad=[-0.00058889 -0.00182319 0.00326633] grad_norm=0.0037867859937250614
iter 23: x=[-0.15521112 0.59196556 0.36646172] y=7.085719921917644e-08 grad=[-0.00042403 -0.00131786 0.0023632 ] grad_norm=0.0027388480957597494
iter 24: x=[-0.15520804 0.5919751 0.36644465] y=3.706684070792221e-08 grad=[-0.00030804 -0.00095344 0.00170803] grad_norm=0.001980226021260023
iter 25: x=[-0.15520582 0.591982 0.36643228] y=1.9368842529843278e-08 grad=[-0.00022221 -0.00068951 0.00123596] grad_norm=0.0014326188247650862
iter 26: x=[-0.1552042 0.591987 0.36642334] y=1.0117107152041172e-08 grad=[-0.00016069 -0.00049853 0.00089359] grad_norm=0.0010357925202697515
iter 27: x=[-0.15520306 0.5919906 0.36641687] y=5.294389371357511e-09 grad=[-0.00011539 -0.00035977 0.00064564] grad_norm=0.0007480646600015461
iter 28: x=[-0.15520221 0.5919932 0.36641222] y=2.772573282072699e-09 grad=[-8.5353851e-05 -2.6106834e-04 4.6634674e-04] grad_norm=0.0005412219907157123
iter 29: x=[-0.1552016 0.5919951 0.36640885] y=1.4500369793779555e-09 grad=[-6.0915947e-05 -1.8846989e-04 3.3760071e-04] grad_norm=0.0003914152330253273
iter 30: x=[-0.15520117 0.5919965 0.3664064 ] y=7.580158722930719e-10 grad=[-4.2676926e-05 -1.3589859e-04 2.4509430e-04] grad_norm=0.0002834800980053842
iter 31: x=[-0.15520087 0.59199744 0.36640462] y=3.965698880392665e-10 grad=[-3.0279160e-05 -9.7990036e-05 1.7738342e-04] grad_norm=0.00020489937742240727
iter 32: x=[-0.15520063 0.59199816 0.36640334] y=2.0686208301867737e-10 grad=[-2.4080276e-05 -7.1525574e-05 1.2683868e-04] grad_norm=0.00014759342593606561
iter 33: x=[-0.15520045 0.5919987 0.36640242] y=1.0719425347360811e-10 grad=[-1.7166138e-05 -5.1498413e-05 9.1552734e-05] grad_norm=0.00010643620771588758
iter 34: x=[-0.15520033 0.59199905 0.36640176] y=5.6852300645005016e-11 grad=[-1.1920929e-05 -3.7193298e-05 6.6757202e-05] grad_norm=7.734322571195662e-05
iter 35: x=[-0.15520024 0.59199935 0.36640128] y=2.943423282886215e-11 grad=[-8.8214874e-06 -2.7179718e-05 4.8637390e-05] grad_norm=5.641056122840382e-05
iter 36: x=[-0.1552002 0.59199953 0.36640093] y=1.545474859199203e-11 grad=[-4.529953e-06 -1.859665e-05 3.528595e-05] grad_norm=4.014292062493041e-05
iter 37: x=[-0.15520014 0.59199965 0.3664007 ] y=8.51274606361585e-12 grad=[-5.2452087e-06 -1.4305115e-05 2.4795532e-05] grad_norm=2.9102697226335295e-05
iter 38: x=[-0.15520011 0.59199977 0.3664005 ] y=4.46620518346208e-12 grad=[-2.9802322e-06 -1.0371208e-05 1.9073486e-05] grad_norm=2.1914416720392182e-05
iter 39: x=[-0.15520008 0.5919998 0.3664004 ] y=2.5850432905372145e-12 grad=[-2.5033951e-06 -7.5101852e-06 1.3351440e-05] grad_norm=1.5521945897489786e-05
iter 40: x=[-0.15520008 0.5919999 0.36640027] y=1.2545520178264269e-12 grad=[-4.76837158e-07 -5.00679016e-06 1.04904175e-05] grad_norm=1.1633751455519814e-05
iter 41: x=[-0.15520006 0.59199995 0.3664002 ] y=6.572520305780927e-13 grad=[-2.1457672e-06 -4.2915344e-06 6.6757202e-06] grad_norm=8.221120879170485e-06
iter 42: x=[-0.15520005 0.592 0.36640015] y=2.6334490144108713e-13 grad=[-1.0728836e-06 -3.2186508e-06 5.7220459e-06] grad_norm=6.6522629822429735e-06
iter 43: x=[-0.15520005 0.592 0.36640012] y=1.674216321134736e-13 grad=[ 4.7683716e-07 -1.4305115e-06 3.8146973e-06] grad_norm=4.1019084164872766e-06
iter 44: x=[-0.15520005 0.592 0.3664001 ] y=9.992007221626409e-14 grad=[-2.3841858e-07 -1.4305115e-06 2.8610229e-06] grad_norm=3.207593863407965e-06
iter 45: x=[-0.15520003 0.592 0.36640006] y=4.440892098500626e-14 grad=[-9.5367432e-07 -1.4305115e-06 1.9073486e-06] grad_norm=2.5678466499812203e-06
iter 46: x=[-0.15520003 0.592 0.36640006] y=4.440892098500626e-14 grad=[-4.7683716e-07 -7.1525574e-07 9.5367432e-07] grad_norm=1.2839233249906101e-06
y_min: 4.440892098500626e-14
网友评论