如前面的博客所述,我们已经完成了用Tensorflow的io算子读取CSV数据集,并生成输入张量的过程。但是输出结果是以概率分布的形式输出的,需要用std::max_element函数进行解析,转换成索引。这个过程非标准化过程,难以理解。在这里我们将此过程用ops::ArgMax函数进行标准化。
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(test_iris_predict)
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
include_directories(${INCLUDE_DIRS})
find_package(TensorflowCC REQUIRED)
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC)
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
conanfile.txt
[requires]
gtest/1.10.0
glog/0.4.0
protobuf/3.9.1
dataframe/1.20.0
[generators]
cmake
tf_iris_model_test.cpp
#include <fstream>
#include <tensorflow/c/c_api.h>
#include "death_handler/death_handler.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
using namespace tensorflow;
using BatchDef = std::initializer_list<tensorflow::int64>;
char const* data_csv = "../data/iris.csv";
int main(int argc, char** argv) {
Debug::DeathHandler dh;
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
// 获取CSV文件行
std::vector<tstring> GetCSVLines() {
std::fstream ifs {data_csv};
std::string line;
std::vector<tstring> lines;
while(std::getline(ifs, line)) {
lines.emplace_back(tstring(line));
}
return lines;
}
Tensor GetInputTensor() {
// 生成iris数据集
// https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/decode-c-s-v
Scope root = Scope::NewRootScope();
ClientSession session(root);
auto lines = GetCSVLines();
auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
// DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
// 1. Decode CSV成列张量
auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});
// 2. Reshape 成 (150, 1), 便于按行concat
auto input_1 = ops::Reshape(root, decode_csv_op.output[1], {150, 1});
auto input_2 = ops::Reshape(root, decode_csv_op.output[2], {150, 1});
auto input_3 = ops::Reshape(root, decode_csv_op.output[3], {150, 1});
auto input_4 = ops::Reshape(root, decode_csv_op.output[4], {150, 1});
// 3. 按行 concat成 (150, 4),用于iris数据集
auto concat_op = ops::Concat(root, {Input(input_1), Input(input_2), Input(input_3), Input(input_4)}, {1});
// 4. Client Session Run,出结果
std::vector<Tensor> outputs_concat {};
session.Run({concat_op}, &outputs_concat);
return outputs_concat[0];
}
std::vector<int> GetOutputBatches() {
Scope root = Scope::NewRootScope();
auto lines = GetCSVLines();
auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
// DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});
ClientSession session(root);
std::vector<Tensor> outputs;
session.Run(decode_csv_op.output, &outputs);
return test::GetTensorValue<int>(outputs[5]);
}
std::vector<int> ConvertTensorToIndexValue(Tensor const& tensor_) {
Scope root = Scope::NewRootScope();
ClientSession session(root);
std::vector<Tensor> outputs;
auto dim_ = ops::Const(root, 1);
auto attrs = ops::ArgMax::OutputType(DT_INT32);
auto arg_max_op = ops::ArgMax(root, tensor_, dim_, attrs);
session.Run({arg_max_op.output}, &outputs);
auto predict_res = test::GetTensorValue<int>(outputs[0]);
return predict_res;
}
TEST(TfIrisModelTest, LoadAndPredict) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir = "../iris_model";
TF_CHECK_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
auto input_tensor = GetInputTensor();
std::vector<tensorflow::Tensor> out_tensors;
TF_CHECK_OK(bundle.GetSession()->Run({{"serving_default_input_1:0", input_tensor}},
{"StatefulPartitionedCall:0"}, {}, &out_tensors));
std::cout << "Print Tensor Value\n";
test::PrintTensorValue<float>(std::cout, out_tensors[0], 3);
std::cout << "\n";
std::cout << "Print Index Value\n";
auto predict_res = ConvertTensorToIndexValue(out_tensors[0]);
for(auto ele: predict_res) {
std::cout << ele << "\n";
}
auto labels = GetOutputBatches();
int correct {0};
for(int i=0; i<predict_res.size(); ++i) {
if(predict_res[i] == labels[i]) {
++ correct;
}
}
std::cout << "Total correct: " << correct << "\n";
std::cout << "Total datasets: " << labels.size() << "\n";
std::cout << "Accuracy is: " << (float)(correct)/labels.size() << "\n";
}
iris.csv
1,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0
5,5.4,3.9,1.7,0.4,0
6,4.6,3.4,1.4,0.3,0
7,5.0,3.4,1.5,0.2,0
8,4.4,2.9,1.4,0.2,0
9,4.9,3.1,1.5,0.1,0
10,5.4,3.7,1.5,0.2,0
11,4.8,3.4,1.6,0.2,0
12,4.8,3.0,1.4,0.1,0
13,4.3,3.0,1.1,0.1,0
14,5.8,4.0,1.2,0.2,0
15,5.7,4.4,1.5,0.4,0
16,5.4,3.9,1.3,0.4,0
17,5.1,3.5,1.4,0.3,0
18,5.7,3.8,1.7,0.3,0
19,5.1,3.8,1.5,0.3,0
20,5.4,3.4,1.7,0.2,0
21,5.1,3.7,1.5,0.4,0
22,4.6,3.6,1.0,0.2,0
23,5.1,3.3,1.7,0.5,0
24,4.8,3.4,1.9,0.2,0
25,5.0,3.0,1.6,0.2,0
26,5.0,3.4,1.6,0.4,0
27,5.2,3.5,1.5,0.2,0
28,5.2,3.4,1.4,0.2,0
29,4.7,3.2,1.6,0.2,0
30,4.8,3.1,1.6,0.2,0
31,5.4,3.4,1.5,0.4,0
32,5.2,4.1,1.5,0.1,0
33,5.5,4.2,1.4,0.2,0
34,4.9,3.1,1.5,0.2,0
35,5.0,3.2,1.2,0.2,0
36,5.5,3.5,1.3,0.2,0
37,4.9,3.6,1.4,0.1,0
38,4.4,3.0,1.3,0.2,0
39,5.1,3.4,1.5,0.2,0
40,5.0,3.5,1.3,0.3,0
41,4.5,2.3,1.3,0.3,0
42,4.4,3.2,1.3,0.2,0
43,5.0,3.5,1.6,0.6,0
44,5.1,3.8,1.9,0.4,0
45,4.8,3.0,1.4,0.3,0
46,5.1,3.8,1.6,0.2,0
47,4.6,3.2,1.4,0.2,0
48,5.3,3.7,1.5,0.2,0
49,5.0,3.3,1.4,0.2,0
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
52,6.9,3.1,4.9,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
55,5.7,2.8,4.5,1.3,1
56,6.3,3.3,4.7,1.6,1
57,4.9,2.4,3.3,1.0,1
58,6.6,2.9,4.6,1.3,1
59,5.2,2.7,3.9,1.4,1
60,5.0,2.0,3.5,1.0,1
61,5.9,3.0,4.2,1.5,1
62,6.0,2.2,4.0,1.0,1
63,6.1,2.9,4.7,1.4,1
64,5.6,2.9,3.6,1.3,1
65,6.7,3.1,4.4,1.4,1
66,5.6,3.0,4.5,1.5,1
67,5.8,2.7,4.1,1.0,1
68,6.2,2.2,4.5,1.5,1
69,5.6,2.5,3.9,1.1,1
70,5.9,3.2,4.8,1.8,1
71,6.1,2.8,4.0,1.3,1
72,6.3,2.5,4.9,1.5,1
73,6.1,2.8,4.7,1.2,1
74,6.4,2.9,4.3,1.3,1
75,6.6,3.0,4.4,1.4,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
78,6.0,2.9,4.5,1.5,1
79,5.7,2.6,3.5,1.0,1
80,5.5,2.4,3.8,1.1,1
81,5.5,2.4,3.7,1.0,1
82,5.8,2.7,3.9,1.2,1
83,6.0,2.7,5.1,1.6,1
84,5.4,3.0,4.5,1.5,1
85,6.0,3.4,4.5,1.6,1
86,6.7,3.1,4.7,1.5,1
87,6.3,2.3,4.4,1.3,1
88,5.6,3.0,4.1,1.3,1
89,5.5,2.5,4.0,1.3,1
90,5.5,2.6,4.4,1.2,1
91,6.1,3.0,4.6,1.4,1
92,5.8,2.6,4.0,1.2,1
93,5.0,2.3,3.3,1.0,1
94,5.6,2.7,4.2,1.3,1
95,5.7,3.0,4.2,1.2,1
96,5.7,2.9,4.2,1.3,1
97,6.2,2.9,4.3,1.3,1
98,5.1,2.5,3.0,1.1,1
99,5.7,2.8,4.1,1.3,1
100,6.3,3.3,6.0,2.5,2
101,5.8,2.7,5.1,1.9,2
102,7.1,3.0,5.9,2.1,2
103,6.3,2.9,5.6,1.8,2
104,6.5,3.0,5.8,2.2,2
105,7.6,3.0,6.6,2.1,2
106,4.9,2.5,4.5,1.7,2
107,7.3,2.9,6.3,1.8,2
108,6.7,2.5,5.8,1.8,2
109,7.2,3.6,6.1,2.5,2
110,6.5,3.2,5.1,2.0,2
111,6.4,2.7,5.3,1.9,2
112,6.8,3.0,5.5,2.1,2
113,5.7,2.5,5.0,2.0,2
114,5.8,2.8,5.1,2.4,2
115,6.4,3.2,5.3,2.3,2
116,6.5,3.0,5.5,1.8,2
117,7.7,3.8,6.7,2.2,2
118,7.7,2.6,6.9,2.3,2
119,6.0,2.2,5.0,1.5,2
120,6.9,3.2,5.7,2.3,2
121,5.6,2.8,4.9,2.0,2
122,7.7,2.8,6.7,2.0,2
123,6.3,2.7,4.9,1.8,2
124,6.7,3.3,5.7,2.1,2
125,7.2,3.2,6.0,1.8,2
126,6.2,2.8,4.8,1.8,2
127,6.1,3.0,4.9,1.8,2
128,6.4,2.8,5.6,2.1,2
129,7.2,3.0,5.8,1.6,2
130,7.4,2.8,6.1,1.9,2
131,7.9,3.8,6.4,2.0,2
132,6.4,2.8,5.6,2.2,2
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2
135,7.7,3.0,6.1,2.3,2
136,6.3,3.4,5.6,2.4,2
137,6.4,3.1,5.5,1.8,2
138,6.0,3.0,4.8,1.8,2
139,6.9,3.1,5.4,2.1,2
140,6.7,3.1,5.6,2.4,2
141,6.9,3.1,5.1,2.3,2
142,5.8,2.7,5.1,1.9,2
143,6.8,3.2,5.9,2.3,2
144,6.7,3.3,5.7,2.5,2
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2
149,5.9,3.0,5.1,1.8,2
其他部分请参考《使用ops::DecodeCSV算子重写鸢尾花数据集预测》。
程序输出如下,

网友评论