美文网首页
Tensorflow C++ 使用xtensor加载Fashio

Tensorflow C++ 使用xtensor加载Fashio

作者: FredricZhu | 来源:发表于2022-04-01 17:27 被阅读0次

本文主要使用xtensor读取Numpy生成的npy文件,进行模型精度的测试。
Python侧带检查点的断点续训的代码,

import tensorflow as tf
import os
import numpy as np

np.set_printoptions(threshold=np.inf)

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

print(model.trainable_variables)
file = open('./weights.txt', 'w')

for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

Python侧保存测试集的代码

from PIL import Image
import numpy as np
import tensorflow as tf

mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)

x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_test = x_test / 255.0
x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.int32)

print("Test data length: ", x_test.shape[0])

# Save file to numpy files
np.save("x_batch.npy", x_test)
np.save("y_p.npy", y_test)

model_save_path = './checkpoint/fashion.ckpt'
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
                                        
model.load_weights(model_save_path)

# 仅使用前12个元素做预测
y_p = y_test[:12]
x_test_ = x_test[:12]
result = model.predict(x_test_)
pred=tf.argmax(result, axis=1)
print("Actual:", pred)
print("Predict:", y_p)

C++ conanfile.txt

 [requires]
 gtest/1.10.0
 glog/0.4.0
 protobuf/3.9.1
 eigen/3.4.0
 dataframe/1.20.0
 opencv/3.4.17
 boost/1.76.0
 abseil/20210324.0
 xtensor/0.23.10

 [generators]
 cmake

C++ CMakeLists.txt

cmake_minimum_required(VERSION 3.3)


project(tf_fashion_predict)

set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")

set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)

set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})

set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})

set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)

include_directories(${INCLUDE_DIRS})


file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

C++ tf_fashion_model_test.cpp

#include <fstream>

#include <tensorflow/c/c_api.h>

#include "death_handler/death_handler.h"

#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"


using namespace tensorflow;

using BatchDef = std::initializer_list<tensorflow::int64>;

int main(int argc, char** argv) {
    Debug::DeathHandler dh;
    ::testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}

Tensor GetInputTensor() {
    // 从numpy 文件读取 fashion测试集
    auto input_tensor = test::LoadNumPy<float>("../data/x_batch.npy", {10000, 28, 28, 1});
    return input_tensor;
}

std::vector<int> GetOutputBatches() {
    // 从numpy 文件读取 fashion 预测结果
    auto pred_tensor = test::LoadNumPy<int>("../data/y_p.npy", {10000, 1});
    auto pred_v = test::GetTensorValue<int>(pred_tensor);
    return pred_v;
}


std::vector<int> ConvertTensorToIndexValue(Tensor const& tensor_) {
    Scope root = Scope::NewRootScope();
    ClientSession session(root);
    std::vector<Tensor> outputs;
    auto dim_ = ops::Const(root, 1);
    auto attrs = ops::ArgMax::OutputType(DT_INT32);
    auto arg_max_op = ops::ArgMax(root, tensor_, dim_, attrs);
    session.Run({arg_max_op.output}, &outputs);
    auto predict_res = test::GetTensorValue<int>(outputs[0]); 
    return predict_res;
}

TEST(TfFashionModelTest, LoadAndPredict) {
    SavedModelBundleLite bundle;
    SessionOptions session_options;
    RunOptions run_options;

    const string export_dir = "../fashion_model";
    TF_CHECK_OK(LoadSavedModel(session_options, run_options, export_dir,
                              {kSavedModelTagServe}, &bundle));
    
    auto input_tensor = GetInputTensor();

    std::vector<tensorflow::Tensor> out_tensors;
    TF_CHECK_OK(bundle.GetSession()->Run({{"serving_default_flatten_input:0", input_tensor}},
    {"StatefulPartitionedCall:0"}, {}, &out_tensors)); 

    std::cout << "Print Index Value\n";
    auto predict_res = ConvertTensorToIndexValue(out_tensors[0]);
    for(auto ele: predict_res) {
        std::cout << ele << "\n";
    }

    auto labels = GetOutputBatches();
    int correct {0};
    for(int i=0; i<predict_res.size(); ++i) {
        if(predict_res[i] == labels[i]) {
            ++ correct;
        }
    }
    
    std::cout << "Total correct: " << correct << "\n";
    std::cout << "Total datasets: " << labels.size() << "\n"; 
    std::cout << "Accuracy is: " << (float)(correct)/labels.size() << "\n";
}

