美文网首页
Apache Arrow 高级用法合集

Apache Arrow 高级用法合集

作者: FredricZhu | 来源:发表于2022-01-13 12:07 被阅读0次

前两天深入研究了一下Apache Arrow的文档,同时看了部分Apache Arrow库的单测。然后基于Apache Arrow 5.0.0研究了一下Arrow-Dataset和Arrow-ExecEngine的用法。
其中Arrow-ExecEngine是试验功能,有点类似Spark的RDD算子的连续调用方法。例如,
rdd.map().filter().join()等等。
目前Apache Arrow仅支持简单的过滤和单列/多列计算,不支持join。
Arrow-ExecEngine功能比较新,只有C++ Binding加入了。Python版本尚未加入。
想要试验Arrow-Dataset和Arrow-ExecEngine,需要在vcpkg中做如下设置。

 {
      "name": "arrow",
      "version>=": "5.0.0",
      "default-features": false,
      "features": [
        "csv",
        "filesystem",
        "json",
        "parquet",
        "dataset",
        "flight"]
    }

程序目录结构如下,


image.png

代码如下,
CMakeLists.txt

cmake_minimum_required(VERSION 2.6)

if(APPLE)
    message(STATUS "This is Apple, do nothing.")
    set(CMAKE_MACOSX_RPATH 1)
    set(CMAKE_PREFIX_PATH /Users/aabjfzhu/software/vcpkg/ports/cppwork/vcpkg_installed/x64-osx/share )
elseif(UNIX)
    message(STATUS "This is linux, set CMAKE_PREFIX_PATH.")
    set(CMAKE_PREFIX_PATH /vcpkg/ports/cppwork/vcpkg_installed/x64-linux/share)
endif(APPLE)

set(Boost_NO_WARN_NEW_VERSIONS 1)

project(arrow_test)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

add_definitions(-g)

find_package(ZLIB)

find_package(glog REQUIRED)

find_package(re2 REQUIRED)

find_package(OpenCV REQUIRED )

find_package(OpenSSL REQUIRED)

find_package(Arrow CONFIG REQUIRED)

find_package(unofficial-brotli REQUIRED)

find_package(unofficial-utf8proc CONFIG REQUIRED)
find_package(Thrift CONFIG REQUIRED)

find_package(Boost REQUIRED COMPONENTS
    system
    filesystem
    serialization
    program_options
    thread
    )

find_package(DataFrame REQUIRED)

if(APPLE)
    MESSAGE(STATUS "This is APPLE, set INCLUDE_DIRS")
set(INCLUDE_DIRS ${Boost_INCLUDE_DIRS} /usr/local/include /usr/local/iODBC/include /opt/snowflake/snowflakeodbc/include/ ${CMAKE_CURRENT_SOURCE_DIR}/../include/ ${CMAKE_CURRENT_SOURCE_DIR}/../../../include)
    set(ARROW_INCLUDE_DIR /Users/aabjfzhu/software/vcpkg/ports/cppwork/vcpkg_installed/x64-osx/include)
elseif(UNIX)
    MESSAGE(STATUS "This is linux, set INCLUDE_DIRS")
    set(INCLUDE_DIRS ${Boost_INCLUDE_DIRS} /usr/local/include ${CMAKE_CURRENT_SOURCE_DIR}/../include/   ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/ /vcpkg/ports/cppwork/vcpkg_installed/x64-linux/include)
    set(ARROW_INCLUDE_DIR /vcpkg/ports/cppwork/vcpkg_installed/x64-linux/include)
endif(APPLE)


if(APPLE)
    MESSAGE(STATUS "This is APPLE, set LINK_DIRS")
    set(LINK_DIRS /usr/local/lib /usr/local/iODBC/lib /opt/snowflake/snowflakeodbc/lib/universal /Users/aabjfzhu/software/vcpkg/ports/cppwork/vcpkg_installed/x64-osx/lib)
elseif(UNIX)
    MESSAGE(STATUS "This is linux, set LINK_DIRS")
    set(LINK_DIRS ${Boost_INCLUDE_DIRS} /usr/local/lib /vcpkg/ports/cppwork/vcpkg_installed/x64-linux/lib)
endif(APPLE)

if(APPLE)
    MESSAGE(STATUS "This is APPLE, set ODBC_LIBS")
    set(ODBC_LIBS iodbc iodbcinst)
elseif(UNIX)
    MESSAGE(STATUS "This is linux, set ODBC_LIBS")
    set(ODBC_LIBS odbc odbcinst ltdl)
endif(APPLE)

include_directories(${INCLUDE_DIRS})
LINK_DIRECTORIES(${LINK_DIRS})

