aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <aneporada@yandex-team.ru>2022-02-17 10:16:20 +0300
committeraneporada <aneporada@yandex-team.ru>2022-02-17 10:16:20 +0300
commit5044e948ba652d71055597d8986592c5cd46a509 (patch)
tree3d26e147878e151b7d972a09ecd28b8f9570324a
parent238dcee0609b29afef350ad1ec1f11a5f77f3ddb (diff)
downloadydb-5044e948ba652d71055597d8986592c5cd46a509.tar.gz
[YQL-14389] Fix handling state of MapNext, add unit test
ref:d9c4908b47be7e602ed7779bc9ca3445fd1450ee
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp81
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp172
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/ya.make1
3 files changed, 224 insertions, 30 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp
index 5d0b4704727..2bbc22173a5 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp
@@ -6,27 +6,39 @@ namespace NMiniKQL {
namespace {
-class TFlowMapNextWrapper : public TStatelessFlowComputationNode<TFlowMapNextWrapper> {
- typedef TStatelessFlowComputationNode<TFlowMapNextWrapper> TBaseComputation;
+struct TState : public TComputationValue<TState> {
+ using TComputationValue::TComputationValue;
+
+ std::optional<NUdf::TUnboxedValue> Prev;
+ bool Finish = false;
+};
+
+class TFlowMapNextWrapper : public TStatefulFlowComputationNode<TFlowMapNextWrapper> {
+ typedef TStatefulFlowComputationNode<TFlowMapNextWrapper> TBaseComputation;
public:
- TFlowMapNextWrapper(EValueRepresentation kind, IComputationNode* flow,
+ TFlowMapNextWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow,
IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
- : TBaseComputation(flow, kind)
+ : TBaseComputation(mutables, flow, kind, EValueRepresentation::Any)
, Flow(flow)
, Item(item)
, NextItem(nextItem)
, NewItem(newItem)
{}
- NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
+ NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& stateValue, TComputationContext& ctx) const {
+ if (!stateValue.HasValue()) {
+ stateValue = ctx.HolderFactory.Create<TState>();
+ }
+ TState& state = *static_cast<TState*>(stateValue.AsBoxed().Get());
+
NUdf::TUnboxedValue result;
for (;;) {
- if (Finish) {
- if (!Prev) {
+ if (state.Finish) {
+ if (!state.Prev) {
return NUdf::TUnboxedValuePod::MakeFinish();
}
- Item->SetValue(ctx, std::move(*Prev));
- Prev.reset();
+ Item->SetValue(ctx, std::move(*state.Prev));
+ state.Prev.reset();
NextItem->SetValue(ctx, NUdf::TUnboxedValuePod());
return NewItem->GetValue(ctx);
}
@@ -37,17 +49,17 @@ public:
}
if (item.IsFinish()) {
- Finish = true;
+ state.Finish = true;
continue;
}
- if (!Prev) {
- Prev = std::move(item);
+ if (!state.Prev) {
+ state.Prev = std::move(item);
continue;
}
- Item->SetValue(ctx, std::move(*Prev));
- Prev = item;
+ Item->SetValue(ctx, std::move(*state.Prev));
+ state.Prev = item;
NextItem->SetValue(ctx, std::move(item));
result = NewItem->GetValue(ctx);
break;
@@ -69,8 +81,6 @@ private:
IComputationExternalNode* const Item;
IComputationExternalNode* const NextItem;
IComputationNode* const NewItem;
- mutable std::optional<NUdf::TUnboxedValue> Prev;
- mutable bool Finish = false;
};
class TStreamMapNextWrapper : public TMutableComputationNode<TStreamMapNextWrapper> {
@@ -83,10 +93,11 @@ public:
, Item(item)
, NextItem(nextItem)
, NewItem(newItem)
+ , StateIndex(mutables.CurValueIndex++)
{}
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem);
+ return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem, StateIndex);
}
private:
@@ -102,13 +113,14 @@ private:
using TBase = TComputationValue<TStreamValue>;
TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream,
- IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
+ IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem, ui32 stateIndex)
: TBase(memInfo)
, CompCtx(compCtx)
, Stream(std::move(stream))
, Item(item)
, NextItem(nextItem)
, NewItem(newItem)
+ , StateIndex(stateIndex)
{
}
@@ -128,13 +140,14 @@ private:
void Load(const NUdf::TStringRef&) final {}
NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
+ auto& state = GetState();
for (;;) {
- if (Finish) {
- if (!Prev) {
+ if (state.Finish) {
+ if (!state.Prev) {
return NUdf::EFetchStatus::Finish;
}
- Item->SetValue(CompCtx, std::move(*Prev));
- Prev.reset();
+ Item->SetValue(CompCtx, std::move(*state.Prev));
+ state.Prev.reset();
NextItem->SetValue(CompCtx, NUdf::TUnboxedValuePod());
result = NewItem->GetValue(CompCtx);
@@ -148,17 +161,17 @@ private:
}
if (status == NUdf::EFetchStatus::Finish) {
- Finish = true;
+ state.Finish = true;
continue;
}
- if (!Prev) {
- Prev = std::move(item);
+ if (!state.Prev) {
+ state.Prev = std::move(item);
continue;
}
- Item->SetValue(CompCtx, std::move(*Prev));
- Prev = item;
+ Item->SetValue(CompCtx, std::move(*state.Prev));
+ state.Prev = item;
NextItem->SetValue(CompCtx, std::move(item));
result = NewItem->GetValue(CompCtx);
break;
@@ -166,19 +179,27 @@ private:
return NUdf::EFetchStatus::Ok;
}
+ TState& GetState() const {
+ auto& result = CompCtx.MutableValues[StateIndex];
+ if (!result.HasValue()) {
+ result = CompCtx.HolderFactory.Create<TState>();
+ }
+ return *static_cast<TState*>(result.AsBoxed().Get());
+ }
+
TComputationContext& CompCtx;
const NUdf::TUnboxedValue Stream;
IComputationExternalNode* const Item;
IComputationExternalNode* const NextItem;
IComputationNode* const NewItem;
- std::optional<NUdf::TUnboxedValue> Prev;
- bool Finish = false;
+ const ui32 StateIndex;
};
IComputationNode* const Stream;
IComputationExternalNode* const Item;
IComputationExternalNode* const NextItem;
IComputationNode* const NewItem;
+ const ui32 StateIndex;
};
}
@@ -193,7 +214,7 @@ IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactory
const auto newItem = LocateNode(ctx.NodeLocator, callable, 3);
if (type->IsFlow()) {
- return new TFlowMapNextWrapper(GetValueRepresentation(type), input, itemArg, nextItemArg, newItem);
+ return new TFlowMapNextWrapper(ctx.Mutables, GetValueRepresentation(type), input, itemArg, nextItemArg, newItem);
} else if (type->IsStream()) {
return new TStreamMapNextWrapper(ctx.Mutables, input, itemArg, nextItemArg, newItem);
}
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp
new file mode 100644
index 00000000000..a416803e447
--- /dev/null
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp
@@ -0,0 +1,172 @@
+#include "mkql_computation_node_ut.h"
+#include <ydb/library/yql/minikql/mkql_runtime_version.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+Y_UNIT_TEST_SUITE(TMiniKQLMapNextTest) {
+ Y_UNIT_TEST_LLVM(OverStream) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto data1 = pb.NewDataLiteral<ui16>(10);
+ const auto data2 = pb.NewDataLiteral<ui16>(20);
+ const auto data3 = pb.NewDataLiteral<ui16>(30);
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id);
+ const auto optDataType = pb.NewOptionalType(dataType);
+ const auto tupleType = pb.NewTupleType({dataType, optDataType});
+
+ const auto list = pb.NewList(dataType, {data1, data2, data3});
+ const auto pgmReturn = pb.MapNext(pb.Iterator(list, {}),
+ [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return pb.NewTuple(tupleType, {item, nextItem});
+ });
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 20);
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 20);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 30);
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 30);
+ UNIT_ASSERT(!item.GetElement(1).HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ }
+
+ Y_UNIT_TEST_LLVM(OverSingleElementStream) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto data1 = pb.NewDataLiteral<ui16>(10);
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id);
+ const auto optDataType = pb.NewOptionalType(dataType);
+ const auto tupleType = pb.NewTupleType({dataType, optDataType});
+
+ const auto list = pb.NewList(dataType, {data1});
+ const auto pgmReturn = pb.MapNext(pb.Iterator(list, {}),
+ [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return pb.NewTuple(tupleType, {item, nextItem});
+ });
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10);
+ UNIT_ASSERT(!item.GetElement(1).HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ }
+
+ Y_UNIT_TEST_LLVM(OverEmptyStream) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id);
+ const auto optDataType = pb.NewOptionalType(dataType);
+ const auto tupleType = pb.NewTupleType({dataType, optDataType});
+
+ const auto list = pb.NewList(dataType, {});
+ const auto pgmReturn = pb.MapNext(pb.Iterator(list, {}),
+ [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return pb.NewTuple(tupleType, {item, nextItem});
+ });
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ }
+
+ Y_UNIT_TEST_LLVM(OverFlow) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto data1 = pb.NewDataLiteral<ui16>(10);
+ const auto data2 = pb.NewDataLiteral<ui16>(20);
+ const auto data3 = pb.NewDataLiteral<ui16>(30);
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id);
+ const auto optDataType = pb.NewOptionalType(dataType);
+ const auto tupleType = pb.NewTupleType({dataType, optDataType});
+
+ const auto list = pb.NewList(dataType, {data1, data2, data3});
+ const auto pgmReturn = pb.FromFlow(pb.MapNext(pb.ToFlow(pb.Iterator(list, {})),
+ [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return pb.NewTuple(tupleType, {item, nextItem});
+ }));
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 20);
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 20);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 30);
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 30);
+ UNIT_ASSERT(!item.GetElement(1).HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ }
+
+ Y_UNIT_TEST_LLVM(OverSingleElementFlow) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto data1 = pb.NewDataLiteral<ui16>(10);
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id);
+ const auto optDataType = pb.NewOptionalType(dataType);
+ const auto tupleType = pb.NewTupleType({dataType, optDataType});
+
+ const auto list = pb.NewList(dataType, {data1});
+ const auto pgmReturn = pb.FromFlow(pb.MapNext(pb.ToFlow(pb.Iterator(list, {})),
+ [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return pb.NewTuple(tupleType, {item, nextItem});
+ }));
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10);
+ UNIT_ASSERT(!item.GetElement(1).HasValue());
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ }
+
+ Y_UNIT_TEST_LLVM(OverEmptyFlow) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id);
+ const auto optDataType = pb.NewOptionalType(dataType);
+ const auto tupleType = pb.NewTupleType({dataType, optDataType});
+
+ const auto list = pb.NewList(dataType, {});
+ const auto pgmReturn = pb.FromFlow(pb.MapNext(pb.ToFlow(pb.Iterator(list, {})),
+ [&](TRuntimeNode item, TRuntimeNode nextItem) {
+ return pb.NewTuple(tupleType, {item, nextItem});
+ }));
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue();
+ NUdf::TUnboxedValue item;
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item));
+ }
+}
+
+}
+}
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/ya.make b/ydb/library/yql/minikql/comp_nodes/ut/ya.make
index dffda1318f3..7c7505394c2 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/ya.make
+++ b/ydb/library/yql/minikql/comp_nodes/ut/ya.make
@@ -52,6 +52,7 @@ SRCS(
mkql_wide_map_ut.cpp
mkql_wide_nodes_ut.cpp
mkql_listfromrange_ut.cpp
+ mkql_mapnext_ut.cpp
)
PEERDIR(