aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoraneporada <aneporada@yandex-team.com>2025-06-03 12:56:43 +0300
committeraneporada <aneporada@yandex-team.com>2025-06-03 13:14:17 +0300
commitb51b122d7876de2baef5405f3c7da3b10cf99b26 (patch)
tree6d90d7b4978c404ebde32b05d6b3a7ab1b1f24a2
parent3e8add594e2ff4cf7ff8f552e9a6e285f4d6c372 (diff)
downloadydb-b51b122d7876de2baef5405f3c7da3b10cf99b26.tar.gz
Support scalar-only logical ops
commit_hash:2be000baa1e203ec9b4bab5a4d236abc64609376
-rw-r--r--yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp82
-rw-r--r--yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp101
-rw-r--r--yql/essentials/minikql/comp_nodes/ut/ya.make.inc1
-rw-r--r--yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg1
-rw-r--r--yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql5
5 files changed, 185 insertions, 5 deletions
diff --git a/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp b/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp
index df89388fd9d..968254a33c1 100644
--- a/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp
+++ b/yql/essentials/minikql/comp_nodes/mkql_block_logical.cpp
@@ -63,8 +63,10 @@ public:
arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
auto firstDatum = batch.values[0];
auto secondDatum = batch.values[1];
- MKQL_ENSURE(!firstDatum.is_scalar() || !secondDatum.is_scalar(), "Expected at least one array");
-
+ if (firstDatum.is_scalar() && secondDatum.is_scalar()) {
+ *res = CalcScalarScalar(firstDatum, secondDatum);
+ return arrow::Status::OK();
+ }
if (IsAllEqualsTo(firstDatum, false)) {
// false AND ... = false
if (firstDatum.is_array()) {
@@ -104,6 +106,30 @@ public:
}
private:
+ arrow::Datum CalcScalarScalar(const arrow::Datum& firstDatum, const arrow::Datum& secondDatum) const {
+ const auto& first = firstDatum.scalar_as<arrow::UInt8Scalar>();
+ const auto& second = secondDatum.scalar_as<arrow::UInt8Scalar>();
+
+ if (first.is_valid && second.is_valid) {
+ bool result = bool((first.value & second.value) & 1u);
+ return MakeScalarDatum(result);
+ }
+
+ if (!first.is_valid && !second.is_valid) {
+ return firstDatum;
+ }
+
+ if (!first.is_valid) {
+ // null and true -> null
+ // null and false -> false
+ return second.value ? firstDatum : secondDatum;
+ } else {
+ // true and null -> null
+ // false and null -> false
+ return first.value ? secondDatum : firstDatum;
+ }
+ }
+
arrow::Datum CalcScalarArray(arrow::MemoryPool* pool, ui8 value, bool valid, const std::shared_ptr<arrow::ArrayData>& arr) const {
bool first_true = valid && value;
bool first_false = valid && !value;
@@ -198,8 +224,10 @@ public:
arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
auto firstDatum = batch.values[0];
auto secondDatum = batch.values[1];
- MKQL_ENSURE(!firstDatum.is_scalar() || !secondDatum.is_scalar(), "Expected at least one array");
-
+ if (firstDatum.is_scalar() && secondDatum.is_scalar()) {
+ *res = CalcScalarScalar(firstDatum, secondDatum);
+ return arrow::Status::OK();
+ }
if (IsAllEqualsTo(firstDatum, true)) {
// true OR ... = true
if (firstDatum.is_array()) {
@@ -239,6 +267,30 @@ public:
}
private:
+ arrow::Datum CalcScalarScalar(const arrow::Datum& firstDatum, const arrow::Datum& secondDatum) const {
+ const auto& first = firstDatum.scalar_as<arrow::UInt8Scalar>();
+ const auto& second = secondDatum.scalar_as<arrow::UInt8Scalar>();
+
+ if (first.is_valid && second.is_valid) {
+ bool result = bool((first.value | second.value) & 1u);
+ return MakeScalarDatum(result);
+ }
+
+ if (!first.is_valid && !second.is_valid) {
+ return firstDatum;
+ }
+
+ if (!first.is_valid) {
+ // null or true -> true
+ // null or false -> null
+ return second.value ? secondDatum : firstDatum;
+ } else {
+ // true or null -> true
+ // false or null -> null
+ return first.value ? firstDatum : secondDatum;
+ }
+ }
+
arrow::Datum CalcScalarArray(arrow::MemoryPool* pool, ui8 value, bool valid, const std::shared_ptr<arrow::ArrayData>& arr) const {
bool first_true = valid && value;
bool first_false = valid && !value;
@@ -334,7 +386,10 @@ public:
arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
auto firstDatum = batch.values[0];
auto secondDatum = batch.values[1];
- MKQL_ENSURE(!firstDatum.is_scalar() || !secondDatum.is_scalar(), "Expected at least one array");
+ if (firstDatum.is_scalar() && secondDatum.is_scalar()) {
+ *res = CalcScalarScalar(firstDatum, secondDatum);
+ return arrow::Status::OK();
+ }
if (firstDatum.null_count() == firstDatum.length()) {
if (firstDatum.is_array()) {
*res = firstDatum;
@@ -369,6 +424,18 @@ public:
}
private:
+ arrow::Datum CalcScalarScalar(const arrow::Datum& firstDatum, const arrow::Datum& secondDatum) const {
+ const auto& first = firstDatum.scalar_as<arrow::UInt8Scalar>();
+ const auto& second = secondDatum.scalar_as<arrow::UInt8Scalar>();
+
+ if (first.is_valid && second.is_valid) {
+ bool result = bool((first.value ^ second.value) & 1u);
+ return MakeScalarDatum(result);
+ }
+
+ return first.is_valid ? secondDatum : firstDatum;
+ }
+
arrow::Datum CalcScalarArray(arrow::MemoryPool* pool, ui8 value, const std::shared_ptr<arrow::ArrayData>& arr) const {
std::shared_ptr<arrow::Buffer> bitmap = CopyBitmap(pool, arr->buffers[0], arr->offset, arr->length);
std::shared_ptr<arrow::Buffer> data = ARROW_RESULT(arrow::AllocateBuffer(arr->length, pool));
@@ -402,6 +469,11 @@ class TNotBlockExec {
public:
arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
const auto& input = batch.values[0];
+ if (input.is_scalar()) {
+ const auto& arg = input.scalar_as<arrow::UInt8Scalar>();
+ *res = arg.is_valid ? MakeScalarDatum(bool(~arg.value & 1u)) : input;
+ return arrow::Status::OK();
+ }
MKQL_ENSURE(input.is_array(), "Expected array");
const auto& arr = *input.array();
if (arr.GetNullCount() == arr.length) {
diff --git a/yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp b/yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp
new file mode 100644
index 00000000000..bb729a95506
--- /dev/null
+++ b/yql/essentials/minikql/comp_nodes/ut/mkql_block_logical_ut.cpp
@@ -0,0 +1,101 @@
+#include "mkql_computation_node_ut.h"
+
+#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
+#include <yql/essentials/minikql/computation/mkql_block_builder.h>
+
+namespace NKikimr {
+namespace NMiniKQL {
+
+namespace {
+
+template<typename Op>
+TMaybe<bool> ScalarOp(TMaybe<bool> arg1, TMaybe<bool> arg2, const Op& op) {
+ TSetup<false> setup;
+ TProgramBuilder& pb = *setup.PgmBuilder;
+ auto arg1val = arg1 ? pb.NewDataLiteral<bool>(*arg1) : pb.NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id);
+ auto arg2val = arg2 ? pb.NewDataLiteral<bool>(*arg2) : pb.NewEmptyOptionalDataLiteral(NUdf::TDataType<bool>::Id);
+ TRuntimeNode root = op(pb, pb.AsScalar(arg1val), pb.AsScalar(arg2val));
+ const auto graph = setup.BuildGraph(root);
+ NYql::NUdf::TUnboxedValue result = graph->GetValue();
+ auto datum = TArrowBlock::From(result).GetDatum();
+ UNIT_ASSERT(datum.is_scalar());
+ const auto& scalar = datum.scalar_as<arrow::UInt8Scalar>();
+ if (!scalar.is_valid) {
+ return {};
+ }
+ UNIT_ASSERT(scalar.value == 0 || scalar.value == 1);
+ return bool(scalar.value);
+}
+
+} //namespace
+
+Y_UNIT_TEST_SUITE(TMiniKQLBlockLogicalTest) {
+
+Y_UNIT_TEST(ScalarAnd) {
+ auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) {
+ return pb.BlockAnd(arg1, arg2);
+ };
+
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, false, op), false);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, true, op), false);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, false, op), false);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, true, op), true);
+
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp({}, false, op), false);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, {}, op), false);
+
+ UNIT_ASSERT(ScalarOp({}, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp(true, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp({}, true, op).Empty());
+}
+
+Y_UNIT_TEST(ScalarOr) {
+ auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) {
+ return pb.BlockOr(arg1, arg2);
+ };
+
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, false, op), false);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, true, op), true);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, false, op), true);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, true, op), true);
+
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp({}, true, op), true);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, {}, op), true);
+
+ UNIT_ASSERT(ScalarOp({}, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp(false, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp({}, false, op).Empty());
+}
+
+Y_UNIT_TEST(ScalarXor) {
+ auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) {
+ return pb.BlockXor(arg1, arg2);
+ };
+
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, false, op), false);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, true, op), true);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, false, op), true);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, true, op), false);
+
+ UNIT_ASSERT(ScalarOp({}, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp(false, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp({}, false, op).Empty());
+ UNIT_ASSERT(ScalarOp(true, {}, op).Empty());
+ UNIT_ASSERT(ScalarOp({}, true, op).Empty());
+}
+
+Y_UNIT_TEST(ScalarNot) {
+ auto op = [](TProgramBuilder& pb, TRuntimeNode arg1, TRuntimeNode arg2) {
+ Y_UNUSED(arg2);
+ return pb.BlockNot(arg1);
+ };
+
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(false, {}, op), true);
+ UNIT_ASSERT_VALUES_EQUAL(ScalarOp(true, {}, op), false);
+ UNIT_ASSERT(ScalarOp({}, {}, op).Empty());
+}
+
+} // Y_UNIT_TEST_SUITE
+
+} // namespace NMiniKQL
+} // namespace NKikimr
diff --git a/yql/essentials/minikql/comp_nodes/ut/ya.make.inc b/yql/essentials/minikql/comp_nodes/ut/ya.make.inc
index 9b82a1b90b7..4ce865905c9 100644
--- a/yql/essentials/minikql/comp_nodes/ut/ya.make.inc
+++ b/yql/essentials/minikql/comp_nodes/ut/ya.make.inc
@@ -28,6 +28,7 @@ SET(ORIG_SOURCES
mkql_block_compress_ut.cpp
mkql_block_coalesce_ut.cpp
mkql_block_exists_ut.cpp
+ mkql_block_logical_ut.cpp
mkql_block_skiptake_ut.cpp
mkql_block_map_join_ut_utils.cpp
mkql_block_map_join_ut.cpp
diff --git a/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg
new file mode 100644
index 00000000000..a654f9117df
--- /dev/null
+++ b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.cfg
@@ -0,0 +1 @@
+in Input input1.txt
diff --git a/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql
new file mode 100644
index 00000000000..82ae95df146
--- /dev/null
+++ b/yt/yql/tests/sql/suites/blocks/boolean_ops_scalar.sql
@@ -0,0 +1,5 @@
+USE plato;
+
+SELECT
+ ('w' = 's') OR ('v' = 's') OR value == "aaa" AS crash,
+FROM Input;