aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorssmike <ssmike@ydb.tech>2023-08-28 16:30:01 +0300
committerssmike <ssmike@ydb.tech>2023-08-28 17:28:42 +0300
commit0d3994e1a90d60e68f14d414d9ff3dde1364c533 (patch)
treea1fa1135e4783ab1e558b2ec5b698c53e15dd043
parent6ac97d50fb3bf76773f1ed44e72c49511eb4ae22 (diff)
downloadydb-0d3994e1a90d60e68f14d414d9ff3dde1364c533.tar.gz
Pushdown bool comparison
-rw-r--r--ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp52
-rw-r--r--ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp85
2 files changed, 131 insertions, 6 deletions
diff --git a/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp b/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp
index 5d61b8f8f1..1b09e09974 100644
--- a/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp
+++ b/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp
@@ -585,7 +585,7 @@ bool IsMemberListBinOpNode(const TExprNode& node) {
return IsListOfMembers(node.Head()) || IsListOfMembers(node.Tail());
}
-TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, TExprContext& ctx) {
+TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, const TExprNode::TPtr& parent, TExprContext& ctx) {
auto it = node->IsCallable() ? SupportedBinOps.find(node->Content()) : SupportedBinOps.end();
if (it != SupportedBinOps.end()) {
TExprNode::TPtr toExpand;
@@ -606,8 +606,6 @@ TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, TExp
YQL_CLOG(DEBUG, Core) << node->Content() << " over tuple";
return ExpandTupleBinOp(*toExpand, ctx);
}
-
- return node;
}
if (node->IsCallable("Not")) {
@@ -699,11 +697,53 @@ TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, TExp
}
}
+ if (node->IsCallable("!=")) {
+ TExprNode::TPtr litArg;
+ TExprNode::TPtr anyArg;
+
+ if (node->Child(1)->IsCallable("Bool")) {
+ litArg = node->Child(1);
+ anyArg = node->Child(0);
+ }
+
+ if (node->Child(0)->IsCallable("Bool")) {
+ litArg = node->Child(0);
+ anyArg = node->Child(1);
+ }
+
+ if (litArg && anyArg) {
+ return ctx.Builder(node->Pos())
+ .Callable("==")
+ .Add(0, anyArg)
+ .Add(1, MakeBool(node->Pos(), !FromString<bool>(litArg->Head().Content()), ctx))
+ .Seal()
+ .Build();
+ }
+ }
+
+ if (node->IsCallable("Member") && (!parent || parent->IsCallable({"And", "Or", "Not", "Coalesce"}))) {
+ auto* typeAnn = node->GetTypeAnn();
+ if (typeAnn->GetKind() == ETypeAnnotationKind::Optional) {
+ typeAnn = typeAnn->Cast<TOptionalExprType>()->GetItemType();
+ }
+ if (typeAnn->GetKind() == ETypeAnnotationKind::Data &&
+ typeAnn->Cast<TDataExprType>()->GetSlot() == EDataSlot::Bool)
+ {
+ YQL_CLOG(DEBUG, Core) << "Replace raw Member with explicit bool comparison";
+ return ctx.Builder(node->Pos())
+ .Callable("==")
+ .Add(0, node)
+ .Add(1, MakeBool(node->Pos(), true, ctx))
+ .Seal()
+ .Build();
+ }
+ }
+
return node;
}
-void DoOptimizeForRangeExtraction(const TExprNode::TPtr& input, TExprNode::TPtr& output, bool topLevel, TExprContext& ctx) {
- output = OptimizeNodeForRangeExtraction(input, ctx);
+void DoOptimizeForRangeExtraction(const TExprNode::TPtr& input, TExprNode::TPtr& output, bool topLevel, TExprContext& ctx, const TExprNode::TPtr& parent = nullptr) {
+ output = OptimizeNodeForRangeExtraction(input, parent, ctx);
if (output != input) {
return;
}
@@ -715,7 +755,7 @@ void DoOptimizeForRangeExtraction(const TExprNode::TPtr& input, TExprNode::TPtr&
continue;
}
TExprNode::TPtr newChild = child;
- DoOptimizeForRangeExtraction(child, newChild, false, ctx);
+ DoOptimizeForRangeExtraction(child, newChild, false, ctx, input);
if (newChild != child) {
changed = true;
child = std::move(newChild);
diff --git a/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp b/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp
index c7346dcaf1..9c83da1568 100644
--- a/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp
+++ b/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp
@@ -886,6 +886,91 @@ Y_UNIT_TEST_SUITE(TYqlExtractPredicate) {
UNIT_ASSERT_EQUAL(lambda, canonicalLambda);
//Cerr << DumpNode(*buildResult.ComputeNode, exprCtx);
}
+
+ Y_UNIT_TEST(BoolPredicate) {
+ TString prog =
+ "use plato;\n"
+ "declare $param as List<Int32>;\n"
+ "$src = [<|x:true, y:1|>, <|x:false, y:2|>];\n"
+ "insert into Output with truncate\n"
+ "select * from as_table($src) where not x;";
+
+ TExprContext exprCtx;
+ TTypeAnnotationContextPtr typesCtx;
+ TExprNode::TPtr exprRoot = ParseAndOptimize(prog, exprCtx, typesCtx);
+ TExprNode::TPtr filterLambda = LocateFilterLambda(exprRoot);
+
+ THashSet<TString> usedColumns;
+ using NDetail::TPredicateRangeExtractor;
+
+ TPredicateExtractorSettings settings;
+ auto extractor = MakePredicateRangeExtractor(settings);
+
+ UNIT_ASSERT(extractor->Prepare(filterLambda, *filterLambda->Head().Head().GetTypeAnn(), usedColumns, exprCtx, *typesCtx));
+
+ auto buildResult = extractor->BuildComputeNode({ "x" }, exprCtx, *typesCtx);
+
+ UNIT_ASSERT(buildResult.ComputeNode);
+
+ auto canonicalRanges =
+ "(\n"
+ "(return (RangeFinalize (RangeMultiply (Uint64 '10000) (RangeUnion (RangeFor '=== (Bool 'false) (DataType 'Bool))))))\n"
+ ")\n";
+ auto ranges = DumpNode(*buildResult.ComputeNode, exprCtx);
+ UNIT_ASSERT_EQUAL(ranges, canonicalRanges);
+
+ auto canonicalLambda =
+ "(\n"
+ "(return (lambda '($1) (OptionalIf (Bool 'true) $1)))\n"
+ ")\n";
+ auto lambda = DumpNode(*buildResult.PrunedLambda, exprCtx);
+ UNIT_ASSERT_EQUAL(lambda, canonicalLambda);
+ }
+
+ Y_UNIT_TEST(BoolPredicateLiteralRange) {
+ TString prog =
+ "use plato;\n"
+ "declare $param as List<Int32>;\n"
+ "$src = [<|x:true, y:1|>, <|x:false, y:2|>];\n"
+ "insert into Output with truncate\n"
+ "select * from as_table($src) where not x and y = 1;";
+
+ TExprContext exprCtx;
+ TTypeAnnotationContextPtr typesCtx;
+ TExprNode::TPtr exprRoot = ParseAndOptimize(prog, exprCtx, typesCtx);
+ TExprNode::TPtr filterLambda = LocateFilterLambda(exprRoot);
+
+ THashSet<TString> usedColumns;
+ using NDetail::TPredicateRangeExtractor;
+
+ TPredicateExtractorSettings settings;
+ settings.BuildLiteralRange = true;
+ auto extractor = MakePredicateRangeExtractor(settings);
+
+ UNIT_ASSERT(extractor->Prepare(filterLambda, *filterLambda->Head().Head().GetTypeAnn(), usedColumns, exprCtx, *typesCtx));
+
+ auto buildResult = extractor->BuildComputeNode({ "x", "y"}, exprCtx, *typesCtx);
+
+ UNIT_ASSERT(buildResult.ComputeNode);
+ UNIT_ASSERT(buildResult.LiteralRange);
+
+ auto canonicalRanges =
+ "(\n"
+ "(let $1 (RangeFor '== (Bool 'false) (DataType 'Bool)))\n"
+ "(let $2 (RangeFor '=== (Int32 '1) (DataType 'Int32)))\n"
+ "(return (RangeFinalize (RangeMultiply (Uint64 '10000) (RangeUnion (RangeIntersect (RangeMultiply (Uint64 '10000) $1 $2))))))\n"
+ ")\n";
+ auto ranges = DumpNode(*buildResult.ComputeNode, exprCtx);
+ UNIT_ASSERT_EQUAL(ranges, canonicalRanges);
+
+ auto canonicalLambda =
+ "(\n"
+ "(return (lambda '($1) (OptionalIf (Bool 'true) $1)))\n"
+ ")\n";
+ auto lambda = DumpNode(*buildResult.PrunedLambda, exprCtx);
+
+ UNIT_ASSERT_EQUAL(lambda, canonicalLambda);
+ }
}
} // namespace NYql