美文网首页
使用ops::DecodeCSV算子重写鸢尾花数据集预测

使用ops::DecodeCSV算子重写鸢尾花数据集预测

作者: FredricZhu | 来源:发表于2022-03-22 11:35 被阅读0次

    本文使用ops::DecodeCSV算子重写鸢尾花数据集预测,这样就不需要依赖三方的hmdf::DataFrame了。
    程序结构如下,


    图片.png

    conanfile.txt

     [requires]
     gtest/1.10.0
     glog/0.4.0
     protobuf/3.9.1
     dataframe/1.20.0
    
     [generators]
     cmake
    

    CMakeLists.txt

    cmake_minimum_required(VERSION 3.3)
    
    
    project(test_iris_predict)
    
    set(CMAKE_CXX_STANDARD 17)
    add_definitions(-g)
    
    include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
    conan_basic_setup()
    
    set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
    include_directories(${INCLUDE_DIRS})
    
    find_package(TensorflowCC REQUIRED)
    
    file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 
    
    file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp)
    
    add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
    target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC)
    
    foreach( test_file ${test_file_list} )
        file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
        string(REPLACE ".cpp" "" file ${filename})
        add_executable(${file}  ${test_file})
        target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
    endforeach( test_file ${test_file_list})
    

    tf_iris_model_test.cpp

    #include <fstream>
    
    #include <tensorflow/c/c_api.h>
    
    #include "death_handler/death_handler.h"
    
    #include "tensorflow/cc/framework/scope.h"
    #include "tensorflow/cc/client/client_session.h"
    #include "tensorflow/cc/saved_model/constants.h"
    #include "tensorflow/cc/saved_model/loader.h"
    #include "tensorflow/cc/saved_model/signature_constants.h"
    #include "tensorflow/cc/saved_model/tag_constants.h"
    
    #include "tensorflow/core/framework/tensor.h"
    #include "tensorflow/core/lib/io/path.h"
    #include "tensorflow/core/platform/env.h"
    #include "tensorflow/core/platform/init_main.h"
    #include "tensorflow/core/platform/logging.h"
    #include "tensorflow/core/platform/types.h"
    
    #include <vector>
    #include "tensorflow/core/public/session.h"
    #include "tensorflow/cc/ops/const_op.h"
    #include "tf_/tensor_testutil.h"
    #include "tensorflow/core/framework/node_def_util.h"
    #include "tensorflow/core/lib/core/status_test_util.h"
    #include "tensorflow/core/platform/test.h"
    
    
    using namespace tensorflow;
    
    using BatchDef = std::initializer_list<tensorflow::int64>;
    char const* data_csv = "../data/iris.csv";
    
    int main(int argc, char** argv) {
        Debug::DeathHandler dh;
        ::testing::InitGoogleTest(&argc, argv);
        int ret = RUN_ALL_TESTS();
        return ret;
    }
    
    // 获取CSV文件行
    std::vector<tstring> GetCSVLines() {
        std::fstream ifs {data_csv};
        std::string line;
        std::vector<tstring> lines;
        while(std::getline(ifs, line)) {
            lines.emplace_back(tstring(line));
        }
        return lines;
    }
    
    Tensor GetInputTensor() {
        // 生成iris数据集
        // https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/decode-c-s-v
        Scope root = Scope::NewRootScope();
        ClientSession session(root);
        
        auto lines = GetCSVLines();
        auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
        // DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
        // 1. Decode CSV成列张量
        auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});
    
        // 2. Reshape 成 (150, 1), 便于按行concat
        auto input_1 = ops::Reshape(root, decode_csv_op.output[1], {150, 1});
        auto input_2 = ops::Reshape(root, decode_csv_op.output[2], {150, 1});
        auto input_3 = ops::Reshape(root, decode_csv_op.output[3], {150, 1});
        auto input_4 = ops::Reshape(root, decode_csv_op.output[4], {150, 1});
    
        // 3. 按行 concat成 (150, 4),用于iris数据集
        auto concat_op = ops::Concat(root, {Input(input_1), Input(input_2), Input(input_3), Input(input_4)}, {1});
    
        // 4. Client Session Run,出结果
        std::vector<Tensor> outputs_concat {};
        session.Run({concat_op}, &outputs_concat);
        return outputs_concat[0];
    }
    
    std::vector<int> GetOutputBatches() {
        Scope root = Scope::NewRootScope();
    
        auto lines = GetCSVLines();
        auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
        // DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
        auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});
    
        ClientSession session(root);
        std::vector<Tensor> outputs;
    
        session.Run(decode_csv_op.output, &outputs);
        return test::GetTensorValue<int>(outputs[5]);
    }
    
    std::vector<int> ConvertTensorToIndexValue(Tensor const& tensor_) {
        auto tensor_res = test::GetTensorValue<float>(tensor_);
        std::vector<int> predict_res{};
        for(int i=0; i<tensor_res.size(); ++i) {
            if(i!=0 && (i+1)%3==0) {
                auto max_idx = std::max_element(tensor_res.begin() + (i-2), tensor_res.begin() + (i+1)) -(tensor_res.begin() + (i-2));
                predict_res.emplace_back((int)max_idx);
            }    
        }
        return predict_res;
    }
    
    
    TEST(TfIrisModelTest, LoadAndPredict) {
        SavedModelBundleLite bundle;
        SessionOptions session_options;
        RunOptions run_options;
    
        const string export_dir = "../iris_model";
        TF_CHECK_OK(LoadSavedModel(session_options, run_options, export_dir,
                                  {kSavedModelTagServe}, &bundle));
        
        auto input_tensor = GetInputTensor();
    
        std::vector<tensorflow::Tensor> out_tensors;
        TF_CHECK_OK(bundle.GetSession()->Run({{"serving_default_input_1:0", input_tensor}},
        {"StatefulPartitionedCall:0"}, {}, &out_tensors)); 
    
        std::cout << "Print Tensor Value\n";
        test::PrintTensorValue<float>(std::cout, out_tensors[0], 3);
        std::cout << "\n";
    
        std::cout << "Print Index Value\n";
        auto predict_res = ConvertTensorToIndexValue(out_tensors[0]);
        for(auto ele: predict_res) {
            std::cout << ele << "\n";
        }
    
        auto labels = GetOutputBatches();
        int correct {0};
        for(int i=0; i<predict_res.size(); ++i) {
            if(predict_res[i] == labels[i]) {
                ++ correct;
            }
        }
        
        std::cout << "Total correct: " << correct << "\n";
        std::cout << "Total datasets: " << labels.size() << "\n"; 
        std::cout << "Accuracy is: " << (float)(correct)/labels.size() << "\n";
    }
    

    程序输出如下,


    图片.png

    相关文章

      网友评论

          本文标题:使用ops::DecodeCSV算子重写鸢尾花数据集预测

          本文链接:https://www.haomeiwen.com/subject/sclfjrtx.html