diff options
author | ssmike <ssmike@ydb.tech> | 2023-08-28 16:30:01 +0300 |
---|---|---|
committer | ssmike <ssmike@ydb.tech> | 2023-08-28 17:28:42 +0300 |
commit | 0d3994e1a90d60e68f14d414d9ff3dde1364c533 (patch) | |
tree | a1fa1135e4783ab1e558b2ec5b698c53e15dd043 | |
parent | 6ac97d50fb3bf76773f1ed44e72c49511eb4ae22 (diff) | |
download | ydb-0d3994e1a90d60e68f14d414d9ff3dde1364c533.tar.gz |
Pushdown bool comparison
-rw-r--r-- | ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp | 52 | ||||
-rw-r--r-- | ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp | 85 |
2 files changed, 131 insertions, 6 deletions
diff --git a/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp b/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp index 5d61b8f8f1..1b09e09974 100644 --- a/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp +++ b/ydb/library/yql/core/extract_predicate/extract_predicate_impl.cpp @@ -585,7 +585,7 @@ bool IsMemberListBinOpNode(const TExprNode& node) { return IsListOfMembers(node.Head()) || IsListOfMembers(node.Tail()); } -TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, TExprContext& ctx) { +TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, const TExprNode::TPtr& parent, TExprContext& ctx) { auto it = node->IsCallable() ? SupportedBinOps.find(node->Content()) : SupportedBinOps.end(); if (it != SupportedBinOps.end()) { TExprNode::TPtr toExpand; @@ -606,8 +606,6 @@ TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, TExp YQL_CLOG(DEBUG, Core) << node->Content() << " over tuple"; return ExpandTupleBinOp(*toExpand, ctx); } - - return node; } if (node->IsCallable("Not")) { @@ -699,11 +697,53 @@ TExprNode::TPtr OptimizeNodeForRangeExtraction(const TExprNode::TPtr& node, TExp } } + if (node->IsCallable("!=")) { + TExprNode::TPtr litArg; + TExprNode::TPtr anyArg; + + if (node->Child(1)->IsCallable("Bool")) { + litArg = node->Child(1); + anyArg = node->Child(0); + } + + if (node->Child(0)->IsCallable("Bool")) { + litArg = node->Child(0); + anyArg = node->Child(1); + } + + if (litArg && anyArg) { + return ctx.Builder(node->Pos()) + .Callable("==") + .Add(0, anyArg) + .Add(1, MakeBool(node->Pos(), !FromString<bool>(litArg->Head().Content()), ctx)) + .Seal() + .Build(); + } + } + + if (node->IsCallable("Member") && (!parent || parent->IsCallable({"And", "Or", "Not", "Coalesce"}))) { + auto* typeAnn = node->GetTypeAnn(); + if (typeAnn->GetKind() == ETypeAnnotationKind::Optional) { + typeAnn = typeAnn->Cast<TOptionalExprType>()->GetItemType(); + } + if (typeAnn->GetKind() == ETypeAnnotationKind::Data && + typeAnn->Cast<TDataExprType>()->GetSlot() == EDataSlot::Bool) + { + YQL_CLOG(DEBUG, Core) << "Replace raw Member with explicit bool comparison"; + return ctx.Builder(node->Pos()) + .Callable("==") + .Add(0, node) + .Add(1, MakeBool(node->Pos(), true, ctx)) + .Seal() + .Build(); + } + } + return node; } -void DoOptimizeForRangeExtraction(const TExprNode::TPtr& input, TExprNode::TPtr& output, bool topLevel, TExprContext& ctx) { - output = OptimizeNodeForRangeExtraction(input, ctx); +void DoOptimizeForRangeExtraction(const TExprNode::TPtr& input, TExprNode::TPtr& output, bool topLevel, TExprContext& ctx, const TExprNode::TPtr& parent = nullptr) { + output = OptimizeNodeForRangeExtraction(input, parent, ctx); if (output != input) { return; } @@ -715,7 +755,7 @@ void DoOptimizeForRangeExtraction(const TExprNode::TPtr& input, TExprNode::TPtr& continue; } TExprNode::TPtr newChild = child; - DoOptimizeForRangeExtraction(child, newChild, false, ctx); + DoOptimizeForRangeExtraction(child, newChild, false, ctx, input); if (newChild != child) { changed = true; child = std::move(newChild); diff --git a/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp b/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp index c7346dcaf1..9c83da1568 100644 --- a/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp +++ b/ydb/library/yql/core/extract_predicate/ut/extract_predicate_ut.cpp @@ -886,6 +886,91 @@ Y_UNIT_TEST_SUITE(TYqlExtractPredicate) { UNIT_ASSERT_EQUAL(lambda, canonicalLambda); //Cerr << DumpNode(*buildResult.ComputeNode, exprCtx); } + + Y_UNIT_TEST(BoolPredicate) { + TString prog = + "use plato;\n" + "declare $param as List<Int32>;\n" + "$src = [<|x:true, y:1|>, <|x:false, y:2|>];\n" + "insert into Output with truncate\n" + "select * from as_table($src) where not x;"; + + TExprContext exprCtx; + TTypeAnnotationContextPtr typesCtx; + TExprNode::TPtr exprRoot = ParseAndOptimize(prog, exprCtx, typesCtx); + TExprNode::TPtr filterLambda = LocateFilterLambda(exprRoot); + + THashSet<TString> usedColumns; + using NDetail::TPredicateRangeExtractor; + + TPredicateExtractorSettings settings; + auto extractor = MakePredicateRangeExtractor(settings); + + UNIT_ASSERT(extractor->Prepare(filterLambda, *filterLambda->Head().Head().GetTypeAnn(), usedColumns, exprCtx, *typesCtx)); + + auto buildResult = extractor->BuildComputeNode({ "x" }, exprCtx, *typesCtx); + + UNIT_ASSERT(buildResult.ComputeNode); + + auto canonicalRanges = + "(\n" + "(return (RangeFinalize (RangeMultiply (Uint64 '10000) (RangeUnion (RangeFor '=== (Bool 'false) (DataType 'Bool))))))\n" + ")\n"; + auto ranges = DumpNode(*buildResult.ComputeNode, exprCtx); + UNIT_ASSERT_EQUAL(ranges, canonicalRanges); + + auto canonicalLambda = + "(\n" + "(return (lambda '($1) (OptionalIf (Bool 'true) $1)))\n" + ")\n"; + auto lambda = DumpNode(*buildResult.PrunedLambda, exprCtx); + UNIT_ASSERT_EQUAL(lambda, canonicalLambda); + } + + Y_UNIT_TEST(BoolPredicateLiteralRange) { + TString prog = + "use plato;\n" + "declare $param as List<Int32>;\n" + "$src = [<|x:true, y:1|>, <|x:false, y:2|>];\n" + "insert into Output with truncate\n" + "select * from as_table($src) where not x and y = 1;"; + + TExprContext exprCtx; + TTypeAnnotationContextPtr typesCtx; + TExprNode::TPtr exprRoot = ParseAndOptimize(prog, exprCtx, typesCtx); + TExprNode::TPtr filterLambda = LocateFilterLambda(exprRoot); + + THashSet<TString> usedColumns; + using NDetail::TPredicateRangeExtractor; + + TPredicateExtractorSettings settings; + settings.BuildLiteralRange = true; + auto extractor = MakePredicateRangeExtractor(settings); + + UNIT_ASSERT(extractor->Prepare(filterLambda, *filterLambda->Head().Head().GetTypeAnn(), usedColumns, exprCtx, *typesCtx)); + + auto buildResult = extractor->BuildComputeNode({ "x", "y"}, exprCtx, *typesCtx); + + UNIT_ASSERT(buildResult.ComputeNode); + UNIT_ASSERT(buildResult.LiteralRange); + + auto canonicalRanges = + "(\n" + "(let $1 (RangeFor '== (Bool 'false) (DataType 'Bool)))\n" + "(let $2 (RangeFor '=== (Int32 '1) (DataType 'Int32)))\n" + "(return (RangeFinalize (RangeMultiply (Uint64 '10000) (RangeUnion (RangeIntersect (RangeMultiply (Uint64 '10000) $1 $2))))))\n" + ")\n"; + auto ranges = DumpNode(*buildResult.ComputeNode, exprCtx); + UNIT_ASSERT_EQUAL(ranges, canonicalRanges); + + auto canonicalLambda = + "(\n" + "(return (lambda '($1) (OptionalIf (Bool 'true) $1)))\n" + ")\n"; + auto lambda = DumpNode(*buildResult.PrunedLambda, exprCtx); + + UNIT_ASSERT_EQUAL(lambda, canonicalLambda); + } } } // namespace NYql |