aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_block_if.cpp
blob: 4ec511fa51f1579d7d50d1121bd80b7721935f85 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#include "mkql_block_if.h"

#include <yql/essentials/minikql/computation/mkql_block_reader.h>
#include <yql/essentials/minikql/computation/mkql_block_builder.h>
#include <yql/essentials/minikql/computation/mkql_block_impl.h>

#include <yql/essentials/minikql/arrow/arrow_defs.h>
#include <yql/essentials/minikql/arrow/arrow_util.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/mkql_node_cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

class TBlockIfScalarWrapper : public TMutableComputationNode<TBlockIfScalarWrapper> {
public:
    class TArrowNode : public IArrowKernelComputationNode {
    public:
        TArrowNode(const TBlockIfScalarWrapper* parent)
            : Parent_(parent)
            , ArgsValuesDescr_(ToValueDescr(parent->ArgsTypes))
            , Kernel_(ConvertToInputTypes(parent->ArgsTypes), ConvertToOutputType(parent->ResultType), [parent](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
                *res = parent->CalculateImpl(MakeDatumProvider(batch.values[0]), MakeDatumProvider(batch.values[1]), MakeDatumProvider(batch.values[2]), *ctx->memory_pool());
                return arrow::Status::OK();
            })
        {
            Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
            Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
        }

        TStringBuf GetKernelName() const final {
            return "If";
        }

        const arrow::compute::ScalarKernel& GetArrowKernel() const {
            return Kernel_;
        }

        const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
            return ArgsValuesDescr_;
        }

        const IComputationNode* GetArgument(ui32 index) const {
            switch (index) {
            case 0:
                return Parent_->Pred;
            case 1:
                return Parent_->Then;
            case 2:
                return Parent_->Else;
            default:
                throw yexception() << "Bad argument index";
            }
        }

    private:
        const TBlockIfScalarWrapper* Parent_;
        const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
        arrow::compute::ScalarKernel Kernel_;
    };
    friend class TArrowNode;

    TBlockIfScalarWrapper(TComputationMutables& mutables, IComputationNode* pred, IComputationNode* thenNode, IComputationNode* elseNode, TType* resultType,
                          bool thenIsScalar, bool elseIsScalar, const TVector<TType*>& argsTypes)
        : TMutableComputationNode(mutables)
        , Pred(pred)
        , Then(thenNode)
        , Else(elseNode)
        , ResultType(resultType)
        , ThenIsScalar(thenIsScalar)
        , ElseIsScalar(elseIsScalar)
        , ArgsTypes(argsTypes)
    {
    }

    std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
        Y_UNUSED(ctx);
        return std::make_unique<TArrowNode>(this);
    }

    arrow::Datum CalculateImpl(const TDatumProvider& predProv, const TDatumProvider& thenProv, const TDatumProvider& elseProv,
        arrow::MemoryPool& memoryPool) const {
        auto predValue = predProv();

        const bool predScalarValue = GetPrimitiveScalarValue<bool>(*predValue.scalar());
        auto result = predScalarValue ? thenProv() : elseProv();

        if (ThenIsScalar == ElseIsScalar || (predScalarValue ? !ThenIsScalar : !ElseIsScalar)) {
            // can return result as-is
            return result;
        }

        auto otherDatum = predScalarValue ? elseProv() : thenProv();
        MKQL_ENSURE(otherDatum.is_arraylike(), "Expecting array");

        std::shared_ptr<arrow::Scalar> resultScalar = result.scalar();

        TVector<std::shared_ptr<arrow::ArrayData>> resultArrays;
        auto itemType = AS_TYPE(TBlockType, ResultType)->GetItemType();
        ForEachArrayData(otherDatum, [&](const std::shared_ptr<arrow::ArrayData>& otherData) {
            auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, itemType, memoryPool);
            ForEachArrayData(chunk, [&](const auto& array) {
                resultArrays.push_back(array);
            });
        });
        return MakeArray(resultArrays);
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        return ctx.HolderFactory.CreateArrowBlock(CalculateImpl(MakeDatumProvider(Pred, ctx), MakeDatumProvider(Then, ctx), MakeDatumProvider(Else, ctx), ctx.ArrowMemoryPool));
    }

private:
    void RegisterDependencies() const final {
        DependsOn(Pred);
        DependsOn(Then);
        DependsOn(Else);
    }

    IComputationNode* const Pred;
    IComputationNode* const Then;
    IComputationNode* const Else;
    TType* const ResultType;
    const bool ThenIsScalar;
    const bool ElseIsScalar;
    const TVector<TType*> ArgsTypes;
};

