summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <[email protected]>2022-10-30 10:43:33 +0300
committeraneporada <[email protected]>2022-10-30 10:43:33 +0300
commit9d1d72f5a7fd1f6f3649918ce82af1fb42d2d8be (patch)
tree9071d235b3c45bb79a25c1a79d31b1b8e032214d
parenta34b8d13197580c648158c2d2a29812d0b211713 (diff)
Support Wide{Take/Skip}Blocks
-rw-r--r--ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp43
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp2
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_wide.cpp42
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_wide.h2
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.cpp32
-rw-r--r--ydb/library/yql/core/yql_expr_type_annotation.h1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/CMakeLists.txt1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp164
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h11
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp3
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp179
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp61
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h3
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp3
14 files changed, 508 insertions, 39 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
index b061c7c18f3..aadfc855c62 100644
--- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
+++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp
@@ -4563,6 +4563,47 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext&
.Build();
}
+TExprNode::TPtr OptimizeSkipTakeToBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) {
+ if (!types.ArrowResolver) {
+ return node;
+ }
+
+ if (node->Head().GetTypeAnn()->GetKind() != ETypeAnnotationKind::Flow) {
+ return node;
+ }
+
+ auto flowItemType = node->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType();
+ if (flowItemType->GetKind() != ETypeAnnotationKind::Multi) {
+ return node;
+ }
+
+ const auto& allTypes = flowItemType->Cast<TMultiExprType>()->GetItems();
+ if (AnyOf(allTypes, [](const TTypeAnnotationNode* type) { return type->GetKind() == ETypeAnnotationKind::Block; })) {
+ return node;
+ }
+
+ bool supported = false;
+ YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Head().Pos()),
+ TVector<const TTypeAnnotationNode*>(allTypes.begin(), allTypes.end()),
+ supported, ctx));
+ if (!supported) {
+ return node;
+ }
+
+ TStringBuf newName = node->Content() == "Skip" ? "WideSkipBlocks" : "WideTakeBlocks";
+ YQL_CLOG(DEBUG, CorePeepHole) << "Convert " << node->Content() << " to " << newName;
+ return ctx.Builder(node->Pos())
+ .Callable("WideFromBlocks")
+ .Callable(0, newName)
+ .Callable(0, "WideToBlocks")
+ .Add(0, node->HeadPtr())
+ .Seal()
+ .Add(1, node->ChildPtr(1))
+ .Seal()
+ .Seal()
+ .Build();
+}
+
TExprNode::TPtr OptimizeWideMaps(const TExprNode::TPtr& node, TExprContext& ctx) {
if (const auto& input = node->Head(); input.IsCallable("ExpandMap")) {
YQL_CLOG(DEBUG, CorePeepHole) << "Fuse " << node->Content() << " with " << input.Content();
@@ -6233,6 +6274,8 @@ struct TPeepHoleRules {
{"WideMap", &OptimizeWideMapBlocks},
{"NarrowMap", &OptimizeWideMapBlocks},
{"WideToBlocks", &OptimizeWideToBlocks},
+ {"Skip", &OptimizeSkipTakeToBlocks},
+ {"Take", &OptimizeSkipTakeToBlocks},
};
TPeepHoleRules()
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index c8dd50f9023..4ec696f19c2 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -11691,6 +11691,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["WideToBlocks"] = &WideToBlocksWrapper;
Functions["WideFromBlocks"] = &WideFromBlocksWrapper;
+ Functions["WideSkipBlocks"] = &WideSkipTakeBlocksWrapper;
+ Functions["WideTakeBlocks"] = &WideSkipTakeBlocksWrapper;
Functions["AsScalar"] = &AsScalarWrapper;
ExtFunctions["BlockFunc"] = &BlockFuncWrapper;
ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_wide.cpp b/ydb/library/yql/core/type_ann/type_ann_wide.cpp
index 93e5622e187..6e51ff57c79 100644
--- a/ydb/library/yql/core/type_ann/type_ann_wide.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_wide.cpp
@@ -658,40 +658,44 @@ IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, T
return IGraphTransformer::TStatus::Error;
}
- if (!EnsureWideFlowType(input->Head(), ctx.Expr)) {
+ TTypeAnnotationNode::TListType retMultiType;
+ if (!EnsureWideFlowBlockType(input->Head(), retMultiType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- const auto multiType = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>();
- if (multiType->GetSize() == 0) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Expected at least one column"));
+ YQL_ENSURE(!retMultiType.empty());
+ retMultiType.pop_back();
+ auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType);
+ input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType));
+ return IGraphTransformer::TStatus::Ok;
+}
+
+IGraphTransformer::TStatus WideSkipTakeBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ if (!EnsureArgsCount(*input, 2U, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
- TTypeAnnotationNode::TListType retMultiType;
- bool isScalar;
- for (const auto& type : multiType->GetItems()) {
- if (!EnsureBlockOrScalarType(input->Pos(), *type, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
- }
-
- retMultiType.push_back(GetBlockItemType(*type, isScalar));
+ TTypeAnnotationNode::TListType blockItemTypes;
+ if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
}
- if (!isScalar) {
- ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Last column should be a scalar"));
+ output = input;
+ const TTypeAnnotationNode* expectedType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64);
+ auto convertStatus = TryConvertTo(input->ChildRef(1), *expectedType, ctx.Expr);
+ if (convertStatus.Level == IGraphTransformer::TStatus::Error) {
+ ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(1)->Pos()), "Can not convert argument to Uint64"));
return IGraphTransformer::TStatus::Error;
}
- if (!EnsureSpecificDataType(input->Pos(), *retMultiType.back(), EDataSlot::Uint64, ctx.Expr)) {
- return IGraphTransformer::TStatus::Error;
+ if (convertStatus.Level != IGraphTransformer::TStatus::Ok) {
+ return convertStatus;
}
- retMultiType.pop_back();
- auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType);
- input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType));
+ input->SetTypeAnn(input->Head().GetTypeAnn());
return IGraphTransformer::TStatus::Ok;
}
+
} // namespace NTypeAnnImpl
}
diff --git a/ydb/library/yql/core/type_ann/type_ann_wide.h b/ydb/library/yql/core/type_ann/type_ann_wide.h
index b02cf26a2eb..f9c3227f681 100644
--- a/ydb/library/yql/core/type_ann/type_ann_wide.h
+++ b/ydb/library/yql/core/type_ann/type_ann_wide.h
@@ -23,5 +23,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus WideToBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+
+ IGraphTransformer::TStatus WideSkipTakeBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
} // namespace NTypeAnnImpl
} // namespace NYql
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp
index 1a286d143fb..617e9166c6c 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.cpp
+++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp
@@ -2677,6 +2677,38 @@ bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& typ
return true;
}
+bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx) {
+ if (!EnsureWideFlowType(node, ctx)) {
+ return false;
+ }
+
+ auto& items = node.GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
+ if (items.empty()) {
+ ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), "Expected at least one column"));
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ bool isScalar;
+ for (const auto& type : items) {
+ if (!EnsureBlockOrScalarType(node.Pos(), *type, ctx)) {
+ return false;
+ }
+
+ blockItemTypes.push_back(GetBlockItemType(*type, isScalar));
+ }
+
+ if (!isScalar) {
+ ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), "Last column should be a scalar"));
+ return false;
+ }
+
+ if (!EnsureSpecificDataType(node.Pos(), *blockItemTypes.back(), EDataSlot::Uint64, ctx)) {
+ return false;
+ }
+
+ return true;
+}
+
bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx) {
if (!node.GetTypeAnn()) {
YQL_ENSURE(node.Type() == TExprNode::Lambda);
diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h
index 4392a4b5504..781d6d13fba 100644
--- a/ydb/library/yql/core/yql_expr_type_annotation.h
+++ b/ydb/library/yql/core/yql_expr_type_annotation.h
@@ -118,6 +118,7 @@ bool EnsureFlowType(const TExprNode& node, TExprContext& ctx);
bool EnsureFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx);
bool EnsureWideFlowType(const TExprNode& node, TExprContext& ctx);
bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx);
+bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx);
bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx);
bool EnsureOptionalType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx);
bool EnsureType(const TExprNode& node, TExprContext& ctx);
diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt
index 7312d9050f3..0dea4f84dc8 100644
--- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt
+++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt
@@ -39,6 +39,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp
${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp
new file mode 100644
index 00000000000..2fb510eb99e
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp
@@ -0,0 +1,164 @@
+#include "mkql_block_skiptake.h"
+
+#include <ydb/library/yql/minikql/arrow/arrow_defs.h>
+#include <ydb/library/yql/minikql/mkql_type_builder.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+#include <ydb/library/yql/minikql/mkql_node_builder.h>
+#include <ydb/library/yql/minikql/mkql_node_cast.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+namespace {
+
+class TWideSkipBlocksWrapper: public TStatefulWideFlowComputationNode<TWideSkipBlocksWrapper> {
+ typedef TStatefulWideFlowComputationNode<TWideSkipBlocksWrapper> TBaseComputation;
+
+public:
+ TWideSkipBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, size_t width)
+ : TBaseComputation(mutables, flow, EValueRepresentation::Any)
+ , Flow(flow)
+ , Count(count)
+ , Width(width)
+ {
+ }
+
+ EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
+ if (!state.HasValue()) {
+ state = Count->GetValue(ctx);
+ }
+
+ auto count = state.Get<ui64>();
+ ui64 blockSize = 0;
+ for (;;) {
+ auto result = Flow->FetchValues(ctx, output);
+ if (count == 0 || result != EFetchResult::One) {
+ return result;
+ }
+
+ blockSize = TArrowBlock::From(*output[Width - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
+ if (blockSize > count) {
+ break;
+ }
+ count -= blockSize;
+ state = NUdf::TUnboxedValuePod(count);
+ }
+
+ ui64 tailSize = blockSize - count;
+ for (size_t i = 0; i < Width - 1; ++i) {
+ auto& datum = TArrowBlock::From(*output[i]).GetDatum();
+ if (datum.is_scalar()) {
+ continue;
+ }
+
+ Y_VERIFY_DEBUG(datum.is_array());
+ *output[i] = ctx.HolderFactory.CreateArrowBlock(datum.array()->Slice(count, tailSize));
+ }
+
+ *output[Width - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(static_cast<uint64_t>(tailSize)));
+ state = NUdf::TUnboxedValuePod::Zero();
+ return EFetchResult::One;
+ }
+
+private:
+ void RegisterDependencies() const final {
+ if (const auto flow = FlowDependsOn(Flow)) {
+ DependsOn(flow, Count);
+ }
+ }
+
+ IComputationWideFlowNode* const Flow;
+ IComputationNode* const Count;
+ const size_t Width;
+};
+
+class TWideTakeBlocksWrapper: public TStatefulWideFlowComputationNode<TWideTakeBlocksWrapper> {
+ typedef TStatefulWideFlowComputationNode<TWideTakeBlocksWrapper> TBaseComputation;
+
+public:
+ TWideTakeBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, size_t width)
+ : TBaseComputation(mutables, flow, EValueRepresentation::Any)
+ , Flow(flow)
+ , Count(count)
+ , Width(width)
+ {
+ }
+
+ EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
+ if (!state.HasValue()) {
+ state = Count->GetValue(ctx);
+ }
+
+ auto count = state.Get<ui64>();
+ if (!count) {
+ return EFetchResult::Finish;
+ }
+
+ auto result = Flow->FetchValues(ctx, output);
+ if (result == EFetchResult::One) {
+ ui64 blockSize = TArrowBlock::From(*output[Width - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
+ if (blockSize > count) {
+ for (size_t i = 0; i < Width - 1; ++i) {
+ auto& datum = TArrowBlock::From(*output[i]).GetDatum();
+ if (datum.is_scalar()) {
+ continue;
+ }
+
+ Y_VERIFY_DEBUG(datum.is_array());
+ *output[i] = ctx.HolderFactory.CreateArrowBlock(datum.array()->Slice(0, count));
+ }
+ *output[Width - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(static_cast<uint64_t>(count)));
+ state = NUdf::TUnboxedValuePod::Zero();
+ } else {
+ state = NUdf::TUnboxedValuePod(count - blockSize);
+ }
+ }
+ return result;
+ }
+
+private:
+ void RegisterDependencies() const final {
+ if (const auto flow = FlowDependsOn(Flow)) {
+ DependsOn(flow, Count);
+ }
+ }
+
+ IComputationWideFlowNode* const Flow;
+ IComputationNode* const Count;
+ const size_t Width;
+};
+
+IComputationNode* WrapSkipTake(bool skip, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
+
+ const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
+ const auto tupleType = AS_TYPE(TTupleType, flowType->GetItemType());
+ MKQL_ENSURE(tupleType->GetElementsCount() > 0, "Expected at least one column");
+
+ auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
+ MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
+
+ const auto count = LocateNode(ctx.NodeLocator, callable, 1);
+ const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType());
+ MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
+
+ if (skip) {
+ return new TWideSkipBlocksWrapper(ctx.Mutables, wideFlow, count, tupleType->GetElementsCount());
+ }
+ return new TWideTakeBlocksWrapper(ctx.Mutables, wideFlow, count, tupleType->GetElementsCount());
+}
+
+} //namespace
+
+IComputationNode* WrapWideSkipBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ bool skip = true;
+ return WrapSkipTake(skip, callable, ctx);
+}
+
+IComputationNode* WrapWideTakeBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ bool skip = false;
+ return WrapSkipTake(skip, callable, ctx);
+}
+
+}
+} \ No newline at end of file
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h
new file mode 100644
index 00000000000..79d20f0cb77
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h
@@ -0,0 +1,11 @@
+#pragma once
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+IComputationNode* WrapWideSkipBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+IComputationNode* WrapWideTakeBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+
+}
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
index 5348a807530..8cd9773020c 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
@@ -7,6 +7,7 @@
#include "mkql_block_func.h"
#include "mkql_blocks.h"
#include "mkql_block_agg.h"
+#include "mkql_block_skiptake.h"
#include "mkql_callable.h"
#include "mkql_chain_map.h"
#include "mkql_chain1_map.h"
@@ -268,6 +269,8 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"BlockBitCast", &WrapBlockBitCast},
{"FromBlocks", &WrapFromBlocks},
{"WideFromBlocks", &WrapWideFromBlocks},
+ {"WideSkipBlocks", &WrapWideSkipBlocks},
+ {"WideTakeBlocks", &WrapWideTakeBlocks},
{"AsScalar", &WrapAsScalar},
{"BlockCombineAll", &WrapBlockCombineAll},
{"MakeHeap", &WrapMakeHeap},
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp
new file mode 100644
index 00000000000..4b12ffc7c35
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp
@@ -0,0 +1,179 @@
+#include "mkql_computation_node_ut.h"
+
+#include <ydb/library/yql/minikql/arrow/arrow_defs.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+
+#include <arrow/array/builder_primitive.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+namespace {
+
+class TTestBlockFlowWrapper: public TStatefulWideFlowComputationNode<TTestBlockFlowWrapper> {
+ typedef TStatefulWideFlowComputationNode<TTestBlockFlowWrapper> TBaseComputation;
+
+public:
+ TTestBlockFlowWrapper(TComputationMutables& mutables, size_t blockSize, size_t blockCount)
+ : TBaseComputation(mutables, nullptr, EValueRepresentation::Any)
+ , BlockSize(blockSize)
+ , BlockCount(blockCount)
+ {
+ }
+
+ EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
+ if (!state.HasValue()) {
+ state = NUdf::TUnboxedValue::Zero();
+ }
+
+ ui64 index = state.Get<ui64>();
+ if (index >= BlockCount) {
+ return EFetchResult::Finish;
+ }
+
+ arrow::UInt64Builder builder(&ctx.ArrowMemoryPool);
+ ARROW_OK(builder.Reserve(BlockSize));
+ for (size_t i = 0; i < BlockSize; ++i) {
+ builder.UnsafeAppend(index * BlockSize + i);
+ }
+
+ std::shared_ptr<arrow::ArrayData> block;
+ ARROW_OK(builder.FinishInternal(&block));
+
+ *output[0] = ctx.HolderFactory.CreateArrowBlock(std::move(block));
+ *output[1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(index)));
+ *output[2] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(BlockSize)));
+
+ state = NUdf::TUnboxedValuePod(++index);
+ return EFetchResult::One;
+ }
+
+private:
+ void RegisterDependencies() const final {
+ }
+
+ const size_t BlockSize;
+ const size_t BlockCount;
+};
+
+IComputationNode* WrapTestBlockFlow(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ MKQL_ENSURE(callable.GetInputsCount() == 0, "Expected no args");
+ return new TTestBlockFlowWrapper(ctx.Mutables, 5, 2);
+}
+
+TIntrusivePtr<IRandomProvider> CreateRandomProvider() {
+ return CreateDeterministicRandomProvider(1);
+}
+
+TIntrusivePtr<ITimeProvider> CreateTimeProvider() {
+ return CreateDeterministicTimeProvider(10000000);
+}
+
+TComputationNodeFactory GetTestFactory() {
+ return [](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
+ if (callable.GetType()->GetName() == "TestBlockFlow") {
+ return WrapTestBlockFlow(callable, ctx);
+ }
+ return GetBuiltinFactory()(callable, ctx);
+ };
+}
+
+struct TSetup_ {
+ TSetup_()
+ : Alloc(__LOCATION__)
+ {
+ FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
+ RandomProvider = CreateRandomProvider();
+ TimeProvider = CreateTimeProvider();
+
+ Env.Reset(new TTypeEnvironment(Alloc));
+ PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
+ }
+
+ TAutoPtr<IComputationGraph> BuildGraph(TRuntimeNode pgm, EGraphPerProcess graphPerProcess = EGraphPerProcess::Multi, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
+ Explorer.Walk(pgm.GetNode(), *Env);
+ TComputationPatternOpts opts(Alloc.Ref(), *Env, GetTestFactory(), FunctionRegistry.Get(),
+ NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, "OFF", graphPerProcess);
+ Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
+ return Pattern->Clone(opts.ToComputationOptions(*RandomProvider, *TimeProvider));
+ }
+
+ TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
+ TIntrusivePtr<IRandomProvider> RandomProvider;
+ TIntrusivePtr<ITimeProvider> TimeProvider;
+
+ TScopedAlloc Alloc;
+ THolder<TTypeEnvironment> Env;
+ THolder<TProgramBuilder> PgmBuilder;
+
+ TExploringNodeVisitor Explorer;
+ IComputationPattern::TPtr Pattern;
+};
+
+TRuntimeNode MakeFlow(TSetup_& setup) {
+ TProgramBuilder& pb = *setup.PgmBuilder;
+ TCallableBuilder callableBuilder(*setup.Env, "TestBlockFlow",
+ pb.NewFlowType(
+ pb.NewTupleType({
+ pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Many),
+ pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Scalar),
+ pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Scalar),
+ })));
+ return TRuntimeNode(callableBuilder.Build(), false);
+}
+
+} // namespace
+
+
+Y_UNIT_TEST_SUITE(TMiniKQLWideTakeSkipBlocks) {
+ Y_UNIT_TEST(TestWideTakeSkipBlocks) {
+ TSetup_ setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto flow = MakeFlow(setup);
+
+ const auto part = pb.WideTakeBlocks(pb.WideSkipBlocks(flow, pb.NewDataLiteral<ui64>(3)), pb.NewDataLiteral<ui64>(5));
+ const auto plain = pb.WideFromBlocks(part);
+
+ const auto singleValueFlow = pb.NarrowMap(plain, [&](TRuntimeNode::TList items) -> TRuntimeNode {
+ // 0, 0;
+ // 1, 0;
+ // 2, 0;
+ // 3, 0; -> 3
+ // 4, 0; -> 4
+ // 5, 1; -> 6
+ // 6, 1; -> 7
+ // 7, 1; -> 8
+ // 8, 1;
+ // 9, 1;
+ // 10, 1;
+ return pb.Add(items[0], items[1]);
+ });
+
+ const auto pgmReturn = pb.ForwardList(singleValueFlow);
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue().GetListIterator();
+
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 3);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 4);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 6);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 7);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 8);
+ }
+}
+
+} // namespace NMiniKQL
+} // namespace NKikimr
+
+
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 8371a62d730..859aa2b8597 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -226,6 +226,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test
return multiOptional;
}
+std::vector<TType*> ValidateBlockFlowType(const TType* flowType) {
+ const auto* inputTupleType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flowType)->GetItemType());
+ MKQL_ENSURE(inputTupleType->GetElementsCount() > 0, "Expected at least one column");
+ std::vector<TType*> flowItems;
+ flowItems.reserve(inputTupleType->GetElementsCount());
+ bool isScalar;
+ for (size_t i = 0; i < inputTupleType->GetElementsCount(); ++i) {
+ auto blockType = AS_TYPE(TBlockType, inputTupleType->GetElementType(i));
+ isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
+ auto withoutBlock = blockType->GetItemType();
+ flowItems.push_back(withoutBlock);
+ }
+
+ MKQL_ENSURE(isScalar, "Last column should be scalar");
+ MKQL_ENSURE(AS_TYPE(TDataType, flowItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
+ return flowItems;
+}
+
} // namespace
std::string_view ScriptTypeAsStr(EScriptType type) {
@@ -1449,31 +1467,22 @@ TRuntimeNode TProgramBuilder::FromBlocks(TRuntimeNode flow) {
}
TRuntimeNode TProgramBuilder::WideFromBlocks(TRuntimeNode flow) {
- TType* outputTupleType;
- {
- const auto* inputTupleType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType());
- MKQL_ENSURE(inputTupleType->GetElementsCount() > 0, "Expected at least one column");
- std::vector<TType*> outputTupleItems;
- outputTupleItems.reserve(inputTupleType->GetElementsCount());
- bool isScalar;
- for (size_t i = 0; i < inputTupleType->GetElementsCount(); ++i) {
- auto blockType = AS_TYPE(TBlockType, inputTupleType->GetElementType(i));
- isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
- auto withoutBlock = blockType->GetItemType();
- outputTupleItems.push_back(withoutBlock);
- }
-
- MKQL_ENSURE(isScalar, "Last column should be scalar");
- MKQL_ENSURE(AS_TYPE(TDataType, outputTupleItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
- outputTupleItems.pop_back();
- outputTupleType = NewTupleType(outputTupleItems);
- }
-
+ auto outputTupleItems = ValidateBlockFlowType(flow.GetStaticType());
+ outputTupleItems.pop_back();
+ TType* outputTupleType = NewTupleType(outputTupleItems);
TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputTupleType));
callableBuilder.Add(flow);
return TRuntimeNode(callableBuilder.Build(), false);
}
+TRuntimeNode TProgramBuilder::WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count) {
+ return BuildWideSkipTakeBlocks(__func__, flow, count);
+}
+
+TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count) {
+ return BuildWideSkipTakeBlocks(__func__, flow, count);
+}
+
TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) {
TCallableBuilder callableBuilder(Env, __func__, NewBlockType(value.GetStaticType(), TBlockType::EShape::Scalar));
callableBuilder.Add(value);
@@ -2451,6 +2460,18 @@ TRuntimeNode TProgramBuilder::BuildMinMax(const std::string_view& callableName,
return BuildMinMax(callableName, args.data(), args.size());
}
+TRuntimeNode TProgramBuilder::BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) {
+ ValidateBlockFlowType(flow.GetStaticType());
+
+ MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data");
+ MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
+
+ TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType());
+ callableBuilder.Add(flow);
+ callableBuilder.Add(count);
+ return TRuntimeNode(callableBuilder.Build(), false);
+}
+
TRuntimeNode TProgramBuilder::Min(const TArrayRef<const TRuntimeNode>& args) {
return BuildMinMax(__func__, args.data(), args.size());
}
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index b97041c0391..00f690d4397 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -243,6 +243,8 @@ public:
TRuntimeNode WideToBlocks(TRuntimeNode flow);
TRuntimeNode FromBlocks(TRuntimeNode flow);
TRuntimeNode WideFromBlocks(TRuntimeNode flow);
+ TRuntimeNode WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count);
+ TRuntimeNode WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count);
TRuntimeNode AsScalar(TRuntimeNode value);
TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args);
@@ -673,6 +675,7 @@ protected:
TRuntimeNode BuildLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& args);
TRuntimeNode BuildBinaryLogical(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2);
TRuntimeNode BuildMinMax(const std::string_view& callableName, const TRuntimeNode* data, size_t size);
+ TRuntimeNode BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count);
private:
TRuntimeNode BuildWideFilter(const std::string_view& callableName, TRuntimeNode flow, const TNarrowLambda& handler);
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index d46d4c1c066..a456f706793 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -518,6 +518,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
{"Take", &TProgramBuilder::Take},
{"Limit", &TProgramBuilder::Take},
+ {"WideTakeBlocks", &TProgramBuilder::WideTakeBlocks},
+ {"WideSkipBlocks", &TProgramBuilder::WideSkipBlocks},
+
{"Append", &TProgramBuilder::Append},
{"Insert", &TProgramBuilder::Append},
{"Prepend", &TProgramBuilder::Prepend},