美文网首页
Apache Arrow 构建简单的UDF,用于DataSet

Apache Arrow 构建简单的UDF,用于DataSet

作者: FredricZhu | 来源:发表于2024-02-06 14:25 被阅读0次

环境配置方面请参考下面的博客。
https://www.jianshu.com/p/6f6b5c88acc4
本例主要演示在Apache Arrow环境中构建自定义的UDF,进行自定义运算的过程。
本例定义一个三元的求和函数,求三列之和。
代码如下,
conanfile.txt

[requires]
boost/1.72.0
arrow/15.0.0

[generators]
cmake

CMakeLists.txt

cmake_minimum_required(VERSION 3.3)

project(9_udf_in_dataset)

set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")

set ( CMAKE_CXX_FLAGS "-pthread")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

include_directories(${INCLUDE_DIRS})
LINK_DIRECTORIES(${LINK_DIRS})

file( GLOB main_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 
file( GLOB sources ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)

foreach( main_file ${main_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${main_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${main_file} ${sources})
    target_link_libraries(${file}  ${CONAN_LIBS_ARROW} ${CONAN_LIBS} pthread)
endforeach( main_file ${main_file_list})

udf_in_dataset.cpp

#include <arrow/api.h>
#include <arrow/compute/api.h>
#include <arrow/csv/api.h>
#include "arrow/acero/exec_plan.h"
#include "arrow/compute/expression.h"

#include <arrow/dataset/dataset.h>
#include <arrow/dataset/plan.h>
#include <arrow/dataset/scanner.h>

#include <arrow/io/interfaces.h>
#include <arrow/io/memory.h>
#include <arrow/io/stdio.h>

#include <arrow/filesystem/filesystem.h>

#include <arrow/result.h>
#include <arrow/status.h>

#include <arrow/util/vector.h>

#include <iostream>
#include <vector>


namespace ds = arrow::dataset;
namespace cp = arrow::compute;


char ch_csv_data[] = R"csv(k1,k2,k3
1,4,7
2,5,8
11,20,21
3,6,9)csv";

cp::FunctionDoc func_doc{
    "User-defined-function usage to demonstrate registering an out-of-tree function",
    "returns x + y + z",
    {"x", "y", "z"},
    "UDFOptions"
};

arrow::Result<std::shared_ptr<ds::Dataset>> CreateDatasetFromCsvData() {
    arrow::io::IOContext const& io_context = arrow::io::default_io_context();
    std::shared_ptr<arrow::io::InputStream> input;
    std::string csv_data = ch_csv_data;
    std::cout << csv_data << std::endl;

    std::string_view sv = csv_data;
    input = arrow::io::BufferReader::FromString(std::string(sv));
    auto read_options = arrow::csv::ReadOptions::Defaults();
    auto parse_options = arrow::csv::ParseOptions::Defaults();
    auto convert_options = arrow::csv::ConvertOptions::Defaults();

    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::csv::TableReader> table_reader,
        arrow::csv::TableReader::Make(io_context, input,
            read_options, parse_options, convert_options)); 
    ARROW_ASSIGN_OR_RAISE(auto maybe_table, table_reader->Read());
    auto ds_ = std::make_shared<ds::InMemoryDataset>(maybe_table);
    arrow::Result<std::shared_ptr<ds::InMemoryDataset>> result(std::move(ds_));
    return result;
}

arrow::Status SampleFunction(cp::KernelContext* ctx, cp::ExecSpan const& batch,
    cp::ExecResult* out) {
    
    // return x + y + z
    const int64_t* x = batch[0].array.GetValues<int64_t>(1);
    const int64_t* y = batch[1].array.GetValues<int64_t>(1);
    const int64_t* z = batch[2].array.GetValues<int64_t>(1);
    int64_t* out_values = out->array_span_mutable()->GetValues<int64_t>(1);
    for (int64_t i = 0; i < batch.length; ++i) {
        *out_values++ = *x++ + *y++ + *z++;
    }
    return arrow::Status::OK();
}

arrow::Status UDFDatasetScan() {
    ARROW_ASSIGN_OR_RAISE(auto data_set, CreateDatasetFromCsvData());
    ARROW_ASSIGN_OR_RAISE(auto scan_builder, data_set->NewScan());

    // Customized add three function
    std::string const name = "add_three";
    auto func = std::make_shared<cp::ScalarFunction>(name, cp::Arity::Ternary(), func_doc);
    cp::ScalarKernel kernel({arrow::int64(), arrow::int64(), arrow::int64()},
        arrow::int64(), SampleFunction);

    kernel.mem_allocation = cp::MemAllocation::PREALLOCATE;
    kernel.null_handling = cp::NullHandling::INTERSECTION;
    ARROW_RETURN_NOT_OK(func->AddKernel(std::move(kernel)));

    auto registry = cp::GetFunctionRegistry();
    ARROW_RETURN_NOT_OK(registry->AddFunction(std::move(func)));

    // Start to project the result dataset
    ARROW_RETURN_NOT_OK(scan_builder->Project({
        cp::field_ref("k1"),
        cp::field_ref("k2"),
        cp::field_ref("k3"),
        cp::call("add_three", {cp::field_ref("k1"), cp::field_ref("k2"), cp::field_ref("k3")})
    }, {"k1", "k2", "k3", "sum"}));

    ARROW_ASSIGN_OR_RAISE(auto scanner, scan_builder->Finish());
    ARROW_ASSIGN_OR_RAISE(auto table, scanner->ToTable());
    std::cout << "Result " << table->num_rows() << " rows" << std::endl;
    std::cout << table->ToString() << std::endl;
    return arrow::Status::OK();
} 


int main(int argc, char* argv[]) {
    arrow::Status st = UDFDatasetScan();
    if(!st.ok()) {
        std::cerr << "Error occurred: " << st.message() << std::endl;
        return EXIT_FAILURE;
    }

    return EXIT_SUCCESS;
}

程序输出如下,


image.png
image.png

相关文章

网友评论

      本文标题:Apache Arrow 构建简单的UDF,用于DataSet

      本文链接:https://www.haomeiwen.com/subject/empcadtx.html