美文网首页TensorFlow2.0
【TensorFlow2】自定义函数 & 自动求导

【TensorFlow2】自定义函数 & 自动求导

作者: Hennyxu | 来源:发表于2020-03-23 23:33 被阅读0次

    TensorFlow2.0对自定义函数自动求导

    % Author: XuYihang

    本段代码实现了自定义函数f对x的求导,f是二范数形式,也是标量,x是列向量,在机器学习中是比较常见的求导形式:

    # 初始化参数
    z = tf.zeros((100,1))
    Lnz = tf.zeros((100,1))
    r = tf.zeros((99,1))
    
    # 写出函数的表达式,并将梯度信息记录在磁带tape中
    with tf.GradientTape() as tape:
        # 如果不加 persistent=True,tape.gradient()在调用一次后就会被释放
        # 添加 persistent=True 后,可多次调用tape.gradient(),最后通过del tape释放
        tape.watch(x)
        y = tf.matmul(A,x,transpose_a=True) + b 
        z = tf.sigmoid(y)
        Lnz = tf.math.log(z)
        r = 0.5 * tf.matmul(D,Lnz)
        # 矩阵的数乘也可以通过函数完成 —— r = tf.multiply(r,0.5)
        f = 0.5 * tf.reduce_sum(tf.square(tf.matmul(W, r) - s))
    
    df_dx = tape.gradient(f, x)
    # print(df_dx)
    
    df_dx = np.mat(df_dx.numpy())
    # print(df_dx)
    io.savemat('df_dx.mat',{'df_dx':df_dx})
    

    将matlab数据导入

    import tensorflow as tf
    import numpy as np
    import scipy.io as io
    
    # 读入matlab保存的.mat文件中的变量
    GradientComput = io.loadmat('GradientComputV2.mat')
    
    A = tf.constant(GradientComput['A'])
    b = tf.constant(GradientComput['b'])
    x = tf.Variable(initial_value = GradientComput['x'])
    W = tf.constant(GradientComput['W'])
    W = tf.cast(W, dtype = tf.float32)
    s = tf.constant(GradientComput['s'])
    s = tf.cast(s, dtype = tf.float32)
    D = tf.constant(GradientComput['D'])
    D = tf.cast(D, dtype = tf.float32)
    
    # 查看数据类型是否匹配,避免后续运算报错
    print(A.dtype, b.dtype, x.dtype, W.dtype, s.dtype, D.dtype)
    print(A.shape, b.shape, x.shape, W.shape, s.shape, D.shape)
    
    <dtype: 'float32'> <dtype: 'float32'> <dtype: 'float32'> <dtype: 'float32'> <dtype: 'float32'> <dtype: 'float32'>
    (80, 100) (100, 1) (80, 1) (198, 99) (198, 1) (99, 100)
    
    参考链接
    1. 简单粗暴TensorFlow2.0
    2. TensorFlow2.0 tutorials

    相关文章

      网友评论

        本文标题:【TensorFlow2】自定义函数 & 自动求导

        本文链接:https://www.haomeiwen.com/subject/rhbkyhtx.html