aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_aggrcount.cpp
blob: 0a4cb6c7f32bb09b53eb7ec43791f0a385b7e1e5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include "mkql_aggrcount.h"
#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h>  // Y_IGNORE
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
#include <yql/essentials/minikql/mkql_node_cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

class TAggrCountInitWrapper : public TDecoratorCodegeneratorNode<TAggrCountInitWrapper> {
    typedef TDecoratorCodegeneratorNode<TAggrCountInitWrapper> TBaseComputation;
public:
    TAggrCountInitWrapper(IComputationNode* value)
        : TBaseComputation(value)
    {}

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& value) const {
        return NUdf::TUnboxedValuePod(ui64(value ? 1ULL : 0ULL));
    }

#ifndef MKQL_DISABLE_CODEGEN
    Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* value, BasicBlock*& block) const {
        const auto check = IsExists(value, block);
        if (Node->IsTemporaryValue())
            ValueCleanup(Node->GetRepresentation(), value, ctx, block);
        return MakeBoolean(check, ctx.Codegen.GetContext(), block);
    }
#endif
};

class TAggrCountUpdateWrapper : public TDecoratorCodegeneratorNode<TAggrCountUpdateWrapper> {
    typedef TDecoratorCodegeneratorNode<TAggrCountUpdateWrapper> TBaseComputation;
public:
    TAggrCountUpdateWrapper(IComputationNode* state)
        : TBaseComputation(state)
    {}

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& value) const {
        return NUdf::TUnboxedValuePod(value.Get<ui64>() + 1U);
    }

#ifndef MKQL_DISABLE_CODEGEN
    Value* DoGenerateGetValue(const TCodegenContext&, Value* value, BasicBlock*& block) const {
        return BinaryOperator::CreateAdd(value, ConstantInt::get(value->getType(), 1), "incr", block);
    }
#endif
};

class TAggrCountIfUpdateWrapper : public TMutableCodegeneratorNode<TAggrCountIfUpdateWrapper> {
    typedef TMutableCodegeneratorNode<TAggrCountIfUpdateWrapper> TBaseComputation;
public:
    TAggrCountIfUpdateWrapper(TComputationMutables& mutables, IComputationNode* value, IComputationNode* state)
        : TBaseComputation(mutables, EValueRepresentation::Embedded)
        , Arg(value)
        , State(state)
    {
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
        auto state = State->GetValue(compCtx);
        return Arg->GetValue(compCtx) ? NUdf::TUnboxedValuePod(state.Get<ui64>() + 1U) : state.Release();
    }

#ifndef MKQL_DISABLE_CODEGEN
    Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
        const auto state = GetNodeValue(State, ctx, block);
        const auto value = GetNodeValue(Arg, ctx, block);
        const auto check = IsExists(value, block);
        if (Arg->IsTemporaryValue())
            ValueCleanup(Arg->GetRepresentation(), value, ctx, block);
        const auto zext = new ZExtInst(check, state->getType(), "zext", block);
        const auto incr = BinaryOperator::CreateAdd(state, zext, "incr", block);
        return incr;
    }
#endif
private:
    void RegisterDependencies() const final {
        DependsOn(Arg);
        DependsOn(State);
    }

    IComputationNode* const Arg;
    IComputationNode* const State;
};

}

IComputationNode* WrapAggrCountInit(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
    if (callable.GetInput(0).GetStaticType()->IsOptional()) {
        return new TAggrCountInitWrapper(LocateNode(ctx.NodeLocator, callable, 0));
    } else {
        return ctx.NodeFactory.CreateImmutableNode(NUdf::TUnboxedValuePod(ui64(1ULL)));
    }
}

IComputationNode* WrapAggrCountUpdate(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
    MKQL_ENSURE(AS_TYPE(TDataType, callable.GetInput(1))->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64 type");
    if (callable.GetInput(0).GetStaticType()->IsOptional()) {
        return new TAggrCountIfUpdateWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), LocateNode(ctx.NodeLocator, callable, 1));
    } else {
        return new TAggrCountUpdateWrapper(LocateNode(ctx.NodeLocator, callable, 1));
    }
}

}
}