aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-09-19 14:19:11 +0300
committervvvv <vvvv@ydb.tech>2022-09-19 14:19:11 +0300
commit114d80934ba1d0aba0ebb40b04313cff5fc0ef46 (patch)
tree44b073f5ac24852d77876a489822e4ad0ec89483
parentf6fbece613358589c77e5732ea9782ebf4342158 (diff)
downloadydb-114d80934ba1d0aba0ebb40b04313cff5fc0ef46.tar.gz
generic block function execution
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.cpp8
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.h2
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_memory_pool.h2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/CMakeLists.txt1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp193
5 files changed, 106 insertions, 100 deletions
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
index 6d6856578a..2f4133405f 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
@@ -2,10 +2,10 @@
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/visitor.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h>
+#include <arrow/datum.h>
+#include <arrow/visitor.h>
+#include <arrow/compute/registry.h>
+#include <arrow/compute/function.h>
namespace NKikimr::NMiniKQL {
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h
index 8b7a537bcf..30d84518e9 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.h
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.h
@@ -1,9 +1,11 @@
#pragma once
#include <ydb/library/yql/minikql/mkql_node.h>
+#include <arrow/datum.h>
namespace NKikimr::NMiniKQL {
bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env);
+bool ConvertInputArrowType(TType* type, bool& isOptional, arrow::ValueDescr& descr);
}
diff --git a/ydb/library/yql/minikql/arrow/mkql_memory_pool.h b/ydb/library/yql/minikql/arrow/mkql_memory_pool.h
index e48e0a7af2..8caacd24c2 100644
--- a/ydb/library/yql/minikql/arrow/mkql_memory_pool.h
+++ b/ydb/library/yql/minikql/arrow/mkql_memory_pool.h
@@ -2,7 +2,7 @@
#include <ydb/library/yql/minikql/mkql_alloc.h>
-#include <contrib/libs/apache/arrow/cpp/src/arrow/memory_pool.h>
+#include <arrow/memory_pool.h>
namespace NKikimr::NMiniKQL {
diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt
index 1a50d6ec5d..9a8ecbcd28 100644
--- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt
+++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt
@@ -18,6 +18,7 @@ target_link_libraries(yql-minikql-comp_nodes PUBLIC
libs-apache-arrow
ydb-library-binary_json
library-yql-minikql
+ yql-minikql-arrow
yql-minikql-invoke_builtins
library-yql-utils
yql-minikql-codegen
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
index 917678ada8..2994b92b36 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
@@ -5,6 +5,7 @@
#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h>
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
+#include <ydb/library/yql/minikql/arrow/mkql_functions.h>
#include <arrow/array/builder_primitive.h>
#include <arrow/compute/exec_internal.h>
@@ -22,137 +23,139 @@ class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> {
public:
TBlockFuncWrapper(TComputationMutables& mutables,
const TString& funcName,
- IComputationNode* leftArg,
- IComputationNode* rightArg,
- TType* leftArgType,
- TType* rightArgType,
- TType* outputType)
+ TVector<IComputationNode*>&& argsNodes,
+ TVector<TType*>&& argsTypes)
: TMutableComputationNode(mutables)
+ , StateIndex(mutables.CurValueIndex++)
, FuncName(funcName)
- , LeftArg(leftArg)
- , RightArg(rightArg)
- , LeftValueDesc(ToValueDescr(leftArgType))
- , RightValueDesc(ToValueDescr(rightArgType))
- , OutputValueDescr(ToValueDescr(outputType))
- , Kernel(ResolveKernel(FuncName, LeftValueDesc, RightValueDesc))
- , OutputTypeBitWidth(static_cast<const arrow::FixedWidthType&>(*OutputValueDescr.type).bit_width())
+ , ArgsNodes(std::move(argsNodes))
+ , ArgsTypes(std::move(argsTypes))
+ , ArgsValuesDescr(ToValueDescr(ArgsTypes))
, FunctionRegistry(*arrow::compute::GetFunctionRegistry())
+ , Function(ResolveFunction(FunctionRegistry, FuncName))
+ , Kernel(ResolveKernel(Function, ArgsValuesDescr))
{
- {
- auto execContext = arrow::compute::ExecContext();
- auto kernelContext = arrow::compute::KernelContext(&execContext);
- const auto kernelOutputValueDesc = ARROW_RESULT(Kernel.signature->out_type().Resolve(&kernelContext, {
- LeftValueDesc,
- RightValueDesc
- }));
- Y_VERIFY_DEBUG(kernelOutputValueDesc == OutputValueDescr);
- }
-
- Y_VERIFY_DEBUG(
- LeftValueDesc.shape == arrow::ValueDescr::ARRAY && RightValueDesc.shape == arrow::ValueDescr::ARRAY ||
- LeftValueDesc.shape == arrow::ValueDescr::SCALAR && RightValueDesc.shape == arrow::ValueDescr::ARRAY ||
- LeftValueDesc.shape == arrow::ValueDescr::ARRAY && RightValueDesc.shape == arrow::ValueDescr::SCALAR);
}
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- const auto leftValue = LeftArg->GetValue(ctx);
- const auto rightValue = RightArg->GetValue(ctx);
- auto& leftDatum = TArrowBlock::From(leftValue).GetDatum();
- auto& rightDatum = TArrowBlock::From(rightValue).GetDatum();
- Y_VERIFY_DEBUG(leftDatum.descr() == LeftValueDesc);
- Y_VERIFY_DEBUG(rightDatum.descr() == RightValueDesc);
- const auto leftKind = leftDatum.kind();
- const auto rightKind = rightDatum.kind();
- MKQL_ENSURE(leftKind != arrow::Datum::ARRAY || rightKind != arrow::Datum::ARRAY ||
- leftDatum.array()->length == rightDatum.array()->length,
- "block size mismatch: "
- << static_cast<ui64>(leftDatum.array()->length)
- << " != "
- << static_cast<ui64>(rightDatum.array()->length));
- const auto blockLength = leftKind == arrow::Datum::ARRAY
- ? leftDatum.array()->length
- : rightDatum.array()->length;
-
- auto execContext = arrow::compute::ExecContext(&ctx.ArrowMemoryPool, nullptr, &FunctionRegistry);
- auto kernelContext = arrow::compute::KernelContext(&execContext);
-
- arrow::Datum output = arrow::ArrayData::Make(
- OutputValueDescr.type,
- blockLength,
- std::vector<std::shared_ptr<arrow::Buffer>> {
- ARROW_RESULT(kernelContext.AllocateBitmap(blockLength)),
- ARROW_RESULT(kernelContext.Allocate(arrow::BitUtil::BytesForBits(OutputTypeBitWidth * blockLength)))
- });
- const auto inputBatch = arrow::compute::ExecBatch({leftDatum, rightDatum}, blockLength);
- ARROW_OK(arrow::compute::detail::PropagateNulls(&kernelContext, inputBatch, output.array().get()));
- ARROW_OK(Kernel.exec(&kernelContext, inputBatch, &output));
+ auto& state = GetState(ctx);
+
+ state.Values.clear();
+ for (ui32 i = 0; i < ArgsNodes.size(); ++i) {
+ state.Values.emplace_back(TArrowBlock::From(ArgsNodes[i]->GetValue(ctx)).GetDatum());
+ Y_VERIFY_DEBUG(ArgsValuesDescr[i] == state.Values.back().descr());
+ }
+
+ auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
+ ARROW_OK(state.Executor->Execute(state.Values, listener.get()));
+ auto output = state.Executor->WrapResults(state.Values, listener->values());
return ctx.HolderFactory.CreateArrowBlock(std::move(output));
}
private:
void RegisterDependencies() const final {
- this->DependsOn(LeftArg);
- this->DependsOn(RightArg);
+ for (const auto& arg : ArgsNodes) {
+ this->DependsOn(arg);
+ }
}
- static const arrow::compute::ScalarKernel& ResolveKernel(const TString& funcName,
- const arrow::ValueDescr& leftArg,
- const arrow::ValueDescr& rightArg)
- {
- auto* functionRegistry = arrow::compute::GetFunctionRegistry();
- Y_VERIFY_DEBUG(functionRegistry != nullptr);
- auto function = ARROW_RESULT(functionRegistry->GetFunction(funcName));
- Y_VERIFY_DEBUG(function != nullptr);
- Y_VERIFY_DEBUG(function->kind() == arrow::compute::Function::SCALAR);
+ static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, const TString& funcName) {
+ auto function = ARROW_RESULT(registry.GetFunction(funcName));
+ MKQL_ENSURE(function != nullptr, "missing function");
+ MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function");
+ return *function;
+ }
- const auto* kernel = ARROW_RESULT(function->DispatchExact({leftArg, rightArg}));
+ static const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) {
+ const auto kernel = ARROW_RESULT(function.DispatchExact(args));
return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
}
- static std::shared_ptr<arrow::DataType> ConvertType(TType* type) {
+ static arrow::ValueDescr ToValueDescr(TType* type) {
bool isOptional;
- const auto dataType = UnpackOptionalData(type, isOptional);
- switch (*dataType->GetDataSlot()) {
- case NUdf::EDataSlot::Uint64:
- return arrow::uint64();
- default:
- Y_FAIL("unexpected type %s", TString(dataType->GetKindAsStr()).c_str());
+ arrow::ValueDescr ret;
+ MKQL_ENSURE(ConvertInputArrowType(type, isOptional, ret), "can't get arrow type");
+ return ret;
+ }
+
+ static std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
+ std::vector<arrow::ValueDescr> res;
+ res.reserve(types.size());
+ for (const auto& type : types) {
+ res.emplace_back(ToValueDescr(type));
}
+
+ return res;
}
- static arrow::ValueDescr ToValueDescr(TType* type) {
- auto* blockType = AS_TYPE(TBlockType, type);
- const auto shape = blockType->GetShape() == TBlockType::EShape::Scalar
- ? arrow::ValueDescr::SCALAR
- : arrow::ValueDescr::ARRAY;
- return arrow::ValueDescr(ConvertType(blockType->GetItemType()), shape);
+ struct TState : public TComputationValue<TState> {
+ using TComputationValue::TComputationValue;
+
+ TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::ScalarKernel& kernel,
+ arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx)
+ : TComputationValue(memInfo)
+ , ExecContext(&ctx.ArrowMemoryPool, nullptr, &registry)
+ , KernelContext(&ExecContext)
+ , Executor(arrow::compute::detail::KernelExecutor::MakeScalar())
+ {
+ auto options = function.default_options();
+ if (kernel.init) {
+ State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ KernelContext.SetState(State.get());
+ }
+
+ ARROW_OK(Executor->Init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ Values.reserve(argsValuesDescr.size());
+ }
+
+ arrow::compute::ExecContext ExecContext;
+ arrow::compute::KernelContext KernelContext;
+ std::unique_ptr<arrow::compute::KernelState> State;
+ std::unique_ptr<arrow::compute::detail::KernelExecutor> Executor;
+
+ std::vector<arrow::Datum> Values;
+ };
+
+ TState& GetState(TComputationContext& ctx) const {
+ auto& result = ctx.MutableValues[StateIndex];
+ if (!result.HasValue()) {
+ result = ctx.HolderFactory.Create<TState>(Function, Kernel, FunctionRegistry, ArgsValuesDescr, ctx);
+ }
+
+ return *static_cast<TState*>(result.AsBoxed().Get());
}
private:
+ const ui32 StateIndex;
const TString FuncName;
- IComputationNode* LeftArg;
- IComputationNode* RightArg;
- const arrow::ValueDescr LeftValueDesc;
- const arrow::ValueDescr RightValueDesc;
- const arrow::ValueDescr OutputValueDescr;
- const arrow::compute::ScalarKernel& Kernel;
- const int OutputTypeBitWidth;
+ const TVector<IComputationNode*> ArgsNodes;
+ const TVector<TType*> ArgsTypes;
+
+ const std::vector<arrow::ValueDescr> ArgsValuesDescr;
arrow::compute::FunctionRegistry& FunctionRegistry;
+ const arrow::compute::Function& Function;
+ const arrow::compute::ScalarKernel& Kernel;
};
}
IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- const auto* callableType = callable.GetType();
+ MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg");
const auto funcNameData = AS_VALUE(TDataLiteral, callable.GetInput(0));
const auto funcName = TString(funcNameData->AsValue().AsStringRef());
+ TVector<IComputationNode*> argsNodes;
+ TVector<TType*> argsTypes;
+ const auto callableType = callable.GetType();
+ for (ui32 i = 1; i < callable.GetInputsCount(); ++i) {
+ argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i));
+ argsTypes.push_back(callableType->GetArgumentType(i));
+ }
+
return new TBlockFuncWrapper(ctx.Mutables,
funcName,
- LocateNode(ctx.NodeLocator, callable, 1),
- LocateNode(ctx.NodeLocator, callable, 2),
- callableType->GetArgumentType(1),
- callableType->GetArgumentType(2),
- callableType->GetReturnType());
+ std::move(argsNodes),
+ std::move(argsTypes)
+ );
}
}