背景
本文介绍dataset的常见四种使用方式。废话不多说,直接进入正题。
dataset作为tf官方提供的数据处理api,在模型构建过程中必不可少。本文仅对dataset的迭代器进行介绍,tf官方给出了四种常见的迭代器:one_shot,initializable,reinitializable,feedable。
one_shot iterator就是该迭代器只会将对应的数据消费一次,消费完抛OutOfRangeError,通过外部包一层try-catch来捕获该异常。
initializable iterator在one_shot的基础上可以多次使用该iterator,但是每次消费完数据集后,需要重新初始化,简单来说,实现了单个iterator下单个dataset中填充数据的切换。
reinitalizable iterator实现单个iterator下多个dataset的切换。
feedable iterator实现了多个iterator下多个dataset的切换。
通常,我们在训练若干步后进行一次评估,对于这种常见,建议使用feedable iterator进行数据消费,实现的方式可见下文的代码,简单来说,就是创建同时创建训练数据集的迭代器和测试数据集的迭代器,通过handler来控制何时消费训练数据集,何时消费测试数据集。
文字始终是乏味的,show me the code:
1、导入环境
# import env
import tensorflow as tf
import numpy as np
2、构建数据
# data info : from numpy or from tfrecord
# fake data from numpy
train_np = (np.random.sample(size=(100,4)),np.random.randint(2,size=(100,1)))
test_np = (np.random.sample(size=(20,4)),np.random.randint(2,size=(20,1)))
valid_np = (np.random.sample(size=(10,4)),np.random.randint(2,size=(10,1)))
3、one_shot iterator
# Usage 1: one_shot iterator and consume data
# adv: simple,one shot iterator without init,use it directly
# disadv: can only used once
train_iterator = train_data.make_one_shot_iterator()
test_iterator = test_data.make_one_shot_iterator()
valid_iterator = valid_data.make_one_shot_iterator()
next_train = train_iterator.get_next()
next_test = test_iterator.get_next()
next_valid = valid_iterator.get_next()
exit_num = 1
with tf.Session() as sess:
# consume train sample
i = 0
while True:
try:
sample = sess.run(next_train)
i += 1
print("train_sample:",sample[0])
if exit_num == i:
break
except tf.errors.OutOfRangeError:
print("All Train Sample consumed!")
# consume test sample
i = 0
while True:
try:
sample = sess.run(next_test)
i += 1
print("train_sample:",sample[0])
if exit_num == i:
break
except tf.errors.OutOfRangeError:
print("All Train Sample consumed!")
# consume valid sample
i = 0
while True:
try:
sample = sess.run(next_valid)
i += 1
print("train_sample:",sample[0])
if exit_num == i:
break
except tf.errors.OutOfRangeError:
print("All Train Sample consumed!"
4、可初始化iterator
# Usage 2: initializable iterator
# adv : can be used any times,support whole data exchange(整体数据切换)
# disadv: init before use
# example: given train_data,test_data,we can consume all train_data and then consume test_data.
x = tf.placeholder(shape=[None,4],dtype=tf.float32)
y = tf.placeholder(shape=[None,1],dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x,y)).batch(2)
iterator = dataset.make_initializable_iterator()
next_ele = iterator.get_next()
with tf.Session() as sess:
# consume all train data
sess.run(iterator.initializer,feed_dict={x:train_np[0],y:train_np[1]})
i = 0
while True:
try:
ele = sess.run(next_ele)
i += 1
if i%10 ==0 or i==1:
print("train step:%d,ele:%s" % (i,str(ele[0])))
except tf.errors.OutOfRangeError:
print("Consume All Train Sample!")
break
# consume all test data
sess.run(iterator.initializer,feed_dict={x:test_np[0],y:test_np[1]})
i = 0
while True:
try:
ele = sess.run(next_ele)
i += 1
if i%10 ==0 or i==1:
print("test step:%d,ele:%s" % (i,str(ele[0])))
except tf.errors.OutOfRangeError:
print("Consume All Test Sample!")
break
5、可重复初始化iterator
# Usage 3:reInitializable iterator
# adv: similar with useage2,the differece is that use one iterator and exchange between dataset
# disadv: only one iterator
# first construct more than one dataset
train_data = tf.data.Dataset.from_tensor_slices(train_np).batch(5)
test_data = tf.data.Dataset.from_tensor_slices(test_np).batch(5)
valid_data = tf.data.Dataset.from_tensor_slices(valid_np).batch(5)
# construct one iterator
iterator = tf.data.Iterator.from_structure(output_shapes=train_data.output_shapes,output_types=train_data.output_types)
next_ele = iterator.get_next()
train_op = iterator.make_initializer(train_data)
test_op = iterator.make_initializer(test_data)
valid_op = iterator.make_initializer(valid_data)
with tf.Session() as sess:
# use train data and consume
sess.run(train_op)
ele = sess.run(next_ele)
print("Train Sample:",ele[0])
# use test data and consume
sess.run(test_op)
ele = sess.run(next_ele)
print("Test Sample:",ele[0])
# use valid data and consume
sess.run(valid_op)
ele = sess.run(next_ele)
print("Valid Sample:",ele[0])
# test:
# conclusion: 再一次对已经消费的数据进行初始化并消费时,并不会在上次的基础上继续消费,而是重头开始消费
sess.run(train_op)
ele2 = sess.run(next_ele)
print("Train Sample:",ele2[0])
6、feedable iterator
# Usage 4: feedable iterator
# adv: similar with usage3,the difference is that more than one iterator and exchange between iterator not dataset
# similar with usage3,construct datasets
train_np = (np.random.sample(size=(100,4)),np.random.randint(2,size=(100,1)))
test_np = (np.random.sample(size=(20,4)),np.random.randint(2,size=(20,1)))
valid_np = (np.random.sample(size=(10,4)),np.random.randint(2,size=(10,1)))
valid_np2 = (np.random.sample(size=(10,4)),np.random.randint(2,size=(10,1)))
train_data = tf.data.Dataset.from_tensor_slices(train_np).batch(5)
test_data = tf.data.Dataset.from_tensor_slices(test_np).batch(5)
valid_data = tf.data.Dataset.from_tensor_slices(valid_np).batch(5)
valid_data2 = tf.data.Dataset.from_tensor_slices(valid_np2).batch(5)
# construct more than one iterators
train_iterator = train_data.make_one_shot_iterator() # usage1
test_iterator = test_data.make_initializable_iterator() # usage2
valid_iterator = tf.data.Iterator.from_structure(output_shapes=valid_data.output_shapes,output_types=valid_data.output_types) # usage3
# used for choosing specific dataset
valid_op1 = valid_iterator.make_initializer(valid_data)
valid_op2 = valid_iterator.make_initializer(valid_data2)
# used for choosing specific iterator
handler = tf.placeholder(shape=[],dtype=tf.string,name='handler')
iterator = tf.data.Iterator.from_string_handle(string_handle=handler,output_shapes=train_iterator.output_shapes,output_types=train_iterator.output_types)
next_ele = iterator.get_next()
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
# consume train data
# get train iterator handler and init iterator
train_hd = sess.run(train_iterator.string_handle())
ele = sess.run(next_ele,feed_dict={handler:train_hd})
print("Train Sample:",ele[0])
# consume test data
# get test iterator handler and init iterator
# first init test iterator
sess.run(test_iterator.initializer)
test_hd = sess.run(test_iterator.string_handle())
ele = sess.run(next_ele,feed_dict={handler:test_hd})
print("Test Sample:",ele[0])
# consume valid data
# get valid iterator handler and init iterator
# first init valid iterator
sess.run(valid_op2)
valid_hd = sess.run(valid_iterator.string_handle())
ele = sess.run(next_ele,feed_dict={handler:valid_hd})
print("Valid Sample:",ele[0])
# consume train data
# get train iterator handler and init iterator
train_hd = sess.run(train_iterator.string_handle())
ele = sess.run(next_ele,feed_dict={handler:train_hd})
print("Train Sample:",ele[0])
# consume test data
# get test iterator handler and init iterator
# first init test iterator
# sess.run(test_iterator.initializer)
test_hd = sess.run(test_iterator.string_handle())
ele = sess.run(next_ele,feed_dict={handler:test_hd})
print("Test Sample:",ele[0])
参考文献:
1、How to use Dataset in TensorFlow
2、tf官方文档
网友评论