diff options
author | aneporada <aneporada@yandex-team.ru> | 2022-02-17 10:16:20 +0300 |
---|---|---|
committer | aneporada <aneporada@yandex-team.ru> | 2022-02-17 10:16:20 +0300 |
commit | 5044e948ba652d71055597d8986592c5cd46a509 (patch) | |
tree | 3d26e147878e151b7d972a09ecd28b8f9570324a | |
parent | 238dcee0609b29afef350ad1ec1f11a5f77f3ddb (diff) | |
download | ydb-5044e948ba652d71055597d8986592c5cd46a509.tar.gz |
[YQL-14389] Fix handling state of MapNext, add unit test
ref:d9c4908b47be7e602ed7779bc9ca3445fd1450ee
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp | 81 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp | 172 | ||||
-rw-r--r-- | ydb/library/yql/minikql/comp_nodes/ut/ya.make | 1 |
3 files changed, 224 insertions, 30 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp index 5d0b4704727..2bbc22173a5 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_mapnext.cpp @@ -6,27 +6,39 @@ namespace NMiniKQL { namespace { -class TFlowMapNextWrapper : public TStatelessFlowComputationNode<TFlowMapNextWrapper> { - typedef TStatelessFlowComputationNode<TFlowMapNextWrapper> TBaseComputation; +struct TState : public TComputationValue<TState> { + using TComputationValue::TComputationValue; + + std::optional<NUdf::TUnboxedValue> Prev; + bool Finish = false; +}; + +class TFlowMapNextWrapper : public TStatefulFlowComputationNode<TFlowMapNextWrapper> { + typedef TStatefulFlowComputationNode<TFlowMapNextWrapper> TBaseComputation; public: - TFlowMapNextWrapper(EValueRepresentation kind, IComputationNode* flow, + TFlowMapNextWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem) - : TBaseComputation(flow, kind) + : TBaseComputation(mutables, flow, kind, EValueRepresentation::Any) , Flow(flow) , Item(item) , NextItem(nextItem) , NewItem(newItem) {} - NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { + NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& stateValue, TComputationContext& ctx) const { + if (!stateValue.HasValue()) { + stateValue = ctx.HolderFactory.Create<TState>(); + } + TState& state = *static_cast<TState*>(stateValue.AsBoxed().Get()); + NUdf::TUnboxedValue result; for (;;) { - if (Finish) { - if (!Prev) { + if (state.Finish) { + if (!state.Prev) { return NUdf::TUnboxedValuePod::MakeFinish(); } - Item->SetValue(ctx, std::move(*Prev)); - Prev.reset(); + Item->SetValue(ctx, std::move(*state.Prev)); + state.Prev.reset(); NextItem->SetValue(ctx, NUdf::TUnboxedValuePod()); return NewItem->GetValue(ctx); } @@ -37,17 +49,17 @@ public: } if (item.IsFinish()) { - Finish = true; + state.Finish = true; continue; } - if (!Prev) { - Prev = std::move(item); + if (!state.Prev) { + state.Prev = std::move(item); continue; } - Item->SetValue(ctx, std::move(*Prev)); - Prev = item; + Item->SetValue(ctx, std::move(*state.Prev)); + state.Prev = item; NextItem->SetValue(ctx, std::move(item)); result = NewItem->GetValue(ctx); break; @@ -69,8 +81,6 @@ private: IComputationExternalNode* const Item; IComputationExternalNode* const NextItem; IComputationNode* const NewItem; - mutable std::optional<NUdf::TUnboxedValue> Prev; - mutable bool Finish = false; }; class TStreamMapNextWrapper : public TMutableComputationNode<TStreamMapNextWrapper> { @@ -83,10 +93,11 @@ public: , Item(item) , NextItem(nextItem) , NewItem(newItem) + , StateIndex(mutables.CurValueIndex++) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { - return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem); + return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem, StateIndex); } private: @@ -102,13 +113,14 @@ private: using TBase = TComputationValue<TStreamValue>; TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream, - IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem) + IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem, ui32 stateIndex) : TBase(memInfo) , CompCtx(compCtx) , Stream(std::move(stream)) , Item(item) , NextItem(nextItem) , NewItem(newItem) + , StateIndex(stateIndex) { } @@ -128,13 +140,14 @@ private: void Load(const NUdf::TStringRef&) final {} NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final { + auto& state = GetState(); for (;;) { - if (Finish) { - if (!Prev) { + if (state.Finish) { + if (!state.Prev) { return NUdf::EFetchStatus::Finish; } - Item->SetValue(CompCtx, std::move(*Prev)); - Prev.reset(); + Item->SetValue(CompCtx, std::move(*state.Prev)); + state.Prev.reset(); NextItem->SetValue(CompCtx, NUdf::TUnboxedValuePod()); result = NewItem->GetValue(CompCtx); @@ -148,17 +161,17 @@ private: } if (status == NUdf::EFetchStatus::Finish) { - Finish = true; + state.Finish = true; continue; } - if (!Prev) { - Prev = std::move(item); + if (!state.Prev) { + state.Prev = std::move(item); continue; } - Item->SetValue(CompCtx, std::move(*Prev)); - Prev = item; + Item->SetValue(CompCtx, std::move(*state.Prev)); + state.Prev = item; NextItem->SetValue(CompCtx, std::move(item)); result = NewItem->GetValue(CompCtx); break; @@ -166,19 +179,27 @@ private: return NUdf::EFetchStatus::Ok; } + TState& GetState() const { + auto& result = CompCtx.MutableValues[StateIndex]; + if (!result.HasValue()) { + result = CompCtx.HolderFactory.Create<TState>(); + } + return *static_cast<TState*>(result.AsBoxed().Get()); + } + TComputationContext& CompCtx; const NUdf::TUnboxedValue Stream; IComputationExternalNode* const Item; IComputationExternalNode* const NextItem; IComputationNode* const NewItem; - std::optional<NUdf::TUnboxedValue> Prev; - bool Finish = false; + const ui32 StateIndex; }; IComputationNode* const Stream; IComputationExternalNode* const Item; IComputationExternalNode* const NextItem; IComputationNode* const NewItem; + const ui32 StateIndex; }; } @@ -193,7 +214,7 @@ IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactory const auto newItem = LocateNode(ctx.NodeLocator, callable, 3); if (type->IsFlow()) { - return new TFlowMapNextWrapper(GetValueRepresentation(type), input, itemArg, nextItemArg, newItem); + return new TFlowMapNextWrapper(ctx.Mutables, GetValueRepresentation(type), input, itemArg, nextItemArg, newItem); } else if (type->IsStream()) { return new TStreamMapNextWrapper(ctx.Mutables, input, itemArg, nextItemArg, newItem); } diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp new file mode 100644 index 00000000000..a416803e447 --- /dev/null +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_mapnext_ut.cpp @@ -0,0 +1,172 @@ +#include "mkql_computation_node_ut.h" +#include <ydb/library/yql/minikql/mkql_runtime_version.h> + +namespace NKikimr { +namespace NMiniKQL { + +Y_UNIT_TEST_SUITE(TMiniKQLMapNextTest) { + Y_UNIT_TEST_LLVM(OverStream) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto data1 = pb.NewDataLiteral<ui16>(10); + const auto data2 = pb.NewDataLiteral<ui16>(20); + const auto data3 = pb.NewDataLiteral<ui16>(30); + + const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id); + const auto optDataType = pb.NewOptionalType(dataType); + const auto tupleType = pb.NewTupleType({dataType, optDataType}); + + const auto list = pb.NewList(dataType, {data1, data2, data3}); + const auto pgmReturn = pb.MapNext(pb.Iterator(list, {}), + [&](TRuntimeNode item, TRuntimeNode nextItem) { + return pb.NewTuple(tupleType, {item, nextItem}); + }); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue(); + NUdf::TUnboxedValue item; + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 20); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 20); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 30); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 30); + UNIT_ASSERT(!item.GetElement(1).HasValue()); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + } + + Y_UNIT_TEST_LLVM(OverSingleElementStream) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto data1 = pb.NewDataLiteral<ui16>(10); + + const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id); + const auto optDataType = pb.NewOptionalType(dataType); + const auto tupleType = pb.NewTupleType({dataType, optDataType}); + + const auto list = pb.NewList(dataType, {data1}); + const auto pgmReturn = pb.MapNext(pb.Iterator(list, {}), + [&](TRuntimeNode item, TRuntimeNode nextItem) { + return pb.NewTuple(tupleType, {item, nextItem}); + }); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue(); + NUdf::TUnboxedValue item; + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10); + UNIT_ASSERT(!item.GetElement(1).HasValue()); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + } + + Y_UNIT_TEST_LLVM(OverEmptyStream) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id); + const auto optDataType = pb.NewOptionalType(dataType); + const auto tupleType = pb.NewTupleType({dataType, optDataType}); + + const auto list = pb.NewList(dataType, {}); + const auto pgmReturn = pb.MapNext(pb.Iterator(list, {}), + [&](TRuntimeNode item, TRuntimeNode nextItem) { + return pb.NewTuple(tupleType, {item, nextItem}); + }); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue(); + NUdf::TUnboxedValue item; + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + } + + Y_UNIT_TEST_LLVM(OverFlow) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto data1 = pb.NewDataLiteral<ui16>(10); + const auto data2 = pb.NewDataLiteral<ui16>(20); + const auto data3 = pb.NewDataLiteral<ui16>(30); + + const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id); + const auto optDataType = pb.NewOptionalType(dataType); + const auto tupleType = pb.NewTupleType({dataType, optDataType}); + + const auto list = pb.NewList(dataType, {data1, data2, data3}); + const auto pgmReturn = pb.FromFlow(pb.MapNext(pb.ToFlow(pb.Iterator(list, {})), + [&](TRuntimeNode item, TRuntimeNode nextItem) { + return pb.NewTuple(tupleType, {item, nextItem}); + })); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue(); + NUdf::TUnboxedValue item; + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 20); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 20); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).template Get<ui16>(), 30); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 30); + UNIT_ASSERT(!item.GetElement(1).HasValue()); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + } + + Y_UNIT_TEST_LLVM(OverSingleElementFlow) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto data1 = pb.NewDataLiteral<ui16>(10); + + const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id); + const auto optDataType = pb.NewOptionalType(dataType); + const auto tupleType = pb.NewTupleType({dataType, optDataType}); + + const auto list = pb.NewList(dataType, {data1}); + const auto pgmReturn = pb.FromFlow(pb.MapNext(pb.ToFlow(pb.Iterator(list, {})), + [&](TRuntimeNode item, TRuntimeNode nextItem) { + return pb.NewTuple(tupleType, {item, nextItem}); + })); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue(); + NUdf::TUnboxedValue item; + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Ok, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).template Get<ui16>(), 10); + UNIT_ASSERT(!item.GetElement(1).HasValue()); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + } + + Y_UNIT_TEST_LLVM(OverEmptyFlow) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto dataType = pb.NewDataType(NUdf::TDataType<ui16>::Id); + const auto optDataType = pb.NewOptionalType(dataType); + const auto tupleType = pb.NewTupleType({dataType, optDataType}); + + const auto list = pb.NewList(dataType, {}); + const auto pgmReturn = pb.FromFlow(pb.MapNext(pb.ToFlow(pb.Iterator(list, {})), + [&](TRuntimeNode item, TRuntimeNode nextItem) { + return pb.NewTuple(tupleType, {item, nextItem}); + })); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue(); + NUdf::TUnboxedValue item; + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + UNIT_ASSERT_VALUES_EQUAL(NUdf::EFetchStatus::Finish, iterator.Fetch(item)); + } +} + +} +} diff --git a/ydb/library/yql/minikql/comp_nodes/ut/ya.make b/ydb/library/yql/minikql/comp_nodes/ut/ya.make index dffda1318f3..7c7505394c2 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/ya.make +++ b/ydb/library/yql/minikql/comp_nodes/ut/ya.make @@ -52,6 +52,7 @@ SRCS( mkql_wide_map_ut.cpp mkql_wide_nodes_ut.cpp mkql_listfromrange_ut.cpp + mkql_mapnext_ut.cpp ) PEERDIR( |