aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_block_container.cpp
blob: 3dd4135e72cd65288770ae886f2fd516659d9032 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include "mkql_block_container.h"

#include <yql/essentials/minikql/computation/mkql_block_impl.h>

#include <yql/essentials/minikql/arrow/arrow_defs.h>
#include <yql/essentials/minikql/arrow/arrow_util.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_node_builder.h>

#include <arrow/util/bitmap_ops.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

class TBlockAsContainerExec {
public:
    TBlockAsContainerExec(const TVector<TType*>& argTypes, const std::shared_ptr<arrow::DataType>& returnArrowType)
        : ArgTypes(argTypes)
        , ReturnArrowType(returnArrowType)
    {}

    arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
        bool allScalars = true;
        size_t length = 0;
        for (const auto& x : batch.values) {
            if (!x.is_scalar()) {
                allScalars = false;
                length = x.array()->length;
                break;
            }
        }

        if (allScalars) {
            // return scalar too
            std::vector<std::shared_ptr<arrow::Scalar>> arrowValue;
            for (const auto& x : batch.values) {
                arrowValue.emplace_back(x.scalar());
            }

            *res = arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, ReturnArrowType));
            return arrow::Status::OK();
        }

        auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, length, { nullptr }, 0, 0);
        MKQL_ENSURE(ArgTypes.size() == batch.values.size(), "Mismatch batch columns");
        for (ui32 i = 0; i < batch.values.size(); ++i) {
            const auto& datum = batch.values[i];
            if (datum.is_scalar()) {
                // expand scalar to array
                auto expandedArray = MakeArrayFromScalar(*datum.scalar(), length, AS_TYPE(TBlockType, ArgTypes[i])->GetItemType(), *ctx->memory_pool());
                newArrayData->child_data.push_back(expandedArray.array());
            } else {
                newArrayData->child_data.push_back(datum.array());
            }
        }

        *res = arrow::Datum(newArrayData);
        return arrow::Status::OK();
    }

private:
    const TVector<TType*> ArgTypes;
    const std::shared_ptr<arrow::DataType> ReturnArrowType;
};

std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockAsContainerKernel(const TVector<TType*>& argTypes, TType* resultType) {
    std::shared_ptr<arrow::DataType> returnArrowType;
    MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
    auto exec = std::make_shared<TBlockAsContainerExec>(argTypes, returnArrowType);
    auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
        [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
        return exec->Exec(ctx, batch, res);
    });

    kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
    return kernel;
}

} // namespace

IComputationNode* WrapBlockAsContainer(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    TComputationNodePtrVector argsNodes;
    TVector<TType*> argsTypes;
    for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
        argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i));
        argsTypes.push_back(callable.GetInput(i).GetStaticType());
    }

    auto kernel = MakeBlockAsContainerKernel(argsTypes, callable.GetType()->GetReturnType());
    return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}

} // namespace NMiniKQL
} // namespace NKikimr