美文网首页
ops::TFRecordReader类使用

ops::TFRecordReader类使用

作者: FredricZhu | 来源:发表于2022-03-24 14:24 被阅读0次

TFRecordReader类内部使用protobuf,可以在不同的语言之间交换数据。主要是用于机器学习的特征数据。但是这个recordreader在读取时,一次只能读取一个数据,也就是per record的读取。
Python 侧构造TFRecord的代码如下,
test_create_tf_record.py

import tensorflow as tf
import numpy as np
import json


tfrecord_filename = '/tmp/train.tfrecord'
# 创建.tfrecord文件,准备写入
writer = tf.compat.v1.python_io.TFRecordWriter(tfrecord_filename)
for i in range(100):
    img_raw = np.random.random_integers(0,255,size=(30, 7)) # 创建30*7,取值在0-255之间随机数组
    img_raw = bytes(json.dumps(img_raw.tolist()), "utf-8")
    example = tf.compat.v1.train.Example(features=tf.train.Features(
            feature={
            # Int64List储存int数据
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])), 
            # 储存byte二进制数据
            'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
            }))
    # 序列化过程
    writer.write(example.SerializeToString()) 
writer.close()

程序目录结构如下,


image.png

程序代码如下,
CMakeLists.txt

cmake_minimum_required(VERSION 3.3)


project(test_parse_ops)

set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")

set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)

set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})

set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})

set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)

include_directories(${INCLUDE_DIRS})


file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

tf_record_reader_test.cpp

#include <string>
#include <vector>
#include <array>
#include <fstream>
#include <glog/logging.h>
#include "tensorflow/core/platform/test.h"
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
#include "tf_/queue_runner.h"


using namespace tensorflow;


int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    ::testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}

TEST(TfArrayOpsTests, FixLenRecordReader) {
    // 定长读取文本文件
    // https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/fixed-length-record-reader#classtensorflow_1_1ops_1_1_fixed_length_record_reader_1aa6ad72f08d89016af3043f72912d11eb
    Scope root = Scope::NewRootScope();
    
    auto attrs = ops::FIFOQueue::Capacity(200);
    auto queue_ = ops::FIFOQueue(root.WithOpName("queue"), {DT_STRING}, attrs);

    auto tensor_ = ops::Const(root, {"/cppwork/_tf/test_parse_ops/data/train.tfrecord"});
    auto enque_ = ops::QueueEnqueueMany(root.WithOpName("enque"), queue_, {tensor_});
    auto close_ = ops::QueueClose(root.WithOpName("close"), queue_);

    auto reader = ops::TFRecordReader(root);
    auto read_res = ops::ReaderRead(root.WithOpName("rec_read"), reader, queue_);

    Tensor dense_def0(DT_STRING, {1});
    Tensor dense_def1(DT_INT64, {1});
    // 1. 这个函数很坑,它读取TFRecordReader对象的输出值,read_res.value, 接收两个和输出值类型相同的默认输出值
    // dense_def0 , dense_def1,当然你的数据如果有三个Feature,这里就三个默认值,注意默认值需要与输出值类型相同
    // 2. "img_raw", "label" 是Python侧命名的标签
    // 3. {1}, {1} 是代表单个特征的大小
    // 4. 注意这里支持的类型只有 DT_INT64, DT_STRING和 DT_FLOAT64,其中DT_STRING在Python侧表现为bytearray
    auto parse_op = ops::ParseSingleExample(root.WithOpName("parse_op"), {read_res.value}, {dense_def0, dense_def1}, 0, {}, {"img_raw", "label"}, {}, {{1}, {1}});
    
    SessionOptions options;
    std::unique_ptr<Session> session(NewSession(options));
    GraphDef graph_def;
    TF_EXPECT_OK(root.ToGraphDef(&graph_def));
    session->Create(graph_def);

    QueueRunnerDef queue_runner_def =
      test::BuildQueueRunnerDef("queue", {"enque"}, "close", "", {tensorflow::error::CANCELLED});
    std::unique_ptr<QueueRunner> qr;
    TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    TF_CHECK_OK(qr->Start(session.get()));
    TF_EXPECT_OK(session->Run({}, {}, {"enque"}, nullptr));
    TF_EXPECT_OK(session->Run({}, {}, {"close"}, nullptr));

    std::vector<Tensor> outputs;
    // 这里 Run 一次会获取一个特征
    for(int i=0; i< 100; ++i) {
        std::vector<Tensor> outputs_res;
        session->Run({}, {parse_op.dense_values[0].name(), parse_op.dense_values[1].name()}, {}, &outputs_res);
        std::cout << outputs_res[0].DebugString() << "\n";
        std::cout << outputs_res[1].DebugString() << "\n";
        auto res = test::GetTensorValue<int64>(outputs_res[1]);
        ASSERT_EQ(i, res[0]);
    }

    TF_EXPECT_OK(qr->Join());
}

程序输出如下,


image.png

相关文章

  • ops::TFRecordReader类使用

    TFRecordReader类内部使用protobuf,可以在不同的语言之间交换数据。主要是用于机器学习的特征数据...

  • OPS - tcpdump使用

    tcpdump 官网 -> http://www.tcpdump.org 1. 安装步骤 在官网分别下载 Tcpd...

  • 遍历二维数组

    vector& ops 使用for(auto op:ops)进行遍历,而不需要求出二维数组的长度

  • 接口目录

    math_ops(一)math_ops函数使用,本篇为算术函数和基本数学函数。1.1 tf.add(x,y) ...

  • OPS规范及各版本对比

    OPS 规范Intel官方主页Open Pluggable Specification(OPS and OPS+)...

  • Tensorflow C++ 使用ops::Equal和ops:

    到此为止,所有标准C++的操作,全部使用ops算子替换完成。本例最重要的函数是GetMatchNum。使用的是mn...

  • 转发

    /*转发 */ onShareAppMessage: function(ops) { if (ops.fro...

  • ISP:接口隔离原则

    在上图应用中,有多个用户需要操作OPS类。现在,我们假设这里的User1只需要使用op1,User2只需要使用op...

  • 项目 ajax封装

    Ajax的基本封装 function ajax(ops){ // 先处理默认属性 ops.type = ops.t...

  • tensorflow math api 汇总

    Defined intensorflow/python/ops/math_ops.py

网友评论

      本文标题:ops::TFRecordReader类使用

      本文链接:https://www.haomeiwen.com/subject/aenkjrtx.html