美文网首页
tfrecord这个锤锤

tfrecord这个锤锤

作者: 林桉 | 来源:发表于2019-12-16 15:53 被阅读0次

    什么是TFRecord?

    TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。

    image.png

    实际上,TFRecord是一种二进制文件,其能更好的利用内存,其内部包含了多个tf.train.Example, 而Example是protocol buffer(protobuf) 数据标准的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而 每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:

    bytes_list: 可以存储string 和byte两种数据类型。
    float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
    int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。
    值的一提的是,TensorFlow 源码中到处可见.proto 的文件,且这些文件定义了TensorFlow重要的数据结构部分,且多种语言可直接使用这类数据,很强大。

    protobuf这个锤锤

    优点:

    • 平台无关,语言无关,可扩展;
    • 提供了友好的动态库,使用简单;
    • 解析速度快,比对应的XML快约20-100倍;
    • 序列化数据非常简洁、紧凑,与XML相比,其序列化之后的数据量约为1/3到1/10。


      image.png

    安装 https://blog.csdn.net/xxjuanq_only_one/article/details/50465272

    import "Common.proto";      // 引入Common.proto,位于Protobuf sdk中
    
    option optimize_for = LITE_RUNTIME;
    
    option java_package = "com.xxxx.entity.pb";    // 生成类的包名
    option java_outer_classname = "PayInfo";       // 生成类的类名
    
    message PayInfo{
        required string payid = 1;             // 支付相关的字段信息
        optional string goodinfo = 2;          // optional 为可选参数
        required string prepayid = 3;          // required为必填参数
        optional string mode = 4;
        optional int  userid = 5;
        repeated string  extra = 6;           // repeated 为数组
    } 
    

    protoc --java_out ./ ./PayInfo.proto

    Why 用TFRecord 这个锤锤 ?

    TFRecord 并非是TensorFlow唯一支持的数据格式,你也可以使用CSV或文本等格式,但是对于TensorFlow来说,TFRecord 是最友好的,也是最方便的。前面提到,TFRecord内部是一系列实现了protocol buffer数据标准的Example,对于大型数据,对比其余数据格式,protocol buffer类型的数据优势很明显。

    转TFrecord

    writer = tf.python_io.TFRecordWriter(out_file_name)  # 1. 定义 writer对象
    
    for data in dataes:
        context = dataes[0]
        question = dataes[1]
        answer = dataes[2]
    
        """ 2. 定义features """
        example = tf.train.Example(
            features = tf.train.Features(
                feature = {
                   'context': tf.train.Feature(
                     int64_list=tf.train.Int64List(value=context)),
                   'question': tf.train.Feature(
                     int64_list=tf.train.Int64List(value=question)),
                   'answer': tf.train.Feature(
                     int64_list=tf.train.Int64List(value=answer))
                }))
    

    读取API

    https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset

    举个🌰

    from __future__ import absolute_import, division, print_function
    
    import csv
    import requests
    import tensorflow as tf
    # Download Titanic dataset (in csv format).
    d = requests.get("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/titanic_dataset.csv")
    with open("titanic_dataset.csv", "wb") as f:
        f.write(d.content)
    # Generate Integer Features.
    def build_int64_feature(data):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))
    
    # Generate Float Features.
    def build_float_feature(data):
        return tf.train.Feature(float_list=tf.train.FloatList(value=[data]))
    
    # Generate String Features.
    def build_string_feature(data):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
    
    # Generate a TF `Example`, parsing all features of the dataset.
    def convert_to_tfexample(survived, pclass, name, sex, age, sibsp, parch, ticket, fare):
        return tf.train.Example(
            features=tf.train.Features(
                feature={
                    'survived': build_int64_feature(survived),
                    'pclass': build_int64_feature(pclass),
                    'name': build_string_feature(name),
                    'sex': build_string_feature(sex),
                    'age': build_float_feature(age),
                    'sibsp': build_int64_feature(sibsp),
                    'parch': build_int64_feature(parch),
                    'ticket': build_string_feature(ticket),
                    'fare': build_float_feature(fare),
                })
        )
    
    # Open dataset file.
    with open("titanic_dataset.csv") as f:
        # Output TFRecord file.
        with tf.io.TFRecordWriter("titanic_dataset.tfrecord") as w:
            # Generate a TF Example for all row in our dataset.
            # CSV reader will read and parse all rows.
            reader = csv.reader(f, skipinitialspace=True)
            for i, record in enumerate(reader):
                # Skip header.
                if i == 0:
                    continue
                survived, pclass, name, sex, age, sibsp, parch, ticket, fare = record
                # Parse each csv row to TF Example using the above functions.
                example = convert_to_tfexample(int(survived), int(pclass), name, sex, float(age), int(sibsp), int(parch), ticket, float(fare))
                # Serialize each TF Example to string, and write to TFRecord file.
                w.write(example.SerializeToString())
    # Build features template, with types.
    features = {
        'survived': tf.io.FixedLenFeature([], tf.int64),
        'pclass': tf.io.FixedLenFeature([], tf.int64),
        'name': tf.io.FixedLenFeature([], tf.string),
        'sex': tf.io.FixedLenFeature([], tf.string),
        'age': tf.io.FixedLenFeature([], tf.float32),
        'sibsp': tf.io.FixedLenFeature([], tf.int64),
        'parch': tf.io.FixedLenFeature([], tf.int64),
        'ticket': tf.io.FixedLenFeature([], tf.string),
        'fare': tf.io.FixedLenFeature([], tf.float32),
    }
    
    # Create TensorFlow session.
    sess = tf.Session()
    
    # Load TFRecord data.
    filenames = ["titanic_dataset.tfrecord"]
    data = tf.data.TFRecordDataset(filenames)
    
    # Parse features, using the above template.
    def parse_record(record):
        return tf.io.parse_single_example(record, features=features)
    # Apply the parsing to each record from the dataset.
    data = data.map(parse_record)
    
    # Refill data indefinitely.
    data = data.repeat()
    # Shuffle data.
    data = data.shuffle(buffer_size=1000)
    # Batch data (aggregate records together).
    data = data.batch(batch_size=4)
    # Prefetch batch (pre-load batch for faster consumption).
    data = data.prefetch(buffer_size=1)
    
    # Create an iterator over the dataset.
    iterator = data.make_initializable_iterator()
    # Initialize the iterator.
    sess.run(iterator.initializer)
    
    # Get next data batch.
    x = iterator.get_next()
    
    # Dequeue data and display.
    for i in range(3):
        print(sess.run(x))
        print("")
    

    相关文章

      网友评论

          本文标题:tfrecord这个锤锤

          本文链接:https://www.haomeiwen.com/subject/umldnctx.html