美文网首页
JAX快速入门

JAX快速入门

作者: 辘轳鹿鹿 | 来源:发表于2022-09-17 09:28 被阅读0次

    JAX由autograd和XLA(accelerate linear algebra)组成

    • 做函数优化(感知机)
    import numpy as np
    
    def func(x,a,b):
        y = x*a+b
        return y
    
    def loss_function(weights,x,y):
        a,b = weights
        y_hat = func(x,a,b)
        return (y_hat-y)**2
    

    jax的作用就是引入梯度

    from jax import grad
    def f(x):
        return x**2
    df = grad(f)
    df(3.0)  #返回6.0
    
    a = np.random.random()
    b = np.random.random()
    weights = [a,b]
    x = np.array([np.random.random() for _ in range(1000)])
    y = np.array([3*xx+4 for xx in x])
    
    
    grad_func = grad(loss_func)
    grad_func(weights,x,y)
    
    
    
    learning_rate = 0.001
    for i in range(100):
        loss = loss_func(weights,x,y)
        da,db = grad_func(weights,x,y)
        a = a - learning_rate*da
        b = b - learning_rate*db
    

    相关文章

      网友评论

          本文标题:JAX快速入门

          本文链接:https://www.haomeiwen.com/subject/dzmjortx.html