diff options
author | vvvv <vvvv@ydb.tech> | 2023-05-17 16:17:09 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-05-17 16:17:09 +0300 |
commit | 2c846092d67299183480076899a56ee76e9c8515 (patch) | |
tree | 7e72f3e8c5785f6c3c5038f1f7da8b8df6c1f6d2 | |
parent | 650b8705c872068ffd6b1cacd055276370dc79af (diff) | |
download | ydb-2c846092d67299183480076899a56ee76e9c8515.tar.gz |
Expose arrow kernels from IComputationNode
13 files changed, 477 insertions, 55 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index d692a304642..6803a70e835 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -38,7 +38,7 @@ const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TSt class TBlockBitCastWrapper : public TBlockFuncNode { public: TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to) - : TBlockFuncNode(mutables, { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions) + : TBlockFuncNode(mutables, "BitCast", { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions) , CastOptions(false) { } @@ -74,7 +74,7 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto } const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType()); - return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions); + return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions); } IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) { diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp index 37c4f88c61b..544f67e6694 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp @@ -14,45 +14,103 @@ namespace NMiniKQL { namespace { class TBlockIfScalarWrapper : public TMutableComputationNode<TBlockIfScalarWrapper> { +friend class TArrowNode; public: + class TArrowNode : public IArrowKernelComputationNode { + public: + TArrowNode(const TBlockIfScalarWrapper* parent) + : Parent_(parent) + , ArgsValuesDescr_(ToValueDescr(parent->ArgsTypes)) + , Kernel_(ConvertToInputTypes(parent->ArgsTypes), ConvertToOutputType(parent->ResultType), [parent](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { + *res = parent->CalculateImpl(MakeDatumProvider(batch.values[0]), MakeDatumProvider(batch.values[1]), MakeDatumProvider(batch.values[2]), *ctx->memory_pool()); + return arrow::Status::OK(); + }) + { + Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; + } + + TStringBuf GetKernelName() const final { + return "If"; + } + + const arrow::compute::ScalarKernel& GetArrowKernel() const { + return Kernel_; + } + + const std::vector<arrow::ValueDescr>& GetArgsDesc() const { + return ArgsValuesDescr_; + } + + const IComputationNode* GetArgument(ui32 index) const { + switch (index) { + case 0: + return Parent_->Pred; + case 1: + return Parent_->Then; + case 2: + return Parent_->Else; + default: + throw yexception() << "Bad argument index"; + } + } + + private: + const TBlockIfScalarWrapper* Parent_; + const std::vector<arrow::ValueDescr> ArgsValuesDescr_; + arrow::compute::ScalarKernel Kernel_; + }; + TBlockIfScalarWrapper(TComputationMutables& mutables, IComputationNode* pred, IComputationNode* thenNode, IComputationNode* elseNode, TType* resultType, - bool thenIsScalar, bool elseIsScalar) + bool thenIsScalar, bool elseIsScalar, const TVector<TType*>& argsTypes) : TMutableComputationNode(mutables) , Pred(pred) , Then(thenNode) , Else(elseNode) - , Type(resultType) + , ResultType(resultType) , ThenIsScalar(thenIsScalar) , ElseIsScalar(elseIsScalar) + , ArgsTypes(argsTypes) { } - NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - auto predValue = Pred->GetValue(ctx); + std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final { + Y_UNUSED(ctx); + return std::make_unique<TArrowNode>(this); + } + + arrow::Datum CalculateImpl(const TDatumProvider& predProv, const TDatumProvider& thenProv, const TDatumProvider& elseProv, + arrow::MemoryPool& memoryPool) const { + auto predValue = predProv(); - const bool predScalarValue = GetPrimitiveScalarValue<bool>(*TArrowBlock::From(predValue).GetDatum().scalar()); - auto result = predScalarValue ? Then->GetValue(ctx) : Else->GetValue(ctx); + const bool predScalarValue = GetPrimitiveScalarValue<bool>(*predValue.scalar()); + auto result = predScalarValue ? thenProv() : elseProv(); if (ThenIsScalar == ElseIsScalar || (predScalarValue ? !ThenIsScalar : !ElseIsScalar)) { // can return result as-is - return result.Release(); + return result; } - auto other = predScalarValue ? Else->GetValue(ctx) : Then->GetValue(ctx); - const auto& otherDatum = TArrowBlock::From(other).GetDatum(); + auto otherDatum = predScalarValue ? elseProv() : thenProv(); MKQL_ENSURE(otherDatum.is_arraylike(), "Expecting array"); - std::shared_ptr<arrow::Scalar> resultScalar = TArrowBlock::From(result).GetDatum().scalar(); + std::shared_ptr<arrow::Scalar> resultScalar = result.scalar(); TVector<std::shared_ptr<arrow::ArrayData>> resultArrays; + auto itemType = AS_TYPE(TBlockType, ResultType)->GetItemType(); ForEachArrayData(otherDatum, [&](const std::shared_ptr<arrow::ArrayData>& otherData) { - auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, Type, ctx.ArrowMemoryPool); + auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, itemType, memoryPool); ForEachArrayData(chunk, [&](const auto& array) { resultArrays.push_back(array); }); }); - return ctx.HolderFactory.CreateArrowBlock(MakeArray(resultArrays)); + return MakeArray(resultArrays); + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + return ctx.HolderFactory.CreateArrowBlock(CalculateImpl(MakeDatumProvider(Pred, ctx), MakeDatumProvider(Then, ctx), MakeDatumProvider(Else, ctx), ctx.ArrowMemoryPool)); } + private: void RegisterDependencies() const final { DependsOn(Pred); @@ -63,9 +121,10 @@ private: IComputationNode* const Pred; IComputationNode* const Then; IComputationNode* const Else; - TType* const Type; + TType* const ResultType; const bool ThenIsScalar; const bool ElseIsScalar; + const TVector<TType*> ArgsTypes; }; template<bool ThenIsScalar, bool ElseIsScalar> @@ -171,15 +230,16 @@ IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactory bool predIsScalar = predType->GetShape() == TBlockType::EShape::Scalar; bool thenIsScalar = thenType->GetShape() == TBlockType::EShape::Scalar; bool elseIsScalar = elseType->GetShape() == TBlockType::EShape::Scalar; + TVector<TType*> argsTypes = { predType, thenType, elseType }; + if (predIsScalar) { - return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType->GetItemType(), - thenIsScalar, elseIsScalar); + return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType, + thenIsScalar, elseIsScalar, argsTypes); } TVector<IComputationNode*> argsNodes = { predCompute, thenCompute, elseCompute }; - TVector<TType*> argsTypes = { predType, thenType, elseType }; - + std::shared_ptr<arrow::compute::ScalarKernel> kernel; if (thenIsScalar && elseIsScalar) { kernel = MakeBlockIfKernel<true, true>(argsTypes, thenType); @@ -191,7 +251,7 @@ IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactory kernel = MakeBlockIfKernel<false, false>(argsTypes, thenType); } - return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp index 7bca187600f..7410a8e62c5 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp @@ -11,20 +11,6 @@ namespace NKikimr::NMiniKQL { -namespace { - -std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { - std::vector<arrow::ValueDescr> res; - res.reserve(types.size()); - for (const auto& type : types) { - res.emplace_back(ToValueDescr(type)); - } - - return res; -} - -} // namespace - arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool) { MKQL_ENSURE(len > 0, "Invalid block size"); auto reader = MakeBlockReader(TTypeInfoHelper(), type); @@ -44,6 +30,16 @@ arrow::ValueDescr ToValueDescr(TType* type) { return ret; } +std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) { + std::vector<arrow::ValueDescr> res; + res.reserve(types.size()); + for (const auto& type : types) { + res.emplace_back(ToValueDescr(type)); + } + + return res; +} + std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes) { std::vector<arrow::compute::InputType> result; result.reserve(argTypes.size()); @@ -57,7 +53,7 @@ arrow::compute::OutputType ConvertToOutputType(TType* output) { return arrow::compute::OutputType(ToValueDescr(output)); } -TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes, +TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TStringBuf name, TVector<IComputationNode*>&& argsNodes, const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel, std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder, const arrow::compute::FunctionOptions* functionOptions) @@ -69,6 +65,7 @@ TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TVector<IComputat , KernelHolder(std::move(kernelHolder)) , Options(functionOptions) , ScalarOutput(GetResultShape(argsTypes) == TBlockType::EShape::Scalar) + , Name(name.starts_with("Block") ? name.substr(5) : name) { } @@ -122,4 +119,29 @@ TBlockFuncNode::TState& TBlockFuncNode::GetState(TComputationContext& ctx) const return *static_cast<TState*>(result.AsBoxed().Get()); } +std::unique_ptr<IArrowKernelComputationNode> TBlockFuncNode::PrepareArrowKernelComputationNode(TComputationContext& ctx) const { + return std::make_unique<TArrowNode>(this); +} + +TBlockFuncNode::TArrowNode::TArrowNode(const TBlockFuncNode* parent) + : Parent_(parent) +{} + +TStringBuf TBlockFuncNode::TArrowNode::GetKernelName() const { + return Parent_->Name; +} + +const arrow::compute::ScalarKernel& TBlockFuncNode::TArrowNode::GetArrowKernel() const { + return Parent_->Kernel; +} + +const std::vector<arrow::ValueDescr>& TBlockFuncNode::TArrowNode::GetArgsDesc() const { + return Parent_->ArgsValuesDescr; +} + +const IComputationNode* TBlockFuncNode::TArrowNode::GetArgument(ui32 index) const { + MKQL_ENSURE(index < Parent_->ArgsNodes.size(), "Wrong index"); + return Parent_->ArgsNodes[index]; +} + } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h index 3e82ae1e617..1cdd661a61d 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h @@ -15,19 +15,33 @@ namespace NKikimr::NMiniKQL { arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool); arrow::ValueDescr ToValueDescr(TType* type); +std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types); std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes); arrow::compute::OutputType ConvertToOutputType(TType* output); class TBlockFuncNode : public TMutableComputationNode<TBlockFuncNode> { +friend class TArrowNode; public: - TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes, + TBlockFuncNode(TComputationMutables& mutables, TStringBuf name, TVector<IComputationNode*>&& argsNodes, const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel, std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder = {}, const arrow::compute::FunctionOptions* functionOptions = nullptr); NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const; private: + class TArrowNode : public IArrowKernelComputationNode { + public: + TArrowNode(const TBlockFuncNode* parent); + TStringBuf GetKernelName() const final; + const arrow::compute::ScalarKernel& GetArrowKernel() const final; + const std::vector<arrow::ValueDescr>& GetArgsDesc() const final; + const IComputationNode* GetArgument(ui32 index) const final; + + private: + const TBlockFuncNode* Parent_; + }; + struct TState : public TComputationValue<TState> { using TComputationValue::TComputationValue; @@ -51,6 +65,9 @@ private: void RegisterDependencies() const final; TState& GetState(TComputationContext& ctx) const; + + std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final; + private: const ui32 StateIndex; const TVector<IComputationNode*> ArgsNodes; @@ -59,6 +76,7 @@ private: const std::shared_ptr<arrow::compute::ScalarKernel> KernelHolder; const arrow::compute::FunctionOptions* const Options; const bool ScalarOutput; + const TString Name; }; template <typename TDerived> diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp index 3dac133b81d..65b0273e0cb 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp @@ -81,7 +81,7 @@ IComputationNode* WrapBlockJust(TCallable& callable, const TComputationNodeFacto kernel = MakeBlockJustKernel<true>(argsTypes, callable.GetType()->GetReturnType()); } - return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp index 1a97f96a208..1a61b9c7f45 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp @@ -155,7 +155,7 @@ IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFa } auto kernel = MakeBlockAsTupleKernel(argsTypes, callable.GetType()->GetReturnType()); - return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel); } IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) { @@ -175,7 +175,7 @@ IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactor TVector<IComputationNode*> argsNodes = { tuple }; TVector<TType*> argsTypes = { blockType }; auto kernel = MakeBlockNthKernel(argsTypes, callable.GetType()->GetReturnType(), index, isOptional, needExternalOptional); - return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel); } } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp index 8adc1c6ca43..ee0b91e3604 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp @@ -345,6 +345,40 @@ private: class TAsScalarWrapper : public TMutableComputationNode<TAsScalarWrapper> { public: + class TArrowNode : public IArrowKernelComputationNode { + public: + TArrowNode(const arrow::Datum& datum) + : Kernel_({}, datum.scalar()->type, [datum](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { + *res = datum; + return arrow::Status::OK(); + }) + { + Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; + } + + TStringBuf GetKernelName() const final { + return "AsScalar"; + } + + const arrow::compute::ScalarKernel& GetArrowKernel() const { + return Kernel_; + } + + const std::vector<arrow::ValueDescr>& GetArgsDesc() const { + return EmptyDesc_; + } + + const IComputationNode* GetArgument(ui32 index) const { + Y_UNUSED(index); + ythrow yexception() << "No input arguments"; + } + + private: + arrow::compute::ScalarKernel Kernel_; + const std::vector<arrow::ValueDescr> EmptyDesc_; + }; + TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type) : TMutableComputationNode(mutables) , Arg_(arg) @@ -360,6 +394,12 @@ public: return ctx.HolderFactory.CreateArrowBlock(std::move(result)); } + std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final { + auto value = Arg_->GetValue(ctx); + arrow::Datum result = ConvertScalar(Type_, value, ctx); + return std::make_unique<TArrowNode>(result); + } + private: void RegisterDependencies() const final { DependsOn(Arg_); diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp index f2ddd998661..b409a523a68 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp @@ -2,8 +2,39 @@ #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> +#include <arrow/compute/exec_internal.h> +#include <arrow/array/builder_primitive.h> + namespace NKikimr { namespace NMiniKQL { + +namespace { + arrow::Datum ExecuteOneKernel(const IArrowKernelComputationNode* kernelNode, + const std::vector<arrow::Datum>& argDatums, arrow::compute::ExecContext& execContext) { + const auto& kernel = kernelNode->GetArrowKernel(); + arrow::compute::KernelContext kernelContext(&execContext); + std::unique_ptr<arrow::compute::KernelState> state; + auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); + ARROW_OK(executor->Init(&kernelContext, { &kernel, kernelNode->GetArgsDesc(), nullptr })); + auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); + ARROW_OK(executor->Execute(argDatums, listener.get())); + return executor->WrapResults(argDatums, listener->values()); + } + + void ExecuteAllKernels(std::vector<arrow::Datum>& datums, const TArrowKernelsTopology* topology, arrow::compute::ExecContext& execContext) { + for (ui32 i = 0; i < topology->Items.size(); ++i) { + std::vector<arrow::Datum> argDatums; + argDatums.reserve(topology->Items[i].Inputs.size()); + for (auto j : topology->Items[i].Inputs) { + argDatums.emplace_back(datums[j]); + } + + arrow::Datum output = ExecuteOneKernel(topology->Items[i].Node.get(), argDatums, execContext); + datums[i + topology->InputArgsCount] = output; + } + } +} + Y_UNIT_TEST_SUITE(TMiniKQLBlocksTest) { Y_UNIT_TEST(TestEmpty) { TSetup<false> setup; @@ -457,7 +488,121 @@ Y_UNIT_TEST(TestWideFromBlocks) { UNIT_ASSERT(iterator.Next(item)); UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 30); } +} + +Y_UNIT_TEST_SUITE(TMiniKQLDirectKernelTest) { +Y_UNIT_TEST(Simple) { + TSetup<false> setup; + auto& pb = *setup.PgmBuilder; + + const auto boolType = pb.NewDataType(NUdf::TDataType<bool>::Id); + const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id); + const auto boolBlocksType = pb.NewBlockType(boolType, TBlockType::EShape::Many); + const auto ui64BlocksType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many); + const auto arg1 = pb.Arg(boolBlocksType); + const auto arg2 = pb.Arg(ui64BlocksType); + const auto arg3 = pb.Arg(ui64BlocksType); + const auto ifNode = pb.BlockIf(arg1, arg2, arg3); + const auto eqNode = pb.BlockFunc("Equals", boolBlocksType, { ifNode, arg2 }); + + const auto graph = setup.BuildGraph(eqNode, {arg1.GetNode(), arg2.GetNode(), arg3.GetNode()}); + const auto topology = graph->GetKernelsTopology(); + UNIT_ASSERT(topology); + UNIT_ASSERT_VALUES_EQUAL(topology->InputArgsCount, 3); + UNIT_ASSERT_VALUES_EQUAL(topology->Items.size(), 2); + UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Node->GetKernelName(), "If"); + const std::vector<ui32> expectedInputs1{{0, 1, 2}}; + UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Inputs, expectedInputs1); + UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Node->GetKernelName(), "Equals"); + const std::vector<ui32> expectedInputs2{{3, 1}}; + UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Inputs, expectedInputs2); + + arrow::compute::ExecContext execContext; + const size_t blockSize = 100000; + std::vector<arrow::Datum> datums(topology->InputArgsCount + topology->Items.size()); + { + arrow::UInt8Builder builder1(execContext.memory_pool()); + arrow::UInt64Builder builder2(execContext.memory_pool()), builder3(execContext.memory_pool()); + ARROW_OK(builder1.Reserve(blockSize)); + ARROW_OK(builder2.Reserve(blockSize)); + ARROW_OK(builder3.Reserve(blockSize)); + for (size_t i = 0; i < blockSize; ++i) { + builder1.UnsafeAppend(i & 1); + builder2.UnsafeAppend(i); + builder3.UnsafeAppend(3 * i); + } + + std::shared_ptr<arrow::ArrayData> data1; + ARROW_OK(builder1.FinishInternal(&data1)); + std::shared_ptr<arrow::ArrayData> data2; + ARROW_OK(builder2.FinishInternal(&data2)); + std::shared_ptr<arrow::ArrayData> data3; + ARROW_OK(builder3.FinishInternal(&data3)); + datums[0] = data1; + datums[1] = data2; + datums[2] = data3; + } + + ExecuteAllKernels(datums, topology, execContext); + + auto res = datums.back().array()->GetValues<ui8>(1); + for (size_t i = 0; i < blockSize; ++i) { + auto expected = (((i & 1) ? i : i * 3) == i) ? 1 : 0; + UNIT_ASSERT_VALUES_EQUAL(res[i], expected); + } +} +Y_UNIT_TEST(WithScalars) { + TSetup<false> setup; + auto& pb = *setup.PgmBuilder; + + const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id); + const auto ui64BlocksType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many); + const auto scalar = pb.AsScalar(pb.NewDataLiteral(false)); + const auto arg1 = pb.Arg(ui64BlocksType); + const auto arg2 = pb.Arg(ui64BlocksType); + const auto ifNode = pb.BlockIf(scalar, arg1, arg2); + + const auto graph = setup.BuildGraph(ifNode, {arg1.GetNode(), arg2.GetNode()}); + const auto topology = graph->GetKernelsTopology(); + UNIT_ASSERT(topology); + UNIT_ASSERT_VALUES_EQUAL(topology->InputArgsCount, 2); + UNIT_ASSERT_VALUES_EQUAL(topology->Items.size(), 2); + UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Node->GetKernelName(), "AsScalar"); + const std::vector<ui32> expectedInputs1; + UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Inputs, expectedInputs1); + UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Node->GetKernelName(), "If"); + const std::vector<ui32> expectedInputs2{{2, 0, 1}}; + UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Inputs, expectedInputs2); + + arrow::compute::ExecContext execContext; + const size_t blockSize = 100000; + std::vector<arrow::Datum> datums(topology->InputArgsCount + topology->Items.size()); + { + arrow::UInt64Builder builder1(execContext.memory_pool()), builder2(execContext.memory_pool()); + ARROW_OK(builder1.Reserve(blockSize)); + ARROW_OK(builder2.Reserve(blockSize)); + for (size_t i = 0; i < blockSize; ++i) { + builder1.UnsafeAppend(i); + builder2.UnsafeAppend(3 * i); + } + + std::shared_ptr<arrow::ArrayData> data1; + ARROW_OK(builder1.FinishInternal(&data1)); + std::shared_ptr<arrow::ArrayData> data2; + ARROW_OK(builder2.FinishInternal(&data2)); + datums[0] = data1; + datums[1] = data2; + } + + ExecuteAllKernels(datums, topology, execContext); + + auto res = datums.back().array()->GetValues<ui64>(1); + for (size_t i = 0; i < blockSize; ++i) { + auto expected = 3 * i; + UNIT_ASSERT_VALUES_EQUAL(res[i], expected); + } +} } diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h index bf20ea2c110..0c3eb242122 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h @@ -91,7 +91,8 @@ struct TSetup { Reset(); Explorer.Walk(pgm.GetNode(), *Env); TComputationPatternOpts opts(Alloc.Ref(), *Env, GetTestFactory(NodeFactory), - FunctionRegistry.Get(), NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, UseLLVM ? "" : "OFF", EGraphPerProcess::Multi); + FunctionRegistry.Get(), NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, + UseLLVM ? "" : "OFF", EGraphPerProcess::Multi, nullptr, nullptr, nullptr); Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts); auto graph = Pattern->Clone(opts.ToComputationOptions(*RandomProvider, *TimeProvider)); Terminator.Reset(new TBindTerminator(graph->GetTerminator())); diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node.cpp b/ydb/library/yql/minikql/computation/mkql_computation_node.cpp index 6921e83fedd..273eaace7a4 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_node.cpp +++ b/ydb/library/yql/minikql/computation/mkql_computation_node.cpp @@ -26,6 +26,23 @@ namespace NKikimr { namespace NMiniKQL { +std::unique_ptr<IArrowKernelComputationNode> IComputationNode::PrepareArrowKernelComputationNode(TComputationContext& ctx) const { + Y_UNUSED(ctx); + return {}; +} + +TDatumProvider MakeDatumProvider(const arrow::Datum& datum) { + return [datum]() { + return datum; + }; +} + +TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx) { + return [node, &ctx]() { + return TArrowBlock::From(node->GetValue(ctx)).GetDatum(); + }; +} + TComputationContext::TComputationContext(const THolderFactory& holderFactory, const NUdf::IValueBuilder* builder, TComputationOptsFull& opts, diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node.h b/ydb/library/yql/minikql/computation/mkql_computation_node.h index 640ef0f5813..ab898d7ee61 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_node.h +++ b/ydb/library/yql/minikql/computation/mkql_computation_node.h @@ -143,6 +143,8 @@ private: #endif }; +class IArrowKernelComputationNode; + class IComputationNode { public: typedef TIntrusivePtr<IComputationNode> TPtr; @@ -177,6 +179,8 @@ public: virtual void Ref() = 0; virtual void UnRef() = 0; virtual ui32 RefCount() const = 0; + + virtual std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const; }; class IComputationExternalNode : public IComputationNode { @@ -209,6 +213,31 @@ public: virtual void InvalidateValue(TComputationContext& compCtx) const = 0; }; +using TDatumProvider = std::function<arrow::Datum()>; + +TDatumProvider MakeDatumProvider(const arrow::Datum& datum); +TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx); + +class IArrowKernelComputationNode { +public: + virtual ~IArrowKernelComputationNode() = default; + + virtual TStringBuf GetKernelName() const = 0; + virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0; + virtual const std::vector<arrow::ValueDescr>& GetArgsDesc() const = 0; + virtual const IComputationNode* GetArgument(ui32 index) const = 0; +}; + +struct TArrowKernelsTopologyItem { + std::vector<ui32> Inputs; + std::unique_ptr<IArrowKernelComputationNode> Node; +}; + +struct TArrowKernelsTopology { + ui32 InputArgsCount = 0; + std::vector<TArrowKernelsTopologyItem> Items; +}; + using TComputationNodePtrVector = std::vector<IComputationNode*, TMKQLAllocator<IComputationNode*>>; using TComputationWideFlowNodePtrVector = std::vector<IComputationWideFlowNode*, TMKQLAllocator<IComputationWideFlowNode*>>; using TComputationExternalNodePtrVector = std::vector<IComputationExternalNode*, TMKQLAllocator<IComputationExternalNode*>>; @@ -223,6 +252,7 @@ public: virtual NUdf::TUnboxedValue GetValue() = 0; virtual TComputationContext& GetContext() = 0; virtual IComputationExternalNode* GetEntryPoint(size_t index, bool require) = 0; + virtual const TArrowKernelsTopology* GetKernelsTopology() = 0; virtual const TComputationNodePtrDeque& GetNodes() const = 0; virtual void Invalidate() = 0; virtual TMemoryUsageInfo& GetMemInfo() const = 0; @@ -310,7 +340,8 @@ struct TComputationPatternOpts { const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr, - NUdf::ICountersProvider* countersProvider = nullptr) + NUdf::ICountersProvider* countersProvider = nullptr, + const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) : AllocState(allocState) , Env(env) , Factory(factory) @@ -321,12 +352,14 @@ struct TComputationPatternOpts { , GraphPerProcess(graphPerProcess) , Stats(stats) , CountersProvider(countersProvider) + , SecureParamsProvider(secureParamsProvider) {} void SetOptions(TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry, NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy, const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr, - NUdf::ICountersProvider* counters = nullptr, const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) { + NUdf::ICountersProvider* counters = nullptr, + const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) { Factory = factory; FunctionRegistry = functionRegistry; ValidateMode = validateMode; diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp b/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp index c8de496a533..2eb542fc21f 100644 --- a/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp +++ b/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp @@ -186,9 +186,9 @@ public: } IComputationExternalNode* GetEntryPoint(size_t index, bool require) { - MKQL_ENSURE(index < Runtime2Computation.size() && (!require || Runtime2Computation[index]), + MKQL_ENSURE(index < Runtime2ComputationEntryPoints.size() && (!require || Runtime2ComputationEntryPoints[index]), "Pattern nodes can not get computation node by index: " << index << ", require: " << require); - return Runtime2Computation[index]; + return Runtime2ComputationEntryPoints[index]; } IComputationNode* GetRoot() { @@ -199,6 +199,10 @@ public: return SuitableForCache; } + size_t GetEntryPointsCount() const { + return Runtime2ComputationEntryPoints.size(); + } + private: friend class TComputationGraphBuildingVisitor; friend class TComputationGraph; @@ -210,7 +214,7 @@ private: TComputationMutables Mutables; TComputationNodePtrDeque ComputationNodesList; IComputationNode* RootNode = nullptr; - TComputationExternalNodePtrVector Runtime2Computation; + TComputationExternalNodePtrVector Runtime2ComputationEntryPoints; TComputationNodeOnNodeMap ElementsCache; bool SuitableForCache = true; }; @@ -527,8 +531,8 @@ public: PatternNodes->RootNode = rootNode; } - void PreserveEntryPoints(TComputationExternalNodePtrVector&& runtime2Computation) { - PatternNodes->Runtime2Computation = std::move(runtime2Computation); + void PreserveEntryPoints(TComputationExternalNodePtrVector&& runtime2ComputationEntryPoints) { + PatternNodes->Runtime2ComputationEntryPoints = std::move(runtime2ComputationEntryPoints); } private: @@ -617,6 +621,72 @@ public: return PatternNodes->GetEntryPoint(index, require); } + const TArrowKernelsTopology* GetKernelsTopology() override { + Prepare(); + if (!KernelsTopology.has_value()) { + CalculateKernelTopology(*Ctx); + } + + return &KernelsTopology.value(); + } + + void CalculateKernelTopology(TComputationContext& ctx) { + KernelsTopology.emplace(); + KernelsTopology->InputArgsCount = PatternNodes->GetEntryPointsCount(); + + std::stack<const IComputationNode*> stack; + struct TNodeState { + bool Visited; + ui32 Index; + }; + + std::unordered_map<const IComputationNode*, TNodeState> deps; + for (ui32 i = 0; i < KernelsTopology->InputArgsCount; ++i) { + auto entryPoint = PatternNodes->GetEntryPoint(i, false); + if (!entryPoint) { + continue; + } + + deps.emplace(entryPoint, TNodeState{ true, i}); + } + + stack.push(PatternNodes->GetRoot()); + while (!stack.empty()) { + auto node = stack.top(); + auto [iter, inserted] = deps.emplace(node, TNodeState{ false, 0 }); + auto extNode = dynamic_cast<const IComputationExternalNode*>(node); + if (extNode) { + MKQL_ENSURE(!inserted, "Unexpected external node"); + stack.pop(); + continue; + } + + auto kernelNode = node->PrepareArrowKernelComputationNode(ctx); + MKQL_ENSURE(kernelNode, "No kernel for node: " << node->DebugString()); + auto argsCount = kernelNode->GetArgsDesc().size(); + + if (!iter->second.Visited) { + for (ui32 j = 0; j < argsCount; ++j) { + stack.push(kernelNode->GetArgument(j)); + } + iter->second.Visited = true; + } else { + iter->second.Index = KernelsTopology->InputArgsCount + KernelsTopology->Items.size(); + KernelsTopology->Items.emplace_back(); + auto& i = KernelsTopology->Items.back(); + i.Inputs.reserve(argsCount); + for (ui32 j = 0; j < argsCount; ++j) { + auto it = deps.find(kernelNode->GetArgument(j)); + MKQL_ENSURE(it != deps.end(), "Missing argument"); + i.Inputs.emplace_back(it->second.Index); + } + + i.Node = std::move(kernelNode); + stack.pop(); + } + } + } + void Invalidate() override { std::fill_n(Ctx->MutableValues.get(), PatternNodes->GetMutables().CurValueIndex, NUdf::TUnboxedValue(NUdf::TUnboxedValuePod::Invalid())); } @@ -690,6 +760,7 @@ private: THolder<TComputationContext> Ctx; TComputationOptsFull CompOpts; bool IsPrepared = false; + std::optional<TArrowKernelsTopology> KernelsTopology; }; } // namespace @@ -861,16 +932,31 @@ TIntrusivePtr<TComputationPatternImpl> MakeComputationPatternImpl(TExploringNode const auto rootNode = builder->GetComputationNode(root.GetNode()); - TComputationExternalNodePtrVector runtime2Computation; - runtime2Computation.resize(entryPoints.size(), nullptr); + TComputationExternalNodePtrVector runtime2ComputationEntryPoints; + runtime2ComputationEntryPoints.resize(entryPoints.size(), nullptr); + std::unordered_map<TNode*, std::vector<ui32>> entryPointIndex; + for (ui32 i = 0; i < entryPoints.size(); ++i) { + entryPointIndex[entryPoints[i]].emplace_back(i); + } + for (const auto& node : explorer.GetNodes()) { - for (auto iter = std::find(entryPoints.cbegin(), entryPoints.cend(), node); entryPoints.cend() != iter; iter = std::find(iter + 1, entryPoints.cend(), node)) { - runtime2Computation[iter - entryPoints.begin()] = dynamic_cast<IComputationExternalNode*>(builder->GetComputationNode(node)); + auto it = entryPointIndex.find(node); + if (it == entryPointIndex.cend()) { + continue; + } + + auto compNode = dynamic_cast<IComputationExternalNode*>(builder->GetComputationNode(node)); + for (auto index : it->second) { + runtime2ComputationEntryPoints[index] = compNode; } + } + + for (const auto& node : explorer.GetNodes()) { node->SetCookie(0); } + builder->PreserveRoot(rootNode); - builder->PreserveEntryPoints(std::move(runtime2Computation)); + builder->PreserveEntryPoints(std::move(runtime2ComputationEntryPoints)); return MakeIntrusive<TComputationPatternImpl>(std::move(builder), opts); } diff --git a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp index bc694c4b311..0a5144f41fb 100644 --- a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp +++ b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp @@ -1939,7 +1939,7 @@ TComputationNodeFactory GetPgFactory() { auto execFunc = FindExec(id); YQL_ENSURE(execFunc); auto kernel = MakePgKernel(argTypes, returnType, execFunc, id); - return new TBlockFuncNode(ctx.Mutables, std::move(argNodes), argTypes, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argNodes), argTypes, *kernel, kernel); } if (name == "PgCast") { @@ -1995,7 +1995,7 @@ TComputationNodeFactory GetPgFactory() { auto returnType = callable.GetType()->GetReturnType(); ui32 sourceId = AS_TYPE(TPgType, AS_TYPE(TBlockType, inputType)->GetItemType())->GetTypeId(); auto kernel = MakeFromPgKernel(inputType, returnType, sourceId); - return new TBlockFuncNode(ctx.Mutables, { arg }, { inputType }, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), { arg }, { inputType }, *kernel, kernel); } if (name == "ToPg") { @@ -2030,7 +2030,7 @@ TComputationNodeFactory GetPgFactory() { auto returnType = callable.GetType()->GetReturnType(); auto targetId = AS_TYPE(TPgType, AS_TYPE(TBlockType, returnType)->GetItemType())->GetTypeId(); auto kernel = MakeToPgKernel(inputType, returnType, targetId); - return new TBlockFuncNode(ctx.Mutables, { arg }, { inputType }, *kernel, kernel); + return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), { arg }, { inputType }, *kernel, kernel); } if (name == "PgArray") { |