美文网首页
使用ops::ArgMax算子标准化预测输出

使用ops::ArgMax算子标准化预测输出

作者: FredricZhu | 来源:发表于2022-03-28 19:49 被阅读0次

如前面的博客所述,我们已经完成了用Tensorflow的io算子读取CSV数据集,并生成输入张量的过程。但是输出结果是以概率分布的形式输出的,需要用std::max_element函数进行解析,转换成索引。这个过程非标准化过程,难以理解。在这里我们将此过程用ops::ArgMax函数进行标准化。
CMakeLists.txt

cmake_minimum_required(VERSION 3.3)


project(test_iris_predict)

set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
include_directories(${INCLUDE_DIRS})

find_package(TensorflowCC REQUIRED)

file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC)

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

conanfile.txt

 [requires]
 gtest/1.10.0
 glog/0.4.0
 protobuf/3.9.1
 dataframe/1.20.0

 [generators]
 cmake

tf_iris_model_test.cpp

#include <fstream>

#include <tensorflow/c/c_api.h>

#include "death_handler/death_handler.h"

#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"


using namespace tensorflow;

using BatchDef = std::initializer_list<tensorflow::int64>;
char const* data_csv = "../data/iris.csv";

int main(int argc, char** argv) {
    Debug::DeathHandler dh;
    ::testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}

// 获取CSV文件行
std::vector<tstring> GetCSVLines() {
    std::fstream ifs {data_csv};
    std::string line;
    std::vector<tstring> lines;
    while(std::getline(ifs, line)) {
        lines.emplace_back(tstring(line));
    }
    return lines;
}

Tensor GetInputTensor() {
    // 生成iris数据集
    // https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/decode-c-s-v
    Scope root = Scope::NewRootScope();
    ClientSession session(root);
    
    auto lines = GetCSVLines();
    auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
    // DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
    // 1. Decode CSV成列张量
    auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});

    // 2. Reshape 成 (150, 1), 便于按行concat
    auto input_1 = ops::Reshape(root, decode_csv_op.output[1], {150, 1});
    auto input_2 = ops::Reshape(root, decode_csv_op.output[2], {150, 1});
    auto input_3 = ops::Reshape(root, decode_csv_op.output[3], {150, 1});
    auto input_4 = ops::Reshape(root, decode_csv_op.output[4], {150, 1});

    // 3. 按行 concat成 (150, 4),用于iris数据集
    auto concat_op = ops::Concat(root, {Input(input_1), Input(input_2), Input(input_3), Input(input_4)}, {1});

    // 4. Client Session Run,出结果
    std::vector<Tensor> outputs_concat {};
    session.Run({concat_op}, &outputs_concat);
    return outputs_concat[0];
}

std::vector<int> GetOutputBatches() {
    Scope root = Scope::NewRootScope();

    auto lines = GetCSVLines();
    auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
    // DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
    auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});

    ClientSession session(root);
    std::vector<Tensor> outputs;

    session.Run(decode_csv_op.output, &outputs);
    return test::GetTensorValue<int>(outputs[5]);
}


std::vector<int> ConvertTensorToIndexValue(Tensor const& tensor_) {
    Scope root = Scope::NewRootScope();
    ClientSession session(root);
    std::vector<Tensor> outputs;
    auto dim_ = ops::Const(root, 1);
    auto attrs = ops::ArgMax::OutputType(DT_INT32);
    auto arg_max_op = ops::ArgMax(root, tensor_, dim_, attrs);
    session.Run({arg_max_op.output}, &outputs);
    auto predict_res = test::GetTensorValue<int>(outputs[0]); 
    return predict_res;
}

TEST(TfIrisModelTest, LoadAndPredict) {
    SavedModelBundleLite bundle;
    SessionOptions session_options;
    RunOptions run_options;

    const string export_dir = "../iris_model";
    TF_CHECK_OK(LoadSavedModel(session_options, run_options, export_dir,
                              {kSavedModelTagServe}, &bundle));
    
    auto input_tensor = GetInputTensor();

    std::vector<tensorflow::Tensor> out_tensors;
    TF_CHECK_OK(bundle.GetSession()->Run({{"serving_default_input_1:0", input_tensor}},
    {"StatefulPartitionedCall:0"}, {}, &out_tensors)); 

    std::cout << "Print Tensor Value\n";
    test::PrintTensorValue<float>(std::cout, out_tensors[0], 3);
    std::cout << "\n";

    std::cout << "Print Index Value\n";
    auto predict_res = ConvertTensorToIndexValue(out_tensors[0]);
    for(auto ele: predict_res) {
        std::cout << ele << "\n";
    }

    auto labels = GetOutputBatches();
    int correct {0};
    for(int i=0; i<predict_res.size(); ++i) {
        if(predict_res[i] == labels[i]) {
            ++ correct;
        }
    }
    
    std::cout << "Total correct: " << correct << "\n";
    std::cout << "Total datasets: " << labels.size() << "\n"; 
    std::cout << "Accuracy is: " << (float)(correct)/labels.size() << "\n";
}

