美文网首页
tf.FixedLengthRecordReader测试-定长记

tf.FixedLengthRecordReader测试-定长记

作者: FredricZhu | 来源:发表于2022-03-23 14:53 被阅读0次

本文演示Tensorflow C++版本FixedLengthRecordReader的用法。
需要先构建一个QueueRunner, 然后把队列相关步骤加入到QueueRunnerDef里面。
最后从Queue里面逐个读取数据即可。如果想要达到shuffle读取数据的目的,可以把ops::ReaderRead步骤名"fix_read"加入到 test::BuildQueueRunnerDef方法的enqueue_ops参数里面。
因为Tensorflow-CC那个docker有几个CPP文件没有编译到so里面去,我自己把相应的代码摘出来,编了一下,还是挺简单的。

废话不多说,上代码了。


image.png

conanfile.txt

 [requires]
 gtest/1.10.0
 glog/0.4.0
 protobuf/3.9.1
 eigen/3.4.0
 dataframe/1.20.0
 opencv/3.4.17
 boost/1.76.0
 abseil/20210324.0

 [generators]
 cmake

tf_fx_length_record_reader_test.cpp

#include <string>
#include <vector>
#include <array>
#include <fstream>
#include <glog/logging.h>
#include "tensorflow/core/platform/test.h"
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
#include "tf_/queue_runner.h"


using namespace tensorflow;


int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    ::testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}

TEST(TfArrayOpsTests, FixLenRecordReader) {
    // 定长读取文本文件
    // https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/fixed-length-record-reader#classtensorflow_1_1ops_1_1_fixed_length_record_reader_1aa6ad72f08d89016af3043f72912d11eb
    Scope root = Scope::NewRootScope();
    
    auto attrs = ops::FIFOQueue::Capacity(100);
    auto queue_ = ops::FIFOQueue(root.WithOpName("queue"), {DT_STRING}, attrs);

    auto tensor_ = ops::Const(root, {"1.txt", "2.txt"});
    auto enque_ = ops::QueueEnqueueMany(root.WithOpName("enque"), queue_, {tensor_});
    auto close_ = ops::QueueClose(root.WithOpName("close"), queue_);

    auto reader = ops::FixedLengthRecordReader(root, 4);
    auto read_res = ops::ReaderRead(root.WithOpName("fix_read"), reader, queue_);
    
    SessionOptions options;
    std::unique_ptr<Session> session(NewSession(options));
    GraphDef graph_def;
    TF_EXPECT_OK(root.ToGraphDef(&graph_def));
    session->Create(graph_def);

    QueueRunnerDef queue_runner_def =
      test::BuildQueueRunnerDef("queue", {"enque"}, "close", "", {tensorflow::error::CANCELLED});
    std::unique_ptr<QueueRunner> qr;
    TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
    TF_CHECK_OK(qr->Start(session.get()));

    TF_EXPECT_OK(session->Run({}, {}, {"enque"}, nullptr));

    std::vector<Tensor> outputs;
    
    std::array<std::string, 6> arr {"abcd", "efgh", "ijkl", /*Contents of 1.txt*/
                                    "ABCD", "EFGH", "IJKL" /*Contents of 2.txt*/};
    
    for(int i=0; i< 6; ++i) {
        session->Run({}, {read_res.key.name(), read_res.value.name()}, {}, &outputs);
        if(outputs.size() > 0) {
            for(auto const& output: outputs) {
                std::cout << output.DebugString() << "\n";
            }
            test::ExpectTensorEqual<tstring>(outputs[1], test::AsTensor<tstring>({arr[i]}, {}));
        }
    }

    TF_EXPECT_OK(session->Run({}, {}, {"close"}, nullptr));
    TF_EXPECT_OK(qr->Join());
}

tf_/tensor_testutil.h

/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_

#include <numeric>
#include <limits>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tf_/queue_runner.h"
#include <iostream>

namespace tensorflow {
namespace test {

// Constructs a scalar tensor with 'val'.
template <typename T>
Tensor AsScalar(const T& val) {
  Tensor ret(DataTypeToEnum<T>::value, {});
  ret.scalar<T>()() = val;
  return ret;
}

using Code = tensorflow::error::Code;

QueueRunnerDef BuildQueueRunnerDef(
    const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
    const std::string& close_op, const std::string& cancel_op,
    const std::vector<Code>& queue_closed_error_codes) {
  QueueRunnerDef queue_runner_def;
  *queue_runner_def.mutable_queue_name() = queue_name;
  for (const std::string& enqueue_op : enqueue_ops) {
    *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
  }
  *queue_runner_def.mutable_close_op_name() = close_op;
  *queue_runner_def.mutable_cancel_op_name() = cancel_op;
  for (const auto& error_code : queue_closed_error_codes) {
    *queue_runner_def.mutable_queue_closed_exception_types()->Add() =
        error_code;
  }
  return queue_runner_def;
}

// Constructs a flat tensor with 'vals'.
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals) {
  Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
  std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
  return ret;
}

template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor) {
   // 打印Tensor值
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    for(decltype(size) i=0; i<size; ++i) {
        os << tensor_pt[i] << "\n";
    }
    return os;
}

template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor, int per_line_count) {
   // 打印Tensor值
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    for(decltype(size) i=0; i<size; ++i) {
        if(i!=0 && (i+1)%per_line_count == 0) {
          os << tensor_pt[i] << "\n";
        }else {
          os << tensor_pt[i] << "\t";
        }
    }
    return os;
}

