aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_lazy_list.cpp
blob: 9e6ec251e3e3e0b9ba79ec7196f5961cdc7f45c9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include "mkql_lazy_list.h"
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h>  // Y_IGNORE
#include <yql/essentials/minikql/mkql_node_cast.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

template <bool IsOptional>
class TLazyListWrapper : public TMutableCodegeneratorNode<TLazyListWrapper<IsOptional>> {
    typedef TMutableCodegeneratorNode<TLazyListWrapper<IsOptional>> TBaseComputation;
public:

    TLazyListWrapper(TComputationMutables& mutables, IComputationNode* list)
        : TBaseComputation(mutables, EValueRepresentation::Boxed), List(list)
    {}

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        auto list = List->GetValue(ctx);

        if (IsOptional && !list) {
            return NUdf::TUnboxedValuePod();
        }

        if (list.GetElements()) {
            return ctx.HolderFactory.LazyList(list.Release());
        }

        return list.Release();
    }

#ifndef MKQL_DISABLE_CODEGEN
    Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
        auto& context = ctx.Codegen.GetContext();
        const auto factory = ctx.GetFactory();
        const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::LazyList));

        const auto list = GetNodeValue(List, ctx, block);

        const auto wrap = BasicBlock::Create(context, "wrap", ctx.Func);
        const auto done = BasicBlock::Create(context, "done", ctx.Func);
        const auto lazy = PHINode::Create(list->getType(), IsOptional ? 3U : 2U, "lazy", done);
        lazy->addIncoming(list, block);

        if constexpr (IsOptional) {
            const auto test = BasicBlock::Create(context, "test", ctx.Func);
            BranchInst::Create(done, test, IsEmpty(list, block), block);

            block = test;
            lazy->addIncoming(list, block);
        }

        const auto ptrType = PointerType::getUnqual(list->getType());
        const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, list, ctx.Codegen, block);
        const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);

        BranchInst::Create(done, wrap, null, block);

        block = wrap;

        if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
            const auto funType = FunctionType::get(list->getType(), {factory->getType(), list->getType()}, false);
            const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
            const auto res = CallInst::Create(funType, funcPtr, {factory, list}, "res", block);
            lazy->addIncoming(res, block);
        } else {
            const auto retPtr = new AllocaInst(list->getType(), 0U, "ret_ptr", block);
            new StoreInst(list, retPtr, block);
            const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), retPtr->getType(), retPtr->getType()}, false);
            const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
            CallInst::Create(funType, funcPtr, {factory, retPtr, retPtr}, "", block);
            const auto res = new LoadInst(list->getType(), retPtr, "res", block);
            lazy->addIncoming(res, block);
        }

        BranchInst::Create(done, block);

        block = done;
        return lazy;
    }
#endif
private:
    void RegisterDependencies() const final {
        this->DependsOn(List);
    }

    IComputationNode* const List;
};

}

IComputationNode* WrapLazyList(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 1U, "Expected single arg, got " << callable.GetInputsCount());
    const auto list = LocateNode(ctx.NodeLocator, callable, 0);

    if (callable.GetInput(0).GetStaticType()->IsOptional()) {
        return new TLazyListWrapper<true>(ctx.Mutables, list);
    } else {
        return new TLazyListWrapper<false>(ctx.Mutables, list);
    }
}

}
}