美文网首页
手写简单梯度下降

手写简单梯度下降

作者: zeolite | 来源:发表于2021-05-15 21:22 被阅读0次
import numpy as np

x = np.linspace(1, 10, 10)
y = x + 0.5

w = np.random.rand() * 10
b = np.random.rand()
lr = 0.001
for n in range(1000):
    loss = np.zeros_like(x)
    for i in range(len(x)):
        y_ = w * x[i] + b
        loss[i] = (y[i] - y_) ** 2
    w = w - lr * x[i]
    b = b - lr
    mse = loss.mean()
    print(n, ':', mse)
    if (mse < 0.01):
        break
print(w, b)

相关文章

网友评论

      本文标题:手写简单梯度下降

      本文链接:https://www.haomeiwen.com/subject/gymkjltx.html