aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvvvv <vvvv@ydb.tech>2022-11-21 20:34:51 +0300
committervvvv <vvvv@ydb.tech>2022-11-21 20:34:51 +0300
commit27bc344a905db44fe44c7898b77dd7d5ee81a58d (patch)
tree965c28a1e1961541b0832b24bcae1623df4cf878
parent179c88461116abaccd2fe2d0999c712a1bd2465d (diff)
downloadydb-27bc344a905db44fe44c7898b77dd7d5ee81a58d.tar.gz
discovery of kernels by YQL types
-rw-r--r--ydb/library/yql/minikql/arrow/mkql_functions.cpp39
-rw-r--r--ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp46
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp121
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h12
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h379
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp4
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp4
-rw-r--r--ydb/library/yql/minikql/mkql_function_metadata.h84
-rw-r--r--ydb/library/yql/minikql/mkql_type_builder.cpp28
-rw-r--r--ydb/library/yql/minikql/mkql_type_builder.h1
19 files changed, 465 insertions, 289 deletions
diff --git a/ydb/library/yql/minikql/arrow/mkql_functions.cpp b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
index 01a6582454f..798e3bf02b9 100644
--- a/ydb/library/yql/minikql/arrow/mkql_functions.cpp
+++ b/ydb/library/yql/minikql/arrow/mkql_functions.cpp
@@ -123,41 +123,34 @@ bool ConvertOutputArrowType(const arrow::compute::OutputType& outType, const std
}
bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TType*& outputType, TTypeEnvironment& env, const IBuiltinFunctionRegistry& registry) {
- auto resFunc = registry.GetArrowFunctionRegistry()->GetFunction(TString(name));
- if (!resFunc.ok()) {
- return false;
- }
-
- const auto& func = *resFunc;
- if (func->kind() != arrow::compute::Function::SCALAR) {
- return false;
- }
-
- std::vector<arrow::ValueDescr> values;
bool hasOptionals = false;
- for (const auto& type : inputTypes) {
- arrow::ValueDescr descr;
+ bool many = false;
+ std::vector<NUdf::TDataTypeId> argTypes;
+ for (const auto& t : inputTypes) {
bool isOptional;
- if (!ConvertInputArrowType(type, isOptional, descr)) {
- return false;
+ auto asBlockType = AS_TYPE(TBlockType, t);
+ if (asBlockType->GetShape() == TBlockType::EShape::Many) {
+ many = true;
}
+ auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
hasOptionals = hasOptionals || isOptional;
- values.push_back(descr);
+ argTypes.push_back(dataType->GetSchemeType());
}
- auto resKernel = func->DispatchExact(values);
- if (!resKernel.ok()) {
+ auto kernel = registry.FindKernel(name, argTypes.data(), argTypes.size());
+ if (!kernel) {
return false;
}
- const auto& kernel = static_cast<const arrow::compute::ScalarKernel*>(*resKernel);
- auto notNull = (kernel->null_handling == arrow::compute::NullHandling::OUTPUT_NOT_NULL);
- const auto& outType = kernel->signature->out_type();
- if (!ConvertOutputArrowType(outType, values, name.EndsWith("?") || (hasOptionals && !notNull), outputType, env)) {
- return false;
+ outputType = TDataType::Create(kernel->ReturnType, env);
+ if (kernel->Family.NullMode != TKernelFamily::ENullMode::AlwaysNotNull) {
+ if (hasOptionals || kernel->Family.NullMode == TKernelFamily::ENullMode::AlwaysNull) {
+ outputType = TOptionalType::Create(outputType, env);
+ }
}
+ outputType = TBlockType::Create(outputType, many ? TBlockType::EShape::Many : TBlockType::EShape::Scalar, env);
return true;
}
diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
index b9bce7f6de9..81ee6bb78de 100644
--- a/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/mkql_block_func.cpp
@@ -43,15 +43,29 @@ const arrow::compute::ScalarKernel& ResolveKernel(const arrow::compute::Function
return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
}
+const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes) {
+ std::vector<NUdf::TDataTypeId> argTypes;
+ for (const auto& t : inputTypes) {
+ auto asBlockType = AS_TYPE(TBlockType, t);
+ bool isOptional;
+ auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
+ argTypes.push_back(dataType->GetSchemeType());
+ }
+
+ auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size());
+ MKQL_ENSURE(kernel, "Can't find kernel for " << funcName);
+ return *kernel;
+}
+
struct TState : public TComputationValue<TState> {
using TComputationValue::TComputationValue;
- TState(TMemoryUsageInfo* memInfo, const arrow::compute::Function& function, const arrow::compute::FunctionOptions* options,
+ TState(TMemoryUsageInfo* memInfo, const arrow::compute::FunctionOptions* options,
const arrow::compute::ScalarKernel& kernel,
- const arrow::compute::FunctionRegistry& registry, const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx)
+ const std::vector<arrow::ValueDescr>& argsValuesDescr, TComputationContext& ctx)
: TComputationValue(memInfo)
, Options(options)
- , ExecContext(&ctx.ArrowMemoryPool, nullptr, const_cast<arrow::compute::FunctionRegistry*>(&registry))
+ , ExecContext(&ctx.ArrowMemoryPool, nullptr, nullptr)
, KernelContext(&ExecContext)
{
if (kernel.init) {
@@ -73,7 +87,7 @@ struct TState : public TComputationValue<TState> {
class TBlockFuncWrapper : public TMutableComputationNode<TBlockFuncWrapper> {
public:
TBlockFuncWrapper(TComputationMutables& mutables,
- const arrow::compute::FunctionRegistry& functionRegistry,
+ const IBuiltinFunctionRegistry& builtins,
const TString& funcName,
TVector<IComputationNode*>&& argsNodes,
TVector<TType*>&& argsTypes)
@@ -83,9 +97,7 @@ public:
, ArgsNodes(std::move(argsNodes))
, ArgsTypes(std::move(argsTypes))
, ArgsValuesDescr(ToValueDescr(ArgsTypes))
- , FunctionRegistry(functionRegistry)
- , Function(ResolveFunction(FunctionRegistry, FuncName))
- , Kernel(ResolveKernel(Function, ArgsValuesDescr))
+ , Kernel(ResolveKernel(builtins, FuncName, ArgsTypes))
{
}
@@ -100,7 +112,7 @@ public:
auto listener = std::make_shared<arrow::compute::detail::DatumAccumulator>();
auto executor = arrow::compute::detail::KernelExecutor::MakeScalar();
- ARROW_OK(executor->Init(&state.KernelContext, { &Kernel, ArgsValuesDescr, state.Options }));
+ ARROW_OK(executor->Init(&state.KernelContext, { &Kernel.GetArrowKernel(), ArgsValuesDescr, state.Options }));
ARROW_OK(executor->Execute(state.Values, listener.get()));
auto output = executor->WrapResults(state.Values, listener->values());
return ctx.HolderFactory.CreateArrowBlock(std::move(output));
@@ -123,7 +135,7 @@ private:
TState& GetState(TComputationContext& ctx) const {
auto& result = ctx.MutableValues[StateIndex];
if (!result.HasValue()) {
- result = ctx.HolderFactory.Create<TState>(Function, Function.default_options(), Kernel, FunctionRegistry, ArgsValuesDescr, ctx);
+ result = ctx.HolderFactory.Create<TState>(Kernel.Family.FunctionOptions, Kernel.GetArrowKernel(), ArgsValuesDescr, ctx);
}
return *static_cast<TState*>(result.AsBoxed().Get());
@@ -136,15 +148,12 @@ private:
const TVector<TType*> ArgsTypes;
const std::vector<arrow::ValueDescr> ArgsValuesDescr;
- const arrow::compute::FunctionRegistry& FunctionRegistry;
- const arrow::compute::Function& Function;
- const arrow::compute::ScalarKernel& Kernel;
+ const TKernel& Kernel;
};
class TBlockBitCastWrapper : public TMutableComputationNode<TBlockBitCastWrapper> {
public:
TBlockBitCastWrapper(TComputationMutables& mutables,
- const arrow::compute::FunctionRegistry& functionRegistry,
IComputationNode* arg,
TType* argType,
TType* to)
@@ -152,8 +161,7 @@ public:
, StateIndex(mutables.CurValueIndex++)
, Arg(arg)
, ArgsValuesDescr({ ToValueDescr(argType) })
- , FunctionRegistry(functionRegistry)
- , Function(ResolveFunction(FunctionRegistry, to))
+ , Function(ResolveFunction(to))
, Kernel(ResolveKernel(Function, ArgsValuesDescr))
, CastOptions(false)
{
@@ -179,7 +187,7 @@ private:
this->DependsOn(Arg);
}
- static const arrow::compute::Function& ResolveFunction(const arrow::compute::FunctionRegistry& registry, TType* to) {
+ static const arrow::compute::Function& ResolveFunction(TType* to) {
bool isOptional;
std::shared_ptr<arrow::DataType> type;
MKQL_ENSURE(ConvertArrowType(to, isOptional, type), "can't get arrow type");
@@ -193,7 +201,7 @@ private:
TState& GetState(TComputationContext& ctx) const {
auto& result = ctx.MutableValues[StateIndex];
if (!result.HasValue()) {
- result = ctx.HolderFactory.Create<TState>(Function, (const arrow::compute::FunctionOptions*)&CastOptions, Kernel, FunctionRegistry, ArgsValuesDescr, ctx);
+ result = ctx.HolderFactory.Create<TState>((const arrow::compute::FunctionOptions*)&CastOptions, Kernel, ArgsValuesDescr, ctx);
}
return *static_cast<TState*>(result.AsBoxed().Get());
@@ -203,7 +211,6 @@ private:
const ui32 StateIndex;
IComputationNode* Arg;
const std::vector<arrow::ValueDescr> ArgsValuesDescr;
- const arrow::compute::FunctionRegistry& FunctionRegistry;
const arrow::compute::Function& Function;
const arrow::compute::ScalarKernel& Kernel;
arrow::compute::CastOptions CastOptions;
@@ -224,7 +231,7 @@ IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFacto
}
return new TBlockFuncWrapper(ctx.Mutables,
- *ctx.FunctionRegistry.GetBuiltins()->GetArrowFunctionRegistry(),
+ *ctx.FunctionRegistry.GetBuiltins(),
funcName,
std::move(argsNodes),
std::move(argsTypes)
@@ -236,7 +243,6 @@ IComputationNode* WrapBlockBitCast(TCallable& callable, const TComputationNodeFa
auto argNode = LocateNode(ctx.NodeLocator, callable, 0);
MKQL_ENSURE(callable.GetInput(1).GetStaticType()->IsType(), "Expected type");
return new TBlockBitCastWrapper(ctx.Mutables,
- *ctx.FunctionRegistry.GetBuiltins()->GetArrowFunctionRegistry(),
argNode,
callable.GetType()->GetArgumentType(0),
static_cast<TType*>(callable.GetInput(1).GetNode())
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
index e8add1b8c11..6ab9f7e66dc 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
@@ -2,6 +2,8 @@
#include "mkql_builtins_impl.h"
#include "mkql_builtins_compare.h"
+#include <ydb/library/yql/minikql/arrow/arrow_defs.h>
+
#include <util/digest/murmur.h>
#include <util/generic/yexception.h>
#include <util/generic/maybe.h>
@@ -16,18 +18,77 @@ namespace NMiniKQL {
namespace {
-void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, arrow::compute::FunctionRegistry& arrowRegistry) {
+class TForeignKernel : public TKernel {
+public:
+ TForeignKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType,
+ const std::shared_ptr<arrow::compute::Function>& function)
+ : TKernel(family, argTypes, returnType)
+ , Function(function)
+ , ArrowKernel(ResolveKernel(Function, argTypes))
+ {}
+
+ const arrow::compute::ScalarKernel& GetArrowKernel() const final {
+ return ArrowKernel;
+ }
+
+private:
+ static const arrow::compute::ScalarKernel& ResolveKernel(const std::shared_ptr<arrow::compute::Function>& function,
+ const std::vector<NUdf::TDataTypeId>& argTypes) {
+ std::vector<arrow::ValueDescr> args;
+ for (const auto& t : argTypes) {
+ args.emplace_back();
+ auto slot = NUdf::FindDataSlot(t);
+ MKQL_ENSURE(slot, "Unexpected data type");
+ MKQL_ENSURE(ConvertArrowType(*slot, args.back().type), "Can't get arrow type");
+ }
+
+ const auto kernel = ARROW_RESULT(function->DispatchExact(args));
+ return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
+ }
+
+private:
+ const std::shared_ptr<arrow::compute::Function> Function;
+ const arrow::compute::ScalarKernel& ArrowKernel;
+};
+
+template <typename TInput1, typename TOutput>
+void RegisterUnary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) {
+ auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));
+
+ std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id });
+ NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;
+
+ auto family = std::make_unique<TKernelFamilyBase>();
+ family->KernelMap.emplace(argTypes, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func));
+
+ Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second);
+}
+
+template <typename TInput1, typename TInput2, typename TOutput>
+void RegisterBinary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) {
+ auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));
+
+ std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id, NUdf::TDataType<TInput2>::Id });
+ NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;
+
+ auto family = std::make_unique<TKernelFamilyBase>();
+ family->KernelMap.emplace(argTypes, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func));
+
+ Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second);
+}
+
+void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, TKernelFamilyMap& kernelFamilyMap) {
RegisterAdd(registry);
- RegisterAdd(arrowRegistry);
+ RegisterAdd(kernelFamilyMap);
RegisterAggrAdd(registry);
RegisterSub(registry);
- RegisterSub(arrowRegistry);
+ RegisterSub(kernelFamilyMap);
RegisterMul(registry);
- RegisterMul(arrowRegistry);
+ RegisterMul(kernelFamilyMap);
RegisterDiv(registry);
- RegisterDiv(arrowRegistry);
+ RegisterDiv(kernelFamilyMap);
RegisterMod(registry);
- RegisterMod(arrowRegistry);
+ RegisterMod(kernelFamilyMap);
RegisterIncrement(registry);
RegisterDecrement(registry);
RegisterBitAnd(registry);
@@ -56,17 +117,17 @@ void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, arrow::comput
RegisterAggrMax(registry);
RegisterAggrMin(registry);
RegisterEquals(registry);
- RegisterEquals(arrowRegistry);
+ RegisterEquals(kernelFamilyMap);
RegisterNotEquals(registry);
- RegisterNotEquals(arrowRegistry);
+ RegisterNotEquals(kernelFamilyMap);
RegisterLess(registry);
- RegisterLess(arrowRegistry);
+ RegisterLess(kernelFamilyMap);
RegisterLessOrEqual(registry);
- RegisterLessOrEqual(arrowRegistry);
+ RegisterLessOrEqual(kernelFamilyMap);
RegisterGreater(registry);
- RegisterGreater(arrowRegistry);
+ RegisterGreater(kernelFamilyMap);
RegisterGreaterOrEqual(registry);
- RegisterGreaterOrEqual(arrowRegistry);
+ RegisterGreaterOrEqual(kernelFamilyMap);
}
void PrintType(NUdf::TDataTypeId schemeType, bool isOptional, IOutputStream& out)
@@ -153,25 +214,31 @@ private:
const TFunctionsMap& GetFunctions() const final;
- arrow::compute::FunctionRegistry* GetArrowFunctionRegistry() const final;
-
void CalculateMetadataEtag();
std::optional<TFunctionDescriptor> FindBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const;
const TDescriptionList& FindCandidates(const std::string_view& name) const;
+ const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final;
+
+ void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) final;
+
TFunctionsMap Functions;
TFunctionParamMetadataList ArgumentsMetadata;
std::optional<ui64> MetadataEtag;
- std::unique_ptr<arrow::compute::FunctionRegistry> ArrowRegistry;
+ TKernelFamilyMap KernelFamilyMap;
};
TBuiltinFunctionRegistry::TBuiltinFunctionRegistry()
- : ArrowRegistry(arrow::compute::FunctionRegistry::Make())
{
- RegisterDefaultOperations(*this, *ArrowRegistry);
- arrow::compute::internal::RegisterScalarBoolean(ArrowRegistry.get());
+ RegisterDefaultOperations(*this, KernelFamilyMap);
+ auto arrowRegistry = arrow::compute::FunctionRegistry::Make();
+ arrow::compute::internal::RegisterScalarBoolean(arrowRegistry.get());
+ RegisterUnary<bool, bool>(*arrowRegistry, "invert", KernelFamilyMap);
+ RegisterBinary<bool, bool, bool>(*arrowRegistry, "and_kleene", KernelFamilyMap);
+ RegisterBinary<bool, bool, bool>(*arrowRegistry, "or_kleene", KernelFamilyMap);
+ RegisterBinary<bool, bool, bool>(*arrowRegistry, "xor", KernelFamilyMap);
CalculateMetadataEtag();
}
@@ -296,8 +363,17 @@ void TBuiltinFunctionRegistry::PrintInfoTo(IOutputStream& out) const
}
}
-arrow::compute::FunctionRegistry* TBuiltinFunctionRegistry::GetArrowFunctionRegistry() const {
- return ArrowRegistry.get();
+const TKernel* TBuiltinFunctionRegistry::FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const {
+ auto fit = KernelFamilyMap.find(TString(name));
+ if (fit == KernelFamilyMap.end()) {
+ return nullptr;
+ }
+
+ return fit->second->FindKernel(argTypes, argTypesCount);
+}
+
+void TBuiltinFunctionRegistry::RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) {
+ Y_ENSURE(KernelFamilyMap.emplace(TString(name), std::move(family)).second);
}
} // namespace
@@ -306,10 +382,5 @@ IBuiltinFunctionRegistry::TPtr CreateBuiltinRegistry() {
return MakeIntrusive<TBuiltinFunctionRegistry>();
}
-void AddFunction(arrow::compute::FunctionRegistry& registry, const std::shared_ptr<arrow::compute::ScalarFunction>& f) {
- ARROW_OK(registry.AddFunction(f));
-}
-
-
} // namespace NMiniKQL
} // namespace NKikimr
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp
index 0cce313512f..967f16e01d6 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp
@@ -192,8 +192,8 @@ void RegisterAdd(IBuiltinFunctionRegistry& registry) {
NUdf::TDataType<NUdf::TInterval>, TDateTimeAdd, TBinaryArgsOptWithNullableResult>(registry, "Add");
}
-void RegisterAdd(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericFunction<TAdd>>("Add"));
+void RegisterAdd(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Add"] = std::make_unique<TBinaryNumericKernelFamily<TAdd>>();
}
void RegisterAggrAdd(IBuiltinFunctionRegistry& registry) {
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h
index 3163168c72c..0542d1f1a16 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h
@@ -470,17 +470,17 @@ void RegisterAggrCompareStrings(IBuiltinFunctionRegistry& registry, const std::s
}
void RegisterEquals(IBuiltinFunctionRegistry& registry);
-void RegisterEquals(arrow::compute::FunctionRegistry& registry);
+void RegisterEquals(TKernelFamilyMap& kernelFamilyMap);
void RegisterNotEquals(IBuiltinFunctionRegistry& registry);
-void RegisterNotEquals(arrow::compute::FunctionRegistry& registry);
+void RegisterNotEquals(TKernelFamilyMap& kernelFamilyMap);
void RegisterLess(IBuiltinFunctionRegistry& registry);
-void RegisterLess(arrow::compute::FunctionRegistry& registry);
+void RegisterLess(TKernelFamilyMap& kernelFamilyMap);
void RegisterLessOrEqual(IBuiltinFunctionRegistry& registry);
-void RegisterLessOrEqual(arrow::compute::FunctionRegistry& registry);
+void RegisterLessOrEqual(TKernelFamilyMap& kernelFamilyMap);
void RegisterGreater(IBuiltinFunctionRegistry& registry);
-void RegisterGreater(arrow::compute::FunctionRegistry& registry);
+void RegisterGreater(TKernelFamilyMap& kernelFamilyMap);
void RegisterGreaterOrEqual(IBuiltinFunctionRegistry& registry);
-void RegisterGreaterOrEqual(arrow::compute::FunctionRegistry& registry);
+void RegisterGreaterOrEqual(TKernelFamilyMap& kernelFamilyMap);
}
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp
index 607b6f444b2..18a76503bba 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp
@@ -166,8 +166,8 @@ void RegisterDiv(IBuiltinFunctionRegistry& registry) {
NUdf::TDataType<NUdf::TInterval>, TNumDivInterval, TBinaryArgsOptWithNullableResult>(registry, "Div");
}
-void RegisterDiv(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericFunction<TIntegralDiv>>("Div?"));
+void RegisterDiv(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Div"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralDiv>>(TKernelFamily::ENullMode::AlwaysNull);
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp
index 00857259d97..878bc03d5af 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp
@@ -284,8 +284,8 @@ void RegisterEquals(IBuiltinFunctionRegistry& registry) {
RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrEquals, TCompareArgsOpt>(registry, aggrName);
}
-void RegisterEquals(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TEqualsOp>>("Equals"));
+void RegisterEquals(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Equals"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TEqualsOp>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp
index 15f035a7bdb..fd7eabe41fb 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp
@@ -281,8 +281,8 @@ void RegisterGreater(IBuiltinFunctionRegistry& registry) {
RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrGreater, TCompareArgsOpt>(registry, aggrName);
}
-void RegisterGreater(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TGreaterOp>>("Greater"));
+void RegisterGreater(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Greater"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TGreaterOp>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp
index 604f9955507..e402584917c 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp
@@ -281,8 +281,8 @@ void RegisterGreaterOrEqual(IBuiltinFunctionRegistry& registry) {
RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrGreaterOrEqual, TCompareArgsOpt>(registry, aggrName);
}
-void RegisterGreaterOrEqual(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TGreaterOrEqualOp>>("GreaterOrEqual"));
+void RegisterGreaterOrEqual(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["GreaterOrEqual"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TGreaterOrEqualOp>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h
index d2fb0844966..ff8255c9cc3 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h
@@ -768,16 +768,16 @@ void RegisterBinaryRealFunction(IBuiltinFunctionRegistry& registry, const std::s
}
void RegisterAdd(IBuiltinFunctionRegistry& registry);
-void RegisterAdd(arrow::compute::FunctionRegistry& registry);
+void RegisterAdd(TKernelFamilyMap& kernelFamilyMap);
void RegisterAggrAdd(IBuiltinFunctionRegistry& registry);
void RegisterSub(IBuiltinFunctionRegistry& registry);
-void RegisterSub(arrow::compute::FunctionRegistry& registry);
+void RegisterSub(TKernelFamilyMap& kernelFamilyMap);
void RegisterMul(IBuiltinFunctionRegistry& registry);
-void RegisterMul(arrow::compute::FunctionRegistry& registry);
+void RegisterMul(TKernelFamilyMap& kernelFamilyMap);
void RegisterDiv(IBuiltinFunctionRegistry& registry);
-void RegisterDiv(arrow::compute::FunctionRegistry& registry);
+void RegisterDiv(TKernelFamilyMap& kernelFamilyMap);
void RegisterMod(IBuiltinFunctionRegistry& registry);
-void RegisterMod(arrow::compute::FunctionRegistry& registry);
+void RegisterMod(TKernelFamilyMap& kernelFamilyMap);
void RegisterIncrement(IBuiltinFunctionRegistry& registry);
void RegisterDecrement(IBuiltinFunctionRegistry& registry);
void RegisterBitAnd(IBuiltinFunctionRegistry& registry);
@@ -806,8 +806,6 @@ void RegisterAggrMax(IBuiltinFunctionRegistry& registry);
void RegisterAggrMin(IBuiltinFunctionRegistry& registry);
void RegisterWith(IBuiltinFunctionRegistry& registry);
-void AddFunction(arrow::compute::FunctionRegistry& registry, const std::shared_ptr<arrow::compute::ScalarFunction>& f);
-
inline arrow::internal::Bitmap GetBitmap(const arrow::ArrayData& arr, int index) {
return arrow::internal::Bitmap{ arr.buffers[index], arr.offset, arr.length };
}
@@ -861,8 +859,8 @@ inline std::shared_ptr<arrow::DataType> GetPrimitiveDataType<ui64>() {
}
template <typename T>
-arrow::compute::InputType GetPrimitiveInputArrowType(bool isScalar) {
- return arrow::compute::InputType(GetPrimitiveDataType<T>(), isScalar ? arrow::ValueDescr::SCALAR : arrow::ValueDescr::ARRAY);
+arrow::compute::InputType GetPrimitiveInputArrowType() {
+ return arrow::compute::InputType(GetPrimitiveDataType<T>(), arrow::ValueDescr::ANY);
}
template <typename T>
@@ -1008,9 +1006,31 @@ template<typename TInput1, typename TInput2, typename TOutput,
template<typename, typename, typename> class TFunc, bool DefaultNulls>
struct TBinaryKernelExecs;
+template<typename TDerived>
+struct TBinaryKernelExecsBase {
+ static arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
+ MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args");
+ const auto& arg1 = batch.values[0];
+ const auto& arg2 = batch.values[1];
+ if (arg1.is_scalar()) {
+ if (arg2.is_scalar()) {
+ return TDerived::ExecScalarScalar(ctx, batch, res);
+ } else {
+ return TDerived::ExecScalarArray(ctx, batch, res);
+ }
+ } else {
+ if (arg2.is_scalar()) {
+ return TDerived::ExecArrayScalar(ctx, batch, res);
+ } else {
+ return TDerived::ExecArrayArray(ctx, batch, res);
+ }
+ }
+ }
+};
+
template<typename TInput1, typename TInput2, typename TOutput,
template<typename, typename, typename> class TFunc>
-struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true>
+struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> : TBinaryKernelExecsBase<TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true>>
{
using TFuncInstance = TFunc<TInput1, TInput2, TOutput>;
@@ -1107,7 +1127,7 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true>
template<typename TInput1, typename TInput2, typename TOutput,
template<typename, typename, typename> class TFunc>
-struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false>
+struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> : TBinaryKernelExecsBase<TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false>>
{
using TFuncInstance = TFunc<TInput1, TInput2, TOutput>;
@@ -1237,195 +1257,200 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false>
}
};
+class TPlainKernel : public TKernel {
+public:
+ TPlainKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, const arrow::compute::ScalarKernel& arrowKernel)
+ : TKernel(family, argTypes, returnType)
+ , ArrowKernel(arrowKernel)
+ {
+ }
+
+ const arrow::compute::ScalarKernel& GetArrowKernel() const final {
+ return ArrowKernel;
+ }
+
+private:
+ const arrow::compute::ScalarKernel ArrowKernel;
+};
+
template<typename TInput1, typename TInput2, typename TOutput,
template<typename, typename, typename> class TFunc>
-void AddBinaryKernel(arrow::compute::ScalarFunction& function) {
+void AddBinaryKernel(TKernelFamilyBase& owner) {
using TFuncInstance = TFunc<TInput1, TInput2, TOutput>;
using TExecs = TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, TFuncInstance::DefaultNulls>;
- auto nullHandling = TFuncInstance::DefaultNulls ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
-
- arrow::compute::ScalarKernel ss({GetPrimitiveInputArrowType<TInput1>(true), GetPrimitiveInputArrowType<TInput2>(true) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecScalarScalar);
- ss.null_handling = nullHandling;
- ARROW_OK(function.AddKernel(ss));
-
- arrow::compute::ScalarKernel sa({ GetPrimitiveInputArrowType<TInput1>(true), GetPrimitiveInputArrowType<TInput2>(false) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecScalarArray);
- sa.null_handling = nullHandling;
- ARROW_OK(function.AddKernel(sa));
- arrow::compute::ScalarKernel as({ GetPrimitiveInputArrowType<TInput1>(false), GetPrimitiveInputArrowType<TInput2>(true) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecArrayScalar);
- as.null_handling = nullHandling;
- ARROW_OK(function.AddKernel(as));
+ std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id, NUdf::TDataType<TInput2>::Id });
+ NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;
- arrow::compute::ScalarKernel aa({ GetPrimitiveInputArrowType<TInput1>(false), GetPrimitiveInputArrowType<TInput2>(false) }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::ExecArrayArray);
- aa.null_handling = nullHandling;
- ARROW_OK(function.AddKernel(aa));
+ arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInput1>(), GetPrimitiveInputArrowType<TInput2>() }, GetPrimitiveOutputArrowType<TOutput>(), &TExecs::Exec);
+ k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
+ owner.KernelMap.emplace(argTypes, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
}
template<template<typename, typename, typename> class TFunc>
-void AddBinaryIntegralKernels(arrow::compute::ScalarFunction& function) {
- AddBinaryKernel<ui8, ui8, ui8, TFunc>(function);
- AddBinaryKernel<ui8, i8, i8, TFunc>(function);
- AddBinaryKernel<ui8, ui16, ui16, TFunc>(function);
- AddBinaryKernel<ui8, i16, i16, TFunc>(function);
- AddBinaryKernel<ui8, ui32, ui32, TFunc>(function);
- AddBinaryKernel<ui8, i32, i32, TFunc>(function);
- AddBinaryKernel<ui8, ui64, ui64, TFunc>(function);
- AddBinaryKernel<ui8, i64, i64, TFunc>(function);
-
- AddBinaryKernel<i8, ui8, i8, TFunc>(function);
- AddBinaryKernel<i8, i8, i8, TFunc>(function);
- AddBinaryKernel<i8, ui16, ui16, TFunc>(function);
- AddBinaryKernel<i8, i16, i16, TFunc>(function);
- AddBinaryKernel<i8, ui32, ui32, TFunc>(function);
- AddBinaryKernel<i8, i32, i32, TFunc>(function);
- AddBinaryKernel<i8, ui64, ui64, TFunc>(function);
- AddBinaryKernel<i8, i64, i64, TFunc>(function);
-
- AddBinaryKernel<ui16, ui8, ui16, TFunc>(function);
- AddBinaryKernel<ui16, i8, ui16, TFunc>(function);
- AddBinaryKernel<ui16, ui16, ui16, TFunc>(function);
- AddBinaryKernel<ui16, i16, i16, TFunc>(function);
- AddBinaryKernel<ui16, ui32, ui32, TFunc>(function);
- AddBinaryKernel<ui16, i32, i32, TFunc>(function);
- AddBinaryKernel<ui16, ui64, ui64, TFunc>(function);
- AddBinaryKernel<ui16, i64, i64, TFunc>(function);
-
- AddBinaryKernel<i16, ui8, i16, TFunc>(function);
- AddBinaryKernel<i16, i8, i16, TFunc>(function);
- AddBinaryKernel<i16, ui16, i16, TFunc>(function);
- AddBinaryKernel<i16, i16, i16, TFunc>(function);
- AddBinaryKernel<i16, ui32, ui32, TFunc>(function);
- AddBinaryKernel<i16, i32, i32, TFunc>(function);
- AddBinaryKernel<i16, ui64, ui64, TFunc>(function);
- AddBinaryKernel<i16, i64, i64, TFunc>(function);
-
- AddBinaryKernel<ui32, ui8, ui32, TFunc>(function);
- AddBinaryKernel<ui32, i8, ui32, TFunc>(function);
- AddBinaryKernel<ui32, ui16, ui32, TFunc>(function);
- AddBinaryKernel<ui32, i16, ui32, TFunc>(function);
- AddBinaryKernel<ui32, ui32, ui32, TFunc>(function);
- AddBinaryKernel<ui32, i32, i32, TFunc>(function);
- AddBinaryKernel<ui32, ui64, ui64, TFunc>(function);
- AddBinaryKernel<ui32, i64, i64, TFunc>(function);
-
- AddBinaryKernel<i32, ui8, i32, TFunc>(function);
- AddBinaryKernel<i32, i8, i32, TFunc>(function);
- AddBinaryKernel<i32, ui16, i32, TFunc>(function);
- AddBinaryKernel<i32, i16, i32, TFunc>(function);
- AddBinaryKernel<i32, ui32, i32, TFunc>(function);
- AddBinaryKernel<i32, i32, i32, TFunc>(function);
- AddBinaryKernel<i32, ui64, ui64, TFunc>(function);
- AddBinaryKernel<i32, i64, i64, TFunc>(function);
-
- AddBinaryKernel<ui64, ui8, ui64, TFunc>(function);
- AddBinaryKernel<ui64, i8, ui64, TFunc>(function);
- AddBinaryKernel<ui64, ui16, ui64, TFunc>(function);
- AddBinaryKernel<ui64, i16, ui64, TFunc>(function);
- AddBinaryKernel<ui64, ui32, ui64, TFunc>(function);
- AddBinaryKernel<ui64, i32, ui64, TFunc>(function);
- AddBinaryKernel<ui64, ui64, ui64, TFunc>(function);
- AddBinaryKernel<ui64, i64, i64, TFunc>(function);
-
- AddBinaryKernel<i64, ui8, i64, TFunc>(function);
- AddBinaryKernel<i64, i8, i64, TFunc>(function);
- AddBinaryKernel<i64, ui16, i64, TFunc>(function);
- AddBinaryKernel<i64, i16, i64, TFunc>(function);
- AddBinaryKernel<i64, ui32, i64, TFunc>(function);
- AddBinaryKernel<i64, i32, i64, TFunc>(function);
- AddBinaryKernel<i64, ui64, i64, TFunc>(function);
- AddBinaryKernel<i64, i64, i64, TFunc>(function);
+void AddBinaryIntegralKernels(TKernelFamilyBase& owner) {
+ AddBinaryKernel<ui8, ui8, ui8, TFunc>(owner);
+ AddBinaryKernel<ui8, i8, i8, TFunc>(owner);
+ AddBinaryKernel<ui8, ui16, ui16, TFunc>(owner);
+ AddBinaryKernel<ui8, i16, i16, TFunc>(owner);
+ AddBinaryKernel<ui8, ui32, ui32, TFunc>(owner);
+ AddBinaryKernel<ui8, i32, i32, TFunc>(owner);
+ AddBinaryKernel<ui8, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<ui8, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<i8, ui8, i8, TFunc>(owner);
+ AddBinaryKernel<i8, i8, i8, TFunc>(owner);
+ AddBinaryKernel<i8, ui16, ui16, TFunc>(owner);
+ AddBinaryKernel<i8, i16, i16, TFunc>(owner);
+ AddBinaryKernel<i8, ui32, ui32, TFunc>(owner);
+ AddBinaryKernel<i8, i32, i32, TFunc>(owner);
+ AddBinaryKernel<i8, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<i8, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<ui16, ui8, ui16, TFunc>(owner);
+ AddBinaryKernel<ui16, i8, ui16, TFunc>(owner);
+ AddBinaryKernel<ui16, ui16, ui16, TFunc>(owner);
+ AddBinaryKernel<ui16, i16, i16, TFunc>(owner);
+ AddBinaryKernel<ui16, ui32, ui32, TFunc>(owner);
+ AddBinaryKernel<ui16, i32, i32, TFunc>(owner);
+ AddBinaryKernel<ui16, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<ui16, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<i16, ui8, i16, TFunc>(owner);
+ AddBinaryKernel<i16, i8, i16, TFunc>(owner);
+ AddBinaryKernel<i16, ui16, i16, TFunc>(owner);
+ AddBinaryKernel<i16, i16, i16, TFunc>(owner);
+ AddBinaryKernel<i16, ui32, ui32, TFunc>(owner);
+ AddBinaryKernel<i16, i32, i32, TFunc>(owner);
+ AddBinaryKernel<i16, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<i16, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<ui32, ui8, ui32, TFunc>(owner);
+ AddBinaryKernel<ui32, i8, ui32, TFunc>(owner);
+ AddBinaryKernel<ui32, ui16, ui32, TFunc>(owner);
+ AddBinaryKernel<ui32, i16, ui32, TFunc>(owner);
+ AddBinaryKernel<ui32, ui32, ui32, TFunc>(owner);
+ AddBinaryKernel<ui32, i32, i32, TFunc>(owner);
+ AddBinaryKernel<ui32, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<ui32, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<i32, ui8, i32, TFunc>(owner);
+ AddBinaryKernel<i32, i8, i32, TFunc>(owner);
+ AddBinaryKernel<i32, ui16, i32, TFunc>(owner);
+ AddBinaryKernel<i32, i16, i32, TFunc>(owner);
+ AddBinaryKernel<i32, ui32, i32, TFunc>(owner);
+ AddBinaryKernel<i32, i32, i32, TFunc>(owner);
+ AddBinaryKernel<i32, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<i32, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<ui64, ui8, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, i8, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, ui16, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, i16, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, ui32, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, i32, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, ui64, ui64, TFunc>(owner);
+ AddBinaryKernel<ui64, i64, i64, TFunc>(owner);
+
+ AddBinaryKernel<i64, ui8, i64, TFunc>(owner);
+ AddBinaryKernel<i64, i8, i64, TFunc>(owner);
+ AddBinaryKernel<i64, ui16, i64, TFunc>(owner);
+ AddBinaryKernel<i64, i16, i64, TFunc>(owner);
+ AddBinaryKernel<i64, ui32, i64, TFunc>(owner);
+ AddBinaryKernel<i64, i32, i64, TFunc>(owner);
+ AddBinaryKernel<i64, ui64, i64, TFunc>(owner);
+ AddBinaryKernel<i64, i64, i64, TFunc>(owner);
}
template<template<typename, typename, typename> class TFunc>
-class TBinaryNumericFunction : public arrow::compute::ScalarFunction {
+class TBinaryNumericKernelFamily : public TKernelFamilyBase {
public:
- TBinaryNumericFunction(const std::string& name)
- : ScalarFunction(name, arrow::compute::Arity::Binary(), nullptr)
+ TBinaryNumericKernelFamily(TKernelFamily::ENullMode nullMode = TKernelFamily::ENullMode::Default)
+ : TKernelFamilyBase(nullMode)
{
AddBinaryIntegralKernels<TFunc>(*this);
}
};
template<template<typename, typename, typename> class TPred>
-void AddBinaryIntegralPredicateKernels(arrow::compute::ScalarFunction& function) {
- AddBinaryKernel<ui8, ui8, bool, TPred>(function);
- AddBinaryKernel<ui8, i8, bool, TPred>(function);
- AddBinaryKernel<ui8, ui16, bool, TPred>(function);
- AddBinaryKernel<ui8, i16, bool, TPred>(function);
- AddBinaryKernel<ui8, ui32, bool, TPred>(function);
- AddBinaryKernel<ui8, i32, bool, TPred>(function);
- AddBinaryKernel<ui8, ui64, bool, TPred>(function);
- AddBinaryKernel<ui8, i64, bool, TPred>(function);
-
- AddBinaryKernel<i8, ui8, bool, TPred>(function);
- AddBinaryKernel<i8, i8, bool, TPred>(function);
- AddBinaryKernel<i8, ui16, bool, TPred>(function);
- AddBinaryKernel<i8, i16, bool, TPred>(function);
- AddBinaryKernel<i8, ui32, bool, TPred>(function);
- AddBinaryKernel<i8, i32, bool, TPred>(function);
- AddBinaryKernel<i8, ui64, bool, TPred>(function);
- AddBinaryKernel<i8, i64, bool, TPred>(function);
-
- AddBinaryKernel<ui16, ui8, bool, TPred>(function);
- AddBinaryKernel<ui16, i8, bool, TPred>(function);
- AddBinaryKernel<ui16, ui16, bool, TPred>(function);
- AddBinaryKernel<ui16, i16, bool, TPred>(function);
- AddBinaryKernel<ui16, ui32, bool, TPred>(function);
- AddBinaryKernel<ui16, i32, bool, TPred>(function);
- AddBinaryKernel<ui16, ui64, bool, TPred>(function);
- AddBinaryKernel<ui16, i64, bool, TPred>(function);
-
- AddBinaryKernel<i16, ui8, bool, TPred>(function);
- AddBinaryKernel<i16, i8, bool, TPred>(function);
- AddBinaryKernel<i16, ui16, bool, TPred>(function);
- AddBinaryKernel<i16, i16, bool, TPred>(function);
- AddBinaryKernel<i16, ui32, bool, TPred>(function);
- AddBinaryKernel<i16, i32, bool, TPred>(function);
- AddBinaryKernel<i16, ui64, bool, TPred>(function);
- AddBinaryKernel<i16, i64, bool, TPred>(function);
-
- AddBinaryKernel<ui32, ui8, bool, TPred>(function);
- AddBinaryKernel<ui32, i8, bool, TPred>(function);
- AddBinaryKernel<ui32, ui16, bool, TPred>(function);
- AddBinaryKernel<ui32, i16, bool, TPred>(function);
- AddBinaryKernel<ui32, ui32, bool, TPred>(function);
- AddBinaryKernel<ui32, i32, bool, TPred>(function);
- AddBinaryKernel<ui32, ui64, bool, TPred>(function);
- AddBinaryKernel<ui32, i64, bool, TPred>(function);
-
- AddBinaryKernel<i32, ui8, bool, TPred>(function);
- AddBinaryKernel<i32, i8, bool, TPred>(function);
- AddBinaryKernel<i32, ui16, bool, TPred>(function);
- AddBinaryKernel<i32, i16, bool, TPred>(function);
- AddBinaryKernel<i32, ui32, bool, TPred>(function);
- AddBinaryKernel<i32, i32, bool, TPred>(function);
- AddBinaryKernel<i32, ui64, bool, TPred>(function);
- AddBinaryKernel<i32, i64, bool, TPred>(function);
-
- AddBinaryKernel<ui64, ui8, bool, TPred>(function);
- AddBinaryKernel<ui64, i8, bool, TPred>(function);
- AddBinaryKernel<ui64, ui16, bool, TPred>(function);
- AddBinaryKernel<ui64, i16, bool, TPred>(function);
- AddBinaryKernel<ui64, ui32, bool, TPred>(function);
- AddBinaryKernel<ui64, i32, bool, TPred>(function);
- AddBinaryKernel<ui64, ui64, bool, TPred>(function);
- AddBinaryKernel<ui64, i64, bool, TPred>(function);
-
- AddBinaryKernel<i64, ui8, bool, TPred>(function);
- AddBinaryKernel<i64, i8, bool, TPred>(function);
- AddBinaryKernel<i64, ui16, bool, TPred>(function);
- AddBinaryKernel<i64, i16, bool, TPred>(function);
- AddBinaryKernel<i64, ui32, bool, TPred>(function);
- AddBinaryKernel<i64, i32, bool, TPred>(function);
- AddBinaryKernel<i64, ui64, bool, TPred>(function);
- AddBinaryKernel<i64, i64, bool, TPred>(function);
+void AddBinaryIntegralPredicateKernels(TKernelFamilyBase& owner) {
+ AddBinaryKernel<ui8, ui8, bool, TPred>(owner);
+ AddBinaryKernel<ui8, i8, bool, TPred>(owner);
+ AddBinaryKernel<ui8, ui16, bool, TPred>(owner);
+ AddBinaryKernel<ui8, i16, bool, TPred>(owner);
+ AddBinaryKernel<ui8, ui32, bool, TPred>(owner);
+ AddBinaryKernel<ui8, i32, bool, TPred>(owner);
+ AddBinaryKernel<ui8, ui64, bool, TPred>(owner);
+ AddBinaryKernel<ui8, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<i8, ui8, bool, TPred>(owner);
+ AddBinaryKernel<i8, i8, bool, TPred>(owner);
+ AddBinaryKernel<i8, ui16, bool, TPred>(owner);
+ AddBinaryKernel<i8, i16, bool, TPred>(owner);
+ AddBinaryKernel<i8, ui32, bool, TPred>(owner);
+ AddBinaryKernel<i8, i32, bool, TPred>(owner);
+ AddBinaryKernel<i8, ui64, bool, TPred>(owner);
+ AddBinaryKernel<i8, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<ui16, ui8, bool, TPred>(owner);
+ AddBinaryKernel<ui16, i8, bool, TPred>(owner);
+ AddBinaryKernel<ui16, ui16, bool, TPred>(owner);
+ AddBinaryKernel<ui16, i16, bool, TPred>(owner);
+ AddBinaryKernel<ui16, ui32, bool, TPred>(owner);
+ AddBinaryKernel<ui16, i32, bool, TPred>(owner);
+ AddBinaryKernel<ui16, ui64, bool, TPred>(owner);
+ AddBinaryKernel<ui16, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<i16, ui8, bool, TPred>(owner);
+ AddBinaryKernel<i16, i8, bool, TPred>(owner);
+ AddBinaryKernel<i16, ui16, bool, TPred>(owner);
+ AddBinaryKernel<i16, i16, bool, TPred>(owner);
+ AddBinaryKernel<i16, ui32, bool, TPred>(owner);
+ AddBinaryKernel<i16, i32, bool, TPred>(owner);
+ AddBinaryKernel<i16, ui64, bool, TPred>(owner);
+ AddBinaryKernel<i16, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<ui32, ui8, bool, TPred>(owner);
+ AddBinaryKernel<ui32, i8, bool, TPred>(owner);
+ AddBinaryKernel<ui32, ui16, bool, TPred>(owner);
+ AddBinaryKernel<ui32, i16, bool, TPred>(owner);
+ AddBinaryKernel<ui32, ui32, bool, TPred>(owner);
+ AddBinaryKernel<ui32, i32, bool, TPred>(owner);
+ AddBinaryKernel<ui32, ui64, bool, TPred>(owner);
+ AddBinaryKernel<ui32, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<i32, ui8, bool, TPred>(owner);
+ AddBinaryKernel<i32, i8, bool, TPred>(owner);
+ AddBinaryKernel<i32, ui16, bool, TPred>(owner);
+ AddBinaryKernel<i32, i16, bool, TPred>(owner);
+ AddBinaryKernel<i32, ui32, bool, TPred>(owner);
+ AddBinaryKernel<i32, i32, bool, TPred>(owner);
+ AddBinaryKernel<i32, ui64, bool, TPred>(owner);
+ AddBinaryKernel<i32, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<ui64, ui8, bool, TPred>(owner);
+ AddBinaryKernel<ui64, i8, bool, TPred>(owner);
+ AddBinaryKernel<ui64, ui16, bool, TPred>(owner);
+ AddBinaryKernel<ui64, i16, bool, TPred>(owner);
+ AddBinaryKernel<ui64, ui32, bool, TPred>(owner);
+ AddBinaryKernel<ui64, i32, bool, TPred>(owner);
+ AddBinaryKernel<ui64, ui64, bool, TPred>(owner);
+ AddBinaryKernel<ui64, i64, bool, TPred>(owner);
+
+ AddBinaryKernel<i64, ui8, bool, TPred>(owner);
+ AddBinaryKernel<i64, i8, bool, TPred>(owner);
+ AddBinaryKernel<i64, ui16, bool, TPred>(owner);
+ AddBinaryKernel<i64, i16, bool, TPred>(owner);
+ AddBinaryKernel<i64, ui32, bool, TPred>(owner);
+ AddBinaryKernel<i64, i32, bool, TPred>(owner);
+ AddBinaryKernel<i64, ui64, bool, TPred>(owner);
+ AddBinaryKernel<i64, i64, bool, TPred>(owner);
}
template<template<typename, typename, typename> class TPred>
-class TBinaryNumericPredicate : public arrow::compute::ScalarFunction {
+class TBinaryNumericPredicateKernelFamily : public TKernelFamilyBase {
public:
- TBinaryNumericPredicate(const std::string& name)
- : ScalarFunction(name, arrow::compute::Arity::Binary(), nullptr)
+ TBinaryNumericPredicateKernelFamily()
{
AddBinaryIntegralPredicateKernels<TPred>(*this);
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp
index 01e1c8e7bad..f332867d06e 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp
@@ -281,8 +281,8 @@ void RegisterLess(IBuiltinFunctionRegistry& registry) {
RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrLess, TCompareArgsOpt>(registry, aggrName);
}
-void RegisterLess(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TLessOp>>("Less"));
+void RegisterLess(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Less"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TLessOp>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp
index 73e883534db..4dd4b8b5d7b 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp
@@ -281,8 +281,8 @@ void RegisterLessOrEqual(IBuiltinFunctionRegistry& registry) {
RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrLessOrEqual, TCompareArgsOpt>(registry, aggrName);
}
-void RegisterLessOrEqual(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TLessOrEqualOp>>("LessOrEqual"));
+void RegisterLessOrEqual(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["LessOrEqual"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TLessOrEqualOp>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp
index d199fb8c089..57a54defa1a 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mod.cpp
@@ -90,8 +90,8 @@ void RegisterMod(IBuiltinFunctionRegistry& registry) {
RegisterBinaryIntegralFunctionOpt<TIntegralMod, TBinaryArgsOptWithNullableResult>(registry, "Mod");
}
-void RegisterMod(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericFunction<TIntegralMod>>("Mod?"));
+void RegisterMod(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Mod"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralMod>>(TKernelFamily::ENullMode::AlwaysNull);
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp
index 1dc0202deb2..0b0ab631a6a 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_mul.cpp
@@ -97,8 +97,8 @@ void RegisterMul(IBuiltinFunctionRegistry& registry) {
NUdf::TDataType<NUdf::TInterval>, TNumMulInterval, TBinaryArgsOptWithNullableResult>(registry, "Mul");
}
-void RegisterMul(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericFunction<TMul>>("Mul"));
+void RegisterMul(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Mul"] = std::make_unique<TBinaryNumericKernelFamily<TMul>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp
index 4fb26abf1ee..a5f5eca21da 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp
@@ -284,8 +284,8 @@ void RegisterNotEquals(IBuiltinFunctionRegistry& registry) {
RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrNotEquals, TCompareArgsOpt>(registry, aggrName);
}
-void RegisterNotEquals(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TNotEqualsOp>>("NotEquals"));
+void RegisterNotEquals(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["NotEquals"] = std::make_unique<TBinaryNumericPredicateKernelFamily<TNotEqualsOp>>();
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp
index 06cd093241b..bfd8e0782bf 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_sub.cpp
@@ -289,8 +289,8 @@ void RegisterSub(IBuiltinFunctionRegistry& registry) {
NUdf::TDataType<NUdf::TTzTimestamp>, TAnyDateTimeSubIntervalTz, TBinaryArgsOptWithNullableResult>(registry, "Sub");
}
-void RegisterSub(arrow::compute::FunctionRegistry& registry) {
- AddFunction(registry, std::make_shared<TBinaryNumericFunction<TSub>>("Sub"));
+void RegisterSub(TKernelFamilyMap& kernelFamilyMap) {
+ kernelFamilyMap["Sub"] = std::make_unique<TBinaryNumericKernelFamily<TSub>>();
}
} // namespace NMiniKQL
diff --git a/ydb/library/yql/minikql/mkql_function_metadata.h b/ydb/library/yql/minikql/mkql_function_metadata.h
index a880b0558b5..3042922133f 100644
--- a/ydb/library/yql/minikql/mkql_function_metadata.h
+++ b/ydb/library/yql/minikql/mkql_function_metadata.h
@@ -1,10 +1,9 @@
#pragma once
#include <ydb/library/yql/public/udf/udf_value.h>
+#include <util/digest/numeric.h>
-namespace arrow::compute {
-class FunctionRegistry;
-};
+#include <arrow/compute/kernel.h>
namespace NKikimr {
@@ -51,6 +50,81 @@ using TArgType = std::pair<NUdf::TDataTypeId, bool>; // type with optional flag
using TDescriptionList = std::vector<TFunctionDescriptor>;
using TFunctionsMap = std::unordered_map<TString, TDescriptionList>;
+class TKernel;
+
+class TKernelFamily {
+public:
+ enum ENullMode {
+ Default,
+ AlwaysNull,
+ AlwaysNotNull
+ };
+
+ const ENullMode NullMode;
+ const arrow::compute::FunctionOptions* FunctionOptions;
+
+ TKernelFamily(ENullMode nullMode = ENullMode::Default, const arrow::compute::FunctionOptions* functionOptions = nullptr)
+ : NullMode(nullMode)
+ , FunctionOptions(functionOptions)
+ {}
+
+ virtual ~TKernelFamily() = default;
+ virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0;
+};
+
+class TKernel {
+public:
+ const TKernelFamily& Family;
+ const std::vector<NUdf::TDataTypeId> ArgTypes;
+ const NUdf::TDataTypeId ReturnType;
+
+ TKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType)
+ : Family(family)
+ , ArgTypes(argTypes)
+ , ReturnType(returnType)
+ {
+ }
+
+ virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0;
+
+ virtual ~TKernel() = default;
+};
+
+struct TTypeHasher {
+ std::size_t operator()(const std::vector<NUdf::TDataTypeId>& s) const noexcept {
+ size_t r = 0;
+ for (const auto& x : s) {
+ r = CombineHashes<size_t>(r, x);
+ }
+
+ return r;
+ }
+};
+
+using TKernelMap = std::unordered_map<std::vector<NUdf::TDataTypeId>, std::unique_ptr<TKernel>, TTypeHasher>;
+
+using TKernelFamilyMap = std::unordered_map<TString, std::unique_ptr<TKernelFamily>>;
+
+class TKernelFamilyBase : public TKernelFamily
+{
+public:
+ TKernelFamilyBase(ENullMode nullMode = ENullMode::Default, const arrow::compute::FunctionOptions* functionOptions = nullptr)
+ : TKernelFamily(nullMode, functionOptions)
+ {}
+
+ const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const final {
+ std::vector<NUdf::TDataTypeId> key(argTypes, argTypes + argTypesCount);
+ auto it = KernelMap.find(key);
+ if (it == KernelMap.end()) {
+ return nullptr;
+ }
+
+ return it->second.get();
+ }
+
+ TKernelMap KernelMap;
+};
+
class IBuiltinFunctionRegistry: public TThrRefBase, private TNonCopyable
{
public:
@@ -70,7 +144,9 @@ public:
virtual TFunctionDescriptor GetBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const = 0;
- virtual arrow::compute::FunctionRegistry* GetArrowFunctionRegistry() const = 0;
+ virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount) const = 0;
+
+ virtual void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) = 0;
};
}
diff --git a/ydb/library/yql/minikql/mkql_type_builder.cpp b/ydb/library/yql/minikql/mkql_type_builder.cpp
index 04ec05bec41..5c4b28f36b1 100644
--- a/ydb/library/yql/minikql/mkql_type_builder.cpp
+++ b/ydb/library/yql/minikql/mkql_type_builder.cpp
@@ -1313,18 +1313,8 @@ private:
namespace NMiniKQL {
-bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type) {
- auto unpacked = UnpackOptional(itemType, isOptional);
- if (!unpacked->IsData()) {
- return false;
- }
-
- auto slot = AS_TYPE(TDataType, unpacked)->GetDataSlot();
- if (!slot) {
- return false;
- }
-
- switch (*slot) {
+bool ConvertArrowType(NUdf::EDataSlot slot, std::shared_ptr<arrow::DataType>& type) {
+ switch (slot) {
case NUdf::EDataSlot::Bool:
type = arrow::boolean();
return true;
@@ -1361,6 +1351,20 @@ bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::
}
}
+bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type) {
+ auto unpacked = UnpackOptional(itemType, isOptional);
+ if (!unpacked->IsData()) {
+ return false;
+ }
+
+ auto slot = AS_TYPE(TDataType, unpacked)->GetDataSlot();
+ if (!slot) {
+ return false;
+ }
+
+ return ConvertArrowType(*slot, type);
+}
+
void TArrowType::Export(ArrowSchema* out) const {
auto status = arrow::ExportType(*Type, out);
if (!status.ok()) {
diff --git a/ydb/library/yql/minikql/mkql_type_builder.h b/ydb/library/yql/minikql/mkql_type_builder.h
index 9cc4ea126ef..7e5e95bcba9 100644
--- a/ydb/library/yql/minikql/mkql_type_builder.h
+++ b/ydb/library/yql/minikql/mkql_type_builder.h
@@ -11,6 +11,7 @@ namespace NKikimr {
namespace NMiniKQL {
bool ConvertArrowType(TType* itemType, bool& isOptional, std::shared_ptr<arrow::DataType>& type);
+bool ConvertArrowType(NUdf::EDataSlot slot, std::shared_ptr<arrow::DataType>& type);
class TArrowType : public NUdf::IArrowType {
public: