aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <aneporada@ydb.tech>2023-01-14 15:35:20 +0300
committeraneporada <aneporada@ydb.tech>2023-01-14 15:35:20 +0300
commitd95441a3c516b3781878af847751cf9d143af91d (patch)
tree8871fb639ae8a7856e00538214f7099c29f85ae9
parent3657be5988251fc9074ba5b86b62bfa985ff4643 (diff)
downloadydb-d95441a3c516b3781878af847751cf9d143af91d.tar.gz
Implement BlockIf
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp5
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.cpp43
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_blocks.h1
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp1
-rw-r--r--ydb/library/yql/minikql/arrow/arrow_util.cpp30
-rw-r--r--ydb/library/yql/minikql/arrow/arrow_util.h3
-rw-r--r--ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp38
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp193
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp198
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_if.h10
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp183
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h52
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_item.h13
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp4
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp2
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp17
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h1
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp4
23 files changed, 585 insertions, 223 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index 63a58e9fdb..e377d1592b 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -4855,7 +4855,7 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
TExprNode::TListType funcArgs;
std::string_view arrowFunctionName;
- if (node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce"}))
+ if (node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce", "If"}))
{
for (auto& child : node->ChildrenList()) {
if (child->IsComplete()) {
@@ -4868,9 +4868,8 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo
}
TString blockFuncName = TString("Block") + node->Content();
- if (funcArgs.size() > 2) {
+ if (node->IsCallable({"And", "Or", "Xor"}) && funcArgs.size() > 2) {
// Split original argument list by pairs (since the order is not important balanced tree is used)
- // this is only supported by And/Or/Xor
rewrites[node.Get()] = SplitByPairs(node->Pos(), blockFuncName, funcArgs, 0, funcArgs.size(), ctx);
} else {
rewrites[node.Get()] = ctx.NewCallable(node->Pos(), blockFuncName, std::move(funcArgs));
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
index 8800570e23..9984ed3766 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp
@@ -189,6 +189,49 @@ IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TEx
return IGraphTransformer::TStatus::Ok;
}
+IGraphTransformer::TStatus BlockIfWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 3U, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ auto pred = input->Child(0);
+ auto thenNode = input->Child(1);
+ auto elseNode = input->Child(2);
+
+ if (!EnsureBlockOrScalarType(*pred, ctx.Expr) ||
+ !EnsureBlockOrScalarType(*thenNode, ctx.Expr) ||
+ !EnsureBlockOrScalarType(*elseNode, ctx.Expr))
+ {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ bool predIsScalar;
+ const TTypeAnnotationNode* predItemType = GetBlockItemType(*pred->GetTypeAnn(), predIsScalar);
+ if (!EnsureSpecificDataType(pred->Pos(), *predItemType, NUdf::EDataSlot::Bool, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ bool thenIsScalar;
+ const TTypeAnnotationNode* thenItemType = GetBlockItemType(*thenNode->GetTypeAnn(), thenIsScalar);
+
+ bool elseIsScalar;
+ const TTypeAnnotationNode* elseItemType = GetBlockItemType(*elseNode->GetTypeAnn(), elseIsScalar);
+
+ if (!IsSameAnnotation(*thenItemType, *elseItemType)) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), TStringBuilder() <<
+ "Mismatch item types: then branch is " << *thenItemType << ", else branch is " << *elseItemType));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (predIsScalar && thenIsScalar && elseIsScalar) {
+ input->SetTypeAnn(ctx.Expr.MakeType<TScalarExprType>(thenItemType));
+ } else {
+ input->SetTypeAnn(ctx.Expr.MakeType<TBlockExprType>(thenItemType));
+ }
+ return IGraphTransformer::TStatus::Ok;
+}
+
IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx) {
Y_UNUSED(output);
if (!EnsureMinArgsCount(*input, 2U, ctx.Expr)) {
diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.h b/ydb/library/yql/core/type_ann/type_ann_blocks.h
index 9e2364dacf..07fed8ae83 100644
--- a/ydb/library/yql/core/type_ann/type_ann_blocks.h
+++ b/ydb/library/yql/core/type_ann/type_ann_blocks.h
@@ -13,6 +13,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockCoalesceWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockLogicalWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+ IGraphTransformer::TStatus BlockIfWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
IGraphTransformer::TStatus BlockBitCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
IGraphTransformer::TStatus BlockCombineAllWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index 11a259c61d..d54270154a 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11840,6 +11840,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["BlockOr"] = &BlockLogicalWrapper;
Functions["BlockXor"] = &BlockLogicalWrapper;
Functions["BlockNot"] = &BlockLogicalWrapper;
+ Functions["BlockIf"] = &BlockIfWrapper;
ExtFunctions["BlockFunc"] = &BlockFuncWrapper;
ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper;
diff --git a/ydb/library/yql/minikql/arrow/arrow_util.cpp b/ydb/library/yql/minikql/arrow/arrow_util.cpp
index 37ded548e4..ff8015734c 100644
--- a/ydb/library/yql/minikql/arrow/arrow_util.cpp
+++ b/ydb/library/yql/minikql/arrow/arrow_util.cpp
@@ -1,5 +1,9 @@
#include "arrow_util.h"
#include "mkql_bit_utils.h"
+
+#include <arrow/array/array_base.h>
+#include <arrow/chunked_array.h>
+
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <util/system/yassert.h>
@@ -67,4 +71,30 @@ std::shared_ptr<arrow::Buffer> MakeDenseBitmap(const ui8* srcSparse, size_t len,
return bitmap;
}
+void ForEachArrayData(const arrow::Datum& datum, const std::function<void(const std::shared_ptr<arrow::ArrayData>&)>& func) {
+ MKQL_ENSURE(datum.is_arraylike(), "Expected array");
+ if (datum.is_array()) {
+ func(datum.array());
+ } else {
+ for (auto& chunk : datum.chunks()) {
+ func(chunk->data());
+ }
+ }
+}
+
+arrow::Datum MakeArray(const TVector<std::shared_ptr<arrow::ArrayData>>& chunks) {
+ MKQL_ENSURE(!chunks.empty(), "Expected non empty chunks");
+ arrow::ArrayVector resultChunks;
+ for (auto& chunk : chunks) {
+ resultChunks.push_back(arrow::Datum(chunk).make_array());
+ }
+
+ if (resultChunks.size() > 1) {
+ auto type = resultChunks.front()->type();
+ auto chunked = ARROW_RESULT(arrow::ChunkedArray::Make(std::move(resultChunks), type));
+ return arrow::Datum(chunked);
+ }
+ return arrow::Datum(resultChunks.front());
+}
+
}
diff --git a/ydb/library/yql/minikql/arrow/arrow_util.h b/ydb/library/yql/minikql/arrow/arrow_util.h
index 005b3f2938..86e0890732 100644
--- a/ydb/library/yql/minikql/arrow/arrow_util.h
+++ b/ydb/library/yql/minikql/arrow/arrow_util.h
@@ -28,6 +28,9 @@ inline arrow::internal::Bitmap GetBitmap(const arrow::ArrayData& arr, int index)
return arrow::internal::Bitmap{ arr.buffers[index], arr.offset, arr.length };
}
+void ForEachArrayData(const arrow::Datum& datum, const std::function<void(const std::shared_ptr<arrow::ArrayData>&)>& func);
+arrow::Datum MakeArray(const TVector<std::shared_ptr<arrow::ArrayData>>& chunks);
+
template <typename T>
T GetPrimitiveScalarValue(const arrow::Scalar& scalar) {
return *static_cast<const T*>(dynamic_cast<const arrow::internal::PrimitiveScalarBase&>(scalar).data());
diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt
index 2e6b5588c3..6cc5135bdb 100644
--- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt
+++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.darwin.txt
@@ -42,6 +42,8 @@ target_sources(yql-minikql-comp_nodes PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_coalesce.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_logical.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt
index 303a6c6b5e..87be655fd0 100644
--- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt
+++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux-aarch64.txt
@@ -43,6 +43,8 @@ target_sources(yql-minikql-comp_nodes PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_coalesce.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_logical.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt
index 303a6c6b5e..87be655fd0 100644
--- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt
+++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.linux.txt
@@ -43,6 +43,8 @@ target_sources(yql-minikql-comp_nodes PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_sum.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_coalesce.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_logical.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp
index a94d01c8b0..08fdd91649 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.cpp
@@ -17,29 +17,6 @@ namespace NMiniKQL {
namespace {
-bool AlwaysUseChunks(const TType* type) {
- if (type->IsOptional()) {
- return AlwaysUseChunks(AS_TYPE(TOptionalType, type)->GetItemType());
- }
-
- if (type->IsTuple()) {
- auto tupleType = AS_TYPE(TTupleType, type);
- for (ui32 i = 0; i < tupleType->GetElementsCount(); ++i) {
- if (AlwaysUseChunks(tupleType->GetElementType(i))) {
- return true;
- }
- }
- return false;
- }
-
- if (type->IsData()) {
- auto slot = *AS_TYPE(TDataType, type)->GetDataSlot();
- return (GetDataTypeInfo(slot).Features & NYql::NUdf::EDataTypeFeatures::StringType) != 0u;
- }
-
- MKQL_ENSURE(false, "Unsupported type");
-}
-
std::shared_ptr<arrow::DataType> GetArrowType(TType* type) {
std::shared_ptr<arrow::DataType> result;
Y_VERIFY(ConvertArrowType(type, result));
@@ -101,21 +78,14 @@ public:
CurrLen += popCount;
}
- NUdf::TUnboxedValuePod Build(TComputationContext& ctx, bool finish) final {
+ arrow::Datum Build(bool finish) final {
auto tree = BuildTree(finish);
- arrow::ArrayVector chunks;
+ TVector<std::shared_ptr<arrow::ArrayData>> chunks;
while (size_t size = CalcSliceSize(*tree)) {
- std::shared_ptr<arrow::ArrayData> data = Slice(*tree, size);
- chunks.push_back(arrow::Datum(data).make_array());
+ chunks.push_back(Slice(*tree, size));
}
- Y_VERIFY(!chunks.empty());
-
- if (chunks.size() > 1 || AlwaysUseChunks(Type)) {
- auto chunked = ARROW_RESULT(arrow::ChunkedArray::Make(std::move(chunks), GetArrowType(Type)));
- return ctx.HolderFactory.CreateArrowBlock(std::move(chunked));
- }
- return ctx.HolderFactory.CreateArrowBlock(chunks.front());
+ return MakeArray(chunks);
}
TBlockArrayTree::Ptr BuildTree(bool finish) {
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h
index 4769c3d59f..3a937d6dd7 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_builder.h
@@ -31,7 +31,7 @@ public:
virtual void Add(NUdf::TUnboxedValuePod value) = 0;
virtual void Add(TBlockItem value) = 0;
virtual void AddMany(const arrow::ArrayData& array, size_t popCount, const ui8* sparseBitmap, size_t bitmapSize) = 0;
- virtual NUdf::TUnboxedValuePod Build(TComputationContext& ctx, bool finish) = 0;
+ virtual arrow::Datum Build(bool finish) = 0;
};
std::unique_ptr<IBlockBuilder> MakeBlockBuilder(TType* type, arrow::MemoryPool& pool, size_t maxBlockLength);
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp
index 68801afd6d..34a5d098a8 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_compress.cpp
@@ -315,7 +315,7 @@ private:
for (ui32 i = 0, outIndex = 0; i < Width_; ++i) {
bool isScalar = Types_[i]->GetShape() == TBlockType::EShape::Scalar;
if (i != BitmapIndex_ && output[outIndex]) {
- *output[outIndex] = isScalar ? s.InputValues_[i] : s.Builders_[i]->Build(ctx, s.Finish_);
+ *output[outIndex] = isScalar ? s.InputValues_[i] : ctx.HolderFactory.CreateArrowBlock(s.Builders_[i]->Build(s.Finish_));
}
if (i != BitmapIndex_) {
outIndex++;
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
index d8b4aa8b43..d692a30464 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
@@ -1,47 +1,18 @@
#include "mkql_block_func.h"
+#include "mkql_block_impl.h"
#include <ydb/library/yql/minikql/arrow/arrow_defs.h>
-#include <ydb/library/yql/minikql/arrow/mkql_functions.h>
-#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
-#include <ydb/library/yql/minikql/computation/mkql_computation_node_codegen.h>
#include <ydb/library/yql/minikql/mkql_node_builder.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
#include <ydb/library/yql/minikql/mkql_type_builder.h>
-#include <arrow/array/builder_primitive.h>
#include <arrow/compute/cast.h>
-#include <arrow/compute/exec_internal.h>
-#include <arrow/compute/function.h>
-#include <arrow/compute/kernel.h>
-#include <arrow/compute/registry.h>
-#include <arrow/util/bit_util.h>
namespace NKikimr {
namespace NMiniKQL {
namespace {
-arrow::ValueDescr ToValueDescr(TType* type) {
- arrow::ValueDescr ret;
- MKQL_ENSURE(ConvertInputArrowType(type, ret), "can't get arrow type");
- return ret;
-}
-
-std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
- std::vector<arrow::ValueDescr> res;
- res.reserve(types.size());
- for (const auto& type : types) {
- res.emplace_back(ToValueDescr(type));
- }
-
- return res;
-}
-
-const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function& function, const std::vector<arrow::ValueDescr>& args) {
- const auto kernel = ARROW_RESULT(function.DispatchExact(args));
- return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
-}
-
const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) {
std::vector<NUdf::TDataTypeId> argTypes;
for (const auto& t : inputTypes) {
@@ -64,166 +35,31 @@ const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TSt
return *kernel;
}
-struct TState : public TComputationValue<TState> {
- using TComputationValue::TComputationValue;
-
- TState(TMemoryUsageInfo* memInfo, const arrow::compute::FunctionOptions* options,
- const arrow::compute::ScalarKernel& kernel,
- const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx)
- : TComputationValue(memInfo)
- , Options(options)
- , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr)
- , KernelContext(&ExecContext)
- {
- if (kernel.init) {
- State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
- KernelContext.SetState(State.get());
- }
-
- Values.reserve(argsValuesDescr.size());
- }
-
- const arrow::compute::FunctionOptions* Options;
- arrow::compute::ExecContext ExecContext;
- arrow::compute::KernelContext KernelContext;
- std::unique_ptr<arrow::compute::KernelState> State;
-
- std::vector<arrow::Datum> Values;
-};
-
-class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> {
-public:
- TBlockFuncWrapper(TComputationMutables& mutables,
- const IBuiltinFunctionRegistry& builtins,
- const TString& funcName,
- TVector<IComputationNode*>&& argsNodes,
- TVector<TType*>&& argsTypes,
- TType* returnType)
- : TMutableComputationNode(mutables)
- , StateIndex(mutables.CurValueIndex++)
- , FuncName(funcName)
- , ArgsNodes(std::move(argsNodes))
- , ArgsTypes(std::move(argsTypes))
- , ArgsValuesDescr(ToValueDescr(ArgsTypes))
- , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes, returnType))
- {
- }
-
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- auto& state = GetState(ctx);
-
- state.Values.clear();
- for (ui32 i = 0; i < ArgsNodes.size(); ++i) {
- state.Values.emplace_back(TArrowBlock::From(ArgsNodes[i]->GetValue(ctx)).GetDatum());
- Y_VERIFY_DEBUG(ArgsValuesDescr[i] == state.Values.back().descr());
- }
-
- auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
- auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
- ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options }));
- ARROW_OK(executor->Execute(state.Values, listener.get()));
- auto output = executor->WrapResults(state.Values, listener->values());
- return ctx.HolderFactory.CreateArrowBlock(std::move(output));
- }
-
-private:
- void RegisterDependencies() const final {
- for (const auto& arg : ArgsNodes) {
- this->DependsOn(arg);
- }
- }
-
- static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, const TString& funcName) {
- auto function = ARROW_RESULT(registry.GetFunction(funcName));
- MKQL_ENSURE(function != nullptr, "missing function");
- MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function");
- return *function;
- }
-
- TState& GetState(TComputationContext& ctx) const {
- auto& result = ctx.MutableValues[StateIndex];
- if (!result.HasValue()) {
- result = ctx.HolderFactory.Create<TState>(Kernel.Family.FunctionOptions, Kernel.GetArrowKernel(), ArgsValuesDescr, ctx);
- }
-
- return *static_cast<TState*>(result.AsBoxed().Get());
- }
-
-private:
- const ui32 StateIndex;
- const TString FuncName;
- const TVector<IComputationNode*> ArgsNodes;
- const TVector<TType*> ArgsTypes;
-
- const std::vector<arrow::ValueDescr> ArgsValuesDescr;
- const TKernel& Kernel;
-};
-
-class TBlockBitCastWrapper : public TMutableComputationNode<TBlockBitCastWrapper> {
+class TBlockBitCastWrapper : public TBlockFuncNode {
public:
- TBlockBitCastWrapper(TComputationMutables& mutables,
- IComputationNode* arg,
- TType* argType,
- TType* to)
- : TMutableComputationNode(mutables)
- , StateIndex(mutables.CurValueIndex++)
- , Arg(arg)
- , ArgsValuesDescr({ ToValueDescr(argType) })
- , Function(ResolveFunction(to))
- , Kernel(ResolveKernel(Function, ArgsValuesDescr))
+ TBlockBitCastWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* argType, TType* to)
+ : TBlockFuncNode(mutables, { arg }, { argType }, ResolveKernel(argType, to), {}, &CastOptions)
, CastOptions(false)
{
}
-
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- auto& state = GetState(ctx);
-
- state.Values.clear();
- state.Values.emplace_back(TArrowBlock::From(Arg->GetValue(ctx)).GetDatum());
- Y_VERIFY_DEBUG(ArgsValuesDescr[0] == state.Values.back().descr());
-
- auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
- auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
- ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options }));
- ARROW_OK(executor->Execute(state.Values, listener.get()));
- auto output = executor->WrapResults(state.Values, listener->values());
- return ctx.HolderFactory.CreateArrowBlock(std::move(output));
- }
-
private:
- void RegisterDependencies() const final {
- this->DependsOn(Arg);
- }
-
- static const arrow::compute::Function& ResolveFunction(TType* to) {
+ static const arrow::compute::ScalarKernel& ResolveKernel(TType* from, TType* to) {
std::shared_ptr<arrow::DataType> type;
MKQL_ENSURE(ConvertArrowType(to, type), "can't get arrow type");
auto function = ARROW_RESULT(arrow::compute::GetCastFunction(type));
MKQL_ENSURE(function != nullptr, "missing function");
MKQL_ENSURE(function->kind() == arrow::compute::Function::SCALAR, "expected SCALAR function");
- return *function;
- }
-
- TState& GetState(TComputationContext& ctx) const {
- auto& result = ctx.MutableValues[StateIndex];
- if (!result.HasValue()) {
- result = ctx.HolderFactory.Create<TState>((const arrow::compute::FunctionOptions*)&CastOptions, Kernel, ArgsValuesDescr, ctx);
- }
- return *static_cast<TState*>(result.AsBoxed().Get());
+ std::vector<arrow::ValueDescr> args = { ToValueDescr(from) };
+ const auto kernel = ARROW_RESULT(function->DispatchExact(args));
+ return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
}
-private:
- const ui32 StateIndex;
- IComputationNode* Arg;
- const std::vector<arrow::ValueDescr> ArgsValuesDescr;
- const arrow::compute::Function& Function;
- const arrow::compute::ScalarKernel& Kernel;
- arrow::compute::CastOptions CastOptions;
+ const arrow::compute::CastOptions CastOptions;
};
-}
+} // namespace
IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg");
@@ -237,13 +73,8 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto
argsTypes.push_back(callableType->GetArgumentType(i));
}
- return new TBlockFuncWrapper(ctx.Mutables,
- *ctx.FunctionRegistry.GetBuiltins(),
- funcName,
- std::move(argsNodes),
- std::move(argsTypes),
- callableType->GetReturnType()
- );
+ const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType());
+ return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions);
}
IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
new file mode 100644
index 0000000000..52084e3ae6
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.cpp
@@ -0,0 +1,198 @@
+#include "mkql_block_if.h"
+#include "mkql_block_impl.h"
+#include "mkql_block_reader.h"
+#include "mkql_block_builder.h"
+
+#include <ydb/library/yql/minikql/arrow/arrow_defs.h>
+#include <ydb/library/yql/minikql/arrow/arrow_util.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+#include <ydb/library/yql/minikql/mkql_node_cast.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+namespace {
+
+class TBlockIfScalarWrapper : public TMutableComputationNode<TBlockIfScalarWrapper> {
+public:
+ TBlockIfScalarWrapper(TComputationMutables& mutables, IComputationNode* pred, IComputationNode* thenNode, IComputationNode* elseNode, TType* resultType,
+ bool thenIsScalar, bool elseIsScalar)
+ : TMutableComputationNode(mutables)
+ , Pred(pred)
+ , Then(thenNode)
+ , Else(elseNode)
+ , Type(resultType)
+ , ThenIsScalar(thenIsScalar)
+ , ElseIsScalar(elseIsScalar)
+ {
+ }
+
+ NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
+ auto predValue = Pred->GetValue(ctx);
+
+ const bool predScalarValue = GetPrimitiveScalarValue<bool>(*TArrowBlock::From(predValue).GetDatum().scalar());
+ auto result = predScalarValue ? Then->GetValue(ctx) : Else->GetValue(ctx);
+
+ if (ThenIsScalar == ElseIsScalar || (predScalarValue ? !ThenIsScalar : !ElseIsScalar)) {
+ // can return result as-is
+ return result.Release();
+ }
+
+ auto other = predScalarValue ? Else->GetValue(ctx) : Then->GetValue(ctx);
+ const auto& otherDatum = TArrowBlock::From(other).GetDatum();
+ MKQL_ENSURE(otherDatum.is_arraylike(), "Expecting array");
+
+ std::shared_ptr<arrow::Scalar> resultScalar = TArrowBlock::From(result).GetDatum().scalar();
+
+ TVector<std::shared_ptr<arrow::ArrayData>> resultArrays;
+ ForEachArrayData(otherDatum, [&](const std::shared_ptr<arrow::ArrayData>& otherData) {
+ auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, Type, ctx.ArrowMemoryPool);
+ ForEachArrayData(chunk, [&](const auto& array) {
+ resultArrays.push_back(array);
+ });
+ });
+ return ctx.HolderFactory.CreateArrowBlock(MakeArray(resultArrays));
+ }
+private:
+ void RegisterDependencies() const final {
+ DependsOn(Pred);
+ DependsOn(Then);
+ DependsOn(Else);
+ }
+
+ IComputationNode* const Pred;
+ IComputationNode* const Then;
+ IComputationNode* const Else;
+ TType* const Type;
+ const bool ThenIsScalar;
+ const bool ElseIsScalar;
+};
+
+template<bool ThenIsScalar, bool ElseIsScalar>
+class TIfBlockExec {
+public:
+ explicit TIfBlockExec(TType* type)
+ : Type(type)
+ , ThenReader(MakeBlockReader(type))
+ , ElseReader(MakeBlockReader(type))
+ {
+ }
+
+ arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
+ arrow::Datum predDatum = batch.values[0];
+ arrow::Datum thenDatum = batch.values[1];
+ arrow::Datum elseDatum = batch.values[2];
+
+ TBlockItem thenItem;
+ const arrow::ArrayData* thenArray = nullptr;
+ if constexpr(ThenIsScalar) {
+ thenItem = ThenReader->GetScalarItem(*thenDatum.scalar());
+ } else {
+ MKQL_ENSURE(thenDatum.is_array(), "Expecting array");
+ thenArray = thenDatum.array().get();
+ }
+
+ TBlockItem elseItem;
+ const arrow::ArrayData* elseArray = nullptr;
+ if constexpr(ElseIsScalar) {
+ elseItem = ElseReader->GetScalarItem(*elseDatum.scalar());
+ } else {
+ MKQL_ENSURE(elseDatum.is_array(), "Expecting array");
+ elseArray = elseDatum.array().get();
+ }
+
+ MKQL_ENSURE(predDatum.is_array(), "Expecting array");
+ const std::shared_ptr<arrow::ArrayData>& pred = predDatum.array();
+
+ const size_t len = pred->length;
+ auto builder = MakeBlockBuilder(Type, *ctx->memory_pool(), len);
+ const ui8* predValues = pred->GetValues<uint8_t>(1);
+ for (size_t i = 0; i < len; ++i) {
+ if constexpr (!ThenIsScalar) {
+ thenItem = ThenReader->GetItem(*thenArray, i);
+ }
+ if constexpr (!ElseIsScalar) {
+ elseItem = ElseReader->GetItem(*elseArray, i);
+ }
+
+ ui64 mask = -ui64(predValues[i]);
+
+ TBlockItem result;
+ ui64 low = (thenItem.Low() & mask) | (elseItem.Low() & ~mask);
+ ui64 high = (thenItem.High() & mask) | (elseItem.High() & ~mask);
+ builder->Add(TBlockItem{low, high});
+ }
+ *res = builder->Build(true);
+ return arrow::Status::OK();
+ }
+
+private:
+ const std::unique_ptr<IBlockReader> ThenReader;
+ const std::unique_ptr<IBlockReader> ElseReader;
+ TType* const Type;
+};
+
+
+template<bool ThenIsScalar, bool ElseIsScalar>
+std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockIfKernel(const TVector<TType*>& argTypes, TType* resultType) {
+ using TExec = TIfBlockExec<ThenIsScalar, ElseIsScalar>;
+
+ auto exec = std::make_shared<TExec>(AS_TYPE(TBlockType, resultType)->GetItemType());
+ auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
+ [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
+ return exec->Exec(ctx, batch, res);
+ });
+
+ kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
+ return kernel;
+}
+
+} // namespace
+
+IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
+
+ auto pred = callable.GetInput(0);
+ auto thenNode = callable.GetInput(1);
+ auto elseNode = callable.GetInput(2);
+
+ auto predType = AS_TYPE(TBlockType, pred.GetStaticType());
+ MKQL_ENSURE(AS_TYPE(TDataType, predType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
+ "Expected bool as first argument");
+
+ auto thenType = AS_TYPE(TBlockType, thenNode.GetStaticType());
+ auto elseType = AS_TYPE(TBlockType, elseNode.GetStaticType());
+ MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");
+
+ auto predCompute = LocateNode(ctx.NodeLocator, callable, 0);
+ auto thenCompute = LocateNode(ctx.NodeLocator, callable, 1);
+ auto elseCompute = LocateNode(ctx.NodeLocator, callable, 2);
+
+ bool predIsScalar = predType->GetShape() == TBlockType::EShape::Scalar;
+ bool thenIsScalar = thenType->GetShape() == TBlockType::EShape::Scalar;
+ bool elseIsScalar = elseType->GetShape() == TBlockType::EShape::Scalar;
+
+ if (predIsScalar) {
+ return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType->GetItemType(),
+ thenIsScalar, elseIsScalar);
+ }
+
+ TVector<IComputationNode*> argsNodes = { predCompute, thenCompute, elseCompute };
+ TVector<TType*> argsTypes = { predType, thenType, elseType };
+
+ std::shared_ptr<arrow::compute::ScalarKernel> kernel;
+ if (thenIsScalar && elseIsScalar) {
+ kernel = MakeBlockIfKernel<true, true>(argsTypes, thenType);
+ } else if (thenIsScalar && !elseIsScalar) {
+ kernel = MakeBlockIfKernel<true, false>(argsTypes, thenType);
+ } else if (!thenIsScalar && elseIsScalar) {
+ kernel = MakeBlockIfKernel<false, true>(argsTypes, thenType);
+ } else {
+ kernel = MakeBlockIfKernel<false, false>(argsTypes, thenType);
+ }
+
+ return new TBlockFuncNode(ctx.Mutables, std::move(argsNodes), argsTypes, *kernel, kernel);
+}
+
+}
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_if.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.h
new file mode 100644
index 0000000000..62fc88c2a2
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_if.h
@@ -0,0 +1,10 @@
+#pragma once
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+
+}
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
new file mode 100644
index 0000000000..f0b709d949
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.cpp
@@ -0,0 +1,183 @@
+#include "mkql_block_impl.h"
+#include "mkql_block_builder.h"
+#include "mkql_block_reader.h"
+
+#include <ydb/library/yql/minikql/arrow/mkql_functions.h>
+#include <ydb/library/yql/minikql/mkql_node_builder.h>
+#include <ydb/library/yql/minikql/arrow/arrow_util.h>
+
+#include <arrow/compute/exec_internal.h>
+
+namespace NKikimr::NMiniKQL {
+
+namespace {
+
+class TArgsDechunker {
+public:
+ explicit TArgsDechunker(std::vector<arrow::Datum>&& args)
+ : Args(std::move(args))
+ , Arrays(Args.size())
+ {
+ for (size_t i = 0; i < Args.size(); ++i) {
+ if (Args[i].is_arraylike()) {
+ ForEachArrayData(Args[i], [&](const auto& data) {
+ Arrays[i].push_back(data);
+ });
+ }
+ }
+ }
+
+ bool Next(std::vector<arrow::Datum>& chunk) {
+ if (Finish) {
+ return false;
+ }
+
+ size_t minSize = Max<size_t>();
+ bool haveData = false;
+ chunk.resize(Args.size());
+ for (size_t i = 0; i < Args.size(); ++i) {
+ if (Args[i].is_scalar()) {
+ chunk[i] = Args[i];
+ continue;
+ }
+ while (!Arrays[i].empty() && Arrays[i].front()->length == 0) {
+ Arrays[i].pop_front();
+ }
+ if (!Arrays[i].empty()) {
+ haveData = true;
+ minSize = std::min<size_t>(minSize, Arrays[i].front()->length);
+ } else {
+ minSize = 0;
+ }
+ }
+
+ MKQL_ENSURE(!haveData || minSize > 0, "Block length mismatch");
+ if (!haveData) {
+ Finish = true;
+ return false;
+ }
+
+ for (size_t i = 0; i < Args.size(); ++i) {
+ if (!Args[i].is_scalar()) {
+ MKQL_ENSURE(!Arrays[i].empty(), "Block length mismatch");
+ chunk[i] = arrow::Datum(Chop(Arrays[i].front(), minSize));
+ }
+ }
+ return true;
+ }
+private:
+ const std::vector<arrow::Datum> Args;
+ std::vector<std::deque<std::shared_ptr<arrow::ArrayData>>> Arrays;
+ bool Finish = false;
+};
+
+std::vector<arrow::ValueDescr> ToValueDescr(const TVector<TType*>& types) {
+ std::vector<arrow::ValueDescr> res;
+ res.reserve(types.size());
+ for (const auto& type : types) {
+ res.emplace_back(ToValueDescr(type));
+ }
+
+ return res;
+}
+
+} // namespace
+
+arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool) {
+ MKQL_ENSURE(len > 0, "Invalid block size");
+ auto reader = MakeBlockReader(type);
+ auto builder = MakeBlockBuilder(type, pool, len);
+
+ auto scalarItem = reader->GetScalarItem(scalar);
+ for (size_t i = 0; i < len; ++i) {
+ builder->Add(scalarItem);
+ }
+
+ return builder->Build(true);
+}
+
+arrow::ValueDescr ToValueDescr(TType* type) {
+ arrow::ValueDescr ret;
+ MKQL_ENSURE(ConvertInputArrowType(type, ret), "can't get arrow type");
+ return ret;
+}
+
+std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes) {
+ std::vector<arrow::compute::InputType> result;
+ result.reserve(argTypes.size());
+ for (auto& type : argTypes) {
+ result.emplace_back(ToValueDescr(type));
+ }
+ return result;
+}
+
+arrow::compute::OutputType ConvertToOutputType(TType* output) {
+ return arrow::compute::OutputType(ToValueDescr(output));
+}
+
+TBlockFuncNode::TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes,
+ const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel,
+ std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder,
+ const arrow::compute::FunctionOptions* functionOptions)
+ : TMutableComputationNode(mutables)
+ , StateIndex(mutables.CurValueIndex++)
+ , ArgsNodes(std::move(argsNodes))
+ , ArgsValuesDescr(ToValueDescr(argsTypes))
+ , Kernel(kernel)
+ , KernelHolder(std::move(kernelHolder))
+ , Options(functionOptions)
+ , ScalarOutput(GetResultShape(argsTypes) == TBlockType::EShape::Scalar)
+{
+}
+
+NUdf::TUnboxedValuePod TBlockFuncNode::DoCalculate(TComputationContext& ctx) const {
+ auto& state = GetState(ctx);
+
+ std::vector<arrow::Datum> argDatums;
+ for (ui32 i = 0; i < ArgsNodes.size(); ++i) {
+ argDatums.emplace_back(TArrowBlock::From(ArgsNodes[i]->GetValue(ctx)).GetDatum());
+ Y_VERIFY_DEBUG(ArgsValuesDescr[i] == argDatums.back().descr());
+ }
+
+ auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, Options }));
+
+ if (ScalarOutput) {
+ auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
+ ARROW_OK(executor->Execute(argDatums, listener.get()));
+ auto output = executor->WrapResults(argDatums, listener->values());
+ return ctx.HolderFactory.CreateArrowBlock(std::move(output));
+ }
+
+ TArgsDechunker dechunker(std::move(argDatums));
+ std::vector<arrow::Datum> chunk;
+ TVector<std::shared_ptr<arrow::ArrayData>> arrays;
+
+ while (dechunker.Next(chunk)) {
+ arrow::compute::detail::DatumAccumulator listener;
+ ARROW_OK(executor->Execute(chunk, &listener));
+ auto output = executor->WrapResults(chunk, listener.values());
+
+ ForEachArrayData(output, [&](const auto& arr) { arrays.push_back(arr); });
+ }
+
+ return ctx.HolderFactory.CreateArrowBlock(MakeArray(arrays));
+}
+
+
+void TBlockFuncNode::RegisterDependencies() const {
+ for (const auto& arg : ArgsNodes) {
+ DependsOn(arg);
+ }
+}
+
+TBlockFuncNode::TState& TBlockFuncNode::GetState(TComputationContext& ctx) const {
+ auto& result = ctx.MutableValues[StateIndex];
+ if (!result.HasValue()) {
+ result = ctx.HolderFactory.Create<TState>(Options, Kernel, ArgsValuesDescr, ctx);
+ }
+
+ return *static_cast<TState*>(result.AsBoxed().Get());
+}
+
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h
index 58885a2b1e..3e82ae1e61 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_impl.h
@@ -6,9 +6,61 @@
#include <ydb/library/yql/minikql/arrow/arrow_util.h>
#include <arrow/array.h>
+#include <arrow/scalar.h>
+#include <arrow/datum.h>
+#include <arrow/compute/kernel.h>
namespace NKikimr::NMiniKQL {
+arrow::Datum MakeArrayFromScalar(const arrow::Scalar& scalar, size_t len, TType* type, arrow::MemoryPool& pool);
+
+arrow::ValueDescr ToValueDescr(TType* type);
+
+std::vector<arrow::compute::InputType> ConvertToInputTypes(const TVector<TType*>& argTypes);
+arrow::compute::OutputType ConvertToOutputType(TType* output);
+
+class TBlockFuncNode : public TMutableComputationNode<TBlockFuncNode> {
+public:
+ TBlockFuncNode(TComputationMutables& mutables, TVector<IComputationNode*>&& argsNodes,
+ const TVector<TType*>& argsTypes, const arrow::compute::ScalarKernel& kernel,
+ std::shared_ptr<arrow::compute::ScalarKernel> kernelHolder = {},
+ const arrow::compute::FunctionOptions* functionOptions = nullptr);
+
+ NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const;
+private:
+ struct TState : public TComputationValue<TState> {
+ using TComputationValue::TComputationValue;
+
+ TState(TMemoryUsageInfo* memInfo, const arrow::compute::FunctionOptions* options,
+ const arrow::compute::ScalarKernel& kernel, const std::vector<arrow::ValueDescr>& argsValuesDescr,
+ TComputationContext& ctx)
+ : TComputationValue(memInfo)
+ , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr)
+ , KernelContext(&ExecContext)
+ {
+ if (kernel.init) {
+ State = ARROW_RESULT(kernel.init(&KernelContext, { &kernel, argsValuesDescr, options }));
+ KernelContext.SetState(State.get());
+ }
+ }
+
+ arrow::compute::ExecContext ExecContext;
+ arrow::compute::KernelContext KernelContext;
+ std::unique_ptr<arrow::compute::KernelState> State;
+ };
+
+ void RegisterDependencies() const final;
+ TState& GetState(TComputationContext& ctx) const;
+private:
+ const ui32 StateIndex;
+ const TVector<IComputationNode*> ArgsNodes;
+ const std::vector<arrow::ValueDescr> ArgsValuesDescr;
+ const arrow::compute::ScalarKernel& Kernel;
+ const std::shared_ptr<arrow::compute::ScalarKernel> KernelHolder;
+ const arrow::compute::FunctionOptions* const Options;
+ const bool ScalarOutput;
+};
+
template <typename TDerived>
class TStatefulWideFlowBlockComputationNode: public TWideFlowBaseComputationNode<TDerived>
{
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h
index 84924977f0..37c357fcc9 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_item.h
@@ -36,6 +36,19 @@ public:
Raw.Simple.Meta = static_cast<ui8>(EMarkers::Present);
}
+ inline TBlockItem(ui64 low, ui64 high) {
+ Raw.Halfs[0] = low;
+ Raw.Halfs[1] = high;
+ }
+
+ inline ui64 Low() const {
+ return Raw.Halfs[0];
+ }
+
+ inline ui64 High() const {
+ return Raw.Halfs[1];
+ }
+
template <typename T, typename = std::enable_if_t<NYql::NUdf::TPrimitiveDataType<T>::Result>>
inline T As() const;
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
index 789cb67df2..2e02a6a9c6 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
@@ -43,7 +43,7 @@ public:
builder->Add(result);
}
- return builder->Build(ctx, true);
+ return ctx.HolderFactory.CreateArrowBlock(builder->Build(true));
}
private:
@@ -98,7 +98,7 @@ public:
for (size_t i = 0; i < Width_; ++i) {
if (auto* out = output[i]; out != nullptr) {
- *out = s.Builders_[i]->Build(ctx, s.IsFinished_);
+ *out = ctx.HolderFactory.CreateArrowBlock(s.Builders_[i]->Build(s.IsFinished_));
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
index e29c051c97..ddeecfa047 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
@@ -8,6 +8,7 @@
#include "mkql_blocks.h"
#include "mkql_block_agg.h"
#include "mkql_block_coalesce.h"
+#include "mkql_block_if.h"
#include "mkql_block_logical.h"
#include "mkql_block_compress.h"
#include "mkql_block_skiptake.h"
@@ -276,6 +277,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"WideTakeBlocks", &WrapWideTakeBlocks},
{"AsScalar", &WrapAsScalar},
{"BlockCoalesce", &WrapBlockCoalesce},
+ {"BlockIf", &WrapBlockIf},
{"BlockAnd", &WrapBlockAnd},
{"BlockOr", &WrapBlockOr},
{"BlockXor", &WrapBlockXor},
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 6eb2f14802..a4f25520f0 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -5312,6 +5312,23 @@ TRuntimeNode TProgramBuilder::PgInternal0(TType* returnType) {
return TRuntimeNode(callableBuilder.Build(), false);
}
+TRuntimeNode TProgramBuilder::BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch) {
+ const auto conditionType = AS_TYPE(TBlockType, condition.GetStaticType());
+ MKQL_ENSURE(AS_TYPE(TDataType, conditionType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
+ "Expected bool as first argument");
+
+ const auto thenType = AS_TYPE(TBlockType, thenBranch.GetStaticType());
+ const auto elseType = AS_TYPE(TBlockType, elseBranch.GetStaticType());
+ MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");
+
+ auto returnType = NewBlockType(thenType->GetItemType(), GetResultShape({conditionType, thenType, elseType}));
+ TCallableBuilder callableBuilder(Env, __func__, returnType);
+ callableBuilder.Add(condition);
+ callableBuilder.Add(thenBranch);
+ callableBuilder.Add(elseBranch);
+ return TRuntimeNode(callableBuilder.Build(), false);
+}
+
TRuntimeNode TProgramBuilder::BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args) {
for (const auto& arg : args) {
MKQL_ENSURE(arg.GetStaticType()->IsBlock(), "Expected Block type");
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index 99b4233201..0e2076e798 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -257,6 +257,7 @@ public:
TRuntimeNode BlockOr(TRuntimeNode first, TRuntimeNode second);
TRuntimeNode BlockXor(TRuntimeNode first, TRuntimeNode second);
+ TRuntimeNode BlockIf(TRuntimeNode condition, TRuntimeNode thenBranch, TRuntimeNode elseBranch);
TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args);
TRuntimeNode BlockBitCast(TRuntimeNode value, TType* targetType);
TRuntimeNode BlockCombineAll(TRuntimeNode flow, std::optional<ui32> filterColumn,
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index 7531252eaa..c0f1254586 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -552,7 +552,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
{"ListFromRange", &TProgramBuilder::ListFromRange},
- {"PreserveStream", &TProgramBuilder::PreserveStream}
+ {"PreserveStream", &TProgramBuilder::PreserveStream},
+
+ {"BlockIf", &TProgramBuilder::BlockIf},
});
AddSimpleCallables({