aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authora-romanov <Anton.Romanov@ydb.tech>2023-03-03 20:13:19 +0300
committera-romanov <Anton.Romanov@ydb.tech>2023-03-03 20:13:19 +0300
commit83db9dc18d8bef463b09a922b3ea85e4c0bbd726 (patch)
tree4aeded6f19a41d56320f2bc7f79575e595dd1b69
parent790e700992f79d88b2eebed52b7f43905896e034 (diff)
downloadydb-83db9dc18d8bef463b09a922b3ea85e4c0bbd726.tar.gz
Fix aggr compare NaNs.
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp161
-rw-r--r--ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp127
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp5
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp7
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp7
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp7
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp7
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp42
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp41
-rw-r--r--ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp5
10 files changed, 370 insertions, 39 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp
index 8f51d66bc0d..05f1e5bd882 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp
@@ -425,10 +425,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) {
UNIT_ASSERT(iterator.Next(item));
UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // ==
UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // !=
- UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // <
- UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <=
- UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // >
- UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >=
+ UNIT_ASSERT(item.GetElement(2).template Get<bool>()); // <
+ UNIT_ASSERT(item.GetElement(3).template Get<bool>()); // <=
+ UNIT_ASSERT(!item.GetElement(4).template Get<bool>()); // >
+ UNIT_ASSERT(!item.GetElement(5).template Get<bool>()); // >=
UNIT_ASSERT(iterator.Next(item));
UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // ==
@@ -449,14 +449,6 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) {
UNIT_ASSERT(iterator.Next(item));
UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // ==
UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // !=
- UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // <
- UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <=
- UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // >
- UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >=
-
- UNIT_ASSERT(iterator.Next(item));
- UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // ==
- UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // !=
UNIT_ASSERT(item.GetElement(2).template Get<bool>()); // <
UNIT_ASSERT(item.GetElement(3).template Get<bool>()); // <=
UNIT_ASSERT(!item.GetElement(4).template Get<bool>()); // >
@@ -465,10 +457,18 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) {
UNIT_ASSERT(iterator.Next(item));
UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // ==
UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // !=
- UNIT_ASSERT(item.GetElement(2).template Get<bool>()); // <
- UNIT_ASSERT(item.GetElement(3).template Get<bool>()); // <=
- UNIT_ASSERT(!item.GetElement(4).template Get<bool>()); // >
- UNIT_ASSERT(!item.GetElement(5).template Get<bool>()); // >=
+ UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // <
+ UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <=
+ UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // >
+ UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >=
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // ==
+ UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // !=
+ UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // <
+ UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <=
+ UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // >
+ UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >=
UNIT_ASSERT(iterator.Next(item));
UNIT_ASSERT(item.GetElement(0).template Get<bool>()); // ==
@@ -985,7 +985,134 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) {
const auto graph = setup.BuildGraph(pgmReturn);
const auto result = graph->GetValue();
UNBOXED_VALUE_STR_EQUAL(result, "1970-01-01T03:00:06,Africa/Asmara");
- }
+ }
+
+ Y_UNIT_TEST_LLVM(TestAggrMinMaxFloats) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto data1 = pb.NewDataLiteral<float>(0.0f*HUGE_VALF);
+ const auto data2 = pb.NewDataLiteral<float>(HUGE_VALF);
+ const auto data3 = pb.NewDataLiteral<float>(3.14f);
+ const auto data4 = pb.NewDataLiteral<float>(-2.13f);
+ const auto data5 = pb.NewDataLiteral<float>(-HUGE_VALF);
+ const auto dataType = pb.NewDataType(NUdf::TDataType<float>::Id);
+ const auto list = pb.NewList(dataType, {data1, data2, data3, data4, data5});
+ const auto pgmReturn = pb.FlatMap(list,
+ [&](TRuntimeNode left) {
+ return pb.Map(list,
+ [&](TRuntimeNode right) {
+ return pb.NewTuple({pb.AggrMin(left, right), pb.AggrMax(left, right)});
+ });
+ });
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue().GetListIterator();
+ NUdf::TUnboxedValue item;
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(std::isnan(item.GetElement(0).Get<float>()));
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), HUGE_VALF);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), HUGE_VALF);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -2.13f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -2.13f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>()));
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -2.13f);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF);
+ UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -HUGE_VALF);
+
+ UNIT_ASSERT(!iterator.Next(item));
+ UNIT_ASSERT(!iterator.Next(item));
+ }
}
}
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp
index eb5ff4df71b..feba3674a47 100644
--- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp
+++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp
@@ -918,6 +918,133 @@ Y_UNIT_TEST_SUITE(TMiniKQLDecimalTest) {
UNIT_ASSERT(!iterator.Next(item));
}
+ Y_UNIT_TEST_LLVM(TestAggrMinMax) {
+ TSetup<LLVM> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+
+ const auto data1 = pb.NewDecimalLiteral(NYql::NDecimal::Nan(), 13, 2);
+ const auto data2 = pb.NewDecimalLiteral(+NYql::NDecimal::Inf(), 13, 2);
+ const auto data3 = pb.NewDecimalLiteral(314, 13, 2);
+ const auto data4 = pb.NewDecimalLiteral(-213, 13, 2);
+ const auto data5 = pb.NewDecimalLiteral(-NYql::NDecimal::Inf(), 13, 2);
+ const auto dataType = pb.NewDecimalType(13, 2);
+ const auto list = pb.NewList(dataType, {data1, data2, data3, data4, data5});
+ const auto pgmReturn = pb.FlatMap(list,
+ [&](TRuntimeNode left) {
+ return pb.Map(list,
+ [&](TRuntimeNode right) {
+ return pb.NewTuple({pb.AggrMin(left, right), pb.AggrMax(left, right)});
+ });
+ });
+
+ const auto graph = setup.BuildGraph(pgmReturn);
+ const auto iterator = graph->GetValue().GetListIterator();
+ NUdf::TUnboxedValue item;
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == NYql::NDecimal::Nan());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == +NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == 314);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == 314);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == 314);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == 314);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == 314);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == 314);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == 314);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == 314);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == 314);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -213);
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == -213);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == -213);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == 314);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == -213);
+
+ UNIT_ASSERT(iterator.Next(item));
+ UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf());
+ UNIT_ASSERT(item.GetElement(1).GetInt128() == -NYql::NDecimal::Inf());
+
+ UNIT_ASSERT(!iterator.Next(item));
+ UNIT_ASSERT(!iterator.Next(item));
+ }
+
Y_UNIT_TEST_LLVM(TestAddSub) {
TSetup<LLVM> setup;
TProgramBuilder& pb = *setup.PgmBuilder;
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp
index 65de097dd9f..f4f7b3675c1 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp
@@ -36,8 +36,9 @@ Y_FORCE_INLINE bool Equals(T1 x, T2 y) {
using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
const auto l = static_cast<FT>(x);
const auto r = static_cast<FT>(y);
- if (Aggr && std::isunordered(l, r)) {
- return std::isnan(l) == std::isnan(r);
+ if constexpr (Aggr) {
+ if (std::isunordered(l, r))
+ return std::isnan(l) == std::isnan(r);
}
return l == r;
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp
index 07ab7645e90..2e7a66e36d8 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp
@@ -36,8 +36,9 @@ Y_FORCE_INLINE bool Greater(T1 x, T2 y) {
using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
const auto l = static_cast<FT>(x);
const auto r = static_cast<FT>(y);
- if (Aggr && std::isunordered(l, r)) {
- return !std::isnan(l) && std::isnan(r);
+ if constexpr (Aggr) {
+ if (std::isunordered(l, r))
+ return !std::isnan(r);
}
return l > r;
}
@@ -62,7 +63,7 @@ Value* GenGreaterFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) {
template <>
Value* GenGreaterFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) {
const auto ugt = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UGT, lhs, rhs, "greater", block);
- const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(lhs->getType(), 0.0), lhs, "ordered", block);
+ const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(rhs->getType(), 0.0), rhs, "ordered", block);
return BinaryOperator::CreateAnd(ugt, ord, "and", block);
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp
index 8f7fafb2c46..a97af279a1c 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp
@@ -36,8 +36,9 @@ Y_FORCE_INLINE bool GreaterOrEqual(T1 x, T2 y) {
using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
const auto l = static_cast<FT>(x);
const auto r = static_cast<FT>(y);
- if (Aggr && std::isunordered(l, r)) {
- return !std::isnan(l) || std::isnan(r);
+ if constexpr (Aggr) {
+ if (std::isunordered(l, r))
+ return std::isnan(l);
}
return l >= r;
}
@@ -62,7 +63,7 @@ Value* GenGreaterOrEqualFloats<false>(Value* lhs, Value* rhs, BasicBlock* block)
template <>
Value* GenGreaterOrEqualFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) {
const auto oge = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OGE, lhs, rhs, "greater_or_equal", block);
- const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "unordered", block);
+ const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(lhs->getType(), 0.0), lhs, "unordered", block);
return BinaryOperator::CreateOr(oge, uno, "or", block);
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp
index 2945bcea260..47b5c83964e 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp
@@ -36,8 +36,9 @@ Y_FORCE_INLINE bool Less(T1 x, T2 y) {
using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
const auto l = static_cast<FT>(x);
const auto r = static_cast<FT>(y);
- if (Aggr && std::isunordered(l, r)) {
- return std::isnan(l) && !std::isnan(r);
+ if constexpr (Aggr) {
+ if (std::isunordered(l, r))
+ return !std::isnan(l);
}
return l < r;
}
@@ -62,7 +63,7 @@ Value* GenLessFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) {
template <>
Value* GenLessFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) {
const auto ult = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ULT, lhs, rhs, "less", block);
- const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(rhs->getType(), 0.0), rhs, "ordered", block);
+ const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(lhs->getType(), 0.0), lhs, "ordered", block);
return BinaryOperator::CreateAnd(ult, ord, "and", block);
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp
index 02f03d49cab..dd9d26ed82e 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp
@@ -36,8 +36,9 @@ Y_FORCE_INLINE bool LessOrEqual(T1 x, T2 y) {
using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
const auto l = static_cast<FT>(x);
const auto r = static_cast<FT>(y);
- if (Aggr && std::isunordered(l, r)) {
- return std::isnan(l) || !std::isnan(r);
+ if constexpr (Aggr) {
+ if (std::isunordered(l, r))
+ return std::isnan(r);
}
return l <= r;
}
@@ -62,7 +63,7 @@ Value* GenLessOrEqualFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) {
template <>
Value* GenLessOrEqualFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) {
const auto ole = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OLE, lhs, rhs, "less_or_equal", block);
- const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(lhs->getType(), 0.0), lhs, "unordered", block);
+ const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "unordered", block);
return BinaryOperator::CreateOr(ole, uno, "or", block);
}
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp
index dc6cfe47780..137eb078804 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp
@@ -28,7 +28,7 @@ struct TMax : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMax<TLeft,
#ifndef MKQL_DISABLE_CODEGEN
static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
{
- if (std::is_floating_point<TOutput>()) {
+ if constexpr (std::is_floating_point<TOutput>()) {
auto& context = ctx.Codegen->GetContext();
auto& module = ctx.Codegen->GetModule();
const auto fnType = FunctionType::get(GetTypeFor<TOutput>(context), {left->getType(), right->getType()}, false);
@@ -46,7 +46,24 @@ struct TMax : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMax<TLeft,
};
template<typename TType>
-using TAggrMax = TMax<TType, TType, TType>;
+struct TFloatAggrMax : public TSimpleArithmeticBinary<TType, TType, TType, TFloatAggrMax<TType>> {
+ static TType Do(TType left, TType right)
+ {
+ return left > right || std::isnan(left) ? left : right;
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ static Value* Gen(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block)
+ {
+ const auto ugt = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UGT, left, right, "greater", block);
+ const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(right->getType(), 0.0), right, "ordered", block);
+ const auto both = BinaryOperator::CreateAnd(ugt, ord, "and", block);
+ return SelectInst::Create(both, left, right, "max", block);
+ }
+#endif
+};
+
+template<typename TType>
+using TAggrMax = std::conditional_t<std::is_floating_point<TType>::value, TFloatAggrMax<TType>, TMax<TType, TType, TType>>;
template<typename TType>
struct TTzMax : public TSelectArithmeticBinaryCopyTimezone<TType, TTzMax<TType>> {
@@ -134,6 +151,25 @@ struct TDecimalMax {
};
template<NUdf::EDataSlot Slot>
+struct TDecimalAggrMax {
+ static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) {
+ const auto lv = left.GetInt128();
+ const auto rv = right.GetInt128();
+ return lv > rv ? left : right;
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ static Value* Generate(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block)
+ {
+ const auto l = GetterForInt128(left, block);
+ const auto r = GetterForInt128(right, block);
+ const auto greater = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, l, r, "greater", block);
+ const auto res = SelectInst::Create(greater, left, right, "max", block);
+ return res;
+ }
+#endif
+};
+
+template<NUdf::EDataSlot Slot>
struct TCustomMax {
static NUdf::TUnboxedValuePod Execute(NUdf::TUnboxedValuePod left, NUdf::TUnboxedValuePod right) {
const bool r = CompareCustoms<Slot>(left, right) < 0;
@@ -175,7 +211,7 @@ void RegisterAggrMax(IBuiltinFunctionRegistry& registry) {
RegisterDatetimeAggregateFunction<TAggrMax, TBinaryArgsSameOpt>(registry, "AggrMax");
RegisterTzDatetimeAggregateFunction<TAggrTzMax, TBinaryArgsSameOpt>(registry, "AggrMax");
- RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalMax, TBinaryArgsSameOpt>(registry, "AggrMax");
+ RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrMax, TBinaryArgsSameOpt>(registry, "AggrMax");
RegisterCustomAggregateFunction<NUdf::TDataType<char*>, TCustomMax, TBinaryArgsSameOpt>(registry, "AggrMax");
RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TUtf8>, TCustomMax, TBinaryArgsSameOpt>(registry, "AggrMax");
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp
index 4fb61f6554f..f3049f0ef02 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp
@@ -28,7 +28,7 @@ struct TMin : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMin<TLeft,
#ifndef MKQL_DISABLE_CODEGEN
static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block)
{
- if (std::is_floating_point<TOutput>()) {
+ if constexpr (std::is_floating_point<TOutput>()) {
auto& context = ctx.Codegen->GetContext();
auto& module = ctx.Codegen->GetModule();
const auto fnType = FunctionType::get(GetTypeFor<TOutput>(context), {left->getType(), right->getType()}, false);
@@ -45,6 +45,23 @@ struct TMin : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMin<TLeft,
#endif
};
+template<typename TType>
+struct TFloatAggrMin : public TSimpleArithmeticBinary<TType, TType, TType, TFloatAggrMin<TType>> {
+ static TType Do(TType left, TType right)
+ {
+ return left < right || std::isnan(right) ? left : right;
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ static Value* Gen(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block)
+ {
+ const auto ult = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ULT, left, right, "less", block);
+ const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(left->getType(), 0.0), left, "ordered", block);
+ const auto both = BinaryOperator::CreateAnd(ult, ord, "and", block);
+ return SelectInst::Create(both, left, right, "min", block);
+ }
+#endif
+};
+
template<NUdf::EDataSlot Slot>
struct TDecimalMin {
static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) {
@@ -92,8 +109,26 @@ struct TDecimalMin {
#endif
};
+template<NUdf::EDataSlot Slot>
+struct TDecimalAggrMin {
+ static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) {
+ const auto lv = left.GetInt128();
+ const auto rv = right.GetInt128();
+ return lv < rv ? left : right;
+ }
+#ifndef MKQL_DISABLE_CODEGEN
+ static Value* Generate(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block)
+ {
+ const auto l = GetterForInt128(left, block);
+ const auto r = GetterForInt128(right, block);
+ const auto less = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, l, r, "less", block);
+ return SelectInst::Create(less, left, right, "min", block);
+ }
+#endif
+};
+
template<typename TType>
-using TAggrMin = TMin<TType, TType, TType>;
+using TAggrMin = std::conditional_t<std::is_floating_point<TType>::value, TFloatAggrMin<TType>, TMin<TType, TType, TType>>;
template<typename TType>
struct TTzMin : public TSelectArithmeticBinaryCopyTimezone<TType, TTzMin<TType>> {
@@ -175,7 +210,7 @@ void RegisterAggrMin(IBuiltinFunctionRegistry& registry) {
RegisterDatetimeAggregateFunction<TAggrMin, TBinaryArgsSameOpt>(registry, "AggrMin");
RegisterTzDatetimeAggregateFunction<TAggrTzMin, TBinaryArgsSameOpt>(registry, "AggrMin");
- RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalMin, TBinaryArgsSameOpt>(registry, "AggrMin");
+ RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrMin, TBinaryArgsSameOpt>(registry, "AggrMin");
RegisterCustomAggregateFunction<NUdf::TDataType<char*>, TCustomMin, TBinaryArgsSameOpt>(registry, "AggrMin");
RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TUtf8>, TCustomMin, TBinaryArgsSameOpt>(registry, "AggrMin");
diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp
index 4c73392ae6d..79726e7d42b 100644
--- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp
+++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp
@@ -36,8 +36,9 @@ Y_FORCE_INLINE bool NotEquals(T1 x, T2 y) {
using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>;
const auto l = static_cast<FT>(x);
const auto r = static_cast<FT>(y);
- if (Aggr && std::isunordered(l, r)) {
- return std::isnan(l) != std::isnan(r);
+ if constexpr (Aggr) {
+ if (std::isunordered(l, r))
+ return std::isnan(l) != std::isnan(r);
}
return l != r;
}