tf_/tensor_testutil.h

/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_

#include <numeric>
#include <limits>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tf_/queue_runner.h"
#include <iostream>
#include "xtensor/xnpy.hpp"
#include "xtensor/xarray.hpp"

namespace tensorflow {
namespace test {

// Constructs a scalar tensor with 'val'.
template <typename T>
Tensor AsScalar(const T& val) {
  Tensor ret(DataTypeToEnum<T>::value, {});
  ret.scalar<T>()() = val;
  return ret;
}

using Code = tensorflow::error::Code;

QueueRunnerDef BuildQueueRunnerDef(
    const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
    const std::string& close_op, const std::string& cancel_op,
    const std::vector<Code>& queue_closed_error_codes) {
  QueueRunnerDef queue_runner_def;
  *queue_runner_def.mutable_queue_name() = queue_name;
  for (const std::string& enqueue_op : enqueue_ops) {
    *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
  }
  *queue_runner_def.mutable_close_op_name() = close_op;
  *queue_runner_def.mutable_cancel_op_name() = cancel_op;
  for (const auto& error_code : queue_closed_error_codes) {
    *queue_runner_def.mutable_queue_closed_exception_types()->Add() =
        error_code;
  }
  return queue_runner_def;
}

// Constructs a flat tensor with 'vals'.
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals) {
  Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
  std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
  return ret;
}

template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor) {
   // 打印Tensor值
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    bool is_uint8 = typeid(tensor_pt[0]) == typeid(uint8);
    for(decltype(size) i=0; i<size; ++i) {
        if(is_uint8) {
          os << (int)tensor_pt[i] << "\n";
        }else {
          os << tensor_pt[i] << "\n";
        }
    }
    return os;
}

template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor, int per_line_count) {
   // 打印Tensor值
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    bool is_uint8 = typeid(tensor_pt[0]) == typeid(uint8);
    for(decltype(size) i=0; i<size; ++i) {
        if(i!=0 && (i+1)%per_line_count == 0) {  
          if(is_uint8) {
            os << (int)tensor_pt[i] << "\n";
          }else {
            os << tensor_pt[i] << "\n";
          }
        }else {
          if(is_uint8) {
            os << (int)tensor_pt[i] << "\t";
          }else {
            os << tensor_pt[i] << "\t";
          }
        }
    }
    return os;
}

template <typename T>
std::vector<T> GetTensorValue( Tensor const& tensor) {
   // 获取tensor的值
    std::vector<T> res;
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    for(decltype(size) i=0; i<size; ++i) {
        res.emplace_back(tensor_pt[i]);
    }
    return res;
}

template <typename OpType>
std::vector<Output> CreateReduceOP(Scope const& s, DataType tf_type, PartialTensorShape const& shape, bool keep_dims) {
  std::vector<Output> outputs{};
  auto input = ops::Placeholder(s.WithOpName("input"), tf_type, ops::Placeholder::Shape(shape));
  auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
  typename OpType::Attrs op_attrs;
  op_attrs.keep_dims_ = keep_dims;
  auto op = OpType(s.WithOpName("my_reduce"), input, axis, op_attrs);
  outputs.emplace_back(std::move(input));
  outputs.emplace_back(std::move(axis));
  outputs.emplace_back(std::move(op));
  return outputs;
}

// Constructs a tensor of "shape" with values "vals".
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
  Tensor ret;
  CHECK(ret.CopyFrom(AsTensor(vals), shape));
  return ret;
}


template <typename T>
Tensor LoadNumPy(char const* file_path, const TensorShape& shape) {
  auto batches = xt::load_npy<T>(file_path);
  auto tensor_ = test::AsTensor<T>(batches, shape);
  return tensor_;
}

// Fills in '*tensor' with 'vals'. E.g.,
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillValues<float>(&x, {11, 21, 21, 22});
template <typename T>
void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
  auto flat = tensor->flat<T>();
  CHECK_EQ(flat.size(), vals.size());
  if (flat.size() > 0) {
    std::copy_n(vals.data(), vals.size(), flat.data());
  }
}

