diff options
author | vvvv <vvvv@yandex-team.com> | 2024-11-07 04:19:26 +0300 |
---|---|---|
committer | vvvv <vvvv@yandex-team.com> | 2024-11-07 04:29:50 +0300 |
commit | 2661be00f3bc47590fda9218bf0386d6355c8c88 (patch) | |
tree | 3d316c07519191283d31c5f537efc6aabb42a2f0 /yql/essentials/minikql/comp_nodes/mkql_invoke.cpp | |
parent | cf2a23963ac10add28c50cc114fbf48953eca5aa (diff) | |
download | ydb-2661be00f3bc47590fda9218bf0386d6355c8c88.tar.gz |
Moved yql/minikql YQL-19206
init
[nodiff:caesar]
commit_hash:d1182ef7d430ccf7e4d37ed933c7126d7bd5d6e4
Diffstat (limited to 'yql/essentials/minikql/comp_nodes/mkql_invoke.cpp')
-rw-r--r-- | yql/essentials/minikql/comp_nodes/mkql_invoke.cpp | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_invoke.cpp b/yql/essentials/minikql/comp_nodes/mkql_invoke.cpp new file mode 100644 index 0000000000..aa9f8b7e7b --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/mkql_invoke.cpp @@ -0,0 +1,223 @@ +#include "mkql_invoke.h" +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE +#include <yql/essentials/minikql/mkql_node_cast.h> +#include <yql/essentials/minikql/mkql_node_builder.h> +#include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +template<bool IsOptional> +class TUnaryArgInvokeBase { +protected: + TUnaryArgInvokeBase(TStringBuf name, const TFunctionDescriptor& descr) + : Name(name), Descriptor(descr) + {} + + NUdf::TUnboxedValuePod DoCalc(const NUdf::TUnboxedValuePod& arg) const { + if (IsOptional && !arg) { + return {}; + } + return Descriptor.Function(&arg); + } + +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenGetValue(const TCodegenContext& ctx, Value* arg, BasicBlock*& block) const { + if (IsOptional) { + auto& context = ctx.Codegen.GetContext(); + + const auto good = BasicBlock::Create(context, "good", ctx.Func); + const auto done = BasicBlock::Create(context, "done", ctx.Func); + const auto result = PHINode::Create(arg->getType(), 2U, "result", done); + + result->addIncoming(arg, block); + BranchInst::Create(good, done, IsExists(arg, block), block); + + block = good; + const auto out = reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(&arg, ctx, block); + + result->addIncoming(out, block); + BranchInst::Create(done, block); + + block = done; + return result; + } else { + return reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(&arg, ctx, block); + } + } +#endif + const TStringBuf Name; + const TFunctionDescriptor Descriptor; +}; + +template<bool IsOptional> +class TSimpleUnaryArgInvokeWrapper : public TDecoratorCodegeneratorNode<TSimpleUnaryArgInvokeWrapper<IsOptional>>, private TUnaryArgInvokeBase<IsOptional> { + typedef TDecoratorCodegeneratorNode<TSimpleUnaryArgInvokeWrapper<IsOptional>> TBaseComputation; +public: + TSimpleUnaryArgInvokeWrapper(TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* arg) + : TBaseComputation(arg), TUnaryArgInvokeBase<IsOptional>(name, descr) + {} + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& arg) const { + return this->DoCalc(arg); + } + +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* arg, BasicBlock*& block) const { + return this->DoGenGetValue(ctx, arg, block); + } +#endif + +private: + TString DebugString() const final { + return TBaseComputation::DebugString() + "(" + this->Name + ")" ; + } +}; + +template<bool IsOptional> +class TDefaultUnaryArgInvokeWrapper : public TMutableCodegeneratorNode<TDefaultUnaryArgInvokeWrapper<IsOptional>>, private TUnaryArgInvokeBase<IsOptional> { + typedef TMutableCodegeneratorNode<TDefaultUnaryArgInvokeWrapper<IsOptional>> TBaseComputation; +public: + TDefaultUnaryArgInvokeWrapper(TComputationMutables& mutables, EValueRepresentation kind, TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* arg) + : TBaseComputation(mutables, kind), TUnaryArgInvokeBase<IsOptional>(name, descr), Arg(arg) + {} + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + return this->DoCalc(Arg->GetValue(ctx)); + } + +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { + const auto arg = GetNodeValue(Arg, ctx, block); + return this->DoGenGetValue(ctx, arg, block); + } +#endif + +private: + void RegisterDependencies() const final { + this->DependsOn(Arg); + } + + TString DebugString() const final { + return TBaseComputation::DebugString() + "(" + this->Name + ")" ; + } + + IComputationNode *const Arg; +}; + +class TBinaryInvokeWrapper : public TBinaryCodegeneratorNode<TBinaryInvokeWrapper> { + typedef TBinaryCodegeneratorNode<TBinaryInvokeWrapper> TBaseComputation; +public: + TBinaryInvokeWrapper(TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* left, IComputationNode* right, EValueRepresentation kind = EValueRepresentation::Embedded) + : TBaseComputation(left, right, kind), Name(name), Descriptor(descr) + { + } + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { + const std::array<NUdf::TUnboxedValue, 2U> args {{Left->GetValue(compCtx), Right->GetValue(compCtx)}}; + return Descriptor.Function(args.data()); + } + +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { + const std::array<Value*, 2U> args {{GetNodeValue(Left, ctx, block), GetNodeValue(Right, ctx, block)}}; + return reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(args.data(), ctx, block); + } +#endif + +private: + TString DebugString() const final { + return TBaseComputation::DebugString() + "(" + Name + ")" ; + } + + const TStringBuf Name; + const TFunctionDescriptor Descriptor; +}; + +template<size_t Size> +class TInvokeWrapper : public TMutableCodegeneratorNode<TInvokeWrapper<Size>> { + typedef TMutableCodegeneratorNode<TInvokeWrapper<Size>> TBaseComputation; +public: + TInvokeWrapper(TComputationMutables& mutables, EValueRepresentation kind, TStringBuf name, const TFunctionDescriptor& descr, TComputationNodePtrVector&& argNodes) + : TBaseComputation(mutables, kind) + , Name(name), Descriptor(descr) + , ArgNodes(std::move(argNodes)) + { + } + + NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { + std::array<NUdf::TUnboxedValue, Size> values; + std::transform(ArgNodes.cbegin(), ArgNodes.cend(), values.begin(), + std::bind(&IComputationNode::GetValue, std::placeholders::_1, std::ref(ctx)) + ); + return Descriptor.Function(values.data()); + } + +#ifndef MKQL_DISABLE_CODEGEN + Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { + std::array<Value*, Size> values; + std::transform(ArgNodes.cbegin(), ArgNodes.cend(), values.begin(), + [&](IComputationNode* node) { return GetNodeValue(node, ctx, block); } + ); + return reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(values.data(), ctx, block); + } +#endif + +private: + void RegisterDependencies() const final { + std::for_each(ArgNodes.cbegin(), ArgNodes.cend(), std::bind(&TInvokeWrapper::DependsOn, this, std::placeholders::_1)); + } + + TString DebugString() const final { + return TBaseComputation::DebugString() + "(" + Name + ")" ; + } + + const TStringBuf Name; + const TFunctionDescriptor Descriptor; + const TComputationNodePtrVector ArgNodes; +}; + +} + +IComputationNode* WrapInvoke(TCallable& callable, const TComputationNodeFactoryContext& ctx) { + MKQL_ENSURE(callable.GetInputsCount() >= 2U && callable.GetInputsCount() <= 4U, "Expected from one to three arguments."); + const auto returnType = callable.GetType()->GetReturnType(); + + const auto inputsCount = callable.GetInputsCount(); + std::array<TArgType, 4U> argsTypes; + TComputationNodePtrVector argNodes; + argNodes.reserve(inputsCount - 1U); + argsTypes.front().first = UnpackOptionalData(returnType, argsTypes.front().second)->GetSchemeType(); + for (ui32 i = 1U; i < inputsCount; ++i) { + argsTypes[i].first = UnpackOptionalData(callable.GetInput(i), argsTypes[i].second)->GetSchemeType(); + argNodes.emplace_back(LocateNode(ctx.NodeLocator, callable, i)); + } + + const auto funcName = AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().AsStringRef(); + const auto funcDesc = ctx.FunctionRegistry.GetBuiltins()->GetBuiltin(funcName, argsTypes.data(), inputsCount); + + const auto returnKind = GetValueRepresentation(returnType); + switch (argNodes.size()) { + case 1U: + if (EValueRepresentation::Embedded == returnKind) { + return new TSimpleUnaryArgInvokeWrapper<false>(funcName, funcDesc, argNodes.front()); + } else { + return new TDefaultUnaryArgInvokeWrapper<false>(ctx.Mutables, returnKind, funcName, funcDesc, argNodes.front()); + } + case 2U: + if (EValueRepresentation::Embedded == returnKind) { + return new TBinaryInvokeWrapper(funcName, funcDesc, argNodes.front(), argNodes.back()); + } + return new TInvokeWrapper<2U>(ctx.Mutables, returnKind, funcName, funcDesc, std::move(argNodes)); + case 3U: + return new TInvokeWrapper<3U>(ctx.Mutables, returnKind, funcName, funcDesc, std::move(argNodes)); + default: + Y_ABORT("Too wide invoke."); + } +} + +} +} |