在学习TF2.0官方文档时的总结,跳转https://www.tensorflow.org/guide/data_performance#top_of_page
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
人工模拟一个数据读取过程:打开文件的时间和读取文件的时间
每个epoch的file打开、读取和训练是串行的,所以整个过程多需时间更长
class ArtificalDataset(tf.data.Dataset):
def _generator(num_samples):
time.sleep(0.03) # 打开文件
for sample_idex in range(num_samples):
time.sleep(0.015) # 读取数据
yield(sample_idex,)
def __new__(cls, num_samples=3):
return tf.data.Dataset.from_generator(
cls._generator,
output_types=tf.dtypes.int64,
output_shapes=(1,),
args=(num_samples,)
)
def benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
for sample in dataset:
time.sleep(0.015) # 每个训练
tf.print('Execution time: ', time.perf_counter() - start_time)
benchmark(ArtificalDataset())
对上述的数据方式进行改进,将数据预取和训练步骤的执行进行重叠
tf.data.Dataset.prefetch()
使用了一个后台进程和一个内部的缓存提前从数据集中获取数据,每次获取的数据必须大于等于单个训练步骤所需的数据,可以通过tf.data.experimental.AUTOTUNE
可以自动设置参数。
benchmark(
ArtificalDataset().prefetch(tf.data.experimental.AUTOTUNE)
)
当使用一个管道进行远程读取数据时就会产生I/O瓶颈。tf.data.Dataset.interleave()
可以将数据加载并行化。
cycle_length
参数用用来控制重叠执行的数量,num_paraller_calls
用来控制并行程度
benchmark(
tf.data.Dataset.range(2).interleave(ArtificalDataset)
)
使用参数num_paraller_calls
参数并行加载多个数据集
benchmark(
tf.data.Dataset.range(2).interleave(
ArtificalDataset,
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
)
tf.data.Dataset.map()
可以通过自定义函数对每个数据进行预处理,这一过程也可以并行
def mapped_function(s):
tf.py_function(lambda : time.sleep(0.03), [], ()) # 模仿预处理
return s
benchmark(
ArtificalDataset().map(mapped_function)
)
现在对该过程进行并行
benchmark(
ArtificalDataset().map(
mapped_function,
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
)
tf.data.Dataset.cache()
可以将数据缓存在本地或者内存
benchmark(
ArtificalDataset().map(
mapped_function
).cache(),
5
)
网友评论