diff options
author | aneporada <[email protected]> | 2022-10-30 10:43:33 +0300 |
---|---|---|
committer | aneporada <[email protected]> | 2022-10-30 10:43:33 +0300 |
commit | 9d1d72f5a7fd1f6f3649918ce82af1fb42d2d8be (patch) | |
tree | 9071d235b3c45bb79a25c1a79d31b1b8e032214d | |
parent | a34b8d13197580c648158c2d2a29812d0b211713 (diff) |
Support Wide{Take/Skip}Blocks
14 files changed, 508 insertions, 39 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index b061c7c18f3..aadfc855c62 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -4563,6 +4563,47 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& .Build(); } +TExprNode::TPtr OptimizeSkipTakeToBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) { + if (!types.ArrowResolver) { + return node; + } + + if (node->Head().GetTypeAnn()->GetKind() != ETypeAnnotationKind::Flow) { + return node; + } + + auto flowItemType = node->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType(); + if (flowItemType->GetKind() != ETypeAnnotationKind::Multi) { + return node; + } + + const auto& allTypes = flowItemType->Cast<TMultiExprType>()->GetItems(); + if (AnyOf(allTypes, [](const TTypeAnnotationNode* type) { return type->GetKind() == ETypeAnnotationKind::Block; })) { + return node; + } + + bool supported = false; + YQL_ENSURE(types.ArrowResolver->AreTypesSupported(ctx.GetPosition(node->Head().Pos()), + TVector<const TTypeAnnotationNode*>(allTypes.begin(), allTypes.end()), + supported, ctx)); + if (!supported) { + return node; + } + + TStringBuf newName = node->Content() == "Skip" ? "WideSkipBlocks" : "WideTakeBlocks"; + YQL_CLOG(DEBUG, CorePeepHole) << "Convert " << node->Content() << " to " << newName; + return ctx.Builder(node->Pos()) + .Callable("WideFromBlocks") + .Callable(0, newName) + .Callable(0, "WideToBlocks") + .Add(0, node->HeadPtr()) + .Seal() + .Add(1, node->ChildPtr(1)) + .Seal() + .Seal() + .Build(); +} + TExprNode::TPtr OptimizeWideMaps(const TExprNode::TPtr& node, TExprContext& ctx) { if (const auto& input = node->Head(); input.IsCallable("ExpandMap")) { YQL_CLOG(DEBUG, CorePeepHole) << "Fuse " << node->Content() << " with " << input.Content(); @@ -6233,6 +6274,8 @@ struct TPeepHoleRules { {"WideMap", &OptimizeWideMapBlocks}, {"NarrowMap", &OptimizeWideMapBlocks}, {"WideToBlocks", &OptimizeWideToBlocks}, + {"Skip", &OptimizeSkipTakeToBlocks}, + {"Take", &OptimizeSkipTakeToBlocks}, }; TPeepHoleRules() diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index c8dd50f9023..4ec696f19c2 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -11691,6 +11691,8 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["WideToBlocks"] = &WideToBlocksWrapper; Functions["WideFromBlocks"] = &WideFromBlocksWrapper; + Functions["WideSkipBlocks"] = &WideSkipTakeBlocksWrapper; + Functions["WideTakeBlocks"] = &WideSkipTakeBlocksWrapper; Functions["AsScalar"] = &AsScalarWrapper; ExtFunctions["BlockFunc"] = &BlockFuncWrapper; ExtFunctions["BlockBitCast"] = &BlockBitCastWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_wide.cpp b/ydb/library/yql/core/type_ann/type_ann_wide.cpp index 93e5622e187..6e51ff57c79 100644 --- a/ydb/library/yql/core/type_ann/type_ann_wide.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_wide.cpp @@ -658,40 +658,44 @@ IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, T return IGraphTransformer::TStatus::Error; } - if (!EnsureWideFlowType(input->Head(), ctx.Expr)) { + TTypeAnnotationNode::TListType retMultiType; + if (!EnsureWideFlowBlockType(input->Head(), retMultiType, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - const auto multiType = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>(); - if (multiType->GetSize() == 0) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Expected at least one column")); + YQL_ENSURE(!retMultiType.empty()); + retMultiType.pop_back(); + auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); + input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); + return IGraphTransformer::TStatus::Ok; +} + +IGraphTransformer::TStatus WideSkipTakeBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + if (!EnsureArgsCount(*input, 2U, ctx.Expr)) { return IGraphTransformer::TStatus::Error; } - TTypeAnnotationNode::TListType retMultiType; - bool isScalar; - for (const auto& type : multiType->GetItems()) { - if (!EnsureBlockOrScalarType(input->Pos(), *type, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } - - retMultiType.push_back(GetBlockItemType(*type, isScalar)); + TTypeAnnotationNode::TListType blockItemTypes; + if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; } - if (!isScalar) { - ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "Last column should be a scalar")); + output = input; + const TTypeAnnotationNode* expectedType = ctx.Expr.MakeType<TDataExprType>(EDataSlot::Uint64); + auto convertStatus = TryConvertTo(input->ChildRef(1), *expectedType, ctx.Expr); + if (convertStatus.Level == IGraphTransformer::TStatus::Error) { + ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Child(1)->Pos()), "Can not convert argument to Uint64")); return IGraphTransformer::TStatus::Error; } - if (!EnsureSpecificDataType(input->Pos(), *retMultiType.back(), EDataSlot::Uint64, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; + if (convertStatus.Level != IGraphTransformer::TStatus::Ok) { + return convertStatus; } - retMultiType.pop_back(); - auto outputItemType = ctx.Expr.MakeType<TMultiExprType>(retMultiType); - input->SetTypeAnn(ctx.Expr.MakeType<TFlowExprType>(outputItemType)); + input->SetTypeAnn(input->Head().GetTypeAnn()); return IGraphTransformer::TStatus::Ok; } + } // namespace NTypeAnnImpl } diff --git a/ydb/library/yql/core/type_ann/type_ann_wide.h b/ydb/library/yql/core/type_ann/type_ann_wide.h index b02cf26a2eb..f9c3227f681 100644 --- a/ydb/library/yql/core/type_ann/type_ann_wide.h +++ b/ydb/library/yql/core/type_ann/type_ann_wide.h @@ -23,5 +23,7 @@ namespace NTypeAnnImpl { IGraphTransformer::TStatus WideToBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus WideFromBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + + IGraphTransformer::TStatus WideSkipTakeBlocksWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); } // namespace NTypeAnnImpl } // namespace NYql diff --git a/ydb/library/yql/core/yql_expr_type_annotation.cpp b/ydb/library/yql/core/yql_expr_type_annotation.cpp index 1a286d143fb..617e9166c6c 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.cpp +++ b/ydb/library/yql/core/yql_expr_type_annotation.cpp @@ -2677,6 +2677,38 @@ bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& typ return true; } +bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx) { + if (!EnsureWideFlowType(node, ctx)) { + return false; + } + + auto& items = node.GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems(); + if (items.empty()) { + ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), "Expected at least one column")); + return IGraphTransformer::TStatus::Error; + } + + bool isScalar; + for (const auto& type : items) { + if (!EnsureBlockOrScalarType(node.Pos(), *type, ctx)) { + return false; + } + + blockItemTypes.push_back(GetBlockItemType(*type, isScalar)); + } + + if (!isScalar) { + ctx.AddError(TIssue(ctx.GetPosition(node.Pos()), "Last column should be a scalar")); + return false; + } + + if (!EnsureSpecificDataType(node.Pos(), *blockItemTypes.back(), EDataSlot::Uint64, ctx)) { + return false; + } + + return true; +} + bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx) { if (!node.GetTypeAnn()) { YQL_ENSURE(node.Type() == TExprNode::Lambda); diff --git a/ydb/library/yql/core/yql_expr_type_annotation.h b/ydb/library/yql/core/yql_expr_type_annotation.h index 4392a4b5504..781d6d13fba 100644 --- a/ydb/library/yql/core/yql_expr_type_annotation.h +++ b/ydb/library/yql/core/yql_expr_type_annotation.h @@ -118,6 +118,7 @@ bool EnsureFlowType(const TExprNode& node, TExprContext& ctx); bool EnsureFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); bool EnsureWideFlowType(const TExprNode& node, TExprContext& ctx); bool EnsureWideFlowType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); +bool EnsureWideFlowBlockType(const TExprNode& node, TTypeAnnotationNode::TListType& blockItemTypes, TExprContext& ctx); bool EnsureOptionalType(const TExprNode& node, TExprContext& ctx); bool EnsureOptionalType(TPositionHandle position, const TTypeAnnotationNode& type, TExprContext& ctx); bool EnsureType(const TExprNode& node, TExprContext& ctx); diff --git a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt index 7312d9050f3..0dea4f84dc8 100644 --- a/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt +++ b/ydb/library/yql/minikql/comp_nodes/CMakeLists.txt @@ -39,6 +39,7 @@ target_sources(yql-minikql-comp_nodes PRIVATE ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_count.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_agg_factory.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp + ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_callable.cpp ${CMAKE_SOURCE_DIR}/ydb/library/yql/minikql/comp_nodes/mkql_chain_map.cpp diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp new file mode 100644 index 00000000000..2fb510eb99e --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.cpp @@ -0,0 +1,164 @@ +#include "mkql_block_skiptake.h" + +#include <ydb/library/yql/minikql/arrow/arrow_defs.h> +#include <ydb/library/yql/minikql/mkql_type_builder.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> +#include <ydb/library/yql/minikql/mkql_node_builder.h> +#include <ydb/library/yql/minikql/mkql_node_cast.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +class TWideSkipBlocksWrapper: public TStatefulWideFlowComputationNode<TWideSkipBlocksWrapper> { + typedef TStatefulWideFlowComputationNode<TWideSkipBlocksWrapper> TBaseComputation; + +public: + TWideSkipBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, size_t width) + : TBaseComputation(mutables, flow, EValueRepresentation::Any) + , Flow(flow) + , Count(count) + , Width(width) + { + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + if (!state.HasValue()) { + state = Count->GetValue(ctx); + } + + auto count = state.Get<ui64>(); + ui64 blockSize = 0; + for (;;) { + auto result = Flow->FetchValues(ctx, output); + if (count == 0 || result != EFetchResult::One) { + return result; + } + + blockSize = TArrowBlock::From(*output[Width - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (blockSize > count) { + break; + } + count -= blockSize; + state = NUdf::TUnboxedValuePod(count); + } + + ui64 tailSize = blockSize - count; + for (size_t i = 0; i < Width - 1; ++i) { + auto& datum = TArrowBlock::From(*output[i]).GetDatum(); + if (datum.is_scalar()) { + continue; + } + + Y_VERIFY_DEBUG(datum.is_array()); + *output[i] = ctx.HolderFactory.CreateArrowBlock(datum.array()->Slice(count, tailSize)); + } + + *output[Width - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(static_cast<uint64_t>(tailSize))); + state = NUdf::TUnboxedValuePod::Zero(); + return EFetchResult::One; + } + +private: + void RegisterDependencies() const final { + if (const auto flow = FlowDependsOn(Flow)) { + DependsOn(flow, Count); + } + } + + IComputationWideFlowNode* const Flow; + IComputationNode* const Count; + const size_t Width; +}; + +class TWideTakeBlocksWrapper: public TStatefulWideFlowComputationNode<TWideTakeBlocksWrapper> { + typedef TStatefulWideFlowComputationNode<TWideTakeBlocksWrapper> TBaseComputation; + +public: + TWideTakeBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, size_t width) + : TBaseComputation(mutables, flow, EValueRepresentation::Any) + , Flow(flow) + , Count(count) + , Width(width) + { + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + if (!state.HasValue()) { + state = Count->GetValue(ctx); + } + + auto count = state.Get<ui64>(); + if (!count) { + return EFetchResult::Finish; + } + + auto result = Flow->FetchValues(ctx, output); + if (result == EFetchResult::One) { + ui64 blockSize = TArrowBlock::From(*output[Width - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value; + if (blockSize > count) { + for (size_t i = 0; i < Width - 1; ++i) { + auto& datum = TArrowBlock::From(*output[i]).GetDatum(); + if (datum.is_scalar()) { + continue; + } + + Y_VERIFY_DEBUG(datum.is_array()); + *output[i] = ctx.HolderFactory.CreateArrowBlock(datum.array()->Slice(0, count)); + } + *output[Width - 1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(static_cast<uint64_t>(count))); + state = NUdf::TUnboxedValuePod::Zero(); + } else { + state = NUdf::TUnboxedValuePod(count - blockSize); + } + } + return result; + } + +private: + void RegisterDependencies() const final { + if (const auto flow = FlowDependsOn(Flow)) { + DependsOn(flow, Count); + } + } + + IComputationWideFlowNode* const Flow; + IComputationNode* const Count; + const size_t Width; +}; + +IComputationNode* WrapSkipTake(bool skip, TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args"); + + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto tupleType = AS_TYPE(TTupleType, flowType->GetItemType()); + MKQL_ENSURE(tupleType->GetElementsCount() > 0, "Expected at least one column"); + + auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + + const auto count = LocateNode(ctx.NodeLocator, callable, 1); + const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType()); + MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64"); + + if (skip) { + return new TWideSkipBlocksWrapper(ctx.Mutables, wideFlow, count, tupleType->GetElementsCount()); + } + return new TWideTakeBlocksWrapper(ctx.Mutables, wideFlow, count, tupleType->GetElementsCount()); +} + +} //namespace + +IComputationNode* WrapWideSkipBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + bool skip = true; + return WrapSkipTake(skip, callable, ctx); +} + +IComputationNode* WrapWideTakeBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + bool skip = false; + return WrapSkipTake(skip, callable, ctx); +} + +} +}
\ No newline at end of file diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h new file mode 100644 index 00000000000..79d20f0cb77 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_skiptake.h @@ -0,0 +1,11 @@ +#pragma once +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> + +namespace NKikimr { +namespace NMiniKQL { + +IComputationNode* WrapWideSkipBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx); +IComputationNode* WrapWideTakeBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx); + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index 5348a807530..8cd9773020c 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -7,6 +7,7 @@ #include "mkql_block_func.h" #include "mkql_blocks.h" #include "mkql_block_agg.h" +#include "mkql_block_skiptake.h" #include "mkql_callable.h" #include "mkql_chain_map.h" #include "mkql_chain1_map.h" @@ -268,6 +269,8 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"BlockBitCast", &WrapBlockBitCast}, {"FromBlocks", &WrapFromBlocks}, {"WideFromBlocks", &WrapWideFromBlocks}, + {"WideSkipBlocks", &WrapWideSkipBlocks}, + {"WideTakeBlocks", &WrapWideTakeBlocks}, {"AsScalar", &WrapAsScalar}, {"BlockCombineAll", &WrapBlockCombineAll}, {"MakeHeap", &WrapMakeHeap}, diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp new file mode 100644 index 00000000000..4b12ffc7c35 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_block_skiptake_ut.cpp @@ -0,0 +1,179 @@ +#include "mkql_computation_node_ut.h" + +#include <ydb/library/yql/minikql/arrow/arrow_defs.h> +#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> + +#include <arrow/array/builder_primitive.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +class TTestBlockFlowWrapper: public TStatefulWideFlowComputationNode<TTestBlockFlowWrapper> { + typedef TStatefulWideFlowComputationNode<TTestBlockFlowWrapper> TBaseComputation; + +public: + TTestBlockFlowWrapper(TComputationMutables& mutables, size_t blockSize, size_t blockCount) + : TBaseComputation(mutables, nullptr, EValueRepresentation::Any) + , BlockSize(blockSize) + , BlockCount(blockCount) + { + } + + EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const { + if (!state.HasValue()) { + state = NUdf::TUnboxedValue::Zero(); + } + + ui64 index = state.Get<ui64>(); + if (index >= BlockCount) { + return EFetchResult::Finish; + } + + arrow::UInt64Builder builder(&ctx.ArrowMemoryPool); + ARROW_OK(builder.Reserve(BlockSize)); + for (size_t i = 0; i < BlockSize; ++i) { + builder.UnsafeAppend(index * BlockSize + i); + } + + std::shared_ptr<arrow::ArrayData> block; + ARROW_OK(builder.FinishInternal(&block)); + + *output[0] = ctx.HolderFactory.CreateArrowBlock(std::move(block)); + *output[1] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(index))); + *output[2] = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(BlockSize))); + + state = NUdf::TUnboxedValuePod(++index); + return EFetchResult::One; + } + +private: + void RegisterDependencies() const final { + } + + const size_t BlockSize; + const size_t BlockCount; +}; + +IComputationNode* WrapTestBlockFlow(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 0, "Expected no args"); + return new TTestBlockFlowWrapper(ctx.Mutables, 5, 2); +} + +TIntrusivePtr<IRandomProvider> CreateRandomProvider() { + return CreateDeterministicRandomProvider(1); +} + +TIntrusivePtr<ITimeProvider> CreateTimeProvider() { + return CreateDeterministicTimeProvider(10000000); +} + +TComputationNodeFactory GetTestFactory() { + return [](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* { + if (callable.GetType()->GetName() == "TestBlockFlow") { + return WrapTestBlockFlow(callable, ctx); + } + return GetBuiltinFactory()(callable, ctx); + }; +} + +struct TSetup_ { + TSetup_() + : Alloc(__LOCATION__) + { + FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry()); + RandomProvider = CreateRandomProvider(); + TimeProvider = CreateTimeProvider(); + + Env.Reset(new TTypeEnvironment(Alloc)); + PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry)); + } + + TAutoPtr<IComputationGraph> BuildGraph(TRuntimeNode pgm, EGraphPerProcess graphPerProcess = EGraphPerProcess::Multi, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) { + Explorer.Walk(pgm.GetNode(), *Env); + TComputationPatternOpts opts(Alloc.Ref(), *Env, GetTestFactory(), FunctionRegistry.Get(), + NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, "OFF", graphPerProcess); + Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts); + return Pattern->Clone(opts.ToComputationOptions(*RandomProvider, *TimeProvider)); + } + + TIntrusivePtr<IFunctionRegistry> FunctionRegistry; + TIntrusivePtr<IRandomProvider> RandomProvider; + TIntrusivePtr<ITimeProvider> TimeProvider; + + TScopedAlloc Alloc; + THolder<TTypeEnvironment> Env; + THolder<TProgramBuilder> PgmBuilder; + + TExploringNodeVisitor Explorer; + IComputationPattern::TPtr Pattern; +}; + +TRuntimeNode MakeFlow(TSetup_& setup) { + TProgramBuilder& pb = *setup.PgmBuilder; + TCallableBuilder callableBuilder(*setup.Env, "TestBlockFlow", + pb.NewFlowType( + pb.NewTupleType({ + pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Many), + pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Scalar), + pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Scalar), + }))); + return TRuntimeNode(callableBuilder.Build(), false); +} + +} // namespace + + +Y_UNIT_TEST_SUITE(TMiniKQLWideTakeSkipBlocks) { + Y_UNIT_TEST(TestWideTakeSkipBlocks) { + TSetup_ setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto flow = MakeFlow(setup); + + const auto part = pb.WideTakeBlocks(pb.WideSkipBlocks(flow, pb.NewDataLiteral<ui64>(3)), pb.NewDataLiteral<ui64>(5)); + const auto plain = pb.WideFromBlocks(part); + + const auto singleValueFlow = pb.NarrowMap(plain, [&](TRuntimeNode::TList items) -> TRuntimeNode { + // 0, 0; + // 1, 0; + // 2, 0; + // 3, 0; -> 3 + // 4, 0; -> 4 + // 5, 1; -> 6 + // 6, 1; -> 7 + // 7, 1; -> 8 + // 8, 1; + // 9, 1; + // 10, 1; + return pb.Add(items[0], items[1]); + }); + + const auto pgmReturn = pb.ForwardList(singleValueFlow); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + + NUdf::TUnboxedValue item; + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 3); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 4); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 6); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 7); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 8); + } +} + +} // namespace NMiniKQL +} // namespace NKikimr + + diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 8371a62d730..859aa2b8597 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -226,6 +226,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test return multiOptional; } +std::vector<TType*> ValidateBlockFlowType(const TType* flowType) { + const auto* inputTupleType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flowType)->GetItemType()); + MKQL_ENSURE(inputTupleType->GetElementsCount() > 0, "Expected at least one column"); + std::vector<TType*> flowItems; + flowItems.reserve(inputTupleType->GetElementsCount()); + bool isScalar; + for (size_t i = 0; i < inputTupleType->GetElementsCount(); ++i) { + auto blockType = AS_TYPE(TBlockType, inputTupleType->GetElementType(i)); + isScalar = blockType->GetShape() == TBlockType::EShape::Scalar; + auto withoutBlock = blockType->GetItemType(); + flowItems.push_back(withoutBlock); + } + + MKQL_ENSURE(isScalar, "Last column should be scalar"); + MKQL_ENSURE(AS_TYPE(TDataType, flowItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64"); + return flowItems; +} + } // namespace std::string_view ScriptTypeAsStr(EScriptType type) { @@ -1449,31 +1467,22 @@ TRuntimeNode TProgramBuilder::FromBlocks(TRuntimeNode flow) { } TRuntimeNode TProgramBuilder::WideFromBlocks(TRuntimeNode flow) { - TType* outputTupleType; - { - const auto* inputTupleType = AS_TYPE(TTupleType, AS_TYPE(TFlowType, flow.GetStaticType())->GetItemType()); - MKQL_ENSURE(inputTupleType->GetElementsCount() > 0, "Expected at least one column"); - std::vector<TType*> outputTupleItems; - outputTupleItems.reserve(inputTupleType->GetElementsCount()); - bool isScalar; - for (size_t i = 0; i < inputTupleType->GetElementsCount(); ++i) { - auto blockType = AS_TYPE(TBlockType, inputTupleType->GetElementType(i)); - isScalar = blockType->GetShape() == TBlockType::EShape::Scalar; - auto withoutBlock = blockType->GetItemType(); - outputTupleItems.push_back(withoutBlock); - } - - MKQL_ENSURE(isScalar, "Last column should be scalar"); - MKQL_ENSURE(AS_TYPE(TDataType, outputTupleItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64"); - outputTupleItems.pop_back(); - outputTupleType = NewTupleType(outputTupleItems); - } - + auto outputTupleItems = ValidateBlockFlowType(flow.GetStaticType()); + outputTupleItems.pop_back(); + TType* outputTupleType = NewTupleType(outputTupleItems); TCallableBuilder callableBuilder(Env, __func__, NewFlowType(outputTupleType)); callableBuilder.Add(flow); return TRuntimeNode(callableBuilder.Build(), false); } +TRuntimeNode TProgramBuilder::WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count) { + return BuildWideSkipTakeBlocks(__func__, flow, count); +} + +TRuntimeNode TProgramBuilder::WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count) { + return BuildWideSkipTakeBlocks(__func__, flow, count); +} + TRuntimeNode TProgramBuilder::AsScalar(TRuntimeNode value) { TCallableBuilder callableBuilder(Env, __func__, NewBlockType(value.GetStaticType(), TBlockType::EShape::Scalar)); callableBuilder.Add(value); @@ -2451,6 +2460,18 @@ TRuntimeNode TProgramBuilder::BuildMinMax(const std::string_view& callableName, return BuildMinMax(callableName, args.data(), args.size()); } +TRuntimeNode TProgramBuilder::BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count) { + ValidateBlockFlowType(flow.GetStaticType()); + + MKQL_ENSURE(count.GetStaticType()->IsData(), "Expected data"); + MKQL_ENSURE(static_cast<const TDataType&>(*count.GetStaticType()).GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64"); + + TCallableBuilder callableBuilder(Env, callableName, flow.GetStaticType()); + callableBuilder.Add(flow); + callableBuilder.Add(count); + return TRuntimeNode(callableBuilder.Build(), false); +} + TRuntimeNode TProgramBuilder::Min(const TArrayRef<const TRuntimeNode>& args) { return BuildMinMax(__func__, args.data(), args.size()); } diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index b97041c0391..00f690d4397 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -243,6 +243,8 @@ public: TRuntimeNode WideToBlocks(TRuntimeNode flow); TRuntimeNode FromBlocks(TRuntimeNode flow); TRuntimeNode WideFromBlocks(TRuntimeNode flow); + TRuntimeNode WideSkipBlocks(TRuntimeNode flow, TRuntimeNode count); + TRuntimeNode WideTakeBlocks(TRuntimeNode flow, TRuntimeNode count); TRuntimeNode AsScalar(TRuntimeNode value); TRuntimeNode BlockFunc(const std::string_view& funcName, TType* returnType, const TArrayRef<const TRuntimeNode>& args); @@ -673,6 +675,7 @@ protected: TRuntimeNode BuildLogical(const std::string_view& callableName, const TArrayRef<const TRuntimeNode>& args); TRuntimeNode BuildBinaryLogical(const std::string_view& callableName, TRuntimeNode data1, TRuntimeNode data2); TRuntimeNode BuildMinMax(const std::string_view& callableName, const TRuntimeNode* data, size_t size); + TRuntimeNode BuildWideSkipTakeBlocks(const std::string_view& callableName, TRuntimeNode flow, TRuntimeNode count); private: TRuntimeNode BuildWideFilter(const std::string_view& callableName, TRuntimeNode flow, const TNarrowLambda& handler); diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index d46d4c1c066..a456f706793 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -518,6 +518,9 @@ TMkqlCommonCallableCompiler::TShared::TShared() { {"Take", &TProgramBuilder::Take}, {"Limit", &TProgramBuilder::Take}, + {"WideTakeBlocks", &TProgramBuilder::WideTakeBlocks}, + {"WideSkipBlocks", &TProgramBuilder::WideSkipBlocks}, + {"Append", &TProgramBuilder::Append}, {"Insert", &TProgramBuilder::Append}, {"Prepend", &TProgramBuilder::Prepend}, |