实现每个epoch遍历所有样本,不同epoch数据重新打乱。
- 方法1,使用
tf.train.batch()
import numpy as np
import tensorflow as tf
def get_batch():
datasets = np.asarray(range(0, 10))
input_queue = tf.train.slice_input_producer([datasets], shuffle=True)
data_batch = tf.train.batch(input_queue, batch_size=4)
return data_batch
if __name__ == "__main__":
data_batch = get_batch()
sess = tf.Session()
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
for epoch in range(3):
print(epoch,'----------------')
total_batch = 3
for i in range(total_batch):
data = sess.run([data_batch])
print(data)
coord.request_stop()
coord.join(threads)
sess.close()
运行结果:
![](https://img.haomeiwen.com/i5551994/e07d7d0ef5d8c456.png)
由此可见基本上每个epoch内所有样本遍历一遍,但是不是完全严格,因为存在batch_size除不尽的原因。
- 方法2,使用
tf.data.Dataset
import numpy as np
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
epochs = 3
bs = 4
dataset = tf.data.Dataset.from_tensor_slices(np.asarray(range(0, 10)))
dataset = dataset.shuffle(buffer_size=100).batch(bs).repeat(epochs)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
if __name__ == "__main__":
with tf.Session() as sess:
for epoch in range(epochs):
print(epoch, '----------------')
total_batch = 3
for i in range(total_batch):
print(sess.run(data))
运行结果:
![](https://img.haomeiwen.com/i5551994/82c6d56fdbabf83a.png)
为了将不足batch_size大小的batch舍去,需要修改为:
dataset = dataset.shuffle(buffer_size=100).batch(bs, drop_remainder=True).repeat(epochs)
网友评论