美文网首页
tensorflow读取batch数据

tensorflow读取batch数据

作者: 乘瓠散人 | 来源:发表于2019-11-03 00:27 被阅读0次

实现每个epoch遍历所有样本,不同epoch数据重新打乱。

  • 方法1,使用tf.train.batch()
import numpy as np
import tensorflow as tf

def get_batch():
    datasets = np.asarray(range(0, 10))
    input_queue = tf.train.slice_input_producer([datasets], shuffle=True)
    data_batch = tf.train.batch(input_queue, batch_size=4)
    return data_batch

if __name__ == "__main__":
    data_batch = get_batch()
    sess = tf.Session()
    sess.run(tf.initialize_local_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)
    for epoch in range(3):
        print(epoch,'----------------')
        total_batch = 3
        for i in range(total_batch):
            data = sess.run([data_batch])
            print(data)

    coord.request_stop()
    coord.join(threads)
    sess.close()

运行结果:


image.png

由此可见基本上每个epoch内所有样本遍历一遍,但是不是完全严格,因为存在batch_size除不尽的原因。

  • 方法2,使用tf.data.Dataset
import numpy as np
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'

epochs = 3
bs = 4

dataset = tf.data.Dataset.from_tensor_slices(np.asarray(range(0, 10)))
dataset = dataset.shuffle(buffer_size=100).batch(bs).repeat(epochs)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()

if __name__ == "__main__":

    with tf.Session() as sess:
        for epoch in range(epochs):
            print(epoch, '----------------')
            total_batch = 3
            for i in range(total_batch):
                print(sess.run(data))

运行结果:


image.png

为了将不足batch_size大小的batch舍去,需要修改为:

dataset = dataset.shuffle(buffer_size=100).batch(bs, drop_remainder=True).repeat(epochs)

相关文章

网友评论

      本文标题:tensorflow读取batch数据

      本文链接:https://www.haomeiwen.com/subject/halivctx.html