// Fills in '*tensor' with 'vals', converting the types as needed.
template <typename T, typename SrcType>
void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
  auto flat = tensor->flat<T>();
  CHECK_EQ(flat.size(), vals.size());
  if (flat.size() > 0) {
    size_t i = 0;
    for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
      flat(i) = T(*itr);
    }
  }
}

// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillIota<float>(&x, 1.0);
template <typename T>
void FillIota(Tensor* tensor, const T& val) {
  auto flat = tensor->flat<T>();
  std::iota(flat.data(), flat.data() + flat.size(), val);
}

// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
//   Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
//   test::FillFn<float>(&x, [](int i)->float { return i*i; });
template <typename T>
void FillFn(Tensor* tensor, std::function<T(int)> fn) {
  auto flat = tensor->flat<T>();
  for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
}

// Expects "x" and "y" are tensors of the same type, same shape, and identical
// values (within 4 ULPs for floating point types unless explicitly disabled).
enum class Tolerance {
  kNone,
  kDefault,
};
void ExpectEqual(const Tensor& x, const Tensor& y,
                 Tolerance t = Tolerance ::kDefault);

// Expects "x" and "y" are tensors of the same (floating point) type,
// same shape and element-wise difference between x and y is no more
// than atol + rtol * abs(x). If atol or rtol is negative, the data type's
// epsilon * kSlackFactor is used.
void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
                 double rtol = -1.0);

// Expects "x" and "y" are tensors of the same type T, same shape, and
// equal values. Consider using ExpectEqual above instead.
template <typename T>
void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
  EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
  ExpectEqual(x, y);
}

// Expects "x" and "y" are tensors of the same type T, same shape, and
// approximate equal values. Consider using ExpectClose above instead.
template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) {
  EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
  ExpectClose(x, y, atol, /*rtol=*/0.0);
}

// For tensor_testutil_test only.
namespace internal_test {
::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y,
                                   double atol = -1.0, double rtol = -1.0);
::testing::AssertionResult IsClose(float x, float y, double atol = -1.0,
                                   double rtol = -1.0);
::testing::AssertionResult IsClose(double x, double y, double atol = -1.0,
                                   double rtol = -1.0);
}  // namespace internal_test

}  // namespace test
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_

tf_/tensor_testutil.cc

/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tf_/tensor_testutil.h"

#include <cmath>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
namespace test {

static ::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y) {
  if (x.dtype() != y.dtype()) {
    return ::testing::AssertionFailure()
           << "Tensors have different dtypes (" << x.dtype() << " vs "
           << y.dtype() << ")";
  }
  return ::testing::AssertionSuccess();
}

static ::testing::AssertionResult IsSameShape(const Tensor& x,
                                              const Tensor& y) {
  if (!x.IsSameSize(y)) {
    return ::testing::AssertionFailure()
           << "Tensors have different shapes (" << x.shape().DebugString()
           << " vs " << y.shape().DebugString() << ")";
  }
  return ::testing::AssertionSuccess();
}

