aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2023-05-17 16:17:09 +0300
committervvvv <vvvv@ydb.tech>2023-05-17 16:17:09 +0300
commit2c846092d67299183480076899a56ee76e9c8515 (patch)
tree7e72f3e8c5785f6c3c5038f1f7da8b8df6c1f6d2
parent650b8705c872068ffd6b1cacd055276370dc79af (diff)
downloadydb-2c846092d67299183480076899a56ee76e9c8515.tar.gz
Expose arrow kernels from IComputationNode
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp4
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp96
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp52
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h20
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp4
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp40
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp145
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h3
-rw-r--r--ydb/library/yql/minikql/computation/mkql_computation_node.cpp17
-rw-r--r--ydb/library/yql/minikql/computation/mkql_computation_node.h37
-rw-r--r--ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp106
-rw-r--r--ydb/library/yql/parser/pg_wrapper/comp_factory.cpp6
13 files changed, 477 insertions, 55 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
index d692a304642..6803a70e835 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
@@ -38,7 +38,7 @@ const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TSt
class TBlockBitCastWrapper : public TBlockFuncNode {
public:
TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to)
- : TBlockFuncNode(mutables, { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions)
+ : TBlockFuncNode(mutables, "BitCast", { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions)
, CastOptions(false)
{
}
@@ -74,7 +74,7 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto
}
const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType());
- return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions);
+ return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions);
}
IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
index 37c4f88c61b..544f67e6694 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
@@ -14,45 +14,103 @@ namespace NMiniKQL {
namespace {
class TBlockIfScalarWrapper : public TMutableComputationNode<TBlockIfScalarWrapper> {
+friend class TArrowNode;
public:
+ class TArrowNode : public IArrowKernelComputationNode {
+ public:
+ TArrowNode(const TBlockIfScalarWrapper* parent)
+ : Parent_(parent)
+ , ArgsValuesDescr_(ToValueDescr(parent->ArgsTypes))
+ , Kernel_(ConvertToInputTypes(parent->ArgsTypes), ConvertToOutputType(parent->ResultType), [parent](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
+ *res = parent->CalculateImpl(MakeDatumProvider(batch.values[0]), MakeDatumProvider(batch.values[1]), MakeDatumProvider(batch.values[2]), *ctx->memory_pool());
+ return arrow::Status::OK();
+ })
+ {
+ Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
+ Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
+ }
+
+ TStringBuf GetKernelName() const final {
+ return "If";
+ }
+
+ const arrow::compute::ScalarKernel& GetArrowKernel() const {
+ return Kernel_;
+ }
+
+ const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
+ return ArgsValuesDescr_;
+ }
+
+ const IComputationNode* GetArgument(ui32 index) const {
+ switch (index) {
+ case 0:
+ return Parent_->Pred;
+ case 1:
+ return Parent_->Then;
+ case 2:
+ return Parent_->Else;
+ default:
+ throw yexception() << "Bad argument index";
+ }
+ }
+
+ private:
+ const TBlockIfScalarWrapper* Parent_;
+ const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
+ arrow::compute::ScalarKernel Kernel_;
+ };
+
TBlockIfScalarWrapper(TComputationMutables& mutables, IComputationNode* pred, IComputationNode* thenNode, IComputationNode* elseNode, TType* resultType,
- bool thenIsScalar, bool elseIsScalar)
+ bool thenIsScalar, bool elseIsScalar, const TVector<TType*>& argsTypes)
: TMutableComputationNode(mutables)
, Pred(pred)
, Then(thenNode)
, Else(elseNode)
- , Type(resultType)
+ , ResultType(resultType)
, ThenIsScalar(thenIsScalar)
, ElseIsScalar(elseIsScalar)
+ , ArgsTypes(argsTypes)
{
}
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- auto predValue = Pred->GetValue(ctx);
+ std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
+ Y_UNUSED(ctx);
+ return std::make_unique<TArrowNode>(this);
+ }
+
+ arrow::Datum CalculateImpl(const TDatumProvider& predProv, const TDatumProvider& thenProv, const TDatumProvider& elseProv,
+ arrow::MemoryPool& memoryPool) const {
+ auto predValue = predProv();
- const bool predScalarValue = GetPrimitiveScalarValue<bool>(*TArrowBlock::From(predValue).GetDatum().scalar());
- auto result = predScalarValue ? Then->GetValue(ctx) : Else->GetValue(ctx);
+ const bool predScalarValue = GetPrimitiveScalarValue<bool>(*predValue.scalar());
+ auto result = predScalarValue ? thenProv() : elseProv();
if (ThenIsScalar == ElseIsScalar || (predScalarValue ? !ThenIsScalar : !ElseIsScalar)) {
// can return result as-is
- return result.Release();
+ return result;
}
- auto other = predScalarValue ? Else->GetValue(ctx) : Then->GetValue(ctx);
- const auto& otherDatum = TArrowBlock::From(other).GetDatum();
+ auto otherDatum = predScalarValue ? elseProv() : thenProv();
MKQL_ENSURE(otherDatum.is_arraylike(), "Expecting array");
- std::shared_ptr<arrow::Scalar> resultScalar = TArrowBlock::From(result).GetDatum().scalar();
+ std::shared_ptr<arrow::Scalar> resultScalar = result.scalar();
TVector<std::shared_ptr<arrow::ArrayData>> resultArrays;
+ auto itemType = AS_TYPE(TBlockType, ResultType)->GetItemType();
ForEachArrayData(otherDatum, [&](const std::shared_ptr<arrow::ArrayData>& otherData) {
- auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, Type, ctx.ArrowMemoryPool);
+ auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, itemType, memoryPool);
ForEachArrayData(chunk, [&](const auto& array) {
resultArrays.push_back(array);
});
});
- return ctx.HolderFactory.CreateArrowBlock(MakeArray(resultArrays));
+ return MakeArray(resultArrays);
+ }
+
+ NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
+ return ctx.HolderFactory.CreateArrowBlock(CalculateImpl(MakeDatumProvider(Pred, ctx), MakeDatumProvider(Then, ctx), MakeDatumProvider(Else, ctx), ctx.ArrowMemoryPool));
}
+
private:
void RegisterDependencies() const final {
DependsOn(Pred);
@@ -63,9 +121,10 @@ private:
IComputationNode* const Pred;
IComputationNode* const Then;
IComputationNode* const Else;
- TType* const Type;
+ TType* const ResultType;
const bool ThenIsScalar;
const bool ElseIsScalar;
+ const TVector<TType*> ArgsTypes;
};
template<bool ThenIsScalar, bool ElseIsScalar>
@@ -171,15 +230,16 @@ IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactory
bool predIsScalar = predType->GetShape() == TBlockType::EShape::Scalar;
bool thenIsScalar = thenType->GetShape() == TBlockType::EShape::Scalar;
bool elseIsScalar = elseType->GetShape() == TBlockType::EShape::Scalar;
+ TVector<TType*> argsTypes = { predType, thenType, elseType };
+
if (predIsScalar) {
- return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType->GetItemType(),
- thenIsScalar, elseIsScalar);
+ return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType,
+ thenIsScalar, elseIsScalar, argsTypes);
}
TVector<IComputationNode*> argsNodes = { predCompute, thenCompute, elseCompute };
- TVector<TType*> argsTypes = { predType, thenType, elseType };
-
+
std::shared_ptr<arrow::compute::ScalarKernel> kernel;
if (thenIsScalar && elseIsScalar) {
kernel = MakeBlockIfKernel<true, true>(argsTypes, thenType);
@@ -191,7 +251,7 @@ IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactory
kernel = MakeBlockIfKernel<false, false>(argsTypes, thenType);
}
- return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
index 7bca187600f..7410a8e62c5 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
@@ -11,20 +11,6 @@
namespace NKikimr::NMiniKQL {
-namespace {
-
-std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
- std::vector<arrow::ValueDescr> res;
- res.reserve(types.size());
- for (const auto& type : types) {
- res.emplace_back(ToValueDescr(type));
- }
-
- return res;
-}
-
-} // namespace
-
arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool) {
MKQL_ENSURE(len > 0, "Invalid block size");
auto reader = MakeBlockReader(TTypeInfoHelper(), type);
@@ -44,6 +30,16 @@ arrow::ValueDescr ToValueDescr(TType* type) {
return ret;
}
+std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
+ std::vector<arrow::ValueDescr> res;
+ res.reserve(types.size());
+ for (const auto& type : types) {
+ res.emplace_back(ToValueDescr(type));
+ }
+
+ return res;
+}
+
std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes) {
std::vector<arrow::compute::InputType> result;
result.reserve(argTypes.size());
@@ -57,7 +53,7 @@ arrow::compute::OutputType ConvertToOutputType(TType* output) {
return arrow::compute::OutputType(ToValueDescr(output));
}
-TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes,
+TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TStringBuf name, TVector<IComputationNode*>&& argsNodes,
const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel,
std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder,
const arrow::compute::FunctionOptions* functionOptions)
@@ -69,6 +65,7 @@ TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TVector<IComputat
, KernelHolder(std::move(kernelHolder))
, Options(functionOptions)
, ScalarOutput(GetResultShape(argsTypes) == TBlockType::EShape::Scalar)
+ , Name(name.starts_with("Block") ? name.substr(5) : name)
{
}
@@ -122,4 +119,29 @@ TBlockFuncNode::TState& TBlockFuncNode::GetState(TComputationContext& ctx) const
return *static_cast<TState*>(result.AsBoxed().Get());
}
+std::unique_ptr<IArrowKernelComputationNode> TBlockFuncNode::PrepareArrowKernelComputationNode(TComputationContext& ctx) const {
+ return std::make_unique<TArrowNode>(this);
+}
+
+TBlockFuncNode::TArrowNode::TArrowNode(const TBlockFuncNode* parent)
+ : Parent_(parent)
+{}
+
+TStringBuf TBlockFuncNode::TArrowNode::GetKernelName() const {
+ return Parent_->Name;
+}
+
+const arrow::compute::ScalarKernel& TBlockFuncNode::TArrowNode::GetArrowKernel() const {
+ return Parent_->Kernel;
+}
+
+const std::vector<arrow::ValueDescr>& TBlockFuncNode::TArrowNode::GetArgsDesc() const {
+ return Parent_->ArgsValuesDescr;
+}
+
+const IComputationNode* TBlockFuncNode::TArrowNode::GetArgument(ui32 index) const {
+ MKQL_ENSURE(index < Parent_->ArgsNodes.size(), "Wrong index");
+ return Parent_->ArgsNodes[index];
+}
+
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h
index 3e82ae1e617..1cdd661a61d 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h
@@ -15,19 +15,33 @@ namespace NKikimr::NMiniKQL {
arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool);
arrow::ValueDescr ToValueDescr(TType* type);
+std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types);
std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes);
arrow::compute::OutputType ConvertToOutputType(TType* output);
class TBlockFuncNode : public TMutableComputationNode<TBlockFuncNode> {
+friend class TArrowNode;
public:
- TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes,
+ TBlockFuncNode(TComputationMutables& mutables, TStringBuf name, TVector<IComputationNode*>&& argsNodes,
const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel,
std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder = {},
const arrow::compute::FunctionOptions* functionOptions = nullptr);
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const;
private:
+ class TArrowNode : public IArrowKernelComputationNode {
+ public:
+ TArrowNode(const TBlockFuncNode* parent);
+ TStringBuf GetKernelName() const final;
+ const arrow::compute::ScalarKernel& GetArrowKernel() const final;
+ const std::vector<arrow::ValueDescr>& GetArgsDesc() const final;
+ const IComputationNode* GetArgument(ui32 index) const final;
+
+ private:
+ const TBlockFuncNode* Parent_;
+ };
+
struct TState : public TComputationValue<TState> {
using TComputationValue::TComputationValue;
@@ -51,6 +65,9 @@ private:
void RegisterDependencies() const final;
TState& GetState(TComputationContext& ctx) const;
+
+ std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final;
+
private:
const ui32 StateIndex;
const TVector<IComputationNode*> ArgsNodes;
@@ -59,6 +76,7 @@ private:
const std::shared_ptr<arrow::compute::ScalarKernel> KernelHolder;
const arrow::compute::FunctionOptions* const Options;
const bool ScalarOutput;
+ const TString Name;
};
template <typename TDerived>
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp
index 3dac133b81d..65b0273e0cb 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_just.cpp
@@ -81,7 +81,7 @@ IComputationNode* WrapBlockJust(TCallable& callable, const TComputationNodeFacto
kernel = MakeBlockJustKernel<true>(argsTypes, callable.GetType()->GetReturnType());
}
- return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp
index 1a97f96a208..1a61b9c7f45 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp
@@ -155,7 +155,7 @@ IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFa
}
auto kernel = MakeBlockAsTupleKernel(argsTypes, callable.GetType()->GetReturnType());
- return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}
IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
@@ -175,7 +175,7 @@ IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactor
TVector<IComputationNode*> argsNodes = { tuple };
TVector<TType*> argsTypes = { blockType };
auto kernel = MakeBlockNthKernel(argsTypes, callable.GetType()->GetReturnType(), index, isOptional, needExternalOptional);
- return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
index 8adc1c6ca43..ee0b91e3604 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
@@ -345,6 +345,40 @@ private:
class TAsScalarWrapper : public TMutableComputationNode<TAsScalarWrapper> {
public:
+ class TArrowNode : public IArrowKernelComputationNode {
+ public:
+ TArrowNode(const arrow::Datum& datum)
+ : Kernel_({}, datum.scalar()->type, [datum](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
+ *res = datum;
+ return arrow::Status::OK();
+ })
+ {
+ Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
+ Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
+ }
+
+ TStringBuf GetKernelName() const final {
+ return "AsScalar";
+ }
+
+ const arrow::compute::ScalarKernel& GetArrowKernel() const {
+ return Kernel_;
+ }
+
+ const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
+ return EmptyDesc_;
+ }
+
+ const IComputationNode* GetArgument(ui32 index) const {
+ Y_UNUSED(index);
+ ythrow yexception() << "No input arguments";
+ }
+
+ private:
+ arrow::compute::ScalarKernel Kernel_;
+ const std::vector<arrow::ValueDescr> EmptyDesc_;
+ };
+
TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type)
: TMutableComputationNode(mutables)
, Arg_(arg)
@@ -360,6 +394,12 @@ public:
return ctx.HolderFactory.CreateArrowBlock(std::move(result));
}
+ std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
+ auto value = Arg_->GetValue(ctx);
+ arrow::Datum result = ConvertScalar(Type_, value, ctx);
+ return std::make_unique<TArrowNode>(result);
+ }
+
private:
void RegisterDependencies() const final {
DependsOn(Arg_);
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp
index f2ddd998661..b409a523a68 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_blocks_ut.cpp
@@ -2,8 +2,39 @@
#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+#include <arrow/compute/exec_internal.h>
+#include <arrow/array/builder_primitive.h>
+
namespace NKikimr {
namespace NMiniKQL {
+
+namespace {
+ arrow::Datum ExecuteOneKernel(const IArrowKernelComputationNode* kernelNode,
+ const std::vector<arrow::Datum>& argDatums, arrow::compute::ExecContext& execContext) {
+ const auto& kernel = kernelNode->GetArrowKernel();
+ arrow::compute::KernelContext kernelContext(&execContext);
+ std::unique_ptr<arrow::compute::KernelState> state;
+ auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
+ ARROW_OK(executor->Init(&kernelContext, { &kernel, kernelNode->GetArgsDesc(), nullptr }));
+ auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
+ ARROW_OK(executor->Execute(argDatums, listener.get()));
+ return executor->WrapResults(argDatums, listener->values());
+ }
+
+ void ExecuteAllKernels(std::vector<arrow::Datum>& datums, const TArrowKernelsTopology* topology, arrow::compute::ExecContext& execContext) {
+ for (ui32 i = 0; i < topology->Items.size(); ++i) {
+ std::vector<arrow::Datum> argDatums;
+ argDatums.reserve(topology->Items[i].Inputs.size());
+ for (auto j : topology->Items[i].Inputs) {
+ argDatums.emplace_back(datums[j]);
+ }
+
+ arrow::Datum output = ExecuteOneKernel(topology->Items[i].Node.get(), argDatums, execContext);
+ datums[i + topology->InputArgsCount] = output;
+ }
+ }
+}
+
Y_UNIT_TEST_SUITE(TMiniKQLBlocksTest) {
Y_UNIT_TEST(TestEmpty) {
TSetup<false> setup;
@@ -457,7 +488,121 @@ Y_UNIT_TEST(TestWideFromBlocks) {
UNIT_ASSERT(iterator.Next(item));
UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 30);
}
+}
+
+Y_UNIT_TEST_SUITE(TMiniKQLDirectKernelTest) {
+Y_UNIT_TEST(Simple) {
+ TSetup<false> setup;
+ auto& pb = *setup.PgmBuilder;
+
+ const auto boolType = pb.NewDataType(NUdf::TDataType<bool>::Id);
+ const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id);
+ const auto boolBlocksType = pb.NewBlockType(boolType, TBlockType::EShape::Many);
+ const auto ui64BlocksType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many);
+ const auto arg1 = pb.Arg(boolBlocksType);
+ const auto arg2 = pb.Arg(ui64BlocksType);
+ const auto arg3 = pb.Arg(ui64BlocksType);
+ const auto ifNode = pb.BlockIf(arg1, arg2, arg3);
+ const auto eqNode = pb.BlockFunc("Equals", boolBlocksType, { ifNode, arg2 });
+
+ const auto graph = setup.BuildGraph(eqNode, {arg1.GetNode(), arg2.GetNode(), arg3.GetNode()});
+ const auto topology = graph->GetKernelsTopology();
+ UNIT_ASSERT(topology);
+ UNIT_ASSERT_VALUES_EQUAL(topology->InputArgsCount, 3);
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items.size(), 2);
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Node->GetKernelName(), "If");
+ const std::vector<ui32> expectedInputs1{{0, 1, 2}};
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Inputs, expectedInputs1);
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Node->GetKernelName(), "Equals");
+ const std::vector<ui32> expectedInputs2{{3, 1}};
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Inputs, expectedInputs2);
+
+ arrow::compute::ExecContext execContext;
+ const size_t blockSize = 100000;
+ std::vector<arrow::Datum> datums(topology->InputArgsCount + topology->Items.size());
+ {
+ arrow::UInt8Builder builder1(execContext.memory_pool());
+ arrow::UInt64Builder builder2(execContext.memory_pool()), builder3(execContext.memory_pool());
+ ARROW_OK(builder1.Reserve(blockSize));
+ ARROW_OK(builder2.Reserve(blockSize));
+ ARROW_OK(builder3.Reserve(blockSize));
+ for (size_t i = 0; i < blockSize; ++i) {
+ builder1.UnsafeAppend(i & 1);
+ builder2.UnsafeAppend(i);
+ builder3.UnsafeAppend(3 * i);
+ }
+
+ std::shared_ptr<arrow::ArrayData> data1;
+ ARROW_OK(builder1.FinishInternal(&data1));
+ std::shared_ptr<arrow::ArrayData> data2;
+ ARROW_OK(builder2.FinishInternal(&data2));
+ std::shared_ptr<arrow::ArrayData> data3;
+ ARROW_OK(builder3.FinishInternal(&data3));
+ datums[0] = data1;
+ datums[1] = data2;
+ datums[2] = data3;
+ }
+
+ ExecuteAllKernels(datums, topology, execContext);
+
+ auto res = datums.back().array()->GetValues<ui8>(1);
+ for (size_t i = 0; i < blockSize; ++i) {
+ auto expected = (((i & 1) ? i : i * 3) == i) ? 1 : 0;
+ UNIT_ASSERT_VALUES_EQUAL(res[i], expected);
+ }
+}
+Y_UNIT_TEST(WithScalars) {
+ TSetup<false> setup;
+ auto& pb = *setup.PgmBuilder;
+
+ const auto ui64Type = pb.NewDataType(NUdf::TDataType<ui64>::Id);
+ const auto ui64BlocksType = pb.NewBlockType(ui64Type, TBlockType::EShape::Many);
+ const auto scalar = pb.AsScalar(pb.NewDataLiteral(false));
+ const auto arg1 = pb.Arg(ui64BlocksType);
+ const auto arg2 = pb.Arg(ui64BlocksType);
+ const auto ifNode = pb.BlockIf(scalar, arg1, arg2);
+
+ const auto graph = setup.BuildGraph(ifNode, {arg1.GetNode(), arg2.GetNode()});
+ const auto topology = graph->GetKernelsTopology();
+ UNIT_ASSERT(topology);
+ UNIT_ASSERT_VALUES_EQUAL(topology->InputArgsCount, 2);
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items.size(), 2);
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Node->GetKernelName(), "AsScalar");
+ const std::vector<ui32> expectedInputs1;
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[0].Inputs, expectedInputs1);
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Node->GetKernelName(), "If");
+ const std::vector<ui32> expectedInputs2{{2, 0, 1}};
+ UNIT_ASSERT_VALUES_EQUAL(topology->Items[1].Inputs, expectedInputs2);
+
+ arrow::compute::ExecContext execContext;
+ const size_t blockSize = 100000;
+ std::vector<arrow::Datum> datums(topology->InputArgsCount + topology->Items.size());
+ {
+ arrow::UInt64Builder builder1(execContext.memory_pool()), builder2(execContext.memory_pool());
+ ARROW_OK(builder1.Reserve(blockSize));
+ ARROW_OK(builder2.Reserve(blockSize));
+ for (size_t i = 0; i < blockSize; ++i) {
+ builder1.UnsafeAppend(i);
+ builder2.UnsafeAppend(3 * i);
+ }
+
+ std::shared_ptr<arrow::ArrayData> data1;
+ ARROW_OK(builder1.FinishInternal(&data1));
+ std::shared_ptr<arrow::ArrayData> data2;
+ ARROW_OK(builder2.FinishInternal(&data2));
+ datums[0] = data1;
+ datums[1] = data2;
+ }
+
+ ExecuteAllKernels(datums, topology, execContext);
+
+ auto res = datums.back().array()->GetValues<ui64>(1);
+ for (size_t i = 0; i < blockSize; ++i) {
+ auto expected = 3 * i;
+ UNIT_ASSERT_VALUES_EQUAL(res[i], expected);
+ }
+}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h
index bf20ea2c110..0c3eb242122 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_computation_node_ut.h
@@ -91,7 +91,8 @@ struct TSetup {
Reset();
Explorer.Walk(pgm.GetNode(), *Env);
TComputationPatternOpts opts(Alloc.Ref(), *Env, GetTestFactory(NodeFactory),
- FunctionRegistry.Get(), NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, UseLLVM ? "" : "OFF", EGraphPerProcess::Multi);
+ FunctionRegistry.Get(), NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception,
+ UseLLVM ? "" : "OFF", EGraphPerProcess::Multi, nullptr, nullptr, nullptr);
Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
auto graph = Pattern->Clone(opts.ToComputationOptions(*RandomProvider, *TimeProvider));
Terminator.Reset(new TBindTerminator(graph->GetTerminator()));
diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node.cpp b/ydb/library/yql/minikql/computation/mkql_computation_node.cpp
index 6921e83fedd..273eaace7a4 100644
--- a/ydb/library/yql/minikql/computation/mkql_computation_node.cpp
+++ b/ydb/library/yql/minikql/computation/mkql_computation_node.cpp
@@ -26,6 +26,23 @@
namespace NKikimr {
namespace NMiniKQL {
+std::unique_ptr<IArrowKernelComputationNode> IComputationNode::PrepareArrowKernelComputationNode(TComputationContext& ctx) const {
+ Y_UNUSED(ctx);
+ return {};
+}
+
+TDatumProvider MakeDatumProvider(const arrow::Datum& datum) {
+ return [datum]() {
+ return datum;
+ };
+}
+
+TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx) {
+ return [node, &ctx]() {
+ return TArrowBlock::From(node->GetValue(ctx)).GetDatum();
+ };
+}
+
TComputationContext::TComputationContext(const THolderFactory& holderFactory,
const NUdf::IValueBuilder* builder,
TComputationOptsFull& opts,
diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node.h b/ydb/library/yql/minikql/computation/mkql_computation_node.h
index 640ef0f5813..ab898d7ee61 100644
--- a/ydb/library/yql/minikql/computation/mkql_computation_node.h
+++ b/ydb/library/yql/minikql/computation/mkql_computation_node.h
@@ -143,6 +143,8 @@ private:
#endif
};
+class IArrowKernelComputationNode;
+
class IComputationNode {
public:
typedef TIntrusivePtr<IComputationNode> TPtr;
@@ -177,6 +179,8 @@ public:
virtual void Ref() = 0;
virtual void UnRef() = 0;
virtual ui32 RefCount() const = 0;
+
+ virtual std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const;
};
class IComputationExternalNode : public IComputationNode {
@@ -209,6 +213,31 @@ public:
virtual void InvalidateValue(TComputationContext& compCtx) const = 0;
};
+using TDatumProvider = std::function<arrow::Datum()>;
+
+TDatumProvider MakeDatumProvider(const arrow::Datum& datum);
+TDatumProvider MakeDatumProvider(const IComputationNode* node, TComputationContext& ctx);
+
+class IArrowKernelComputationNode {
+public:
+ virtual ~IArrowKernelComputationNode() = default;
+
+ virtual TStringBuf GetKernelName() const = 0;
+ virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0;
+ virtual const std::vector<arrow::ValueDescr>& GetArgsDesc() const = 0;
+ virtual const IComputationNode* GetArgument(ui32 index) const = 0;
+};
+
+struct TArrowKernelsTopologyItem {
+ std::vector<ui32> Inputs;
+ std::unique_ptr<IArrowKernelComputationNode> Node;
+};
+
+struct TArrowKernelsTopology {
+ ui32 InputArgsCount = 0;
+ std::vector<TArrowKernelsTopologyItem> Items;
+};
+
using TComputationNodePtrVector = std::vector<IComputationNode*, TMKQLAllocator<IComputationNode*>>;
using TComputationWideFlowNodePtrVector = std::vector<IComputationWideFlowNode*, TMKQLAllocator<IComputationWideFlowNode*>>;
using TComputationExternalNodePtrVector = std::vector<IComputationExternalNode*, TMKQLAllocator<IComputationExternalNode*>>;
@@ -223,6 +252,7 @@ public:
virtual NUdf::TUnboxedValue GetValue() = 0;
virtual TComputationContext& GetContext() = 0;
virtual IComputationExternalNode* GetEntryPoint(size_t index, bool require) = 0;
+ virtual const TArrowKernelsTopology* GetKernelsTopology() = 0;
virtual const TComputationNodePtrDeque& GetNodes() const = 0;
virtual void Invalidate() = 0;
virtual TMemoryUsageInfo& GetMemInfo() const = 0;
@@ -310,7 +340,8 @@ struct TComputationPatternOpts {
const TString& optLLVM,
EGraphPerProcess graphPerProcess,
IStatsRegistry* stats = nullptr,
- NUdf::ICountersProvider* countersProvider = nullptr)
+ NUdf::ICountersProvider* countersProvider = nullptr,
+ const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr)
: AllocState(allocState)
, Env(env)
, Factory(factory)
@@ -321,12 +352,14 @@ struct TComputationPatternOpts {
, GraphPerProcess(graphPerProcess)
, Stats(stats)
, CountersProvider(countersProvider)
+ , SecureParamsProvider(secureParamsProvider)
{}
void SetOptions(TComputationNodeFactory factory, const IFunctionRegistry* functionRegistry,
NUdf::EValidateMode validateMode, NUdf::EValidatePolicy validatePolicy,
const TString& optLLVM, EGraphPerProcess graphPerProcess, IStatsRegistry* stats = nullptr,
- NUdf::ICountersProvider* counters = nullptr, const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) {
+ NUdf::ICountersProvider* counters = nullptr,
+ const NUdf::ISecureParamsProvider* secureParamsProvider = nullptr) {
Factory = factory;
FunctionRegistry = functionRegistry;
ValidateMode = validateMode;
diff --git a/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp b/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp
index c8de496a533..2eb542fc21f 100644
--- a/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp
+++ b/ydb/library/yql/minikql/computation/mkql_computation_node_graph.cpp
@@ -186,9 +186,9 @@ public:
}
IComputationExternalNode* GetEntryPoint(size_t index, bool require) {
- MKQL_ENSURE(index < Runtime2Computation.size() && (!require || Runtime2Computation[index]),
+ MKQL_ENSURE(index < Runtime2ComputationEntryPoints.size() && (!require || Runtime2ComputationEntryPoints[index]),
"Pattern nodes can not get computation node by index: " << index << ", require: " << require);
- return Runtime2Computation[index];
+ return Runtime2ComputationEntryPoints[index];
}
IComputationNode* GetRoot() {
@@ -199,6 +199,10 @@ public:
return SuitableForCache;
}
+ size_t GetEntryPointsCount() const {
+ return Runtime2ComputationEntryPoints.size();
+ }
+
private:
friend class TComputationGraphBuildingVisitor;
friend class TComputationGraph;
@@ -210,7 +214,7 @@ private:
TComputationMutables Mutables;
TComputationNodePtrDeque ComputationNodesList;
IComputationNode* RootNode = nullptr;
- TComputationExternalNodePtrVector Runtime2Computation;
+ TComputationExternalNodePtrVector Runtime2ComputationEntryPoints;
TComputationNodeOnNodeMap ElementsCache;
bool SuitableForCache = true;
};
@@ -527,8 +531,8 @@ public:
PatternNodes->RootNode = rootNode;
}
- void PreserveEntryPoints(TComputationExternalNodePtrVector&& runtime2Computation) {
- PatternNodes->Runtime2Computation = std::move(runtime2Computation);
+ void PreserveEntryPoints(TComputationExternalNodePtrVector&& runtime2ComputationEntryPoints) {
+ PatternNodes->Runtime2ComputationEntryPoints = std::move(runtime2ComputationEntryPoints);
}
private:
@@ -617,6 +621,72 @@ public:
return PatternNodes->GetEntryPoint(index, require);
}
+ const TArrowKernelsTopology* GetKernelsTopology() override {
+ Prepare();
+ if (!KernelsTopology.has_value()) {
+ CalculateKernelTopology(*Ctx);
+ }
+
+ return &KernelsTopology.value();
+ }
+
+ void CalculateKernelTopology(TComputationContext& ctx) {
+ KernelsTopology.emplace();
+ KernelsTopology->InputArgsCount = PatternNodes->GetEntryPointsCount();
+
+ std::stack<const IComputationNode*> stack;
+ struct TNodeState {
+ bool Visited;
+ ui32 Index;
+ };
+
+ std::unordered_map<const IComputationNode*, TNodeState> deps;
+ for (ui32 i = 0; i < KernelsTopology->InputArgsCount; ++i) {
+ auto entryPoint = PatternNodes->GetEntryPoint(i, false);
+ if (!entryPoint) {
+ continue;
+ }
+
+ deps.emplace(entryPoint, TNodeState{ true, i});
+ }
+
+ stack.push(PatternNodes->GetRoot());
+ while (!stack.empty()) {
+ auto node = stack.top();
+ auto [iter, inserted] = deps.emplace(node, TNodeState{ false, 0 });
+ auto extNode = dynamic_cast<const IComputationExternalNode*>(node);
+ if (extNode) {
+ MKQL_ENSURE(!inserted, "Unexpected external node");
+ stack.pop();
+ continue;
+ }
+
+ auto kernelNode = node->PrepareArrowKernelComputationNode(ctx);
+ MKQL_ENSURE(kernelNode, "No kernel for node: " << node->DebugString());
+ auto argsCount = kernelNode->GetArgsDesc().size();
+
+ if (!iter->second.Visited) {
+ for (ui32 j = 0; j < argsCount; ++j) {
+ stack.push(kernelNode->GetArgument(j));
+ }
+ iter->second.Visited = true;
+ } else {
+ iter->second.Index = KernelsTopology->InputArgsCount + KernelsTopology->Items.size();
+ KernelsTopology->Items.emplace_back();
+ auto& i = KernelsTopology->Items.back();
+ i.Inputs.reserve(argsCount);
+ for (ui32 j = 0; j < argsCount; ++j) {
+ auto it = deps.find(kernelNode->GetArgument(j));
+ MKQL_ENSURE(it != deps.end(), "Missing argument");
+ i.Inputs.emplace_back(it->second.Index);
+ }
+
+ i.Node = std::move(kernelNode);
+ stack.pop();
+ }
+ }
+ }
+
void Invalidate() override {
std::fill_n(Ctx->MutableValues.get(), PatternNodes->GetMutables().CurValueIndex, NUdf::TUnboxedValue(NUdf::TUnboxedValuePod::Invalid()));
}
@@ -690,6 +760,7 @@ private:
THolder<TComputationContext> Ctx;
TComputationOptsFull CompOpts;
bool IsPrepared = false;
+ std::optional<TArrowKernelsTopology> KernelsTopology;
};
} // namespace
@@ -861,16 +932,31 @@ TIntrusivePtr<TComputationPatternImpl> MakeComputationPatternImpl(TExploringNode
const auto rootNode = builder->GetComputationNode(root.GetNode());
- TComputationExternalNodePtrVector runtime2Computation;
- runtime2Computation.resize(entryPoints.size(), nullptr);
+ TComputationExternalNodePtrVector runtime2ComputationEntryPoints;
+ runtime2ComputationEntryPoints.resize(entryPoints.size(), nullptr);
+ std::unordered_map<TNode*, std::vector<ui32>> entryPointIndex;
+ for (ui32 i = 0; i < entryPoints.size(); ++i) {
+ entryPointIndex[entryPoints[i]].emplace_back(i);
+ }
+
for (const auto& node : explorer.GetNodes()) {
- for (auto iter = std::find(entryPoints.cbegin(), entryPoints.cend(), node); entryPoints.cend() != iter; iter = std::find(iter + 1, entryPoints.cend(), node)) {
- runtime2Computation[iter - entryPoints.begin()] = dynamic_cast<IComputationExternalNode*>(builder->GetComputationNode(node));
+ auto it = entryPointIndex.find(node);
+ if (it == entryPointIndex.cend()) {
+ continue;
+ }
+
+ auto compNode = dynamic_cast<IComputationExternalNode*>(builder->GetComputationNode(node));
+ for (auto index : it->second) {
+ runtime2ComputationEntryPoints[index] = compNode;
}
+ }
+
+ for (const auto& node : explorer.GetNodes()) {
node->SetCookie(0);
}
+
builder->PreserveRoot(rootNode);
- builder->PreserveEntryPoints(std::move(runtime2Computation));
+ builder->PreserveEntryPoints(std::move(runtime2ComputationEntryPoints));
return MakeIntrusive<TComputationPatternImpl>(std::move(builder), opts);
}
diff --git a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp
index bc694c4b311..0a5144f41fb 100644
--- a/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp
+++ b/ydb/library/yql/parser/pg_wrapper/comp_factory.cpp
@@ -1939,7 +1939,7 @@ TComputationNodeFactory GetPgFactory() {
auto execFunc = FindExec(id);
YQL_ENSURE(execFunc);
auto kernel = MakePgKernel(argTypes, returnType, execFunc, id);
- return new TBlockFuncNode(ctx.Mutables, std::move(argNodes), argTypes, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argNodes), argTypes, *kernel, kernel);
}
if (name == "PgCast") {
@@ -1995,7 +1995,7 @@ TComputationNodeFactory GetPgFactory() {
auto returnType = callable.GetType()->GetReturnType();
ui32 sourceId = AS_TYPE(TPgType, AS_TYPE(TBlockType, inputType)->GetItemType())->GetTypeId();
auto kernel = MakeFromPgKernel(inputType, returnType, sourceId);
- return new TBlockFuncNode(ctx.Mutables, { arg }, { inputType }, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), { arg }, { inputType }, *kernel, kernel);
}
if (name == "ToPg") {
@@ -2030,7 +2030,7 @@ TComputationNodeFactory GetPgFactory() {
auto returnType = callable.GetType()->GetReturnType();
auto targetId = AS_TYPE(TPgType, AS_TYPE(TBlockType, returnType)->GetItemType())->GetTypeId();
auto kernel = MakeToPgKernel(inputType, returnType, targetId);
- return new TBlockFuncNode(ctx.Mutables, { arg }, { inputType }, *kernel, kernel);
+ return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), { arg }, { inputType }, *kernel, kernel);
}
if (name == "PgArray") {