美文网首页
tensorflow教程2:数据读取

tensorflow教程2:数据读取

作者: 是neinei啊 | 来源:发表于2017-11-25 12:54 被阅读862次

    Tensorflow的数据读取有三种方式:

    Preloaded data: 预加载数据,也就是TensorFlow图中的常量或变量保留所有数据(对于小数据集)。
    Feeding: Python产生数据,再把数据喂给后端。
    Reading from file: 从文件中直接读取,输入流水线从TensorFlow图开头的文件中读取数据。

    Preloaded data: 预加载数据

    预加载数据方法仅限于用在可以完全加载到内存中的小数据集上,主要有两种方法:

    把数据存在常量(constant)中。
    把数据存在变量(variable)中,我们初始化并且永不改变它的值。
    用常量更简单些,但会占用更多的内存,因为常量存储在graph数据结构内部。例如:

    import tensorflow as tf
    # 构造Graph
    x1 = tf.constant([2, 3, 4])
    x2 = tf.constant([4, 0, 1])
    y = tf.add(x1, x2)
    # 打开一个session --> 计算y
    with tf.Session() as sess:
        print sess.run(y)
    

    这种方法在设计Graph的时候,x1和x2就被定义成了两个有值的列表,在计算y的时候直接取x1和x2的值。

    如果用变量的话,我们需要在graph构建好之后初始化该变量。例如:

    training_data = ...
    training_labels = ...
    with tf.Session() as sess:
      data_initializer = tf.placeholder(dtype=training_data.dtype,
                                        shape=training_data.shape)
      label_initializer = tf.placeholder(dtype=training_labels.dtype,
                                         shape=training_labels.shape)
      input_data = tf.Variable(data_initializer, trainable=False, collections=[])
      input_labels = tf.Variable(label_initializer, trainable=False, collections=[])
      ...
      sess.run(input_data.initializer,
               feed_dict={data_initializer: training_data})
      sess.run(input_labels.initializer,
               feed_dict={label_initializer: training_labels})
    

    Feeding: 供给数据

    我们一般用tf.placeholder节点来feed数据,该节点不需要初始化也不包含任何数据,我们在执行run()或者eval()指令时通过feed_dict参数把数据传入graph中来计算。如果在运行过程中没有对tf.placeholder节点传入数据,程序会报错。例如:

    import tensorflow as tf
    # 设计Graph
    x1 = tf.placeholder(tf.int16)
    x2 = tf.placeholder(tf.int16)
    y = tf.add(x1, x2)
    # 用Python产生数据
    li1 = [2, 3, 4]
    li2 = [4, 0, 1]
    # 打开一个session --> 喂数据 --> 计算y
    with tf.Session() as sess:
        print sess.run(y, feed_dict={x1: li1, x2: li2})
    

    两种方法的区别

    Preload:

    将数据直接内嵌到Graph中,再把Graph传入Session中运行。当数据量比较大时,Graph的传输会遇到效率问题。

    Feeding:

    用占位符替代数据,待运行的时候填充数据。

    Reading From File 从文件中读数据

    前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。最优的方案就是在Graph定义好文件读取的方法,让TF自己去从文件中读取数据,并解码成可使用的样本集。从文件中读取数据一般包含以下步骤:

    • 文件名列表
    • 文件名随机排序(可选的)
    • 迭代控制(可选的)
    • 文件名队列
    • 针对输入文件格式的阅读器
    • 记录解析器
    • 预处理器(可选的)
    • 样本队列

    在了解具体的操作之前首先了解文件读取数据的优点:


    AnimatedFileQueues.gif

    在上图中,首先由一个单线程把文件名堆入队列,两个Reader同时从队列中取文件名并读取数据,Decoder将读出的数据解码后堆入样本队列,最后单个或批量取出样本(图中没有展示样本出列)。我们这里通过三段代码逐步实现上图的数据流,这里我们不使用随机,让结果更清晰。

    文件准备

    $ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv
    $ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv
    $ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv
    $ cat A.csv
    Alpha1,A1
    Alpha2,A2
    Alpha3,A3
    

    单个Reader,单个样本

    import tensorflow as tf
    # 生成一个先入先出队列和一个QueueRunner
    filenames = ['A.csv', 'B.csv', 'C.csv']
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
    # 定义Reader
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    # 定义Decoder
    example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
    # 运行Graph
    with tf.Session() as sess:
        coord = tf.train.Coordinator()  #创建一个协调器,管理线程
        threads = tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。
        for i in range(10):
            print example.eval()   #取样本的时候,一个Reader先从文件名队列中取出文件名,读出数据,Decoder解析后进入样本队列。
        coord.request_stop()
        coord.join(threads)
    # outpt
    Alpha1
    Alpha2
    Alpha3
    Bee1
    Bee2
    Bee3
    Sea1
    Sea2
    Sea3
    Alpha1
    

    单个Reader,多个样本

    import tensorflow as tf
    filenames = ['A.csv', 'B.csv', 'C.csv']
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
    # 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。Decoder解码后数据会进入这个队列,再批量出队。
    # 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。
    example_batch, label_batch = tf.train.batch(
          [example, label], batch_size=5)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(10):
            print example_batch.eval()
        coord.request_stop()
        coord.join(threads)
    # output
    # ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
    # ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
    # ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
    # ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
    # ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
    # ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
    # ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
    # ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
    # ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
    # ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
    

    多Reader,多个样本

    import tensorflow as tf
    filenames = ['A.csv', 'B.csv', 'C.csv']
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    record_defaults = [['null'], ['null']]
    example_list = [tf.decode_csv(value, record_defaults=record_defaults)
                      for _ in range(2)]  # Reader设置为2
    # 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
    example_batch, label_batch = tf.train.batch_join(
          example_list, batch_size=5)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(10):
            print example_batch.eval()
        coord.request_stop()
        coord.join(threads)
        
    # output
    # ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
    # ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
    # ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
    # ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
    # ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
    # ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']
    # ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']
    # ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']
    # ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']
    # ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
    

    tf.train.batch与tf.train.shuffle_batch函数是单个Reader读取,但是可以多线程。tf.train.batch_join与tf.train.shuffle_batch_join可设置多Reader读取,每个Reader使用一个线程。至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,甚至更多的线程反而会使效率下降。

    迭代控制

    filenames = ['A.csv', 'B.csv', 'C.csv']
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3)  # num_epoch: 设置迭代数
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    record_defaults = [['null'], ['null']]
    example_list = [tf.decode_csv(value, record_defaults=record_defaults)
                      for _ in range(2)]
    example_batch, label_batch = tf.train.batch_join(
          example_list, batch_size=5)
    init_local_op = tf.initialize_local_variables()
    with tf.Session() as sess:
        sess.run(init_local_op)   # 初始化本地变量 
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                print example_batch.eval()
        except tf.errors.OutOfRangeError:
            print('Epochs Complete!')
        finally:
            coord.request_stop()
        coord.join(threads)
        coord.request_stop()
        coord.join(threads)
    # output
    # ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']
    # ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']
    # ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']
    # ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']
    # ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']
    

    在迭代控制中,记得添加tf.initialize_local_variables(),官网教程没有说明,但是如果不初始化,运行就会报错。

    下面开始正式的步骤:

    文件名列表

    文件名列表.jpg

    我们首先要有个文件名列表,为了产生文件名列表,我们可以手动用Python输入字符串,例如:

    ["file0", "file1"]
    [("file%d" % i) for i in range(2)]
    [("file%d" % i) for i in range(2)]
    

    我们也可以用tf.train.match_filenames_once函数来生成文件名列表。

    有了文件名列表后,我们需要把它送入 tf.train.string_input_producer函数中生成一个先入先出的文件名队列,文件阅读器需要从该队列中读取文件名。

    string_input_producer(
        string_tensor,
        num_epochs=None,
        shuffle=True,
        seed=None,
        capacity=32,
        shared_name=None,
        name=None,
        cancel_op=None
    )
    

    一个QueueRunner每次会把每批次的所有文件名送入队列中,可以通过设置string_input_producer函数的shuffle参数来对文件名随机排序,或者通过设置num_epochs来决定对string_tensor里的文件使用多少次,类型为整型,如果想要迭代控制则需要设置了num_epochs参数,同时需要添加tf.local_variables_initializer()进行初始化,如果不初始化会报错。
    这个QueueRunner的工作线程独立于文件阅读器的线程, 因此随机排序和将文件名送入到文件名队列这些过程不会阻碍文件阅读器的运行。

    文件格式

    根据不同的文件格式, 应该选择对应的文件阅读器, 然后将文件名队列提供给阅读器的read方法。阅读器每次从队列中读取一个文件,它的read方法会输出一个key来表征读入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。
    根据不同的文件类型,有三种不同的文件阅读器:

    • tf.TextLineReader
    • tf.FixedLengthRecordReader
    • tf.TFRecordReader

    它们分别用于单行读取(如CSV文件)、固定长度读取(如CIFAR-10的.bin二进制文件)、TensorFlow标准格式读取。

    根据不同的文件阅读器,有三种不同的解析器,它们分别对应上面三种阅读器:

    • tf.decode_csv
    • tf.decode_raw
    • tf.parse_single_exampletf.parse_example

    CSV文件

    当我们读入CSV格式的文件时,我们可以使用tf.TextLineReader阅读器和tf.decode_csv解析器。例如:

    #!/usr/bin/python
    # -*- coding: UTF-8 -*-
    import tensorflow as tf
    import numpy as np
    
    filename_queue = tf.train.string_input_producer(["./data/data1.csv", "./data/data2.csv"])
    
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    # key返回的是读取文件和行数信息 b'./data/iris.csv:146'
    # value是按行读取到的原始字符串,送到下面的decoder去解析
    
    record_defaults = [[1.0], [1.0], [1.0], [1.0], ["Null"]] # 这里的数据类型决定了读取的数据类型,而且必须是list形式
    col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults) # 解析出的每一个属性都是rank为0的标量,每次解码一行,col对应这一行的一列也就是一个数字
    features = tf.stack([col1, col2, col3, col4])
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    
        for i in range(100):
            example, label = sess.run([features, col5])
            print (example,col5)
        coord.request_stop()
        coord.join(threads)   
    

    每次read的执行都会从文件中读取一行内容,decode_csv操作会解析这一行内容并将其转为张量列表。在调用run或者eval去执行read之前, 必须先调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

    record_defaults = [[1], [1], [1], [1], [1]]代表了解析的摸版,默认用,隔开,是用于指定矩阵格式以及数据类型的,CSV文件中的矩阵是NXM的,则此处为1XM,例如上例中M=5[1]表示解析为整型,如果矩阵中有小数,则应为float型,[1]应该变为[1.0][‘null’]解析为string类型。

    col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults = record_defaults), 矩阵中有几列,这里就要写几个参数,比如5列,就要写到col5,不管你到底用多少。否则报错。

    固定长度记录

    我们也可以从二进制文件‘(.bin)中读取固定长度的数据,使用的是tf.FixedLengthRecordReader阅读器和tf.decode_raw解析器。decode_raw节点会把string转化为uint8类型的张量。

    例如CIFAR-10数据集就采用的固定长度的数据,1字节的标签,后面跟着3072字节的图像数据。使用uint8类型张量的标准操作可以把每个图像的片段截取下来并且按照需要重组。下面有一个例子:

    reader = tf.FixedLengthRecordReader(record_bytes = record_bytes)
    key, value = reader.read(filename_queue)
    record_bytes = tf.decode_raw(value, tf.uint8)
    label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
    image_raw = tf.slice(record_bytes, [label_bytes], [image_bytes])
    image_raw = tf.reshape(image_raw, [depth, height, width])
    image = tf.transpose(image_raw, (1,2,0)) # 图像形状为[height, width, channels]     
    image = tf.cast(image, tf.float32)
    

    这里介绍上述代码中出现的函数:tf.slice()

    slice(
        input_,
        begin,
        size,
        name=None
    )
    

    从一个张量input中提取出长度为size的一部分,提取的起点由begin定义。size是一个向量,它代表着在每个维度提取出的tensor的大小。begin表示提取的位置,它表示的是input的起点偏离值,也就是从每个维度第几个值开始提取。

    begin从0开始,size从1开始,如果size[i]的值为-1,则第i个维度从begin处到余下的所有值都被提取出来。

    例如:

    # 'input' is [[[1, 1, 1], [2, 2, 2]],
    #             [[3, 3, 3], [4, 4, 4]],
    #             [[5, 5, 5], [6, 6, 6]]]
    tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
    tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
                                                [4, 4, 4]]]
    tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
                                               [[5, 5, 5]]]
    

    标准TensorFlow格式

    我们也可以把任意的数据转换为TensorFlow所支持的格式, 这种方法使TensorFlow的数据集更容易与网络应用架构相匹配。这种方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Exampleprotocol buffer(里面包含了名为Features的字段)。你可以写一段代码获取你的数据, 将数据填入到Exampleprotocol buffer,将protocol buffer序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter类写入到TFRecords文件。

    从TFRecords文件中读取数据, 可以使用tf.TFRecordReader阅读器以及tf.parse_single_example解析器。parse_single_example操作可以将Exampleprotocol buffer解析为张量。 具体可以参考如下例子,把MNIST数据集转化为TFRecords格式:

    SparseTensors这种稀疏输入数据类型使用队列来处理不是太好。如果要使用SparseTensors你就必须在批处理之后使用tf.parse_example去解析字符串记录 (而不是在批处理之前使用tf.parse_single_example) 。

    预处理

    我们可以对输入的样本数据进行任意的预处理, 这些预处理不依赖于训练参数, 比如数据归一化, 提取随机数据片,增加噪声或失真等等。具体可以参考如下对CIFAR-10处理的例子:

    批处理

    经过了之前的步骤,在数据读取流程的最后, 我们需要有另一个队列来批量执行输入样本的训练,评估或者推断。根据要不要打乱顺序,我们常用的有两个函数:

    • tf.train.batch()
    • tf.train.shuffle_batch()

    下面来分别介绍:

    tf.train.batch()

    tf.train.batch(
       tensors,
       batch_size,
       num_threads=1,
       capacity=32,
       enqueue_many=False,
       shapes=None,
       dynamic_pad=False,
       allow_smaller_final_batch=False,
       shared_name=None,
       name=None
    )
    

    该函数将会使用一个队列,函数读取一定数量的tensors送入队列,然后每次从中选取batch_size个tensors组成一个新的tensors返回出来。

    capacity参数决定了队列的长度。

    num_threads决定了有多少个线程进行入队操作,如果设置的超过一个线程,它们将从不同文件不同位置同时读取,可以更加充分的混合训练样本。

    如果enqueue_many参数为False,则输入参数tensors为一个形状为[x, y, z]的张量,输出为一个形状为[batch_size, x, y, z]的张量。如果enqueue_many参数为True,则输入参数tensors为一个形状为[*, x, y, z]的张量,其中所有*的数值相同,输出为一个形状为[batch_size, x, y, z]的张量。

    allow_smaller_final_batchTrue时,如果队列中的张量数量不足batch_size,将会返回小于batch_size长度的张量,如果为False,剩下的张量会被丢弃。

    tf.train.shuffle_batch()

    tf.train.shuffle_batch(
        tensors,
        batch_size,
        capacity,
        min_after_dequeue,
        num_threads=1,
        seed=None,
        enqueue_many=False,
        shapes=None,
        allow_smaller_final_batch=False,
        shared_name=None,
        name=None
    )
    

    该函数类似于上面的tf.train.batch(),同样创建一个队列,主要区别是会首先把队列中的张量进行乱序处理,然后再选取其中的batch_size个张量组成一个新的张量返回。但是新增加了几个参数。

    capacity参数依然为队列的长度,建议capacity的取值如下:

    min_after_dequeue + (num_threads + a small safety margin) * batch_size

    min_after_dequeue这个参数的意思是队列中,做dequeue(取数据)的操作后,线程要保证队列中至少剩下min_after_dequeue个数据。如果min_after_dequeue设置的过少,则即使shuffleTrue,也达不到好的混合效果。

    假设你有一个队列,现在里面有m个数据,你想要每次随机从队列中取n个数据,则代表先混合了m个数据,再从中取走n个。

    当第一次取走n个后,队列就变为m-n个数据;当你下次再想要取n个时,假设队列在此期间入队进来了k个数据,则现在的队列中有(m-n+k)个数据,则此时会从混合的(m-n+k)个数据中随机取走n个。

    如果队列填充的速度比较慢,k就比较小,那你取出来的n个数据只是与周围很小的一部分(m-n+k)个数据进行了混合。

    因为我们的目的肯定是想尽最大可能的混合数据,因此设置min_after_dequeue,可以保证每次dequeue后都有足够量的数据填充尽队列,保证下次dequeue时可以很充分的混合数据。

    但是min_after_dequeue也不能设置的太大,这样会导致队列填充的时间变长,尤其是在最初的装载阶段,会花费比较长的时间。

    其他参数和tf.train.batch()相同。

    这里我们使用tf.train.shuffle_batch函数来对队列中的样本进行乱序处理。如下的模版:

    def read_my_file_format(filename_queue):
      reader = tf.SomeReader()
      key, record_string = reader.read(filename_queue)
      example, label = tf.some_decoder(record_string)
      processed_example = some_processing(example)
      return processed_example, label
    def input_pipeline(filenames, batch_size, num_epochs=None):
      filename_queue = tf.train.string_input_producer(
          filenames, num_epochs=num_epochs, shuffle=True)
      example, label = read_my_file_format(filename_queue)
      # min_after_dequeue 越大意味着随机效果越好但是也会占用更多的时间和内存
      # capacity 必须比 min_after_dequeue 大
      # 建议capacity的取值如下:
      # min_after_dequeue + (num_threads + a small safety margin) * batch_size
      min_after_dequeue = 10000
      capacity = min_after_dequeue + 3 * batch_size
      example_batch, label_batch = tf.train.shuffle_batch(
          [example, label], batch_size=batch_size, capacity=capacity,
          min_after_dequeue=min_after_dequeue)
      return example_batch, label_batch```
    

    一个具体的例子如下,该例采用了CIFAR-10数据集,采用了固定长度读取的tf.FixedLengthRecordReader阅读器和tf.decode_raw解析器,同时进行了数据预处理操作中的标准化操作,最后使用tf.train.shuffle_batch函数批量执行数据的乱序处理。

    class cifar10_data(object):
        def __init__(self, filename_queue):
            self.height = 32
            self.width = 32
            self.depth = 3
            self.label_bytes = 1
            self.image_bytes = self.height * self.width * self.depth
            self.record_bytes = self.label_bytes + self.image_bytes
            self.label, self.image = self.read_cifar10(filename_queue)
            
        def read_cifar10(self, filename_queue):
            reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes)
            key, value = reader.read(filename_queue)
            record_bytes = tf.decode_raw(value, tf.uint8)
            label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32)
            image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes])
            image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width])
            image = tf.transpose(image_raw, (1,2,0))        
            image = tf.cast(image, tf.float32)
            return label, image
    
    def inputs(data_dir, batch_size, train = True, name = 'input'):
        with tf.name_scope(name):
            if train:    
                filenames = [os.path.join(data_dir,'data_batch_%d.bin' % ii) 
                            for ii in range(1,6)]
                for f in filenames:
                    if not tf.gfile.Exists(f):
                        raise ValueError('Failed to find file: ' + f)
                        
                filename_queue = tf.train.string_input_producer(filenames)
                read_input = cifar10_data(filename_queue)
                images = read_input.image
                images = tf.image.per_image_standardization(images)
                labels = read_input.label
                image, label = tf.train.shuffle_batch(
                                        [images,labels], batch_size = batch_size, 
                                        min_after_dequeue = 20000, capacity = 20192)
            
                return image, tf.reshape(label, [batch_size])
                
            else:
                filenames = [os.path.join(data_dir,'test_batch.bin')]
                for f in filenames:
                    if not tf.gfile.Exists(f):
                        raise ValueError('Failed to find file: ' + f)
                        
                filename_queue = tf.train.string_input_producer(filenames)
                read_input = cifar10_data(filename_queue)
                images = read_input.image
                images = tf.image.per_image_standardization(images)
                labels = read_input.label
                image, label = tf.train.shuffle_batch(
                                        [images,labels], batch_size = batch_size, 
                                        min_after_dequeue = 20000, capacity = 20192)
            
                return image, tf.reshape(label, [batch_size])
    

    这里介绍下函数tf.image.per_image_standardization(image),该函数对图像进行线性变换使它具有零均值和单位方差,即规范化。其中参数image是一个3-D的张量,形状为[height, width, channels]。

    参考 ZangBo

    相关文章

      网友评论

          本文标题:tensorflow教程2:数据读取

          本文链接:https://www.haomeiwen.com/subject/heoqbxtx.html