diff options
author | vvvv <vvvv@ydb.tech> | 2022-09-16 14:42:49 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-09-16 14:42:49 +0300 |
commit | 1bfaafcb1adce88e7a6161dba40b7769a04debb3 (patch) | |
tree | b85ec717b2f90b91f2fced052e70a1eb7e925a9b | |
parent | f6c4135c52c0ab15e44686ea64c139b868c81f46 (diff) | |
download | ydb-1bfaafcb1adce88e7a6161dba40b7769a04debb3.tar.gz |
discovery of Arrow function output type over given input types
-rw-r--r-- | ydb/library/yql/minikql/arrow/CMakeLists.txt | 1 | ||||
-rw-r--r-- | ydb/library/yql/minikql/arrow/mkql_functions.cpp | 146 | ||||
-rw-r--r-- | ydb/library/yql/minikql/arrow/mkql_functions.h | 9 | ||||
-rw-r--r-- | ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp | 95 |
4 files changed, 251 insertions, 0 deletions
diff --git a/ydb/library/yql/minikql/arrow/CMakeLists.txt b/ydb/library/yql/minikql/arrow/CMakeLists.txt index b60d2eef53..eb3204951d 100644 --- a/ydb/library/yql/minikql/arrow/CMakeLists.txt +++ b/ydb/library/yql/minikql/arrow/CMakeLists.txt @@ -18,5 +18,6 @@ target_link_libraries(yql-minikql-arrow PUBLIC library-yql-minikql ) target_sources(yql-minikql-arrow PRIVATE + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/arrow/mkql_functions.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/arrow/mkql_memory_pool.cpp ) diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp new file mode 100644 index 0000000000..6d6856578a --- /dev/null +++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp @@ -0,0 +1,146 @@ +#include "mkql_functions.h" +#include <ydb/library/yql/minikql/mkql_node_builder.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/visitor.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h> + +namespace NKikimr::NMiniKQL { + +bool ConvertInputArrowType(TType* type, bool& isOptional, arrow::ValueDescr& descr) { + auto blockType = AS_TYPE(TBlockType, type); + descr.shape = blockType->GetShape() == TBlockType::EShape::Scalar ? arrow::ValueDescr::SCALAR : arrow::ValueDescr::ARRAY; + auto unpacked = UnpackOptional(blockType->GetItemType(), isOptional); + if (!unpacked->IsData()) { + return false; + } + + auto slot = AS_TYPE(TDataType, unpacked)->GetDataSlot(); + if (!slot) { + return false; + } + + switch (*slot) { + case NUdf::EDataSlot::Bool: + descr.type = arrow::boolean(); + return true; + case NUdf::EDataSlot::Uint64: + descr.type = arrow::uint64(); + return true; + default: + return false; + } +} + +class TOutputTypeVisitor : public arrow::TypeVisitor +{ +public: + TOutputTypeVisitor(TTypeEnvironment& env) + : Env_(env) + {} + + arrow::Status Visit(const arrow::BooleanType&) { + SetDataType(NUdf::EDataSlot::Bool); + return arrow::Status::OK(); + } + + arrow::Status Visit(const arrow::UInt64Type&) { + SetDataType(NUdf::EDataSlot::Uint64); + return arrow::Status::OK(); + } + + TType* GetType() const { + return Type_; + } + +private: + void SetDataType(NUdf::EDataSlot slot) { + Type_ = TDataType::Create(NUdf::GetDataTypeInfo(slot).TypeId, Env_); + } + +private: + TTypeEnvironment& Env_; + TType* Type_ = nullptr; +}; + +bool ConvertOutputArrowType(const arrow::compute::OutputType& outType, const std::vector<arrow::ValueDescr>& values, + bool optional, TType*& outputType, TTypeEnvironment& env) { + arrow::ValueDescr::Shape shape; + std::shared_ptr<arrow::DataType> dataType; + + auto execContext = arrow::compute::ExecContext(); + auto kernelContext = arrow::compute::KernelContext(&execContext); + auto descrRes = outType.Resolve(&kernelContext, values); + if (!descrRes.ok()) { + return false; + } + + const auto& descr = *descrRes; + dataType = descr.type; + shape = descr.shape; + + TOutputTypeVisitor visitor(env); + if (!dataType->Accept(&visitor).ok()) { + return false; + } + + TType* itemType = visitor.GetType(); + if (optional) { + itemType = TOptionalType::Create(itemType, env); + } + + switch (shape) { + case arrow::ValueDescr::SCALAR: + outputType = TBlockType::Create(itemType, TBlockType::EShape::Scalar, env); + return true; + case arrow::ValueDescr::ARRAY: + outputType = TBlockType::Create(itemType, TBlockType::EShape::Many, env); + return true; + default: + return false; + } +} + +bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env) { + auto registry = arrow::compute::GetFunctionRegistry(); + auto resFunc = registry->GetFunction(TString(name)); + if (!resFunc.ok()) { + return false; + } + + const auto& func = *resFunc; + if (func->kind() != arrow::compute::Function::SCALAR) { + return false; + } + + std::vector<arrow::ValueDescr> values; + bool hasOptionals = false; + for (const auto& type : inputTypes) { + arrow::ValueDescr descr; + bool isOptional; + if (!ConvertInputArrowType(type, isOptional, descr)) { + return false; + } + + hasOptionals = hasOptionals || isOptional; + values.push_back(descr); + } + + auto resKernel = func->DispatchExact(values); + if (!resKernel.ok()) { + return false; + } + + const auto& kernel = static_cast<const arrow::compute::ScalarKernel*>(*resKernel); + auto notNull = (kernel->null_handling == arrow::compute::NullHandling::OUTPUT_NOT_NULL); + const auto& outType = kernel->signature->out_type(); + if (!ConvertOutputArrowType(outType, values, hasOptionals && !notNull, outputType, env)) { + return false; + } + + return true; +} + +} diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h new file mode 100644 index 0000000000..8b7a537bcf --- /dev/null +++ b/ydb/library/yql/minikql/arrow/mkql_functions.h @@ -0,0 +1,9 @@ +#pragma once + +#include <ydb/library/yql/minikql/mkql_node.h> + +namespace NKikimr::NMiniKQL { + +bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env); + +} diff --git a/ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp b/ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp new file mode 100644 index 0000000000..9f27c39c91 --- /dev/null +++ b/ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp @@ -0,0 +1,95 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "mkql_functions.h" + +namespace NKikimr::NMiniKQL { + +Y_UNIT_TEST_SUITE(TMiniKQLArrowFunctions) { + Y_UNIT_TEST(Add) { + TScopedAlloc alloc; + TTypeEnvironment env(alloc); + + auto uint64Type = TDataType::Create(NUdf::GetDataTypeInfo(NUdf::EDataSlot::Uint64).TypeId, env); + auto uint64TypeOpt = TOptionalType::Create(uint64Type, env); + + auto scalarType = TBlockType::Create(uint64Type, TBlockType::EShape::Scalar, env); + auto arrayType = TBlockType::Create(uint64Type, TBlockType::EShape::Many, env); + + auto scalarTypeOpt = TBlockType::Create(uint64TypeOpt, TBlockType::EShape::Scalar, env); + auto arrayTypeOpt = TBlockType::Create(uint64TypeOpt, TBlockType::EShape::Many, env); + + TType* outputType; + UNIT_ASSERT(!FindArrowFunction("_add_", {}, outputType, env)); + UNIT_ASSERT(!FindArrowFunction("add", {}, outputType, env)); + UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ scalarType }, outputType, env)); + UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ arrayType }, outputType, env)); + UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt }, outputType, env)); + UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt }, outputType, env)); + + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, arrayType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayType)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, arrayType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayType)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, scalarType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayType)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, scalarType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*scalarType)); + + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, arrayTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, arrayTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, scalarTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, scalarTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*scalarTypeOpt)); + + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, arrayType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, arrayType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, scalarType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, scalarType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*scalarTypeOpt)); + + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, arrayTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, arrayTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, scalarTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayTypeOpt)); + UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, scalarTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*scalarTypeOpt)); + } + + Y_UNIT_TEST(IsNull) { + TScopedAlloc alloc; + TTypeEnvironment env(alloc); + + auto bool64Type = TDataType::Create(NUdf::GetDataTypeInfo(NUdf::EDataSlot::Bool).TypeId, env); + auto bool64TypeOpt = TOptionalType::Create(bool64Type, env); + + auto scalarType = TBlockType::Create(bool64Type, TBlockType::EShape::Scalar, env); + auto arrayType = TBlockType::Create(bool64Type, TBlockType::EShape::Many, env); + + auto scalarTypeOpt = TBlockType::Create(bool64TypeOpt, TBlockType::EShape::Scalar, env); + auto arrayTypeOpt = TBlockType::Create(bool64TypeOpt, TBlockType::EShape::Many, env); + + TType* outputType; + UNIT_ASSERT(!FindArrowFunction("is_null", {}, outputType, env)); + UNIT_ASSERT(!FindArrowFunction("is_null", TVector<TType*>{ scalarType, scalarType }, outputType, env)); + + UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ scalarType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*scalarType)); + UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ arrayType }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayType)); + + UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ scalarTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*scalarType)); + UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ arrayTypeOpt }, outputType, env)); + UNIT_ASSERT(outputType->Equals(*arrayType)); + } +} + +} |