file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/http/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/yaml/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/death_handler/impl/*.cpp  ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/api_accuracy/utils/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/api_accuracy/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES} ${test_file})
target_link_libraries(${PROJECT_NAME}_lib ${Boost_LIBRARIES} ZLIB::ZLIB glog::glog DataFrame::DataFrame ${OpenCV_LIBS})
target_link_libraries(${PROJECT_NAME}_lib OpenSSL::SSL OpenSSL::Crypto libgtest.a pystring libyaml-cpp.a libgmock.a ${ODBC_LIBS} libnanodbc.a pthread dl backtrace libzstd.a libbz2.a libsnappy.a re2::re2 parquet lz4 unofficial::brotli::brotlidec-static unofficial::brotli::brotlienc-static unofficial::brotli::brotlicommon-static utf8proc thrift::thrift arrow arrow_dataset)

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

arr_/arr_.h

#ifndef _FREDRIC_ARR__H_
#define _FREDRIC_ARR__H_

#define ARROW_COMPUTE

#include <arrow/compute/api.h>
#include <arrow/compute/exec/exec_plan.h>

#include <arrow/dataset/api.h>
#include <arrow/filesystem/api.h>
#include "arrow/pretty_print.h"
#include <arrow/api.h>
#include <arrow/csv/api.h>
#include <arrow/json/api.h>
#include <arrow/io/api.h>
#include <arrow/table.h>
#include <arrow/pretty_print.h>
#include <arrow/result.h>
#include <arrow/status.h>
#include <arrow/ipc/api.h>
#include "arrow/util/logging.h"
#include "arrow/util/vector.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/thread_pool.h"

#include <arrow/testing/gtest_util.h>
#include <parquet/arrow/reader.h>
#include <parquet/arrow/writer.h>
#include <parquet/exception.h>
#include <memory>
#include <iostream>
#include <string>
#include <chrono>
#include <thread>

template <typename T>
using numbuildT = arrow::NumericBuilder<T>;

struct BatchesWithSchema {
  std::vector<arrow::compute::ExecBatch> batches;
  std::shared_ptr<arrow::Schema> schema;
};


struct ArrowUtil {
    static arrow::Status read_csv(char const* file_name, std::shared_ptr<arrow::Table>& tb);
    static arrow::Status read_ipc(char const* file_name, std::shared_ptr<arrow::Table>& tb);
    static arrow::Status read_parquet(char const* file_name, std::shared_ptr<arrow::Table>& tb);
    static arrow::Status read_json(char const* file_name, std::shared_ptr<arrow::Table>& tb);

    static arrow::Status write_ipc(arrow::Table const& tb , char const* file_name);
    static arrow::Status write_parquet(arrow::Table const& tb , char const* file_name);
    
    static void SleepFor(double seconds) {
        std::this_thread::sleep_for(
            std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9)));
    }

    template <typename T, typename buildT, typename arrayT>
    inline static std::shared_ptr<arrow::Array> chunked_array_to_array(std::shared_ptr<arrow::ChunkedArray> const& array_a) {
        buildT int64_builder;
        int64_builder.Resize(array_a->length());
        std::vector<T> int64_values;
        int64_values.reserve(array_a->length());
        for(int i=0; i<array_a->num_chunks(); ++i) {
            auto inner_arr = array_a->chunk(i);
            auto int_a = std::static_pointer_cast<arrayT>(inner_arr);
            for(int j=0; j<int_a->length(); ++j) {
                int64_values.push_back(int_a->Value(j));
            }
        }

        int64_builder.AppendValues(int64_values);
        std::shared_ptr<arrow::Array> array_a_res;
        int64_builder.Finish(&array_a_res);
        return array_a_res;
    }


    template <typename T, typename arrayT>
    inline static std::vector<T> chunked_array_to_vector(std::shared_ptr<arrow::ChunkedArray> const& array_a) {
        std::vector<T> int64_values;
        int64_values.reserve(array_a->length());
        for(int i=0; i<array_a->num_chunks(); ++i) {
            auto inner_arr = array_a->chunk(i);
            auto int_a = std::static_pointer_cast<arrayT>(inner_arr);
            for(int j=0; j<int_a->length(); ++j) {
                int64_values.push_back(int_a->Value(j));
            }
        }
        return int64_values;
    }

    inline static std::vector<std::string> chunked_array_to_str_vector(std::shared_ptr<arrow::ChunkedArray> const& array_a) {
        std::vector<std::string> int64_values;
        int64_values.reserve(array_a->length());
        for(int i=0; i<array_a->num_chunks(); ++i) {
            auto inner_arr = array_a->chunk(i);
            auto int_a = std::static_pointer_cast<arrow::StringArray>(inner_arr);
            for(int j=0; j<int_a->length(); ++j) {
                int64_values.push_back(int_a->Value(j).to_string());
            }
        }
        return int64_values;
    }


    inline static std::shared_ptr<arrow::Array> chunked_array_to_str_array(std::shared_ptr<arrow::ChunkedArray> const& array_a) {
        arrow::StringBuilder int64_builder;
        int64_builder.Resize(array_a->length());
        std::vector<std::string> int64_values;
        int64_values.reserve(array_a->length());
        for(int i=0; i<array_a->num_chunks(); ++i) {
            auto inner_arr = array_a->chunk(i);
            auto int_a = std::static_pointer_cast<arrow::StringArray>(inner_arr);
            for(int j=0; j<int_a->length(); ++j) {
                int64_values.push_back(int_a->Value(j).to_string());
            }
        }
        int64_builder.AppendValues(int64_values);
        std::shared_ptr<arrow::Array> array_a_res;
        int64_builder.Finish(&array_a_res);
        return array_a_res;
    }

    static arrow::Result<arrow::compute::ExecNode*> MakeTestSourceNode(arrow::compute::ExecPlan* plan, std::string label,
                                     BatchesWithSchema batches_with_schema, bool parallel,
                                     bool slow);
    
    static arrow::Future<std::vector<arrow::compute::ExecBatch>> StartAndCollect(
        arrow::compute::ExecPlan* plan, arrow::AsyncGenerator<arrow::util::optional<arrow::compute::ExecBatch>> gen);
    
    static BatchesWithSchema MakeBatchAndSchema(std::shared_ptr<arrow::Table> const& tb);
    static arrow::Status ConvertExecBatchToRecBatch(std::shared_ptr<arrow::Schema> const& schema, std::vector<arrow::compute::ExecBatch> const& exec_batches,  arrow::RecordBatchVector& out_rec_batches);
 
};

#endif

arr_/impl/arr_.cpp

#include "arr_/arr_.h"
#include <iostream>



arrow::Status ArrowUtil::read_csv(char const* file_name, std::shared_ptr<arrow::Table>& tb) { 

    auto fs = std::make_shared<arrow::fs::LocalFileSystem>();

    ARROW_ASSIGN_OR_RAISE(auto info, fs->GetFileInfo(file_name));

    auto format = std::make_shared<arrow::dataset::CsvFileFormat>();
    ARROW_ASSIGN_OR_RAISE(auto factory,
      arrow::dataset::FileSystemDatasetFactory::Make(fs, {info}, format, arrow::dataset::FileSystemFactoryOptions()));
    ARROW_ASSIGN_OR_RAISE(auto dataset, factory->Finish());
    ARROW_ASSIGN_OR_RAISE(auto scan_builder, dataset->NewScan());
    ARROW_ASSIGN_OR_RAISE(auto scanner, scan_builder->Finish());
    ARROW_ASSIGN_OR_RAISE(auto table, scanner->ToTable());
    tb = table;
    return arrow::Status::OK();
}

arrow::Status ArrowUtil::read_ipc(char const* file_name, std::shared_ptr<arrow::Table>& tb) {

    auto fs = std::make_shared<arrow::fs::LocalFileSystem>();

    ARROW_ASSIGN_OR_RAISE(auto info, fs->GetFileInfo(file_name));

    auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
    ARROW_ASSIGN_OR_RAISE(auto factory,
      arrow::dataset::FileSystemDatasetFactory::Make(fs, {info}, format, arrow::dataset::FileSystemFactoryOptions()));
    ARROW_ASSIGN_OR_RAISE(auto dataset, factory->Finish());
    ARROW_ASSIGN_OR_RAISE(auto scan_builder, dataset->NewScan());
    ARROW_ASSIGN_OR_RAISE(auto scanner, scan_builder->Finish());
    ARROW_ASSIGN_OR_RAISE(auto table, scanner->ToTable());
    tb = table;
    return arrow::Status::OK();
}

arrow::Status ArrowUtil::read_parquet(char const* file_name, std::shared_ptr<arrow::Table>& tb) {

    auto fs = std::make_shared<arrow::fs::LocalFileSystem>();

    ARROW_ASSIGN_OR_RAISE(auto info, fs->GetFileInfo(file_name));

    auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
    ARROW_ASSIGN_OR_RAISE(auto factory,
      arrow::dataset::FileSystemDatasetFactory::Make(fs, {info}, format, arrow::dataset::FileSystemFactoryOptions()));
    ARROW_ASSIGN_OR_RAISE(auto dataset, factory->Finish());
    ARROW_ASSIGN_OR_RAISE(auto scan_builder, dataset->NewScan());
    ARROW_ASSIGN_OR_RAISE(auto scanner, scan_builder->Finish());
    ARROW_ASSIGN_OR_RAISE(auto table, scanner->ToTable());
    tb = table;
    return arrow::Status::OK();
}

arrow::Status ArrowUtil::read_json(char const* file_name, std::shared_ptr<arrow::Table>& tb) {

    std::shared_ptr<arrow::io::ReadableFile> infile;
    PARQUET_ASSIGN_OR_THROW(infile,
                          arrow::io::ReadableFile::Open(file_name,
                                                        arrow::default_memory_pool()));

    ARROW_ASSIGN_OR_RAISE(auto reader, arrow::json::TableReader::Make(arrow::default_memory_pool(), infile, arrow::json::ReadOptions::Defaults(), arrow::json::ParseOptions::Defaults()));

    ARROW_ASSIGN_OR_RAISE(auto res_tb, reader->Read());
    tb = res_tb;
    return arrow::Status::OK();
}   

arrow::Status ArrowUtil::write_ipc(arrow::Table const& tb , char const* file_name) {

    auto fs = std::make_shared<arrow::fs::LocalFileSystem>();
    ARROW_ASSIGN_OR_RAISE(auto output_file, fs->OpenOutputStream
                        (file_name));
    ARROW_ASSIGN_OR_RAISE(auto batch_writer,
                        arrow::ipc::MakeFileWriter(output_file, tb.schema()));
    ARROW_RETURN_NOT_OK(batch_writer->WriteTable(tb));
    ARROW_RETURN_NOT_OK(batch_writer->Close());

    return arrow::Status::OK();
}

arrow::Status ArrowUtil::write_parquet(arrow::Table const& tb , char const* file_name) {
    auto fs = std::make_shared<arrow::fs::LocalFileSystem>();
    PARQUET_ASSIGN_OR_THROW(
      auto outfile, fs->OpenOutputStream(file_name));
    // The last argument to the function call is the size of the RowGroup in
    // the parquet file. Normally you would choose this to be rather large but
    // for the example, we use a small value to have multiple RowGroups.
    PARQUET_THROW_NOT_OK(
      parquet::arrow::WriteTable(tb, arrow::default_memory_pool(), outfile, 2048));
    return arrow::Status::OK();
}


arrow::Result<arrow::compute::ExecNode*> ArrowUtil::MakeTestSourceNode(arrow::compute::ExecPlan* plan, std::string label,
                                     BatchesWithSchema batches_with_schema, bool parallel,
                                     bool slow) {
  DCHECK_GT(batches_with_schema.batches.size(), 0);

  auto opt_batches = ::arrow::internal::MapVector(
      [](arrow::compute::ExecBatch batch) { return arrow::util::make_optional(std::move(batch)); },
      std::move(batches_with_schema.batches));

  arrow::AsyncGenerator<arrow::util::optional<arrow::compute::ExecBatch>> gen;

  if (parallel) {
    // emulate batches completing initial decode-after-scan on a cpu thread
    ARROW_ASSIGN_OR_RAISE(
        gen, MakeBackgroundGenerator(arrow::MakeVectorIterator(std::move(opt_batches)),
                                     ::arrow::internal::GetCpuThreadPool()));

    // ensure that callbacks are not executed immediately on a background thread
    gen = MakeTransferredGenerator(std::move(gen), ::arrow::internal::GetCpuThreadPool());
  } else {
    gen = arrow::MakeVectorGenerator(std::move(opt_batches));
  }

  if (slow) {
    gen = arrow::MakeMappedGenerator(std::move(gen), [](const arrow::util::optional<arrow::compute::ExecBatch>& batch) {
      ArrowUtil::SleepFor(1e-3);
      return batch;
    });
  }

  return arrow::compute::MakeSourceNode(plan, label, std::move(batches_with_schema.schema),
                        std::move(gen));
}

arrow::Future<std::vector<arrow::compute::ExecBatch>> ArrowUtil::StartAndCollect(
    arrow::compute::ExecPlan* plan, arrow::AsyncGenerator<arrow::util::optional<arrow::compute::ExecBatch>> gen) {
  RETURN_NOT_OK(plan->Validate());
  RETURN_NOT_OK(plan->StartProducing());

  auto collected_fut = CollectAsyncGenerator(gen);

  return arrow::AllComplete({plan->finished(), arrow::Future<>(collected_fut)})
      .Then([collected_fut]() -> arrow::Result<std::vector<arrow::compute::ExecBatch>> {
        ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result());
        return ::arrow::internal::MapVector(
            [](arrow::util::optional<arrow::compute::ExecBatch> batch) { return std::move(*batch); },
            std::move(collected));
      });
}


BatchesWithSchema ArrowUtil::MakeBatchAndSchema(std::shared_ptr<arrow::Table> const& tb) {
    arrow::TableBatchReader reader(*tb);
    auto schema = tb->schema();

    std::vector<arrow::compute::ExecBatch> ex_batches;
    std::shared_ptr<arrow::RecordBatch> a_batch;
    while(reader.ReadNext(&a_batch).ok()) {
        if(nullptr == a_batch) {
            break;
        }
        arrow::compute::ExecBatch batch(*a_batch);
        ex_batches.emplace_back(std::move(batch));
    }
    
    BatchesWithSchema batche_sh {ex_batches, schema};
    return batche_sh;
}

arrow::Status ArrowUtil::ConvertExecBatchToRecBatch(std::shared_ptr<arrow::Schema> const& schema, std::vector<arrow::compute::ExecBatch> const& exec_batches,  arrow::RecordBatchVector& out_rec_batches) {
    std::vector<std::shared_ptr<arrow::RecordBatch>> rec_batches;
    for(auto ex_batch: exec_batches) {
        rec_batches.emplace_back(ex_batch.ToRecordBatch(schema, arrow::default_memory_pool()).ValueOrDie());
    }
    out_rec_batches = std::move(rec_batches);
    return arrow::Status::OK();
}

arr_/test/arr_rw_test.cpp

#include "arr_/arr_.h"
#include "death_handler/death_handler.h"
#include <glog/logging.h>
#include <gtest/gtest.h>


#include <iostream>
#include <vector>

int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}


GTEST_TEST(RWTests, ReadCsv) { 
    // 读取CSV文件
    char const* csv_path = "../data/test.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    ASSERT_EQ(tb_.num_rows(), 2); 
}

GTEST_TEST(RWTests, ReadIpc) { 
    // 读取Arrow IPC 文件
    char const* ipc_path = "../data/test_dst.arrow";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_ipc(ipc_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    ASSERT_EQ(tb_.num_rows(), 2); 
}

GTEST_TEST(RWTests, WriteIpc) { 
    // 读取CSV文件并写入IPC文件
    char const* csv_path = "../data/test.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);
    auto const& tb_ = *tb;

    char const* write_csv_path = "../data/test_dst.arrow";
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto write_res = ArrowUtil::write_ipc(tb_, write_csv_path);
    ASSERT_TRUE(write_res == arrow::Status::OK());
}

GTEST_TEST(RWTests, WriteParquet) { 
    // 写入Parquet文件
    char const* csv_path = "../data/test.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);
    auto const& tb_ = *tb;

    char const* write_parquet_path = "../data/test_dst.parquet";
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto write_res = ArrowUtil::write_parquet(tb_, write_parquet_path);
    ASSERT_TRUE(write_res == arrow::Status::OK());
}


GTEST_TEST(RWTests, ReadParquet) { 
    // 读取 Parquet
    char const* parquet_path = "../data/test_dst.parquet";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_parquet(parquet_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    ASSERT_EQ(tb_.num_rows(), 2); 
}

GTEST_TEST(RWTests, ReadJson) { 
    // 读取Json文件
    char const* json_path = "../data/test.json";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_json(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    ASSERT_EQ(tb_.num_rows(), 2); 
}

arr_arr_compute_test.cpp

#include "arr_/arr_.h"
#include "death_handler/death_handler.h"
#include <glog/logging.h>
#include <gtest/gtest.h>


#include <iostream>
#include <vector>

int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}

GTEST_TEST(ArrComputeTests, ComputeGreater) { 
    // 比较两列 int 值中 int1 > int2的值, greater函数
    char const* json_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto array_a = tb_.GetColumnByName("int1");
    auto array_b = tb_.GetColumnByName("int2");
    
    auto array_a_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_a);
    auto array_b_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_b);

    auto compared_datum = arrow::compute::CallFunction("greater", {array_a_res, array_b_res});
    auto array_a_gt_b_compute = compared_datum->make_array();
    
    arrow::PrettyPrint(*array_a_gt_b_compute, {}, &std::cerr);

    auto schema =
      arrow::schema({arrow::field("int1", arrow::int64()), arrow::field("int2", arrow::int64()),
                     arrow::field("a>b? (arrow)", arrow::boolean())});
    
    std::shared_ptr<arrow::Table> my_table = arrow::Table::Make(
      schema, {array_a_res, array_b_res, array_a_gt_b_compute}, tb_.num_rows());
    
    arrow::PrettyPrint(*my_table, {}, &std::cerr);
}

GTEST_TEST(ArrComputeTests, ComputeMinMax) {
    // 计算int1列的最大值和最小值
    char const* json_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto array_a = tb_.GetColumnByName("int1");
    auto array_a_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_a);

    arrow::compute::ScalarAggregateOptions scalar_aggregate_options;
    scalar_aggregate_options.skip_nulls = false;

    auto min_max = arrow::compute::CallFunction("min_max", {array_a_res}, &scalar_aggregate_options);

    // Unpack struct scalar result (a two-field {"min", "max"} scalar)
    auto min_value = min_max->scalar_as<arrow::StructScalar>().value[0];
    auto max_value = min_max->scalar_as<arrow::StructScalar>().value[1];

    ASSERT_EQ(min_value->ToString(), "1");
    ASSERT_EQ(max_value->ToString(), "8");
} 

GTEST_TEST(ArrComputeTests, ComputeMean) {
    // 计算int1列的平均值
    char const* json_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto array_a = tb_.GetColumnByName("int1");
    auto array_a_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_a);

    arrow::compute::ScalarAggregateOptions scalar_aggregate_options;
    scalar_aggregate_options.skip_nulls = false;

    auto mean= arrow::compute::CallFunction("mean", {array_a_res}, &scalar_aggregate_options);

    auto const& mean_value = mean->scalar_as<arrow::Scalar>();
    
    ASSERT_EQ(mean_value.ToString(), "4.5");
} 

GTEST_TEST(ArrComputeTests, ComputeAdd) {
    // 将第一列的值加3
    char const* json_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto array_a = tb_.GetColumnByName("int1");
    auto array_a_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_a);

    std::shared_ptr<arrow::Scalar> increment = std::make_shared<arrow::Int64Scalar>(3);

    auto add = arrow::compute::CallFunction("add", {array_a_res, increment});
    std::shared_ptr<arrow::Array> incremented_array = add->array_as<arrow::Array>();
    arrow::PrettyPrint(*incremented_array, {}, &std::cerr);
} 


GTEST_TEST(ArrComputeTests, ComputeAddArray) {
    // int1和int2两列相加
    char const* json_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto array_a = tb_.GetColumnByName("int1");
    auto array_a_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_a);
    
    auto array_b = tb_.GetColumnByName("int2");
    auto array_b_res = ArrowUtil::chunked_array_to_array<int64_t, numbuildT<arrow::Int64Type>, arrow::Int64Array>(array_b);

    auto add = arrow::compute::CallFunction("add", {array_a_res, array_b_res});
    std::shared_ptr<arrow::Array> incremented_array = add->array_as<arrow::Array>();
    arrow::PrettyPrint(*incremented_array, {}, &std::cerr);
} 

GTEST_TEST(ArrComputeTests, ComputeStringEqual) {
    // 比较s1和s2两列是否相等
    char const* json_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);

    auto array_a = tb_.GetColumnByName("s1");
    auto array_a_res = ArrowUtil::chunked_array_to_str_array(array_a);

    auto array_b = tb_.GetColumnByName("s2");
    auto array_b_res = ArrowUtil::chunked_array_to_str_array(array_b);

    auto eq_ = arrow::compute::CallFunction("equal", {array_a_res, array_b_res});
    std::shared_ptr<arrow::Array> eq_array = eq_->array_as<arrow::Array>();
    arrow::PrettyPrint(*eq_array, {}, &std::cerr);
}

GTEST_TEST(ArrComputeTests, ComputeCustom) {
    // 自己写算法逐个比较相等 
    char const* json_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto arr1 = tb_.GetColumnByName("s1");
    auto arr2 = tb_.GetColumnByName("s2");
    auto v1 = ArrowUtil::chunked_array_to_str_vector(arr1);
    auto v2 = ArrowUtil::chunked_array_to_str_vector(arr2);
    for(std::size_t i=0; i<v1.size(); ++i) {
        if(v1[i] != v2[i]) {
            std::cerr << v1[i] << "!=" << v2[i] << "\n";
        }
    }
}

GTEST_TEST(ArrComputeTests, ComputeCustomDbl) { 
    // 自己写算法比较double值
    char const* json_path = "../data/custom_dbl.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto arr1 = tb_.GetColumnByName("dbl1");
    auto arr2 = tb_.GetColumnByName("dbl2");
    auto v1 = ArrowUtil::chunked_array_to_vector<double, arrow::DoubleArray>(arr1);
    auto v2 = ArrowUtil::chunked_array_to_vector<double, arrow::DoubleArray>(arr2);
    for(std::size_t i=0; i<v1.size(); ++i) {
        if(v1[i] != v2[i]) {
            std::cerr << v1[i] << "!=" << v2[i] << "\n";
        }
    }
}

GTEST_TEST(ArrComputeTests, ComputeEqualDbl) { 
    // 使用equal函数比较double值
    char const* json_path = "../data/custom_dbl.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);
    auto arr1 = tb_.GetColumnByName("dbl1");
    auto arr2 = tb_.GetColumnByName("dbl2");

    auto dbl_arr1 = ArrowUtil::chunked_array_to_array<double, numbuildT<arrow::DoubleType>, arrow::DoubleArray>(arr1);
    auto dbl_arr2 = ArrowUtil::chunked_array_to_array<double, numbuildT<arrow::DoubleType>, arrow::DoubleArray>(arr2);

    auto eq_ = arrow::compute::CallFunction("equal", {dbl_arr1, dbl_arr2});
    std::shared_ptr<arrow::Array> eq_array = eq_->array_as<arrow::Array>();
    arrow::PrettyPrint(*eq_array, {}, &std::cerr);
}

GTEST_TEST(ArrComputeTests, StrStartsWith) {
    // 计算s1列以是否以 Zha开头的值
    char const* json_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(json_path, tb);
    auto const& tb_ = *tb;
    arrow::PrettyPrint(tb_, {}, &std::cerr);

    auto array_a = tb_.GetColumnByName("s1");
    auto array_a_res = ArrowUtil::chunked_array_to_str_array(array_a);

    arrow::compute::MatchSubstringOptions options("Zha");

    auto eq_ = arrow::compute::CallFunction("starts_with", {array_a_res}, &options);
    std::shared_ptr<arrow::Array> eq_array = eq_->array_as<arrow::Array>();
    arrow::PrettyPrint(*eq_array, {}, &std::cerr);
}

arr_/test/arr_dataset_test.cpp

#include "arr_/arr_.h"
#include <memory>
#include "death_handler/death_handler.h"
#include <glog/logging.h>
#include <gtest/gtest.h>


#include <iostream>
#include <vector>

int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}


GTEST_TEST(DatasetTests, ProjectGreater) { 
    // 两列比较判断
    char const* csv_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto ds = std::make_shared<arrow::dataset::InMemoryDataset>(tb);
    auto scanner_builder = ds->NewScan().ValueOrDie();
    scanner_builder->Project({
        arrow::compute::field_ref("int1"),
        arrow::compute::field_ref("int2"),
        arrow::compute::call("greater", {arrow::compute::field_ref("int1"), arrow::compute::field_ref("int2")})
    }, {"int1", "int2", "res"});
    auto scanner = scanner_builder->Finish().ValueOrDie();
    auto proj_tb = scanner->ToTable().ValueOrDie();
    arrow::PrettyPrint(*proj_tb, {}, &std::cerr);
}

GTEST_TEST(DatasetTests, ProjectAddLiteral) { 
    // 本列加一个标量
    char const* csv_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto ds = std::make_shared<arrow::dataset::InMemoryDataset>(tb);
    auto scanner_builder = ds->NewScan().ValueOrDie();

    scanner_builder->Project({
        arrow::compute::field_ref("int1"),
        arrow::compute::field_ref("int2"),
        arrow::compute::call("add", {arrow::compute::field_ref("int1"), arrow::compute::literal(3)})
    }, {"int1", "int2", "added_res"});
    auto scanner = scanner_builder->Finish().ValueOrDie();
    auto proj_tb = scanner->ToTable().ValueOrDie();
    arrow::PrettyPrint(*proj_tb, {}, &std::cerr);
}

GTEST_TEST(DatasetTests, ProjectAddColumn) { 
    // 两列相加
    char const* csv_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto ds = std::make_shared<arrow::dataset::InMemoryDataset>(tb);
    auto scanner_builder = ds->NewScan().ValueOrDie();

    scanner_builder->Project({
        arrow::compute::field_ref("int1"),
        arrow::compute::field_ref("int2"),
        arrow::compute::call("add", {arrow::compute::field_ref("int1"), arrow::compute::field_ref("int2")})
    }, {"int1", "int2", "added_column_res"});
    auto scanner = scanner_builder->Finish().ValueOrDie();
    auto proj_tb = scanner->ToTable().ValueOrDie();
    arrow::PrettyPrint(*proj_tb, {}, &std::cerr);
}

GTEST_TEST(DatasetTests, ProjectStrEq) { 
    // 两列字符串相等
    char const* csv_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto ds = std::make_shared<arrow::dataset::InMemoryDataset>(tb);
    auto scanner_builder = ds->NewScan().ValueOrDie();

    scanner_builder->Project({
        arrow::compute::field_ref("s1"),
        arrow::compute::field_ref("s2"),
        arrow::compute::call("equal", {arrow::compute::field_ref("s1"), arrow::compute::field_ref("s2")})
    }, {"s1", "s2", "is_equal"});
    auto scanner = scanner_builder->Finish().ValueOrDie();
    auto proj_tb = scanner->ToTable().ValueOrDie();
    arrow::PrettyPrint(*proj_tb, {}, &std::cerr);
}

GTEST_TEST(DatasetTests, ProjectDoubleEq) { 
    // 两列double相等
    char const* csv_path = "../data/custom_dbl.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto ds = std::make_shared<arrow::dataset::InMemoryDataset>(tb);
    auto scanner_builder = ds->NewScan().ValueOrDie();

    scanner_builder->Project({
        arrow::compute::field_ref("dbl1"),
        arrow::compute::field_ref("dbl2"),
        arrow::compute::call("equal", {arrow::compute::field_ref("dbl1"), arrow::compute::field_ref("dbl2")})
    }, {"dbl1", "dbl2", "is_dbl_equal"});
    auto scanner = scanner_builder->Finish().ValueOrDie();
    auto proj_tb = scanner->ToTable().ValueOrDie();
    arrow::PrettyPrint(*proj_tb, {}, &std::cerr);
}

GTEST_TEST(DatasetTests, ProjectStrStartsWith) { 
    // 字符串以Zha开头
    char const* csv_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto ds = std::make_shared<arrow::dataset::InMemoryDataset>(tb);
    auto scanner_builder = ds->NewScan().ValueOrDie();

    arrow::compute::MatchSubstringOptions options("Zha");

    scanner_builder->Project({
        arrow::compute::field_ref("s1"),
        arrow::compute::field_ref("s2"),
        arrow::compute::call("starts_with", {arrow::compute::field_ref("s1")}, options)
    }, {"s1", "s2", "is_str_starts_with"});
    auto scanner = scanner_builder->Finish().ValueOrDie();
    auto proj_tb = scanner->ToTable().ValueOrDie();
    arrow::PrettyPrint(*proj_tb, {}, &std::cerr);
}

arr_/test/arr_execution_engine_test.cpp

#include "arr_/arr_.h"
#include <memory>
#include "death_handler/death_handler.h"
#include <glog/logging.h>
#include <gtest/gtest.h>



#include <iostream>
#include <vector>

int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}

GTEST_TEST(ExecEngineTests, ProjectSourceSink) { 
    // 直接Sink原有数据
    char const* csv_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto plan = arrow::compute::ExecPlan::Make().ValueOrDie();
    auto batche_sh = ArrowUtil::MakeBatchAndSchema(tb);

    auto source = ArrowUtil::MakeTestSourceNode(plan.get(), "source",
                                                           batche_sh, true, false).ValueOrDie();

    auto sink_gen = arrow::compute::MakeSinkNode(source, "sink");
    auto res_ex_batches = ArrowUtil::StartAndCollect(plan.get(), sink_gen).result().ValueOrDie();
    
    std::vector<std::shared_ptr<arrow::RecordBatch>> rec_batches;
    ArrowUtil::ConvertExecBatchToRecBatch(tb->schema(), res_ex_batches, rec_batches);

    auto res_tb = arrow::Table::FromRecordBatches(rec_batches).ValueOrDie();
    arrow::PrettyPrint(*res_tb, {}, &std::cerr);
}

GTEST_TEST(ExecEngineTests, ProjectFilterThenProject) { 
    // 先做 Filter > 3的,再做加和
    char const* csv_path = "../data/comp_gt.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto plan = arrow::compute::ExecPlan::Make().ValueOrDie();
    auto batche_sh = ArrowUtil::MakeBatchAndSchema(tb);

    auto source = ArrowUtil::MakeTestSourceNode(plan.get(), "source",
                                                           batche_sh, true, false).ValueOrDie();

    auto predicate = arrow::compute::greater(arrow::compute::field_ref("int1"), arrow::compute::literal(3)).Bind(*batche_sh.schema).ValueOrDie();

    auto filter = arrow::compute::MakeFilterNode(source, "filter", predicate).ValueOrDie();

    std::vector<arrow::compute::Expression> exprs{
        arrow::compute::field_ref("int1"),
        arrow::compute::field_ref("int2"),
        arrow::compute::call("add", { arrow::compute::field_ref("int1"),  arrow::compute::field_ref("int2")}),
    };

    for (auto& expr : exprs) {
        expr = expr.Bind(*batche_sh.schema).ValueOrDie();
    }

    auto projection = arrow::compute::MakeProjectNode(filter, "project", exprs, {"int1", "int2", "add_res"}).ValueOrDie();
    auto sink_gen = arrow::compute::MakeSinkNode(projection, "sink");
    auto res_ex_batches = ArrowUtil::StartAndCollect(plan.get(), sink_gen).result().ValueOrDie();
    
    auto result_schema = arrow::schema({arrow::field("int1", arrow::int64()),
            arrow::field("int2", arrow::int64()), 
            arrow::field("add_res", arrow::int64())}); 

    std::vector<std::shared_ptr<arrow::RecordBatch>> rec_batches;
    ArrowUtil::ConvertExecBatchToRecBatch(result_schema, res_ex_batches, rec_batches);

    auto res_tb = arrow::Table::FromRecordBatches(rec_batches).ValueOrDie();
    arrow::PrettyPrint(*res_tb, {}, &std::cerr);
}


GTEST_TEST(ExecEngineTests, ProjectDblFilter) { 
    // 先Filter startswith Zha, 再 Filter endswith liu
    char const* csv_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto plan = arrow::compute::ExecPlan::Make().ValueOrDie();
    auto batche_sh = ArrowUtil::MakeBatchAndSchema(tb);

    auto source = ArrowUtil::MakeTestSourceNode(plan.get(), "source",
                                                           batche_sh, true, false).ValueOrDie();
    arrow::compute::MatchSubstringOptions s_options("Zha");
    auto s_predicate = arrow::compute::call("starts_with", {arrow::compute::field_ref("s1")}, s_options).Bind(*batche_sh.schema).ValueOrDie();
    auto s_filter = arrow::compute::MakeFilterNode(source, "filter", s_predicate).ValueOrDie();

    arrow::compute::MatchSubstringOptions e_options("liu");
    auto e_predicate = arrow::compute::call("ends_with", {arrow::compute::field_ref("s1")}, e_options).Bind(*batche_sh.schema).ValueOrDie();
    auto e_filter = arrow::compute::MakeFilterNode(s_filter, "filter", e_predicate).ValueOrDie();

    auto sink_gen = arrow::compute::MakeSinkNode(e_filter, "sink");

    auto res_ex_batches = ArrowUtil::StartAndCollect(plan.get(), sink_gen).result().ValueOrDie();
    std::vector<std::shared_ptr<arrow::RecordBatch>> rec_batches;
    ArrowUtil::ConvertExecBatchToRecBatch(tb->schema(), res_ex_batches, rec_batches);

    auto res_tb = arrow::Table::FromRecordBatches(rec_batches).ValueOrDie();
    arrow::PrettyPrint(*res_tb, {}, &std::cerr);
}


GTEST_TEST(ExecEngineTests, ProjectStrEq) { 
    // 比较两列字符串是否相等
    char const* csv_path = "../data/comp_s_eq.csv";
    std::shared_ptr<arrow::Table> tb;
    ArrowUtil::read_csv(csv_path, tb);

    auto plan = arrow::compute::ExecPlan::Make().ValueOrDie();
    auto batche_sh = ArrowUtil::MakeBatchAndSchema(tb);

    auto source = ArrowUtil::MakeTestSourceNode(plan.get(), "source",
                                                           batche_sh, true, false).ValueOrDie();

   
    std::vector<arrow::compute::Expression> exprs{
        arrow::compute::field_ref("s1"),
        arrow::compute::field_ref("s2"),
        arrow::compute::call("equal", {arrow::compute::field_ref("s1"),  arrow::compute::field_ref("s2")}),
    };

    for (auto& expr : exprs) {
        expr = expr.Bind(*batche_sh.schema).ValueOrDie();
    }

    auto projection = arrow::compute::MakeProjectNode(source, "project", exprs, {"s1", "s2", "is_equal"}).ValueOrDie();

    auto sink_gen = arrow::compute::MakeSinkNode(projection, "sink");
    auto res_ex_batches = ArrowUtil::StartAndCollect(plan.get(), sink_gen).result().ValueOrDie();
    
    auto result_schema = arrow::schema({arrow::field("s1", arrow::utf8()),
            arrow::field("s2", arrow::utf8()), 
            arrow::field("is_equal", arrow::boolean())}); 

    std::vector<std::shared_ptr<arrow::RecordBatch>> rec_batches;
    ArrowUtil::ConvertExecBatchToRecBatch(result_schema, res_ex_batches, rec_batches);

    auto res_tb = arrow::Table::FromRecordBatches(rec_batches).ValueOrDie();
    arrow::PrettyPrint(*res_tb, {}, &std::cerr);
}

程序输出如下,


image.png

相关文章

网友评论

      本文标题:Apache Arrow 高级用法合集

      本文链接:https://www.haomeiwen.com/subject/dusncrtx.html