#include "mkql_apply.h"
#include <yql/essentials/minikql/computation/mkql_block_impl.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <library/cpp/containers/stack_array/stack_array.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_value_builder.h>
namespace NKikimr {
namespace NMiniKQL {
namespace {
class TApplyWrapper: public TMutableCodegeneratorPtrNode<TApplyWrapper> {
typedef TMutableCodegeneratorPtrNode<TApplyWrapper> TBaseComputation;
public:
struct TKernelState : public arrow::compute::KernelState {
TKernelState(ui32 argsCount)
: Alloc(__LOCATION__)
, MemInfo("Apply")
, HolderFactory(Alloc.Ref(), MemInfo)
, ValueBuilder(HolderFactory, NUdf::EValidatePolicy::Exception)
, Args(argsCount)
{
Alloc.Ref().EnableArrowTracking = false;
Alloc.Release();
}
~TKernelState()
{
Alloc.Acquire();
}
TScopedAlloc Alloc;
TMemoryUsageInfo MemInfo;
THolderFactory HolderFactory;
TDefaultValueBuilder ValueBuilder;
TVector<NUdf::TUnboxedValue> Args;
};
class TArrowNode : public IArrowKernelComputationNode {
public:
TArrowNode(const TApplyWrapper* parent, const NUdf::TUnboxedValue& callable, TType* returnType, const TVector<TType*>& argsTypes)
: Parent_(parent)
, Callable_(callable)
, ArgsValuesDescr_(ToValueDescr(argsTypes))
, Kernel_(ConvertToInputTypes(argsTypes), ConvertToOutputType(returnType), [this](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
auto& state = dynamic_cast<TKernelState&>(*ctx->state());
auto guard = Guard(state.Alloc);
Y_ENSURE(batch.values.size() == state.Args.size());
for (ui32 i = 0; i < batch.values.size(); ++i) {
state.Args[i] = state.HolderFactory.CreateArrowBlock(arrow::Datum(batch.values[i]));
}
const auto& ret = Callable_.Run(&state.ValueBuilder, state.Args.data());
*res = TArrowBlock::From(ret).GetDatum();
return arrow::Status::OK();
})
{
Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
Kernel_.init = [argsCount = argsTypes.size()](arrow::compute::KernelContext*, const arrow::compute::KernelInitArgs&) {
auto state = std::make_unique<TKernelState>(argsCount);
return arrow::Result(std::move(state));
};
}
TStringBuf GetKernelName() const final {
return "Apply";
}
const arrow::compute::ScalarKernel& GetArrowKernel() const {
return Kernel_;
}
const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
return ArgsValuesDescr_;
}
const IComputationNode* GetArgument(ui32 index) const {
return Parent_->ArgNodes[index];
}
private:
const TApplyWrapper* Parent_;
const NUdf::TUnboxedValue Callable_;
const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
arrow::compute::ScalarKernel Kernel_;
};
friend class TArrowNode;
TApplyWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* callableNode,
TComputationNodePtrVector&& argNodes, ui32 usedArgs, const NUdf::TSourcePosition& pos, TCallableType* callableType)
: TBaseComputation(mutables, kind)
, CallableNode(callableNode)
, ArgNodes(std::move(argNodes))
, UsedArgs(usedArgs)
, Position(pos)
, CallableType(callableType)
{
Stateless = false;
}
std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
if (UsedArgs != CallableType->GetArgumentsCount()) {
return {};
}
std::shared_ptr<arrow::DataType> t;
if (!CallableType->GetReturnType()->IsBlock() ||
!ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetReturnType())->GetItemType(), t)) {
return {};
}
TVector<TType*> argsTypes;
for (ui32 i = 0; i < CallableType->GetArgumentsCount(); ++i) {
argsTypes.push_back(CallableType->GetArgumentType(i));
if (!CallableType->GetArgumentType(i)->IsBlock() ||
!ConvertArrowType(AS_TYPE(TBlockType, CallableType->GetArgumentType(i))->GetItemType(), t)) {
return {};
}
}
const auto callable = CallableNode->GetValue(ctx);
return std::make_unique<TArrowNode>(this, callable, CallableType->GetReturnType(), argsTypes);
}
NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
NStackArray::TStackArray<NUdf::TUnboxedValue> values(ALLOC_ON_STACK(NUdf::TUnboxedValue, UsedArgs));
for (size_t i = 0; i < UsedArgs; ++i) {
if (const auto valueNode = ArgNodes[i]) {
values[i] = valueNode->GetValue(ctx);
}
}
const auto callable = CallableNode->GetValue(ctx);
const auto prev = ctx.CalleePosition;
ctx.CalleePosition = &Position;
const auto ret = callable.Run(ctx.Builder, values.data());
ctx.CalleePosition = prev;
return ret;
}
#ifndef MKQL_DISABLE_CODEGEN
void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
auto& context = ctx.Codegen.GetContext();
const auto idxType = Type::getInt32Ty(context);
const auto valType = Type::getInt128Ty(context);
const auto arrayType = ArrayType::get(valType, ArgNodes.size());
const auto args = *Stateless || ctx.AlwaysInline ?
new AllocaInst(arrayType, 0U, "args", &ctx.Func->getEntryBlock().back()):
new AllocaInst(arrayType, 0U, "args", block);
ui32 i = 0;
std::vector<std::pair<Value*, EValueRepresentation>> argsv;
argsv.reserve(ArgNodes.size());
for (const auto node : ArgNodes) {
const auto argPtr = GetElementPtrInst::CreateInBounds(arrayType, args, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, i++)}, "arg_ptr", block);
if (node) {
GetNodeValue(argPtr, node, ctx, block);
argsv.emplace_back(argPtr, node->GetRepresentation());
} else {
new StoreInst(ConstantInt::get(valType, 0), argPtr, block);
}
}
if (const auto codegen = dynamic_cast<ICodegeneratorRunNode*>(CallableNode)) {
codegen->CreateRun(ctx, block, pointer, args);
} else {
const auto callable = GetNodeValue(CallableNode, ctx, block);
const auto calleePtr = GetElementPtrInst::CreateInBounds(GetCompContextType(context), ctx.Ctx, {ConstantInt::get(idxType, 0), ConstantInt::get(idxType, 6)}, "callee_ptr", block);
const auto previous = new LoadInst(PointerType::getUnqual(GetSourcePosType(context)), calleePtr, "previous", block);
const auto callee = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), ui64(&Position)), previous->getType(), "callee", block);
new StoreInst(callee, calleePtr, block);
CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, callable, ctx.Codegen, block, ctx.GetBuilder(), args);
new StoreInst(previous, calleePtr, block);
if (CallableNode->IsTemporaryValue()) {
CleanupBoxed(callable, ctx, block);
}
}
for (const auto& arg : argsv) {
ValueUnRef(arg.second, arg.first, ctx, block);
}
}
#endif
private:
void RegisterDependencies() const final {
DependsOn(CallableNode);
for (const auto node : ArgNodes) {
if (node) {
DependsOn(node);
}
}
}
IComputationNode *const CallableNode;
const TComputationNodePtrVector ArgNodes;
const ui32 UsedArgs;
const NUdf::TSourcePosition Position;
TCallableType* CallableType;
};
}
IComputationNode* WrapApply(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
const bool withPos = callable.GetType()->GetName() == "Apply2";
const ui32 deltaArgs = withPos ? 3 : 0;
MKQL_ENSURE(callable.GetInputsCount() >= 2 + deltaArgs, "Expected at least " << (2 + deltaArgs) << " arguments");
const auto function = callable.GetInput(0);
MKQL_ENSURE(!function.IsImmediate() && function.GetNode()->GetType()->IsCallable(),
"First argument of Apply must be a callable");
const auto functionCallable = static_cast<TCallable*>(function.GetNode());
const auto returnType = functionCallable->GetType()->GetReturnType();
MKQL_ENSURE(returnType->IsCallable(), "Expected callable as return type");
const TStringBuf file = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(2))->AsValue().AsStringRef() : NUdf::TStringRef();
const ui32 row = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().Get<ui32>() : 0;
const ui32 column = withPos ? AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().Get<ui32>() : 0;
const ui32 inputsCount = callable.GetInputsCount() - deltaArgs;
const ui32 argsCount = inputsCount - 2;
const ui32 dependentCount = AS_VALUE(TDataLiteral, callable.GetInput(1))->AsValue().Get<ui32>();
MKQL_ENSURE(dependentCount <= argsCount, "Too many dependent nodes");
const ui32 usedArgs = argsCount - dependentCount;
auto callableType = static_cast<TCallableType*>(returnType);
MKQL_ENSURE(usedArgs <= callableType->GetArgumentsCount(), "Too many arguments");
MKQL_ENSURE(usedArgs >= callableType->GetArgumentsCount() - callableType->GetOptionalArgumentsCount(), "Too few arguments");
TComputationNodePtrVector argNodes(callableType->GetArgumentsCount() + dependentCount);
for (ui32 i = 2; i < 2 + usedArgs; ++i) {
argNodes[i - 2] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs);
}
for (ui32 i = 2 + usedArgs; i < inputsCount; ++i) {
argNodes[callableType->GetArgumentsCount() + i - 2 - usedArgs] = LocateNode(ctx.NodeLocator, callable, i + deltaArgs);
}
auto functionNode = LocateNode(ctx.NodeLocator, callable, 0);
return new TApplyWrapper(ctx.Mutables, GetValueRepresentation(callable.GetType()->GetReturnType()), functionNode, std::move(argNodes),
callableType->GetArgumentsCount(), NUdf::TSourcePosition(row, column, file), callableType);
}
}
}