美文网首页
tf.data.Dataset 属性及方法

tf.data.Dataset 属性及方法

作者: butters001 | 来源:发表于2020-06-11 11:51 被阅读0次

    大型元素集

    • 源数据集 (Source Datasets)

      创建数据集的最简单方法 list:

      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      for element in dataset:
        print(element)
      

      处理文本文件:

      dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
      

      处理TFRecord文件

      dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
      

      创建一个匹配规则的所有文件的数据集

      dataset = tf.data.Dataset.list_files("/path/*.txt")  # doctest: +SKIP
      
    • 转换 (Transformations)

      有了数据集后,您可以对准备的数据进行转换

      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      dataset = dataset.map(lambda x: x*2)
      list(dataset.as_numpy_iterator())
      # 输出 [2, 4, 6]
      
    • 小 tips

      数据集里的元素说明 element_spec

      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec
      TensorSpec(shape=(), dtype=tf.int32, name=None)
      

    方法

    • apply

      apply(
          transformation_func
      )
      参数:transformation_func 一个方法名(此方法接收一个dataset参数 并返回处理后的dataset)
      返回值:apply的返回值 即参数 transformation_func 方法的返回值
      

      将转换函数应用于此数据集

      dataset = tf.data.Dataset.range(100)
      def dataset_fn(ds):
        return ds.filter(lambda x: x < 5)
      dataset = dataset.apply(dataset_fn)
      list(dataset.as_numpy_iterator())
      # Output: [0, 1, 2, 3, 4]
      
    • as_numpy_iterator

      返回一个迭代器,该迭代器将数据集的所有元素转换为numpy

      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      for element in dataset.as_numpy_iterator():
        print(element)
      # Output: 
      # 1
      # 2
      # 3
      
      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      print(list(dataset.as_numpy_iterator()))
      # Output: [1, 2, 3]
      
    • as_numpy_iterator() 将保留数据集元素的嵌套结构

      dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
                                                    'b': [5, 6]})
      list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
                                            {'a': (2, 4), 'b': 6}]
      # Output: True
      
    • batch

      将数据集进行分批处理

      batch(
          batch_size, drop_remainder=False
      )
      # param1: batch_size 每批次包含几个元素
      # param2: drop_remainder ds_length/batch_size 不能被整除时 是否删掉最后一个批次
      # return: 一个 Dataset
      
      dataset = tf.data.Dataset.range(8)
      dataset = dataset.batch(3)
      list(dataset.as_numpy_iterator())
      # Output: [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
      
      dataset = tf.data.Dataset.range(8)
      dataset = dataset.batch(3, drop_remainder=True)
      list(dataset.as_numpy_iterator())
      # Output: [array([0, 1, 2]), array([3, 4, 5])]
      
    • cache

      缓存数据集中的元素 可缓存在内存或指定文件内

      cache(
          filename=''
      )
      # param: filename 文件名 如未提供此参数 则默认缓存到内存
      # return: 一个 Dataset
      
      dataset = tf.data.Dataset.range(5)
      dataset = dataset.map(lambda x: x**2)
      dataset = dataset.cache()  # 缓存到内存中
      # The first time reading through the data will generate the data using `range` and `map`.
      list(dataset.as_numpy_iterator())
      
      # Subsequent iterations read from the cache.
      list(dataset.as_numpy_iterator())
      
      dataset = tf.data.Dataset.range(5)
      dataset = dataset.cache("/path/to/file")  # doctest: +SKIP  缓存到文件中
      list(dataset.as_numpy_iterator())  # doctest: +SKIP
      
      dataset = tf.data.Dataset.range(10)
      dataset = dataset.cache("/path/to/file")  # Same file! # doctest: +SKIP
      list(dataset.as_numpy_iterator())  # doctest: +SKIP
      
    • concatenate

      通过将给定数据集与此数据集连接来创建一个 Dataset

      注意 两个数据集的 结构 和 数据类型 必须一致

      concatenate(
          dataset
      )
      
      a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
      b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
      ds = a.concatenate(b)
      list(ds.as_numpy_iterator())
      # Output: [1, 2, 3, 4, 5, 6, 7]
      
    • enumerate

      枚举此数据集的元素

      enumerate(
          start=0
      )
      # param: start 表示枚举的起始值
      # return: A Dataset
      
      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      dataset = dataset.enumerate(start=5)
      for element in dataset.as_numpy_iterator():
        print(element)
      # Output:
      # (5, 1)
      # (6, 2)
      # (7, 3)
      
      # The nested structure of the input dataset determines the structure of
      # elements in the resulting dataset.
      dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
      dataset = dataset.enumerate()
      for element in dataset.as_numpy_iterator():
        print(element)
      # Output:
      # (0, array([7, 8], dtype=int32))
      # (1, array([ 9, 10], dtype=int32))
      
    • filter

      过滤数据集 进行条件筛选

      filter(
          predicate
      )
      # param: predicate 将数据集元素映射到布尔值的函数。
      # return: A Dataset
      
      dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      dataset = dataset.filter(lambda x: x < 3)
      list(dataset.as_numpy_iterator())
      # Output: [1, 2]
      
      # `tf.math.equal(x, y)` is required for equality comparison
      def filter_fn(x):
        return tf.math.equal(x, 1)
      dataset = dataset.filter(filter_fn)
      list(dataset.as_numpy_iterator())
      # Output: [1]
      
    • flat_map

      flat_map(
          map_func
      )
      # param: map_func 映射数据集里每一个元素的方法
      # return: A Dataset
      

      根据map_func方法映射并展开数据集

      dataset = Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
      list(dataset.as_numpy_iterator())
      # Output: [1, 2, 3, 4, 5, 6, 7, 8, 9]
      
    • from_generator

      根据生成器创建一个数据集

      @staticmethod
      from_generator(
          generator, output_types, output_shapes=None, args=None
      )
      # param1: generator 一个可调用的生成器 其返回值的结构和类型必须与param2, param3一致 该生成器函数的参数数量需要与param4一致
      # param2: 确定生成器返回的每个值的数据类型
      # param3: 确定生成器返回的每个值的结构形状
      # param4: 一个元组 传递给生成器作为参数
      
      import itertools
      
      def gen():
        for i in itertools.count(1):
          yield (i, [1] * i)
      
      dataset = tf.data.Dataset.from_generator(
           gen,
           (tf.int64, tf.int64),
           (tf.TensorShape([]), tf.TensorShape([None])))
      
      list(dataset.take(3).as_numpy_iterator())
      # Output: [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]
      

    相关文章

      网友评论

          本文标题:tf.data.Dataset 属性及方法

          本文链接:https://www.haomeiwen.com/subject/lcxbtktx.html