美文网首页
TFRecord实践

TFRecord实践

作者: 钢镚儿_e134 | 来源:发表于2020-08-06 14:01 被阅读0次

在工作中,需要在训练模型的过程中,读入大规模稀疏矩阵,因此考虑用tfrecord进行加载

1.生TFRecord

import tensorflow as tf
import numpy as np
"""
txt文件中保存的是矩阵每一行的行坐标,列坐标,以及元素值
数据格式为:‘行坐标’ + ‘[对应所有列坐标]’ + ‘[对应所有元素值]’
"""
def write_TFRecord(srcpath, dstpath):  
    writer = tf.python_io.TFRecordWriter(dstpath)
    f = open(srcpath)
    line = f.readline()
    while line:
        line_ = line.strip().split('\t')
        cols = eval(line[1])
        vals = eval(line[2])
        rows = [int(line[0])]
        features = tf.train.Features(
            feature={'rows': tf.train.Feature(int64_list=tf.train.Int64List(value=rows)),
                    'photos': tf.train.Feature(int64_list=tf.train.Int64List(value=cols)),
                    'vals': tf.train.Feature(float_list=tf.train.FloatList(value=vals))
                    })
            
        example = tf.train.Example(features=features)
        writer.write(example.SerializeToString())
    writer.close()          
    f.close()

2. 利用tf.data.TFRecordDataset接口进行解析

2.1 将每一行的值解析为稠密张量

def parser(example):
    dicts = {
        'rows': tf.FixedLenFeature(shape=[],dtype=tf.int64),
        'cols': tf.VarLenFeature(dtype=tf.int64), #由于cols为变长,需要使用 tf.VarLenFeature
        'vals': tf.VarLenFeature(dtype=tf.float32)
    }
    parsed_example = tf.parse_single_example(example, dicts)
    rows = parsed_example['rows']
    cols = parsed_example['cols']
    vals = parsed_example['vals']
    return rows, tf.sparse_tensor_to_dense(rows), tf.sparse_tensor_to_dense(vals)
# 采用这种方式,返回的是稀疏张量,需要用tf.sparse_tensor_to_dense转化为稠密张量

def get_batch_dataset(recordfile, parser):
    dataset = tf.data.TFRecordDataset(recordfile).map(parser).padded_batch(2, padded_shapes=([],[None],[None]))
# 由于row_index跟vals均不为定长,无法进行batch,所以需要对其进行填充,将短的张量用0填充,直到其长度与batch中最长的张量相等
    return dataset

dataset = get_batch_dataset('tfrecord',  parser)

2.2 直接读取为稀疏张量

def parser1(example):
    my_example_features = {'sparse': tf.SparseFeature(index_key=['rows', 'cols'],
                                                  value_key='vals',
                                                  dtype=tf.float32,
                                                  size=[1,max_col])} #size[0]表示一行,size[1]表示稀疏矩阵的列数
    parsed_example = tf.parse_single_example(example, my_example_features)
    return parsed_example['sparse']

def get_batch_dataset(recordfile, parser):                            
    dataset = tf.data.TFRecordDataset(recordfile).map(parser).repeat(2).batch(20000)
    return dataset

dataset = get_batch_dataset('tfrecord',  parser)

3. Iterator的使用

相关文章

  • TFRecord实践

    在工作中,需要在训练模型的过程中,读入大规模稀疏矩阵,因此考虑用tfrecord进行加载 1.生TFRecord ...

  • (Tensorflow)TFRecord样例程序读写

    一:TFRecord样例程序读 二:TFRecord样例程序写

  • tensorflow中的读写tfrecord文件

    之前一直对tfrecord不是很懂,今天终于弄明白了,果然是实践出真知。先留个坑,日后再填。

  • tfrecord文件读写

    将数据集保存为tfrecord文件 tensorflow读取tfrecord文件用于网络训练 参考文章:Tenso...

  • TFRecord

    TFRecord 是 TensorFlow 提供的用于高速读取数据的文件格式。该post介绍了如何将数据转换为TF...

  • tfrecord这个锤锤

    什么是TFRecord? TFRecord 是Google官方推荐的一种数据格式,是Google专门为Tensor...

  • TFRecord 读取和写入

    TFRecord 是tensorflow中用于存储二进制数据的一种简单格式。 # 1. 写入TFRecord文件 ...

  • Tensorflow(一) TFRecord生成与读取

    TFRecord生成 一、为什么使用TFRecord? 正常情况下我们训练文件夹经常会生成 train, test...

  • tfrecord开启多线程

    关于什么时候使用在tfrecord文件方面上使用多线程问题,tensorflow封装了一套对tfrecord多进程...

  • Tensorflow 自定义生成tfrecord文件

    一.TFRecord简介 TensorFlow提供了TFRecord的格式来统一存储数据,它是一种能够将图像数据和...

网友评论

      本文标题:TFRecord实践

      本文链接:https://www.haomeiwen.com/subject/gcabrktx.html