summaryrefslogtreecommitdiffstats
path: root/yql/essentials/parser/pg_wrapper/pg_ops.cpp
blob: af7205eeb326dde657c584d75997cfe6f9f8c972 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#include "pg_ops.h"

#include <yql/essentials/parser/pg_catalog/catalog.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/parser/pg_wrapper/utils.h>

namespace NYql {

using namespace NKikimr::NMiniKQL;

class TPgSignNumericState: public TPgSign::TCallState {
public:
    TPgSignNumericState(const TPgResolvedCallWithCast& call, const TPgCast& resultCaster)
        : Call_(call)
        , ResultCaster_(resultCaster)
        , CallState_(MakeHolder<TPgResolvedCallWithCastState>(call.GetCall(), call.GetCasters()))
        , ResultCastState_(MakeHolder<TPgCastState>(resultCaster.GetFInfo1(), resultCaster.GetFInfo2()))
    {
    }

    TMaybe<i32> GetSign(const NUdf::TUnboxedValue& value) override {
        std::array<NUdf::TUnboxedValue, 1> args = {value};
        auto result = Call_.CallFunctionWithCast(*CallState_, args);
        if (!result) {
            return Nothing();
        }

        auto floatResult = ResultCaster_.Calculate(result, -1, *ResultCastState_);
        MKQL_ENSURE(floatResult, "Result cast returned NULL unexpectedly");
        auto floatValue = DatumGetFloat8(ScalarDatumFromPod(floatResult));
        if (floatValue > 0) {
            return 1;
        } else if (floatValue < 0) {
            return -1;
        } else {
            return 0;
        }
    }

private:
    const TPgResolvedCallWithCast& Call_;
    const TPgCast& ResultCaster_;
    THolder<TPgResolvedCallWithCastState> CallState_;
    THolder<TPgCastState> ResultCastState_;
};

// State for interval sign
class TPgSignIntervalState: public TPgSign::TCallState {
public:
    TPgSignIntervalState(const TPgResolvedCall<false>& eqCall, const TPgResolvedCall<false>& gtCall,
                         const TPgConst& zeroInterval)
        : ZeroInterval_(zeroInterval)
        , EqCall_(eqCall)
        , GtCall_(gtCall)
        , EqState_(MakeHolder<TPgResolvedCallState>(eqCall.GetArgCount(), &eqCall.GetFInfo()))
        , GtState_(MakeHolder<TPgResolvedCallState>(gtCall.GetArgCount(), &gtCall.GetFInfo()))
    {
    }

    TMaybe<i32> GetSign(const NUdf::TUnboxedValue& value) override {
        auto zeroValue = ZeroInterval_.ExtractConst();
        std::array<NUdf::TUnboxedValue, 2> args = {value, zeroValue};
        // Check if == 0
        auto eqResult = EqCall_.CallFunction(*EqState_, args);
        if (!eqResult) {
            return Nothing();
        }
        if (DatumGetBool(ScalarDatumFromPod(eqResult))) {
            return 0;
        }

        // Check if > 0
        auto gtResult = GtCall_.CallFunction(*GtState_, args);

        MKQL_ENSURE(gtResult, "Comparison operator returned NULL unexpectedly");
        if (DatumGetBool(ScalarDatumFromPod(gtResult))) {
            return 1;
        }

        return -1;
    }

private:
    const TPgConst& ZeroInterval_;
    const TPgResolvedCall<false>& EqCall_;
    const TPgResolvedCall<false>& GtCall_;
    THolder<TPgResolvedCallState> EqState_;
    THolder<TPgResolvedCallState> GtState_;
};

// Implementation for numeric types (uses sign() function)
class TPgSignNumeric: public TPgSign {
public:
    explicit TPgSignNumeric(ui32 inputTypeOid);

    std::unique_ptr<TCallState> MakeCallState() const override;

private:
    TPgResolvedCallWithCast Call_;
    TPgCast ResultCaster_;
};

// Implementation for interval type (uses == and > operators)
class TPgSignInterval: public TPgSign {
public:
    explicit TPgSignInterval(ui32 inputTypeOid);

