diff options
author | vitya-smirnov <[email protected]> | 2025-08-04 12:34:52 +0300 |
---|---|---|
committer | vitya-smirnov <[email protected]> | 2025-08-04 12:47:13 +0300 |
commit | 2193f20a163ea765183f000e8d4e1902a96fed71 (patch) | |
tree | 953970cf094d61a2808c7ffa30ef5b8d1b09f6c9 | |
parent | b775b25ed665fee10e5af9a506a151f761cad219 (diff) |
YQL-19747: Partially concat and resolve symbols
- Migrated GetId to a Partial Evaluation function.
- Centeralized named expressions collection.
- Centeralized query parameters resolution.
- Supported symbols resolution and concatenation.
- Migrated USE and Columns Syntesis on the evaluation.
commit_hash:fafedc4330bcd4a7ab607ae1d5f2691a2f5a60f2
14 files changed, 520 insertions, 164 deletions
diff --git a/yql/essentials/sql/v1/complete/analysis/global/column.cpp b/yql/essentials/sql/v1/complete/analysis/global/column.cpp index 8fbf90e0c36..64c5925ffec 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/column.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/column.cpp @@ -1,10 +1,9 @@ #include "column.h" #include "base_visitor.h" +#include "evaluate.h" #include "narrowing_visitor.h" -#include <yql/essentials/sql/v1/complete/syntax/format.h> - #include <util/generic/hash_set.h> #include <util/generic/scope.h> @@ -12,72 +11,10 @@ namespace NSQLComplete { namespace { - // TODO: Extract it to `identifier.cpp` and reuse it also at `use.cpp` - // and replace `GetId` at `parse_tree.cpp`. - class TIdentifierVisitor: public SQLv1Antlr4BaseVisitor { - public: - std::any visitCluster_expr(SQLv1::Cluster_exprContext* ctx) override { - if (auto* x = ctx->pure_column_or_named()) { - return visit(x); - } - return {}; - } - - std::any visitTable_key(SQLv1::Table_keyContext* ctx) override { - if (auto* x = ctx->id_table_or_type()) { - return visit(x); - } - return {}; - } - - std::any visitUnary_casual_subexpr(SQLv1::Unary_casual_subexprContext* ctx) override { - std::any prev; - if (auto* x = ctx->id_expr()) { - prev = visit(x); - } else if (auto* x = ctx->atom_expr()) { - prev = visit(x); - } - - std::any next = visit(ctx->unary_subexpr_suffix()); - if (!next.has_value()) { - return prev; - } - - return {}; - } - - std::any visitTerminal(antlr4::tree::TerminalNode* node) override { - switch (node->getSymbol()->getType()) { - case SQLv1::TOKEN_ID_QUOTED: - return TString(Unquoted(GetText(node))); - case SQLv1::TOKEN_ID_PLAIN: - return GetText(node); - } - return {}; - } - - private: - TString GetText(antlr4::tree::ParseTree* tree) const { - return TString(tree->getText()); - } - }; - - TMaybe<TString> GetId(antlr4::ParserRuleContext* ctx) { - if (ctx == nullptr) { - return Nothing(); - } - - std::any result = TIdentifierVisitor().visit(ctx); - if (!result.has_value()) { - return Nothing(); - } - return std::any_cast<TString>(result); - } - class TInferenceVisitor: public TSQLv1BaseVisitor { public: - explicit TInferenceVisitor(THashMap<TString, SQLv1::Subselect_stmtContext*> subqueries) - : Subqueries_(std::move(subqueries)) + TInferenceVisitor(const TNamedNodes* nodes) + : Nodes_(nodes) { } @@ -106,8 +43,9 @@ namespace NSQLComplete { } std::any visitTable_ref(SQLv1::Table_refContext* ctx) override { - if (TMaybe<TString> path = GetId(ctx->table_key())) { - TString cluster = GetId(ctx->cluster_expr()).GetOrElse(""); + if (TMaybe<TString> path; (path = GetObjectId(ctx->table_key())) || + (path = GetObjectId(ctx->bind_parameter()))) { + TString cluster = GetObjectId(ctx->cluster_expr()).GetOrElse(""); return TColumnContext{ .Tables = { TTableId{std::move(cluster), std::move(*path)}, @@ -115,19 +53,22 @@ namespace NSQLComplete { }; } - if (TMaybe<TString> named = NSQLComplete::GetId(ctx->bind_parameter())) { - if (auto it = Subqueries_.find(*named); it != Subqueries_.end()) { - if (Resolving_.contains(*named)) { - return {}; - } - - Resolving_.emplace(*named); - Y_DEFER { - Resolving_.erase(*named); - }; + if (TMaybe<TString> named = NSQLComplete::GetName(ctx->bind_parameter())) { + const TNamedNode* node = Nodes_->FindPtr(*named); + if (!node || !std::holds_alternative<SQLv1::Subselect_stmtContext*>(*node)) { + return {}; + } - return visit(it->second); + if (Resolving_.contains(*named)) { + return {}; } + + Resolving_.emplace(*named); + Y_DEFER { + Resolving_.erase(*named); + }; + + return visit(std::get<SQLv1::Subselect_stmtContext*>(*node)); } return {}; @@ -176,7 +117,7 @@ namespace NSQLComplete { } if (ctx->opt_id_prefix() != nullptr && ctx->TOKEN_ASTERISK() != nullptr) { - TMaybe<TString> alias = GetId(ctx->opt_id_prefix()->an_id()); + TMaybe<TString> alias = GetColumnId(ctx->opt_id_prefix()->an_id()); if (alias.Empty()) { return TColumnContext::Asterisk(); } @@ -205,9 +146,9 @@ namespace NSQLComplete { }; std::any visitWithout_column_name(SQLv1::Without_column_nameContext* ctx) override { - TString table = GetId(ctx->an_id(0)).GetOrElse(""); - TMaybe<TString> column = GetId(ctx->an_id(1)).Or([&] { - return GetId(ctx->an_id_without()); + TString table = GetObjectId(ctx->an_id(0)).GetOrElse(""); + TMaybe<TString> column = GetColumnId(ctx->an_id(1)).Or([&] { + return GetColumnId(ctx->an_id_without()); }); if (column.Empty()) { @@ -223,8 +164,8 @@ namespace NSQLComplete { private: TMaybe<TString> GetAlias(SQLv1::Named_single_sourceContext* ctx) const { - TMaybe<TString> alias = GetId(ctx->an_id()); - alias = alias.Defined() ? alias : GetId(ctx->an_id_as_compat()); + TMaybe<TString> alias = GetColumnId(ctx->an_id()); + alias = alias.Defined() ? alias : GetColumnId(ctx->an_id_as_compat()); return alias; } @@ -236,7 +177,7 @@ namespace NSQLComplete { id = ctx->an_id_or_type(); id = id ? id : ctx->an_id_as_compat(); } - return GetId(id); + return GetColumnId(id); } TMaybe<TColumnContext> Head(SQLv1::Select_coreContext* ctx) { @@ -269,14 +210,36 @@ namespace NSQLComplete { }); } - THashMap<TString, SQLv1::Subselect_stmtContext*> Subqueries_; + TMaybe<TString> GetColumnId(antlr4::ParserRuleContext* ctx) const { + if (!ctx) { + return Nothing(); + } + + TPartialValue value = PartiallyEvaluate(ctx, *Nodes_); + if (!std::holds_alternative<TIdentifier>(value)) { + return Nothing(); + } + + return std::get<TIdentifier>(value); + } + + TMaybe<TString> GetObjectId(antlr4::ParserRuleContext* ctx) const { + if (!ctx) { + return Nothing(); + } + + return ToObjectRef(PartiallyEvaluate(ctx, *Nodes_)); + } + THashSet<TString> Resolving_; + const TNamedNodes* Nodes_; }; class TVisitor: public TSQLv1NarrowingVisitor { public: - TVisitor(const TParsedInput& input) + TVisitor(const TParsedInput& input, const TNamedNodes* nodes) : TSQLv1NarrowingVisitor(input) + , Nodes_(nodes) { } @@ -304,46 +267,22 @@ namespace NSQLComplete { return {}; } - return TInferenceVisitor(std::move(Subqueries_)).visit(source); + return TInferenceVisitor(Nodes_).visit(source); } private: - std::any visitNamed_nodes_stmt(SQLv1::Named_nodes_stmtContext* ctx) override { - TMaybe<std::string> name = Name(ctx->bind_parameter_list()); - if (name.Empty()) { - return {}; - } - - SQLv1::Subselect_stmtContext* subselect = ctx->subselect_stmt(); - if (subselect == nullptr) { - return {}; - } - - Subqueries_[std::move(*name)] = subselect; - return {}; - } - - TMaybe<std::string> Name(SQLv1::Bind_parameter_listContext* ctx) const { - auto parameters = ctx->bind_parameter(); - if (parameters.size() != 1) { - return Nothing(); - } - - return NSQLComplete::GetId(parameters[0]); - } - bool IsEnclosingStrict(antlr4::ParserRuleContext* ctx) const { return ctx != nullptr && IsEnclosing(ctx); } - THashMap<TString, SQLv1::Subselect_stmtContext*> Subqueries_; + const TNamedNodes* Nodes_; }; } // namespace - TMaybe<TColumnContext> InferColumnContext(TParsedInput input) { + TMaybe<TColumnContext> InferColumnContext(TParsedInput input, const TNamedNodes& nodes) { // TODO: add utility `auto ToMaybe<T>(std::any any) -> TMaybe<T>` - std::any result = TVisitor(input).visit(input.SqlQuery); + std::any result = TVisitor(input, &nodes).visit(input.SqlQuery); if (!result.has_value()) { return Nothing(); } diff --git a/yql/essentials/sql/v1/complete/analysis/global/column.h b/yql/essentials/sql/v1/complete/analysis/global/column.h index 306626ba061..0caa617ea7c 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/column.h +++ b/yql/essentials/sql/v1/complete/analysis/global/column.h @@ -2,9 +2,10 @@ #include "global.h" #include "input.h" +#include "named_node.h" namespace NSQLComplete { - TMaybe<TColumnContext> InferColumnContext(TParsedInput input); + TMaybe<TColumnContext> InferColumnContext(TParsedInput input, const TNamedNodes& nodes); } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/evaluate.cpp b/yql/essentials/sql/v1/complete/analysis/global/evaluate.cpp index 9ca19429253..c42fef22c24 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/evaluate.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/evaluate.cpp @@ -1,46 +1,223 @@ #include "evaluate.h" +#include <yql/essentials/sql/v1/complete/syntax/format.h> + +#include <contrib/libs/re2/re2/re2.h> + +#include <util/generic/hash_set.h> +#include <util/generic/scope.h> + namespace NSQLComplete { namespace { class TVisitor: public SQLv1Antlr4BaseVisitor { public: - explicit TVisitor(const TEnvironment* env) - : Env_(env) + explicit TVisitor(const TNamedNodes* nodes) + : Nodes_(nodes) { } - std::any visitBind_parameter(SQLv1::Bind_parameterContext* ctx) override { - TMaybe<std::string> id = GetId(ctx); - if (id.Empty()) { - return defaultResult(); + std::any visitCluster_expr(SQLv1::Cluster_exprContext* ctx) override { + if (auto* x = ctx->pure_column_or_named()) { + return visit(x); + } + return defaultResult(); + } + + std::any visitTable_key(SQLv1::Table_keyContext* ctx) override { + if (auto* x = ctx->id_table_or_type()) { + return visit(x); + } + return defaultResult(); + } + + std::any visitUnary_casual_subexpr(SQLv1::Unary_casual_subexprContext* ctx) override { + TPartialValue prev; + if (auto* x = ctx->id_expr()) { + prev = std::any_cast<TPartialValue>(visit(x)); + } else if (auto* x = ctx->atom_expr()) { + prev = std::any_cast<TPartialValue>(visit(x)); + } + + auto next = std::any_cast<TPartialValue>(visit(ctx->unary_subexpr_suffix())); + if (!IsDefined(next)) { + return prev; + } + + return defaultResult(); + } + + std::any visitMul_subexpr(SQLv1::Mul_subexprContext* ctx) override { + auto args = ctx->con_subexpr(); + Y_ENSURE(!args.empty()); + + if (args.size() == 1) { + return visit(args[0]); } - id->insert(0, "$"); - if (const NYT::TNode* node = Env_->Parameters.FindPtr(*id)) { - return *node; + NYT::TNode result; + for (auto* arg : args) { + if (!arg) { + return defaultResult(); + } + + auto value = std::any_cast<TPartialValue>(visit(arg)); + if (!std::holds_alternative<NYT::TNode>(value)) { + return defaultResult(); + } + + auto node = std::get<NYT::TNode>(value); + auto maybe = Concat(std::move(result), std::move(node)); + if (!maybe) { + return defaultResult(); + } + + result = std::move(*maybe); } + return TPartialValue(std::move(result)); + } + std::any visitTerminal(antlr4::tree::TerminalNode* node) override { + switch (node->getSymbol()->getType()) { + case SQLv1::TOKEN_ID_QUOTED: + return TPartialValue(TString(Unquoted(GetText(node)))); + case SQLv1::TOKEN_ID_PLAIN: + return TPartialValue(GetText(node)); + case SQLv1::TOKEN_STRING_VALUE: + if (auto content = GetContent(node)) { + return TPartialValue(NYT::TNode(std::move(*content))); + } + } return defaultResult(); } + std::any visitBind_parameter(SQLv1::Bind_parameterContext* ctx) override { + TMaybe<std::string> id = GetName(ctx); + if (id.Empty()) { + return defaultResult(); + } + + return EvaluateNode(std::move(*id)); + } + + protected: std::any defaultResult() override { - return NYT::TNode(); + return TPartialValue(std::monostate()); } private: - const TEnvironment* Env_; + TPartialValue EvaluateNode(std::string name) { + const TNamedNode* node = Nodes_->FindPtr(name); + if (!node) { + return std::monostate(); + } + + if (std::holds_alternative<NYT::TNode>(*node)) { + return std::get<NYT::TNode>(*node); + } + + if (std::holds_alternative<SQLv1::ExprContext*>(*node)) { + if (Resolving_.contains(name)) { + return std::monostate(); + } + + Resolving_.emplace(name); + Y_DEFER { + Resolving_.erase(name); + }; + + std::any any = visit(std::get<SQLv1::ExprContext*>(*node)); + return std::any_cast<TPartialValue>(std::move(any)); + } + + return std::monostate(); + } + + TMaybe<NYT::TNode> Concat(NYT::TNode lhs, NYT::TNode rhs) { + if (!lhs.HasValue()) { + return rhs; + } + + NYT::TNode::EType type = rhs.GetType(); + if (type != lhs.GetType()) { + return Nothing(); + } + + switch (type) { + case NYT::TNode::String: + return lhs.AsString() + rhs.AsString(); + case NYT::TNode::Int64: + case NYT::TNode::Uint64: + case NYT::TNode::Double: + case NYT::TNode::Bool: + case NYT::TNode::List: + case NYT::TNode::Map: + case NYT::TNode::Undefined: + case NYT::TNode::Null: + return Nothing(); + } + } + + TIdentifier GetText(antlr4::tree::ParseTree* tree) const { + return TIdentifier(tree->getText()); + } + + TMaybe<TString> GetContent(antlr4::tree::TerminalNode* node) const { + static RE2 regex(R"re(["']([^"'\\]*)["'])re"); + + TString text = GetText(node); + TString content; + if (!RE2::FullMatch(text, regex, &content)) { + return Nothing(); + } + + return content; + } + + THashSet<std::string> Resolving_; + const TNamedNodes* Nodes_; }; - NYT::TNode EvaluateG(antlr4::ParserRuleContext* ctx, const TEnvironment& env) { - return std::any_cast<NYT::TNode>(TVisitor(&env).visit(ctx)); + TPartialValue EvaluateG(antlr4::ParserRuleContext* ctx, const TNamedNodes& nodes) { + return std::any_cast<TPartialValue>(TVisitor(&nodes).visit(ctx)); } } // namespace - NYT::TNode Evaluate(SQLv1::Bind_parameterContext* ctx, const TEnvironment& env) { - return EvaluateG(ctx, env); + bool IsDefined(const TPartialValue& value) { + return !std::holds_alternative<std::monostate>(value); + } + + TMaybe<TString> ToObjectRef(const TPartialValue& value) { + return std::visit([](const auto& value) -> TMaybe<TString> { + using T = std::decay_t<decltype(value)>; + if constexpr (std::is_same_v<T, NYT::TNode>) { + if (!value.IsString()) { + return Nothing(); + } + + return value.AsString(); + } else if constexpr (std::is_same_v<T, TIdentifier>) { + return value; + } else if constexpr (std::is_same_v<T, std::monostate>) { + return Nothing(); + } else { + static_assert(false); + } + }, value); + } + + NYT::TNode Evaluate(SQLv1::Bind_parameterContext* ctx, const TNamedNodes& nodes) { + TPartialValue value = EvaluateG(ctx, nodes); + if (std::holds_alternative<NYT::TNode>(value)) { + return std::get<NYT::TNode>(value); + } + return NYT::TNode(); + } + + TPartialValue PartiallyEvaluate(antlr4::ParserRuleContext* ctx, const TNamedNodes& nodes) { + return EvaluateG(ctx, nodes); } } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/evaluate.h b/yql/essentials/sql/v1/complete/analysis/global/evaluate.h index 03cbcdc798b..fd10939f0ea 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/evaluate.h +++ b/yql/essentials/sql/v1/complete/analysis/global/evaluate.h @@ -1,11 +1,25 @@ #pragma once +#include "named_node.h" #include "parse_tree.h" #include <yql/essentials/sql/v1/complete/core/environment.h> namespace NSQLComplete { - NYT::TNode Evaluate(SQLv1::Bind_parameterContext* ctx, const TEnvironment& env); + using TIdentifier = TString; + + using TPartialValue = std::variant< + NYT::TNode, + TIdentifier, + std::monostate>; + + bool IsDefined(const TPartialValue& value); + + TMaybe<TString> ToObjectRef(const TPartialValue& value); + + NYT::TNode Evaluate(SQLv1::Bind_parameterContext* ctx, const TNamedNodes& nodes); + + TPartialValue PartiallyEvaluate(antlr4::ParserRuleContext* ctx, const TNamedNodes& nodes); } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/global.cpp b/yql/essentials/sql/v1/complete/analysis/global/global.cpp index 130338e03de..e394570c35b 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/global.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/global.cpp @@ -167,10 +167,12 @@ namespace NSQLComplete { .SqlQuery = sqlQuery, }; - ctx.Use = FindUseStatement(parsed, env); - ctx.Names = CollectNamedNodes(parsed); + TNamedNodes nodes = CollectNamedNodes(parsed, env); + + ctx.Use = FindUseStatement(parsed, nodes); + ctx.Names = Keys(nodes); ctx.EnclosingFunction = EnclosingFunction(parsed); - ctx.Column = InferColumnContext(parsed); + ctx.Column = InferColumnContext(parsed, nodes); if (ctx.Use && ctx.Column) { EnrichTableClusters(*ctx.Column, *ctx.Use); @@ -196,6 +198,15 @@ namespace NSQLComplete { return Parser_.sql_query(); } + TVector<TString> Keys(const TNamedNodes& nodes) { + TVector<TString> keys; + keys.reserve(nodes.size()); + for (const auto& [name, _] : nodes) { + keys.emplace_back(name); + } + return keys; + } + void EnrichTableClusters(TColumnContext& column, const TUseContext& use) { for (auto& table : column.Tables) { if (table.Cluster.empty()) { diff --git a/yql/essentials/sql/v1/complete/analysis/global/global_ut.cpp b/yql/essentials/sql/v1/complete/analysis/global/global_ut.cpp index 522628bb5d5..4f0d8c8a456 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/global_ut.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/global_ut.cpp @@ -359,4 +359,90 @@ Y_UNIT_TEST_SUITE(GlobalAnalysisTests) { } } + Y_UNIT_TEST(EvaluationAssignment) { + IGlobalAnalysis::TPtr global = MakeGlobalAnalysis(); + + TString query = R"sql( + DECLARE $t1 AS String; + DECLARE $t2 AS String; + + $x1 = $t1; + $x2 = $t2; + + $y1 = $x1; + $y2 = $x2; + + $q1 = (SELECT * FROM $y1); + $q2 = (SELECT * FROM $y2); + + SELECT # FROM $q1 JOIN $q2 + )sql"; + + TEnvironment env = { + .Parameters = { + {"$t1", "table1"}, + {"$t2", "table2"}, + }, + }; + + TGlobalContext ctx = global->Analyze(SharpedInput(query), env); + + TColumnContext expected = { + .Tables = { + TAliased<TTableId>("", {"", "table1"}), + TAliased<TTableId>("", {"", "table2"}), + }, + }; + UNIT_ASSERT_VALUES_EQUAL(ctx.Column, expected); + } + + Y_UNIT_TEST(EvaluationRecursion) { + IGlobalAnalysis::TPtr global = MakeGlobalAnalysis(); + + TVector<TString> queries = { + R"sql($x = $x; SELECT # FROM $x)sql", + R"sql($x = $y; $y = $x; SELECT # FROM $x)sql", + }; + + for (TString& query : queries) { + TGlobalContext ctx = global->Analyze(SharpedInput(query), {}); + Y_DO_NOT_OPTIMIZE_AWAY(ctx); + } + } + + Y_UNIT_TEST(EvaluationSubquery) { + IGlobalAnalysis::TPtr global = MakeGlobalAnalysis(); + + TString query = R"sql( + $x = (SELECT * FROM $y); + $y = (SELECT * FROM $x); + $z = $x || $y; + SELECT # FROM $z; + )sql"; + + TGlobalContext ctx = global->Analyze(SharpedInput(query), {}); + + TColumnContext expected = {}; + UNIT_ASSERT_VALUES_EQUAL(ctx.Column, expected); + } + + Y_UNIT_TEST(EvaluationStringConcat) { + IGlobalAnalysis::TPtr global = MakeGlobalAnalysis(); + + TString query = R"sql( + $cluster = 'ex' || 'am' || "ple"; + $product = "yql"; + $seq = "1"; + $source = "/home/" || $product || "/" || $seq; + SELECT # FROM $cluster.$source; + )sql"; + + TGlobalContext ctx = global->Analyze(SharpedInput(query), {}); + + TColumnContext expected = { + .Tables = {TAliased<TTableId>("", {"example", "/home/yql/1"})}, + }; + UNIT_ASSERT_VALUES_EQUAL(ctx.Column, expected); + } + } // Y_UNIT_TEST_SUITE(GlobalAnalysisTests) diff --git a/yql/essentials/sql/v1/complete/analysis/global/named_node.cpp b/yql/essentials/sql/v1/complete/analysis/global/named_node.cpp index 57223e73405..792bd3b04cc 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/named_node.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/named_node.cpp @@ -14,9 +14,10 @@ namespace NSQLComplete { class TVisitor: public TSQLv1NarrowingVisitor { public: - TVisitor(const TParsedInput& input, THashSet<TString>* names) + TVisitor(const TParsedInput& input, TNamedNodes* names, const TEnvironment* env) : TSQLv1NarrowingVisitor(input) , Names_(names) + , Env_(env) { } @@ -32,7 +33,26 @@ namespace NSQLComplete { } std::any visitDeclare_stmt(SQLv1::Declare_stmtContext* ctx) override { - VisitNullable(ctx->bind_parameter()); + auto* parameter = ctx->bind_parameter(); + if (!parameter) { + return {}; + } + + TMaybe<std::string> id = GetName(parameter); + if (id.Empty() || id == "_") { + return {}; + } + + id->insert(0, "$"); + const NYT::TNode* node = Env_->Parameters.FindPtr(*id); + id->erase(0, 1); + + if (node) { + (*Names_)[*id] = *node; + } else { + (*Names_)[*id] = std::monostate(); + } + return {}; } @@ -53,8 +73,37 @@ namespace NSQLComplete { std::any visitNamed_nodes_stmt(SQLv1::Named_nodes_stmtContext* ctx) override { VisitNullable(ctx->bind_parameter_list()); if (IsEnclosing(ctx)) { - return visitChildren(ctx); + visitChildren(ctx); + } + + auto* list = ctx->bind_parameter_list(); + if (!list) { + return {}; + } + + auto parameters = list->bind_parameter(); + if (parameters.size() != 1) { + return {}; + } + + auto* parameter = parameters[0]; + if (!parameter) { + return {}; } + + TMaybe<std::string> id = GetName(parameter); + if (id.Empty() || id == "_") { + return {}; + } + + if (auto* expr = ctx->expr()) { + (*Names_)[std::move(*id)] = expr; + } else if (auto* subselect = ctx->subselect_stmt()) { + (*Names_)[std::move(*id)] = subselect; + } else { + (*Names_)[std::move(*id)] = std::monostate(); + } + return {}; } @@ -85,12 +134,12 @@ namespace NSQLComplete { return {}; } - TMaybe<std::string> id = GetId(ctx); + TMaybe<std::string> id = GetName(ctx); if (id.Empty() || id == "_") { return {}; } - Names_->emplace(std::move(*id)); + (*Names_)[std::move(*id)] = std::monostate(); return {}; } @@ -102,15 +151,16 @@ namespace NSQLComplete { visit(tree); } - THashSet<TString>* Names_; + TNamedNodes* Names_; + const TEnvironment* Env_; }; } // namespace - TVector<TString> CollectNamedNodes(TParsedInput input) { - THashSet<TString> names; - TVisitor(input, &names).visit(input.SqlQuery); - return TVector<TString>(begin(names), end(names)); + TNamedNodes CollectNamedNodes(TParsedInput input, const TEnvironment& env) { + TNamedNodes names; + TVisitor(input, &names, &env).visit(input.SqlQuery); + return names; } } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/named_node.h b/yql/essentials/sql/v1/complete/analysis/global/named_node.h index 55466765013..07e1e6f91b9 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/named_node.h +++ b/yql/essentials/sql/v1/complete/analysis/global/named_node.h @@ -2,11 +2,23 @@ #include "input.h" +#include <yql/essentials/sql/v1/complete/core/environment.h> + +#include <library/cpp/yson/node/node.h> + #include <util/generic/string.h> #include <util/generic/vector.h> namespace NSQLComplete { - TVector<TString> CollectNamedNodes(TParsedInput input); + using TNamedNode = std::variant< + SQLv1::ExprContext*, + SQLv1::Subselect_stmtContext*, + NYT::TNode, + std::monostate>; + + using TNamedNodes = THashMap<TString, TNamedNode>; + + TNamedNodes CollectNamedNodes(TParsedInput input, const TEnvironment& env); } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/parse_tree.cpp b/yql/essentials/sql/v1/complete/analysis/global/parse_tree.cpp index 6a87a540066..3577a6cf8d1 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/parse_tree.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/parse_tree.cpp @@ -1,11 +1,13 @@ #include "parse_tree.h" +#include <yql/essentials/sql/v1/complete/syntax/format.h> + #include <util/system/yassert.h> #include <util/generic/maybe.h> namespace NSQLComplete { - TMaybe<std::string> GetId(SQLv1::Bind_parameterContext* ctx) { + TMaybe<std::string> GetName(SQLv1::Bind_parameterContext* ctx) { if (ctx == nullptr) { return Nothing(); } diff --git a/yql/essentials/sql/v1/complete/analysis/global/parse_tree.h b/yql/essentials/sql/v1/complete/analysis/global/parse_tree.h index f194232aea5..72423526547 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/parse_tree.h +++ b/yql/essentials/sql/v1/complete/analysis/global/parse_tree.h @@ -17,6 +17,6 @@ namespace NSQLComplete { using NALADefaultAntlr4::SQLv1Antlr4BaseVisitor; - TMaybe<std::string> GetId(SQLv1::Bind_parameterContext* ctx); + TMaybe<std::string> GetName(SQLv1::Bind_parameterContext* ctx); } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/use.cpp b/yql/essentials/sql/v1/complete/analysis/global/use.cpp index ce592b724c3..4d3a98414ba 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/use.cpp +++ b/yql/essentials/sql/v1/complete/analysis/global/use.cpp @@ -9,9 +9,9 @@ namespace NSQLComplete { class TVisitor: public TSQLv1NarrowingVisitor { public: - TVisitor(const TParsedInput& input, const TEnvironment* env) + TVisitor(const TParsedInput& input, const TNamedNodes* nodes) : TSQLv1NarrowingVisitor(input) - , Env_(env) + , Nodes_(nodes) { } @@ -63,21 +63,21 @@ namespace NSQLComplete { } TMaybe<TString> GetId(SQLv1::Bind_parameterContext* ctx) const { - NYT::TNode node = Evaluate(ctx, *Env_); + NYT::TNode node = Evaluate(ctx, *Nodes_); if (!node.HasValue() || !node.IsString()) { return Nothing(); } return node.AsString(); } - const TEnvironment* Env_; + const TNamedNodes* Nodes_; }; } // namespace // TODO(YQL-19747): Use any to maybe conversion function - TMaybe<TUseContext> FindUseStatement(TParsedInput input, const TEnvironment& env) { - std::any result = TVisitor(input, &env).visit(input.SqlQuery); + TMaybe<TUseContext> FindUseStatement(TParsedInput input, const TNamedNodes& nodes) { + std::any result = TVisitor(input, &nodes).visit(input.SqlQuery); if (!result.has_value()) { return Nothing(); } diff --git a/yql/essentials/sql/v1/complete/analysis/global/use.h b/yql/essentials/sql/v1/complete/analysis/global/use.h index 964519e4e99..411e0f9fcd8 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/use.h +++ b/yql/essentials/sql/v1/complete/analysis/global/use.h @@ -2,6 +2,7 @@ #include "global.h" #include "input.h" +#include "named_node.h" #include <util/generic/ptr.h> #include <util/generic/maybe.h> @@ -9,6 +10,6 @@ namespace NSQLComplete { - TMaybe<TUseContext> FindUseStatement(TParsedInput input, const TEnvironment& env); + TMaybe<TUseContext> FindUseStatement(TParsedInput input, const TNamedNodes& nodes); } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/analysis/global/ya.make b/yql/essentials/sql/v1/complete/analysis/global/ya.make index cf2a9c9aada..9a6d824ca1e 100644 --- a/yql/essentials/sql/v1/complete/analysis/global/ya.make +++ b/yql/essentials/sql/v1/complete/analysis/global/ya.make @@ -19,6 +19,7 @@ PEERDIR( yql/essentials/sql/v1/complete/text yql/essentials/parser/antlr_ast/gen/v1_antlr4 yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4 + contrib/libs/re2 ) END() diff --git a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp index d1dd837f9f9..604af6008d9 100644 --- a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp +++ b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp @@ -225,19 +225,24 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { Y_UNIT_TEST(UseClusterResultion) { auto engine = MakeSqlCompletionEngineUT(); { + TString query = R"sql( + DECLARE $cluster_name AS String; + USE yt:$cluster_name; + SELECT * FROM # + )sql"; + + TEnvironment env = { + .Parameters = {{"$cluster_name", "saurus"}}, + }; + TVector<TCandidate> expected = { + {BindingName, "$cluster_name"}, {TableName, "`maxim`"}, {ClusterName, "example"}, {ClusterName, "saurus"}, - {Keyword, "ANY"}, }; - UNIT_ASSERT_VALUES_EQUAL( - CompleteTop( - 4, - engine, - "USE yt:$cluster_name; SELECT * FROM ", - {.Parameters = {{"$cluster_name", "saurus"}}}), - expected); + + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(4, engine, query, env), expected); } { TVector<TCandidate> expected = { @@ -1310,6 +1315,63 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { } } + Y_UNIT_TEST(ColumnsFromNamedExpr) { + auto engine = MakeSqlCompletionEngineUT(); + { + TVector<TString> queries = { + R"sql(SELECT # FROM $)sql", + R"sql(SELECT # FROM $$)sql", + R"sql(SELECT # FROM $x)sql", + }; + + TVector<TCandidate> expected = { + {Keyword, "ALL"}, + }; + + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(1, engine, queries[0]), expected); + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(1, engine, queries[1]), expected); + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(1, engine, queries[2]), expected); + } + { + TString declare = R"sql(DECLARE $x AS String;)sql"; + + TVector<TString> queries = { + declare + R"sql(SELECT # FROM example.$x)sql", + declare + R"sql(USE example; SELECT # FROM $x)sql", + }; + + TVector<TCandidate> expected = { + {BindingName, "$x"}, + {ColumnName, "Age"}, + {ColumnName, "Name"}, + }; + + TEnvironment env = { + .Parameters = {{"$x", "/people"}}, + }; + + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(3, engine, queries[0], env), expected); + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(3, engine, queries[1], env), expected); + } + { + TString query = R"sql( + USE example; + SELECT # FROM $x; + )sql"; + + TEnvironment env = { + .Parameters = {{"$x", "/people"}}, + }; + + TVector<TCandidate> expected = { + {Keyword, "ALL"}, + }; + + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(1, engine, query, {}), expected); + UNIT_ASSERT_VALUES_EQUAL(CompleteTop(1, engine, query, env), expected); + } + } + Y_UNIT_TEST(ColumnPositions) { auto engine = MakeSqlCompletionEngineUT(); |