aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/sql/v1/complete/sql_context.cpp
blob: 18f676e40b7147a361bb60234ffa393400ca48d8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include "sql_context.h"

#include "c3_engine.h"
#include "sql_syntax.h"

#include <yql/essentials/parser/antlr_ast/gen/v1_antlr4/SQLv1Antlr4Lexer.h>
#include <yql/essentials/parser/antlr_ast/gen/v1_antlr4/SQLv1Antlr4Parser.h>
#include <yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4/SQLv1Antlr4Lexer.h>
#include <yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4/SQLv1Antlr4Parser.h>

#include <util/generic/algorithm.h>
#include <util/stream/output.h>

namespace NSQLComplete {

    template <bool IsAnsiLexer>
    class TSpecializedSqlContextInference: public ISqlContextInference {
    private:
        using TDefaultYQLGrammar = TAntlrGrammar<
            NALADefaultAntlr4::SQLv1Antlr4Lexer,
            NALADefaultAntlr4::SQLv1Antlr4Parser>;

        using TAnsiYQLGrammar = TAntlrGrammar<
            NALAAnsiAntlr4::SQLv1Antlr4Lexer,
            NALAAnsiAntlr4::SQLv1Antlr4Parser>;

        using G = std::conditional_t<
            IsAnsiLexer,
            TAnsiYQLGrammar,
            TDefaultYQLGrammar>;

    public:
        TSpecializedSqlContextInference()
            : Grammar(&GetSqlGrammar(IsAnsiLexer))
            , C3(ComputeC3Config())
        {
        }

        TCompletionContext Analyze(TCompletionInput input) override {
            auto prefix = input.Text.Head(input.CursorPosition);
            auto tokens = C3.Complete(prefix);
            FilterIdKeywords(tokens);
            return {
                .Keywords = SiftedKeywords(tokens),
            };
        }

    private:
        IC3Engine::TConfig ComputeC3Config() {
            return {
                .IgnoredTokens = ComputeIgnoredTokens(),
                .PreferredRules = ComputePreferredRules(),
            };
        }

        std::unordered_set<TTokenId> ComputeIgnoredTokens() {
            auto ignoredTokens = Grammar->GetAllTokens();
            for (auto keywordToken : Grammar->GetKeywordTokens()) {
                ignoredTokens.erase(keywordToken);
            }
            return ignoredTokens;
        }

        std::unordered_set<TRuleId> ComputePreferredRules() {
            const auto& keywordRules = Grammar->GetKeywordRules();

            std::unordered_set<TRuleId> preferredRules;
            preferredRules.insert(std::begin(keywordRules), std::end(keywordRules));
            return preferredRules;
        }

        void FilterIdKeywords(TVector<TSuggestedToken>& tokens) {
            const auto& keywordRules = Grammar->GetKeywordRules();
            auto [first, last] = std::ranges::remove_if(tokens, [&](const TSuggestedToken& token) {
                return AnyOf(token.ParserCallStack, [&](TRuleId rule) {
                    return Find(keywordRules, rule) != std::end(keywordRules);
                });
            });
            tokens.erase(first, last);
        }

        TVector<TString> SiftedKeywords(const TVector<TSuggestedToken>& tokens) {
            const auto& vocabulary = Grammar->GetVocabulary();
            const auto& keywordTokens = Grammar->GetKeywordTokens();

            TVector<TString> keywords;
            for (const auto& token : tokens) {
                if (keywordTokens.contains(token.Number)) {
                    keywords.emplace_back(vocabulary.getDisplayName(token.Number));
                }
            }
            return keywords;
        }

        const ISqlGrammar* Grammar;
        TC3Engine<G> C3;
    };

    class TSqlContextInference: public ISqlContextInference {
    public:
        TCompletionContext Analyze(TCompletionInput input) override {
            auto isAnsiLexer = IsAnsiQuery(TString(input.Text));
            auto& engine = GetSpecializedEngine(isAnsiLexer);
            return engine.Analyze(std::move(input));
        }

    private:
        ISqlContextInference& GetSpecializedEngine(bool isAnsiLexer) {
            if (isAnsiLexer) {
                return AnsiEngine;
            }
            return DefaultEngine;
        }

        TSpecializedSqlContextInference</* IsAnsiLexer = */ false> DefaultEngine;
        TSpecializedSqlContextInference</* IsAnsiLexer = */ true> AnsiEngine;
    };

    ISqlContextInference::TPtr MakeSqlContextInference() {
        return TSqlContextInference::TPtr(new TSqlContextInference());
    }

} // namespace NSQLComplete