美文网首页
TensorFlow2.0 使用tf.function装饰器将动

TensorFlow2.0 使用tf.function装饰器将动

作者: 又双叒叕苟了一天 | 来源:发表于2021-01-24 15:20 被阅读0次

    这里只是一些简单的介绍

    tf.function简单使用

    假设我们要把一个模型的前向传播转化成静态图:

    import tensorflow as tf
    import tensorflow.keras as keras
    import tensorflow.keras.layers as layers
    
    class model(keras.Model):
        def __init__(self):
            pass
    
        @tf.function
        def call(x):
            pass
    

    这个装饰器对任何只包含tensor操作的函数都有效.

    动态图和静态图转换时需要注意的区别

    for, while

    在eager执行模式下, 可以使用普通的python语法对流程进行控制, 但是在tf.function装饰的函数中, 要对上面的2种方式进行转化.

    for

    def funa():
        for i in range(10):
            pass
    
    @tf.function
    def funb():
        for i in tf.range(10):
            pass
    

    while

    def funa():
        i = 0
        while i < 10:
            i += 1
            pass
    
    @tf.function
    def funb():
        i = tf.constant(0)
        while i < tf.constant(10):
            i += 1
            pass
    

    使用1.x的tf.cond, tf.while_loop的方式进行控制应该也是可以的.

    print

    在使用tf.function装饰的函数中print只会在最初执行1次, tf.Variable()也是. 如果要每次都执行需要使用tf.print

    def funa():
        i = 0
        while i < 10:
            print(i)
            i += 1
    
    @tf.function
    def funb():
        i = tf.constant(0)
        while i < tf.constant(10):
            tf.print(i)
            i += 1
    

    TensorArry

    如果要使用类似python中类似list的数据结构, 可以使用tf.TensorArray

    def funa():
        i = 0
        res = []
        while i < 10:
            print(i)
            res.append(i)
            i += 1
    
    @tf.function
    def funb():
        i = tf.constant(0)
        res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size)
        while i < tf.constant(10):
            tf.print(i)
            res = res.write(i, i)  # 注意这个`=`, 如果只写res.write(i, i)会出错
            i += 1
    

    input_signature

    @tf.function是支持多态的, 假设有以下函数

    @tf.function
    def fun(x, y, training=True):
        return x + y
    

    x=tf.constant(0)y=tf.constant(1), x=tf.constant(0.0)y=tf.constant(1.0)的情况下是会产生两个不同的静态图的, 甚至x=tf.constant(0)y=tf.constant(1), x=tf.constant(1)y=tf.constant(1) 都是两个不同的静态图, 因为他们的数据类型不同, 或者数值不同都会造成静态图不同, 这时候静态图可能比eager执行方式更加费时, 因为需要retracing是哪一张静态图. 所以在使用@tf.function时最好指定输入数据的类型和shape, 类似于tensorflow1.x中tf.placehold的效果:

    @tf.function(input_signature=(tf.TensorSpec(shape=None, dtype=tf.int32), tf.TensorSpec(shape=None, dtype=tf.int32))
    def fun(x, y, training=True):
        return x + y
    

    此时输入x=tf.constant(0)y=tf.constant(1), x=tf.constant(1)y=tf.constant(1)都会调用同一张静态图. 另外, 传入的每一个python类型也都会构造一个图, 所以最好把training=True改为training=tf.constant(True).

    shape

    和tensorflow1.x中tf.shape于get_shape()/shape的区别类似, 在tf.function装饰的函数中, 需要使用tf.shape()获取tensor的shape, 而不能使用get_shape()或者shape. 否则会产生NoneType错误.

    相关文章

      网友评论

          本文标题:TensorFlow2.0 使用tf.function装饰器将动

          本文链接:https://www.haomeiwen.com/subject/ozsmzktx.html