美文网首页
Tensorflow 数据预读取--Queue

Tensorflow 数据预读取--Queue

作者: yalesaleng | 来源:发表于2018-07-16 16:16 被阅读7次

    Google开源的深度学习框架Tensorflow在数据预取上做了一些特殊的特征来提高模型训练或者推理的效率,避免在IO上耗费过多的时间。本文通过几个简单例子介绍Tensorflow构建queue常用函数的使用方法。

    深度学习训练模型通常是建立在大数据基础上,一般情况下可以把数据都加载到内存避免训练时数据读取IO。但是,当数据占用空间较大,如图片集或者视频集,无法全部载入内存;另一种方式是在训练时再读取需要的数据,但是增加的IO耗时会让模型训练过程很漫长很漫长。

    Tensorflow提供了Queue这个工具来更好的解决这类问题。Queue构建了一个大小为capacity的缓存区,多线程执行数据的enqueue,神经网络模型从缓存区dequeue数据。如果capacity足够大,数据的加载和读取可以同时执行,没有阻塞,从而IO的时间几乎可以忽略不计。

    slice_input_producer

    过程描述:图片数据保存在本地,内存中保存所有图片的系统路径,现在构建Queue,从磁盘上读取并缓存数据。整个过程类似于:

    def slice_input_producer_demo(image_pair_path, summary_path):
        # 重置graph
        tf.reset_default_graph() 
        # 获取<图片一系统路径,图片二系统路径,标签信息>三个list(load_data函数见supplementary)
        image_one_path_list, image_two_path_list, label_list = load_data()
        ## 构造数据queue
        train_input_queue = tf.train.slice_input_producer([image_one_path_list, image_two_path_list, label_list], capacity=10 * batch_size)
    
        ## queue输出数据
        img_one_queue = get_image(train_input_queue[0])
        img_two_queue = get_image(train_input_queue[1])
        label_queue = train_input_queue[2]
    
        ## shuffle_batch批量从queu批量读取数据
        batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue],batch_size=batch_size,capacity =  10 + 10* batch_size,min_after_dequeue = 10,num_threads=16,shapes=[(image_width, image_height, image_channel),(image_width, image_height, image_channel),()])
    
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
    
        summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)
    
        ## 启动queue线程
        coord = tf.train.Coordinator()  
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)  
    
        for i in range(10):
            batch_img_one_val, batch_img_two_val, label = sess.run([batch_img_one, batch_img_two,batch_label])
            for k in range(batch_size):
                fig = plt.figure()
                fig.add_subplot(1,2,1)
                plt.imshow(batch_img_one_val[k])
                fig.add_subplot(1,2,2)
                plt.imshow(batch_img_two_val[k])
                plt.show()
    
        coord.request_stop()  
        coord.join(threads)  
        sess.close()
        summary_writer.close()
    

    整个过程很清晰,主要由以下几步组成:
    1、图片的路径和标记信息载入内存:image_one_path_list, image_two_path_list, label_list = load_data()
    2、构造第一个queue:train_input_queue = tf.train.slice_input_producer( [image_one_path_list, image_two_path_list, label_list], capacity=10 * batch_size)
    3、从queue取出图片路径数据加载图片:img_one_queue = get_image(train_input_queue[0])
    4、构造第二个queue:shuffle_queue,把图片数据enqueue到缓存区,批量dequeue输出结果。batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue]...)

    string_input_producer

    string_input_producer从一个pipeline把字符串输出到一个queue。

    def string_input_producer_demo(image_pair_path, summary_path):
        tf.reset_default_graph()
    
        image_one_path_list, image_two_path_list, label_list = load_data()
        ## 构造数据queue
        train_input_queue = tf.train.string_input_producer(image_one_path_list, capacity=10 * batch_size)
    
        ## queue输出数据
        img_one_queue = get_image(train_input_queue.dequeue())
    
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
        summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)
    
        ## queue线程
        coord = tf.train.Coordinator()  
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)  
    
        for i in range(10):
            img_one_val = sess.run([img_one_queue])
            fig = plt.figure()
            plt.imshow(img_one_val[0])
            plt.show()
    
        coord.request_stop()  
        coord.join(threads)  
        sess.close()
        summary_writer.close()
    

    range_input_producer:生成0到limit-1的queue

    def range_input_producer_demo(image_pair_path, summary_path):
        tf.reset_default_graph()
    
        image_one_path_list, image_two_path_list, label_list = load_data()
        length_data = len(image_one_path_list)
    
        image_one_path_list = tf.convert_to_tensor(image_one_path_list)
        image_two_path_list = tf.convert_to_tensor(image_two_path_list)
        label_list = tf.convert_to_tensor(label_list)
    
        ## 构造数据queue
        train_input_queue = tf.train.range_input_producer(length_data, capacity=10 * batch_size)
    
        ## queue输出数据
        range_index = train_input_queue.dequeue()
        img_one_queue = get_image(tf.gather(image_one_path_list, range_index))
        img_two_queue = get_image(tf.gather(image_two_path_list, range_index))
        label_queue = range_index 
    
        ## 批量从queu读取数据
        batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch([img_one_queue, img_two_queue, label_queue],batch_size=batch_size,capacity =  10 + 10* batch_size,min_after_dequeue = 10,num_threads=16,shapes=[(image_width, image_height, image_channel),(image_width, image_height, image_channel),()])
    
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
    
        summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)
    
        ## queue线程
        coord = tf.train.Coordinator()  
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)  
    
        for i in range(10):
            batch_img_one_val, batch_img_two_val, label = sess.run([batch_img_one, batch_img_two,batch_label])
            for k in range(batch_size):
                fig = plt.figure()
                fig.add_subplot(1,2,1)
                plt.imshow(batch_img_one_val[k])
                fig.add_subplot(1,2,2)
                plt.imshow(batch_img_two_val[k])
                plt.show()
    
        coord.request_stop()  
        coord.join(threads)  
        sess.close()
        summary_writer.close()
    
    

    input_producer:input_tensor里的行构成queue

    def input_producer_demo(image_pair_path, summary_path):
        tf.reset_default_graph()
    
        image_one_path_list, image_two_path_list, label_list = load_data()
        length_data = len(image_one_path_list)
    
        image_one_path_list = tf.convert_to_tensor(image_one_path_list)
    
        ## 构造数据queue
        train_input_queue = tf.train.input_producer(image_one_path_list, capacity=10 * batch_size)
    
        ## Expected string, got <tensorflow.python.ops.data_flow_ops.FIFOQueue object of type 'FIFOQueue' instead.
        img_one_queue = get_image(train_input_queue.dequeue())
    
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
    
        summary_writer = tf.train.SummaryWriter(summary_path, graph_def=sess.graph)
    
        ## queue线程
        coord = tf.train.Coordinator()  
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)  
    
        for i in range(10):
            batch_img_one_val = sess.run([img_one_queue])
    
    #         for k in range(batch_size):
            print(len(batch_img_one_val))
            fig = plt.figure()
            plt.imshow(batch_img_one_val[0])
            plt.show()
    
        coord.request_stop()  
        coord.join(threads)  
        sess.close()
        summary_writer.close()
    

    supplementary

    数据格式:
    /home/Alex/4000.jpg /home/Alex/4001.jpg 0
    /home/Alex/4000.jpg /home/Alex/4002.jpg 1

    # 获取《图片一本地路径,图片二本地路径,标记》数据对
    def load_data():
        reader_handler = open(image_pair_path, 'r')
    
        image_one_path_list = []
        image_two_path_list = []
        label_list = []
    
        count = 0
        for line in reader_handler:
            count = count + 1
            elems = line.split("\t")
            if len(elems) < 3:
                print("len(elems) < 3:" + line)
                continue
            image_one_path = elems[0].strip()
            image_two_path = elems[1].strip()
            label = int(elems[2].strip())
    
            image_one_path_list.append(image_one_path)
            image_two_path_list.append(image_two_path)
            label_list.append(label)
    
        return image_one_path_list, image_two_path_list, label_list
    
    # 根据图片路径读取图片
    def get_image(image_path):  
        """Reads the jpg image from image_path. 
        Returns the image as a tf.float32 tensor 
        Args: 
            image_path: tf.string tensor 
        Reuturn: 
            the decoded jpeg image casted to float32 
        """  
        content = tf.read_file(image_path)
        tf_image = tf.image.decode_jpeg(content, channels=3)
    
        return tf_image
    

    相关文章

      网友评论

          本文标题:Tensorflow 数据预读取--Queue

          本文链接:https://www.haomeiwen.com/subject/fkubpftx.html