template <typename T>
static ::testing::AssertionResult EqualFailure(const T& x, const T& y) {
  return ::testing::AssertionFailure()
         << std::setprecision(std::numeric_limits<T>::digits10 + 2) << x
         << " not equal to " << y;
}
static ::testing::AssertionResult IsEqual(float x, float y, Tolerance t) {
  // We consider NaNs equal for testing.
  if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
    return ::testing::AssertionSuccess();
  if (t == Tolerance::kNone) {
    if (x == y) return ::testing::AssertionSuccess();
  } else {
    if (::testing::internal::CmpHelperFloatingPointEQ<float>("", "", x, y))
      return ::testing::AssertionSuccess();
  }
  return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(double x, double y, Tolerance t) {
  // We consider NaNs equal for testing.
  if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
    return ::testing::AssertionSuccess();
  if (t == Tolerance::kNone) {
    if (x == y) return ::testing::AssertionSuccess();
  } else {
    if (::testing::internal::CmpHelperFloatingPointEQ<double>("", "", x, y))
      return ::testing::AssertionSuccess();
  }
  return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y,
                                          Tolerance t) {
  // We consider NaNs equal for testing.
  if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
    return ::testing::AssertionSuccess();

  // Below is a reimplementation of CmpHelperFloatingPointEQ<Eigen::half>, which
  // we cannot use because Eigen::half is not default-constructible.

  if (Eigen::numext::isnan(x) || Eigen::numext::isnan(y))
    return EqualFailure(x, y);

  auto sign_and_magnitude_to_biased = [](uint16_t sam) {
    const uint16_t kSignBitMask = 0x8000;
    if (kSignBitMask & sam) return ~sam + 1;  // negative number.
    return kSignBitMask | sam;                // positive number.
  };

  auto xb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(x));
  auto yb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(y));
  if (t == Tolerance::kNone) {
    if (xb == yb) return ::testing::AssertionSuccess();
  } else {
    auto distance = xb >= yb ? xb - yb : yb - xb;
    const uint16_t kMaxUlps = 4;
    if (distance <= kMaxUlps) return ::testing::AssertionSuccess();
  }
  return EqualFailure(x, y);
}
template <typename T>
static ::testing::AssertionResult IsEqual(const T& x, const T& y, Tolerance t) {
  if (::testing::internal::CmpHelperEQ<T>("", "", x, y))
    return ::testing::AssertionSuccess();
  return EqualFailure(x, y);
}
template <typename T>
static ::testing::AssertionResult IsEqual(const std::complex<T>& x,
                                          const std::complex<T>& y,
                                          Tolerance t) {
  if (IsEqual(x.real(), y.real(), t) && IsEqual(x.imag(), y.imag(), t))
    return ::testing::AssertionSuccess();
  return EqualFailure(x, y);
}

template <typename T>
static void ExpectEqual(const Tensor& x, const Tensor& y,
                        Tolerance t = Tolerance::kDefault) {
  const T* Tx = x.unaligned_flat<T>().data();
  const T* Ty = y.unaligned_flat<T>().data();
  auto size = x.NumElements();
  int max_failures = 10;
  int num_failures = 0;
  for (decltype(size) i = 0; i < size; ++i) {
    EXPECT_TRUE(IsEqual(Tx[i], Ty[i], t)) << "i = " << (++num_failures, i);
    ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
  }
}

template <typename T>
static ::testing::AssertionResult IsClose(const T& x, const T& y, const T& atol,
                                          const T& rtol) {
  // We consider NaNs equal for testing.
  if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
    return ::testing::AssertionSuccess();
  if (x == y) return ::testing::AssertionSuccess();  // Handle infinity.
  auto tolerance = atol + rtol * Eigen::numext::abs(x);
  if (Eigen::numext::abs(x - y) <= tolerance)
    return ::testing::AssertionSuccess();
  return ::testing::AssertionFailure() << x << " not close to " << y;
}

template <typename T>
static ::testing::AssertionResult IsClose(const std::complex<T>& x,
                                          const std::complex<T>& y,
                                          const T& atol, const T& rtol) {
  if (IsClose(x.real(), y.real(), atol, rtol) &&
      IsClose(x.imag(), y.imag(), atol, rtol))
    return ::testing::AssertionSuccess();
  return ::testing::AssertionFailure() << x << " not close to " << y;
}

// Return type can be different from T, e.g. float for T=std::complex<float>.
template <typename T>
static auto GetTolerance(double tolerance) {
  using Real = typename Eigen::NumTraits<T>::Real;
  auto default_tol = static_cast<Real>(5.0) * Eigen::NumTraits<T>::epsilon();
  auto result = tolerance < 0.0 ? default_tol : static_cast<Real>(tolerance);
  EXPECT_GE(result, static_cast<Real>(0));
  return result;
}