    std::unique_ptr<TCallState> MakeCallState() const override;

private:
    TPgResolvedCall<false> EqCall_;
    TPgResolvedCall<false> GtCall_;
    TPgConst ZeroInterval_;
};

// TPgSign factory
std::unique_ptr<TPgSign> TPgSign::Create(ui32 inputTypeOid) {
    if (inputTypeOid == INTERVALOID) {
        return std::make_unique<TPgSignInterval>(inputTypeOid);
    }
    return std::make_unique<TPgSignNumeric>(inputTypeOid);
}

// TPgSignNumeric implementation
TPgSignNumeric::TPgSignNumeric(ui32 inputTypeOid)
    : Call_(TPgResolvedCallWithCast::ForProc("sign", {inputTypeOid}))
    , ResultCaster_(Call_.GetReturnTypeId(), FLOAT8OID, false)
{
}

std::unique_ptr<TPgSign::TCallState> TPgSignNumeric::MakeCallState() const {
    return std::make_unique<TPgSignNumericState>(Call_, ResultCaster_);
}

namespace {

TPgResolvedCall<false> MakeOperatorCall(std::string_view operName, ui32 leftType, ui32 rightType) {
    const auto& oper = NPg::LookupOper(TString(operName), {leftType, rightType});
    MKQL_ENSURE(oper.LeftType == leftType && oper.RightType == rightType,
                "Type mismatch for operator " << operName);
    TVector<ui32> argTypes = {oper.LeftType, oper.RightType};
    return TPgResolvedCall<false>(operName, oper.ProcId, argTypes, oper.ResultType);
}

} // namespace

// TPgSignInterval implementation
TPgSignInterval::TPgSignInterval(ui32 inputTypeOid)
    : EqCall_(MakeOperatorCall("=", inputTypeOid, INTERVALOID))
    , GtCall_(MakeOperatorCall(">", inputTypeOid, INTERVALOID))
    , ZeroInterval_(INTERVALOID, "0 seconds")
{
    MKQL_ENSURE(EqCall_.GetReturnTypeId() == BOOLOID,
                "Equality operator must return bool, got " << EqCall_.GetReturnTypeId());
    MKQL_ENSURE(GtCall_.GetReturnTypeId() == BOOLOID,
                "Greater-than operator must return bool, got " << GtCall_.GetReturnTypeId());
    MKQL_ENSURE(EqCall_.GetArgTypeId(0) == INTERVALOID && EqCall_.GetArgTypeId(1) == INTERVALOID,
                "Equality operator must expect (interval, interval), got (" << EqCall_.GetArgTypeId(0) << ", " << EqCall_.GetArgTypeId(1) << ")");
    MKQL_ENSURE(GtCall_.GetArgTypeId(0) == INTERVALOID && GtCall_.GetArgTypeId(1) == INTERVALOID,
                "Greater-than operator must expect (interval, interval), got (" << GtCall_.GetArgTypeId(0) << ", " << GtCall_.GetArgTypeId(1) << ")");
}

std::unique_ptr<TPgSign::TCallState> TPgSignInterval::MakeCallState() const {
    return std::make_unique<TPgSignIntervalState>(EqCall_, GtCall_, ZeroInterval_);
}

TPgCompareOp::TCallState::TCallState(const TPgResolvedCallWithCast& call)
    : State(call.CreateState())
    , Call(call)
{
}

TPgCompareOp::TPgCompareOp(ui32 lhsTypeOid, ui32 rhsTypeOid, std::string_view operName)
    : Call(TPgResolvedCallWithCast::ForOperator(operName, {lhsTypeOid, rhsTypeOid}))
{
    MKQL_ENSURE(Call.GetReturnTypeId() == BOOLOID,
                "Comparison operator must return bool, got " << Call.GetReturnTypeId());
}

TPgCompareOp::TCallState TPgCompareOp::MakeCallState() const {
    return TCallState(Call);
}

TMaybe<bool> TPgCompareOp::TCallState::Compare(const NUdf::TUnboxedValue& lhs, const NUdf::TUnboxedValue& rhs) {
    std::array<NUdf::TUnboxedValue, 2> args = {lhs, rhs};
    auto result = Call.CallFunctionWithCast(State, args);
    if (!result) {
        return Nothing();
    }
    return DatumGetBool(ScalarDatumFromPod(result));
}

} // namespace NYql