aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_dynamic_variant.cpp
blob: 98359fb86292d187c169e55d46a9310b31a4486e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "mkql_dynamic_variant.h"

#include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>

#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_node_builder.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

template <typename TDerived>
class TDynamicVariantBaseWrapper : public TMutableComputationNode<TDerived> {
public:
    using TBase = TMutableComputationNode<TDerived>;

    TDynamicVariantBaseWrapper(TComputationMutables& mutables, IComputationNode* item,
        IComputationNode* index)
        : TBase(mutables)
        , Item(item)
        , Index(index)
    {}

private:
    void RegisterDependencies() const final {
        this->DependsOn(Item);
        this->DependsOn(Index);
    }

protected:
    IComputationNode* const Item;
    IComputationNode* const Index;
};

class TDynamicVariantTupleWrapper : public TDynamicVariantBaseWrapper<TDynamicVariantTupleWrapper> {
public:
    using TBase = TDynamicVariantBaseWrapper<TDynamicVariantTupleWrapper>;

    TDynamicVariantTupleWrapper(TComputationMutables& mutables, IComputationNode* item,
        IComputationNode* index, TVariantType* varType)
        : TBase(mutables, item, index)
        , AltCounts(varType->GetAlternativesCount())
    {}

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        auto indexValue = Index->GetValue(ctx);
        if (!indexValue || indexValue.Get<ui32>() >= AltCounts) {
            return {};
        }

        NUdf::TUnboxedValuePod item = Item->GetValue(ctx).Release();
        NUdf::TUnboxedValuePod var = ctx.HolderFactory.CreateVariantHolder(item, indexValue.Get<ui32>());
        return var.MakeOptional();
    }

private:
    const ui32 AltCounts;
};

class TDynamicVariantStructWrapper : public TDynamicVariantBaseWrapper<TDynamicVariantStructWrapper> {
public:
    using TBase = TDynamicVariantBaseWrapper<TDynamicVariantStructWrapper>;

    TDynamicVariantStructWrapper(TComputationMutables& mutables, IComputationNode* item,
        IComputationNode* index, TVariantType* varType)
        : TBase(mutables, item, index)
        , Fields(MakeFields(varType))
    {}

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        auto indexValue = Index->GetValue(ctx);
        if (!indexValue) {
            return {};
        }

        TStringBuf indexStr = indexValue.AsStringRef();
        auto ptr = Fields.FindPtr(indexStr);
        if (!ptr) {
            return {};
        }

        NUdf::TUnboxedValuePod item = Item->GetValue(ctx).Release();
        NUdf::TUnboxedValuePod var = ctx.HolderFactory.CreateVariantHolder(item, *ptr);
        return var.MakeOptional();
    }

private:
    static THashMap<TStringBuf, ui32> MakeFields(TVariantType* varType) {
        THashMap<TStringBuf, ui32> res;
        auto structType = AS_TYPE(TStructType, varType->GetUnderlyingType());
        for (ui32 i = 0; i < structType->GetMembersCount(); ++i) {
            res[structType->GetMemberName(i)] = i;
        }

        return res;
    }

private:
    const THashMap<TStringBuf, ui32> Fields;
};

}

IComputationNode* WrapDynamicVariant(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 arguments");
    const auto item = LocateNode(ctx.NodeLocator, callable, 0);
    const auto index = LocateNode(ctx.NodeLocator, callable, 1);
    const auto varTypeNode = callable.GetInput(2);
    MKQL_ENSURE(varTypeNode.IsImmediate() && varTypeNode.GetStaticType()->IsType(), "Expected immediate type");
    const auto varType = AS_TYPE(TVariantType, static_cast<TType*>(varTypeNode.GetNode()));

    if (varType->GetUnderlyingType()->IsTuple()) {
        return new TDynamicVariantTupleWrapper(ctx.Mutables, item, index, varType);
    } else {
        return new TDynamicVariantStructWrapper(ctx.Mutables, item, index, varType);
    }

    return nullptr;
}

}
}