summaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/computation/mkql_computation_node_dict_ut.cpp
blob: 7c940247b21fdfc3727a5010ec5fb33ad0db065b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <yql/essentials/minikql/mkql_node.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_program_builder.h>
#include <yql/essentials/minikql/mkql_function_registry.h>
#include <yql/essentials/minikql/computation/mkql_computation_node.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
#include <yql/essentials/minikql/comp_nodes/mkql_factories.h>

#include <library/cpp/testing/unittest/registar.h>

#include <vector>
#include <utility>
#include <algorithm>

namespace NKikimr {
namespace NMiniKQL {

namespace {
struct TSetup {
    TSetup(TScopedAlloc& alloc)
        : Alloc(alloc)
    {
        FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
        RandomProvider = CreateDeterministicRandomProvider(1);
        TimeProvider = CreateDeterministicTimeProvider(10000000);

        Env.Reset(new TTypeEnvironment(Alloc));
        PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
    }

    THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
        Explorer.Walk(pgm.GetNode(), *Env);
        TComputationPatternOpts opts(Alloc.Ref(), *Env, GetBuiltinFactory(),
                                     FunctionRegistry.Get(),
                                     NUdf::EValidateMode::None, NUdf::EValidatePolicy::Exception, "OFF", EGraphPerProcess::Multi);
        Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
        TComputationOptsFull compOpts = opts.ToComputationOptions(*RandomProvider, *TimeProvider);
        return Pattern->Clone(compOpts);
    }

    TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
    TIntrusivePtr<IRandomProvider> RandomProvider;
    TIntrusivePtr<ITimeProvider> TimeProvider;

    TScopedAlloc& Alloc;
    THolder<TTypeEnvironment> Env;
    THolder<TProgramBuilder> PgmBuilder;

    TExploringNodeVisitor Explorer;
    IComputationPattern::TPtr Pattern;
};
} // namespace

Y_UNIT_TEST_SUITE(TestCompactMultiDict) {
Y_UNIT_TEST(TestIterate) {
    TScopedAlloc alloc(__LOCATION__);

    TSetup setup(alloc);

    const std::vector<std::pair<ui32, std::vector<ui32>>> items = {{1, {1, 2}}, {2, {1}}, {3, {0}}, {6, {1, 7}}};

    TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
    TVector<TRuntimeNode> rItems;
    for (auto& [k, vv] : items) {
        for (auto& v : vv) {
            rItems.push_back(pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui32>(k), pgmBuilder.NewDataLiteral<ui32>(v)}));
        }
    }
    auto ui32Type = pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id);
    auto list = pgmBuilder.NewList(pgmBuilder.NewTupleType({ui32Type, ui32Type}), rItems);

    auto dict = pgmBuilder.ToHashedDict(list, /*all*/ true,
                                        [&pgmBuilder](TRuntimeNode item) { return pgmBuilder.Nth(item, 0); },
                                        [&pgmBuilder](TRuntimeNode item) { return pgmBuilder.Nth(item, 1); },
                                        /*isCompact*/ true,
                                        items.size());

    auto graph = setup.BuildGraph(dict, {});
    NUdf::TUnboxedValue res = graph->GetValue();

    std::vector<ui32> keyVals;
    for (NUdf::TUnboxedValue keys = res.GetKeysIterator(), v; keys.Next(v);) {
        keyVals.push_back(v.Get<ui32>());
    }
    UNIT_ASSERT_VALUES_EQUAL(keyVals.size(), items.size());
    std::sort(keyVals.begin(), keyVals.end());
    UNIT_ASSERT(
        std::equal(keyVals.begin(), keyVals.end(), items.begin(),
                   [](ui32 l, const std::pair<ui32, std::vector<ui32>>& r) { return l == r.first; }));

    std::vector<std::vector<ui32>> origPayloads;
    for (auto& [k, vv] : items) {
        origPayloads.push_back(vv);
        std::sort(origPayloads.back().begin(), origPayloads.back().end());
    }
    std::sort(origPayloads.begin(), origPayloads.end());

    std::vector<std::vector<ui32>> payloadVals;
    for (NUdf::TUnboxedValue payloads = res.GetPayloadsIterator(), v; payloads.Next(v);) {
        payloadVals.emplace_back();
        for (NUdf::TUnboxedValue i = v.GetListIterator(), p; i.Next(p);) {
            payloadVals.back().push_back(p.Get<ui32>());
        }
        std::sort(payloadVals.back().begin(), payloadVals.back().end());
    }
    std::sort(payloadVals.begin(), payloadVals.end());
    UNIT_ASSERT_VALUES_EQUAL(origPayloads, payloadVals);

    std::vector<std::pair<ui32, std::vector<ui32>>> vals;
    for (NUdf::TUnboxedValue values = res.GetDictIterator(), k, payloads; values.NextPair(k, payloads);) {
        vals.emplace_back(k.Get<ui32>(), std::vector<ui32>{});
        for (NUdf::TUnboxedValue i = payloads.GetListIterator(), p; i.Next(p);) {
            vals.back().second.push_back(p.Get<ui32>());
        }
        std::sort(vals.back().second.begin(), vals.back().second.end());
    }
    UNIT_ASSERT_VALUES_EQUAL(items, vals);
}
} // Y_UNIT_TEST_SUITE(TestCompactMultiDict)
} // namespace NMiniKQL
} // namespace NKikimr