iris.csv

1,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0
5,5.4,3.9,1.7,0.4,0
6,4.6,3.4,1.4,0.3,0
7,5.0,3.4,1.5,0.2,0
8,4.4,2.9,1.4,0.2,0
9,4.9,3.1,1.5,0.1,0
10,5.4,3.7,1.5,0.2,0
11,4.8,3.4,1.6,0.2,0
12,4.8,3.0,1.4,0.1,0
13,4.3,3.0,1.1,0.1,0
14,5.8,4.0,1.2,0.2,0
15,5.7,4.4,1.5,0.4,0
16,5.4,3.9,1.3,0.4,0
17,5.1,3.5,1.4,0.3,0
18,5.7,3.8,1.7,0.3,0
19,5.1,3.8,1.5,0.3,0
20,5.4,3.4,1.7,0.2,0
21,5.1,3.7,1.5,0.4,0
22,4.6,3.6,1.0,0.2,0
23,5.1,3.3,1.7,0.5,0
24,4.8,3.4,1.9,0.2,0
25,5.0,3.0,1.6,0.2,0
26,5.0,3.4,1.6,0.4,0
27,5.2,3.5,1.5,0.2,0
28,5.2,3.4,1.4,0.2,0
29,4.7,3.2,1.6,0.2,0
30,4.8,3.1,1.6,0.2,0
31,5.4,3.4,1.5,0.4,0
32,5.2,4.1,1.5,0.1,0
33,5.5,4.2,1.4,0.2,0
34,4.9,3.1,1.5,0.2,0
35,5.0,3.2,1.2,0.2,0
36,5.5,3.5,1.3,0.2,0
37,4.9,3.6,1.4,0.1,0
38,4.4,3.0,1.3,0.2,0
39,5.1,3.4,1.5,0.2,0
40,5.0,3.5,1.3,0.3,0
41,4.5,2.3,1.3,0.3,0
42,4.4,3.2,1.3,0.2,0
43,5.0,3.5,1.6,0.6,0
44,5.1,3.8,1.9,0.4,0
45,4.8,3.0,1.4,0.3,0
46,5.1,3.8,1.6,0.2,0
47,4.6,3.2,1.4,0.2,0
48,5.3,3.7,1.5,0.2,0
49,5.0,3.3,1.4,0.2,0
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
52,6.9,3.1,4.9,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
55,5.7,2.8,4.5,1.3,1
56,6.3,3.3,4.7,1.6,1
57,4.9,2.4,3.3,1.0,1
58,6.6,2.9,4.6,1.3,1
59,5.2,2.7,3.9,1.4,1
60,5.0,2.0,3.5,1.0,1
61,5.9,3.0,4.2,1.5,1
62,6.0,2.2,4.0,1.0,1
63,6.1,2.9,4.7,1.4,1
64,5.6,2.9,3.6,1.3,1
65,6.7,3.1,4.4,1.4,1
66,5.6,3.0,4.5,1.5,1
67,5.8,2.7,4.1,1.0,1
68,6.2,2.2,4.5,1.5,1
69,5.6,2.5,3.9,1.1,1
70,5.9,3.2,4.8,1.8,1
71,6.1,2.8,4.0,1.3,1
72,6.3,2.5,4.9,1.5,1
73,6.1,2.8,4.7,1.2,1
74,6.4,2.9,4.3,1.3,1
75,6.6,3.0,4.4,1.4,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
78,6.0,2.9,4.5,1.5,1
79,5.7,2.6,3.5,1.0,1
80,5.5,2.4,3.8,1.1,1
81,5.5,2.4,3.7,1.0,1
82,5.8,2.7,3.9,1.2,1
83,6.0,2.7,5.1,1.6,1
84,5.4,3.0,4.5,1.5,1
85,6.0,3.4,4.5,1.6,1
86,6.7,3.1,4.7,1.5,1
87,6.3,2.3,4.4,1.3,1
88,5.6,3.0,4.1,1.3,1
89,5.5,2.5,4.0,1.3,1
90,5.5,2.6,4.4,1.2,1
91,6.1,3.0,4.6,1.4,1
92,5.8,2.6,4.0,1.2,1
93,5.0,2.3,3.3,1.0,1
94,5.6,2.7,4.2,1.3,1
95,5.7,3.0,4.2,1.2,1
96,5.7,2.9,4.2,1.3,1
97,6.2,2.9,4.3,1.3,1
98,5.1,2.5,3.0,1.1,1
99,5.7,2.8,4.1,1.3,1
100,6.3,3.3,6.0,2.5,2
101,5.8,2.7,5.1,1.9,2
102,7.1,3.0,5.9,2.1,2
103,6.3,2.9,5.6,1.8,2
104,6.5,3.0,5.8,2.2,2
105,7.6,3.0,6.6,2.1,2
106,4.9,2.5,4.5,1.7,2
107,7.3,2.9,6.3,1.8,2
108,6.7,2.5,5.8,1.8,2
109,7.2,3.6,6.1,2.5,2
110,6.5,3.2,5.1,2.0,2
111,6.4,2.7,5.3,1.9,2
112,6.8,3.0,5.5,2.1,2
113,5.7,2.5,5.0,2.0,2
114,5.8,2.8,5.1,2.4,2
115,6.4,3.2,5.3,2.3,2
116,6.5,3.0,5.5,1.8,2
117,7.7,3.8,6.7,2.2,2
118,7.7,2.6,6.9,2.3,2
119,6.0,2.2,5.0,1.5,2
120,6.9,3.2,5.7,2.3,2
121,5.6,2.8,4.9,2.0,2
122,7.7,2.8,6.7,2.0,2
123,6.3,2.7,4.9,1.8,2
124,6.7,3.3,5.7,2.1,2
125,7.2,3.2,6.0,1.8,2
126,6.2,2.8,4.8,1.8,2
127,6.1,3.0,4.9,1.8,2
128,6.4,2.8,5.6,2.1,2
129,7.2,3.0,5.8,1.6,2
130,7.4,2.8,6.1,1.9,2
131,7.9,3.8,6.4,2.0,2
132,6.4,2.8,5.6,2.2,2
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2
135,7.7,3.0,6.1,2.3,2
136,6.3,3.4,5.6,2.4,2
137,6.4,3.1,5.5,1.8,2
138,6.0,3.0,4.8,1.8,2
139,6.9,3.1,5.4,2.1,2
140,6.7,3.1,5.6,2.4,2
141,6.9,3.1,5.1,2.3,2
142,5.8,2.7,5.1,1.9,2
143,6.8,3.2,5.9,2.3,2
144,6.7,3.3,5.7,2.5,2
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2
149,5.9,3.0,5.1,1.8,2

