美文网首页
1. Tensorflow实战学习:TFRecord读取数据

1. Tensorflow实战学习:TFRecord读取数据

作者: 闪电侠悟空 | 来源:发表于2017-11-30 15:25 被阅读0次

    数据读取的基本方式参见CS 20SIInput Pipeline部分,Tensorflow主要有两种加载数据的方式:

    1. Feeding:给出placeholder,然后在session中传递参数的方式输入数据。
    2. Reading from files: 不显示的利用用占位符,直接利用从文件读取生成队列,然后利用tf.cast函数直接将数据丢入到tensorflow的Graph中。

    超大数据文件的主流读取的方式是第二种,并且最好是使用Tensorflow自带的TFRecord文件格式,TFRecord使用方法也比较简单,(要很好的理解其中的队列和多线程的原理,请看CS20SI课程提供的PPT),下面是实现的网页参考:

    tf.train.shuffle_batch的使用说明

    tf.train.batch和tf.train.shuffle_batch的理解:详细解释了这两个batch函数使用的不同。min_after_dequeue值越大,数据越乱,为了效率,个人认为保持capacity值的1/2到3/4就足够乱了。

    生成和读取实验

    输入图像及亮点比例标签

    下面是具体的实现代码:

    '''
    TFRecord Study
    数据的写入,与数据的读取
    Author: 闪电侠悟空
    Date: 2017-11-30
    '''
    from time import sleep
    
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    TRAIN_NUM = 10000
    
    def write2tfrecord():
        '''
        Write to TFRecord files
        '''
        # Step 1. construct the TFRecord Writer
        writer = tf.python_io.TFRecordWriter(path='IMBD.tfrecords')
    
        for (threshold,i) in zip(np.linspace(0,1,TRAIN_NUM,dtype=np.float32),range(TRAIN_NUM)):
            print(threshold,'and ', i,'is saving!')
            prob = np.random.uniform(0,1,[64,64]) # construct the data set
            image = np.uint8(prob<threshold)*255
            print(type(image[1,1]))
    
            # Step 2. to bytes
            image_raw = image.tostring()
    
            # Step 3. construct the example
            y = tf.train.Feature(float_list=tf.train.FloatList(value=[threshold]))
            x = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))
            z = tf.train.Feature(int64_list=tf.train.Int64List(value = [i]))
    
            example = tf.train.Example(features=tf.train.Features(feature = {"percent":y,"number":z,"raw_image":x }))
    
            # Step 4. write the example to the file
            writer.write(example.SerializeToString())
    
            pass
        #Step 5. close the writer
        writer.close()
    
    
    def readanddecode():
        filename_queue = tf.train.string_input_producer(['IMBD.tfrecords'])
    
        reader = tf.TFRecordReader()
        _,serialized_example = reader.read(filename_queue) # Reture the file name and content
        features = tf.parse_single_example(serialized_example,features={"percent":tf.FixedLenFeature([],tf.float32),\
                                                                        "number":tf.FixedLenFeature([],tf.int64),\
                                                                        "raw_image":tf.FixedLenFeature([],tf.string)})
        img = tf.decode_raw(features["raw_image"],tf.uint8)
        img = tf.reshape(img,[64,64])
        img = tf.cast(img,tf.uint8)
    
        i = tf.cast(features["number"],tf.int64)
        percent = tf.cast(features["percent"],tf.float32)
        return  img,i,percent
    
    def mainloop():
        img, i, percent = readanddecode()#get a single data
        img_batch,i_batch = tf.train.shuffle_batch([img,i],batch_size=20,capacity=10000,min_after_dequeue=9999)
    
        init = tf.global_variables_initializer()
    
        with tf.Session() as sess:
            sess.run(init)
            threads = tf.train.start_queue_runners(sess=sess)
            for j in range(502):
                val_images, val_is = sess.run([img_batch,i_batch])
                print(val_is)
    
    if __name__ =="__main__":
        #write2tfrecord()
        #readanddecode()
        mainloop()
    

    相关文章

      网友评论

          本文标题:1. Tensorflow实战学习:TFRecord读取数据

          本文链接:https://www.haomeiwen.com/subject/xlpvbxtx.html