import tensorflow as tf
import numpy as np
"""解析数据"""
def _parse_function(example_proto):
features = {'images': tf.FixedLenFeature((), tf.string),
'labels': tf.FixedLenFeature((), tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
data = tf.decode_raw(parsed_features['images'], tf.float32)
return data, parsed_features['labels']
"""读取单一数据"""
def read_one_batch():
""""""
dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
dataset = dataset.map(_parse_function)
dataset = dataset.repeat(2)
dataset = dataset.batch(32)
iterator = dataset.make_one_shot_iterator()
next_data = iterator.get_next()
return next_data
"""读取指定batchsize的数据"""
def read_N_batch():
dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
dataset = dataset.map(_parse_function)
dataset = dataset.repeat(2)
batch = tf.placeholder(tf.int64, shape=[])
dataset = dataset.batch(batch)
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
"""读取不同类型的数据"""
def read_diff_batch():
tr_dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
tr_dataset = tr_dataset.map(_parse_function)
tr_dataset = tr_dataset.repeat(2)
tr_dataset = tr_dataset.batch(32)
te_dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
te_dataset = te_dataset.map(_parse_function)
te_dataset = te_dataset.repeat(2) # 整个数据集的循环次数
te_dataset = te_dataset.batch(16)
iterator = tf.data.Iterator.from_structure(tr_dataset.output_types,
tr_dataset.output_shapes)
train_op = iterator.make_initializer(tr_dataset)
test_op = iterator.make_initializer(te_dataset)
next_data = iterator.get_next()
return train_op, test_op, next_data
if __name__ == '__main__':
with tf.Session() as sess:
train_op, test_op, next_data = read_diff_batch
for _ in range(2):
sess.run(train_op)
for _ in range(3):
print(np.shape(sess.run(next_data, )[0]))
sess.run(test_op)
for _ in range(2):
print(np.shape(sess.run(next_data, )[0]))
网友评论