美文网首页
TFRecord 读取和写入

TFRecord 读取和写入

作者: JeremyL | 来源:发表于2022-12-17 17:52 被阅读0次

TFRecord 是tensorflow中用于存储二进制数据的一种简单格式。

# 1. 写入TFRecord文件

import tensorflow as tf
import numpy as np

## \##将数据转换为tf.Example接受的格式
#The following functions can be used to convert a value to a type compatible with tf.Example.
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#数据格式转换例子
print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
feature = _float_feature(np.exp(1))
feature.SerializeToString() #序列化

## \##将数据转换为tf.Example接受的格式
def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.Example-compatible
  # data type.
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

## \##例子
# This is an example observation from the dataset.
example_observation = []
serialized_example = serialize_example(False, 4, 'goat', 0.9876)
serialized_example

#tf.train.Example.FromString可以解码二进制信息
example_proto = tf.train.Example.FromString(serialized_example)
example_proto

#写入文件
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialize_example(1, 4, "yummy food.", 0.9876))
writer.close()

filename = 'test.tfrecord'
writer = tf.io.TFRecordWriter(filename)
writer.write(serialize_example(False, 4, b'goat', 0.9876))
writer.close()

# The number of observations in the dataset.
n_observations = int(1e4)
# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)
# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)
# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# Float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)
# Write the `tf.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)

# 2. 读取TFRecord文件

filenames = ['test.tfrecord']
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

for raw_record in raw_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    print(example)

# 3. 实例

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
import IPython.display as display
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))
## \## 写入 TFRecord 文件
image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())

## \## 读取TFRecord 文件
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
## \##从 TFRecord 文件中恢复图像:
for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

# 4. tf.Example

tf.Example 中储存着 {"string": tf.train.Feature}
tf.train.Feature接受三种数据类型

1.  tf.train.BytesList
    *   `string`
    *   `byte`
2.  tf.train.FloatList
    *   `float` (`float32`)
    *   `double` (`float64`)
3.  tf.train.Int64List
    *   `bool`
    *   `enum`
    *   `int32`
    *   `uint32`
    *   `int64`
    *   `uint64`

# 参考:

Tensorflow之TFRecord的原理和使用心得
TFRecord 的保存与读取
TFRecord and tf.train.Example

相关文章

网友评论

      本文标题:TFRecord 读取和写入

      本文链接:https://www.haomeiwen.com/subject/zqbmfdtx.html