aboutsummaryrefslogblamecommitdiffstats
path: root/yql/essentials/minikql/computation/mkql_computation_node_graph_saveload_ut.cpp
blob: dc3b20ddcba3f06420297bd2c7371f70e272ef34 (plain) (tree)



























































































































































































































































































































































































                                                                                                                                             
#include <yql/essentials/minikql/mkql_node.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_program_builder.h>
#include <yql/essentials/minikql/mkql_function_registry.h>
#include <yql/essentials/minikql/computation/mkql_computation_node.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_graph_saveload.h>
#include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
#include <yql/essentials/minikql/comp_nodes/mkql_factories.h>

#include <library/cpp/testing/unittest/registar.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {
    TIntrusivePtr<IRandomProvider> CreateRandomProvider() {
        return CreateDeterministicRandomProvider(1);
    }

    TIntrusivePtr<ITimeProvider> CreateTimeProvider() {
        return CreateDeterministicTimeProvider(10000000);
    }

    TComputationNodeFactory GetAuxCallableFactory() {
        return [](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
            if (callable.GetType()->GetName() == "OneYieldStream") {
                return new TExternalComputationNode(ctx.Mutables);
            }

            return GetBuiltinFactory()(callable, ctx);
        };
    }

    struct TSetup {
        TSetup(TScopedAlloc& alloc)
            : Alloc(alloc)
        {
            FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
            RandomProvider = CreateRandomProvider();
            TimeProvider = CreateTimeProvider();

            Env.Reset(new TTypeEnvironment(Alloc));
            PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
        }

        THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
            Explorer.Walk(pgm.GetNode(), *Env);
            TComputationPatternOpts opts(Alloc.Ref(), *Env, GetAuxCallableFactory(),
                FunctionRegistry.Get(),
                NUdf::EValidateMode::None, NUdf::EValidatePolicy::Fail, "OFF", EGraphPerProcess::Multi);
            Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
            TComputationOptsFull compOpts = opts.ToComputationOptions(*RandomProvider, *TimeProvider);
            return Pattern->Clone(compOpts);
        }

        TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
        TIntrusivePtr<IRandomProvider> RandomProvider;
        TIntrusivePtr<ITimeProvider> TimeProvider;

        TScopedAlloc& Alloc;
        THolder<TTypeEnvironment> Env;
        THolder<TProgramBuilder> PgmBuilder;

        TExploringNodeVisitor Explorer;
        IComputationPattern::TPtr Pattern;
    };

    struct TStreamWithYield : public NUdf::TBoxedValue {
        TStreamWithYield(const TUnboxedValueVector& items, ui32 yieldPos, ui32 index)
            : Items(items)
            , YieldPos(yieldPos)
            , Index(index)
        {}

    private:
        TUnboxedValueVector Items;
        ui32 YieldPos;
        ui32 Index;

        ui32 GetTraverseCount() const override {
            return 0;
        }

        NUdf::TUnboxedValue Save() const override {
            return NUdf::TUnboxedValue::Zero();
        }

        bool Load2(const NUdf::TUnboxedValue& state) override {
            Y_UNUSED(state);
            return false;
        }

        NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
            if (Index >= Items.size()) {
                return NUdf::EFetchStatus::Finish;
            }
            if (Index == YieldPos) {
                return NUdf::EFetchStatus::Yield;
            }
            result = Items[Index++];
            return NUdf::EFetchStatus::Ok;
        }
    };
}

