aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <aneporada@yandex-team.ru>2022-02-16 00:37:52 +0300
committeraneporada <aneporada@yandex-team.ru>2022-02-16 00:37:52 +0300
commit1de3c666314929ac9afa4646b07acd822fcd8baf (patch)
tree53d7956e1b508728f6fc967703e45ae395f6e3cb
parent0bd0e5d1ab192e2782657cb05a36637ccb42ed57 (diff)
downloadydb-1de3c666314929ac9afa4646b07acd822fcd8baf.tar.gz
[YQL-14389] Implement MapNext for stream and flow
ref:d91668ca7d35bceb13e602ecd5f711e453cb14bc
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_core.cpp1
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.cpp34
-rw-r--r--ydb/library/yql/core/type_ann/type_ann_list.h1
-rw-r--r--ydb/library/yql/core/yql_expr_constraint.cpp1
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp2
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp205
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h10
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ya.make2
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.cpp29
-rw-r--r--ydb/library/yql/minikql/mkql_program_builder.h1
-rw-r--r--ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp7
11 files changed, 293 insertions, 0 deletions
diff --git a/ydb/library/yql/core/type_ann/type_ann_core.cpp b/ydb/library/yql/core/type_ann/type_ann_core.cpp
index cf811b38eb..743692061a 100644
--- a/ydb/library/yql/core/type_ann/type_ann_core.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_core.cpp
@@ -12807,6 +12807,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["Likely"] = &BoolOpt1Wrapper;
Functions["Map"] = &MapWrapper;
Functions["OrderedMap"] = &MapWrapper;
+ Functions["MapNext"] = &MapNextWrapper;
Functions["FoldMap"] = &FoldMapWrapper;
Functions["Fold1Map"] = &Fold1MapWrapper;
Functions["Chain1Map"] = &Chain1MapWrapper;
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.cpp b/ydb/library/yql/core/type_ann/type_ann_list.cpp
index b145cde389..1fa47b7600 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.cpp
+++ b/ydb/library/yql/core/type_ann/type_ann_list.cpp
@@ -783,6 +783,40 @@ namespace {
return IGraphTransformer::TStatus::Ok;
}
+ IGraphTransformer::TStatus MapNextWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+ Y_UNUSED(output);
+ if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ const TTypeAnnotationNode* itemType = nullptr;
+ if (!EnsureNewSeqType<false, false>(input->Head(), ctx.Expr, &itemType)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ const auto status = ConvertToLambda(input->TailRef(), ctx.Expr, 2);
+ if (status.Level != IGraphTransformer::TStatus::Ok) {
+ return status;
+ }
+
+ auto& lambda = input->TailRef();
+ const TTypeAnnotationNode* nextItemType = ctx.Expr.MakeType<TOptionalExprType>(itemType);
+ if (!UpdateLambdaAllArgumentsTypes(lambda, {itemType, nextItemType}, ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ if (!lambda->GetTypeAnn()) {
+ return IGraphTransformer::TStatus::Repeat;
+ }
+
+ if (!EnsureComputableType(lambda->Pos(), *lambda->GetTypeAnn(), ctx.Expr)) {
+ return IGraphTransformer::TStatus::Error;
+ }
+
+ input->SetTypeAnn(MakeSequenceType(input->Head().GetTypeAnn()->GetKind(), *lambda->GetTypeAnn(), ctx.Expr));
+ return IGraphTransformer::TStatus::Ok;
+ }
+
IGraphTransformer::TStatus LMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
diff --git a/ydb/library/yql/core/type_ann/type_ann_list.h b/ydb/library/yql/core/type_ann/type_ann_list.h
index 608b813fe8..c56b4879ee 100644
--- a/ydb/library/yql/core/type_ann/type_ann_list.h
+++ b/ydb/library/yql/core/type_ann/type_ann_list.h
@@ -14,6 +14,7 @@ namespace NTypeAnnImpl {
template <bool InverseCondition>
IGraphTransformer::TStatus InclusiveFilterWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus MapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
+ IGraphTransformer::TStatus MapNextWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus LMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
template <bool Warn>
IGraphTransformer::TStatus FlatMapWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
diff --git a/ydb/library/yql/core/yql_expr_constraint.cpp b/ydb/library/yql/core/yql_expr_constraint.cpp
index 78c8f68776..25a20effd7 100644
--- a/ydb/library/yql/core/yql_expr_constraint.cpp
+++ b/ydb/library/yql/core/yql_expr_constraint.cpp
@@ -124,6 +124,7 @@ public:
Functions["WideFilter"] = &TCallableConstraintTransformer::FilterWrap<true>;
Functions["OrderedMap"] = &TCallableConstraintTransformer::MapWrap<true, false>;
Functions["Map"] = &TCallableConstraintTransformer::MapWrap<false, false>;
+ Functions["MapNext"] = &TCallableConstraintTransformer::MapWrap<true, false>;
Functions["OrderedFlatMap"] = &TCallableConstraintTransformer::MapWrap<true, true>;
Functions["FlatMap"] = &TCallableConstraintTransformer::MapWrap<false, true>;
Functions["OrderedMultiMap"] = &TCallableConstraintTransformer::MapWrap<true, false>;
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
index 57599acb64..ae2b84405b 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
@@ -50,6 +50,7 @@
#include "mkql_logical.h"
#include "mkql_lookup.h"
#include "mkql_map.h"
+#include "mkql_mapnext.h"
#include "mkql_map_join.h"
#include "mkql_multihopping.h"
#include "mkql_multimap.h"
@@ -144,6 +145,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"Fold1", &WrapFold1},
{"Map", &WrapMap},
{"OrderedMap", &WrapMap},
+ {"MapNext", &WrapMapNext},
{"MultiMap", &WrapMultiMap},
{"FlatMap", &WrapFlatMap},
{"OrderedFlatMap", &WrapFlatMap},
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp
new file mode 100644
index 0000000000..5d0b470472
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp
@@ -0,0 +1,205 @@
+#include "mkql_mapnext.h"
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+namespace {
+
+class TFlowMapNextWrapper : public TStatelessFlowComputationNode<TFlowMapNextWrapper> {
+ typedef TStatelessFlowComputationNode<TFlowMapNextWrapper> TBaseComputation;
+public:
+ TFlowMapNextWrapper(EValueRepresentation kind, IComputationNode* flow,
+ IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
+ : TBaseComputation(flow, kind)
+ , Flow(flow)
+ , Item(item)
+ , NextItem(nextItem)
+ , NewItem(newItem)
+ {}
+
+ NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
+ NUdf::TUnboxedValue result;
+ for (;;) {
+ if (Finish) {
+ if (!Prev) {
+ return NUdf::TUnboxedValuePod::MakeFinish();
+ }
+ Item->SetValue(ctx, std::move(*Prev));
+ Prev.reset();
+ NextItem->SetValue(ctx, NUdf::TUnboxedValuePod());
+ return NewItem->GetValue(ctx);
+ }
+
+ auto item = Flow->GetValue(ctx);
+ if (item.IsYield()) {
+ return item;
+ }
+
+ if (item.IsFinish()) {
+ Finish = true;
+ continue;
+ }
+
+ if (!Prev) {
+ Prev = std::move(item);
+ continue;
+ }
+
+ Item->SetValue(ctx, std::move(*Prev));
+ Prev = item;
+ NextItem->SetValue(ctx, std::move(item));
+ result = NewItem->GetValue(ctx);
+ break;
+ }
+
+ return result;
+ }
+
+private:
+ void RegisterDependencies() const final {
+ if (const auto flow = FlowDependsOn(Flow)) {
+ Own(flow, Item);
+ Own(flow, NextItem);
+ DependsOn(flow, NewItem);
+ }
+ }
+
+ IComputationNode* const Flow;
+ IComputationExternalNode* const Item;
+ IComputationExternalNode* const NextItem;
+ IComputationNode* const NewItem;
+ mutable std::optional<NUdf::TUnboxedValue> Prev;
+ mutable bool Finish = false;
+};
+
+class TStreamMapNextWrapper : public TMutableComputationNode<TStreamMapNextWrapper> {
+ typedef TMutableComputationNode<TStreamMapNextWrapper> TBaseComputation;
+public:
+ TStreamMapNextWrapper(TComputationMutables& mutables, IComputationNode* stream,
+ IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
+ : TBaseComputation(mutables)
+ , Stream(stream)
+ , Item(item)
+ , NextItem(nextItem)
+ , NewItem(newItem)
+ {}
+
+ NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
+ return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem);
+ }
+
+private:
+ void RegisterDependencies() const final {
+ DependsOn(Stream);
+ Own(Item);
+ Own(NextItem);
+ DependsOn(NewItem);
+ }
+
+ class TStreamValue : public TComputationValue<TStreamValue> {
+ public:
+ using TBase = TComputationValue<TStreamValue>;
+
+ TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream,
+ IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
+ : TBase(memInfo)
+ , CompCtx(compCtx)
+ , Stream(std::move(stream))
+ , Item(item)
+ , NextItem(nextItem)
+ , NewItem(newItem)
+ {
+ }
+
+ private:
+ ui32 GetTraverseCount() const final {
+ return 1U;
+ }
+
+ NUdf::TUnboxedValue GetTraverseItem(ui32) const final {
+ return Stream;
+ }
+
+ NUdf::TUnboxedValue Save() const final {
+ return NUdf::TUnboxedValuePod::Zero();
+ }
+
+ void Load(const NUdf::TStringRef&) final {}
+
+ NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
+ for (;;) {
+ if (Finish) {
+ if (!Prev) {
+ return NUdf::EFetchStatus::Finish;
+ }
+ Item->SetValue(CompCtx, std::move(*Prev));
+ Prev.reset();
+ NextItem->SetValue(CompCtx, NUdf::TUnboxedValuePod());
+
+ result = NewItem->GetValue(CompCtx);
+ return NUdf::EFetchStatus::Ok;
+ }
+
+ NUdf::TUnboxedValue item;
+ const auto status = Stream.Fetch(item);
+ if (status == NUdf::EFetchStatus::Yield) {
+ return status;
+ }
+
+ if (status == NUdf::EFetchStatus::Finish) {
+ Finish = true;
+ continue;
+ }
+
+ if (!Prev) {
+ Prev = std::move(item);
+ continue;
+ }
+
+ Item->SetValue(CompCtx, std::move(*Prev));
+ Prev = item;
+ NextItem->SetValue(CompCtx, std::move(item));
+ result = NewItem->GetValue(CompCtx);
+ break;
+ }
+ return NUdf::EFetchStatus::Ok;
+ }
+
+ TComputationContext& CompCtx;
+ const NUdf::TUnboxedValue Stream;
+ IComputationExternalNode* const Item;
+ IComputationExternalNode* const NextItem;
+ IComputationNode* const NewItem;
+ std::optional<NUdf::TUnboxedValue> Prev;
+ bool Finish = false;
+ };
+
+ IComputationNode* const Stream;
+ IComputationExternalNode* const Item;
+ IComputationExternalNode* const NextItem;
+ IComputationNode* const NewItem;
+};
+
+}
+
+IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
+ MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args, got " << callable.GetInputsCount());
+ const auto type = callable.GetType()->GetReturnType();
+
+ const auto input = LocateNode(ctx.NodeLocator, callable, 0);
+ const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
+ const auto nextItemArg = LocateExternalNode(ctx.NodeLocator, callable, 2);
+ const auto newItem = LocateNode(ctx.NodeLocator, callable, 3);
+
+ if (type->IsFlow()) {
+ return new TFlowMapNextWrapper(GetValueRepresentation(type), input, itemArg, nextItemArg, newItem);
+ } else if (type->IsStream()) {
+ return new TStreamMapNextWrapper(ctx.Mutables, input, itemArg, nextItemArg, newItem);
+ }
+
+ THROW yexception() << "Expected flow or stream.";
+}
+
+}
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h
new file mode 100644
index 0000000000..a3741c69fb
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.h
@@ -0,0 +1,10 @@
+#pragma once
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactoryContext& ctx);
+
+}
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/ya.make b/ydb/library/yql/minikql/comp_nodes/ya.make
index 55838c4012..de4081e35f 100644
--- a/ydb/library/yql/minikql/comp_nodes/ya.make
+++ b/ydb/library/yql/minikql/comp_nodes/ya.make
@@ -110,6 +110,8 @@ SRCS(
mkql_lookup.h
mkql_map.cpp
mkql_map.h
+ mkql_mapnext.cpp
+ mkql_mapnext.h
mkql_map_join.cpp
mkql_map_join.h
mkql_multihopping.cpp
diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp
index 47ae6e4c8f..ff80f0605f 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_program_builder.cpp
@@ -734,6 +734,35 @@ TRuntimeNode TProgramBuilder::OrderedMap(TRuntimeNode list, const TUnaryLambda&
return BuildMap(__func__, list, handler);
}
+TRuntimeNode TProgramBuilder::MapNext(TRuntimeNode list, const TBinaryLambda& handler) {
+ const auto listType = list.GetStaticType();
+ MKQL_ENSURE(listType->IsStream() || listType->IsFlow(), "Expected stream or flow");
+
+ const auto itemType = listType->IsFlow() ?
+ AS_TYPE(TFlowType, listType)->GetItemType():
+ AS_TYPE(TStreamType, listType)->GetItemType();
+
+ ThrowIfListOfVoid(itemType);
+
+ TType* nextItemType = TOptionalType::Create(itemType, Env);
+
+ const auto itemArg = Arg(itemType);
+ const auto nextItemArg = Arg(nextItemType);
+
+ const auto newItem = handler(itemArg, nextItemArg);
+
+ const auto resultListType = listType->IsFlow() ?
+ (TType*)TFlowType::Create(newItem.GetStaticType(), Env):
+ (TType*)TStreamType::Create(newItem.GetStaticType(), Env);
+
+ TCallableBuilder callableBuilder(Env, __func__, resultListType);
+ callableBuilder.Add(list);
+ callableBuilder.Add(itemArg);
+ callableBuilder.Add(nextItemArg);
+ callableBuilder.Add(newItem);
+ return TRuntimeNode(callableBuilder.Build(), false);
+}
+
template <bool Ordered>
TRuntimeNode TProgramBuilder::BuildExtract(TRuntimeNode list, const std::string_view& name) {
const auto listType = list.GetStaticType();
diff --git a/ydb/library/yql/minikql/mkql_program_builder.h b/ydb/library/yql/minikql/mkql_program_builder.h
index 10f1ad7ccf..0d8402d46c 100644
--- a/ydb/library/yql/minikql/mkql_program_builder.h
+++ b/ydb/library/yql/minikql/mkql_program_builder.h
@@ -325,6 +325,7 @@ public:
TRuntimeNode Discard(TRuntimeNode stream);
TRuntimeNode Map(TRuntimeNode list, const TUnaryLambda& handler);
TRuntimeNode OrderedMap(TRuntimeNode list, const TUnaryLambda& handler);
+ TRuntimeNode MapNext(TRuntimeNode list, const TBinaryLambda& handler);
TRuntimeNode Extract(TRuntimeNode list, const std::string_view& name);
TRuntimeNode OrderedExtract(TRuntimeNode list, const std::string_view& name);
TRuntimeNode ChainMap(TRuntimeNode list, TRuntimeNode state, const TBinaryLambda& handler);
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
index dd1f162d78..7be905346b 100644
--- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
+++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp
@@ -1019,6 +1019,13 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
});
});
+ AddCallable("MapNext", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+ const auto list = MkqlBuildExpr(node.Head(), ctx);
+ return ctx.ProgramBuilder.MapNext(list, [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return MkqlBuildLambda(node.Tail(), ctx, {item, nextItem});
+ });
+ });
+
AddCallable("Fold1", [](const TExprNode& node, TMkqlBuildContext& ctx) {
const auto list = MkqlBuildExpr(node.Head(), ctx);
return ctx.ProgramBuilder.Fold1(list, [&](TRuntimeNode item) {