template <typename T>
std::vector<T> GetTensorValue( Tensor const& tensor) {
   // 获取tensor的值
    std::vector<T> res;
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    for(decltype(size) i=0; i<size; ++i) {
        res.emplace_back(tensor_pt[i]);
    }
    return res;
}

template <typename OpType>
std::vector<Output> CreateReduceOP(Scope const& s, DataType tf_type, PartialTensorShape const& shape, bool keep_dims) {
  std::vector<Output> outputs{};
  auto input = ops::Placeholder(s.WithOpName("input"), tf_type, ops::Placeholder::Shape(shape));
  auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
  typename OpType::Attrs op_attrs;
  op_attrs.keep_dims_ = keep_dims;
  auto op = OpType(s.WithOpName("my_reduce"), input, axis, op_attrs);
  outputs.emplace_back(std::move(input));
  outputs.emplace_back(std::move(axis));
  outputs.emplace_back(std::move(op));
  return outputs;
}

// Constructs a tensor of "shape" with values "vals".
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
  Tensor ret;
  CHECK(ret.CopyFrom(AsTensor(vals), shape));
  return ret;
}

// Fills in '*tensor' with 'vals'. E.g.,
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillValues<float>(&x, {11, 21, 21, 22});
template <typename T>
void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
  auto flat = tensor->flat<T>();
  CHECK_EQ(flat.size(), vals.size());
  if (flat.size() > 0) {
    std::copy_n(vals.data(), vals.size(), flat.data());
  }
}

// Fills in '*tensor' with 'vals', converting the types as needed.
template <typename T, typename SrcType>
void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
  auto flat = tensor->flat<T>();
  CHECK_EQ(flat.size(), vals.size());
  if (flat.size() > 0) {
    size_t i = 0;
    for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
      flat(i) = T(*itr);
    }
  }
}

// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillIota<float>(&x, 1.0);
template <typename T>
void FillIota(Tensor* tensor, const T& val) {
  auto flat = tensor->flat<T>();
  std::iota(flat.data(), flat.data() + flat.size(), val);
}

// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillFn<float>(&x, [](int i)->float { return i*i; });
template <typename T>
void FillFn(Tensor* tensor, std::function<T(int)> fn) {
  auto flat = tensor->flat<T>();
  for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
}

// Expects "x" and "y" are tensors of the same type, same shape, and identical
// values (within 4 ULPs for floating point types unless explicitly disabled).
enum class Tolerance {
  kNone,
  kDefault,
};
void ExpectEqual(const Tensor& x, const Tensor& y,
                 Tolerance t = Tolerance ::kDefault);

// Expects "x" and "y" are tensors of the same (floating point) type,
// same shape and element-wise difference between x and y is no more
// than atol + rtol * abs(x). If atol or rtol is negative, the data type's
// epsilon * kSlackFactor is used.
void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
                 double rtol = -1.0);

// Expects "x" and "y" are tensors of the same type T, same shape, and
// equal values. Consider using ExpectEqual above instead.
template <typename T>
void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
  EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
  ExpectEqual(x, y);
}

// Expects "x" and "y" are tensors of the same type T, same shape, and
// approximate equal values. Consider using ExpectClose above instead.
template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) {
  EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
  ExpectClose(x, y, atol, /*rtol=*/0.0);
}

// For tensor_testutil_test only.
namespace internal_test {
::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y,
                                   double atol = -1.0, double rtol = -1.0);
::testing::AssertionResult IsClose(float x, float y, double atol = -1.0,
                                   double rtol = -1.0);
::testing::AssertionResult IsClose(double x, double y, double atol = -1.0,
                                   double rtol = -1.0);
}  // namespace internal_test

}  // namespace test
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_

程序输出如下,


image.png

build/1.txt

abcdefghijkl

build/2.txt

ABCDEFGHIJKL

相关文章

  • tf.FixedLengthRecordReader测试-定长记

    本文演示Tensorflow C++版本FixedLengthRecordReader的用法。需要先构建一个Que...

  • python 批量生成随机字符串的hash值

    python 批量生成随机字符串的hash值 需求 由于测试需要,需产生大量SHA1序列,通过生成随机定长序列,然...

  • 8-28

    测试监控程序 无事可记

  • mysql 优化1

    1.表的优化:定长和不定长分离 Int 4个字节,char4个字节,定长,time也是定长 核心且常用的字段,应该...

  • LevelDB varint 变长编码

    定义在coding.h 文件中。 固定长度编码 注意到有注释说指令优化,写了简单代码,测试了下,的确如此。 变长编...

  • Mysql数据库类型

    一、列数据类型: 1、字符类型:char(固定长度)、varchar(不固定长度)、binary(固定长度)、va...

  • mysql 字段

    String 及文本 char定长,最大只有256bytes。存入内容长度大于指定长度,严格模式报错,否则截取定长...

  • Scanner的操作

    Scanner的操作 知道固定长度 不知道固定长度

  • 【mysql】char 和 varchar的区别

    char 定长,最大可存储255个字节,指定长度,不满足使用空格填充。varchar 不定长,最大可存储65525...

  • Scala基础——数组

    定长数组 数组一般包括定长数组和变长数组,在Scala中使用Array进行声明定长数组注意:scalad的索引标示...

网友评论

      本文标题:tf.FixedLengthRecordReader测试-定长记

      本文链接:https://www.haomeiwen.com/subject/kzbcjrtx.html