diff options
author | aneporada <aneporada@yandex-team.ru> | 2022-02-16 00:37:52 +0300 |
---|---|---|
committer | aneporada <aneporada@yandex-team.ru> | 2022-02-16 00:37:52 +0300 |
commit | 1de3c666314929ac9afa4646b07acd822fcd8baf (patch) | |
tree | 53d7956e1b508728f6fc967703e45ae395f6e3cb | |
parent | 0bd0e5d1ab192e2782657cb05a36637ccb42ed57 (diff) | |
download | ydb-1de3c666314929ac9afa4646b07acd822fcd8baf.tar.gz |
[YQL-14389] Implement MapNext for stream and flow
ref:d91668ca7d35bceb13e602ecd5f711e453cb14bc
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_core.cpp | 1 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.cpp | 34 | ||||
-rw-r--r-- | ydb/library/yql/core/type_ann/type_ann_list.h | 1 | ||||
-rw-r--r-- | ydb/library/yql/core/yql_expr_constraint.cpp | 1 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp | 2 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp | 205 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h | 10 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/ya.make | 2 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.cpp | 29 | ||||
-rw-r--r-- | ydb/library/yql/minikql/mkql_program_builder.h | 1 | ||||
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 7 |
11 files changed, 293 insertions, 0 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp index cf811b38eb..743692061a 100644 --- a/ydb/library/yql/core/type_ann/type_ann_core.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp @@ -12807,6 +12807,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot> Functions["Likely"] = &BoolOpt1Wrapper; Functions["Map"] = &MapWrapper; Functions["OrderedMap"] = &MapWrapper; + Functions["MapNext"] = &MapNextWrapper; Functions["FoldMap"] = &FoldMapWrapper; Functions["Fold1Map"] = &Fold1MapWrapper; Functions["Chain1Map"] = &Chain1MapWrapper; diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp index b145cde389..1fa47b7600 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp @@ -783,6 +783,40 @@ namespace { return IGraphTransformer::TStatus::Ok; } + IGraphTransformer::TStatus MapNextWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { + Y_UNUSED(output); + if (!EnsureArgsCount(*input, 2, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + const TTypeAnnotationNode* itemType = nullptr; + if (!EnsureNewSeqType<false, false>(input->Head(), ctx.Expr, &itemType)) { + return IGraphTransformer::TStatus::Error; + } + + const auto status = ConvertToLambda(input->TailRef(), ctx.Expr, 2); + if (status.Level != IGraphTransformer::TStatus::Ok) { + return status; + } + + auto& lambda = input->TailRef(); + const TTypeAnnotationNode* nextItemType = ctx.Expr.MakeType<TOptionalExprType>(itemType); + if (!UpdateLambdaAllArgumentsTypes(lambda, {itemType, nextItemType}, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + if (!lambda->GetTypeAnn()) { + return IGraphTransformer::TStatus::Repeat; + } + + if (!EnsureComputableType(lambda->Pos(), *lambda->GetTypeAnn(), ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + input->SetTypeAnn(MakeSequenceType(input->Head().GetTypeAnn()->GetKind(), *lambda->GetTypeAnn(), ctx.Expr)); + return IGraphTransformer::TStatus::Ok; + } + IGraphTransformer::TStatus LMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { Y_UNUSED(output); if (!EnsureArgsCount(*input, 2, ctx.Expr)) { diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h index 608b813fe8..c56b4879ee 100644 --- a/ydb/library/yql/core/type_ann/type_ann_list.h +++ b/ydb/library/yql/core/type_ann/type_ann_list.h @@ -14,6 +14,7 @@ namespace NTypeAnnImpl { template <bool InverseCondition> IGraphTransformer::TStatus InclusiveFilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); + IGraphTransformer::TStatus MapNextWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); IGraphTransformer::TStatus LMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); template <bool Warn> IGraphTransformer::TStatus FlatMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx); diff --git a/ydb/library/yql/core/yql_expr_constraint.cpp b/ydb/library/yql/core/yql_expr_constraint.cpp index 78c8f68776..25a20effd7 100644 --- a/ydb/library/yql/core/yql_expr_constraint.cpp +++ b/ydb/library/yql/core/yql_expr_constraint.cpp @@ -124,6 +124,7 @@ public: Functions["WideFilter"] = &TCallableConstraintTransformer::FilterWrap<true>; Functions["OrderedMap"] = &TCallableConstraintTransformer::MapWrap<true, false>; Functions["Map"] = &TCallableConstraintTransformer::MapWrap<false, false>; + Functions["MapNext"] = &TCallableConstraintTransformer::MapWrap<true, false>; Functions["OrderedFlatMap"] = &TCallableConstraintTransformer::MapWrap<true, true>; Functions["FlatMap"] = &TCallableConstraintTransformer::MapWrap<false, true>; Functions["OrderedMultiMap"] = &TCallableConstraintTransformer::MapWrap<true, false>; diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp index 57599acb64..ae2b84405b 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp @@ -50,6 +50,7 @@ #include "mkql_logical.h" #include "mkql_lookup.h" #include "mkql_map.h" +#include "mkql_mapnext.h" #include "mkql_map_join.h" #include "mkql_multihopping.h" #include "mkql_multimap.h" @@ -144,6 +145,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller { {"Fold1", &WrapFold1}, {"Map", &WrapMap}, {"OrderedMap", &WrapMap}, + {"MapNext", &WrapMapNext}, {"MultiMap", &WrapMultiMap}, {"FlatMap", &WrapFlatMap}, {"OrderedFlatMap", &WrapFlatMap}, diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp new file mode 100644 index 0000000000..5d0b470472 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp @@ -0,0 +1,205 @@ +#include "mkql_mapnext.h" +#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +class TFlowMapNextWrapper : public TStatelessFlowComputationNode<TFlowMapNextWrapper> { + typedef TStatelessFlowComputationNode<TFlowMapNextWrapper> TBaseComputation; +public: + TFlowMapNextWrapper(EValueRepresentation kind, IComputationNode* flow, + IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem) + : TBaseComputation(flow, kind) + , Flow(flow) + , Item(item) + , NextItem(nextItem) + , NewItem(newItem) + {} + + NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { + NUdf::TUnboxedValue result; + for (;;) { + if (Finish) { + if (!Prev) { + return NUdf::TUnboxedValuePod::MakeFinish(); + } + Item->SetValue(ctx, std::move(*Prev)); + Prev.reset(); + NextItem->SetValue(ctx, NUdf::TUnboxedValuePod()); + return NewItem->GetValue(ctx); + } + + auto item = Flow->GetValue(ctx); + if (item.IsYield()) { + return item; + } + + if (item.IsFinish()) { + Finish = true; + continue; + } + + if (!Prev) { + Prev = std::move(item); + continue; + } + + Item->SetValue(ctx, std::move(*Prev)); + Prev = item; + NextItem->SetValue(ctx, std::move(item)); + result = NewItem->GetValue(ctx); + break; + } + + return result; + } + +private: + void RegisterDependencies() const final { + if (const auto flow = FlowDependsOn(Flow)) { + Own(flow, Item); + Own(flow, NextItem); + DependsOn(flow, NewItem); + } + } + + IComputationNode* const Flow; + IComputationExternalNode* const Item; + IComputationExternalNode* const NextItem; + IComputationNode* const NewItem; + mutable std::optional<NUdf::TUnboxedValue> Prev; + mutable bool Finish = false; +}; + +class TStreamMapNextWrapper : public TMutableComputationNode<TStreamMapNextWrapper> { + typedef TMutableComputationNode<TStreamMapNextWrapper> TBaseComputation; +public: + TStreamMapNextWrapper(TComputationMutables& mutables, IComputationNode* stream, + IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem) + : TBaseComputation(mutables) + , Stream(stream) + , Item(item) + , NextItem(nextItem) + , NewItem(newItem) + {} + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem); + } + +private: + void RegisterDependencies() const final { + DependsOn(Stream); + Own(Item); + Own(NextItem); + DependsOn(NewItem); + } + + class TStreamValue : public TComputationValue<TStreamValue> { + public: + using TBase = TComputationValue<TStreamValue>; + + TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream, + IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem) + : TBase(memInfo) + , CompCtx(compCtx) + , Stream(std::move(stream)) + , Item(item) + , NextItem(nextItem) + , NewItem(newItem) + { + } + + private: + ui32 GetTraverseCount() const final { + return 1U; + } + + NUdf::TUnboxedValue GetTraverseItem(ui32) const final { + return Stream; + } + + NUdf::TUnboxedValue Save() const final { + return NUdf::TUnboxedValuePod::Zero(); + } + + void Load(const NUdf::TStringRef&) final {} + + NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final { + for (;;) { + if (Finish) { + if (!Prev) { + return NUdf::EFetchStatus::Finish; + } + Item->SetValue(CompCtx, std::move(*Prev)); + Prev.reset(); + NextItem->SetValue(CompCtx, NUdf::TUnboxedValuePod()); + + result = NewItem->GetValue(CompCtx); + return NUdf::EFetchStatus::Ok; + } + + NUdf::TUnboxedValue item; + const auto status = Stream.Fetch(item); + if (status == NUdf::EFetchStatus::Yield) { + return status; + } + + if (status == NUdf::EFetchStatus::Finish) { + Finish = true; + continue; + } + + if (!Prev) { + Prev = std::move(item); + continue; + } + + Item->SetValue(CompCtx, std::move(*Prev)); + Prev = item; + NextItem->SetValue(CompCtx, std::move(item)); + result = NewItem->GetValue(CompCtx); + break; + } + return NUdf::EFetchStatus::Ok; + } + + TComputationContext& CompCtx; + const NUdf::TUnboxedValue Stream; + IComputationExternalNode* const Item; + IComputationExternalNode* const NextItem; + IComputationNode* const NewItem; + std::optional<NUdf::TUnboxedValue> Prev; + bool Finish = false; + }; + + IComputationNode* const Stream; + IComputationExternalNode* const Item; + IComputationExternalNode* const NextItem; + IComputationNode* const NewItem; +}; + +} + +IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args, got " << callable.GetInputsCount()); + const auto type = callable.GetType()->GetReturnType(); + + const auto input = LocateNode(ctx.NodeLocator, callable, 0); + const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1); + const auto nextItemArg = LocateExternalNode(ctx.NodeLocator, callable, 2); + const auto newItem = LocateNode(ctx.NodeLocator, callable, 3); + + if (type->IsFlow()) { + return new TFlowMapNextWrapper(GetValueRepresentation(type), input, itemArg, nextItemArg, newItem); + } else if (type->IsStream()) { + return new TStreamMapNextWrapper(ctx.Mutables, input, itemArg, nextItemArg, newItem); + } + + THROW yexception() << "Expected flow or stream."; +} + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h new file mode 100644 index 0000000000..a3741c69fb --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h @@ -0,0 +1,10 @@ +#pragma once +#include <ydb/library/yql/minikql/computation/mkql_computation_node.h> + +namespace NKikimr { +namespace NMiniKQL { + +IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactoryContext& ctx); + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/ya.make b/ydb/library/yql/minikql/comp_nodes/ya.make index 55838c4012..de4081e35f 100644 --- a/ydb/library/yql/minikql/comp_nodes/ya.make +++ b/ydb/library/yql/minikql/comp_nodes/ya.make @@ -110,6 +110,8 @@ SRCS( mkql_lookup.h mkql_map.cpp mkql_map.h + mkql_mapnext.cpp + mkql_mapnext.h mkql_map_join.cpp mkql_map_join.h mkql_multihopping.cpp diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index 47ae6e4c8f..ff80f0605f 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -734,6 +734,35 @@ TRuntimeNode TProgramBuilder::OrderedMap(TRuntimeNode list, const TUnaryLambda& return BuildMap(__func__, list, handler); } +TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& handler) { + const auto listType = list.GetStaticType(); + MKQL_ENSURE(listType->IsStream() || listType->IsFlow(), "Expected stream or flow"); + + const auto itemType = listType->IsFlow() ? + AS_TYPE(TFlowType, listType)->GetItemType(): + AS_TYPE(TStreamType, listType)->GetItemType(); + + ThrowIfListOfVoid(itemType); + + TType* nextItemType = TOptionalType::Create(itemType, Env); + + const auto itemArg = Arg(itemType); + const auto nextItemArg = Arg(nextItemType); + + const auto newItem = handler(itemArg, nextItemArg); + + const auto resultListType = listType->IsFlow() ? + (TType*)TFlowType::Create(newItem.GetStaticType(), Env): + (TType*)TStreamType::Create(newItem.GetStaticType(), Env); + + TCallableBuilder callableBuilder(Env, __func__, resultListType); + callableBuilder.Add(list); + callableBuilder.Add(itemArg); + callableBuilder.Add(nextItemArg); + callableBuilder.Add(newItem); + return TRuntimeNode(callableBuilder.Build(), false); +} + template <bool Ordered> TRuntimeNode TProgramBuilder::BuildExtract(TRuntimeNode list, const std::string_view& name) { const auto listType = list.GetStaticType(); diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h index 10f1ad7ccf..0d8402d46c 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.h +++ b/ydb/library/yql/minikql/mkql_program_builder.h @@ -325,6 +325,7 @@ public: TRuntimeNode Discard(TRuntimeNode stream); TRuntimeNode Map(TRuntimeNode list, const TUnaryLambda& handler); TRuntimeNode OrderedMap(TRuntimeNode list, const TUnaryLambda& handler); + TRuntimeNode MapNext(TRuntimeNode list, const TBinaryLambda& handler); TRuntimeNode Extract(TRuntimeNode list, const std::string_view& name); TRuntimeNode OrderedExtract(TRuntimeNode list, const std::string_view& name); TRuntimeNode ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler); diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index dd1f162d78..7be905346b 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -1019,6 +1019,13 @@ TMkqlCommonCallableCompiler::TShared::TShared() { }); }); + AddCallable("MapNext", [](const TExprNode& node, TMkqlBuildContext& ctx) { + const auto list = MkqlBuildExpr(node.Head(), ctx); + return ctx.ProgramBuilder.MapNext(list, [&](TRuntimeNode item, TRuntimeNode nextItem) { + return MkqlBuildLambda(node.Tail(), ctx, {item, nextItem}); + }); + }); + AddCallable("Fold1", [](const TExprNode& node, TMkqlBuildContext& ctx) { const auto list = MkqlBuildExpr(node.Head(), ctx); return ctx.ProgramBuilder.Fold1(list, [&](TRuntimeNode item) { |