aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/public/purecalc/ut/test_mixed_allocators.cpp
blob: 797f3c5b5125c422679dce862fe6b035fd48716f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <library/cpp/testing/unittest/registar.h>

#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/mkql_string_util.h>

#include <yql/essentials/public/purecalc/io_specs/protobuf/spec.h>
#include <yql/essentials/public/purecalc/ut/protos/test_structs.pb.h>

using namespace NYql::NPureCalc;

namespace {
    class TStatelessInputSpec : public TInputSpecBase {
    public:
        TStatelessInputSpec()
            : Schemas_({NYT::TNode::CreateList()
                .Add("StructType")
                .Add(NYT::TNode::CreateList()
                    .Add(NYT::TNode::CreateList()
                        .Add("InputValue")
                        .Add(NYT::TNode::CreateList()
                            .Add("DataType")
                            .Add("Utf8")
                        )
                    )
                )
            })
        {};

        const TVector<NYT::TNode>& GetSchemas() const override {
            return Schemas_;
        }

    private:
        const TVector<NYT::TNode> Schemas_;
    };

    class TStatelessInputConsumer : public IConsumer<const NYql::NUdf::TUnboxedValue&> {
    public:
        TStatelessInputConsumer(TWorkerHolder<IPushStreamWorker> worker)
            : Worker_(std::move(worker))
        {}

        void OnObject(const NYql::NUdf::TUnboxedValue& value) override {
            with_lock (Worker_->GetScopedAlloc()) {
                NYql::NUdf::TUnboxedValue* items = nullptr;
                NYql::NUdf::TUnboxedValue result = Worker_->GetGraph().GetHolderFactory().CreateDirectArrayHolder(1, items);

                items[0] = value;

                Worker_->Push(std::move(result));

                // Clear graph after each object because
                // values allocated on another allocator and should be released
                Worker_->GetGraph().Invalidate();
            }
        }

        void OnFinish() override {
            with_lock(Worker_->GetScopedAlloc()) {
                Worker_->OnFinish();
            }
        }

    private:
        TWorkerHolder<IPushStreamWorker> Worker_;
    };

    class TStatelessConsumer : public IConsumer<NPureCalcProto::TStringMessage*> {
        const TString ExpectedData_;
        const ui64 ExpectedRows_;
        ui64 RowId_ = 0;

    public:
        TStatelessConsumer(const TString& expectedData, ui64 expectedRows)
            : ExpectedData_(expectedData)
            , ExpectedRows_(expectedRows)
        {}

        void OnObject(NPureCalcProto::TStringMessage* message) override {
            UNIT_ASSERT_VALUES_EQUAL_C(ExpectedData_, message->GetX(), RowId_);
            RowId_++;
        }

        void OnFinish() override {
            UNIT_ASSERT_VALUES_EQUAL(ExpectedRows_, RowId_);
        }
    };
}

template <>
struct TInputSpecTraits<TStatelessInputSpec> {
    static constexpr bool IsPartial = false;
    static constexpr bool SupportPushStreamMode = true;

    using TConsumerType = THolder<IConsumer<const NYql::NUdf::TUnboxedValue&>>;

    static TConsumerType MakeConsumer(const TStatelessInputSpec&, TWorkerHolder<IPushStreamWorker> worker) {
        return MakeHolder<TStatelessInputConsumer>(std::move(worker));
    }
};

Y_UNIT_TEST_SUITE(TestMixedAllocators) {
    Y_UNIT_TEST(TestPushStream) {
        const auto targetString = "large string >= 14 bytes";
        const auto factory = MakeProgramFactory();
        const auto sql = TStringBuilder() << "SELECT InputValue AS X FROM Input WHERE InputValue = \"" << targetString << "\";";

        const auto program = factory->MakePushStreamProgram(
            TStatelessInputSpec(),
            TProtobufOutputSpec<NPureCalcProto::TStringMessage>(),
            sql
        );

        const ui64 numberRows = 5;
        const auto inputConsumer = program->Apply(MakeHolder<TStatelessConsumer>(targetString, numberRows));
        NKikimr::NMiniKQL::TScopedAlloc alloc(__LOCATION__, NKikimr::TAlignedPagePoolCounters(), true, false);

        const auto pushString = [&](TString inputValue) {
            NYql::NUdf::TUnboxedValue stringValue;
            with_lock(alloc) {
                stringValue = NKikimr::NMiniKQL::MakeString(inputValue);
                alloc.Ref().LockObject(stringValue);
            }

            inputConsumer->OnObject(stringValue);

            with_lock(alloc) {
                alloc.Ref().UnlockObject(stringValue);
                stringValue.Clear();
            }
        };

        for (ui64 i = 0; i < numberRows; ++i) {
            pushString(targetString);
            pushString("another large string >= 14 bytes");
        }
        inputConsumer->OnFinish();
    }
}