TFRecordReader类内部使用protobuf,可以在不同的语言之间交换数据。主要是用于机器学习的特征数据。但是这个recordreader在读取时,一次只能读取一个数据,也就是per record的读取。
Python 侧构造TFRecord的代码如下,
test_create_tf_record.py
import tensorflow as tf
import numpy as np
import json
tfrecord_filename = '/tmp/train.tfrecord'
# 创建.tfrecord文件,准备写入
writer = tf.compat.v1.python_io.TFRecordWriter(tfrecord_filename)
for i in range(100):
img_raw = np.random.random_integers(0,255,size=(30, 7)) # 创建30*7,取值在0-255之间随机数组
img_raw = bytes(json.dumps(img_raw.tolist()), "utf-8")
example = tf.compat.v1.train.Example(features=tf.train.Features(
feature={
# Int64List储存int数据
'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
# 储存byte二进制数据
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
}))
# 序列化过程
writer.write(example.SerializeToString())
writer.close()
程序目录结构如下,

程序代码如下,
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(test_parse_ops)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)
set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})
set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)
include_directories(${INCLUDE_DIRS})
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
tf_record_reader_test.cpp
#include <string>
#include <vector>
#include <array>
#include <fstream>
#include <glog/logging.h>
#include "tensorflow/core/platform/test.h"
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
#include "tf_/queue_runner.h"
using namespace tensorflow;
int main(int argc, char** argv) {
FLAGS_log_dir = "./";
FLAGS_alsologtostderr = true;
// 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
FLAGS_minloglevel = 0;
Debug::DeathHandler dh;
google::InitGoogleLogging("./logs.log");
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
TEST(TfArrayOpsTests, FixLenRecordReader) {
// 定长读取文本文件
// https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/fixed-length-record-reader#classtensorflow_1_1ops_1_1_fixed_length_record_reader_1aa6ad72f08d89016af3043f72912d11eb
Scope root = Scope::NewRootScope();
auto attrs = ops::FIFOQueue::Capacity(200);
auto queue_ = ops::FIFOQueue(root.WithOpName("queue"), {DT_STRING}, attrs);
auto tensor_ = ops::Const(root, {"/cppwork/_tf/test_parse_ops/data/train.tfrecord"});
auto enque_ = ops::QueueEnqueueMany(root.WithOpName("enque"), queue_, {tensor_});
auto close_ = ops::QueueClose(root.WithOpName("close"), queue_);
auto reader = ops::TFRecordReader(root);
auto read_res = ops::ReaderRead(root.WithOpName("rec_read"), reader, queue_);
Tensor dense_def0(DT_STRING, {1});
Tensor dense_def1(DT_INT64, {1});
// 1. 这个函数很坑,它读取TFRecordReader对象的输出值,read_res.value, 接收两个和输出值类型相同的默认输出值
// dense_def0 , dense_def1,当然你的数据如果有三个Feature,这里就三个默认值,注意默认值需要与输出值类型相同
// 2. "img_raw", "label" 是Python侧命名的标签
// 3. {1}, {1} 是代表单个特征的大小
// 4. 注意这里支持的类型只有 DT_INT64, DT_STRING和 DT_FLOAT64,其中DT_STRING在Python侧表现为bytearray
auto parse_op = ops::ParseSingleExample(root.WithOpName("parse_op"), {read_res.value}, {dense_def0, dense_def1}, 0, {}, {"img_raw", "label"}, {}, {{1}, {1}});
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
GraphDef graph_def;
TF_EXPECT_OK(root.ToGraphDef(&graph_def));
session->Create(graph_def);
QueueRunnerDef queue_runner_def =
test::BuildQueueRunnerDef("queue", {"enque"}, "close", "", {tensorflow::error::CANCELLED});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
TF_CHECK_OK(qr->Start(session.get()));
TF_EXPECT_OK(session->Run({}, {}, {"enque"}, nullptr));
TF_EXPECT_OK(session->Run({}, {}, {"close"}, nullptr));
std::vector<Tensor> outputs;
// 这里 Run 一次会获取一个特征
for(int i=0; i< 100; ++i) {
std::vector<Tensor> outputs_res;
session->Run({}, {parse_op.dense_values[0].name(), parse_op.dense_values[1].name()}, {}, &outputs_res);
std::cout << outputs_res[0].DebugString() << "\n";
std::cout << outputs_res[1].DebugString() << "\n";
auto res = test::GetTensorValue<int64>(outputs_res[1]);
ASSERT_EQ(i, res[0]);
}
TF_EXPECT_OK(qr->Join());
}
程序输出如下,

网友评论