美文网首页
tf.function让开发者轻松创建TensorFlow静态图

tf.function让开发者轻松创建TensorFlow静态图

作者: LabVIEW_Python | 来源:发表于2021-10-05 16:23 被阅读0次

曾几何时, TensorFlow 1.x那种先构建静态图,然后再运行的方式,让无数人因为编码方式反直觉化,反Pythonic化,而觉得异常“恶心”

import tensorflow as tf
 
# Build a graph.
a = tf.constant([1.0, 2.0])
b = tf.constant([3.0, 4.0])
c = a * b
# 反直觉化,反Pythonic化编码方式,让人感到不舒服
with tf.Session() as sess:
    print sess.run(c)

TensorFlow 2.x后,引入了tf.function(),终于让开发者在获得静态图高性能的情况下,不用再使用session.run()了。

什么是tf.funciton()呢?
tf.funciton通过一个装饰器来对函数进行编译,转化为tensorflow的静态计算图进行计算。让开发者在eager mode下调试代码,然后通过装饰器@tf.function,非常方便的把调试好的代码一键转换为TensorFlow静态图,从而获得执行的高性能。

完整的tf.function性能试验代码:

import tensorflow as tf 
from tensorflow import keras 
import time

# 创建一个模型.
model = keras.Sequential(
    [
        keras.layers.Dense(32, activation="relu"),
        keras.layers.Dense(32, activation="relu"),
        keras.layers.Dense(10),
    ]
)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# 定义一个训练函数
# 在eager mode下,调试函数
# 调试完毕后加入装饰器@tf.function,可以把普通Python函数编译为TensorFlow静态图
# 大大提升执行效率

@tf.function  # 将普通Python函数编译为TensorFlow静态图,大大提高执行效率
def train_on_batch(x, y):
    with tf.GradientTape() as tape:
        logits = model(x)
        loss = loss_fn(y, logits)
        gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    return loss


# Prepare a dataset.
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(64)

start = time.time()

for step, (x, y) in enumerate(dataset):   
    loss = train_on_batch(x, y)
    if step % 100 == 0:
        print("Step:", step, "Loss:", float(loss))

end = time.time()
print("Time:", (end-start))

通过注释掉@tf.function可以比较eager mode下和编译为静态图后的执行效率差异。
Eager mode下执行普通Python函数:5.14s
编译为TensorFlow静态图后:1.62s
性能相差接近4倍!

TensorFlow静态图执行性能 eager mode执行性能

相关文章

网友评论

      本文标题:tf.function让开发者轻松创建TensorFlow静态图

      本文链接:https://www.haomeiwen.com/subject/uhsgnltx.html