aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/ut/mkql_test_factory.cpp
blob: c30cd86c88a02f35437230a04bd99ccfac1e12cd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include "mkql_computation_node_ut.h"

#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_string_util.h>


namespace NKikimr {
namespace NMiniKQL {

namespace {


ui64 g_Yield = std::numeric_limits<ui64>::max();
ui64 g_TestStreamData[] = {0, 0, 1, 0, 0, 0, 1, 2, 3};
ui64 g_TestYieldStreamData[] = {0, 1, 2, g_Yield, 0, g_Yield, 1, 2, 0, 1, 2, 0, g_Yield, 1, 2};

class TTestStreamWrapper: public TMutableComputationNode<TTestStreamWrapper> {
    typedef TMutableComputationNode<TTestStreamWrapper> TBaseComputation;
public:
    class TStreamValue : public TComputationValue<TStreamValue> {
    public:
        using TBase = TComputationValue<TStreamValue>;

        TStreamValue(TMemoryUsageInfo* memInfo, ui64 count)
            : TBase(memInfo)
            , Count(count)
        {
        }

    private:
        NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
            if (Index == Count) {
                return NUdf::EFetchStatus::Finish;
            }

            result = NUdf::TUnboxedValuePod(g_TestStreamData[Index++]);
            return NUdf::EFetchStatus::Ok;
        }

    private:
        ui64 Index = 0;
        const ui64 Count;
    };

    TTestStreamWrapper(TComputationMutables& mutables, ui64 count)
        : TBaseComputation(mutables)
        , Count(Min<ui64>(count, Y_ARRAY_SIZE(g_TestStreamData)))
    {
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        return ctx.HolderFactory.Create<TStreamValue>(Count);
    }

private:
    void RegisterDependencies() const final {}

private:
    const ui64 Count;
};

class TTestYieldStreamWrapper: public TMutableComputationNode<TTestYieldStreamWrapper> {
    typedef TMutableComputationNode<TTestYieldStreamWrapper> TBaseComputation;
public:
    class TStreamValue : public TComputationValue<TStreamValue> {
    public:
        using TBase = TComputationValue<TStreamValue>;

        TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx)
            : TBase(memInfo)
            , CompCtx(compCtx) {}

    private:
        NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
            if (Index == Y_ARRAY_SIZE(g_TestYieldStreamData)) {
                return NUdf::EFetchStatus::Finish;
            }

            const auto value = g_TestYieldStreamData[Index];
            if (value == g_Yield) {
                ++Index;
                return NUdf::EFetchStatus::Yield;
            }

            NUdf::TUnboxedValue* items = nullptr;
            result = CompCtx.HolderFactory.CreateDirectArrayHolder(2, items);
            items[0] = NUdf::TUnboxedValuePod(value);
            items[1] = MakeString(ToString(Index));

            ++Index;
            return NUdf::EFetchStatus::Ok;
        }

    private:
        TComputationContext& CompCtx;
        ui64 Index = 0;
    };

    TTestYieldStreamWrapper(TComputationMutables& mutables)
        : TBaseComputation(mutables) {}

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        return ctx.HolderFactory.Create<TStreamValue>(ctx);
    }

private:
    void RegisterDependencies() const final {}
};

IComputationNode* WrapTestStream(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args");
    const ui64 count = AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().Get<ui64>();
    return new TTestStreamWrapper(ctx.Mutables, count);
}


IComputationNode* WrapTestYieldStream(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(!callable.GetInputsCount(), "Expected no args");
    return new TTestYieldStreamWrapper(ctx.Mutables);
}

}

TComputationNodeFactory GetTestFactory(TComputationNodeFactory customFactory) {
    return [customFactory](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
        if (callable.GetType()->GetName() == "TestList") {
            return new TExternalComputationNode(ctx.Mutables);
        }

        if (callable.GetType()->GetName() == "TestStream") {
            return WrapTestStream(callable, ctx);
        }

        if (callable.GetType()->GetName() == "TestYieldStream") {
            return WrapTestYieldStream(callable, ctx);
        }

        if (customFactory) {
            auto ret = customFactory(callable, ctx);
            if (ret) {
                return ret;
            }
        }

        return GetBuiltinFactory()(callable, ctx);
    };
}

}
}