美文网首页
tensorflow dataset四种迭代器使用

tensorflow dataset四种迭代器使用

作者: 我爱吃海鲜 | 来源:发表于2019-02-02 17:44 被阅读0次

背景

本文介绍dataset的常见四种使用方式。废话不多说,直接进入正题。

dataset作为tf官方提供的数据处理api,在模型构建过程中必不可少。本文仅对dataset的迭代器进行介绍,tf官方给出了四种常见的迭代器:one_shot,initializable,reinitializable,feedable。

one_shot iterator就是该迭代器只会将对应的数据消费一次,消费完抛OutOfRangeError,通过外部包一层try-catch来捕获该异常。

initializable iterator在one_shot的基础上可以多次使用该iterator,但是每次消费完数据集后,需要重新初始化,简单来说,实现了单个iterator下单个dataset中填充数据的切换。

reinitalizable iterator实现单个iterator下多个dataset的切换。

feedable iterator实现了多个iterator下多个dataset的切换。

通常,我们在训练若干步后进行一次评估,对于这种常见,建议使用feedable iterator进行数据消费,实现的方式可见下文的代码,简单来说,就是创建同时创建训练数据集的迭代器和测试数据集的迭代器,通过handler来控制何时消费训练数据集,何时消费测试数据集。

文字始终是乏味的,show me the code:

1、导入环境

# import env

import tensorflow as tf

import numpy as np

2、构建数据

# data info : from numpy or from tfrecord

# fake data from numpy

train_np = (np.random.sample(size=(100,4)),np.random.randint(2,size=(100,1)))

test_np  = (np.random.sample(size=(20,4)),np.random.randint(2,size=(20,1)))

valid_np = (np.random.sample(size=(10,4)),np.random.randint(2,size=(10,1)))

3、one_shot iterator

# Usage 1: one_shot iterator and consume data

# adv: simple,one shot iterator without init,use it directly

# disadv: can only used once

train_iterator = train_data.make_one_shot_iterator()

test_iterator  = test_data.make_one_shot_iterator()

valid_iterator = valid_data.make_one_shot_iterator()

next_train = train_iterator.get_next()

next_test  = test_iterator.get_next()

next_valid = valid_iterator.get_next()

exit_num = 1

