aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_prepend.cpp
blob: 334197d48896ca3aa8ec0ec52753c368f111a37f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include "mkql_prepend.h"
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h>  // Y_IGNORE
#include <yql/essentials/minikql/mkql_node_cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

template<bool IsVoid>
class TPrependWrapper : public TMutableCodegeneratorNode<TPrependWrapper<IsVoid>> {
    typedef TMutableCodegeneratorNode<TPrependWrapper<IsVoid>> TBaseComputation;
public:
    TPrependWrapper(TComputationMutables& mutables, IComputationNode* left, IComputationNode* right)
        : TBaseComputation(mutables, right->GetRepresentation())
        , Left(left)
        , Right(right)
    {
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        auto left = Left->GetValue(ctx);
        auto right = Right->GetValue(ctx);

        if (IsVoid && !left.IsBoxed())
            return right.Release();

        return ctx.HolderFactory.Prepend(left.Release(), right.Release());
    }

#ifndef MKQL_DISABLE_CODEGEN
    Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
        auto& context = ctx.Codegen.GetContext();

        const auto factory = ctx.GetFactory();

        const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::Prepend));

        const auto left = GetNodeValue(Left, ctx, block);
        const auto right = GetNodeValue(Right, ctx, block);

        if constexpr (IsVoid) {
            const auto work = BasicBlock::Create(context, "work", ctx.Func);
            const auto done = BasicBlock::Create(context, "done", ctx.Func);
            const auto result = PHINode::Create(right->getType(), 2, "result", done);
            result->addIncoming(right, block);

            const uint64_t init[] = {0x0ULL, 0x300000000000000ULL};
            const auto mask = ConstantInt::get(left->getType(), APInt(128, 2, init));
            const auto boxed = BinaryOperator::CreateAnd(left, mask, "boxed",  block);
            const auto check = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, boxed, mask, "check", block);
            BranchInst::Create(work, done, check, block);
            block = work;

            if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
                const auto funType = FunctionType::get(right->getType(), {factory->getType(), left->getType(), right->getType()}, false);
                const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
                const auto res = CallInst::Create(funType, funcPtr, {factory, left, right}, "res", block);
                result->addIncoming(res, block);
            } else {
                const auto retPtr = new AllocaInst(right->getType(), 0U, "ret_ptr", block);
                const auto itemPtr = new AllocaInst(left->getType(), 0U, "item_ptr", block);
                new StoreInst(right, retPtr, block);
                new StoreInst(left, itemPtr, block);
                const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), retPtr->getType(), itemPtr->getType(), retPtr->getType()}, false);
                const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
                CallInst::Create(funType, funcPtr, {factory, retPtr, itemPtr, retPtr}, "", block);
                const auto res = new LoadInst(right->getType(), retPtr, "res", block);
                result->addIncoming(res, block);
            }

            BranchInst::Create(done, block);

            block = done;
            return result;
        } else {
            if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
                const auto funType = FunctionType::get(right->getType(), {factory->getType(), left->getType(), right->getType()}, false);
                const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
                const auto res = CallInst::Create(funType, funcPtr, {factory, left, right}, "res", block);
                return res;
            } else {
                const auto retPtr = new AllocaInst(right->getType(), 0U, "ret_ptr", block);
                const auto itemPtr = new AllocaInst(left->getType(), 0U, "item_ptr", block);
                new StoreInst(right, retPtr, block);
                new StoreInst(left, itemPtr, block);
                const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), retPtr->getType(), itemPtr->getType(), retPtr->getType()}, false);
                const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
                CallInst::Create(funType, funcPtr, {factory, retPtr, itemPtr, retPtr}, "", block);
                const auto res = new LoadInst(right->getType(), retPtr, "res", block);
                return res;
            }
        }
    }
#endif
private:
    void RegisterDependencies() const final {
        this->DependsOn(Left);
        this->DependsOn(Right);
    }

    IComputationNode* const Left;
    IComputationNode* const Right;
};

}

IComputationNode* WrapPrepend(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");

    const auto leftType = callable.GetInput(0).GetStaticType();
    const auto rightType = AS_TYPE(TListType, callable.GetInput(1));

    MKQL_ENSURE(rightType->GetItemType()->IsSameType(*leftType), "Mismatch item type");

    const auto left = LocateNode(ctx.NodeLocator, callable, 0);
    const auto right = LocateNode(ctx.NodeLocator, callable, 1);
    if (leftType->IsVoid())
        return new TPrependWrapper<true>(ctx.Mutables, left, right);
    else
        return new TPrependWrapper<false>(ctx.Mutables, left, right);
}

}
}