aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_match_recognize_save_load.h
blob: d94ef5eb4b6318e36fac2ed74a5296d9238d7b94 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once

#include <yql/essentials/minikql/computation/mkql_computation_node.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/comp_nodes/mkql_saveload.h>
#include <yql/essentials/minikql/mkql_string_util.h>

namespace NKikimr::NMiniKQL::NMatchRecognize {

struct TSerializerContext {

    TSerializerContext(
        TComputationContext& ctx,
        TType* rowType,
        const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker)
        : Ctx(ctx)
        , RowType(rowType)
        , RowPacker(rowPacker)
        {}

    TComputationContext&    Ctx;
    TType*                  RowType;
    const TMutableObjectOverBoxedValue<TValuePackerBoxed>& RowPacker;
};

template<class>
inline constexpr bool always_false_v = false;

struct TMrOutputSerializer : TOutputSerializer {
private:
    enum class TPtrStateMode {
        Saved = 0,
        FromCache = 1
    };

public:
    TMrOutputSerializer(const TSerializerContext& context, EMkqlStateType stateType, ui32 stateVersion, TComputationContext& ctx)
        : TOutputSerializer(stateType, stateVersion, ctx)
        , Context(context)
    {} 

    using TOutputSerializer::Write;

    template <typename... Ts>
    void operator()(Ts&&... args) {
        (Write(std::forward<Ts>(args)), ...);
    }

    void Write(const NUdf::TUnboxedValue& value) {
        WriteUnboxedValue(Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), value);
    }

    template<class Type>
    void Write(const TIntrusivePtr<Type>& ptr) {
        bool isValid = static_cast<bool>(ptr);
        WriteBool(Buf, isValid);
        if (!isValid) {
            return;
        }
        auto addr = reinterpret_cast<std::uintptr_t>(ptr.Get());
        WriteUi64(Buf, addr);

        auto it = Cache.find(addr);
        if (it != Cache.end()) {
            WriteByte(Buf, static_cast<ui8>(TPtrStateMode::FromCache));
            return;
        }
        WriteByte(Buf, static_cast<ui8>(TPtrStateMode::Saved));
        ptr->Save(*this);
        Cache[addr] = addr;
    }

private:
    const TSerializerContext& Context;
    mutable std::map<std::uintptr_t, std::uintptr_t> Cache;
};

struct TMrInputSerializer : TInputSerializer {
private:
    enum class TPtrStateMode {
        Saved = 0,
        FromCache = 1
    };

public:
    TMrInputSerializer(TSerializerContext& context, const NUdf::TUnboxedValue& state)
        : TInputSerializer(state, EMkqlStateType::SIMPLE_BLOB)
        , Context(context) {    
    }

    using TInputSerializer::Read;

    template <typename... Ts>
    void operator()(Ts&... args) {
        (Read(args), ...);
    }

    void Read(NUdf::TUnboxedValue& value) {
        value = ReadUnboxedValue(Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), Context.Ctx);
    }

    template<class Type>
    void Read(TIntrusivePtr<Type>& ptr) {
        bool isValid = Read<bool>();
        if (!isValid) {
            ptr.Reset();
            return;
        }
        ui64 addr = Read<ui64>();
        TPtrStateMode mode = static_cast<TPtrStateMode>(Read<ui8>());
        if (mode == TPtrStateMode::Saved) {
            ptr = MakeIntrusive<Type>();
            ptr->Load(*this);
            Cache[addr] = ptr.Get();
            return;
        }
        auto it = Cache.find(addr);
        MKQL_ENSURE(it != Cache.end(), "Internal error");
        auto* cachePtr = static_cast<Type*>(it->second);
        ptr = TIntrusivePtr<Type>(cachePtr);
    }
 
private:
    TSerializerContext& Context;
    mutable std::map<std::uintptr_t, void *> Cache;
};

} //namespace NKikimr::NMiniKQL::NMatchRecognize