with tf.Session() as sess:

    # consume train sample

    i = 0

    while True:

        try:

            sample = sess.run(next_train)

            i += 1

            print("train_sample:",sample[0])

            if exit_num == i:

                break

        except tf.errors.OutOfRangeError:

            print("All Train Sample consumed!")

    # consume test sample

    i = 0

    while True:

        try:

            sample = sess.run(next_test)

            i += 1

            print("train_sample:",sample[0])

            if exit_num == i:

                break

        except tf.errors.OutOfRangeError:

            print("All Train Sample consumed!")

    # consume valid sample

    i = 0

    while True:

        try:

            sample = sess.run(next_valid)

            i += 1

            print("train_sample:",sample[0])

            if exit_num == i:

                break

        except tf.errors.OutOfRangeError:

            print("All Train Sample consumed!"

4、可初始化iterator

# Usage 2: initializable iterator

# adv : can be used any times,support whole data exchange(整体数据切换)

# disadv: init before use

# example: given train_data,test_data,we can consume all train_data and then consume test_data.

x = tf.placeholder(shape=[None,4],dtype=tf.float32)

y = tf.placeholder(shape=[None,1],dtype=tf.int32)

dataset = tf.data.Dataset.from_tensor_slices((x,y)).batch(2)

iterator = dataset.make_initializable_iterator()

next_ele = iterator.get_next()

with tf.Session() as sess:

    # consume all train data

    sess.run(iterator.initializer,feed_dict={x:train_np[0],y:train_np[1]})

    i = 0

    while True:

        try:

            ele = sess.run(next_ele)

            i += 1

            if i%10 ==0 or i==1:

                print("train step:%d,ele:%s" % (i,str(ele[0])))

        except tf.errors.OutOfRangeError:

            print("Consume All Train Sample!")

            break

    # consume all test data

    sess.run(iterator.initializer,feed_dict={x:test_np[0],y:test_np[1]})

    i = 0

    while True:

        try:

            ele = sess.run(next_ele)

            i += 1

            if i%10 ==0 or i==1:

                print("test step:%d,ele:%s" % (i,str(ele[0])))

        except tf.errors.OutOfRangeError:

            print("Consume All Test Sample!")

            break

5、可重复初始化iterator

# Usage 3:reInitializable iterator

# adv: similar with useage2,the differece is that use one iterator and exchange between dataset

# disadv: only one iterator

# first construct more than one dataset

train_data = tf.data.Dataset.from_tensor_slices(train_np).batch(5)

test_data = tf.data.Dataset.from_tensor_slices(test_np).batch(5)

valid_data = tf.data.Dataset.from_tensor_slices(valid_np).batch(5)

# construct one iterator

iterator = tf.data.Iterator.from_structure(output_shapes=train_data.output_shapes,output_types=train_data.output_types)

next_ele = iterator.get_next()

train_op = iterator.make_initializer(train_data)

test_op  = iterator.make_initializer(test_data)

valid_op = iterator.make_initializer(valid_data)

with tf.Session() as sess:

    # use train data and consume

    sess.run(train_op)

    ele = sess.run(next_ele)

    print("Train Sample:",ele[0])

    # use test data and consume

    sess.run(test_op)

    ele = sess.run(next_ele)

    print("Test Sample:",ele[0])

    # use valid data and consume

    sess.run(valid_op)

    ele = sess.run(next_ele)

    print("Valid Sample:",ele[0])

    # test:

    # conclusion: 再一次对已经消费的数据进行初始化并消费时,并不会在上次的基础上继续消费,而是重头开始消费

    sess.run(train_op)

    ele2 = sess.run(next_ele)

    print("Train Sample:",ele2[0])

6、feedable iterator

# Usage 4: feedable iterator

# adv: similar with usage3,the difference is that more than one iterator and exchange between iterator not dataset

# similar with usage3,construct datasets

train_np = (np.random.sample(size=(100,4)),np.random.randint(2,size=(100,1)))

test_np  = (np.random.sample(size=(20,4)),np.random.randint(2,size=(20,1)))

valid_np = (np.random.sample(size=(10,4)),np.random.randint(2,size=(10,1)))

valid_np2 = (np.random.sample(size=(10,4)),np.random.randint(2,size=(10,1)))

train_data = tf.data.Dataset.from_tensor_slices(train_np).batch(5)

test_data = tf.data.Dataset.from_tensor_slices(test_np).batch(5)

valid_data = tf.data.Dataset.from_tensor_slices(valid_np).batch(5)

valid_data2 = tf.data.Dataset.from_tensor_slices(valid_np2).batch(5)

# construct more than one iterators

train_iterator = train_data.make_one_shot_iterator() # usage1

test_iterator  = test_data.make_initializable_iterator() # usage2

valid_iterator = tf.data.Iterator.from_structure(output_shapes=valid_data.output_shapes,output_types=valid_data.output_types) # usage3

# used for choosing specific dataset

valid_op1 = valid_iterator.make_initializer(valid_data)

valid_op2 = valid_iterator.make_initializer(valid_data2)

# used for choosing specific iterator

handler = tf.placeholder(shape=[],dtype=tf.string,name='handler')

iterator = tf.data.Iterator.from_string_handle(string_handle=handler,output_shapes=train_iterator.output_shapes,output_types=train_iterator.output_types)

next_ele = iterator.get_next()

with tf.Session() as sess:

    sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])

    # consume train data

    # get train iterator handler and init iterator

    train_hd = sess.run(train_iterator.string_handle())

    ele = sess.run(next_ele,feed_dict={handler:train_hd})

    print("Train Sample:",ele[0])

    # consume test data

    # get test iterator handler and init iterator

    # first init test iterator

    sess.run(test_iterator.initializer)

    test_hd = sess.run(test_iterator.string_handle())

    ele = sess.run(next_ele,feed_dict={handler:test_hd})

    print("Test Sample:",ele[0])

    # consume valid data

    # get valid iterator handler and init iterator

    # first init valid iterator

    sess.run(valid_op2)

    valid_hd = sess.run(valid_iterator.string_handle())

    ele = sess.run(next_ele,feed_dict={handler:valid_hd})

    print("Valid Sample:",ele[0])

    # consume train data

    # get train iterator handler and init iterator

    train_hd = sess.run(train_iterator.string_handle())

    ele = sess.run(next_ele,feed_dict={handler:train_hd})

    print("Train Sample:",ele[0])

    # consume test data

    # get test iterator handler and init iterator

    # first init test iterator

#    sess.run(test_iterator.initializer)

    test_hd = sess.run(test_iterator.string_handle())

    ele = sess.run(next_ele,feed_dict={handler:test_hd})

    print("Test Sample:",ele[0])

参考文献:

1、How to use Dataset in TensorFlow

2、tf官方文档

相关文章

网友评论

      本文标题:tensorflow dataset四种迭代器使用

      本文链接:https://www.haomeiwen.com/subject/bngssqtx.html