diff options
author | a-romanov <Anton.Romanov@ydb.tech> | 2023-03-03 20:13:19 +0300 |
---|---|---|
committer | a-romanov <Anton.Romanov@ydb.tech> | 2023-03-03 20:13:19 +0300 |
commit | 83db9dc18d8bef463b09a922b3ea85e4c0bbd726 (patch) | |
tree | 4aeded6f19a41d56320f2bc7f79575e595dd1b69 | |
parent | 790e700992f79d88b2eebed52b7f43905896e034 (diff) | |
download | ydb-83db9dc18d8bef463b09a922b3ea85e4c0bbd726.tar.gz |
Fix aggr compare NaNs.
10 files changed, 370 insertions, 39 deletions
diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp index 8f51d66bc0d..05f1e5bd882 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_compare_ut.cpp @@ -425,10 +425,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) { UNIT_ASSERT(iterator.Next(item)); UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // == UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // != - UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // < - UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <= - UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // > - UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >= + UNIT_ASSERT(item.GetElement(2).template Get<bool>()); // < + UNIT_ASSERT(item.GetElement(3).template Get<bool>()); // <= + UNIT_ASSERT(!item.GetElement(4).template Get<bool>()); // > + UNIT_ASSERT(!item.GetElement(5).template Get<bool>()); // >= UNIT_ASSERT(iterator.Next(item)); UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // == @@ -449,14 +449,6 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) { UNIT_ASSERT(iterator.Next(item)); UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // == UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // != - UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // < - UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <= - UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // > - UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >= - - UNIT_ASSERT(iterator.Next(item)); - UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // == - UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // != UNIT_ASSERT(item.GetElement(2).template Get<bool>()); // < UNIT_ASSERT(item.GetElement(3).template Get<bool>()); // <= UNIT_ASSERT(!item.GetElement(4).template Get<bool>()); // > @@ -465,10 +457,18 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) { UNIT_ASSERT(iterator.Next(item)); UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // == UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // != - UNIT_ASSERT(item.GetElement(2).template Get<bool>()); // < - UNIT_ASSERT(item.GetElement(3).template Get<bool>()); // <= - UNIT_ASSERT(!item.GetElement(4).template Get<bool>()); // > - UNIT_ASSERT(!item.GetElement(5).template Get<bool>()); // >= + UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // < + UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <= + UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // > + UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >= + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(!item.GetElement(0).template Get<bool>()); // == + UNIT_ASSERT(item.GetElement(1).template Get<bool>()); // != + UNIT_ASSERT(!item.GetElement(2).template Get<bool>()); // < + UNIT_ASSERT(!item.GetElement(3).template Get<bool>()); // <= + UNIT_ASSERT(item.GetElement(4).template Get<bool>()); // > + UNIT_ASSERT(item.GetElement(5).template Get<bool>()); // >= UNIT_ASSERT(iterator.Next(item)); UNIT_ASSERT(item.GetElement(0).template Get<bool>()); // == @@ -985,7 +985,134 @@ Y_UNIT_TEST_SUITE(TMiniKQLCompareTest) { const auto graph = setup.BuildGraph(pgmReturn); const auto result = graph->GetValue(); UNBOXED_VALUE_STR_EQUAL(result, "1970-01-01T03:00:06,Africa/Asmara"); - } + } + + Y_UNIT_TEST_LLVM(TestAggrMinMaxFloats) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto data1 = pb.NewDataLiteral<float>(0.0f*HUGE_VALF); + const auto data2 = pb.NewDataLiteral<float>(HUGE_VALF); + const auto data3 = pb.NewDataLiteral<float>(3.14f); + const auto data4 = pb.NewDataLiteral<float>(-2.13f); + const auto data5 = pb.NewDataLiteral<float>(-HUGE_VALF); + const auto dataType = pb.NewDataType(NUdf::TDataType<float>::Id); + const auto list = pb.NewList(dataType, {data1, data2, data3, data4, data5}); + const auto pgmReturn = pb.FlatMap(list, + [&](TRuntimeNode left) { + return pb.Map(list, + [&](TRuntimeNode right) { + return pb.NewTuple({pb.AggrMin(left, right), pb.AggrMax(left, right)}); + }); + }); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(std::isnan(item.GetElement(0).Get<float>())); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), HUGE_VALF); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), HUGE_VALF); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), 3.14f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -2.13f); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -2.13f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -2.13f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT(std::isnan(item.GetElement(1).Get<float>())); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), HUGE_VALF); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), 3.14f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -2.13f); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(0).Get<float>(), -HUGE_VALF); + UNIT_ASSERT_VALUES_EQUAL(item.GetElement(1).Get<float>(), -HUGE_VALF); + + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } } } diff --git a/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp b/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp index eb5ff4df71b..feba3674a47 100644 --- a/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp +++ b/ydb/library/yql/minikql/comp_nodes/ut/mkql_decimal_ut.cpp @@ -918,6 +918,133 @@ Y_UNIT_TEST_SUITE(TMiniKQLDecimalTest) { UNIT_ASSERT(!iterator.Next(item)); } + Y_UNIT_TEST_LLVM(TestAggrMinMax) { + TSetup<LLVM> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + + const auto data1 = pb.NewDecimalLiteral(NYql::NDecimal::Nan(), 13, 2); + const auto data2 = pb.NewDecimalLiteral(+NYql::NDecimal::Inf(), 13, 2); + const auto data3 = pb.NewDecimalLiteral(314, 13, 2); + const auto data4 = pb.NewDecimalLiteral(-213, 13, 2); + const auto data5 = pb.NewDecimalLiteral(-NYql::NDecimal::Inf(), 13, 2); + const auto dataType = pb.NewDecimalType(13, 2); + const auto list = pb.NewList(dataType, {data1, data2, data3, data4, data5}); + const auto pgmReturn = pb.FlatMap(list, + [&](TRuntimeNode left) { + return pb.Map(list, + [&](TRuntimeNode right) { + return pb.NewTuple({pb.AggrMin(left, right), pb.AggrMax(left, right)}); + }); + }); + + const auto graph = setup.BuildGraph(pgmReturn); + const auto iterator = graph->GetValue().GetListIterator(); + NUdf::TUnboxedValue item; + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == NYql::NDecimal::Nan()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == +NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == 314); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == 314); + UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == 314); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == 314); + UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == 314); + UNIT_ASSERT(item.GetElement(1).GetInt128() == 314); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == 314); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == 314); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == 314); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -213); + UNIT_ASSERT(item.GetElement(1).GetInt128() == -213); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == -213); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == NYql::NDecimal::Nan()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == +NYql::NDecimal::Inf()); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == 314); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == -213); + + UNIT_ASSERT(iterator.Next(item)); + UNIT_ASSERT(item.GetElement(0).GetInt128() == -NYql::NDecimal::Inf()); + UNIT_ASSERT(item.GetElement(1).GetInt128() == -NYql::NDecimal::Inf()); + + UNIT_ASSERT(!iterator.Next(item)); + UNIT_ASSERT(!iterator.Next(item)); + } + Y_UNIT_TEST_LLVM(TestAddSub) { TSetup<LLVM> setup; TProgramBuilder& pb = *setup.PgmBuilder; diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp index 65de097dd9f..f4f7b3675c1 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_equals.cpp @@ -36,8 +36,9 @@ Y_FORCE_INLINE bool Equals(T1 x, T2 y) { using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast<FT>(x); const auto r = static_cast<FT>(y); - if (Aggr && std::isunordered(l, r)) { - return std::isnan(l) == std::isnan(r); + if constexpr (Aggr) { + if (std::isunordered(l, r)) + return std::isnan(l) == std::isnan(r); } return l == r; } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp index 07ab7645e90..2e7a66e36d8 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater.cpp @@ -36,8 +36,9 @@ Y_FORCE_INLINE bool Greater(T1 x, T2 y) { using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast<FT>(x); const auto r = static_cast<FT>(y); - if (Aggr && std::isunordered(l, r)) { - return !std::isnan(l) && std::isnan(r); + if constexpr (Aggr) { + if (std::isunordered(l, r)) + return !std::isnan(r); } return l > r; } @@ -62,7 +63,7 @@ Value* GenGreaterFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) { template <> Value* GenGreaterFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) { const auto ugt = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UGT, lhs, rhs, "greater", block); - const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(lhs->getType(), 0.0), lhs, "ordered", block); + const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(rhs->getType(), 0.0), rhs, "ordered", block); return BinaryOperator::CreateAnd(ugt, ord, "and", block); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp index 8f7fafb2c46..a97af279a1c 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_greater_or_equal.cpp @@ -36,8 +36,9 @@ Y_FORCE_INLINE bool GreaterOrEqual(T1 x, T2 y) { using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast<FT>(x); const auto r = static_cast<FT>(y); - if (Aggr && std::isunordered(l, r)) { - return !std::isnan(l) || std::isnan(r); + if constexpr (Aggr) { + if (std::isunordered(l, r)) + return std::isnan(l); } return l >= r; } @@ -62,7 +63,7 @@ Value* GenGreaterOrEqualFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) template <> Value* GenGreaterOrEqualFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) { const auto oge = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OGE, lhs, rhs, "greater_or_equal", block); - const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "unordered", block); + const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(lhs->getType(), 0.0), lhs, "unordered", block); return BinaryOperator::CreateOr(oge, uno, "or", block); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp index 2945bcea260..47b5c83964e 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less.cpp @@ -36,8 +36,9 @@ Y_FORCE_INLINE bool Less(T1 x, T2 y) { using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast<FT>(x); const auto r = static_cast<FT>(y); - if (Aggr && std::isunordered(l, r)) { - return std::isnan(l) && !std::isnan(r); + if constexpr (Aggr) { + if (std::isunordered(l, r)) + return !std::isnan(l); } return l < r; } @@ -62,7 +63,7 @@ Value* GenLessFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) { template <> Value* GenLessFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) { const auto ult = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ULT, lhs, rhs, "less", block); - const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(rhs->getType(), 0.0), rhs, "ordered", block); + const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(lhs->getType(), 0.0), lhs, "ordered", block); return BinaryOperator::CreateAnd(ult, ord, "and", block); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp index 02f03d49cab..dd9d26ed82e 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_less_or_equal.cpp @@ -36,8 +36,9 @@ Y_FORCE_INLINE bool LessOrEqual(T1 x, T2 y) { using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast<FT>(x); const auto r = static_cast<FT>(y); - if (Aggr && std::isunordered(l, r)) { - return std::isnan(l) || !std::isnan(r); + if constexpr (Aggr) { + if (std::isunordered(l, r)) + return std::isnan(r); } return l <= r; } @@ -62,7 +63,7 @@ Value* GenLessOrEqualFloats<false>(Value* lhs, Value* rhs, BasicBlock* block) { template <> Value* GenLessOrEqualFloats<true>(Value* lhs, Value* rhs, BasicBlock* block) { const auto ole = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_OLE, lhs, rhs, "less_or_equal", block); - const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(lhs->getType(), 0.0), lhs, "unordered", block); + const auto uno = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UNO, ConstantFP::get(rhs->getType(), 0.0), rhs, "unordered", block); return BinaryOperator::CreateOr(ole, uno, "or", block); } diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp index dc6cfe47780..137eb078804 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_max.cpp @@ -28,7 +28,7 @@ struct TMax : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMax<TLeft, #ifndef MKQL_DISABLE_CODEGEN static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { - if (std::is_floating_point<TOutput>()) { + if constexpr (std::is_floating_point<TOutput>()) { auto& context = ctx.Codegen->GetContext(); auto& module = ctx.Codegen->GetModule(); const auto fnType = FunctionType::get(GetTypeFor<TOutput>(context), {left->getType(), right->getType()}, false); @@ -46,7 +46,24 @@ struct TMax : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMax<TLeft, }; template<typename TType> -using TAggrMax = TMax<TType, TType, TType>; +struct TFloatAggrMax : public TSimpleArithmeticBinary<TType, TType, TType, TFloatAggrMax<TType>> { + static TType Do(TType left, TType right) + { + return left > right || std::isnan(left) ? left : right; + } +#ifndef MKQL_DISABLE_CODEGEN + static Value* Gen(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block) + { + const auto ugt = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_UGT, left, right, "greater", block); + const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(right->getType(), 0.0), right, "ordered", block); + const auto both = BinaryOperator::CreateAnd(ugt, ord, "and", block); + return SelectInst::Create(both, left, right, "max", block); + } +#endif +}; + +template<typename TType> +using TAggrMax = std::conditional_t<std::is_floating_point<TType>::value, TFloatAggrMax<TType>, TMax<TType, TType, TType>>; template<typename TType> struct TTzMax : public TSelectArithmeticBinaryCopyTimezone<TType, TTzMax<TType>> { @@ -134,6 +151,25 @@ struct TDecimalMax { }; template<NUdf::EDataSlot Slot> +struct TDecimalAggrMax { + static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) { + const auto lv = left.GetInt128(); + const auto rv = right.GetInt128(); + return lv > rv ? left : right; + } +#ifndef MKQL_DISABLE_CODEGEN + static Value* Generate(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block) + { + const auto l = GetterForInt128(left, block); + const auto r = GetterForInt128(right, block); + const auto greater = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, l, r, "greater", block); + const auto res = SelectInst::Create(greater, left, right, "max", block); + return res; + } +#endif +}; + +template<NUdf::EDataSlot Slot> struct TCustomMax { static NUdf::TUnboxedValuePod Execute(NUdf::TUnboxedValuePod left, NUdf::TUnboxedValuePod right) { const bool r = CompareCustoms<Slot>(left, right) < 0; @@ -175,7 +211,7 @@ void RegisterAggrMax(IBuiltinFunctionRegistry& registry) { RegisterDatetimeAggregateFunction<TAggrMax, TBinaryArgsSameOpt>(registry, "AggrMax"); RegisterTzDatetimeAggregateFunction<TAggrTzMax, TBinaryArgsSameOpt>(registry, "AggrMax"); - RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalMax, TBinaryArgsSameOpt>(registry, "AggrMax"); + RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrMax, TBinaryArgsSameOpt>(registry, "AggrMax"); RegisterCustomAggregateFunction<NUdf::TDataType<char*>, TCustomMax, TBinaryArgsSameOpt>(registry, "AggrMax"); RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TUtf8>, TCustomMax, TBinaryArgsSameOpt>(registry, "AggrMax"); diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp index 4fb61f6554f..f3049f0ef02 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_min.cpp @@ -28,7 +28,7 @@ struct TMin : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMin<TLeft, #ifndef MKQL_DISABLE_CODEGEN static Value* Gen(Value* left, Value* right, const TCodegenContext& ctx, BasicBlock*& block) { - if (std::is_floating_point<TOutput>()) { + if constexpr (std::is_floating_point<TOutput>()) { auto& context = ctx.Codegen->GetContext(); auto& module = ctx.Codegen->GetModule(); const auto fnType = FunctionType::get(GetTypeFor<TOutput>(context), {left->getType(), right->getType()}, false); @@ -45,6 +45,23 @@ struct TMin : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TMin<TLeft, #endif }; +template<typename TType> +struct TFloatAggrMin : public TSimpleArithmeticBinary<TType, TType, TType, TFloatAggrMin<TType>> { + static TType Do(TType left, TType right) + { + return left < right || std::isnan(right) ? left : right; + } +#ifndef MKQL_DISABLE_CODEGEN + static Value* Gen(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block) + { + const auto ult = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ULT, left, right, "less", block); + const auto ord = CmpInst::Create(Instruction::FCmp, FCmpInst::FCMP_ORD, ConstantFP::get(left->getType(), 0.0), left, "ordered", block); + const auto both = BinaryOperator::CreateAnd(ult, ord, "and", block); + return SelectInst::Create(both, left, right, "min", block); + } +#endif +}; + template<NUdf::EDataSlot Slot> struct TDecimalMin { static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) { @@ -92,8 +109,26 @@ struct TDecimalMin { #endif }; +template<NUdf::EDataSlot Slot> +struct TDecimalAggrMin { + static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right) { + const auto lv = left.GetInt128(); + const auto rv = right.GetInt128(); + return lv < rv ? left : right; + } +#ifndef MKQL_DISABLE_CODEGEN + static Value* Generate(Value* left, Value* right, const TCodegenContext&, BasicBlock*& block) + { + const auto l = GetterForInt128(left, block); + const auto r = GetterForInt128(right, block); + const auto less = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, l, r, "less", block); + return SelectInst::Create(less, left, right, "min", block); + } +#endif +}; + template<typename TType> -using TAggrMin = TMin<TType, TType, TType>; +using TAggrMin = std::conditional_t<std::is_floating_point<TType>::value, TFloatAggrMin<TType>, TMin<TType, TType, TType>>; template<typename TType> struct TTzMin : public TSelectArithmeticBinaryCopyTimezone<TType, TTzMin<TType>> { @@ -175,7 +210,7 @@ void RegisterAggrMin(IBuiltinFunctionRegistry& registry) { RegisterDatetimeAggregateFunction<TAggrMin, TBinaryArgsSameOpt>(registry, "AggrMin"); RegisterTzDatetimeAggregateFunction<TAggrTzMin, TBinaryArgsSameOpt>(registry, "AggrMin"); - RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalMin, TBinaryArgsSameOpt>(registry, "AggrMin"); + RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TDecimal>, TDecimalAggrMin, TBinaryArgsSameOpt>(registry, "AggrMin"); RegisterCustomAggregateFunction<NUdf::TDataType<char*>, TCustomMin, TBinaryArgsSameOpt>(registry, "AggrMin"); RegisterCustomAggregateFunction<NUdf::TDataType<NUdf::TUtf8>, TCustomMin, TBinaryArgsSameOpt>(registry, "AggrMin"); diff --git a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp index 4c73392ae6d..79726e7d42b 100644 --- a/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp +++ b/ydb/library/yql/minikql/invoke_builtins/mkql_builtins_not_equals.cpp @@ -36,8 +36,9 @@ Y_FORCE_INLINE bool NotEquals(T1 x, T2 y) { using FT = std::conditional_t<(sizeof(F1) > sizeof(F2)), F1, F2>; const auto l = static_cast<FT>(x); const auto r = static_cast<FT>(y); - if (Aggr && std::isunordered(l, r)) { - return std::isnan(l) != std::isnan(r); + if constexpr (Aggr) { + if (std::isunordered(l, r)) + return std::isnan(l) != std::isnan(r); } return l != r; } |