美文网首页
tfrecord文件读写

tfrecord文件读写

作者: 乘瓠散人 | 来源:发表于2019-11-04 22:46 被阅读0次
  • 将数据集保存为tfrecord文件
# save numpy array as tfrecord
def save_tfrecords(datas, labels, outfile):
    with tf.python_io.TFRecordWriter(outfile) as writer:
        for i in range(len(datas)):
            features = tf.train.Features(
                feature={
                    "data":tf.train.Feature(bytes_list=tf.train.BytesList(value=[datas[i].astype(np.float32).tostring()])),
                    "label":tf.train.Feature(bytes_list=tf.train.BytesList(value=[labels[i].astype(np.float32).tostring()])),
                    "shape":tf.train.Feature(int64_list=tf.train.Int64List(value=[datas[i].shape[0], datas[i].shape[1]]))
                }
            )
            example = tf.train.Example(features=features)
            serialized = example.SerializeToString()
            writer.write(serialized)

train_record = './train.tfrecord'
valid_record = './valid.tfrecord'

if not os.path.exists(train_record):
    datas, labels = prepare()  # function of preparing your dataset
    n = 1000  # valid num
    save_tfrecords(datas[:-n], labels[:-n], train_record)
    save_tfrecords(datas[-n:], labels[-n:], valid_record)

  • tensorflow读取tfrecord文件用于网络训练
# read numpy array from tfrecord
def _parse_function(self, example_proto):
    features = {"data": tf.FixedLenFeature((), tf.string),
                "label": tf.FixedLenFeature((), tf.string),
                "shape": tf.FixedLenFeature([2], tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, features)
    shape = parsed_features['shape']
    data = tf.decode_raw(parsed_features['data'], tf.float32)
    label = tf.decode_raw(parsed_features['label'], tf.float32)
    return tf.reshape(data, shape), tf.reshape(label, shape)

if __name__ == "__main__":
    epochs = 100
    bs = 32

    train_dataset = tf.data.TFRecordDataset('./train.tfrecord')  # load tfrecord file
    train_dataset = train_dataset.map(_parse_function)  # parse data into tensor
    train_dataset = train_dataset.shuffle(buffer_size=1000).batch(bs, drop_remainder=True).repeat(epochs)

    valid_dataset = tf.data.TFRecordDataset('./valid.tfrecord')
    valid_dataset = valid_dataset.map(_parse_function)
    valid_dataset = valid_dataset.batch(bs, drop_remainder=True).repeat(epochs)

    # make two handles of iterator to process trainset and validset separately
    train_handle = train_dataset.make_one_shot_iterator().string_handle()
    valid_handle = valid_dataset.make_one_shot_iterator().string_handle()

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes
    )
    x, y = iterator.get_next()

    loss = model_function()
    train_op = tf.train.AdamOptimizer(0.001).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # must
        handle_train, handle_valid = sess.run([train_handle, valid_handle])
        for epoch in range(epochs):
            print(epoch, '----------------')
            # training
            train_loss = []
            for b in range(train_steps):
                x_b, y_b = sess.run([x, y], feed_dict={handle: handle_train})
                _, t_loss = sess.run([train_op, loss], feed_dict={data:x_b, label:y_b})
                train_loss.append(t_loss)
            # evaluation
            valid_loss = []
            for b in range(valid_steps):
                x_v, y_v = sess.run([x, y], feed_dict={handle: handle_valid})
                e_loss = sess.run([loss], feed_dict={data:x_v, label:y_v})
                valid_loss.append(e_loss)


参考文章:
Tensorflow数据读写:Numpy存储为TFRecord文件与读取

相关文章

网友评论

      本文标题:tfrecord文件读写

      本文链接:https://www.haomeiwen.com/subject/gfsjbctx.html