template<bool ThenIsScalar, bool ElseIsScalar>
class TIfBlockExec {
public:
    explicit TIfBlockExec(TType* type)
        : ThenReader(MakeBlockReader(TTypeInfoHelper(), type)), ElseReader(MakeBlockReader(TTypeInfoHelper(), type)), Type(type)
    {
    }

    arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
        arrow::Datum predDatum = batch.values[0];
        arrow::Datum thenDatum = batch.values[1];
        arrow::Datum elseDatum = batch.values[2];

        TBlockItem thenItem;
        const arrow::ArrayData* thenArray = nullptr;
        if constexpr(ThenIsScalar) {
            thenItem = ThenReader->GetScalarItem(*thenDatum.scalar());
        } else {
            MKQL_ENSURE(thenDatum.is_array(), "Expecting array");
            thenArray = thenDatum.array().get();
        }

        TBlockItem elseItem;
        const arrow::ArrayData* elseArray = nullptr;
        if constexpr(ElseIsScalar) {
            elseItem = ElseReader->GetScalarItem(*elseDatum.scalar());
        } else {
            MKQL_ENSURE(elseDatum.is_array(), "Expecting array");
            elseArray = elseDatum.array().get();
        }

        MKQL_ENSURE(predDatum.is_array(), "Expecting array");
        const std::shared_ptr<arrow::ArrayData>& pred = predDatum.array();

        const size_t len = pred->length;
        auto builder = MakeArrayBuilder(TTypeInfoHelper(), Type, *ctx->memory_pool(), len, nullptr);
        const ui8* predValues = pred->GetValues<uint8_t>(1);
        for (size_t i = 0; i < len; ++i) {
            if constexpr (!ThenIsScalar) {
                thenItem = ThenReader->GetItem(*thenArray, i);
            }
            if constexpr (!ElseIsScalar) {
                elseItem = ElseReader->GetItem(*elseArray, i);
            }

            ui64 mask = -ui64(predValues[i]);
            ui64 low = (thenItem.Low() & mask) | (elseItem.Low() & ~mask);
            ui64 high = (thenItem.High() & mask) | (elseItem.High() & ~mask);
            builder->Add(TBlockItem{low, high});
        }
        *res = builder->Build(true);
        return arrow::Status::OK();
    }

private:
    const std::unique_ptr<IBlockReader> ThenReader;
    const std::unique_ptr<IBlockReader> ElseReader;
    TType* const Type;
};


template<bool ThenIsScalar, bool ElseIsScalar>
std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockIfKernel(const TVector<TType*>& argTypes, TType* resultType) {
    using TExec = TIfBlockExec<ThenIsScalar, ElseIsScalar>;

    auto exec = std::make_shared<TExec>(AS_TYPE(TBlockType, resultType)->GetItemType());
    auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
        [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
            return exec->Exec(ctx, batch, res);
    });

    kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
    return kernel;
}

} // namespace

IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");

    auto pred = callable.GetInput(0);
    auto thenNode = callable.GetInput(1);
    auto elseNode = callable.GetInput(2);

    auto predType = AS_TYPE(TBlockType, pred.GetStaticType());
    MKQL_ENSURE(AS_TYPE(TDataType, predType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
                "Expected bool as first argument");

    auto thenType = AS_TYPE(TBlockType, thenNode.GetStaticType());
    auto elseType = AS_TYPE(TBlockType, elseNode.GetStaticType());
    MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");

    auto predCompute = LocateNode(ctx.NodeLocator, callable, 0);
    auto thenCompute = LocateNode(ctx.NodeLocator, callable, 1);
    auto elseCompute = LocateNode(ctx.NodeLocator, callable, 2);

    bool predIsScalar = predType->GetShape() == TBlockType::EShape::Scalar;
    bool thenIsScalar = thenType->GetShape() == TBlockType::EShape::Scalar;
    bool elseIsScalar = elseType->GetShape() == TBlockType::EShape::Scalar;
    TVector<TType*> argsTypes = { predType, thenType, elseType };


    if (predIsScalar) {
        return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType,
                                         thenIsScalar, elseIsScalar, argsTypes);
    }

    TComputationNodePtrVector argsNodes = { predCompute, thenCompute, elseCompute };

    std::shared_ptr<arrow::compute::ScalarKernel> kernel;
    if (thenIsScalar && elseIsScalar) {
        kernel = MakeBlockIfKernel<true, true>(argsTypes, thenType);
    } else if (thenIsScalar && !elseIsScalar) {
        kernel = MakeBlockIfKernel<true, false>(argsTypes, thenType);
    } else if (!thenIsScalar && elseIsScalar) {
        kernel = MakeBlockIfKernel<false, true>(argsTypes, thenType);
    } else {
        kernel = MakeBlockIfKernel<false, false>(argsTypes, thenType);
    }

    return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}

}
}