Y_UNIT_TEST_SUITE(TMiniKQLSaveLoadTest) {
    Y_UNIT_TEST(TestSqueezeSaveLoad) {
        TScopedAlloc alloc(__LOCATION__);

        const std::vector<ui32> items = {2, 3, 4, 5, 6, 7, 8};

        auto buildGraph = [&items] (TSetup& setup, ui32 yieldPos, ui32 startIndex) -> THolder<IComputationGraph> {
            TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

            auto dataType = pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id);
            auto streamType = pgmBuilder.NewStreamType(dataType);

            TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", streamType);
            auto streamNode = inStream.Build();

            auto pgmReturn = pgmBuilder.Squeeze(
                TRuntimeNode(streamNode, false),
                pgmBuilder.NewDataLiteral<ui32>(1),
                [&](TRuntimeNode item, TRuntimeNode state) {
                    return pgmBuilder.Add(item, state);
                },
                [](TRuntimeNode state) {
                    return state;
                },
                [](TRuntimeNode state) {
                    return state;
                });

            TUnboxedValueVector streamItems;
            for (auto item : items) {
                streamItems.push_back(NUdf::TUnboxedValuePod(item));
            }

            auto graph = setup.BuildGraph(pgmReturn, {streamNode});
            auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
            graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
            return graph;
        };

        for (ui32 yieldPos = 0; yieldPos < items.size(); ++yieldPos) {
            TSetup setup1(alloc);
            auto graph1 = buildGraph(setup1, yieldPos, 0);

            auto root1 = graph1->GetValue();
            NUdf::TUnboxedValue res;
            auto status = root1.Fetch(res);
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);

            TString graphState;
            SaveGraphState(&root1, 1, 0ULL, graphState);

            TSetup setup2(alloc);
            auto graph2 = buildGraph(setup2, -1, yieldPos);

            auto root2 = graph2->GetValue();
            LoadGraphState(&root2, 1, 0ULL, graphState);

            status = root2.Fetch(res);
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Ok);
            UNIT_ASSERT_EQUAL(res.Get<ui32>(), 36);

            status = root2.Fetch(res);
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);
        }
    }

    Y_UNIT_TEST(TestSqueeze1SaveLoad) {
        TScopedAlloc alloc(__LOCATION__);

        const std::vector<ui32> items = {1, 2, 3, 4, 5, 6, 7, 8};

        auto buildGraph = [&items] (TSetup& setup, ui32 yieldPos, ui32 startIndex) -> THolder<IComputationGraph> {
            TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

            auto dataType = pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id);
            auto streamType = pgmBuilder.NewStreamType(dataType);

            TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", streamType);
            auto streamNode = inStream.Build();

            auto pgmReturn = pgmBuilder.Squeeze1(
                TRuntimeNode(streamNode, false),
                [](TRuntimeNode item) {
                    return item;
                },
                [&](TRuntimeNode item, TRuntimeNode state) {
                    return pgmBuilder.Add(item, state);
                },
                [](TRuntimeNode state) {
                    return state;
                },
                [](TRuntimeNode state) {
                    return state;
                });

            TUnboxedValueVector streamItems;
            for (auto item : items) {
                streamItems.push_back(NUdf::TUnboxedValuePod(item));
            }

            auto graph = setup.BuildGraph(pgmReturn, {streamNode});
            auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
            graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
            return graph;
        };

        for (ui32 yieldPos = 0; yieldPos < items.size(); ++yieldPos) {
            TSetup setup1(alloc);
            auto graph1 = buildGraph(setup1, yieldPos, 0);

            auto root1 = graph1->GetValue();

            NUdf::TUnboxedValue res;
            auto status = root1.Fetch(res);
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);

            TString graphState;
            SaveGraphState(&root1, 1, 0ULL, graphState);

            TSetup setup2(alloc);
            auto graph2 = buildGraph(setup2, -1, yieldPos);

            auto root2 = graph2->GetValue();
            LoadGraphState(&root2, 1, 0ULL, graphState);

            status = root2.Fetch(res);
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Ok);
            UNIT_ASSERT_EQUAL(res.Get<ui32>(), 36);

            status = root2.Fetch(res);
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);
        }
    }

    Y_UNIT_TEST(TestHoppingSaveLoad) {
        TScopedAlloc alloc(__LOCATION__);

        const std::vector<std::pair<i64, ui32>> items = {
            {1, 2},
            {2, 3},
            {15, 4},
            {23, 6},
            {24, 5},
            {25, 7},
            {40, 2},
            {47, 1},
            {51, 6},
            {59, 2},
            {85, 8},
            {55, 1000},
            {200, 0}
        };

        auto buildGraph = [&items] (TSetup& setup, ui32 yieldPos, ui32 startIndex) -> THolder<IComputationGraph> {
            TProgramBuilder& pgmBuilder = *setup.PgmBuilder;

            auto structType = pgmBuilder.NewEmptyStructType();
            structType = pgmBuilder.NewStructType(structType, "time",
                pgmBuilder.NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id));
            structType = pgmBuilder.NewStructType(structType, "sum",
                pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id));
            auto timeIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("time");
            auto sumIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("sum");

            auto inStreamType = pgmBuilder.NewStreamType(structType);

            TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", inStreamType);
            auto streamNode = inStream.Build();

            ui64 hop = 10, interval = 30, delay = 20;

            auto pgmReturn = pgmBuilder.HoppingCore(
                TRuntimeNode(streamNode, false),
                [&](TRuntimeNode item) { // timeExtractor
                    return pgmBuilder.Member(item, "time");
                },
                [&](TRuntimeNode item) { // init
                    std::vector<std::pair<std::string_view, TRuntimeNode>> members;
                    members.emplace_back("sum", pgmBuilder.Member(item, "sum"));
                    return pgmBuilder.NewStruct(members);
                },
                [&](TRuntimeNode item, TRuntimeNode state) { // update
                    auto add = pgmBuilder.AggrAdd(
                        pgmBuilder.Member(item, "sum"),
                        pgmBuilder.Member(state, "sum"));
                    std::vector<std::pair<std::string_view, TRuntimeNode>> members;
                    members.emplace_back("sum", add);
                    return pgmBuilder.NewStruct(members);
                },
                [&](TRuntimeNode state) { // save
                    return pgmBuilder.Member(state, "sum");
                },
                [&](TRuntimeNode savedState) { // load
                    std::vector<std::pair<std::string_view, TRuntimeNode>> members;
                    members.emplace_back("sum", savedState);
                    return pgmBuilder.NewStruct(members);
                },
                [&](TRuntimeNode state1, TRuntimeNode state2) { // merge
                    auto add = pgmBuilder.AggrAdd(
                        pgmBuilder.Member(state1, "sum"),
                        pgmBuilder.Member(state2, "sum"));
                    std::vector<std::pair<std::string_view, TRuntimeNode>> members;
                    members.emplace_back("sum", add);
                    return pgmBuilder.NewStruct(members);
                },
                [&](TRuntimeNode state, TRuntimeNode time) { // finish
                    Y_UNUSED(time);
                    std::vector<std::pair<std::string_view, TRuntimeNode>> members;
                    members.emplace_back("sum", pgmBuilder.Member(state, "sum"));
                    return pgmBuilder.NewStruct(members);
                },
                pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&hop, sizeof(hop))), // hop
                pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&interval, sizeof(interval))), // interval
                pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&delay, sizeof(delay)))  // delay
            );

            auto graph = setup.BuildGraph(pgmReturn, {streamNode});

            TUnboxedValueVector streamItems;
            for (size_t i = 0; i < items.size(); ++i) {
                NUdf::TUnboxedValue* itemsPtr;
                auto structValues = graph->GetHolderFactory().CreateDirectArrayHolder(2, itemsPtr);
                itemsPtr[timeIndex] = NUdf::TUnboxedValuePod(items[i].first);
                itemsPtr[sumIndex] = NUdf::TUnboxedValuePod(items[i].second);
                streamItems.push_back(std::move(structValues));
            }

            auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
            graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
            return graph;
        };

        for (ui32 yieldPos = 0; yieldPos < items.size(); ++yieldPos) {
            std::vector<ui32> result;

            TSetup setup1(alloc);
            auto graph1 = buildGraph(setup1, yieldPos, 0);
            auto root1 = graph1->GetValue();

            NUdf::EFetchStatus status = NUdf::EFetchStatus::Ok;
            while (status == NUdf::EFetchStatus::Ok) {
                NUdf::TUnboxedValue val;
                status = root1.Fetch(val);
                if (status == NUdf::EFetchStatus::Ok) {
                    result.push_back(val.GetElement(0).Get<ui32>());
                }
            }
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);

            TString graphState;
            SaveGraphState(&root1, 1, 0ULL, graphState);

            TSetup setup2(alloc);
            auto graph2 = buildGraph(setup2, -1, yieldPos);
            auto root2 = graph2->GetValue();
            LoadGraphState(&root2, 1, 0ULL, graphState);

            status = NUdf::EFetchStatus::Ok;
            while (status == NUdf::EFetchStatus::Ok) {
                NUdf::TUnboxedValue val;
                status = root2.Fetch(val);
                if (status == NUdf::EFetchStatus::Ok) {
                    result.push_back(val.GetElement(0).Get<ui32>());
                }
            }
            UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);

            const std::vector<ui32> resultCompare = {5, 9, 27, 22, 21, 11, 11, 8, 8, 8, 8};
            UNIT_ASSERT_EQUAL(result, resultCompare);
        }
    }
}

} // namespace NMiniKQL
} // namespace NKikimr