美文网首页PythonTensorFlow
TensorFlow三种常用的数据加载方式(附python演练)

TensorFlow三种常用的数据加载方式(附python演练)

作者: 人工智能遇见磐创 | 来源:发表于2018-12-13 15:22 被阅读12次

    简介

    TensorFlow系列后期部分正在整理,整理好后会继续更新。在此段时间大家有什么疑问的,可以留言,我看见了会为您解答。

    今天主要说下一些在TensorFlow读取数据部分的内容,希望对大家有帮助。文章内容参考了一篇博客:https://blog.csdn.net/lujiandong1/article/details/53376802,尊重该博主原创。

    TensorFlow读取数据有三种方式:

    • Preloaded data: 预加载数据
    • Feeding: Python: 产生数据,再把数据喂给后端
    • Reading from file: 从文件中直接读取

    注意我们这里说的读取数据的三种方式不是针对有多少种不同的数据格式(比如像字典结构数据;bin file 读取数据;CSV读取数据,从原图读取数据等),而是指读取数据的不同方式。

    一 预加载数据

    Import tensorflow as tf
     # 设计Graph
     x1 = tf.constant([2,3,4])
     x2 = tf.constant([4,0,1])
     y = tf.add(x1,x2) 
     # 打开一个session,计算y
     with tf.Session() as sess:
       print(sess.run(y))
    

    可以看见预加载数据的读取方式是直接读取定义好的数据,直接嵌入至Graph,然后将Graph传入Session中运行。

    二 Feeding方式加载数据

    import tensorflow as tf
    # 设计Graph
    x1 = tf.placeholder(tf.int16)
    x2 = tf.placeholder(tf.int16)
    y = tf.add(x1, x2)
    # 用Python产生数据
    li1 = [2, 3, 4]
    li2 = [4, 0, 1]
    with tf.Session() as sess:
       print sess.run(y, feed_dict={x1: li1, x2: li2})
    

    Feeding方式加载数据时,是事先不知道传进来的数据是什么,只需要先用tf.placeholder方法定义好准备放入的数据的类型等特征。然后同预加载数据比较,在打开一个session后将具体的数据比如这里的li1,li2喂给我们提前用tf.placeholder定义好的位置x1,x2占位符。这样x1,x2这时候就会被传入li1,li2用于进行计算了。用占位符代替数据,待运行的时候填充数据。

    前两种方法很方便,但是如果遇到大型数据时候会很吃力。最好的办法是在Graph定义好文件读取的办法,让tensorflow自己从文件中读取数据,并解码成可用的样本。

    这么说可能大家会疑惑,难道前两种不也是读取好文件后才传的数据吗?

    举个例子,预加载数据一般是定义tensorflow中的Graph里所需要的常量,我们不会将所有的数据都像这样去定义,只会定义graph中所需要的一些常量而已。

    而Feeding读取数据是这个样子的:假设读入的时间是0.1s,计算的时间是0.9s。那么Feeding方式读取数据,在我们之前说的识别Mnist手写数据集例子中,先读取一个batch的数据,假设读取数据后计算,其中读取数据花费0.1s,计算花费0.9s。那么当我们进行下一次batch个数据的的读取时,又要先花费0.1s读取数据,再花费0.9s计算。这就意味在,在每次读取新的数据的那个0.1s时,CPU是处于空闲状态的。这样就会大大降低了与运行的效率。

    三 Reading from file方式载入数据

    Reading from file方式通过将读取数据和计算这两个过程分别放入两个线程中来解决CPU在读取数据时处于闲置状态而导致的低效率问题。

    图一

    读取线程将文件系统中的数据陆续读入进内存的队列中,而另外计算是另外一个线程。这样这两个线程同时工作,就能保证CPU一直在计算,而不会因为IO阶段而闲置的问题。

    图二
    $ echo -e "Alpha1,A1
    Alpha2,A2
    Alpha3,A3" > A.csv
    $ echo -e "Bee1,B1
    Bee2,B2
    Bee3,B3" > B.csv
    $ echo -e "Sea1,C1
    Sea2,C2
    Sea3,C3" > C.csv
    #单个Reader,单个样本
    #-*- coding:utf-8 -*-
    import tensorflow as tf
    # 生成一个先入先出队列和一个QueueRunner,生成文件名队列
    filenames = ['A.csv', 'B.csv', 'C.csv']
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
    # 定义Reader
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    # 定义Decoder
    example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
    #example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2)
    # 运行Graph
    with tf.Session() as sess:
       coord = tf.train.Coordinator() #创建一个协调器,管理线程
       threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
       for i in range(10):
         print example.eval(),label.eval()
    coord.request_stop()
    coord.join(threads)
    

    读取步骤:

    1. 创建文件名列表list (filenames)
    2. 创建文件名队列创建文件名队列,调用tf.train.string_input_producer(),参数包含:文件名列表,num_epochs(定义重复次数),shuffle(定义是否打乱文件的顺序)
    3. 定义对应文件的阅读器。(针对不同类型的数据文件有不同的阅读器,如tf.ReaderBase、tf.TFRecordReader 、tf.TextLineReader 、tf.WholeFileReader 、tf.IdentityReader 、tf.FixedLengthRecordReader)
    4. 解析器 (同理针对不同类型的数据,如csv,图片,bin数据等,有不同的解析器如tf.decode_csv 、tf.decode_raw 、 tf.image.decode_image)
    5. 预处理,对原始数据进行处理,以适应network输入所需
    6. 生成batch,调用tf.train.batch() 或者 tf.train.shuffle_batch()
    7. prefetch(可选)使用预加载队列slim.prefetch_queue.prefetch_queue()
    8. 启动填充队列的线程,调用tf.train.start_queue_runners

    这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。tf.train.string_input_producer还有两个重要的参数,一个是num_epochs,表示epoch数。另外一个就是shuffle是指在一个epoch内文件的顺序是否被打乱。在tensorflow中,内存队列不需要我们自己建立,我们只需要使用reader对象从文件名队列中读取数据就可以了。

    在我们使用tf.train.string_input_producer创建文件名队列后,整个系统其实还是处于"停滞状态"的,也就是说,我们文件名并没有真正被加入到队列中,此时如果我们开始计算,因为内存队列中什么也没有,计算单元就会一直等待,导致整个系统被阻塞。使用tf.train.start_queue_runners之后,才会启动填充队列的线程,这时系统就不再"停滞"。此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了。

    这次就说到这里,大致介绍了tensorflow读取数据的三种不同方式,尤其是最后一种,在针对大型数据时普遍采用的方法,内容是参考了原博主的内容。下次会着重介绍第三种数据加载方式更细节的内容。关于shuffle的执行和线程thread的使用方法。希望此次分享对大家用tensorflow处理数据读取上有所帮助。


    对深度学习感兴趣,热爱Tensorflow的小伙伴,欢迎关注我们的网站http://www.panchuang.net 我们的公众号:磐创AI。

    相关文章

      网友评论

        本文标题:TensorFlow三种常用的数据加载方式(附python演练)

        本文链接:https://www.haomeiwen.com/subject/wrcbhqtx.html