美文网首页
TFRecord 统一输入数据格式和组合数据

TFRecord 统一输入数据格式和组合数据

作者: huim | 来源:发表于2018-12-15 14:53 被阅读0次

    TF 提供了一种统一输入数据的格式—— TFRecord ~
    它有两个特别好的优点:
    1.可以将一个样本的所有信息统一起来存储,这些信息可以是不同的数据类型;
    2.利用文件队列的多线程操作,使得数据的读取和批量处理更加方便快捷。

    part 1 获得数据

    从 CelebA 数据集的20多万个数据中,得到每一个样本的图像及对应的标签,用作图像分类的训练和测试数据:

    def get_data(txt_path,img_path):
    imgs = []
    labels = []
    with open(txt_path) as f:
        # 解压后的 list_attr_celeba.txt 文件从第三行开始是数据内容
        line = f.readline() # 第一行
        line = f.readline() # 第二行
        line = f.readline() # 第三行
        while line:
            array = line.split()
            file_name = array[0]
            # print(file_name)
            img = cv2.imread(img_path+file_name)
            img = cv2.resize(img,(96,128))
            imgs.append(img)
            label = np.zeros([5,2]) 
            for i,idx in enumerate([16,35,36,38,39]):
                l = int(array[idx])
                if l == 1:
                    label[i,1] = 1
                else:
                    label[i,0] = 1
            labels.append(label)
            line = f.readline()
        print('Data prepared!')
    return imgs,labels
    

    调用上面定义的 get_data()函数,得到 images 和 labels(这里label取了5类,判断人脸是否含有帽子/眼镜/项链/耳环/领带等装饰):

    txt_path = r'E:/celeA/list_attr_celeba.txt'
    img_path = r'E:/celeA/img_align_celeba/'
    imgs,labels=get_data(txt_path,img_path)
    len(imgs),len(labels) # (202599, 202599)
    

    part 2 创建一个 writer 将数据写入TFRecord文件

    TFRcord 文件中的数据都是通过 tf.train.Example() 定义的,其中包含了一个从属性名称到取值的字典。
    属性名称为一个字符串,属性的取值可以为字符串(BytesList)/ 实数列表(FloatList)/ 整数列表(Int64List)。

    # 生成字符串型的属性。
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    # 生成整数型的属性。
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    # 生成实数型的属性。
    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    

    如:imgs 的数据类型为 uint8,而labels的数据类型为float64,tfrecord可以将图片及其对应的标签编码成字符串,作为 tfrecord 文件中的一条数据。下面取前20万数据做为训练数据,写入20个文件,每个文件记录10000条数据,剩下的作为测试数据:

    num_shards = 20 # 文件数
    instances_per_shard = 10000 # 每个文件包含的数据量
    
    for i in range(num_shards):
        # 文件名如'E:/celeA/data/data.tfrecords-00000-of-00100'
        filename = ('E:/celeA/all_data/test.tfrecords-%.5d-of-%.5d'%(i,num_shards-1))
    writer = tf.python_io.TFRecordWriter(filename)
    for j in range(instances_per_shard):
        # 将图像和标签转化成字符串
        image_raw = test_x[instances_per_shard*i+j].tostring()
        label_raw = test_y[instances_per_shard*i+j].tostring()
        # 将图像和标签数据作为一个example
        example = tf.train.Example(features=tf.train.Features(feature={
            'image':_bytes_feature(image_raw),
            'label':_bytes_feature(label_raw)
        }))
        writer.write(example.SerializeToString())
    writer.close()
    

    part 3 创建一个reader来读取TFRecord文件

    files=tf.train.match_filenames_once('E:/celeA/all_data/data.tfrecords-*')
    # 文件队列,方便利用多线程管理原始文件列表
    filename_queue = tf.train.string_input_producer(files,shuffle=False)
    
    reader = tf.TFRecordReader()
    # 解析读入的单个数据
    _,serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
    serialized_example,
    features={
        # tf.FixedLenFeature() 是一种属性解析方法,解析结果为一个Tensor
        'image':tf.FixedLenFeature([],tf.string),
        'label':tf.FixedLenFeature([],tf.string) # 这里的数据格式要和写入时一样
    })
    
    # tf.decode_raw() 用于解析字符串
    # tf.cast() 转换数据类型
    img = tf.decode_raw(features['image'],tf.uint8)
    image = tf.reshape(tf.cast(img,tf.float32), [96,128,3])
    l = tf.decode_raw(features['label'],tf.float64)
    label = tf.reshape(tf.cast(l,tf.float32), [5,2])
    
    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        # print(sess.run(files))
    
        # 启用多线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)    
    
        for i in range(80):
            im,la = sess.run([image,label])
            print(im.shape,la.dtype) # (96, 128, 3) float32
    
        coord.request_stop()
        coord.join(threads)
    

    注意:在解析字符串时,解析的数据类型如果和原始数据的数据类型不一样,解析得到的结果就和原始数据不同,所以在读写 tfrecord 文件时一定要明确原始数据类型。这里image的原始数据类型是uint8,为了作为在tensorflow中网络的输入数据(一般是 tf.float32),利用tf.cast()函数将数据类型转换成 tf.float32 ,label 亦然。

    还有一点需要注意:
    因为用到的文件队列操作,这里需要开启多线程(指定线程数量,默认为1)。

    part 4 组合数据 batching

    在训练网络时,通常将训练数据分成小批量的数据进行训练,这样能够提高模型训练效率。tensorflow 提供了tf.train.batch()tf.train.shuffle_batch函数来将组织小批量数据。

    batch_size = 64
    
    min_after_dequeue = 64 # 定义出队时最少元素个数来保证随机打乱的顺序
    capacity = min_after_dequeue+3*batch_size # batch 队列中最多可以存储的数据个数
    batch_x,batch_y = tf.train.shuffle_batch([image,label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)
    
    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        # print(sess.run(files))
    
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)    
    
        for i in range(10):
            b_x,b_y = sess.run([batch_x,batch_y])
            print(b_x,b_y)
    
        coord.request_stop()
        coord.join(threads)
    

    part 5 输入数据处理框架

    1.生成用于训练和测试的 tfrecord 文件
    2.定义计算图
    3.开启会话,在训练过程中,从不同文件中读取小批量数据(是否按顺序,可选)进行训练/验证

    相关文章

      网友评论

          本文标题:TFRecord 统一输入数据格式和组合数据

          本文链接:https://www.haomeiwen.com/subject/rvqahqtx.html