template <typename T>
static void ExpectClose(const Tensor& x, const Tensor& y, double atol,
                        double rtol) {
  auto typed_atol = GetTolerance<T>(atol);
  auto typed_rtol = GetTolerance<T>(rtol);

  const T* Tx = x.unaligned_flat<T>().data();
  const T* Ty = y.unaligned_flat<T>().data();
  auto size = x.NumElements();
  int max_failures = 10;
  int num_failures = 0;
  for (decltype(size) i = 0; i < size; ++i) {
    EXPECT_TRUE(IsClose(Tx[i], Ty[i], typed_atol, typed_rtol))
        << "i = " << (++num_failures, i) << " Tx[i] = " << Tx[i]
        << " Ty[i] = " << Ty[i];
    ASSERT_LT(num_failures, max_failures)
        << "Too many mismatches (atol = " << atol << " rtol = " << rtol
        << "), giving up.";
  }
  EXPECT_EQ(num_failures, 0)
      << "Mismatches detected (atol = " << atol << " rtol = " << rtol << ").";
}

void ExpectEqual(const Tensor& x, const Tensor& y, Tolerance t) {
  ASSERT_TRUE(IsSameType(x, y));
  ASSERT_TRUE(IsSameShape(x, y));

  switch (x.dtype()) {
    case DT_FLOAT:
      return ExpectEqual<float>(x, y, t);
    case DT_DOUBLE:
      return ExpectEqual<double>(x, y, t);
    case DT_INT32:
      return ExpectEqual<int32>(x, y);
    case DT_UINT32:
      return ExpectEqual<uint32>(x, y);
    case DT_UINT16:
      return ExpectEqual<uint16>(x, y);
    case DT_UINT8:
      return ExpectEqual<uint8>(x, y);
    case DT_INT16:
      return ExpectEqual<int16>(x, y);
    case DT_INT8:
      return ExpectEqual<int8>(x, y);
    case DT_STRING:
      return ExpectEqual<tstring>(x, y);
    case DT_COMPLEX64:
      return ExpectEqual<complex64>(x, y, t);
    case DT_COMPLEX128:
      return ExpectEqual<complex128>(x, y, t);
    case DT_INT64:
      return ExpectEqual<int64>(x, y);
    case DT_UINT64:
      return ExpectEqual<uint64>(x, y);
    case DT_BOOL:
      return ExpectEqual<bool>(x, y);
    case DT_QINT8:
      return ExpectEqual<qint8>(x, y);
    case DT_QUINT8:
      return ExpectEqual<quint8>(x, y);
    case DT_QINT16:
      return ExpectEqual<qint16>(x, y);
    case DT_QUINT16:
      return ExpectEqual<quint16>(x, y);
    case DT_QINT32:
      return ExpectEqual<qint32>(x, y);
    case DT_BFLOAT16:
      return ExpectEqual<bfloat16>(x, y, t);
    case DT_HALF:
      return ExpectEqual<Eigen::half>(x, y, t);
    default:
      EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
  }
}

void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
  ASSERT_TRUE(IsSameType(x, y));
  ASSERT_TRUE(IsSameShape(x, y));

  switch (x.dtype()) {
    case DT_HALF:
      return ExpectClose<Eigen::half>(x, y, atol, rtol);
    case DT_BFLOAT16:
      return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
    case DT_FLOAT:
      return ExpectClose<float>(x, y, atol, rtol);
    case DT_DOUBLE:
      return ExpectClose<double>(x, y, atol, rtol);
    case DT_COMPLEX64:
      return ExpectClose<complex64>(x, y, atol, rtol);
    case DT_COMPLEX128:
      return ExpectClose<complex128>(x, y, atol, rtol);
    default:
      EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
  }
}

::testing::AssertionResult internal_test::IsClose(Eigen::half x, Eigen::half y,
                                                  double atol, double rtol) {
  return test::IsClose(x, y, GetTolerance<Eigen::half>(atol),
                       GetTolerance<Eigen::half>(rtol));
}
::testing::AssertionResult internal_test::IsClose(float x, float y, double atol,
                                                  double rtol) {
  return test::IsClose(x, y, GetTolerance<float>(atol),
                       GetTolerance<float>(rtol));
}
::testing::AssertionResult internal_test::IsClose(double x, double y,
                                                  double atol, double rtol) {
  return test::IsClose(x, y, GetTolerance<double>(atol),
                       GetTolerance<double>(rtol));
}

}  // end namespace test
}  // end namespace tensorflow

程序输出如下,


图片.png

相关文章

网友评论

      本文标题:Tensorflow C++ 使用xtensor加载Fashio

      本文链接:https://www.haomeiwen.com/subject/meiajrtx.html