寻找了一下在Tensorflow里面直接加载CSV数据集的方式。
还挺别扭的。
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(test_parse_ops)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)
set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})
set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)
include_directories(${INCLUDE_DIRS})
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
conanfile.txt
[requires]
gtest/1.10.0
glog/0.4.0
protobuf/3.9.1
eigen/3.4.0
dataframe/1.20.0
opencv/3.4.17
[generators]
cmake
csv文件内容
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0
5,5.4,3.9,1.7,0.4,0
6,4.6,3.4,1.4,0.3,0
7,5.0,3.4,1.5,0.2,0
8,4.4,2.9,1.4,0.2,0
9,4.9,3.1,1.5,0.1,0
10,5.4,3.7,1.5,0.2,0
11,4.8,3.4,1.6,0.2,0
12,4.8,3.0,1.4,0.1,0
13,4.3,3.0,1.1,0.1,0
14,5.8,4.0,1.2,0.2,0
15,5.7,4.4,1.5,0.4,0
16,5.4,3.9,1.3,0.4,0
17,5.1,3.5,1.4,0.3,0
18,5.7,3.8,1.7,0.3,0
19,5.1,3.8,1.5,0.3,0
20,5.4,3.4,1.7,0.2,0
21,5.1,3.7,1.5,0.4,0
22,4.6,3.6,1.0,0.2,0
23,5.1,3.3,1.7,0.5,0
24,4.8,3.4,1.9,0.2,0
25,5.0,3.0,1.6,0.2,0
26,5.0,3.4,1.6,0.4,0
27,5.2,3.5,1.5,0.2,0
28,5.2,3.4,1.4,0.2,0
29,4.7,3.2,1.6,0.2,0
30,4.8,3.1,1.6,0.2,0
31,5.4,3.4,1.5,0.4,0
32,5.2,4.1,1.5,0.1,0
33,5.5,4.2,1.4,0.2,0
34,4.9,3.1,1.5,0.2,0
35,5.0,3.2,1.2,0.2,0
36,5.5,3.5,1.3,0.2,0
37,4.9,3.6,1.4,0.1,0
38,4.4,3.0,1.3,0.2,0
39,5.1,3.4,1.5,0.2,0
40,5.0,3.5,1.3,0.3,0
41,4.5,2.3,1.3,0.3,0
42,4.4,3.2,1.3,0.2,0
43,5.0,3.5,1.6,0.6,0
44,5.1,3.8,1.9,0.4,0
45,4.8,3.0,1.4,0.3,0
46,5.1,3.8,1.6,0.2,0
47,4.6,3.2,1.4,0.2,0
48,5.3,3.7,1.5,0.2,0
49,5.0,3.3,1.4,0.2,0
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
52,6.9,3.1,4.9,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
55,5.7,2.8,4.5,1.3,1
56,6.3,3.3,4.7,1.6,1
57,4.9,2.4,3.3,1.0,1
58,6.6,2.9,4.6,1.3,1
59,5.2,2.7,3.9,1.4,1
60,5.0,2.0,3.5,1.0,1
61,5.9,3.0,4.2,1.5,1
62,6.0,2.2,4.0,1.0,1
63,6.1,2.9,4.7,1.4,1
64,5.6,2.9,3.6,1.3,1
65,6.7,3.1,4.4,1.4,1
66,5.6,3.0,4.5,1.5,1
67,5.8,2.7,4.1,1.0,1
68,6.2,2.2,4.5,1.5,1
69,5.6,2.5,3.9,1.1,1
70,5.9,3.2,4.8,1.8,1
71,6.1,2.8,4.0,1.3,1
72,6.3,2.5,4.9,1.5,1
73,6.1,2.8,4.7,1.2,1
74,6.4,2.9,4.3,1.3,1
75,6.6,3.0,4.4,1.4,1
76,6.8,2.8,4.8,1.4,1
77,6.7,3.0,5.0,1.7,1
78,6.0,2.9,4.5,1.5,1
79,5.7,2.6,3.5,1.0,1
80,5.5,2.4,3.8,1.1,1
81,5.5,2.4,3.7,1.0,1
82,5.8,2.7,3.9,1.2,1
83,6.0,2.7,5.1,1.6,1
84,5.4,3.0,4.5,1.5,1
85,6.0,3.4,4.5,1.6,1
86,6.7,3.1,4.7,1.5,1
87,6.3,2.3,4.4,1.3,1
88,5.6,3.0,4.1,1.3,1
89,5.5,2.5,4.0,1.3,1
90,5.5,2.6,4.4,1.2,1
91,6.1,3.0,4.6,1.4,1
92,5.8,2.6,4.0,1.2,1
93,5.0,2.3,3.3,1.0,1
94,5.6,2.7,4.2,1.3,1
95,5.7,3.0,4.2,1.2,1
96,5.7,2.9,4.2,1.3,1
97,6.2,2.9,4.3,1.3,1
98,5.1,2.5,3.0,1.1,1
99,5.7,2.8,4.1,1.3,1
100,6.3,3.3,6.0,2.5,2
101,5.8,2.7,5.1,1.9,2
102,7.1,3.0,5.9,2.1,2
103,6.3,2.9,5.6,1.8,2
104,6.5,3.0,5.8,2.2,2
105,7.6,3.0,6.6,2.1,2
106,4.9,2.5,4.5,1.7,2
107,7.3,2.9,6.3,1.8,2
108,6.7,2.5,5.8,1.8,2
109,7.2,3.6,6.1,2.5,2
110,6.5,3.2,5.1,2.0,2
111,6.4,2.7,5.3,1.9,2
112,6.8,3.0,5.5,2.1,2
113,5.7,2.5,5.0,2.0,2
114,5.8,2.8,5.1,2.4,2
115,6.4,3.2,5.3,2.3,2
116,6.5,3.0,5.5,1.8,2
117,7.7,3.8,6.7,2.2,2
118,7.7,2.6,6.9,2.3,2
119,6.0,2.2,5.0,1.5,2
120,6.9,3.2,5.7,2.3,2
121,5.6,2.8,4.9,2.0,2
122,7.7,2.8,6.7,2.0,2
123,6.3,2.7,4.9,1.8,2
124,6.7,3.3,5.7,2.1,2
125,7.2,3.2,6.0,1.8,2
126,6.2,2.8,4.8,1.8,2
127,6.1,3.0,4.9,1.8,2
128,6.4,2.8,5.6,2.1,2
129,7.2,3.0,5.8,1.6,2
130,7.4,2.8,6.1,1.9,2
131,7.9,3.8,6.4,2.0,2
132,6.4,2.8,5.6,2.2,2
133,6.3,2.8,5.1,1.5,2
134,6.1,2.6,5.6,1.4,2
135,7.7,3.0,6.1,2.3,2
136,6.3,3.4,5.6,2.4,2
137,6.4,3.1,5.5,1.8,2
138,6.0,3.0,4.8,1.8,2
139,6.9,3.1,5.4,2.1,2
140,6.7,3.1,5.6,2.4,2
141,6.9,3.1,5.1,2.3,2
142,5.8,2.7,5.1,1.9,2
143,6.8,3.2,5.9,2.3,2
144,6.7,3.3,5.7,2.5,2
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2
149,5.9,3.0,5.1,1.8,2
tf_decode_csv_test.cpp
#include <string>
#include <vector>
#include <fstream>
#include <glog/logging.h>
#include "tensorflow/core/platform/test.h"
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
using namespace tensorflow;
using BatchDef = std::initializer_list<tensorflow::int64>;
char const* csv_filepath = "../data/iris.csv";
int main(int argc, char** argv) {
FLAGS_log_dir = "./";
FLAGS_alsologtostderr = true;
// 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
FLAGS_minloglevel = 0;
Debug::DeathHandler dh;
google::InitGoogleLogging("./logs.log");
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
Tensor MakeTensor(std::vector<std::string> const& batch, BatchDef const& batch_def) {
Tensor t(DT_STRING,
TensorShape(batch_def));
for (int i = 0; i < batch.size(); ++i) {
t.flat<tensorflow::tstring>()(i) = tensorflow::tstring(batch[i]);
}
return t;
}
TEST(TfArrayOpsTests, DecodeCSV) {
// 读取CSV文件
// https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/decode-c-s-v
Scope root = Scope::NewRootScope();
auto input =
tensorflow::ops::Placeholder(root.WithOpName("input"), DT_STRING);
// DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(2.0f), Input(1.2f), Input(2.0f), Input(1.2f), Input(1)});
ClientSession session(root);
std::vector<Tensor> outputs;
std::fstream ifs {csv_filepath};
std::string line;
std::vector<string> lines;
while(std::getline(ifs, line)) {
lines.emplace_back(line);
}
auto lines_tensor = MakeTensor(lines, {(long)lines.size()});
session.Run({{input, lines_tensor}}, decode_csv_op.output, &outputs);
std::cout << outputs[0].DebugString()<< "\n";
std::cout << outputs[1].DebugString()<< "\n";
std::cout << outputs[2].DebugString()<< "\n";
std::cout << outputs[3].DebugString()<< "\n";
std::cout << outputs[4].DebugString()<< "\n";
std::cout << outputs[5].DebugString()<< "\n";
auto index_vals = test::GetTensorValue<int32>(outputs[0]);
auto huae_length_vals = test::GetTensorValue<float>(outputs[1]);
test::PrintTensorValue<float>(std::cout, outputs[1]);
ASSERT_EQ(0, index_vals[0]);
ASSERT_EQ(5.099999904632568359, huae_length_vals[0]);
ASSERT_EQ(150, index_vals.size());
ASSERT_EQ(150, huae_length_vals.size());
}
程序输出如下,
image.png
网友评论