diff options
author | aneporada <[email protected]> | 2022-11-18 22:05:14 +0300 |
---|---|---|
committer | aneporada <[email protected]> | 2022-11-18 22:05:14 +0300 |
commit | 024dc480aedd4a427ee9e6aab87df7dc56ffd67b (patch) | |
tree | 8569cf2c08ac1101d50606af364b3f4f43c4b385 | |
parent | c99bf985cf0dbeb462857a91e80dd289a467f67a (diff) |
Add compare kernels, fix handling WideFilter with limit
10 files changed, 302 insertions, 15 deletions
diff --git a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp index fa4e2b233b5..72b2d829fd2 100644 --- a/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/ydb/library/yql/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -2611,11 +2611,18 @@ TExprNode::TPtr OptimizeSkip(const TExprNode::TPtr& node, TExprContext& ctx) { template <bool EnableNewOptimizers> TExprNode::TPtr OptimizeTake(const TExprNode::TPtr& node, TExprContext& ctx) { if constexpr (EnableNewOptimizers) { - if (const auto& input = node->Head(); input.IsCallable({"Filter", "OrderedFilter", "WideFilter"}) && 2U == input.ChildrenSize()) { - YQL_CLOG(DEBUG, CorePeepHole) << "Inject " << node->Content() << " limit into " << input.Content(); - auto list = input.ChildrenList(); - list.emplace_back(node->TailPtr()); - return ctx.ChangeChildren(input, std::move(list)); + if (const auto& input = node->Head(); input.IsCallable({"Filter", "OrderedFilter", "WideFilter"})) { + if (2U == input.ChildrenSize()) { + YQL_CLOG(DEBUG, CorePeepHole) << "Inject " << node->Content() << " limit into " << input.Content(); + auto list = input.ChildrenList(); + list.emplace_back(node->TailPtr()); + return ctx.ChangeChildren(input, std::move(list)); + } + + auto childLimit = input.ChildPtr(2); + auto myLimit = node->ChildPtr(1); + YQL_CLOG(DEBUG, CorePeepHole) << "Merge " << node->Content() << " limit into " << input.Content(); + return ctx.ChangeChild(input, 2, ctx.NewCallable(node->Pos(), "Min", { myLimit, childLimit })); } if (const auto& input = node->Head(); 1U == input.UseCount()) { @@ -4326,6 +4333,14 @@ struct TBlockRules { {"/", { "Div?" } }, // kernel produces optional output on non-optional inputs {"%", { "Mod?" } }, // kernel produces optional output on non-optional inputs {"Not", { "invert" }}, + + // comparison kernels + {"==", { "Equals" } }, + {"!=", { "NotEquals" } }, + {"<", { "Less" } }, + {"<=", { "LessOrEqual" } }, + {">", { "Greater" } }, + {">=", { "GreaterOrEqual" } }, }; TBlockRules() @@ -4621,7 +4636,7 @@ TExprNode::TPtr OptimizeWideMapBlocks(const TExprNode::TPtr& node, TExprContext& TExprNode::TPtr OptimizeWideFilterBlocks(const TExprNode::TPtr& node, TExprContext& ctx, TTypeAnnotationContext& types) { auto multiInputType = node->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>(); - auto lambda = node->TailPtr(); + auto lambda = node->ChildPtr(1); YQL_ENSURE(lambda->ChildrenSize() == 2); // filter lambda should have single output ui32 newNodes; diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp index bf3406e35ad..e8add1b8c11 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp @@ -56,11 +56,17 @@ void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, arrow::comput RegisterAggrMax(registry); RegisterAggrMin(registry); RegisterEquals(registry); + RegisterEquals(arrowRegistry); RegisterNotEquals(registry); + RegisterNotEquals(arrowRegistry); RegisterLess(registry); + RegisterLess(arrowRegistry); RegisterLessOrEqual(registry); + RegisterLessOrEqual(arrowRegistry); RegisterGreater(registry); + RegisterGreater(arrowRegistry); RegisterGreaterOrEqual(registry); + RegisterGreaterOrEqual(arrowRegistry); } void PrintType(NUdf::TDataTypeId schemeType, bool isOptional, IOutputStream& out) diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h index 8d865b86d98..3163168c72c 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_compare.h @@ -470,11 +470,17 @@ void RegisterAggrCompareStrings(IBuiltinFunctionRegistry& registry, const std::s } void RegisterEquals(IBuiltinFunctionRegistry& registry); +void RegisterEquals(arrow::compute::FunctionRegistry& registry); void RegisterNotEquals(IBuiltinFunctionRegistry& registry); +void RegisterNotEquals(arrow::compute::FunctionRegistry& registry); void RegisterLess(IBuiltinFunctionRegistry& registry); +void RegisterLess(arrow::compute::FunctionRegistry& registry); void RegisterLessOrEqual(IBuiltinFunctionRegistry& registry); +void RegisterLessOrEqual(arrow::compute::FunctionRegistry& registry); void RegisterGreater(IBuiltinFunctionRegistry& registry); +void RegisterGreater(arrow::compute::FunctionRegistry& registry); void RegisterGreaterOrEqual(IBuiltinFunctionRegistry& registry); +void RegisterGreaterOrEqual(arrow::compute::FunctionRegistry& registry); } } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp index 79f24a52f46..00857259d97 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp @@ -155,6 +155,14 @@ struct TEquals : public TCompareArithmeticBinary<TLeft, TRight, TEquals<TLeft, T #endif }; +template<typename TLeft, typename TRight, typename TOutput> +struct TEqualsOp; + +template<typename TLeft, typename TRight> +struct TEqualsOp<TLeft, TRight, bool> : public TEquals<TLeft, TRight, false> { + static constexpr bool DefaultNulls = true; +}; + template<typename TLeft, typename TRight, bool Aggr> struct TDiffDateEquals : public TCompareArithmeticBinary<TLeft, TRight, TDiffDateEquals<TLeft, TRight, Aggr>>, public TAggrEquals { static bool Do(TLeft left, TRight right) @@ -276,5 +284,9 @@ void RegisterEquals(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrEquals, TCompareArgsOpt>(registry, aggrName); } +void RegisterEquals(arrow::compute::FunctionRegistry& registry) { + AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TEqualsOp>>("Equals")); +} + } // namespace NMiniKQL } // namespace NKikimr diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp index 08297c89161..15f035a7bdb 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp @@ -148,6 +148,14 @@ struct TGreater : public TCompareArithmeticBinary<TLeft, TRight, TGreater<TLeft, #endif }; +template<typename TLeft, typename TRight, typename TOutput> +struct TGreaterOp; + +template<typename TLeft, typename TRight> +struct TGreaterOp<TLeft, TRight, bool> : public TGreater<TLeft, TRight, false> { + static constexpr bool DefaultNulls = true; +}; + template<typename TLeft, typename TRight, bool Aggr> struct TDiffDateGreater : public TCompareArithmeticBinary<TLeft, TRight, TDiffDateGreater<TLeft, TRight, Aggr>>, public TAggrGreater { static bool Do(TLeft left, TRight right) @@ -273,5 +281,9 @@ void RegisterGreater(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrGreater, TCompareArgsOpt>(registry, aggrName); } +void RegisterGreater(arrow::compute::FunctionRegistry& registry) { + AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TGreaterOp>>("Greater")); +} + } // namespace NMiniKQL } // namespace NKikimr diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp index dbd899eb9fe..604f9955507 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp @@ -148,6 +148,14 @@ struct TGreaterOrEqual : public TCompareArithmeticBinary<TLeft, TRight, TGreater #endif }; +template<typename TLeft, typename TRight, typename TOutput> +struct TGreaterOrEqualOp; + +template<typename TLeft, typename TRight> +struct TGreaterOrEqualOp<TLeft, TRight, bool> : public TGreaterOrEqual<TLeft, TRight, false> { + static constexpr bool DefaultNulls = true; +}; + template<typename TLeft, typename TRight, bool Aggr> struct TDiffDateGreaterOrEqual : public TCompareArithmeticBinary<TLeft, TRight, TDiffDateGreaterOrEqual<TLeft, TRight, Aggr>>, public TAggrGreaterOrEqual { static bool Do(TLeft left, TRight right) @@ -273,5 +281,9 @@ void RegisterGreaterOrEqual(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrGreaterOrEqual, TCompareArgsOpt>(registry, aggrName); } +void RegisterGreaterOrEqual(arrow::compute::FunctionRegistry& registry) { + AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TGreaterOrEqualOp>>("GreaterOrEqual")); +} + } // namespace NMiniKQL } // namespace NKikimr diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h index abf16cc75c1..d2fb0844966 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_impl.h @@ -923,6 +923,87 @@ inline arrow::Datum MakeScalarDatum<ui64>(ui64 value) { return arrow::Datum(std::make_shared<arrow::UInt64Scalar>(value)); } +template<typename TInput1, typename TInput2, + template<typename, typename, typename> class TPred> +void CalcPredScalarArray(TInput1 val1, const TInput2* in2, ui8* output, size_t size) +{ + using TPredInstance = TPred<TInput1, TInput2, bool>; + while (size >= 8) { + ui8 result = 0; + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 0); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 1); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 2); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 3); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 4); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 5); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 6); + result |= (ui8(TPredInstance::Do(val1, *in2++)) << 7); + *output++ = result; + size -= 8; + } + if (size) { + ui8 result = 0; + for (ui8 i = 0; i < size; ++i) { + result |= (ui8(TPredInstance::Do(val1, *in2++)) << i); + } + *output = result; + } +} + +template<typename TInput1, typename TInput2, + template<typename, typename, typename> class TPred> +void CalcPredArrayScalar(const TInput1* in1, TInput2 val2, ui8* output, size_t size) +{ + using TPredInstance = TPred<TInput1, TInput2, bool>; + while (size >= 8) { + ui8 result = 0; + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 0); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 1); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 2); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 3); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 4); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 5); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 6); + result |= (ui8(TPredInstance::Do(*in1++, val2)) << 7); + *output++ = result; + size -= 8; + } + if (size) { + ui8 result = 0; + for (ui8 i = 0; i < size; ++i) { + result |= (ui8(TPredInstance::Do(*in1++, val2)) << i); + } + *output = result; + } +} + +template<typename TInput1, typename TInput2, + template<typename, typename, typename> class TPred> +void CalcPredArrayArray(const TInput1* in1, const TInput2* in2, ui8* output, size_t size) +{ + using TPredInstance = TPred<TInput1, TInput2, bool>; + while (size >= 8) { + ui8 result = 0; + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 0); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 1); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 2); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 3); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 4); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 5); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 6); + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << 7); + *output++ = result; + size -= 8; + } + if (size) { + ui8 result = 0; + for (ui8 i = 0; i < size; ++i) { + result |= (ui8(TPredInstance::Do(*in1++, *in2++)) << i); + } + *output = result; + } +} + template<typename TInput1, typename TInput2, typename TOutput, template<typename, typename, typename> class TFunc, bool DefaultNulls> struct TBinaryKernelExecs; @@ -958,9 +1039,15 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> const auto& arr2 = *arg2.array(); auto length = arr2.length; const auto values2 = arr2.GetValues<TInput2>(1); - auto resValues = resArr.GetMutableValues<TOutput>(1); - for (int64_t i = 0; i < length; ++i) { - resValues[i] = TFuncInstance::Do(val1, values2[i]); + if constexpr (std::is_same<TOutput, bool>::value) { + MKQL_ENSURE(resArr.offset == 0, "Expecting zero output offset"); + ui8* resValues = resArr.GetMutableValues<ui8>(1, 0); + CalcPredScalarArray<TInput1, TInput2, TFunc>(val1, values2, resValues, length); + } else { + auto resValues = resArr.GetMutableValues<TOutput>(1); + for (int64_t i = 0; i < length; ++i) { + resValues[i] = TFuncInstance::Do(val1, values2[i]); + } } } @@ -976,10 +1063,16 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> const auto& arr1 = *arg1.array(); auto length = arr1.length; const auto values1 = arr1.GetValues<TInput1>(1); - auto resValues = resArr.GetMutableValues<TOutput>(1); const auto val2 = GetPrimitiveScalarValue<TInput2>(*arg2.scalar()); - for (int64_t i = 0; i < length; ++i) { - resValues[i] = TFuncInstance::Do(values1[i], val2); + if constexpr (std::is_same<TOutput, bool>::value) { + MKQL_ENSURE(resArr.offset == 0, "Expecting zero output offset"); + ui8* resValues = resArr.GetMutableValues<ui8>(1, 0); + CalcPredArrayScalar<TInput1, TInput2, TFunc>(values1, val2, resValues, length); + } else { + auto resValues = resArr.GetMutableValues<TOutput>(1); + for (int64_t i = 0; i < length; ++i) { + resValues[i] = TFuncInstance::Do(values1[i], val2); + } } } @@ -997,9 +1090,15 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, true> auto length = arr1.length; const auto values1 = arr1.GetValues<TInput1>(1); const auto values2 = arr2.GetValues<TInput2>(1); - auto resValues = resArr.GetMutableValues<TOutput>(1); - for (int64_t i = 0; i < length; ++i) { - resValues[i] = TFuncInstance::Do(values1[i], values2[i]); + if constexpr (std::is_same<TOutput, bool>::value) { + MKQL_ENSURE(resArr.offset == 0, "Expecting zero output offset"); + ui8* resValues = resArr.GetMutableValues<ui8>(1, 0); + CalcPredArrayArray<TInput1, TInput2, TFunc>(values1, values2, resValues, length); + } else { + auto resValues = resArr.GetMutableValues<TOutput>(1); + for (int64_t i = 0; i < length; ++i) { + resValues[i] = TFuncInstance::Do(values1[i], values2[i]); + } } return arrow::Status::OK(); @@ -1034,6 +1133,7 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> static arrow::Status ExecScalarArray(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); + static_assert(!std::is_same<TOutput, bool>::value); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; const auto& arr2 = *arg2.array(); @@ -1068,6 +1168,7 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> static arrow::Status ExecArrayScalar(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); + static_assert(!std::is_same<TOutput, bool>::value); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; const auto& arr1 = *arg1.array(); @@ -1102,6 +1203,7 @@ struct TBinaryKernelExecs<TInput1, TInput2, TOutput, TFunc, false> static arrow::Status ExecArrayArray(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) { MKQL_ENSURE(batch.values.size() == 2, "Expected 2 args"); + static_assert(!std::is_same<TOutput, bool>::value); const auto& arg1 = batch.values[0]; const auto& arg2 = batch.values[1]; const auto& arr1 = *arg1.array(); @@ -1244,5 +1346,90 @@ public: } }; +template<template<typename, typename, typename> class TPred> +void AddBinaryIntegralPredicateKernels(arrow::compute::ScalarFunction& function) { + AddBinaryKernel<ui8, ui8, bool, TPred>(function); + AddBinaryKernel<ui8, i8, bool, TPred>(function); + AddBinaryKernel<ui8, ui16, bool, TPred>(function); + AddBinaryKernel<ui8, i16, bool, TPred>(function); + AddBinaryKernel<ui8, ui32, bool, TPred>(function); + AddBinaryKernel<ui8, i32, bool, TPred>(function); + AddBinaryKernel<ui8, ui64, bool, TPred>(function); + AddBinaryKernel<ui8, i64, bool, TPred>(function); + + AddBinaryKernel<i8, ui8, bool, TPred>(function); + AddBinaryKernel<i8, i8, bool, TPred>(function); + AddBinaryKernel<i8, ui16, bool, TPred>(function); + AddBinaryKernel<i8, i16, bool, TPred>(function); + AddBinaryKernel<i8, ui32, bool, TPred>(function); + AddBinaryKernel<i8, i32, bool, TPred>(function); + AddBinaryKernel<i8, ui64, bool, TPred>(function); + AddBinaryKernel<i8, i64, bool, TPred>(function); + + AddBinaryKernel<ui16, ui8, bool, TPred>(function); + AddBinaryKernel<ui16, i8, bool, TPred>(function); + AddBinaryKernel<ui16, ui16, bool, TPred>(function); + AddBinaryKernel<ui16, i16, bool, TPred>(function); + AddBinaryKernel<ui16, ui32, bool, TPred>(function); + AddBinaryKernel<ui16, i32, bool, TPred>(function); + AddBinaryKernel<ui16, ui64, bool, TPred>(function); + AddBinaryKernel<ui16, i64, bool, TPred>(function); + + AddBinaryKernel<i16, ui8, bool, TPred>(function); + AddBinaryKernel<i16, i8, bool, TPred>(function); + AddBinaryKernel<i16, ui16, bool, TPred>(function); + AddBinaryKernel<i16, i16, bool, TPred>(function); + AddBinaryKernel<i16, ui32, bool, TPred>(function); + AddBinaryKernel<i16, i32, bool, TPred>(function); + AddBinaryKernel<i16, ui64, bool, TPred>(function); + AddBinaryKernel<i16, i64, bool, TPred>(function); + + AddBinaryKernel<ui32, ui8, bool, TPred>(function); + AddBinaryKernel<ui32, i8, bool, TPred>(function); + AddBinaryKernel<ui32, ui16, bool, TPred>(function); + AddBinaryKernel<ui32, i16, bool, TPred>(function); + AddBinaryKernel<ui32, ui32, bool, TPred>(function); + AddBinaryKernel<ui32, i32, bool, TPred>(function); + AddBinaryKernel<ui32, ui64, bool, TPred>(function); + AddBinaryKernel<ui32, i64, bool, TPred>(function); + + AddBinaryKernel<i32, ui8, bool, TPred>(function); + AddBinaryKernel<i32, i8, bool, TPred>(function); + AddBinaryKernel<i32, ui16, bool, TPred>(function); + AddBinaryKernel<i32, i16, bool, TPred>(function); + AddBinaryKernel<i32, ui32, bool, TPred>(function); + AddBinaryKernel<i32, i32, bool, TPred>(function); + AddBinaryKernel<i32, ui64, bool, TPred>(function); + AddBinaryKernel<i32, i64, bool, TPred>(function); + + AddBinaryKernel<ui64, ui8, bool, TPred>(function); + AddBinaryKernel<ui64, i8, bool, TPred>(function); + AddBinaryKernel<ui64, ui16, bool, TPred>(function); + AddBinaryKernel<ui64, i16, bool, TPred>(function); + AddBinaryKernel<ui64, ui32, bool, TPred>(function); + AddBinaryKernel<ui64, i32, bool, TPred>(function); + AddBinaryKernel<ui64, ui64, bool, TPred>(function); + AddBinaryKernel<ui64, i64, bool, TPred>(function); + + AddBinaryKernel<i64, ui8, bool, TPred>(function); + AddBinaryKernel<i64, i8, bool, TPred>(function); + AddBinaryKernel<i64, ui16, bool, TPred>(function); + AddBinaryKernel<i64, i16, bool, TPred>(function); + AddBinaryKernel<i64, ui32, bool, TPred>(function); + AddBinaryKernel<i64, i32, bool, TPred>(function); + AddBinaryKernel<i64, ui64, bool, TPred>(function); + AddBinaryKernel<i64, i64, bool, TPred>(function); +} + +template<template<typename, typename, typename> class TPred> +class TBinaryNumericPredicate : public arrow::compute::ScalarFunction { +public: + TBinaryNumericPredicate(const std::string& name) + : ScalarFunction(name, arrow::compute::Arity::Binary(), nullptr) + { + AddBinaryIntegralPredicateKernels<TPred>(*this); + } +}; + } } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp index af400f128a9..01e1c8e7bad 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp @@ -148,6 +148,14 @@ struct TLess : public TCompareArithmeticBinary<TLeft, TRight, TLess<TLeft, TRigh #endif }; +template<typename TLeft, typename TRight, typename TOutput> +struct TLessOp; + +template<typename TLeft, typename TRight> +struct TLessOp<TLeft, TRight, bool> : public TLess<TLeft, TRight, false> { + static constexpr bool DefaultNulls = true; +}; + template<typename TLeft, typename TRight, bool Aggr> struct TDiffDateLess : public TCompareArithmeticBinary<TLeft, TRight, TDiffDateLess<TLeft, TRight, Aggr>>, public TAggrLess { static bool Do(TLeft left, TRight right) @@ -273,5 +281,9 @@ void RegisterLess(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrLess, TCompareArgsOpt>(registry, aggrName); } +void RegisterLess(arrow::compute::FunctionRegistry& registry) { + AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TLessOp>>("Less")); +} + } // namespace NMiniKQL } // namespace NKikimr diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp index 37739a4e9ee..73e883534db 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp @@ -148,6 +148,14 @@ struct TLessOrEqual : public TCompareArithmeticBinary<TLeft, TRight, TLessOrEqua #endif }; +template<typename TLeft, typename TRight, typename TOutput> +struct TLessOrEqualOp; + +template<typename TLeft, typename TRight> +struct TLessOrEqualOp<TLeft, TRight, bool> : public TLessOrEqual<TLeft, TRight, false> { + static constexpr bool DefaultNulls = true; +}; + template<typename TLeft, typename TRight, bool Aggr> struct TDiffDateLessOrEqual : public TCompareArithmeticBinary<TLeft, TRight, TDiffDateLessOrEqual<TLeft, TRight, Aggr>>, public TAggrLessOrEqual { static bool Do(TLeft left, TRight right) @@ -273,5 +281,9 @@ void RegisterLessOrEqual(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrLessOrEqual, TCompareArgsOpt>(registry, aggrName); } +void RegisterLessOrEqual(arrow::compute::FunctionRegistry& registry) { + AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TLessOrEqualOp>>("LessOrEqual")); +} + } // namespace NMiniKQL } // namespace NKikimr diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp index d9d8991ca07..4fb26abf1ee 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp @@ -155,6 +155,14 @@ struct TNotEquals : public TCompareArithmeticBinary<TLeft, TRight, TNotEquals<TL #endif }; +template<typename TLeft, typename TRight, typename TOutput> +struct TNotEqualsOp; + +template<typename TLeft, typename TRight> +struct TNotEqualsOp<TLeft, TRight, bool> : public TNotEquals<TLeft, TRight, false> { + static constexpr bool DefaultNulls = true; +}; + template<typename TLeft, typename TRight, bool Aggr> struct TDiffDateNotEquals : public TCompareArithmeticBinary<TLeft, TRight, TDiffDateNotEquals<TLeft, TRight, Aggr>>, public TAggrNotEquals { static bool Do(TLeft left, TRight right) @@ -276,5 +284,10 @@ void RegisterNotEquals(IBuiltinFunctionRegistry& registry) { RegisterAggrCompareCustomOpt<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrNotEquals, TCompareArgsOpt>(registry, aggrName); } +void RegisterNotEquals(arrow::compute::FunctionRegistry& registry) { + AddFunction(registry, std::make_shared<TBinaryNumericPredicate<TNotEqualsOp>>("NotEquals")); +} + + } // namespace NMiniKQL } // namespace NKikimr |