背景:大家在使用tensorflow 训练model 的时候,如何更好更快的加载数据,tensorflow官方给出了tf record这种格式,这种数据格式无论低级别api/高级别api都可以加载。
1.什么是tf record,tf record是一种基于protobuffer的数据格式(对protobuffer,具体如下图所示,每一个训练样本就是一个Example这个结构体。我们只需要将之前每一个训练样本里面的多个特征,都按照下述格式存储起来即可。
tf record pb2.如何产出,比如我这里有一个训练样本,有四个特征分别是姓名(name),点击序列(click_list),ctr,week_num。下面就用几种方式来介绍如何生成tf_record. 第一种是使用tf 的api(这里采用tf2.0)。tf.train.Example.代码如下所示,这样就没有采用python 处理proto的语法格式直接来产出的。不需要引入proto 编译产出的xxxpb2.py 只需要import tensorflow as tf。
tf record 产出方法1产出方法2:采用python处理proto的一些方法。首先要生成 proto对应的python 类。 protoc tf_feature.proto --python_out=./
然后import 一下各类。采用python 原有api 去处理。
tfrecord 生产method2产出方法3,本来可以采用json_format parse,但是bytes_list 这个格式,老是报错。这里不过多介绍了。
如何读取:两种方法
(1)读成dict key是特征名字,value 是tensor。需要定义一个feature_schema解释每一个特征是fixedlen还是val len。 然后用tf.data.dataset 读入。结果如图2.
读取方法1 结果(2)直接采用pb的序列化反序列化。
读取方式2当然在实际生产项目中,我们需要采用mr/spark 道理是一样的。
网友评论