aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp
diff options
context:
space:
mode:
authorvvvv <vvvv@yandex-team.com>2024-11-07 04:19:26 +0300
committervvvv <vvvv@yandex-team.com>2024-11-07 04:29:50 +0300
commit2661be00f3bc47590fda9218bf0386d6355c8c88 (patch)
tree3d316c07519191283d31c5f537efc6aabb42a2f0 /yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp
parentcf2a23963ac10add28c50cc114fbf48953eca5aa (diff)
downloadydb-2661be00f3bc47590fda9218bf0386d6355c8c88.tar.gz
Moved yql/minikql YQL-19206
init [nodiff:caesar] commit_hash:d1182ef7d430ccf7e4d37ed933c7126d7bd5d6e4
Diffstat (limited to 'yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp')
-rw-r--r--yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp381
1 files changed, 381 insertions, 0 deletions
diff --git a/yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp b/yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp
new file mode 100644
index 0000000000..dc3b20ddcb
--- /dev/null
+++ b/yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp
@@ -0,0 +1,381 @@
+#include <yql/essentials/minikql/mkql_node.h>
+#include <yql/essentials/minikql/mkql_node_cast.h>
+#include <yql/essentials/minikql/mkql_program_builder.h>
+#include <yql/essentials/minikql/mkql_function_registry.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
+#include <yql/essentials/minikql/computation/mkql_computation_node_graph_saveload.h>
+#include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
+#include <yql/essentials/minikql/comp_nodes/mkql_factories.h>
+
+#include <library/cpp/testing/unittest/registar.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+namespace {
+ TIntrusivePtr<IRandomProvider> CreateRandomProvider() {
+ return CreateDeterministicRandomProvider(1);
+ }
+
+ TIntrusivePtr<ITimeProvider> CreateTimeProvider() {
+ return CreateDeterministicTimeProvider(10000000);
+ }
+
+ TComputationNodeFactory GetAuxCallableFactory() {
+ return [](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
+ if (callable.GetType()->GetName() == "OneYieldStream") {
+ return new TExternalComputationNode(ctx.Mutables);
+ }
+
+ return GetBuiltinFactory()(callable, ctx);
+ };
+ }
+
+ struct TSetup {
+ TSetup(TScopedAlloc& alloc)
+ : Alloc(alloc)
+ {
+ FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
+ RandomProvider = CreateRandomProvider();
+ TimeProvider = CreateTimeProvider();
+
+ Env.Reset(new TTypeEnvironment(Alloc));
+ PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
+ }
+
+ THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
+ Explorer.Walk(pgm.GetNode(), *Env);
+ TComputationPatternOpts opts(Alloc.Ref(), *Env, GetAuxCallableFactory(),
+ FunctionRegistry.Get(),
+ NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "OFF", EGraphPerProcess::Multi);
+ Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
+ TComputationOptsFull compOpts = opts.ToComputationOptions(*RandomProvider, *TimeProvider);
+ return Pattern->Clone(compOpts);
+ }
+
+ TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
+ TIntrusivePtr<IRandomProvider> RandomProvider;
+ TIntrusivePtr<ITimeProvider> TimeProvider;
+
+ TScopedAlloc& Alloc;
+ THolder<TTypeEnvironment> Env;
+ THolder<TProgramBuilder> PgmBuilder;
+
+ TExploringNodeVisitor Explorer;
+ IComputationPattern::TPtr Pattern;
+ };
+
+ struct TStreamWithYield : public NUdf::TBoxedValue {
+ TStreamWithYield(const TUnboxedValueVector& items, ui32 yieldPos, ui32 index)
+ : Items(items)
+ , YieldPos(yieldPos)
+ , Index(index)
+ {}
+
+ private:
+ TUnboxedValueVector Items;
+ ui32 YieldPos;
+ ui32 Index;
+
+ ui32 GetTraverseCount() const override {
+ return 0;
+ }
+
+ NUdf::TUnboxedValue Save() const override {
+ return NUdf::TUnboxedValue::Zero();
+ }
+
+ bool Load2(const NUdf::TUnboxedValue& state) override {
+ Y_UNUSED(state);
+ return false;
+ }
+
+ NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
+ if (Index >= Items.size()) {
+ return NUdf::EFetchStatus::Finish;
+ }
+ if (Index == YieldPos) {
+ return NUdf::EFetchStatus::Yield;
+ }
+ result = Items[Index++];
+ return NUdf::EFetchStatus::Ok;
+ }
+ };
+}
+
+Y_UNIT_TEST_SUITE(TMiniKQLSaveLoadTest) {
+ Y_UNIT_TEST(TestSqueezeSaveLoad) {
+ TScopedAlloc alloc(__LOCATION__);
+
+ const std::vector<ui32> items = {2, 3, 4, 5, 6, 7, 8};
+
+ auto buildGraph = [&items] (TSetup& setup, ui32 yieldPos, ui32 startIndex) -> THolder<IComputationGraph> {
+ TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
+
+ auto dataType = pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id);
+ auto streamType = pgmBuilder.NewStreamType(dataType);
+
+ TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", streamType);
+ auto streamNode = inStream.Build();
+
+ auto pgmReturn = pgmBuilder.Squeeze(
+ TRuntimeNode(streamNode, false),
+ pgmBuilder.NewDataLiteral<ui32>(1),
+ [&](TRuntimeNode item, TRuntimeNode state) {
+ return pgmBuilder.Add(item, state);
+ },
+ [](TRuntimeNode state) {
+ return state;
+ },
+ [](TRuntimeNode state) {
+ return state;
+ });
+
+ TUnboxedValueVector streamItems;
+ for (auto item : items) {
+ streamItems.push_back(NUdf::TUnboxedValuePod(item));
+ }
+
+ auto graph = setup.BuildGraph(pgmReturn, {streamNode});
+ auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
+ graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
+ return graph;
+ };
+
+ for (ui32 yieldPos = 0; yieldPos < items.size(); ++yieldPos) {
+ TSetup setup1(alloc);
+ auto graph1 = buildGraph(setup1, yieldPos, 0);
+
+ auto root1 = graph1->GetValue();
+ NUdf::TUnboxedValue res;
+ auto status = root1.Fetch(res);
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);
+
+ TString graphState;
+ SaveGraphState(&root1, 1, 0ULL, graphState);
+
+ TSetup setup2(alloc);
+ auto graph2 = buildGraph(setup2, -1, yieldPos);
+
+ auto root2 = graph2->GetValue();
+ LoadGraphState(&root2, 1, 0ULL, graphState);
+
+ status = root2.Fetch(res);
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Ok);
+ UNIT_ASSERT_EQUAL(res.Get<ui32>(), 36);
+
+ status = root2.Fetch(res);
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);
+ }
+ }
+
+ Y_UNIT_TEST(TestSqueeze1SaveLoad) {
+ TScopedAlloc alloc(__LOCATION__);
+
+ const std::vector<ui32> items = {1, 2, 3, 4, 5, 6, 7, 8};
+
+ auto buildGraph = [&items] (TSetup& setup, ui32 yieldPos, ui32 startIndex) -> THolder<IComputationGraph> {
+ TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
+
+ auto dataType = pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id);
+ auto streamType = pgmBuilder.NewStreamType(dataType);
+
+ TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", streamType);
+ auto streamNode = inStream.Build();
+
+ auto pgmReturn = pgmBuilder.Squeeze1(
+ TRuntimeNode(streamNode, false),
+ [](TRuntimeNode item) {
+ return item;
+ },
+ [&](TRuntimeNode item, TRuntimeNode state) {
+ return pgmBuilder.Add(item, state);
+ },
+ [](TRuntimeNode state) {
+ return state;
+ },
+ [](TRuntimeNode state) {
+ return state;
+ });
+
+ TUnboxedValueVector streamItems;
+ for (auto item : items) {
+ streamItems.push_back(NUdf::TUnboxedValuePod(item));
+ }
+
+ auto graph = setup.BuildGraph(pgmReturn, {streamNode});
+ auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
+ graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
+ return graph;
+ };
+
+ for (ui32 yieldPos = 0; yieldPos < items.size(); ++yieldPos) {
+ TSetup setup1(alloc);
+ auto graph1 = buildGraph(setup1, yieldPos, 0);
+
+ auto root1 = graph1->GetValue();
+
+ NUdf::TUnboxedValue res;
+ auto status = root1.Fetch(res);
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);
+
+ TString graphState;
+ SaveGraphState(&root1, 1, 0ULL, graphState);
+
+ TSetup setup2(alloc);
+ auto graph2 = buildGraph(setup2, -1, yieldPos);
+
+ auto root2 = graph2->GetValue();
+ LoadGraphState(&root2, 1, 0ULL, graphState);
+
+ status = root2.Fetch(res);
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Ok);
+ UNIT_ASSERT_EQUAL(res.Get<ui32>(), 36);
+
+ status = root2.Fetch(res);
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);
+ }
+ }
+
+ Y_UNIT_TEST(TestHoppingSaveLoad) {
+ TScopedAlloc alloc(__LOCATION__);
+
+ const std::vector<std::pair<i64, ui32>> items = {
+ {1, 2},
+ {2, 3},
+ {15, 4},
+ {23, 6},
+ {24, 5},
+ {25, 7},
+ {40, 2},
+ {47, 1},
+ {51, 6},
+ {59, 2},
+ {85, 8},
+ {55, 1000},
+ {200, 0}
+ };
+
+ auto buildGraph = [&items] (TSetup& setup, ui32 yieldPos, ui32 startIndex) -> THolder<IComputationGraph> {
+ TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
+
+ auto structType = pgmBuilder.NewEmptyStructType();
+ structType = pgmBuilder.NewStructType(structType, "time",
+ pgmBuilder.NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id));
+ structType = pgmBuilder.NewStructType(structType, "sum",
+ pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id));
+ auto timeIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("time");
+ auto sumIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("sum");
+
+ auto inStreamType = pgmBuilder.NewStreamType(structType);
+
+ TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", inStreamType);
+ auto streamNode = inStream.Build();
+
+ ui64 hop = 10, interval = 30, delay = 20;
+
+ auto pgmReturn = pgmBuilder.HoppingCore(
+ TRuntimeNode(streamNode, false),
+ [&](TRuntimeNode item) { // timeExtractor
+ return pgmBuilder.Member(item, "time");
+ },
+ [&](TRuntimeNode item) { // init
+ std::vector<std::pair<std::string_view, TRuntimeNode>> members;
+ members.emplace_back("sum", pgmBuilder.Member(item, "sum"));
+ return pgmBuilder.NewStruct(members);
+ },
+ [&](TRuntimeNode item, TRuntimeNode state) { // update
+ auto add = pgmBuilder.AggrAdd(
+ pgmBuilder.Member(item, "sum"),
+ pgmBuilder.Member(state, "sum"));
+ std::vector<std::pair<std::string_view, TRuntimeNode>> members;
+ members.emplace_back("sum", add);
+ return pgmBuilder.NewStruct(members);
+ },
+ [&](TRuntimeNode state) { // save
+ return pgmBuilder.Member(state, "sum");
+ },
+ [&](TRuntimeNode savedState) { // load
+ std::vector<std::pair<std::string_view, TRuntimeNode>> members;
+ members.emplace_back("sum", savedState);
+ return pgmBuilder.NewStruct(members);
+ },
+ [&](TRuntimeNode state1, TRuntimeNode state2) { // merge
+ auto add = pgmBuilder.AggrAdd(
+ pgmBuilder.Member(state1, "sum"),
+ pgmBuilder.Member(state2, "sum"));
+ std::vector<std::pair<std::string_view, TRuntimeNode>> members;
+ members.emplace_back("sum", add);
+ return pgmBuilder.NewStruct(members);
+ },
+ [&](TRuntimeNode state, TRuntimeNode time) { // finish
+ Y_UNUSED(time);
+ std::vector<std::pair<std::string_view, TRuntimeNode>> members;
+ members.emplace_back("sum", pgmBuilder.Member(state, "sum"));
+ return pgmBuilder.NewStruct(members);
+ },
+ pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&hop, sizeof(hop))), // hop
+ pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&interval, sizeof(interval))), // interval
+ pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&delay, sizeof(delay))) // delay
+ );
+
+ auto graph = setup.BuildGraph(pgmReturn, {streamNode});
+
+ TUnboxedValueVector streamItems;
+ for (size_t i = 0; i < items.size(); ++i) {
+ NUdf::TUnboxedValue* itemsPtr;
+ auto structValues = graph->GetHolderFactory().CreateDirectArrayHolder(2, itemsPtr);
+ itemsPtr[timeIndex] = NUdf::TUnboxedValuePod(items[i].first);
+ itemsPtr[sumIndex] = NUdf::TUnboxedValuePod(items[i].second);
+ streamItems.push_back(std::move(structValues));
+ }
+
+ auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
+ graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
+ return graph;
+ };
+
+ for (ui32 yieldPos = 0; yieldPos < items.size(); ++yieldPos) {
+ std::vector<ui32> result;
+
+ TSetup setup1(alloc);
+ auto graph1 = buildGraph(setup1, yieldPos, 0);
+ auto root1 = graph1->GetValue();
+
+ NUdf::EFetchStatus status = NUdf::EFetchStatus::Ok;
+ while (status == NUdf::EFetchStatus::Ok) {
+ NUdf::TUnboxedValue val;
+ status = root1.Fetch(val);
+ if (status == NUdf::EFetchStatus::Ok) {
+ result.push_back(val.GetElement(0).Get<ui32>());
+ }
+ }
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);
+
+ TString graphState;
+ SaveGraphState(&root1, 1, 0ULL, graphState);
+
+ TSetup setup2(alloc);
+ auto graph2 = buildGraph(setup2, -1, yieldPos);
+ auto root2 = graph2->GetValue();
+ LoadGraphState(&root2, 1, 0ULL, graphState);
+
+ status = NUdf::EFetchStatus::Ok;
+ while (status == NUdf::EFetchStatus::Ok) {
+ NUdf::TUnboxedValue val;
+ status = root2.Fetch(val);
+ if (status == NUdf::EFetchStatus::Ok) {
+ result.push_back(val.GetElement(0).Get<ui32>());
+ }
+ }
+ UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);
+
+ const std::vector<ui32> resultCompare = {5, 9, 27, 22, 21, 11, 11, 8, 8, 8, 8};
+ UNIT_ASSERT_EQUAL(result, resultCompare);
+ }
+ }
+}
+
+} // namespace NMiniKQL
+} // namespace NKikimr