其他部分请参考《使用ops::DecodeCSV算子重写鸢尾花数据集预测》。
程序输出如下,


图片.png

相关文章

  • 使用ops::ArgMax算子标准化预测输出

    如前面的博客所述,我们已经完成了用Tensorflow的io算子读取CSV数据集,并生成输入张量的过程。但是输出结...

  • 函数索引

    Math_Ops.py argmax

  • 使用ops::DecodeCSV算子重写鸢尾花数据集预测

    本文使用ops::DecodeCSV算子重写鸢尾花数据集预测,这样就不需要依赖三方的hmdf::DataFrame...

  • tf.argmax

    文件 math_ops.py 函数定义 参数 返回 案例 输出结果解释首先,argmax返回的是索引值,返回每一行...

  • Tensorflow C++ 使用ops::Equal和ops:

    到此为止,所有标准C++的操作,全部使用ops算子替换完成。本例最重要的函数是GetMatchNum。使用的是mn...

  • printf函数--小问题未解决

    使用语言的过程实际上就是输入---处理---输出,本次主要学习标准化输出printf() 标准化输出、scanf(...

  • 3、评估MNIST

    预测正确的标签 首先让我们找出那些预测正确的标签。tf.argmax 是一个非常有用的函数,它能给出某个tenso...

  • OPS - tcpdump使用

    tcpdump 官网 -> http://www.tcpdump.org 1. 安装步骤 在官网分别下载 Tcpd...

  • 遍历二维数组

    vector& ops 使用for(auto op:ops)进行遍历,而不需要求出二维数组的长度

  • numpy的argmax用法(转)

    原链接: (Python)numpy的argmax用法 一维数组 argmax返回的是最大数的索引.argmax有...

网友评论

      本文标题:使用ops::ArgMax算子标准化预测输出

      本文链接:https://www.haomeiwen.com/subject/txxmjrtx.html