tf.data模块

作者: 612twilight | 来源:发表于2020-03-19 14:42 被阅读0次

    tf.data是tensorflow提供的用来构建模型输入流水线的模块,集成了map,reduce,batch,shuffle等功能,使用起来比较方便,最佳的自然去看官网链接,这里只是我的学习记录。

    tf.data.Dataset.from_tensor_slices

    • 传入一维的list,输出的是scalar
    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
    >>> list(dataset.as_numpy_iterator())
    [1, 2, 3]
    

    传入二维的tensor,输出一维的tensor

    >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
    >>> list(dataset.as_numpy_iterator())
    [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
    

    传入字典,字典的value对应的是tensor,输出的也是字典,key不变,value正常切割

    >>> # Dictionary structure is also preserved.
    >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
    >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
    ...                                       {'a': 2, 'b': 4}]
    True
    

    传入tuple构成的数据,相当于tuple内部的元素依次切割,并然后在组合起来

    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
    >>> list(dataset.as_numpy_iterator())
    [(1, 3, 5), (2, 4, 6)]
    
    # 这种格式在tf.keras的model.fit里面很常用,尤其是对于有多个输入和输出的时候,可以用key去指定,这里前一个字典代表了feature,后一个字典代表了label
    >>> dataset = tf.data.Dataset.from_tensor_slices(({"a": [1, 2]}, {"b": [3, 4]}))
    >>> list(dataset.as_numpy_iterator())
    [({'a': 1}, {'b': 3}), ({'a': 2}, {'b': 4})]
    

    注意:这里不支持non-rectangular形式的输入tensor,比如这种就不行,然而使用from_generator可以接受不一样的tensor

    dataset = tf.data.Dataset.from_tensor_slices([[1], [2,3]])
    print(list(dataset.as_numpy_iterator()))
    

    tf.data.Dataset.from_generator

    def gen_series1(): #生成不定长度
        i = 0
        while True:
            size = np.random.randint(0, 10)
            yield np.random.normal(size=(size,))
            i += 1
    
    ds_series = tf.data.Dataset.from_generator(gen_series1, output_types=tf.float32, output_shapes=None)
    
    def gen_series2(): # 定长与不定长的组合tuple
        i = 0
        while True:
            size = np.random.randint(0, 10)
            yield i, np.random.normal(size=(size,))
            i += 1
    
    
    ds_series = tf.data.Dataset.from_generator(gen_series2, output_types=(tf.int32, tf.float32), output_shapes=((), (None,)))
    
    
    这里的output_types是用来指定类型
    output_shapes是指定shape
    
    

    tf.data.TFRecordDataset

    还有一种是从TFRecord文件里面读取数据的接口,TFRecords是tensorflow推荐的数据存取方式,里面每一个元素都是一个tf.train.Example,一般需要先解码才可以使用。

    def parse_example(example):
        feature_dict = {
            "fixlen1": tf.io.FixedLenFeature([10], tf.int32),
            "fixlen2": tf.io.FixedLenFeature([10], tf.int32),
            "varlen": tf.io.VarLenFeature(tf.int32)
        }
        feature = tf.io.parse_single_example(example, feature_dict)
        # 如果是要输入给tf.keras,假设fixlen1,fixlen2是feature,而varlen是label
        # 这里可以做一下转换,变成tuple
        return {"fixlen1": feature['fixlen1'], "fixlen2": feature["fixlen2"]}, {"varlen": feature["varlen"]}
    
    dataset = tf.data.TFRecordDataset(["recordfile.records"])
    dataset.map(parse_example)
    

    tf.data.TextLineDataset

    dataset = tf.data.TextLineDataset(file_paths)
    

    一行行的读取与返回

    tf.data.experimental.make_csv_dataset

    titanic_batches = tf.data.experimental.make_csv_dataset(
        titanic_file, batch_size=4,
        label_name="survived", select_columns=['class', 'fare', 'survived'])
    

    可以指定label列名,以及选取哪几列作为特征

    相关文章

      网友评论

        本文标题:tf.data模块

        本文链接:https://www.haomeiwen.com/subject/qahlyhtx.html