aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-09-16 14:42:49 +0300
committervvvv <vvvv@ydb.tech>2022-09-16 14:42:49 +0300
commit1bfaafcb1adce88e7a6161dba40b7769a04debb3 (patch)
treeb85ec717b2f90b91f2fced052e70a1eb7e925a9b
parentf6c4135c52c0ab15e44686ea64c139b868c81f46 (diff)
downloadydb-1bfaafcb1adce88e7a6161dba40b7769a04debb3.tar.gz
discovery of Arrow function output type over given input types
-rw-r--r--ydb/library/yql/minikql/arrow/CMakeLists.txt1
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.cpp146
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.h9
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp95
4 files changed, 251 insertions, 0 deletions
diff --git a/ydb/library/yql/minikql/arrow/CMakeLists.txt b/ydb/library/yql/minikql/arrow/CMakeLists.txt
index b60d2eef53..eb3204951d 100644
--- a/ydb/library/yql/minikql/arrow/CMakeLists.txt
+++ b/ydb/library/yql/minikql/arrow/CMakeLists.txt
@@ -18,5 +18,6 @@ target_link_libraries(yql-minikql-arrow PUBLIC
library-yql-minikql
)
target_sources(yql-minikql-arrow PRIVATE
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/arrow/mkql_functions.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/arrow/mkql_memory_pool.cpp
)
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
new file mode 100644
index 0000000000..6d6856578a
--- /dev/null
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
@@ -0,0 +1,146 @@
+#include "mkql_functions.h"
+#include <ydb/library/yql/minikql/mkql_node_builder.h>
+#include <ydb/library/yql/minikql/mkql_node_cast.h>
+
+#include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/visitor.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h>
+#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h>
+
+namespace NKikimr::NMiniKQL {
+
+bool ConvertInputArrowType(TType* type, bool& isOptional, arrow::ValueDescr& descr) {
+ auto blockType = AS_TYPE(TBlockType, type);
+ descr.shape = blockType->GetShape() == TBlockType::EShape::Scalar ? arrow::ValueDescr::SCALAR : arrow::ValueDescr::ARRAY;
+ auto unpacked = UnpackOptional(blockType->GetItemType(), isOptional);
+ if (!unpacked->IsData()) {
+ return false;
+ }
+
+ auto slot = AS_TYPE(TDataType, unpacked)->GetDataSlot();
+ if (!slot) {
+ return false;
+ }
+
+ switch (*slot) {
+ case NUdf::EDataSlot::Bool:
+ descr.type = arrow::boolean();
+ return true;
+ case NUdf::EDataSlot::Uint64:
+ descr.type = arrow::uint64();
+ return true;
+ default:
+ return false;
+ }
+}
+
+class TOutputTypeVisitor : public arrow::TypeVisitor
+{
+public:
+ TOutputTypeVisitor(TTypeEnvironment& env)
+ : Env_(env)
+ {}
+
+ arrow::Status Visit(const arrow::BooleanType&) {
+ SetDataType(NUdf::EDataSlot::Bool);
+ return arrow::Status::OK();
+ }
+
+ arrow::Status Visit(const arrow::UInt64Type&) {
+ SetDataType(NUdf::EDataSlot::Uint64);
+ return arrow::Status::OK();
+ }
+
+ TType* GetType() const {
+ return Type_;
+ }
+
+private:
+ void SetDataType(NUdf::EDataSlot slot) {
+ Type_ = TDataType::Create(NUdf::GetDataTypeInfo(slot).TypeId, Env_);
+ }
+
+private:
+ TTypeEnvironment& Env_;
+ TType* Type_ = nullptr;
+};
+
+bool ConvertOutputArrowType(const arrow::compute::OutputType& outType, const std::vector<arrow::ValueDescr>& values,
+ bool optional, TType*& outputType, TTypeEnvironment& env) {
+ arrow::ValueDescr::Shape shape;
+ std::shared_ptr<arrow::DataType> dataType;
+
+ auto execContext = arrow::compute::ExecContext();
+ auto kernelContext = arrow::compute::KernelContext(&execContext);
+ auto descrRes = outType.Resolve(&kernelContext, values);
+ if (!descrRes.ok()) {
+ return false;
+ }
+
+ const auto& descr = *descrRes;
+ dataType = descr.type;
+ shape = descr.shape;
+
+ TOutputTypeVisitor visitor(env);
+ if (!dataType->Accept(&visitor).ok()) {
+ return false;
+ }
+
+ TType* itemType = visitor.GetType();
+ if (optional) {
+ itemType = TOptionalType::Create(itemType, env);
+ }
+
+ switch (shape) {
+ case arrow::ValueDescr::SCALAR:
+ outputType = TBlockType::Create(itemType, TBlockType::EShape::Scalar, env);
+ return true;
+ case arrow::ValueDescr::ARRAY:
+ outputType = TBlockType::Create(itemType, TBlockType::EShape::Many, env);
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env) {
+ auto registry = arrow::compute::GetFunctionRegistry();
+ auto resFunc = registry->GetFunction(TString(name));
+ if (!resFunc.ok()) {
+ return false;
+ }
+
+ const auto& func = *resFunc;
+ if (func->kind() != arrow::compute::Function::SCALAR) {
+ return false;
+ }
+
+ std::vector<arrow::ValueDescr> values;
+ bool hasOptionals = false;
+ for (const auto& type : inputTypes) {
+ arrow::ValueDescr descr;
+ bool isOptional;
+ if (!ConvertInputArrowType(type, isOptional, descr)) {
+ return false;
+ }
+
+ hasOptionals = hasOptionals || isOptional;
+ values.push_back(descr);
+ }
+
+ auto resKernel = func->DispatchExact(values);
+ if (!resKernel.ok()) {
+ return false;
+ }
+
+ const auto& kernel = static_cast<const arrow::compute::ScalarKernel*>(*resKernel);
+ auto notNull = (kernel->null_handling == arrow::compute::NullHandling::OUTPUT_NOT_NULL);
+ const auto& outType = kernel->signature->out_type();
+ if (!ConvertOutputArrowType(outType, values, hasOptionals && !notNull, outputType, env)) {
+ return false;
+ }
+
+ return true;
+}
+
+}
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h
new file mode 100644
index 0000000000..8b7a537bcf
--- /dev/null
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include <ydb/library/yql/minikql/mkql_node.h>
+
+namespace NKikimr::NMiniKQL {
+
+bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env);
+
+}
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp b/ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp
new file mode 100644
index 0000000000..9f27c39c91
--- /dev/null
+++ b/ydb/library/yql/minikql/arrow/mkql_functions_ut.cpp
@@ -0,0 +1,95 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "mkql_functions.h"
+
+namespace NKikimr::NMiniKQL {
+
+Y_UNIT_TEST_SUITE(TMiniKQLArrowFunctions) {
+ Y_UNIT_TEST(Add) {
+ TScopedAlloc alloc;
+ TTypeEnvironment env(alloc);
+
+ auto uint64Type = TDataType::Create(NUdf::GetDataTypeInfo(NUdf::EDataSlot::Uint64).TypeId, env);
+ auto uint64TypeOpt = TOptionalType::Create(uint64Type, env);
+
+ auto scalarType = TBlockType::Create(uint64Type, TBlockType::EShape::Scalar, env);
+ auto arrayType = TBlockType::Create(uint64Type, TBlockType::EShape::Many, env);
+
+ auto scalarTypeOpt = TBlockType::Create(uint64TypeOpt, TBlockType::EShape::Scalar, env);
+ auto arrayTypeOpt = TBlockType::Create(uint64TypeOpt, TBlockType::EShape::Many, env);
+
+ TType* outputType;
+ UNIT_ASSERT(!FindArrowFunction("_add_", {}, outputType, env));
+ UNIT_ASSERT(!FindArrowFunction("add", {}, outputType, env));
+ UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ scalarType }, outputType, env));
+ UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ arrayType }, outputType, env));
+ UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt }, outputType, env));
+ UNIT_ASSERT(!FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt }, outputType, env));
+
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, arrayType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayType));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, arrayType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayType));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, scalarType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayType));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, scalarType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*scalarType));
+
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, arrayTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, arrayTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayType, scalarTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarType, scalarTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*scalarTypeOpt));
+
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, arrayType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, arrayType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, scalarType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, scalarType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*scalarTypeOpt));
+
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, arrayTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, arrayTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ arrayTypeOpt, scalarTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayTypeOpt));
+ UNIT_ASSERT(FindArrowFunction("add", TVector<TType*>{ scalarTypeOpt, scalarTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*scalarTypeOpt));
+ }
+
+ Y_UNIT_TEST(IsNull) {
+ TScopedAlloc alloc;
+ TTypeEnvironment env(alloc);
+
+ auto bool64Type = TDataType::Create(NUdf::GetDataTypeInfo(NUdf::EDataSlot::Bool).TypeId, env);
+ auto bool64TypeOpt = TOptionalType::Create(bool64Type, env);
+
+ auto scalarType = TBlockType::Create(bool64Type, TBlockType::EShape::Scalar, env);
+ auto arrayType = TBlockType::Create(bool64Type, TBlockType::EShape::Many, env);
+
+ auto scalarTypeOpt = TBlockType::Create(bool64TypeOpt, TBlockType::EShape::Scalar, env);
+ auto arrayTypeOpt = TBlockType::Create(bool64TypeOpt, TBlockType::EShape::Many, env);
+
+ TType* outputType;
+ UNIT_ASSERT(!FindArrowFunction("is_null", {}, outputType, env));
+ UNIT_ASSERT(!FindArrowFunction("is_null", TVector<TType*>{ scalarType, scalarType }, outputType, env));
+
+ UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ scalarType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*scalarType));
+ UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ arrayType }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayType));
+
+ UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ scalarTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*scalarType));
+ UNIT_ASSERT(FindArrowFunction("is_null", TVector<TType*>{ arrayTypeOpt }, outputType, env));
+ UNIT_ASSERT(outputType->Equals(*arrayType));
+ }
+}
+
+}