#include <Interpreters/InJoinSubqueriesPreprocessor.h> #include <Interpreters/Context.h> #include <Interpreters/DatabaseAndTableWithAlias.h> #include <Interpreters/IdentifierSemantic.h> #include <Interpreters/InDepthNodeVisitor.h> #include <Storages/StorageDistributed.h> #include <Parsers/ASTIdentifier.h> #include <Parsers/ASTSelectQuery.h> #include <Parsers/ASTTablesInSelectQuery.h> #include <Parsers/ASTFunction.h> #include <Common/typeid_cast.h> #include <Common/checkStackSize.h> namespace DB { namespace ErrorCodes { extern const int BAD_ARGUMENTS; extern const int DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED; extern const int LOGICAL_ERROR; } namespace { StoragePtr tryGetTable(const ASTPtr & database_and_table, ContextPtr context) { auto table_id = context->tryResolveStorageID(database_and_table); if (!table_id) return {}; return DatabaseCatalog::instance().tryGetTable(table_id, context); } using CheckShardsAndTables = InJoinSubqueriesPreprocessor::CheckShardsAndTables; struct NonGlobalTableData : public WithContext { using TypeToVisit = ASTTableExpression; NonGlobalTableData( ContextPtr context_, const CheckShardsAndTables & checker_, std::vector<ASTPtr> & renamed_tables_, ASTFunction * function_, ASTTableJoin * table_join_) : WithContext(context_), checker(checker_), renamed_tables(renamed_tables_), function(function_), table_join(table_join_) { } const CheckShardsAndTables & checker; std::vector<ASTPtr> & renamed_tables; ASTFunction * function = nullptr; ASTTableJoin * table_join = nullptr; void visit(ASTTableExpression & node, ASTPtr &) { ASTPtr & database_and_table = node.database_and_table_name; if (database_and_table) renameIfNeeded(database_and_table); } private: void renameIfNeeded(ASTPtr & database_and_table) { const DistributedProductMode distributed_product_mode = getContext()->getSettingsRef().distributed_product_mode; StoragePtr storage = tryGetTable(database_and_table, getContext()); if (!storage || !checker.hasAtLeastTwoShards(*storage)) return; if (distributed_product_mode == DistributedProductMode::LOCAL) { /// Convert distributed table to corresponding remote table. std::string database; std::string table; std::tie(database, table) = checker.getRemoteDatabaseAndTableName(*storage); String alias = database_and_table->tryGetAlias(); if (alias.empty()) throw Exception(ErrorCodes::DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED, "Distributed table should have an alias when distributed_product_mode set to local"); auto & identifier = database_and_table->as<ASTTableIdentifier &>(); renamed_tables.emplace_back(identifier.clone()); identifier.resetTable(database, table); } else if (getContext()->getSettingsRef().prefer_global_in_and_join || distributed_product_mode == DistributedProductMode::GLOBAL) { if (function) { auto * concrete = function->as<ASTFunction>(); if (concrete->name == "in") concrete->name = "globalIn"; else if (concrete->name == "notIn") concrete->name = "globalNotIn"; else if (concrete->name == "globalIn" || concrete->name == "globalNotIn") { /// Already processed. } else throw Exception(ErrorCodes::LOGICAL_ERROR, "Logical error: unexpected function name {}", concrete->name); } else if (table_join) table_join->locality = JoinLocality::Global; else throw Exception(ErrorCodes::LOGICAL_ERROR, "Logical error: unexpected AST node"); } else if (distributed_product_mode == DistributedProductMode::DENY) { throw Exception(ErrorCodes::DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED, "Double-distributed IN/JOIN subqueries is denied (distributed_product_mode = 'deny'). " "You may rewrite query to use local tables " "in subqueries, or use GLOBAL keyword, or set distributed_product_mode to suitable value."); } else throw Exception(ErrorCodes::LOGICAL_ERROR, "InJoinSubqueriesPreprocessor: unexpected value of 'distributed_product_mode' setting"); } }; using NonGlobalTableMatcher = OneTypeMatcher<NonGlobalTableData>; using NonGlobalTableVisitor = InDepthNodeVisitor<NonGlobalTableMatcher, true>; class NonGlobalSubqueryMatcher { public: struct Data : public WithContext { using RenamedTables = std::vector<std::pair<ASTPtr, std::vector<ASTPtr>>>; Data(ContextPtr context_, const CheckShardsAndTables & checker_, RenamedTables & renamed_tables_) : WithContext(context_), checker(checker_), renamed_tables(renamed_tables_) { } const CheckShardsAndTables & checker; RenamedTables & renamed_tables; }; static void visit(ASTPtr & node, Data & data) { if (auto * function = node->as<ASTFunction>()) visit(*function, node, data); if (const auto * tables = node->as<ASTTablesInSelectQueryElement>()) visit(*tables, node, data); } static bool needChildVisit(ASTPtr & node, const ASTPtr & child) { if (auto * function = node->as<ASTFunction>()) if (function->name == "in" || function->name == "notIn") return false; /// Processed, process others if (const auto * t = node->as<ASTTablesInSelectQueryElement>()) if (t->table_join && t->table_expression) return false; /// Processed, process others /// Descent into all children, but not into subqueries of other kind (scalar subqueries), that are irrelevant to us. return !child->as<ASTSelectQuery>(); } private: static void visit(ASTFunction & node, ASTPtr &, Data & data) { if (node.name == "in" || node.name == "notIn") { if (node.arguments->children.size() != 2) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function '{}' expects two arguments, given: '{}'", node.name, node.formatForErrorMessage()); } auto & subquery = node.arguments->children.at(1); std::vector<ASTPtr> renamed; NonGlobalTableVisitor::Data table_data(data.getContext(), data.checker, renamed, &node, nullptr); NonGlobalTableVisitor(table_data).visit(subquery); if (!renamed.empty()) data.renamed_tables.emplace_back(subquery, std::move(renamed)); } } static void visit(const ASTTablesInSelectQueryElement & node, ASTPtr &, Data & data) { if (!node.table_join || !node.table_expression) return; ASTTableJoin * table_join = node.table_join->as<ASTTableJoin>(); if (table_join->locality != JoinLocality::Global) { if (auto * table = node.table_expression->as<ASTTableExpression>()) { if (auto & subquery = table->subquery) { std::vector<ASTPtr> renamed; NonGlobalTableVisitor::Data table_data(data.getContext(), data.checker, renamed, nullptr, table_join); NonGlobalTableVisitor(table_data).visit(subquery); if (!renamed.empty()) data.renamed_tables.emplace_back(subquery, std::move(renamed)); } else if (table->database_and_table_name) { auto tb = node.table_expression; std::vector<ASTPtr> renamed; NonGlobalTableVisitor::Data table_data{data.getContext(), data.checker, renamed, nullptr, table_join}; NonGlobalTableVisitor(table_data).visit(tb); if (!renamed.empty()) data.renamed_tables.emplace_back(tb, std::move(renamed)); } } } } }; using NonGlobalSubqueryVisitor = InDepthNodeVisitor<NonGlobalSubqueryMatcher, true>; } void InJoinSubqueriesPreprocessor::visit(ASTPtr & ast) const { if (!ast) return; checkStackSize(); ASTSelectQuery * query = ast->as<ASTSelectQuery>(); if (!query || !query->tables()) return; if (getContext()->getSettingsRef().distributed_product_mode == DistributedProductMode::ALLOW) return; const auto & tables_in_select_query = query->tables()->as<ASTTablesInSelectQuery &>(); if (tables_in_select_query.children.empty()) return; const auto & tables_element = tables_in_select_query.children[0]->as<ASTTablesInSelectQueryElement &>(); if (!tables_element.table_expression) return; const auto * table_expression = tables_element.table_expression->as<ASTTableExpression>(); /// If not ordinary table, skip it. if (!table_expression->database_and_table_name) return; /// If not really distributed table, skip it. { StoragePtr storage = tryGetTable(table_expression->database_and_table_name, getContext()); if (!storage || !checker->hasAtLeastTwoShards(*storage)) return; } NonGlobalSubqueryVisitor::Data visitor_data{getContext(), *checker, renamed_tables}; NonGlobalSubqueryVisitor(visitor_data).visit(ast); } bool InJoinSubqueriesPreprocessor::CheckShardsAndTables::hasAtLeastTwoShards(const IStorage & table) const { const StorageDistributed * distributed = dynamic_cast<const StorageDistributed *>(&table); if (!distributed) return false; return distributed->getShardCount() >= 2; } std::pair<std::string, std::string> InJoinSubqueriesPreprocessor::CheckShardsAndTables::getRemoteDatabaseAndTableName(const IStorage & table) const { const StorageDistributed & distributed = dynamic_cast<const StorageDistributed &>(table); return { distributed.getRemoteDatabaseName(), distributed.getRemoteTableName() }; } }