diff options
| author | vitalyisaev <[email protected]> | 2023-11-14 09:58:56 +0300 |
|---|---|---|
| committer | vitalyisaev <[email protected]> | 2023-11-14 10:20:20 +0300 |
| commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
| tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Interpreters/RewriteAnyFunctionVisitor.cpp | |
| parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Interpreters/RewriteAnyFunctionVisitor.cpp')
| -rw-r--r-- | contrib/clickhouse/src/Interpreters/RewriteAnyFunctionVisitor.cpp | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Interpreters/RewriteAnyFunctionVisitor.cpp b/contrib/clickhouse/src/Interpreters/RewriteAnyFunctionVisitor.cpp new file mode 100644 index 00000000000..163e117f93d --- /dev/null +++ b/contrib/clickhouse/src/Interpreters/RewriteAnyFunctionVisitor.cpp @@ -0,0 +1,124 @@ +#include <Common/typeid_cast.h> +#include <Parsers/ASTFunction.h> +#include <Parsers/ASTIdentifier.h> +#include <Parsers/ASTSubquery.h> +#include <Interpreters/RewriteAnyFunctionVisitor.h> +#include <AggregateFunctions/AggregateFunctionFactory.h> +#include <Parsers/ASTTablesInSelectQuery.h> + +namespace DB +{ + +namespace +{ + +bool extractIdentifiers(const ASTFunction & func, std::unordered_set<ASTPtr *> & identifiers) +{ + for (auto & arg : func.arguments->children) + { + if (const auto * arg_func = arg->as<ASTFunction>()) + { + /// arrayJoin() is special and should not be optimized (think about + /// it as a an aggregate function), otherwise wrong result will be + /// produced: + /// SELECT *, any(arrayJoin([[], []])) FROM numbers(1) GROUP BY number + /// ┌─number─┬─arrayJoin(array(array(), array()))─┐ + /// │ 0 │ [] │ + /// │ 0 │ [] │ + /// └────────┴────────────────────────────────────┘ + /// While should be: + /// ┌─number─┬─any(arrayJoin(array(array(), array())))─┐ + /// │ 0 │ [] │ + /// └────────┴─────────────────────────────────────────┘ + if (arg_func->name == "arrayJoin") + return false; + + if (arg_func->name == "lambda") + return false; + + // We are looking for identifiers inside a function calculated inside + // the aggregate function `any()`. Window or aggregate function can't + // be inside `any`, but this check in GetAggregatesMatcher happens + // later, so we have to explicitly skip these nested functions here. + if (arg_func->is_window_function + || AggregateUtils::isAggregateFunction(*arg_func)) + { + return false; + } + + if (!extractIdentifiers(*arg_func, identifiers)) + return false; + } + else if (arg->as<ASTIdentifier>()) + identifiers.emplace(&arg); + } + + return true; +} + +} + + +void RewriteAnyFunctionMatcher::visit(ASTPtr & ast, Data & data) +{ + if (auto * func = ast->as<ASTFunction>()) + { + if (func->is_window_function) + return; + + visit(*func, ast, data); + } +} + +void RewriteAnyFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data) +{ + if (!func.arguments || func.arguments->children.empty() || !func.arguments->children[0]) + return; + + if (func.name != "any" && func.name != "anyLast") + return; + + auto & func_arguments = func.arguments->children; + + if (func_arguments.size() != 1) + return; + + const auto * first_arg_func = func_arguments[0]->as<ASTFunction>(); + if (!first_arg_func || first_arg_func->arguments->children.empty()) + return; + + /// We have rewritten this function. Just unwrap its argument. + if (data.rewritten.contains(ast.get())) + { + func_arguments[0]->setAlias(func.alias); + ast = func_arguments[0]; + return; + } + + std::unordered_set<ASTPtr *> identifiers; /// implicit remove duplicates + if (!extractIdentifiers(func, identifiers)) + return; + + /// Wrap identifiers: any(f(x, y, g(z))) -> any(f(any(x), any(y), g(any(z)))) + for (auto * ast_to_change : identifiers) + { + ASTPtr identifier_ast = *ast_to_change; + *ast_to_change = makeASTFunction(func.name); + (*ast_to_change)->as<ASTFunction>()->arguments->children.emplace_back(identifier_ast); + } + + data.rewritten.insert(ast.get()); + + /// Unwrap function: any(f(any(x), any(y), g(any(z)))) -> f(any(x), any(y), g(any(z))) + func_arguments[0]->setAlias(func.alias); + ast = func_arguments[0]; +} + +bool RewriteAnyFunctionMatcher::needChildVisit(const ASTPtr & node, const ASTPtr &) +{ + return !node->as<ASTSubquery>() && + !node->as<ASTTableExpression>() && + !node->as<ASTArrayJoin>(); +} + +} |
