美文网首页
dataset导入数据简介

dataset导入数据简介

作者: 灿烂的GL | 来源:发表于2018-04-16 20:04 被阅读0次

dataset:作为数据和模型的接口

1、启动方式

源调用(内存中的某些张量构建数据集):tf.data.Dataset.from_tensors()ortf.data.Dataset.from_tensor_slices()

TFRecord格式存储在磁盘上:tf.data.TFRecordDataset

2、调用数据

单个调用,每元素转换:如Dataset.map()

多元转换(如Dataset.batch())

3、不同数据的导入方式

eg:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))

print(dataset1.output_types)  # ==> "tf.float32"

print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(

  (tf.random_uniform([4]),

    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))

print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"

print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))

print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

dataset = tf.data.Dataset.from_tensor_slices(

  {"a": tf.random_uniform([4]),

    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})

print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"

print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

4、不同的迭代器:

one-shot,

initializable,

reinitializable, and

feedable.

参考:dataset详细介绍

这里使用迭代器有一个优点,迭代结束后直接跳出

tf.errors.OutOfRangeError:错误,不用使用平常TF中loop while True的结构

5、tfrecord数据导入处理

eg:

filenames = tf.placeholder(tf.string, shape=[None])

dataset = tf.data.TFRecordDataset(filenames)

dataset = dataset.map(...)  # Parse the record into tensors.

dataset = dataset.repeat()  # Repeat the input indefinitely.

dataset = dataset.batch(32)

iterator = dataset.make_initializable_iterator()

# Initialize `iterator` with training data.

training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]

sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

# Initialize `iterator` with validation data.

validation_filenames = ["/var/data/validation1.tfrecord", ...]

sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

支持随机初始化数据

eg:

def dataset_input_fn():

  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]

  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`

  # protocol buffer, and perform any additional per-record preprocessing.

  def parser(record):

    keys_to_features = {

        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),

        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),

        "label": tf.FixedLenFeature((), tf.int64,

                                    default_value=tf.zeros([], dtype=tf.int64)),

    }

    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.

    image = tf.image.decode_jpeg(parsed["image_data"])

    image = tf.reshape(image, [299, 299, 1])

    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label

  # tensor for each example.

  dataset = dataset.map(parser)

  dataset = dataset.shuffle(buffer_size=10000)

  dataset = dataset.batch(32)

  dataset = dataset.repeat(num_epochs)

  iterator = dataset.make_one_shot_iterator()

  # `features` is a dictionary in which each value is a batch of values for

  # that feature; `labels` is a batch of labels.

  features, labels = iterator.get_next()

  return features, labels

这里存在一个疑问:tfrecord直接编解码和tfrecord+dataset进行训练的区别,之后会补充;对于在框架中应用参考:几种高级API组合应用

相关文章

网友评论

      本文标题:dataset导入数据简介

      本文链接:https://www.haomeiwen.com/subject/pjixhftx.html