summaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp
diff options
context:
space:
mode:
authorvitalyisaev <[email protected]>2023-11-14 09:58:56 +0300
committervitalyisaev <[email protected]>2023-11-14 10:20:20 +0300
commitc2b2dfd9827a400a8495e172a56343462e3ceb82 (patch)
treecd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp
parentd4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff)
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp')
-rw-r--r--contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp225
1 files changed, 225 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp b/contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp
new file mode 100644
index 00000000000..360d1cdf76c
--- /dev/null
+++ b/contrib/clickhouse/src/Functions/UserDefined/UserDefinedSQLFunctionVisitor.cpp
@@ -0,0 +1,225 @@
+#include "UserDefinedSQLFunctionVisitor.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include <stack>
+
+#include <Parsers/ASTAlterQuery.h>
+#include <Parsers/ASTCreateQuery.h>
+#include <Parsers/ASTFunction.h>
+#include <Parsers/ASTCreateFunctionQuery.h>
+#include <Parsers/ASTExpressionList.h>
+#include <Parsers/ASTIdentifier.h>
+#include <Functions/UserDefined/UserDefinedSQLFunctionFactory.h>
+#include "Parsers/ASTColumnDeclaration.h"
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int UNSUPPORTED_METHOD;
+}
+
+void UserDefinedSQLFunctionVisitor::visit(ASTPtr & ast)
+{
+ if (!ast)
+ {
+ chassert(false);
+ return;
+ }
+
+ const auto visit_child_with_shared_ptr = [&](ASTPtr & child)
+ {
+ if (!child)
+ return;
+
+ auto * old_value = child.get();
+ visit(child);
+
+ // child did not change
+ if (old_value == child.get())
+ return;
+
+ // child changed, we need to modify it in the list of children of the parent also
+ for (auto & current_child : ast->children)
+ {
+ if (current_child.get() == old_value)
+ current_child = child;
+ }
+ };
+
+ if (auto * col_decl = ast->as<ASTColumnDeclaration>())
+ {
+ visit_child_with_shared_ptr(col_decl->default_expression);
+ visit_child_with_shared_ptr(col_decl->ttl);
+ return;
+ }
+
+ if (auto * storage = ast->as<ASTStorage>())
+ {
+ const auto visit_child = [&](IAST * & child)
+ {
+ if (!child)
+ return;
+
+ if (const auto * function = child->template as<ASTFunction>())
+ {
+ std::unordered_set<std::string> udf_in_replace_process;
+ auto replace_result = tryToReplaceFunction(*function, udf_in_replace_process);
+ if (replace_result)
+ ast->setOrReplace(child, replace_result);
+ }
+
+ visit(child);
+ };
+
+ visit_child(storage->partition_by);
+ visit_child(storage->primary_key);
+ visit_child(storage->order_by);
+ visit_child(storage->sample_by);
+ visit_child(storage->ttl_table);
+
+ return;
+ }
+
+ if (auto * alter = ast->as<ASTAlterCommand>())
+ {
+ visit_child_with_shared_ptr(alter->col_decl);
+ visit_child_with_shared_ptr(alter->column);
+ visit_child_with_shared_ptr(alter->partition);
+ visit_child_with_shared_ptr(alter->order_by);
+ visit_child_with_shared_ptr(alter->sample_by);
+ visit_child_with_shared_ptr(alter->index_decl);
+ visit_child_with_shared_ptr(alter->index);
+ visit_child_with_shared_ptr(alter->constraint_decl);
+ visit_child_with_shared_ptr(alter->constraint);
+ visit_child_with_shared_ptr(alter->projection_decl);
+ visit_child_with_shared_ptr(alter->projection);
+ visit_child_with_shared_ptr(alter->predicate);
+ visit_child_with_shared_ptr(alter->update_assignments);
+ visit_child_with_shared_ptr(alter->values);
+ visit_child_with_shared_ptr(alter->ttl);
+ visit_child_with_shared_ptr(alter->select);
+
+ return;
+ }
+
+ if (const auto * function = ast->template as<ASTFunction>())
+ {
+ std::unordered_set<std::string> udf_in_replace_process;
+ auto replace_result = tryToReplaceFunction(*function, udf_in_replace_process);
+ if (replace_result)
+ ast = replace_result;
+ }
+
+ for (auto & child : ast->children)
+ visit(child);
+}
+
+void UserDefinedSQLFunctionVisitor::visit(IAST * ast)
+{
+ if (!ast)
+ return;
+
+ for (auto & child : ast->children)
+ visit(child);
+}
+
+ASTPtr UserDefinedSQLFunctionVisitor::tryToReplaceFunction(const ASTFunction & function, std::unordered_set<std::string> & udf_in_replace_process)
+{
+ if (udf_in_replace_process.find(function.name) != udf_in_replace_process.end())
+ throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
+ "Recursive function call detected during function call {}",
+ function.name);
+
+ auto user_defined_function = UserDefinedSQLFunctionFactory::instance().tryGet(function.name);
+ if (!user_defined_function)
+ return nullptr;
+
+ const auto & function_arguments_list = function.children.at(0)->as<ASTExpressionList>();
+ auto & function_arguments = function_arguments_list->children;
+
+ const auto & create_function_query = user_defined_function->as<ASTCreateFunctionQuery>();
+ auto & function_core_expression = create_function_query->function_core->children.at(0);
+
+ const auto & identifiers_expression_list = function_core_expression->children.at(0)->children.at(0)->as<ASTExpressionList>();
+ const auto & identifiers_raw = identifiers_expression_list->children;
+
+ if (function_arguments.size() != identifiers_raw.size())
+ throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
+ "Function {} expects {} arguments actual arguments {}",
+ create_function_query->getFunctionName(),
+ identifiers_raw.size(),
+ function_arguments.size());
+
+ std::unordered_map<std::string, ASTPtr> identifier_name_to_function_argument;
+
+ for (size_t parameter_index = 0; parameter_index < identifiers_raw.size(); ++parameter_index)
+ {
+ const auto & identifier = identifiers_raw[parameter_index]->as<ASTIdentifier>();
+ const auto & function_argument = function_arguments[parameter_index];
+ const auto & identifier_name = identifier->name();
+
+ identifier_name_to_function_argument.emplace(identifier_name, function_argument);
+ }
+
+ auto [it, _] = udf_in_replace_process.emplace(function.name);
+
+ auto function_body_to_update = function_core_expression->children.at(1)->clone();
+
+ auto expression_list = std::make_shared<ASTExpressionList>();
+ expression_list->children.emplace_back(std::move(function_body_to_update));
+
+ std::stack<ASTPtr> ast_nodes_to_update;
+ ast_nodes_to_update.push(expression_list);
+
+ while (!ast_nodes_to_update.empty())
+ {
+ auto ast_node_to_update = ast_nodes_to_update.top();
+ ast_nodes_to_update.pop();
+
+ for (auto & child : ast_node_to_update->children)
+ {
+ if (auto * inner_function = child->as<ASTFunction>())
+ {
+ auto replace_result = tryToReplaceFunction(*inner_function, udf_in_replace_process);
+ if (replace_result)
+ child = replace_result;
+ }
+
+ auto identifier_name_opt = tryGetIdentifierName(child);
+ if (identifier_name_opt)
+ {
+ auto function_argument_it = identifier_name_to_function_argument.find(*identifier_name_opt);
+
+ if (function_argument_it == identifier_name_to_function_argument.end())
+ continue;
+
+ auto child_alias = child->tryGetAlias();
+ child = function_argument_it->second->clone();
+
+ if (!child_alias.empty())
+ child->setAlias(child_alias);
+
+ continue;
+ }
+
+ ast_nodes_to_update.push(child);
+ }
+ }
+
+ udf_in_replace_process.erase(it);
+
+ function_body_to_update = expression_list->children[0];
+
+ auto function_alias = function.tryGetAlias();
+
+ if (!function_alias.empty())
+ function_body_to_update->setAlias(function_alias);
+
+ return function_body_to_update;
+}
+
+}