diff options
author | vvvv <vvvv@ydb.tech> | 2022-11-21 20:34:51 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2022-11-21 20:34:51 +0300 |
commit | 27bc344a905db44fe44c7898b77dd7d5ee81a58d (patch) | |
tree | 965c28a1e1961541b0832b24bcae1623df4cf878 | |
parent | 179c88461116abaccd2fe2d0999c712a1bd2465d (diff) | |
download | ydb-27bc344a905db44fe44c7898b77dd7d5ee81a58d.tar.gz |
discovery of kernels by YQL types
19 files changed, 465 insertions, 289 deletions
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp index 01a6582454f..798e3bf02b9 100644 --- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp +++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp @@ -123,41 +123,34 @@ bool ConvertOutputArrowType(const arrow::compute::OutputType& outType, const std } bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env, const IBuiltinFunctionRegistry& registry) { - auto resFunc = registry.GetArrowFunctionRegistry()->GetFunction(TString(name)); - if (!resFunc.ok()) { - return false; - } - - const auto& func = *resFunc; - if (func->kind() != arrow::compute::Function::SCALAR) { - return false; - } - - std::vector<arrow::ValueDescr> values; bool hasOptionals = false; - for (const auto& type : inputTypes) { - arrow::ValueDescr descr; + bool many = false; + std::vector<NUdf::TDataTypeId> argTypes; + for (const auto& t : inputTypes) { bool isOptional; - if (!ConvertInputArrowType(type, isOptional, descr)) { - return false; + auto asBlockType = AS_TYPE(TBlockType, t); + if (asBlockType->GetShape() == TBlockType::EShape::Many) { + many = true; } + auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional); hasOptionals = hasOptionals || isOptional; - values.push_back(descr); + argTypes.push_back(dataType->GetSchemeType()); } - auto resKernel = func->DispatchExact(values); - if (!resKernel.ok()) { + auto kernel = registry.FindKernel(name, argTypes.data(), argTypes.size()); + if (!kernel) { return false; } - const auto& kernel = static_cast<const arrow::compute::ScalarKernel*>(*resKernel); - auto notNull = (kernel->null_handling == arrow::compute::NullHandling::OUTPUT_NOT_NULL); - const auto& outType = kernel->signature->out_type(); - if (!ConvertOutputArrowType(outType, values, name.EndsWith("?") || (hasOptionals && !notNull), outputType, env)) { - return false; + outputType = TDataType::Create(kernel->ReturnType, env); + if (kernel->Family.NullMode != TKernelFamily::ENullMode::AlwaysNotNull) { + if (hasOptionals || kernel->Family.NullMode == TKernelFamily::ENullMode::AlwaysNull) { + outputType = TOptionalType::Create(outputType, env); + } } + outputType = TBlockType::Create(outputType, many ? TBlockType::EShape::Many : TBlockType::EShape::Scalar, env); return true; } diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp index b9bce7f6de9..81ee6bb78de 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp @@ -43,15 +43,29 @@ const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function return *static_cast<const arrow::compute::ScalarKernel*>(kernel); } +const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes) { + std::vector<NUdf::TDataTypeId> argTypes; + for (const auto& t : inputTypes) { + auto asBlockType = AS_TYPE(TBlockType, t); + bool isOptional; + auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional); + argTypes.push_back(dataType->GetSchemeType()); + } + + auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size()); + MKQL_ENSURE(kernel, "Can't find kernel for " << funcName); + return *kernel; +} + struct TState : public TComputationValue<TState> { using TComputationValue::TComputationValue; - TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::FunctionOptions* options, + TState(TMemoryUsageInfo* memInfo, const arrow::compute::FunctionOptions* options, const arrow::compute::ScalarKernel& kernel, - const arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx) + const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx) : TComputationValue(memInfo) , Options(options) - , ExecContext(&ctx.ArrowMemoryPool, nullptr, const_cast<arrow::compute::FunctionRegistry*>(®istry)) + , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr) , KernelContext(&ExecContext) { if (kernel.init) {
@@ -73,7 +87,7 @@ struct TState : public TComputationValue<TState> { class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> { public: TBlockFuncWrapper(TComputationMutables& mutables, - const arrow::compute::FunctionRegistry& functionRegistry, + const IBuiltinFunctionRegistry& builtins, const TString& funcName, TVector<IComputationNode*>&& argsNodes, TVector<TType*>&& argsTypes) @@ -83,9 +97,7 @@ public: , ArgsNodes(std::move(argsNodes)) , ArgsTypes(std::move(argsTypes)) , ArgsValuesDescr(ToValueDescr(ArgsTypes)) - , FunctionRegistry(functionRegistry) - , Function(ResolveFunction(FunctionRegistry, FuncName)) - , Kernel(ResolveKernel(Function, ArgsValuesDescr)) + , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes)) { } @@ -100,7 +112,7 @@ public: auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>(); auto executor = arrow::compute::detail::KernelExecutor::MakeScalar(); - ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options }));
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options }));
ARROW_OK(executor->Execute(state.Values, listener.get()));
auto output = executor->WrapResults(state.Values, listener->values()); return ctx.HolderFactory.CreateArrowBlock(std::move(output)); @@ -123,7 +135,7 @@ private: TState& GetState(TComputationContext& ctx) const { auto& result = ctx.MutableValues[StateIndex]; if (!result.HasValue()) { - result = ctx.HolderFactory.Create<TState>(Function, Function.default_options(), Kernel, FunctionRegistry, ArgsValuesDescr, ctx); + result = ctx.HolderFactory.Create<TState>(Kernel.Family.FunctionOptions, Kernel.GetArrowKernel(), ArgsValuesDescr, ctx); } return *static_cast<TState*>(result.AsBoxed().Get()); @@ -136,15 +148,12 @@ private: const TVector<TType*> ArgsTypes; const std::vector<arrow::ValueDescr> ArgsValuesDescr; - const arrow::compute::FunctionRegistry& FunctionRegistry; - const arrow::compute::Function& Function; - const arrow::compute::ScalarKernel& Kernel; + const TKernel& Kernel; }; class TBlockBitCastWrapper : public TMutableComputationNode<TBlockBitCastWrapper> { public: TBlockBitCastWrapper(TComputationMutables& mutables, - const arrow::compute::FunctionRegistry& functionRegistry, IComputationNode* arg, TType* argType, TType* to) @@ -152,8 +161,7 @@ public: , StateIndex(mutables.CurValueIndex++) , Arg(arg) , ArgsValuesDescr({ ToValueDescr(argType) }) - , FunctionRegistry(functionRegistry) - , Function(ResolveFunction(FunctionRegistry, to)) + , Function(ResolveFunction(to)) , Kernel(ResolveKernel(Function, ArgsValuesDescr)) , CastOptions(false) { @@ -179,7 +187,7 @@ private: this->DependsOn(Arg); } - static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, TType* to) { + static const arrow::compute::Function& ResolveFunction(TType* to) { bool isOptional; std::shared_ptr<arrow::DataType> type; MKQL_ENSURE(ConvertArrowType(to, isOptional, type), "can't get arrow type"); @@ -193,7 +201,7 @@ private: TState& GetState(TComputationContext& ctx) const { auto& result = ctx.MutableValues[StateIndex]; if (!result.HasValue()) { - result = ctx.HolderFactory.Create<TState>(Function, (const arrow::compute::FunctionOptions*)&CastOptions, Kernel, FunctionRegistry, ArgsValuesDescr, ctx); + result = ctx.HolderFactory.Create<TState>((const arrow::compute::FunctionOptions*)&CastOptions, Kernel, ArgsValuesDescr, ctx); } return *static_cast<TState*>(result.AsBoxed().Get()); @@ -203,7 +211,6 @@ private: const ui32 StateIndex; IComputationNode* Arg; const std::vector<arrow::ValueDescr> ArgsValuesDescr; - const arrow::compute::FunctionRegistry& FunctionRegistry; const arrow::compute::Function& Function; const arrow::compute::ScalarKernel& Kernel; arrow::compute::CastOptions CastOptions; @@ -224,7 +231,7 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto } return new TBlockFuncWrapper(ctx.Mutables, - *ctx.FunctionRegistry.GetBuiltins()->GetArrowFunctionRegistry(), + *ctx.FunctionRegistry.GetBuiltins(), funcName, std::move(argsNodes), std::move(argsTypes) @@ -236,7 +243,6 @@ IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFa auto argNode = LocateNode(ctx.NodeLocator, callable, 0); MKQL_ENSURE(callable.GetInput(1).GetStaticType()->IsType(), "Expected type"); return new TBlockBitCastWrapper(ctx.Mutables, - *ctx.FunctionRegistry.GetBuiltins()->GetArrowFunctionRegistry(), argNode, callable.GetType()->GetArgumentType(0), static_cast<TType*>(callable.GetInput(1).GetNode()) diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp index e8add1b8c11..6ab9f7e66dc 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp @@ -2,6 +2,8 @@ #include "mkql_builtins_impl.h" #include "mkql_builtins_compare.h" +#include <ydb/library/yql/minikql/arrow/arrow_defs.h> + #include <util/digest/murmur.h> #include <util/generic/yexception.h> #include <util/generic/maybe.h> @@ -16,18 +18,77 @@ namespace NMiniKQL { namespace { -void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, arrow::compute::FunctionRegistry& arrowRegistry) { +class TForeignKernel : public TKernel { +public: + TForeignKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, + const std::shared_ptr<arrow::compute::Function>& function) + : TKernel(family, argTypes, returnType) + , Function(function) + , ArrowKernel(ResolveKernel(Function, argTypes)) + {} + + const arrow::compute::ScalarKernel& GetArrowKernel() const final { + return ArrowKernel; + } + +private: + static const arrow::compute::ScalarKernel& ResolveKernel(const std::shared_ptr<arrow::compute::Function>& function, + const std::vector<NUdf::TDataTypeId>& argTypes) { + std::vector<arrow::ValueDescr> args; + for (const auto& t : argTypes) { + args.emplace_back(); + auto slot = NUdf::FindDataSlot(t); + MKQL_ENSURE(slot, "Unexpected data type"); + MKQL_ENSURE(ConvertArrowType(*slot, args.back().type), "Can't get arrow type"); + } + + const auto kernel = ARROW_RESULT(function->DispatchExact(args)); + return *static_cast<const arrow::compute::ScalarKernel*>(kernel); + } + +private: + const std::shared_ptr<arrow::compute::Function> Function; + const arrow::compute::ScalarKernel& ArrowKernel; +}; + +template <typename TInput1, typename TOutput> +void RegisterUnary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) { + auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));
+ + std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id }); + NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id; + + auto family = std::make_unique<TKernelFamilyBase>(); + family->KernelMap.emplace(argTypes, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func)); + + Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second); +} + +template <typename TInput1, typename TInput2, typename TOutput> +void RegisterBinary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) { + auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));
+ + std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id, NUdf::TDataType<TInput2>::Id }); + NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id; + + auto family = std::make_unique<TKernelFamilyBase>(); + family->KernelMap.emplace(argTypes, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func)); + + Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second); +} + +void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, TKernelFamilyMap& kernelFamilyMap) { RegisterAdd(registry); - RegisterAdd(arrowRegistry); + RegisterAdd(kernelFamilyMap); RegisterAggrAdd(registry); RegisterSub(registry); - RegisterSub(arrowRegistry); + RegisterSub(kernelFamilyMap); RegisterMul(registry); - RegisterMul(arrowRegistry); + RegisterMul(kernelFamilyMap); RegisterDiv(registry); - RegisterDiv(arrowRegistry); + RegisterDiv(kernelFamilyMap); RegisterMod(registry); - RegisterMod(arrowRegistry); + RegisterMod(kernelFamilyMap); RegisterIncrement(registry); RegisterDecrement(registry); RegisterBitAnd(registry); @@ -56,17 +117,17 @@ void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, arrow::comput RegisterAggrMax(registry); RegisterAggrMin(registry); RegisterEquals(registry); - RegisterEquals(arrowRegistry); + RegisterEquals(kernelFamilyMap); RegisterNotEquals(registry); - RegisterNotEquals(arrowRegistry); + RegisterNotEquals(kernelFamilyMap); RegisterLess(registry); - RegisterLess(arrowRegistry); + RegisterLess(kernelFamilyMap); RegisterLessOrEqual(registry); - RegisterLessOrEqual(arrowRegistry); + RegisterLessOrEqual(kernelFamilyMap); RegisterGreater(registry); - RegisterGreater(arrowRegistry); + RegisterGreater(kernelFamilyMap); RegisterGreaterOrEqual(registry); - RegisterGreaterOrEqual(arrowRegistry); + RegisterGreaterOrEqual(kernelFamilyMap); } void PrintType(NUdf::TDataTypeId schemeType, bool isOptional, IOutputStream& out) @@ -153,25 +214,31 @@ private: const TFunctionsMap& GetFunctions() const final; - arrow::compute::FunctionRegistry* GetArrowFunctionRegistry() const final; - void CalculateMetadataEtag(); std::optional<TFunctionDescriptor> FindBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const; const TDescriptionList& FindCandidates(const std::string_view& name) const; + const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final; + + void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) final; + TFunctionsMap Functions; TFunctionParamMetadataList ArgumentsMetadata; std::optional<ui64> MetadataEtag; - std::unique_ptr<arrow::compute::FunctionRegistry> ArrowRegistry; + TKernelFamilyMap KernelFamilyMap; }; TBuiltinFunctionRegistry::TBuiltinFunctionRegistry() - : ArrowRegistry(arrow::compute::FunctionRegistry::Make()) { - RegisterDefaultOperations(*this, *ArrowRegistry); - arrow::compute::internal::RegisterScalarBoolean(ArrowRegistry.get()); + RegisterDefaultOperations(*this, KernelFamilyMap); + auto arrowRegistry = arrow::compute::FunctionRegistry::Make(); + arrow::compute::internal::RegisterScalarBoolean(arrowRegistry.get()); + RegisterUnary<bool, bool>(*arrowRegistry, "invert", KernelFamilyMap); + RegisterBinary<bool, bool, bool>(*arrowRegistry, "and_kleene", KernelFamilyMap); + RegisterBinary<bool, bool, bool>(*arrowRegistry, "or_kleene", KernelFamilyMap); + RegisterBinary<bool, bool, bool>(*arrowRegistry, "xor", KernelFamilyMap); CalculateMetadataEtag(); } @@ -296,8 +363,17 @@ void TBuiltinFunctionRegistry::PrintInfoTo(IOutputStream& out) const } } -arrow::compute::FunctionRegistry* TBuiltinFunctionRegistry::GetArrowFunctionRegistry() const { - return ArrowRegistry.get(); +const TKernel* TBuiltinFunctionRegistry::FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const { + auto fit = KernelFamilyMap.find(TString(name)); + if (fit == KernelFamilyMap.end()) { + return nullptr; + } + + return fit->second->FindKernel(argTypes, argTypesCount); +} + +void TBuiltinFunctionRegistry::RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) { + Y_ENSURE(KernelFamilyMap.emplace(TString(name), std::move(family)).second); } } // namespace @@ -306,10 +382,5 @@ IBuiltinFunctionRegistry::TPtr CreateBuiltinRegistry() { return MakeIntrusive<TBuiltinFunctionRegistry>(); } -void AddFunction(arrow::compute::FunctionRegistry& registry, const std::shared_ptr<arrow::compute::ScalarFunction>& f) { - ARROW_OK(registry.AddFunction(f)); -} - - } // namespace NMiniKQL } // namespace NKikimr diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp index 0cce313512f..967f16e01d6 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp @@ -192,8 +192,8 @@ void RegisterAdd(IBuiltinFunctionRegistry& registry) { NUdf::TDataType<NUdf::TInterval>, TDateTimeAdd, TBinaryArgsOptWithNullableResult>(registry, "Add"); } -void RegisterAdd(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericFunction<TAdd>>("Add")); +void RegisterAdd(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Add"] = std::make_unique<TBinaryNumericKernelFamily<TAdd>>(); } void RegisterAggrAdd(IBuiltinFunctionRegistry& registry) { diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h index 3163168c72c..0542d1f1a16 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h @@ -470,17 +470,17 @@ void RegisterAggrCompareStrings(IBuiltinFunctionRegistry& registry, const std::s } void RegisterEquals(IBuiltinFunctionRegistry& registry); -void RegisterEquals(arrow::compute::FunctionRegistry& registry); +void RegisterEquals(TKernelFamilyMap& kernelFamilyMap); void RegisterNotEquals(IBuiltinFunctionRegistry& registry); -void RegisterNotEquals(arrow::compute::FunctionRegistry& registry); +void RegisterNotEquals(TKernelFamilyMap& kernelFamilyMap); void RegisterLess(IBuiltinFunctionRegistry& registry); -void RegisterLess(arrow::compute::FunctionRegistry& registry); +void RegisterLess(TKernelFamilyMap& kernelFamilyMap); void RegisterLessOrEqual(IBuiltinFunctionRegistry& registry); -void RegisterLessOrEqual(arrow::compute::FunctionRegistry& registry); +void RegisterLessOrEqual(TKernelFamilyMap& kernelFamilyMap); void RegisterGreater(IBuiltinFunctionRegistry& registry); -void RegisterGreater(arrow::compute::FunctionRegistry& registry); +void RegisterGreater(TKernelFamilyMap& kernelFamilyMap); void RegisterGreaterOrEqual(IBuiltinFunctionRegistry& registry); -void RegisterGreaterOrEqual(arrow::compute::FunctionRegistry& registry); +void RegisterGreaterOrEqual(TKernelFamilyMap& kernelFamilyMap); } } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp index 607b6f444b2..18a76503bba 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp @@ -166,8 +166,8 @@ void RegisterDiv(IBuiltinFunctionRegistry& registry) { NUdf::TDataType<NUdf::TInterval>, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div"); } -void RegisterDiv(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericFunction<TIntegralDiv>>("Div?")); +void RegisterDiv(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Div"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralDiv>>(TKernelFamily::ENullMode::AlwaysNull); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp index 00857259d97..878bc03d5af 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp @@ -284,8 +284,8 @@ void RegisterEquals(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrEquals, TCompareArgsOpt>(registry, aggrName); } -void RegisterEquals(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TEqualsOp>>("Equals")); +void RegisterEquals(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Equals"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TEqualsOp>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp index 15f035a7bdb..fd7eabe41fb 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp @@ -281,8 +281,8 @@ void RegisterGreater(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrGreater, TCompareArgsOpt>(registry, aggrName); } -void RegisterGreater(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TGreaterOp>>("Greater")); +void RegisterGreater(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Greater"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TGreaterOp>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp index 604f9955507..e402584917c 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp @@ -281,8 +281,8 @@ void RegisterGreaterOrEqual(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrGreaterOrEqual, TCompareArgsOpt>(registry, aggrName); } -void RegisterGreaterOrEqual(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TGreaterOrEqualOp>>("GreaterOrEqual")); +void RegisterGreaterOrEqual(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["GreaterOrEqual"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TGreaterOrEqualOp>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h index d2fb0844966..ff8255c9cc3 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h @@ -768,16 +768,16 @@ void RegisterBinaryRealFunction(IBuiltinFunctionRegistry& registry, const std::s } void RegisterAdd(IBuiltinFunctionRegistry& registry); -void RegisterAdd(arrow::compute::FunctionRegistry& registry); +void RegisterAdd(TKernelFamilyMap& kernelFamilyMap); void RegisterAggrAdd(IBuiltinFunctionRegistry& registry); void RegisterSub(IBuiltinFunctionRegistry& registry); -void RegisterSub(arrow::compute::FunctionRegistry& registry); +void RegisterSub(TKernelFamilyMap& kernelFamilyMap); void RegisterMul(IBuiltinFunctionRegistry& registry); -void RegisterMul(arrow::compute::FunctionRegistry& registry); +void RegisterMul(TKernelFamilyMap& kernelFamilyMap); void RegisterDiv(IBuiltinFunctionRegistry& registry); -void RegisterDiv(arrow::compute::FunctionRegistry& registry); +void RegisterDiv(TKernelFamilyMap& kernelFamilyMap); void RegisterMod(IBuiltinFunctionRegistry& registry); -void RegisterMod(arrow::compute::FunctionRegistry& registry); +void RegisterMod(TKernelFamilyMap& kernelFamilyMap); void RegisterIncrement(IBuiltinFunctionRegistry& registry); void RegisterDecrement(IBuiltinFunctionRegistry& registry); void RegisterBitAnd(IBuiltinFunctionRegistry& registry); @@ -806,8 +806,6 @@ void RegisterAggrMax(IBuiltinFunctionRegistry& registry); void RegisterAggrMin(IBuiltinFunctionRegistry& registry); void RegisterWith(IBuiltinFunctionRegistry& registry); -void AddFunction(arrow::compute::FunctionRegistry& registry, const std::shared_ptr<arrow::compute::ScalarFunction>& f); - inline arrow::internal::Bitmap GetBitmap(const arrow::ArrayData& arr, int index) { return arrow::internal::Bitmap{ arr.buffers[index], arr.offset, arr.length }; } @@ -861,8 +859,8 @@ inline std::shared_ptr<arrow::DataType> GetPrimitiveDataType<ui64>() { } template <typename T> -arrow::compute::InputType GetPrimitiveInputArrowType(bool isScalar) { - return arrow::compute::InputType(GetPrimitiveDataType<T>(), isScalar ? arrow::ValueDescr::SCALAR : arrow::ValueDescr::ARRAY); +arrow::compute::InputType GetPrimitiveInputArrowType() { + return arrow::compute::InputType(GetPrimitiveDataType<T>(), arrow::ValueDescr::ANY); } template <typename T> @@ -1008,9 +1006,31 @@ template<typename TInput1, typename TInput2, typename TOutput, template<typename, typename, typename> class TFunc, bool DefaultNulls> struct TBinaryKernelExecs; +template<typename TDerived> +struct TBinaryKernelExecsBase { + static arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { + MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); + const auto& arg1 = batch.values[0]; + const auto& arg2 = batch.values[1]; + if (arg1.is_scalar()) { + if (arg2.is_scalar()) { + return TDerived::ExecScalarScalar(ctx, batch, res); + } else { + return TDerived::ExecScalarArray(ctx, batch, res); + } + } else { + if (arg2.is_scalar()) { + return TDerived::ExecArrayScalar(ctx, batch, res); + } else { + return TDerived::ExecArrayArray(ctx, batch, res); + } + } + } +}; + template<typename TInput1, typename TInput2, typename TOutput, template<typename, typename, typename> class TFunc> -struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> +struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> : TBinaryKernelExecsBase<TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true>> { using TFuncInstance = TFunc<TInput1, TInput2, TOutput>; @@ -1107,7 +1127,7 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> template<typename TInput1, typename TInput2, typename TOutput, template<typename, typename, typename> class TFunc> -struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> +struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> : TBinaryKernelExecsBase<TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false>> { using TFuncInstance = TFunc<TInput1, TInput2, TOutput>; @@ -1237,195 +1257,200 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> } }; +class TPlainKernel : public TKernel { +public: + TPlainKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, const arrow::compute::ScalarKernel& arrowKernel) + : TKernel(family, argTypes, returnType) + , ArrowKernel(arrowKernel) + { + } + + const arrow::compute::ScalarKernel& GetArrowKernel() const final { + return ArrowKernel; + } + +private: + const arrow::compute::ScalarKernel ArrowKernel; +}; + template<typename TInput1, typename TInput2, typename TOutput, template<typename, typename, typename> class TFunc> -void AddBinaryKernel(arrow::compute::ScalarFunction& function) { +void AddBinaryKernel(TKernelFamilyBase& owner) { using TFuncInstance = TFunc<TInput1, TInput2, TOutput>; using TExecs = TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, TFuncInstance::DefaultNulls>; - auto nullHandling = TFuncInstance::DefaultNulls ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE; - - arrow::compute::ScalarKernel ss({GetPrimitiveInputArrowType<TInput1>(true), GetPrimitiveInputArrowType<TInput2>(true) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecScalarScalar); - ss.null_handling = nullHandling; - ARROW_OK(function.AddKernel(ss)); - - arrow::compute::ScalarKernel sa({ GetPrimitiveInputArrowType<TInput1>(true), GetPrimitiveInputArrowType<TInput2>(false) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecScalarArray); - sa.null_handling = nullHandling; - ARROW_OK(function.AddKernel(sa)); - arrow::compute::ScalarKernel as({ GetPrimitiveInputArrowType<TInput1>(false), GetPrimitiveInputArrowType<TInput2>(true) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecArrayScalar); - as.null_handling = nullHandling; - ARROW_OK(function.AddKernel(as)); + std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id, NUdf::TDataType<TInput2>::Id }); + NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id; - arrow::compute::ScalarKernel aa({ GetPrimitiveInputArrowType<TInput1>(false), GetPrimitiveInputArrowType<TInput2>(false) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecArrayArray); - aa.null_handling = nullHandling; - ARROW_OK(function.AddKernel(aa)); + arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInput1>(), GetPrimitiveInputArrowType<TInput2>() }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::Exec); + k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE; + owner.KernelMap.emplace(argTypes, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k)); } template<template<typename, typename, typename> class TFunc> -void AddBinaryIntegralKernels(arrow::compute::ScalarFunction& function) { - AddBinaryKernel<ui8, ui8, ui8, TFunc>(function); - AddBinaryKernel<ui8, i8, i8, TFunc>(function); - AddBinaryKernel<ui8, ui16, ui16, TFunc>(function); - AddBinaryKernel<ui8, i16, i16, TFunc>(function); - AddBinaryKernel<ui8, ui32, ui32, TFunc>(function); - AddBinaryKernel<ui8, i32, i32, TFunc>(function); - AddBinaryKernel<ui8, ui64, ui64, TFunc>(function); - AddBinaryKernel<ui8, i64, i64, TFunc>(function); - - AddBinaryKernel<i8, ui8, i8, TFunc>(function); - AddBinaryKernel<i8, i8, i8, TFunc>(function); - AddBinaryKernel<i8, ui16, ui16, TFunc>(function); - AddBinaryKernel<i8, i16, i16, TFunc>(function); - AddBinaryKernel<i8, ui32, ui32, TFunc>(function); - AddBinaryKernel<i8, i32, i32, TFunc>(function); - AddBinaryKernel<i8, ui64, ui64, TFunc>(function); - AddBinaryKernel<i8, i64, i64, TFunc>(function); - - AddBinaryKernel<ui16, ui8, ui16, TFunc>(function); - AddBinaryKernel<ui16, i8, ui16, TFunc>(function); - AddBinaryKernel<ui16, ui16, ui16, TFunc>(function); - AddBinaryKernel<ui16, i16, i16, TFunc>(function); - AddBinaryKernel<ui16, ui32, ui32, TFunc>(function); - AddBinaryKernel<ui16, i32, i32, TFunc>(function); - AddBinaryKernel<ui16, ui64, ui64, TFunc>(function); - AddBinaryKernel<ui16, i64, i64, TFunc>(function); - - AddBinaryKernel<i16, ui8, i16, TFunc>(function); - AddBinaryKernel<i16, i8, i16, TFunc>(function); - AddBinaryKernel<i16, ui16, i16, TFunc>(function); - AddBinaryKernel<i16, i16, i16, TFunc>(function); - AddBinaryKernel<i16, ui32, ui32, TFunc>(function); - AddBinaryKernel<i16, i32, i32, TFunc>(function); - AddBinaryKernel<i16, ui64, ui64, TFunc>(function); - AddBinaryKernel<i16, i64, i64, TFunc>(function); - - AddBinaryKernel<ui32, ui8, ui32, TFunc>(function); - AddBinaryKernel<ui32, i8, ui32, TFunc>(function); - AddBinaryKernel<ui32, ui16, ui32, TFunc>(function); - AddBinaryKernel<ui32, i16, ui32, TFunc>(function); - AddBinaryKernel<ui32, ui32, ui32, TFunc>(function); - AddBinaryKernel<ui32, i32, i32, TFunc>(function); - AddBinaryKernel<ui32, ui64, ui64, TFunc>(function); - AddBinaryKernel<ui32, i64, i64, TFunc>(function); - - AddBinaryKernel<i32, ui8, i32, TFunc>(function); - AddBinaryKernel<i32, i8, i32, TFunc>(function); - AddBinaryKernel<i32, ui16, i32, TFunc>(function); - AddBinaryKernel<i32, i16, i32, TFunc>(function); - AddBinaryKernel<i32, ui32, i32, TFunc>(function); - AddBinaryKernel<i32, i32, i32, TFunc>(function); - AddBinaryKernel<i32, ui64, ui64, TFunc>(function); - AddBinaryKernel<i32, i64, i64, TFunc>(function); - - AddBinaryKernel<ui64, ui8, ui64, TFunc>(function); - AddBinaryKernel<ui64, i8, ui64, TFunc>(function); - AddBinaryKernel<ui64, ui16, ui64, TFunc>(function); - AddBinaryKernel<ui64, i16, ui64, TFunc>(function); - AddBinaryKernel<ui64, ui32, ui64, TFunc>(function); - AddBinaryKernel<ui64, i32, ui64, TFunc>(function); - AddBinaryKernel<ui64, ui64, ui64, TFunc>(function); - AddBinaryKernel<ui64, i64, i64, TFunc>(function); - - AddBinaryKernel<i64, ui8, i64, TFunc>(function); - AddBinaryKernel<i64, i8, i64, TFunc>(function); - AddBinaryKernel<i64, ui16, i64, TFunc>(function); - AddBinaryKernel<i64, i16, i64, TFunc>(function); - AddBinaryKernel<i64, ui32, i64, TFunc>(function); - AddBinaryKernel<i64, i32, i64, TFunc>(function); - AddBinaryKernel<i64, ui64, i64, TFunc>(function); - AddBinaryKernel<i64, i64, i64, TFunc>(function); +void AddBinaryIntegralKernels(TKernelFamilyBase& owner) { + AddBinaryKernel<ui8, ui8, ui8, TFunc>(owner); + AddBinaryKernel<ui8, i8, i8, TFunc>(owner); + AddBinaryKernel<ui8, ui16, ui16, TFunc>(owner); + AddBinaryKernel<ui8, i16, i16, TFunc>(owner); + AddBinaryKernel<ui8, ui32, ui32, TFunc>(owner); + AddBinaryKernel<ui8, i32, i32, TFunc>(owner); + AddBinaryKernel<ui8, ui64, ui64, TFunc>(owner); + AddBinaryKernel<ui8, i64, i64, TFunc>(owner); + + AddBinaryKernel<i8, ui8, i8, TFunc>(owner); + AddBinaryKernel<i8, i8, i8, TFunc>(owner); + AddBinaryKernel<i8, ui16, ui16, TFunc>(owner); + AddBinaryKernel<i8, i16, i16, TFunc>(owner); + AddBinaryKernel<i8, ui32, ui32, TFunc>(owner); + AddBinaryKernel<i8, i32, i32, TFunc>(owner); + AddBinaryKernel<i8, ui64, ui64, TFunc>(owner); + AddBinaryKernel<i8, i64, i64, TFunc>(owner); + + AddBinaryKernel<ui16, ui8, ui16, TFunc>(owner); + AddBinaryKernel<ui16, i8, ui16, TFunc>(owner); + AddBinaryKernel<ui16, ui16, ui16, TFunc>(owner); + AddBinaryKernel<ui16, i16, i16, TFunc>(owner); + AddBinaryKernel<ui16, ui32, ui32, TFunc>(owner); + AddBinaryKernel<ui16, i32, i32, TFunc>(owner); + AddBinaryKernel<ui16, ui64, ui64, TFunc>(owner); + AddBinaryKernel<ui16, i64, i64, TFunc>(owner); + + AddBinaryKernel<i16, ui8, i16, TFunc>(owner); + AddBinaryKernel<i16, i8, i16, TFunc>(owner); + AddBinaryKernel<i16, ui16, i16, TFunc>(owner); + AddBinaryKernel<i16, i16, i16, TFunc>(owner); + AddBinaryKernel<i16, ui32, ui32, TFunc>(owner); + AddBinaryKernel<i16, i32, i32, TFunc>(owner); + AddBinaryKernel<i16, ui64, ui64, TFunc>(owner); + AddBinaryKernel<i16, i64, i64, TFunc>(owner); + + AddBinaryKernel<ui32, ui8, ui32, TFunc>(owner); + AddBinaryKernel<ui32, i8, ui32, TFunc>(owner); + AddBinaryKernel<ui32, ui16, ui32, TFunc>(owner); + AddBinaryKernel<ui32, i16, ui32, TFunc>(owner); + AddBinaryKernel<ui32, ui32, ui32, TFunc>(owner); + AddBinaryKernel<ui32, i32, i32, TFunc>(owner); + AddBinaryKernel<ui32, ui64, ui64, TFunc>(owner); + AddBinaryKernel<ui32, i64, i64, TFunc>(owner); + + AddBinaryKernel<i32, ui8, i32, TFunc>(owner); + AddBinaryKernel<i32, i8, i32, TFunc>(owner); + AddBinaryKernel<i32, ui16, i32, TFunc>(owner); + AddBinaryKernel<i32, i16, i32, TFunc>(owner); + AddBinaryKernel<i32, ui32, i32, TFunc>(owner); + AddBinaryKernel<i32, i32, i32, TFunc>(owner); + AddBinaryKernel<i32, ui64, ui64, TFunc>(owner); + AddBinaryKernel<i32, i64, i64, TFunc>(owner); + + AddBinaryKernel<ui64, ui8, ui64, TFunc>(owner); + AddBinaryKernel<ui64, i8, ui64, TFunc>(owner); + AddBinaryKernel<ui64, ui16, ui64, TFunc>(owner); + AddBinaryKernel<ui64, i16, ui64, TFunc>(owner); + AddBinaryKernel<ui64, ui32, ui64, TFunc>(owner); + AddBinaryKernel<ui64, i32, ui64, TFunc>(owner); + AddBinaryKernel<ui64, ui64, ui64, TFunc>(owner); + AddBinaryKernel<ui64, i64, i64, TFunc>(owner); + + AddBinaryKernel<i64, ui8, i64, TFunc>(owner); + AddBinaryKernel<i64, i8, i64, TFunc>(owner); + AddBinaryKernel<i64, ui16, i64, TFunc>(owner); + AddBinaryKernel<i64, i16, i64, TFunc>(owner); + AddBinaryKernel<i64, ui32, i64, TFunc>(owner); + AddBinaryKernel<i64, i32, i64, TFunc>(owner); + AddBinaryKernel<i64, ui64, i64, TFunc>(owner); + AddBinaryKernel<i64, i64, i64, TFunc>(owner); } template<template<typename, typename, typename> class TFunc> -class TBinaryNumericFunction : public arrow::compute::ScalarFunction { +class TBinaryNumericKernelFamily : public TKernelFamilyBase { public: - TBinaryNumericFunction(const std::string& name) - : ScalarFunction(name, arrow::compute::Arity::Binary(), nullptr) + TBinaryNumericKernelFamily(TKernelFamily::ENullMode nullMode = TKernelFamily::ENullMode::Default) + : TKernelFamilyBase(nullMode) { AddBinaryIntegralKernels<TFunc>(*this); } }; template<template<typename, typename, typename> class TPred> -void AddBinaryIntegralPredicateKernels(arrow::compute::ScalarFunction& function) { - AddBinaryKernel<ui8, ui8, bool, TPred>(function); - AddBinaryKernel<ui8, i8, bool, TPred>(function); - AddBinaryKernel<ui8, ui16, bool, TPred>(function); - AddBinaryKernel<ui8, i16, bool, TPred>(function); - AddBinaryKernel<ui8, ui32, bool, TPred>(function); - AddBinaryKernel<ui8, i32, bool, TPred>(function); - AddBinaryKernel<ui8, ui64, bool, TPred>(function); - AddBinaryKernel<ui8, i64, bool, TPred>(function); - - AddBinaryKernel<i8, ui8, bool, TPred>(function); - AddBinaryKernel<i8, i8, bool, TPred>(function); - AddBinaryKernel<i8, ui16, bool, TPred>(function); - AddBinaryKernel<i8, i16, bool, TPred>(function); - AddBinaryKernel<i8, ui32, bool, TPred>(function); - AddBinaryKernel<i8, i32, bool, TPred>(function); - AddBinaryKernel<i8, ui64, bool, TPred>(function); - AddBinaryKernel<i8, i64, bool, TPred>(function); - - AddBinaryKernel<ui16, ui8, bool, TPred>(function); - AddBinaryKernel<ui16, i8, bool, TPred>(function); - AddBinaryKernel<ui16, ui16, bool, TPred>(function); - AddBinaryKernel<ui16, i16, bool, TPred>(function); - AddBinaryKernel<ui16, ui32, bool, TPred>(function); - AddBinaryKernel<ui16, i32, bool, TPred>(function); - AddBinaryKernel<ui16, ui64, bool, TPred>(function); - AddBinaryKernel<ui16, i64, bool, TPred>(function); - - AddBinaryKernel<i16, ui8, bool, TPred>(function); - AddBinaryKernel<i16, i8, bool, TPred>(function); - AddBinaryKernel<i16, ui16, bool, TPred>(function); - AddBinaryKernel<i16, i16, bool, TPred>(function); - AddBinaryKernel<i16, ui32, bool, TPred>(function); - AddBinaryKernel<i16, i32, bool, TPred>(function); - AddBinaryKernel<i16, ui64, bool, TPred>(function); - AddBinaryKernel<i16, i64, bool, TPred>(function); - - AddBinaryKernel<ui32, ui8, bool, TPred>(function); - AddBinaryKernel<ui32, i8, bool, TPred>(function); - AddBinaryKernel<ui32, ui16, bool, TPred>(function); - AddBinaryKernel<ui32, i16, bool, TPred>(function); - AddBinaryKernel<ui32, ui32, bool, TPred>(function); - AddBinaryKernel<ui32, i32, bool, TPred>(function); - AddBinaryKernel<ui32, ui64, bool, TPred>(function); - AddBinaryKernel<ui32, i64, bool, TPred>(function); - - AddBinaryKernel<i32, ui8, bool, TPred>(function); - AddBinaryKernel<i32, i8, bool, TPred>(function); - AddBinaryKernel<i32, ui16, bool, TPred>(function); - AddBinaryKernel<i32, i16, bool, TPred>(function); - AddBinaryKernel<i32, ui32, bool, TPred>(function); - AddBinaryKernel<i32, i32, bool, TPred>(function); - AddBinaryKernel<i32, ui64, bool, TPred>(function); - AddBinaryKernel<i32, i64, bool, TPred>(function); - - AddBinaryKernel<ui64, ui8, bool, TPred>(function); - AddBinaryKernel<ui64, i8, bool, TPred>(function); - AddBinaryKernel<ui64, ui16, bool, TPred>(function); - AddBinaryKernel<ui64, i16, bool, TPred>(function); - AddBinaryKernel<ui64, ui32, bool, TPred>(function); - AddBinaryKernel<ui64, i32, bool, TPred>(function); - AddBinaryKernel<ui64, ui64, bool, TPred>(function); - AddBinaryKernel<ui64, i64, bool, TPred>(function); - - AddBinaryKernel<i64, ui8, bool, TPred>(function); - AddBinaryKernel<i64, i8, bool, TPred>(function); - AddBinaryKernel<i64, ui16, bool, TPred>(function); - AddBinaryKernel<i64, i16, bool, TPred>(function); - AddBinaryKernel<i64, ui32, bool, TPred>(function); - AddBinaryKernel<i64, i32, bool, TPred>(function); - AddBinaryKernel<i64, ui64, bool, TPred>(function); - AddBinaryKernel<i64, i64, bool, TPred>(function); +void AddBinaryIntegralPredicateKernels(TKernelFamilyBase& owner) { + AddBinaryKernel<ui8, ui8, bool, TPred>(owner); + AddBinaryKernel<ui8, i8, bool, TPred>(owner); + AddBinaryKernel<ui8, ui16, bool, TPred>(owner); + AddBinaryKernel<ui8, i16, bool, TPred>(owner); + AddBinaryKernel<ui8, ui32, bool, TPred>(owner); + AddBinaryKernel<ui8, i32, bool, TPred>(owner); + AddBinaryKernel<ui8, ui64, bool, TPred>(owner); + AddBinaryKernel<ui8, i64, bool, TPred>(owner); + + AddBinaryKernel<i8, ui8, bool, TPred>(owner); + AddBinaryKernel<i8, i8, bool, TPred>(owner); + AddBinaryKernel<i8, ui16, bool, TPred>(owner); + AddBinaryKernel<i8, i16, bool, TPred>(owner); + AddBinaryKernel<i8, ui32, bool, TPred>(owner); + AddBinaryKernel<i8, i32, bool, TPred>(owner); + AddBinaryKernel<i8, ui64, bool, TPred>(owner); + AddBinaryKernel<i8, i64, bool, TPred>(owner); + + AddBinaryKernel<ui16, ui8, bool, TPred>(owner); + AddBinaryKernel<ui16, i8, bool, TPred>(owner); + AddBinaryKernel<ui16, ui16, bool, TPred>(owner); + AddBinaryKernel<ui16, i16, bool, TPred>(owner); + AddBinaryKernel<ui16, ui32, bool, TPred>(owner); + AddBinaryKernel<ui16, i32, bool, TPred>(owner); + AddBinaryKernel<ui16, ui64, bool, TPred>(owner); + AddBinaryKernel<ui16, i64, bool, TPred>(owner); + + AddBinaryKernel<i16, ui8, bool, TPred>(owner); + AddBinaryKernel<i16, i8, bool, TPred>(owner); + AddBinaryKernel<i16, ui16, bool, TPred>(owner); + AddBinaryKernel<i16, i16, bool, TPred>(owner); + AddBinaryKernel<i16, ui32, bool, TPred>(owner); + AddBinaryKernel<i16, i32, bool, TPred>(owner); + AddBinaryKernel<i16, ui64, bool, TPred>(owner); + AddBinaryKernel<i16, i64, bool, TPred>(owner); + + AddBinaryKernel<ui32, ui8, bool, TPred>(owner); + AddBinaryKernel<ui32, i8, bool, TPred>(owner); + AddBinaryKernel<ui32, ui16, bool, TPred>(owner); + AddBinaryKernel<ui32, i16, bool, TPred>(owner); + AddBinaryKernel<ui32, ui32, bool, TPred>(owner); + AddBinaryKernel<ui32, i32, bool, TPred>(owner); + AddBinaryKernel<ui32, ui64, bool, TPred>(owner); + AddBinaryKernel<ui32, i64, bool, TPred>(owner); + + AddBinaryKernel<i32, ui8, bool, TPred>(owner); + AddBinaryKernel<i32, i8, bool, TPred>(owner); + AddBinaryKernel<i32, ui16, bool, TPred>(owner); + AddBinaryKernel<i32, i16, bool, TPred>(owner); + AddBinaryKernel<i32, ui32, bool, TPred>(owner); + AddBinaryKernel<i32, i32, bool, TPred>(owner); + AddBinaryKernel<i32, ui64, bool, TPred>(owner); + AddBinaryKernel<i32, i64, bool, TPred>(owner); + + AddBinaryKernel<ui64, ui8, bool, TPred>(owner); + AddBinaryKernel<ui64, i8, bool, TPred>(owner); + AddBinaryKernel<ui64, ui16, bool, TPred>(owner); + AddBinaryKernel<ui64, i16, bool, TPred>(owner); + AddBinaryKernel<ui64, ui32, bool, TPred>(owner); + AddBinaryKernel<ui64, i32, bool, TPred>(owner); + AddBinaryKernel<ui64, ui64, bool, TPred>(owner); + AddBinaryKernel<ui64, i64, bool, TPred>(owner); + + AddBinaryKernel<i64, ui8, bool, TPred>(owner); + AddBinaryKernel<i64, i8, bool, TPred>(owner); + AddBinaryKernel<i64, ui16, bool, TPred>(owner); + AddBinaryKernel<i64, i16, bool, TPred>(owner); + AddBinaryKernel<i64, ui32, bool, TPred>(owner); + AddBinaryKernel<i64, i32, bool, TPred>(owner); + AddBinaryKernel<i64, ui64, bool, TPred>(owner); + AddBinaryKernel<i64, i64, bool, TPred>(owner); } template<template<typename, typename, typename> class TPred> -class TBinaryNumericPredicate : public arrow::compute::ScalarFunction { +class TBinaryNumericPredicateKernelFamily : public TKernelFamilyBase { public: - TBinaryNumericPredicate(const std::string& name) - : ScalarFunction(name, arrow::compute::Arity::Binary(), nullptr) + TBinaryNumericPredicateKernelFamily() { AddBinaryIntegralPredicateKernels<TPred>(*this); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp index 01e1c8e7bad..f332867d06e 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp @@ -281,8 +281,8 @@ void RegisterLess(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrLess, TCompareArgsOpt>(registry, aggrName); } -void RegisterLess(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TLessOp>>("Less")); +void RegisterLess(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Less"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TLessOp>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp index 73e883534db..4dd4b8b5d7b 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp @@ -281,8 +281,8 @@ void RegisterLessOrEqual(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrLessOrEqual, TCompareArgsOpt>(registry, aggrName); } -void RegisterLessOrEqual(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TLessOrEqualOp>>("LessOrEqual")); +void RegisterLessOrEqual(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["LessOrEqual"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TLessOrEqualOp>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp index d199fb8c089..57a54defa1a 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp @@ -90,8 +90,8 @@ void RegisterMod(IBuiltinFunctionRegistry& registry) { RegisterBinaryIntegralFunctionOpt<TIntegralMod, TBinaryArgsOptWithNullableResult>(registry, "Mod"); } -void RegisterMod(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericFunction<TIntegralMod>>("Mod?")); +void RegisterMod(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Mod"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralMod>>(TKernelFamily::ENullMode::AlwaysNull); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp index 1dc0202deb2..0b0ab631a6a 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp @@ -97,8 +97,8 @@ void RegisterMul(IBuiltinFunctionRegistry& registry) { NUdf::TDataType<NUdf::TInterval>, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul"); } -void RegisterMul(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericFunction<TMul>>("Mul")); +void RegisterMul(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Mul"] = std::make_unique<TBinaryNumericKernelFamily<TMul>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp index 4fb26abf1ee..a5f5eca21da 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp @@ -284,8 +284,8 @@ void RegisterNotEquals(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrNotEquals, TCompareArgsOpt>(registry, aggrName); } -void RegisterNotEquals(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TNotEqualsOp>>("NotEquals")); +void RegisterNotEquals(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["NotEquals"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TNotEqualsOp>>(); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp index 06cd093241b..bfd8e0782bf 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp @@ -289,8 +289,8 @@ void RegisterSub(IBuiltinFunctionRegistry& registry) { NUdf::TDataType<NUdf::TTzTimestamp>, TAnyDateTimeSubIntervalTz, TBinaryArgsOptWithNullableResult>(registry, "Sub"); } -void RegisterSub(arrow::compute::FunctionRegistry& registry) { - AddFunction(registry, std::make_shared<TBinaryNumericFunction<TSub>>("Sub")); +void RegisterSub(TKernelFamilyMap& kernelFamilyMap) { + kernelFamilyMap["Sub"] = std::make_unique<TBinaryNumericKernelFamily<TSub>>(); } } // namespace NMiniKQL diff --git a/ydb/library/yql/minikql/mkql_function_metadata.h b/ydb/library/yql/minikql/mkql_function_metadata.h index a880b0558b5..3042922133f 100644 --- a/ydb/library/yql/minikql/mkql_function_metadata.h +++ b/ydb/library/yql/minikql/mkql_function_metadata.h @@ -1,10 +1,9 @@ #pragma once #include <ydb/library/yql/public/udf/udf_value.h> +#include <util/digest/numeric.h> -namespace arrow::compute { -class FunctionRegistry; -}; +#include <arrow/compute/kernel.h> namespace NKikimr { @@ -51,6 +50,81 @@ using TArgType = std::pair<NUdf::TDataTypeId, bool>; // type with optional flag using TDescriptionList = std::vector<TFunctionDescriptor>; using TFunctionsMap = std::unordered_map<TString, TDescriptionList>; +class TKernel; + +class TKernelFamily { +public: + enum ENullMode { + Default, + AlwaysNull, + AlwaysNotNull + }; + + const ENullMode NullMode; + const arrow::compute::FunctionOptions* FunctionOptions; + + TKernelFamily(ENullMode nullMode = ENullMode::Default, const arrow::compute::FunctionOptions* functionOptions = nullptr) + : NullMode(nullMode) + , FunctionOptions(functionOptions) + {} + + virtual ~TKernelFamily() = default; + virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0; +}; + +class TKernel { +public: + const TKernelFamily& Family; + const std::vector<NUdf::TDataTypeId> ArgTypes; + const NUdf::TDataTypeId ReturnType; + + TKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType) + : Family(family) + , ArgTypes(argTypes) + , ReturnType(returnType) + { + } + + virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0; + + virtual ~TKernel() = default; +}; + +struct TTypeHasher { + std::size_t operator()(const std::vector<NUdf::TDataTypeId>& s) const noexcept { + size_t r = 0; + for (const auto& x : s) { + r = CombineHashes<size_t>(r, x); + } + + return r; + } +}; + +using TKernelMap = std::unordered_map<std::vector<NUdf::TDataTypeId>, std::unique_ptr<TKernel>, TTypeHasher>; + +using TKernelFamilyMap = std::unordered_map<TString, std::unique_ptr<TKernelFamily>>; + +class TKernelFamilyBase : public TKernelFamily +{ +public: + TKernelFamilyBase(ENullMode nullMode = ENullMode::Default, const arrow::compute::FunctionOptions* functionOptions = nullptr) + : TKernelFamily(nullMode, functionOptions) + {} + + const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final { + std::vector<NUdf::TDataTypeId> key(argTypes, argTypes + argTypesCount); + auto it = KernelMap.find(key); + if (it == KernelMap.end()) { + return nullptr; + } + + return it->second.get(); + } + + TKernelMap KernelMap; +}; + class IBuiltinFunctionRegistry: public TThrRefBase, private TNonCopyable { public: @@ -70,7 +144,9 @@ public: virtual TFunctionDescriptor GetBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const = 0; - virtual arrow::compute::FunctionRegistry* GetArrowFunctionRegistry() const = 0; + virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0; + + virtual void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) = 0; }; } diff --git a/ydb/library/yql/minikql/mkql_type_builder.cpp b/ydb/library/yql/minikql/mkql_type_builder.cpp index 04ec05bec41..5c4b28f36b1 100644 --- a/ydb/library/yql/minikql/mkql_type_builder.cpp +++ b/ydb/library/yql/minikql/mkql_type_builder.cpp @@ -1313,18 +1313,8 @@ private: namespace NMiniKQL { -bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type) { - auto unpacked = UnpackOptional(itemType, isOptional); - if (!unpacked->IsData()) { - return false; - } - - auto slot = AS_TYPE(TDataType, unpacked)->GetDataSlot(); - if (!slot) { - return false; - } - - switch (*slot) { +bool ConvertArrowType(NUdf::EDataSlot slot, std::shared_ptr<arrow::DataType>& type) { + switch (slot) { case NUdf::EDataSlot::Bool: type = arrow::boolean(); return true; @@ -1361,6 +1351,20 @@ bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow:: } } +bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type) { + auto unpacked = UnpackOptional(itemType, isOptional); + if (!unpacked->IsData()) { + return false; + } + + auto slot = AS_TYPE(TDataType, unpacked)->GetDataSlot(); + if (!slot) { + return false; + } + + return ConvertArrowType(*slot, type); +} + void TArrowType::Export(ArrowSchema* out) const { auto status = arrow::ExportType(*Type, out); if (!status.ok()) { diff --git a/ydb/library/yql/minikql/mkql_type_builder.h b/ydb/library/yql/minikql/mkql_type_builder.h index 9cc4ea126ef..7e5e95bcba9 100644 --- a/ydb/library/yql/minikql/mkql_type_builder.h +++ b/ydb/library/yql/minikql/mkql_type_builder.h @@ -11,6 +11,7 @@ namespace NKikimr { namespace NMiniKQL { bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type); +bool ConvertArrowType(NUdf::EDataSlot slot, std::shared_ptr<arrow::DataType>& type); class TArrowType : public NUdf::IArrowType { public: |