diff options
author | vvvv <vvvv@ydb.tech> | 2022-09-19 14:19:11 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-09-19 14:19:11 +0300 |
commit | 114d80934ba1d0aba0ebb40b04313cff5fc0ef46 (patch) | |
tree | 44b073f5ac24852d77876a489822e4ad0ec89483 | |
parent | f6fbece613358589c77e5732ea9782ebf4342158 (diff) | |
download | ydb-114d80934ba1d0aba0ebb40b04313cff5fc0ef46.tar.gz |
generic block function execution
5 files changed, 106 insertions, 100 deletions
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp index 6d6856578a..2f4133405f 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp +++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp @@ -2,10 +2,10 @@ #include <ydb/library/yql/minikql/mkql_node_builder.h> #include <ydb/library/yql/minikql/mkql_node_cast.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/datum.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/visitor.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/function.h> +#include <arrow/datum.h> +#include <arrow/visitor.h> +#include <arrow/compute/registry.h> +#include <arrow/compute/function.h> namespace NKikimr::NMiniKQL { diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.h b/ydb/library/yql/minikql/arrow/mkql_functions.h index 8b7a537bcf..30d84518e9 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.h +++ b/ydb/library/yql/minikql/arrow/mkql_functions.h @@ -1,9 +1,11 @@ #pragma once #include <ydb/library/yql/minikql/mkql_node.h> +#include <arrow/datum.h> namespace NKikimr::NMiniKQL { bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env); +bool ConvertInputArrowType(TType* type, bool& isOptional, arrow::ValueDescr& descr); } diff --git a/ydb/library/yql/minikql/arrow/mkql_memory_pool.h b/ydb/library/yql/minikql/arrow/mkql_memory_pool.h index e48e0a7af2..8caacd24c2 100644 --- a/ydb/library/yql/minikql/arrow/mkql_memory_pool.h +++ b/ydb/library/yql/minikql/arrow/mkql_memory_pool.h @@ -2,7 +2,7 @@ #include <ydb/library/yql/minikql/mkql_alloc.h> -#include <contrib/libs/apache/arrow/cpp/src/arrow/memory_pool.h> +#include <arrow/memory_pool.h> namespace NKikimr::NMiniKQL { diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt index 1a50d6ec5d..9a8ecbcd28 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt @@ -18,6 +18,7 @@ target_link_libraries(yql-minikql-comp_nodes PUBLIC libs-apache-arrow ydb-library-binary_json library-yql-minikql + yql-minikql-arrow yql-minikql-invoke_builtins library-yql-utils yql-minikql-codegen diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index 917678ada8..2994b92b36 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -5,6 +5,7 @@ #include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h> #include <ydb/library/yql/minikql/mkql_node_builder.h> #include <ydb/library/yql/minikql/mkql_node_cast.h> +#include <ydb/library/yql/minikql/arrow/mkql_functions.h> #include <arrow/array/builder_primitive.h> #include <arrow/compute/exec_internal.h> @@ -22,137 +23,139 @@ class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> { public: TBlockFuncWrapper(TComputationMutables& mutables, const TString& funcName, - IComputationNode* leftArg, - IComputationNode* rightArg, - TType* leftArgType, - TType* rightArgType, - TType* outputType) + TVector<IComputationNode*>&& argsNodes, + TVector<TType*>&& argsTypes) : TMutableComputationNode(mutables) + , StateIndex(mutables.CurValueIndex++) , FuncName(funcName) - , LeftArg(leftArg) - , RightArg(rightArg) - , LeftValueDesc(ToValueDescr(leftArgType)) - , RightValueDesc(ToValueDescr(rightArgType)) - , OutputValueDescr(ToValueDescr(outputType)) - , Kernel(ResolveKernel(FuncName, LeftValueDesc, RightValueDesc)) - , OutputTypeBitWidth(static_cast<const arrow::FixedWidthType&>(*OutputValueDescr.type).bit_width()) + , ArgsNodes(std::move(argsNodes)) + , ArgsTypes(std::move(argsTypes)) + , ArgsValuesDescr(ToValueDescr(ArgsTypes)) , FunctionRegistry(*arrow::compute::GetFunctionRegistry()) + , Function(ResolveFunction(FunctionRegistry, FuncName)) + , Kernel(ResolveKernel(Function, ArgsValuesDescr)) { - { - auto execContext = arrow::compute::ExecContext(); - auto kernelContext = arrow::compute::KernelContext(&execContext); - const auto kernelOutputValueDesc = ARROW_RESULT(Kernel.signature->out_type().Resolve(&kernelContext, { - LeftValueDesc, - RightValueDesc - })); - Y_VERIFY_DEBUG(kernelOutputValueDesc == OutputValueDescr); - } - - Y_VERIFY_DEBUG( - LeftValueDesc.shape == arrow::ValueDescr::ARRAY && RightValueDesc.shape == arrow::ValueDescr::ARRAY || - LeftValueDesc.shape == arrow::ValueDescr::SCALAR && RightValueDesc.shape == arrow::ValueDescr::ARRAY || - LeftValueDesc.shape == arrow::ValueDescr::ARRAY && RightValueDesc.shape == arrow::ValueDescr::SCALAR); } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - const auto leftValue = LeftArg->GetValue(ctx); - const auto rightValue = RightArg->GetValue(ctx); - auto& leftDatum = TArrowBlock::From(leftValue).GetDatum(); - auto& rightDatum = TArrowBlock::From(rightValue).GetDatum(); - Y_VERIFY_DEBUG(leftDatum.descr() == LeftValueDesc); - Y_VERIFY_DEBUG(rightDatum.descr() == RightValueDesc); - const auto leftKind = leftDatum.kind(); - const auto rightKind = rightDatum.kind(); - MKQL_ENSURE(leftKind != arrow::Datum::ARRAY || rightKind != arrow::Datum::ARRAY || - leftDatum.array()->length == rightDatum.array()->length, - "block size mismatch: " - << static_cast<ui64>(leftDatum.array()->length) - << " != " - << static_cast<ui64>(rightDatum.array()->length)); - const auto blockLength = leftKind == arrow::Datum::ARRAY - ? leftDatum.array()->length - : rightDatum.array()->length; - - auto execContext = arrow::compute::ExecContext(&ctx.ArrowMemoryPool, nullptr, &FunctionRegistry); - auto kernelContext = arrow::compute::KernelContext(&execContext); - - arrow::Datum output = arrow::ArrayData::Make( - OutputValueDescr.type, - blockLength, - std::vector<std::shared_ptr<arrow::Buffer>> { - ARROW_RESULT(kernelContext.AllocateBitmap(blockLength)), - ARROW_RESULT(kernelContext.Allocate(arrow::BitUtil::BytesForBits(OutputTypeBitWidth * blockLength))) - }); - const auto inputBatch = arrow::compute::ExecBatch({leftDatum, rightDatum}, blockLength); - ARROW_OK(arrow::compute::detail::PropagateNulls(&kernelContext, inputBatch, output.array().get())); - ARROW_OK(Kernel.exec(&kernelContext, inputBatch, &output)); + auto& state = GetState(ctx); + + state.Values.clear(); + for (ui32 i = 0; i < ArgsNodes.size(); ++i) { + state.Values.emplace_back(TArrowBlock::From(ArgsNodes[i]->GetValue(ctx)).GetDatum()); + Y_VERIFY_DEBUG(ArgsValuesDescr[i] == state.Values.back().descr()); + } + + auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
+ ARROW_OK(state.Executor->Execute(state.Values, listener.get()));
+ auto output = state.Executor->WrapResults(state.Values, listener->values()); return ctx.HolderFactory.CreateArrowBlock(std::move(output)); } private: void RegisterDependencies() const final { - this->DependsOn(LeftArg); - this->DependsOn(RightArg); + for (const auto& arg : ArgsNodes) { + this->DependsOn(arg); + } } - static const arrow::compute::ScalarKernel& ResolveKernel(const TString& funcName, - const arrow::ValueDescr& leftArg, - const arrow::ValueDescr& rightArg) - { - auto* functionRegistry = arrow::compute::GetFunctionRegistry(); - Y_VERIFY_DEBUG(functionRegistry != nullptr); - auto function = ARROW_RESULT(functionRegistry->GetFunction(funcName)); - Y_VERIFY_DEBUG(function != nullptr); - Y_VERIFY_DEBUG(function->kind() == arrow::compute::Function::SCALAR); + static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, const TString& funcName) { + auto function = ARROW_RESULT(registry.GetFunction(funcName)); + MKQL_ENSURE(function != nullptr, "missing function"); + MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function"); + return *function; + } - const auto* kernel = ARROW_RESULT(function->DispatchExact({leftArg, rightArg})); + static const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) { + const auto kernel = ARROW_RESULT(function.DispatchExact(args)); return *static_cast<const arrow::compute::ScalarKernel*>(kernel); } - static std::shared_ptr<arrow::DataType> ConvertType(TType* type) { + static arrow::ValueDescr ToValueDescr(TType* type) { bool isOptional; - const auto dataType = UnpackOptionalData(type, isOptional); - switch (*dataType->GetDataSlot()) { - case NUdf::EDataSlot::Uint64: - return arrow::uint64(); - default: - Y_FAIL("unexpected type %s", TString(dataType->GetKindAsStr()).c_str()); + arrow::ValueDescr ret; + MKQL_ENSURE(ConvertInputArrowType(type, isOptional, ret), "can't get arrow type"); + return ret; + } + + static std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { + std::vector<arrow::ValueDescr> res; + res.reserve(types.size()); + for (const auto& type : types) { + res.emplace_back(ToValueDescr(type)); } + + return res; } - static arrow::ValueDescr ToValueDescr(TType* type) { - auto* blockType = AS_TYPE(TBlockType, type); - const auto shape = blockType->GetShape() == TBlockType::EShape::Scalar - ? arrow::ValueDescr::SCALAR - : arrow::ValueDescr::ARRAY; - return arrow::ValueDescr(ConvertType(blockType->GetItemType()), shape); + struct TState : public TComputationValue<TState> { + using TComputationValue::TComputationValue; + + TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::ScalarKernel& kernel, + arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx) + : TComputationValue(memInfo) + , ExecContext(&ctx.ArrowMemoryPool, nullptr, ®istry) + , KernelContext(&ExecContext) + , Executor(arrow::compute::detail::KernelExecutor::MakeScalar()) + { + auto options = function.default_options(); + if (kernel.init) {
+ State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ KernelContext.SetState(State.get());
+ }
+
+ ARROW_OK(Executor->Init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ Values.reserve(argsValuesDescr.size()); + } + + arrow::compute::ExecContext ExecContext; + arrow::compute::KernelContext KernelContext; + std::unique_ptr<arrow::compute::KernelState> State; + std::unique_ptr<arrow::compute::detail::KernelExecutor> Executor; + + std::vector<arrow::Datum> Values; + }; + + TState& GetState(TComputationContext& ctx) const { + auto& result = ctx.MutableValues[StateIndex]; + if (!result.HasValue()) { + result = ctx.HolderFactory.Create<TState>(Function, Kernel, FunctionRegistry, ArgsValuesDescr, ctx); + } + + return *static_cast<TState*>(result.AsBoxed().Get()); } private: + const ui32 StateIndex; const TString FuncName; - IComputationNode* LeftArg; - IComputationNode* RightArg; - const arrow::ValueDescr LeftValueDesc; - const arrow::ValueDescr RightValueDesc; - const arrow::ValueDescr OutputValueDescr; - const arrow::compute::ScalarKernel& Kernel; - const int OutputTypeBitWidth; + const TVector<IComputationNode*> ArgsNodes; + const TVector<TType*> ArgsTypes; + + const std::vector<arrow::ValueDescr> ArgsValuesDescr; arrow::compute::FunctionRegistry& FunctionRegistry; + const arrow::compute::Function& Function; + const arrow::compute::ScalarKernel& Kernel; }; } IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) { - const auto* callableType = callable.GetType(); + MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg"); const auto funcNameData = AS_VALUE(TDataLiteral, callable.GetInput(0)); const auto funcName = TString(funcNameData->AsValue().AsStringRef()); + TVector<IComputationNode*> argsNodes; + TVector<TType*> argsTypes; + const auto callableType = callable.GetType(); + for (ui32 i = 1; i < callable.GetInputsCount(); ++i) { + argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i)); + argsTypes.push_back(callableType->GetArgumentType(i)); + } + return new TBlockFuncWrapper(ctx.Mutables, funcName, - LocateNode(ctx.NodeLocator, callable, 1), - LocateNode(ctx.NodeLocator, callable, 2), - callableType->GetArgumentType(1), - callableType->GetArgumentType(2), - callableType->GetReturnType()); + std::move(argsNodes), + std::move(argsTypes) + ); } } |