# save numpy array as tfrecord
def save_tfrecords(datas, labels, outfile):
with tf.python_io.TFRecordWriter(outfile) as writer:
for i in range(len(datas)):
features = tf.train.Features(
feature={
"data":tf.train.Feature(bytes_list=tf.train.BytesList(value=[datas[i].astype(np.float32).tostring()])),
"label":tf.train.Feature(bytes_list=tf.train.BytesList(value=[labels[i].astype(np.float32).tostring()])),
"shape":tf.train.Feature(int64_list=tf.train.Int64List(value=[datas[i].shape[0], datas[i].shape[1]]))
}
)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)
train_record = './train.tfrecord'
valid_record = './valid.tfrecord'
if not os.path.exists(train_record):
datas, labels = prepare() # function of preparing your dataset
n = 1000 # valid num
save_tfrecords(datas[:-n], labels[:-n], train_record)
save_tfrecords(datas[-n:], labels[-n:], valid_record)
- tensorflow读取tfrecord文件用于网络训练
# read numpy array from tfrecord
def _parse_function(self, example_proto):
features = {"data": tf.FixedLenFeature((), tf.string),
"label": tf.FixedLenFeature((), tf.string),
"shape": tf.FixedLenFeature([2], tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
shape = parsed_features['shape']
data = tf.decode_raw(parsed_features['data'], tf.float32)
label = tf.decode_raw(parsed_features['label'], tf.float32)
return tf.reshape(data, shape), tf.reshape(label, shape)
if __name__ == "__main__":
epochs = 100
bs = 32
train_dataset = tf.data.TFRecordDataset('./train.tfrecord') # load tfrecord file
train_dataset = train_dataset.map(_parse_function) # parse data into tensor
train_dataset = train_dataset.shuffle(buffer_size=1000).batch(bs, drop_remainder=True).repeat(epochs)
valid_dataset = tf.data.TFRecordDataset('./valid.tfrecord')
valid_dataset = valid_dataset.map(_parse_function)
valid_dataset = valid_dataset.batch(bs, drop_remainder=True).repeat(epochs)
# make two handles of iterator to process trainset and validset separately
train_handle = train_dataset.make_one_shot_iterator().string_handle()
valid_handle = valid_dataset.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes
)
x, y = iterator.get_next()
loss = model_function()
train_op = tf.train.AdamOptimizer(0.001).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# must
handle_train, handle_valid = sess.run([train_handle, valid_handle])
for epoch in range(epochs):
print(epoch, '----------------')
# training
train_loss = []
for b in range(train_steps):
x_b, y_b = sess.run([x, y], feed_dict={handle: handle_train})
_, t_loss = sess.run([train_op, loss], feed_dict={data:x_b, label:y_b})
train_loss.append(t_loss)
# evaluation
valid_loss = []
for b in range(valid_steps):
x_v, y_v = sess.run([x, y], feed_dict={handle: handle_valid})
e_loss = sess.run([loss], feed_dict={data:x_v, label:y_v})
valid_loss.append(e_loss)
参考文章:
Tensorflow数据读写:Numpy存储为TFRecord文件与读取
网友评论