这里只是一些简单的介绍
tf.function简单使用
假设我们要把一个模型的前向传播转化成静态图:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
class model(keras.Model):
def __init__(self):
pass
@tf.function
def call(x):
pass
这个装饰器对任何只包含tensor操作的函数都有效.
动态图和静态图转换时需要注意的区别
for, while
在eager执行模式下, 可以使用普通的python语法对流程进行控制, 但是在tf.function装饰的函数中, 要对上面的2种方式进行转化.
for
def funa():
for i in range(10):
pass
@tf.function
def funb():
for i in tf.range(10):
pass
while
def funa():
i = 0
while i < 10:
i += 1
pass
@tf.function
def funb():
i = tf.constant(0)
while i < tf.constant(10):
i += 1
pass
使用1.x的tf.cond
, tf.while_loop
的方式进行控制应该也是可以的.
在使用tf.function
装饰的函数中print只会在最初执行1次, tf.Variable()
也是. 如果要每次都执行需要使用tf.print
def funa():
i = 0
while i < 10:
print(i)
i += 1
@tf.function
def funb():
i = tf.constant(0)
while i < tf.constant(10):
tf.print(i)
i += 1
TensorArry
如果要使用类似python中类似list的数据结构, 可以使用tf.TensorArray
def funa():
i = 0
res = []
while i < 10:
print(i)
res.append(i)
i += 1
@tf.function
def funb():
i = tf.constant(0)
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size)
while i < tf.constant(10):
tf.print(i)
res = res.write(i, i) # 注意这个`=`, 如果只写res.write(i, i)会出错
i += 1
input_signature
@tf.function是支持多态的, 假设有以下函数
@tf.function
def fun(x, y, training=True):
return x + y
在x=tf.constant(0)
和y=tf.constant(1)
, x=tf.constant(0.0)
和y=tf.constant(1.0)
的情况下是会产生两个不同的静态图的, 甚至x=tf.constant(0)
和y=tf.constant(1)
, x=tf.constant(1)
和y=tf.constant(1)
都是两个不同的静态图, 因为他们的数据类型不同, 或者数值不同都会造成静态图不同, 这时候静态图可能比eager执行方式更加费时, 因为需要retracing是哪一张静态图. 所以在使用@tf.function时最好指定输入数据的类型和shape, 类似于tensorflow1.x中tf.placehold
的效果:
@tf.function(input_signature=(tf.TensorSpec(shape=None, dtype=tf.int32), tf.TensorSpec(shape=None, dtype=tf.int32))
def fun(x, y, training=True):
return x + y
此时输入x=tf.constant(0)
和y=tf.constant(1)
, x=tf.constant(1)
和y=tf.constant(1)
都会调用同一张静态图. 另外, 传入的每一个python类型也都会构造一个图, 所以最好把training=True
改为training=tf.constant(True)
.
shape
和tensorflow1.x中tf.shape于get_shape()/shape的区别类似, 在tf.function
装饰的函数中, 需要使用tf.shape()
获取tensor的shape, 而不能使用get_shape()
或者shape
. 否则会产生NoneType
错误.
网友评论