aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/sql/v1/complete/sql_antlr4.cpp
blob: 33c847f3e219285d73048962409e8d30e3a31f0e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include "sql_antlr4.h"

#include <yql/essentials/sql/v1/format/sql_format.h>

#include <yql/essentials/parser/antlr_ast/gen/v1_antlr4/SQLv1Antlr4Lexer.h>
#include <yql/essentials/parser/antlr_ast/gen/v1_antlr4/SQLv1Antlr4Parser.h>
#include <yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4/SQLv1Antlr4Lexer.h>
#include <yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4/SQLv1Antlr4Parser.h>

#define RULE_(mode, name) NALA##mode##Antlr4::SQLv1Antlr4Parser::Rule##name

#define RULE(name) RULE_(Default, name)

#define STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(name) \
    static_assert(RULE_(Default, name) == RULE_(Ansi, name))

namespace NSQLComplete {

    class TSqlGrammar: public ISqlGrammar {
    public:
        TSqlGrammar(bool isAnsiLexer)
            : Vocabulary(GetVocabulary(isAnsiLexer))
            , AllTokens(ComputeAllTokens())
            , KeywordTokens(ComputeKeywordTokens())
        {
        }

        const antlr4::dfa::Vocabulary& GetVocabulary() const override {
            return *Vocabulary;
        }

        const std::unordered_set<TTokenId>& GetAllTokens() const override {
            return AllTokens;
        }

        const std::unordered_set<TTokenId>& GetKeywordTokens() const override {
            return KeywordTokens;
        }

        const TVector<TRuleId>& GetKeywordRules() const override {
            static const TVector<TRuleId> KeywordRules = {
                RULE(Keyword),
                RULE(Keyword_expr_uncompat),
                RULE(Keyword_table_uncompat),
                RULE(Keyword_select_uncompat),
                RULE(Keyword_alter_uncompat),
                RULE(Keyword_in_uncompat),
                RULE(Keyword_window_uncompat),
                RULE(Keyword_hint_uncompat),
                RULE(Keyword_as_compat),
                RULE(Keyword_compat),
            };

            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_expr_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_table_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_select_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_alter_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_in_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_window_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_hint_uncompat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_as_compat);
            STATIC_ASSERT_RULE_ID_MODE_INDEPENDENT(Keyword_compat);

            return KeywordRules;
        }

    private:
        static const antlr4::dfa::Vocabulary* GetVocabulary(bool isAnsiLexer) {
            if (isAnsiLexer) { // Taking a reference is okay as vocabulary storage is static
                return &NALAAnsiAntlr4::SQLv1Antlr4Parser(nullptr).getVocabulary();
            }
            return &NALADefaultAntlr4::SQLv1Antlr4Parser(nullptr).getVocabulary();
        }

        std::unordered_set<TTokenId> ComputeAllTokens() {
            const auto& vocabulary = GetVocabulary();

            std::unordered_set<TTokenId> allTokens;

            for (size_t type = 1; type <= vocabulary.getMaxTokenType(); ++type) {
                allTokens.emplace(type);
            }

            return allTokens;
        }

        std::unordered_set<TTokenId> ComputeKeywordTokens() {
            const auto& vocabulary = GetVocabulary();
            const auto keywords = NSQLFormat::GetKeywords();

            auto keywordTokens = GetAllTokens();
            std::erase_if(keywordTokens, [&](TTokenId token) {
                return !keywords.contains(vocabulary.getSymbolicName(token));
            });
            keywordTokens.erase(TOKEN_EOF);

            return keywordTokens;
        }

        const antlr4::dfa::Vocabulary* Vocabulary;
        const std::unordered_set<TTokenId> AllTokens;
        const std::unordered_set<TTokenId> KeywordTokens;
    };

    const ISqlGrammar& GetSqlGrammar(bool isAnsiLexer) {
        const static TSqlGrammar DefaultSqlGrammar(/* isAnsiLexer = */ false);
        const static TSqlGrammar AnsiSqlGrammar(/* isAnsiLexer = */ true);

        if (isAnsiLexer) {
            return AnsiSqlGrammar;
        }
        return DefaultSqlGrammar;
    }

} // namespace NSQLComplete