tensorflow动态加载文件

作者: 阿发贝塔伽马 | 来源:发表于2018-05-21 20:52 被阅读63次

    如果把文件全部加载到内存中,对大数据量来说,是不可行的,tensorflow使用列队,通过多线程来操作队列进出。举例子来说明>

    tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。

    下面这个例子是将文件名加入到队列中,每次从列队中只能取出一个tensor,然后读取图片数据,还是频繁io操作,

    import tensorflow as tf
    import matplotlib.pyplot as plt
    %matplotlib inline
    def get_image(image_path):  
        content = tf.read_file(image_path)
        tf_image = tf.image.decode_jpeg(content, channels=3)
        return tf_image
    def plot_pic(batch_img_one_val, batch_img_two_val, label):
        fig = plt.figure(figsize=(6,2))
        plt.suptitle(label)
        ax1 = fig.add_subplot(1,2,1)
        #ax1.set_title(label)
        ax1.imshow(batch_img_one_val)
        ax1.axis('off')
        ax2 = fig.add_subplot(1,2,2)
        ax2.imshow(batch_img_two_val)
        ax2.axis('off')
        plt.show()
        
    
    def slice_input_producer_one_sample():
        # 重置graph
        tf.reset_default_graph()
        batch_size = 1
        images_one_path_list = ['lda.png', 'snapshot.png','hua.jpeg']
        images_two_path_list = ['tuzi.jpg', 'test.png', 'hua.jpeg']
        label_list = ['lad_tuzi', 'snap_test', 'hua']
        
        # 构造数据queue
        # capacity队列的大小,本例子中一个队列元素是['lda.png','tuzi.jpg','lad_tuzi],理解slice切片功能
        train_input_queue = tf.train.slice_input_producer(
            [images_one_path_list, images_two_path_list,label_list], 
              capacity= 1*batch_size, shuffle=False)
        
        # queue输出数据
        img_one_queue = get_image(train_input_queue[0])
        img_two_queue = get_image(train_input_queue[1])
        label_queue = train_input_queue[2]
    
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        # 启动queue线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(10):
            batch_img_one_val, batch_img_two_val, label= sess.run(
                [img_one_queue, img_two_queue,label_queue])
    
            plot_pic(batch_img_one_val, batch_img_two_val, label)
    
        coord.request_stop()  
        coord.join(threads)  
        sess.close()
    slice_input_producer_one_sample()
    

    第一个



    第二个



    等等。。注意每次读取两个图片一个label与输入list的对应关系

    现在把读取的图片内存加入到新列队中
    使用tf.train.shuffle_batch
    取两次图片,每次取三个,这样程序就从列队中取出已经加载好的图片内存数据

    import matplotlib.pyplot as plt
    def conver_image_size(img,hsize, wsize):
        img = tf.image.convert_image_dtype(img, dtype=tf.float32)  
        img = tf.image.resize_images(img, [hsize, wsize])
        return img
    
    
    def slice_input_producer_demo():
        # 重置graph
        tf.reset_default_graph()
        # 获取图片系统路径,标签信息
        batch_size = 3
        hsize = 377
        wsize = 500
     
        images_one_path_list = ['lda.png', 'snapshot.png','hua.jpeg']
        images_two_path_list = ['tuzi.jpg', 'test.png', 'hua.jpeg']
        label_list = ['lad_tuzi', 'snap_test', 'hua']
        
        # 构造数据queue
        train_input_queue = tf.train.slice_input_producer(
            [images_one_path_list, images_two_path_list,label_list], 
              capacity= 3, shuffle=False)
        
        # queue输出数据
        img_one_queue = get_image(train_input_queue[0])
        img_two_queue = get_image(train_input_queue[1])
        label_queue = train_input_queue[2]
        # shuffle_batch 批量从queue批量读取数据
    
        img_one_queue = conver_image_size(img_one_queue, hsize, wsize)
        img_two_queue = conver_image_size(img_two_queue, hsize, wsize)
        
        batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch(
                 [img_one_queue, img_two_queue, label_queue],
                  batch_size=batch_size,
                  capacity =  10 + 10* batch_size,
                  min_after_dequeue = 10,
                  num_threads=16)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        
        # 启动queue线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        
        for i in range(2):
            batch_img_one_val, batch_img_two_val, label= sess.run(
                [batch_img_one, batch_img_two,batch_label])
            print label
            fig = plt.figure(figsize=(4,6))
            for k in range(batch_size):
                ax1 = fig.add_subplot(batch_size,2,2*k+1)
                ax1.set_title(label[k])
                plt.imshow(batch_img_one_val[k])
                ax2 = fig.add_subplot(batch_size,2,2*k+2)
                ax2.set_title(label[k])
                ax2.imshow(batch_img_two_val[k])
            plt.show()
        coord.request_stop()  
        coord.join(threads)  
        sess.close()
    
    第一次
    第二次

    string_input_producer加载序列

    def string_input_producter_demo():
        tf.reset_default_graph()
        images_one_path_list = ['lda.png', 'snapshot.png','hua.jpeg']
        images_two_path_list = ['tuzi.jpg', 'test.png', 'hua.jpeg']
        label_list = ['lad_tuzi', 'snap_test', 'hua']
        batch_size = 2
        hsize = 377
        wsize = 500
        # 构造数据queue
        train_input_queue = tf.train.string_input_producer(
            images_one_path_list, capacity=10*batch_size)
        
        #queue输出数据
        img_one_queue = get_image(train_input_queue.dequeue())
        
        img_one_queue = conver_image_size(img_one_queue, hsize, wsize)
        # 将图片数据加载到新的队列
        batch_img_one = tf.train.shuffle_batch(
                 [img_one_queue],
                  batch_size=batch_size,
                  capacity =  10 + 10* batch_size,
                  min_after_dequeue = 10,
                  num_threads=16)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        
        for i in range(2):
            for k in range(batch_size):
                img_one_val = sess.run(batch_img_one[k])
                fig = plt.figure()
                plt.imshow(img_one_val)
                plt.show()
        coord.request_stop()
        coord.join(threads)
        sess.close()
    string_input_producter_demo()
    

    加载CSV文件

    A.csv文件如下
    import tensorflow as tf
    from tensorflow.python.framework import ops
    ops.reset_default_graph()
    
    batch_size = 2
    filenames = ['A.csv', 'B.csv', 'C.csv']
    
    filename_queue = tf.train.string_input_producer(
            filenames, shuffle=False)
    # 定义Reader
    reader = tf.TextLineReader()
    key,value = reader.read(filename_queue)
    
    # 定义Decoder
    example, label = tf.decode_csv(
        value, record_defaults = [['null'], ['null']])
    batch_data,label_data = tf.train.shuffle_batch(
                 [example, label],
                  batch_size=batch_size,
                  capacity =  10 + 10* batch_size,
                  min_after_dequeue = 10,
                  num_threads=16)
    
    with tf.Session() as sess:
        # 创建一个协调器,管理线程    
        coord = tf.train.Coordinator()
        # 启动QueueRunner,此时文件数据列队已经进队
        threads = tf.train.start_queue_runners(coord=coord)
        sess.run(tf.global_variables_initializer())
    
        for i in range(9):
            batch_, label_ = sess.run([batch_data, label_data])
            print batch_
            print label_
            print '-----'
        coord.request_stop()
        coord.join(threads)
    

    每次从列队中加载两个数据


    参考
    Tensorflow 数据预读取--Queue

    相关文章

      网友评论

      本文标题:tensorflow动态加载文件

      本文链接:https://www.haomeiwen.com/subject/cebqjftx.html