什么是TFRecord?
TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。
image.png实际上,TFRecord是一种二进制文件,其能更好的利用内存,其内部包含了多个tf.train.Example, 而Example是protocol buffer(protobuf) 数据标准的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而 每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:
bytes_list: 可以存储string 和byte两种数据类型。
float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。
值的一提的是,TensorFlow 源码中到处可见.proto 的文件,且这些文件定义了TensorFlow重要的数据结构部分,且多种语言可直接使用这类数据,很强大。
protobuf这个锤锤
优点:
- 平台无关,语言无关,可扩展;
- 提供了友好的动态库,使用简单;
- 解析速度快,比对应的XML快约20-100倍;
-
序列化数据非常简洁、紧凑,与XML相比,其序列化之后的数据量约为1/3到1/10。
image.png
安装 https://blog.csdn.net/xxjuanq_only_one/article/details/50465272
import "Common.proto"; // 引入Common.proto,位于Protobuf sdk中
option optimize_for = LITE_RUNTIME;
option java_package = "com.xxxx.entity.pb"; // 生成类的包名
option java_outer_classname = "PayInfo"; // 生成类的类名
message PayInfo{
required string payid = 1; // 支付相关的字段信息
optional string goodinfo = 2; // optional 为可选参数
required string prepayid = 3; // required为必填参数
optional string mode = 4;
optional int userid = 5;
repeated string extra = 6; // repeated 为数组
}
protoc --java_out ./ ./PayInfo.proto
Why 用TFRecord 这个锤锤 ?
TFRecord 并非是TensorFlow唯一支持的数据格式,你也可以使用CSV或文本等格式,但是对于TensorFlow来说,TFRecord 是最友好的,也是最方便的。前面提到,TFRecord内部是一系列实现了protocol buffer数据标准的Example,对于大型数据,对比其余数据格式,protocol buffer类型的数据优势很明显。
转TFrecord
writer = tf.python_io.TFRecordWriter(out_file_name) # 1. 定义 writer对象
for data in dataes:
context = dataes[0]
question = dataes[1]
answer = dataes[2]
""" 2. 定义features """
example = tf.train.Example(
features = tf.train.Features(
feature = {
'context': tf.train.Feature(
int64_list=tf.train.Int64List(value=context)),
'question': tf.train.Feature(
int64_list=tf.train.Int64List(value=question)),
'answer': tf.train.Feature(
int64_list=tf.train.Int64List(value=answer))
}))
读取API
https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
举个🌰
from __future__ import absolute_import, division, print_function
import csv
import requests
import tensorflow as tf
# Download Titanic dataset (in csv format).
d = requests.get("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/titanic_dataset.csv")
with open("titanic_dataset.csv", "wb") as f:
f.write(d.content)
# Generate Integer Features.
def build_int64_feature(data):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
# Generate Float Features.
def build_float_feature(data):
return tf.train.Feature(float_list=tf.train.FloatList(value=[data]))
# Generate String Features.
def build_string_feature(data):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
# Generate a TF `Example`, parsing all features of the dataset.
def convert_to_tfexample(survived, pclass, name, sex, age, sibsp, parch, ticket, fare):
return tf.train.Example(
features=tf.train.Features(
feature={
'survived': build_int64_feature(survived),
'pclass': build_int64_feature(pclass),
'name': build_string_feature(name),
'sex': build_string_feature(sex),
'age': build_float_feature(age),
'sibsp': build_int64_feature(sibsp),
'parch': build_int64_feature(parch),
'ticket': build_string_feature(ticket),
'fare': build_float_feature(fare),
})
)
# Open dataset file.
with open("titanic_dataset.csv") as f:
# Output TFRecord file.
with tf.io.TFRecordWriter("titanic_dataset.tfrecord") as w:
# Generate a TF Example for all row in our dataset.
# CSV reader will read and parse all rows.
reader = csv.reader(f, skipinitialspace=True)
for i, record in enumerate(reader):
# Skip header.
if i == 0:
continue
survived, pclass, name, sex, age, sibsp, parch, ticket, fare = record
# Parse each csv row to TF Example using the above functions.
example = convert_to_tfexample(int(survived), int(pclass), name, sex, float(age), int(sibsp), int(parch), ticket, float(fare))
# Serialize each TF Example to string, and write to TFRecord file.
w.write(example.SerializeToString())
# Build features template, with types.
features = {
'survived': tf.io.FixedLenFeature([], tf.int64),
'pclass': tf.io.FixedLenFeature([], tf.int64),
'name': tf.io.FixedLenFeature([], tf.string),
'sex': tf.io.FixedLenFeature([], tf.string),
'age': tf.io.FixedLenFeature([], tf.float32),
'sibsp': tf.io.FixedLenFeature([], tf.int64),
'parch': tf.io.FixedLenFeature([], tf.int64),
'ticket': tf.io.FixedLenFeature([], tf.string),
'fare': tf.io.FixedLenFeature([], tf.float32),
}
# Create TensorFlow session.
sess = tf.Session()
# Load TFRecord data.
filenames = ["titanic_dataset.tfrecord"]
data = tf.data.TFRecordDataset(filenames)
# Parse features, using the above template.
def parse_record(record):
return tf.io.parse_single_example(record, features=features)
# Apply the parsing to each record from the dataset.
data = data.map(parse_record)
# Refill data indefinitely.
data = data.repeat()
# Shuffle data.
data = data.shuffle(buffer_size=1000)
# Batch data (aggregate records together).
data = data.batch(batch_size=4)
# Prefetch batch (pre-load batch for faster consumption).
data = data.prefetch(buffer_size=1)
# Create an iterator over the dataset.
iterator = data.make_initializable_iterator()
# Initialize the iterator.
sess.run(iterator.initializer)
# Get next data batch.
x = iterator.get_next()
# Dequeue data and display.
for i in range(3):
print(sess.run(x))
print("")
网友评论