diff options
author | aneporada <aneporada@yandex-team.com> | 2025-06-03 12:56:43 +0300 |
---|---|---|
committer | aneporada <aneporada@yandex-team.com> | 2025-06-03 13:14:17 +0300 |
commit | b51b122d7876de2baef5405f3c7da3b10cf99b26 (patch) | |
tree | 6d90d7b4978c404ebde32b05d6b3a7ab1b1f24a2 | |
parent | 3e8add594e2ff4cf7ff8f552e9a6e285f4d6c372 (diff) | |
download | ydb-b51b122d7876de2baef5405f3c7da3b10cf99b26.tar.gz |
Support scalar-only logical ops
commit_hash:2be000baa1e203ec9b4bab5a4d236abc64609376
5 files changed, 185 insertions, 5 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp b/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp index df89388fd9d..968254a33c1 100644 --- a/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp +++ b/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp @@ -63,8 +63,10 @@ public: arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { auto firstDatum = batch.values[0]; auto secondDatum = batch.values[1]; - MKQL_ENSURE(!firstDatum.is_scalar() || !secondDatum.is_scalar(), "Expected at least one array"); - + if (firstDatum.is_scalar() && secondDatum.is_scalar()) { + *res = CalcScalarScalar(firstDatum, secondDatum); + return arrow::Status::OK(); + } if (IsAllEqualsTo(firstDatum, false)) { // false AND ... = false if (firstDatum.is_array()) { @@ -104,6 +106,30 @@ public: } private: + arrow::Datum CalcScalarScalar(const arrow::Datum& firstDatum, const arrow::Datum& secondDatum) const { + const auto& first = firstDatum.scalar_as<arrow::UInt8Scalar>(); + const auto& second = secondDatum.scalar_as<arrow::UInt8Scalar>(); + + if (first.is_valid && second.is_valid) { + bool result = bool((first.value & second.value) & 1u); + return MakeScalarDatum(result); + } + + if (!first.is_valid && !second.is_valid) { + return firstDatum; + } + + if (!first.is_valid) { + // null and true -> null + // null and false -> false + return second.value ? firstDatum : secondDatum; + } else { + // true and null -> null + // false and null -> false + return first.value ? secondDatum : firstDatum; + } + } + arrow::Datum CalcScalarArray(arrow::MemoryPool* pool, ui8 value, bool valid, const std::shared_ptr<arrow::ArrayData>& arr) const { bool first_true = valid && value; bool first_false = valid && !value; @@ -198,8 +224,10 @@ public: arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { auto firstDatum = batch.values[0]; auto secondDatum = batch.values[1]; - MKQL_ENSURE(!firstDatum.is_scalar() || !secondDatum.is_scalar(), "Expected at least one array"); - + if (firstDatum.is_scalar() && secondDatum.is_scalar()) { + *res = CalcScalarScalar(firstDatum, secondDatum); + return arrow::Status::OK(); + } if (IsAllEqualsTo(firstDatum, true)) { // true OR ... = true if (firstDatum.is_array()) { @@ -239,6 +267,30 @@ public: } private: + arrow::Datum CalcScalarScalar(const arrow::Datum& firstDatum, const arrow::Datum& secondDatum) const { + const auto& first = firstDatum.scalar_as<arrow::UInt8Scalar>(); + const auto& second = secondDatum.scalar_as<arrow::UInt8Scalar>(); + + if (first.is_valid && second.is_valid) { + bool result = bool((first.value | second.value) & 1u); + return MakeScalarDatum(result); + } + + if (!first.is_valid && !second.is_valid) { + return firstDatum; + } + + if (!first.is_valid) { + // null or true -> true + // null or false -> null + return second.value ? secondDatum : firstDatum; + } else { + // true or null -> true + // false or null -> null + return first.value ? firstDatum : secondDatum; + } + } + arrow::Datum CalcScalarArray(arrow::MemoryPool* pool, ui8 value, bool valid, const std::shared_ptr<arrow::ArrayData>& arr) const { bool first_true = valid && value; bool first_false = valid && !value; @@ -334,7 +386,10 @@ public: arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { auto firstDatum = batch.values[0]; auto secondDatum = batch.values[1]; - MKQL_ENSURE(!firstDatum.is_scalar() || !secondDatum.is_scalar(), "Expected at least one array"); + if (firstDatum.is_scalar() && secondDatum.is_scalar()) { + *res = CalcScalarScalar(firstDatum, secondDatum); + return arrow::Status::OK(); + } if (firstDatum.null_count() == firstDatum.length()) { if (firstDatum.is_array()) { *res = firstDatum; @@ -369,6 +424,18 @@ public: } private: + arrow::Datum CalcScalarScalar(const arrow::Datum& firstDatum, const arrow::Datum& secondDatum) const { + const auto& first = firstDatum.scalar_as<arrow::UInt8Scalar>(); + const auto& second = secondDatum.scalar_as<arrow::UInt8Scalar>(); + + if (first.is_valid && second.is_valid) { + bool result = bool((first.value ^ second.value) & 1u); + return MakeScalarDatum(result); + } + + return first.is_valid ? secondDatum : firstDatum; + } + arrow::Datum CalcScalarArray(arrow::MemoryPool* pool, ui8 value, const std::shared_ptr<arrow::ArrayData>& arr) const { std::shared_ptr<arrow::Buffer> bitmap = CopyBitmap(pool, arr->buffers[0], arr->offset, arr->length); std::shared_ptr<arrow::Buffer> data = ARROW_RESULT(arrow::AllocateBuffer(arr->length, pool)); @@ -402,6 +469,11 @@ class TNotBlockExec { public: arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const { const auto& input = batch.values[0]; + if (input.is_scalar()) { + const auto& arg = input.scalar_as<arrow::UInt8Scalar>(); + *res = arg.is_valid ? MakeScalarDatum(bool(~arg.value & 1u)) : input; + return arrow::Status::OK(); + } MKQL_ENSURE(input.is_array(), "Expected array"); const auto& arr = *input.array(); if (arr.GetNullCount() == arr.length) { diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp new file mode 100644 index 00000000000..bb729a95506 --- /dev/null +++ b/yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp @@ -0,0 +1,101 @@ +#include "mkql_computation_node_ut.h" + +#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h> +#include <yql/essentials/minikql/computation/mkql_block_builder.h> + +namespace NKikimr { +namespace NMiniKQL { + +namespace { + +template<typename Op> +TMaybe<bool> ScalarOp(TMaybe<bool> arg1, TMaybe<bool> arg2, const Op& op) { + TSetup<false> setup; + TProgramBuilder& pb = *setup.PgmBuilder; + auto arg1val = arg1 ? pb.NewDataLiteral<bool>(*arg1) : pb.NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id); + auto arg2val = arg2 ? pb.NewDataLiteral<bool>(*arg2) : pb.NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id); + TRuntimeNode root = op(pb, pb.AsScalar(arg1val), pb.AsScalar(arg2val)); + const auto graph = setup.BuildGraph(root); + NYql::NUdf::TUnboxedValue result = graph->GetValue(); + auto datum = TArrowBlock::From(result).GetDatum(); + UNIT_ASSERT(datum.is_scalar()); + const auto& scalar = datum.scalar_as<arrow::UInt8Scalar>(); + if (!scalar.is_valid) { + return {}; + } + UNIT_ASSERT(scalar.value == 0 || scalar.value == 1); + return bool(scalar.value); +} + +} //namespace + +Y_UNIT_TEST_SUITE(TMiniKQLBlockLogicalTest) { + +Y_UNIT_TEST(ScalarAnd) { + auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) { + return pb.BlockAnd(arg1, arg2); + }; + + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, false, op), false); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, true, op), false); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, false, op), false); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, true, op), true); + + UNIT_ASSERT_VALUES_EQUAL(ScalarOp({}, false, op), false); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, {}, op), false); + + UNIT_ASSERT(ScalarOp({}, {}, op).Empty()); + UNIT_ASSERT(ScalarOp(true, {}, op).Empty()); + UNIT_ASSERT(ScalarOp({}, true, op).Empty()); +} + +Y_UNIT_TEST(ScalarOr) { + auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) { + return pb.BlockOr(arg1, arg2); + }; + + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, false, op), false); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, true, op), true); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, false, op), true); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, true, op), true); + + UNIT_ASSERT_VALUES_EQUAL(ScalarOp({}, true, op), true); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, {}, op), true); + + UNIT_ASSERT(ScalarOp({}, {}, op).Empty()); + UNIT_ASSERT(ScalarOp(false, {}, op).Empty()); + UNIT_ASSERT(ScalarOp({}, false, op).Empty()); +} + +Y_UNIT_TEST(ScalarXor) { + auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) { + return pb.BlockXor(arg1, arg2); + }; + + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, false, op), false); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, true, op), true); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, false, op), true); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, true, op), false); + + UNIT_ASSERT(ScalarOp({}, {}, op).Empty()); + UNIT_ASSERT(ScalarOp(false, {}, op).Empty()); + UNIT_ASSERT(ScalarOp({}, false, op).Empty()); + UNIT_ASSERT(ScalarOp(true, {}, op).Empty()); + UNIT_ASSERT(ScalarOp({}, true, op).Empty()); +} + +Y_UNIT_TEST(ScalarNot) { + auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) { + Y_UNUSED(arg2); + return pb.BlockNot(arg1); + }; + + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, {}, op), true); + UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, {}, op), false); + UNIT_ASSERT(ScalarOp({}, {}, op).Empty()); +} + +} // Y_UNIT_TEST_SUITE + +} // namespace NMiniKQL +} // namespace NKikimr diff --git a/yql/essentials/minikql/comp_nodes/ut/ya.make.inc b/yql/essentials/minikql/comp_nodes/ut/ya.make.inc index 9b82a1b90b7..4ce865905c9 100644 --- a/yql/essentials/minikql/comp_nodes/ut/ya.make.inc +++ b/yql/essentials/minikql/comp_nodes/ut/ya.make.inc @@ -28,6 +28,7 @@ SET(ORIG_SOURCES mkql_block_compress_ut.cpp mkql_block_coalesce_ut.cpp mkql_block_exists_ut.cpp + mkql_block_logical_ut.cpp mkql_block_skiptake_ut.cpp mkql_block_map_join_ut_utils.cpp mkql_block_map_join_ut.cpp diff --git a/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg new file mode 100644 index 00000000000..a654f9117df --- /dev/null +++ b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg @@ -0,0 +1 @@ +in Input input1.txt diff --git a/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql new file mode 100644 index 00000000000..82ae95df146 --- /dev/null +++ b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql @@ -0,0 +1,5 @@ +USE plato; + +SELECT + ('w' = 's') OR ('v' = 's') OR value == "aaa" AS crash, +FROM Input; |