diff options
author | vityaman <[email protected]> | 2025-04-02 15:37:42 +0300 |
---|---|---|
committer | robot-piglet <[email protected]> | 2025-04-02 15:53:10 +0300 |
commit | 9728f9489a1c25e2b2e1e7959fa662a389f68db8 (patch) | |
tree | ce9061c436c9e73f7a6119119349f0be2c7b4b61 /yql/essentials/sql/v1 | |
parent | 18a74248135f1108b545fb7e04607445925b764d (diff) |
YQL-19747 Introduce types and functions ranking
- [x] Fix bug with incorrect no-case sorting.
- [x] Get names from `sql_functions.json` and `types.json`.
- [x] Add types and functions ranking according to `rules_corr_basic.json` data via a `PartialSort`.
- [x] Add benchmark workspace.
---
Pull Request resolved: https://github.com/ytsaurus/ytsaurus/pull/1167
commit_hash:84d93265fb69bf5651f905d6af038056657e9a16
Diffstat (limited to 'yql/essentials/sql/v1')
18 files changed, 526 insertions, 147 deletions
diff --git a/yql/essentials/sql/v1/complete/bench/main.cpp b/yql/essentials/sql/v1/complete/bench/main.cpp new file mode 100644 index 00000000000..ace37c43e35 --- /dev/null +++ b/yql/essentials/sql/v1/complete/bench/main.cpp @@ -0,0 +1,40 @@ +#include <benchmark/benchmark.h> + +#include <yql/essentials/sql/v1/complete/name/static/name_service.h> +#include <yql/essentials/sql/v1/complete/name/static/ranking.h> +#include <yql/essentials/sql/v1/complete/sql_complete.h> + +#include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h> +#include <yql/essentials/sql/v1/lexer/antlr4_pure_ansi/lexer.h> + +#include <util/generic/xrange.h> +#include <util/system/compiler.h> + +namespace NSQLComplete { + + NSQLComplete::TLexerSupplier MakePureLexerSupplier() { + NSQLTranslationV1::TLexers lexers; + lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory(); + lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory(); + return [lexers = std::move(lexers)](bool ansi) { + return NSQLTranslationV1::MakeLexer( + lexers, ansi, /* antlr4 = */ true, + NSQLTranslationV1::ELexerFlavor::Pure); + }; + } + + void BenchmarkComplete(benchmark::State& state) { + auto names = NSQLComplete::MakeDefaultNameSet(); + auto ranking = NSQLComplete::MakeDefaultRanking(); + auto service = MakeStaticNameService(std::move(names), std::move(ranking)); + auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(service)); + + for (const auto _ : state) { + auto completion = engine->Complete({"SELECT "}); + benchmark::DoNotOptimize(completion); + } + } + +} // namespace NSQLComplete + +BENCHMARK(NSQLComplete::BenchmarkComplete); diff --git a/yql/essentials/sql/v1/complete/bench/ya.make b/yql/essentials/sql/v1/complete/bench/ya.make new file mode 100644 index 00000000000..aec2f394c82 --- /dev/null +++ b/yql/essentials/sql/v1/complete/bench/ya.make @@ -0,0 +1,13 @@ +G_BENCHMARK() + +SRCS( + main.cpp +) + +PEERDIR( + yql/essentials/sql/v1/complete + yql/essentials/sql/v1/complete/name/static + yql/essentials/sql/v1/lexer +) + +END() diff --git a/yql/essentials/sql/v1/complete/name/static/default_name_set.cpp b/yql/essentials/sql/v1/complete/name/static/default_name_set.cpp deleted file mode 100644 index f04c9180a3c..00000000000 --- a/yql/essentials/sql/v1/complete/name/static/default_name_set.cpp +++ /dev/null @@ -1,73 +0,0 @@ -#include "name_service.h" - -namespace NSQLComplete { - - // TODO(YQL-19747): Use some name registry - NameSet MakeDefaultNameSet() { - return { - .Types = { - "Bool", - "Int8", - "Uint8", - "Int16", - "Uint16", - "Int32", - "Uint32", - "Int64", - "Uint64", - "Float", - "Double", - "String", - "Utf8", - "Yson", - "Json", - "Uuid", - "JsonDocument", - "Date", - "Datetime", - "Timestamp", - "Interval", - "TzDate", - "TzDatetime", - "TzTimestamp", - "Date32", - "Datetime64", - "Timestamp64", - "Interval64", - "TzDate32", - "TzDatetime64", - "TzTimestamp64", - "Decimal", - "DyNumber", - }, - .Functions = { - "COALESCE", - "LENGTH", - "SUBSTRING", - "FIND", - "RFIND", - "StartsWith", - "EndsWith", - "IF", - "NANVL", - "Random", - "RandomNumber", - "RandomUuid", - "CurrentUtcDate", - "CurrentUtcDatetime", - "CurrentUtcTimestamp", - "CurrentTzDate", - "CurrentTzDatetime", - "CurrentTzTimestamp", - "AddTimezone", - "RemoveTimezone", - "Version", - "MAX_OF", - "MIN_OF", - "GREATEST", - "LEAST", - }, - }; - } - -} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/frequency.cpp b/yql/essentials/sql/v1/complete/name/static/frequency.cpp new file mode 100644 index 00000000000..d9c8ba9652c --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/frequency.cpp @@ -0,0 +1,87 @@ +#include "frequency.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/resource/resource.h> + +#include <util/charset/utf8.h> + +namespace NSQLComplete { + + constexpr struct { + struct { + const char* Parent = "parent"; + const char* Rule = "rule"; + const char* Sum = "sum"; + } Key; + struct { + const char* Type = "TYPE"; + const char* Func = "FUNC"; + const char* Module = "MODULE"; + const char* ModuleFunc = "MODULE_FUNC"; + } Parent; + } Json; + + struct TFrequencyItem { + TString Parent; + TString Rule; + size_t Sum; + + static TFrequencyItem ParseJsonMap(NJson::TJsonValue::TMapType&& json) { + return { + .Parent = json.at(Json.Key.Parent).GetStringSafe(), + .Rule = json.at(Json.Key.Rule).GetStringSafe(), + .Sum = json.at(Json.Key.Sum).GetUIntegerSafe(), + }; + } + + static TVector<TFrequencyItem> ParseListFromJsonArray(NJson::TJsonValue::TArray& json) { + TVector<TFrequencyItem> items; + items.reserve(json.size()); + for (auto& element : json) { + auto item = TFrequencyItem::ParseJsonMap(std::move(element.GetMapSafe())); + items.emplace_back(std::move(item)); + } + return items; + } + + static TVector<TFrequencyItem> ParseListFromJsonText(const TStringBuf text) { + NJson::TJsonValue json = NJson::ReadJsonFastTree(text); + return ParseListFromJsonArray(json.GetArraySafe()); + } + }; + + TFrequencyData Convert(TVector<TFrequencyItem> items) { + TFrequencyData data; + for (auto& item : items) { + if (item.Parent == Json.Parent.Type || + item.Parent == Json.Parent.Func || + item.Parent == Json.Parent.ModuleFunc || + item.Parent == Json.Parent.Module) { + item.Rule = ToLowerUTF8(item.Rule); + } + + if (item.Parent == Json.Parent.Type) { + data.Types[item.Rule] += item.Sum; + } else if (item.Parent == Json.Parent.Func || + item.Parent == Json.Parent.ModuleFunc) { + data.Functions[item.Rule] += item.Sum; + } else if (item.Parent == Json.Parent.Module) { + // Ignore, unsupported: Modules + } else { + // Ignore, unsupported: Parser Call Stacks + } + } + return data; + } + + TFrequencyData ParseJsonFrequencyData(const TStringBuf text) { + return Convert(TFrequencyItem::ParseListFromJsonText(text)); + } + + TFrequencyData LoadFrequencyData() { + TString text; + Y_ENSURE(NResource::FindExact("rules_corr_basic.json", &text)); + return ParseJsonFrequencyData(text); + } + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/frequency.h b/yql/essentials/sql/v1/complete/name/static/frequency.h new file mode 100644 index 00000000000..3d128f824b4 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/frequency.h @@ -0,0 +1,17 @@ +#pragma once + +#include <util/generic/string.h> +#include <util/generic/hash.h> + +namespace NSQLComplete { + + struct TFrequencyData { + THashMap<TString, size_t> Types; + THashMap<TString, size_t> Functions; + }; + + TFrequencyData ParseJsonFrequencyData(const TStringBuf text); + + TFrequencyData LoadFrequencyData(); + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp b/yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp new file mode 100644 index 00000000000..dd6ee2cfbb2 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp @@ -0,0 +1,37 @@ +#include "frequency.h" + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NSQLComplete; + +Y_UNIT_TEST_SUITE(FrequencyTests) { + + Y_UNIT_TEST(FrequencyDataJson) { + TFrequencyData actual = ParseJsonFrequencyData(R"([ + {"parent":"FUNC","rule":"ABC","sum":1}, + {"parent":"TYPE","rule":"BIGINT","sum":7101}, + {"parent":"MODULE_FUNC","rule":"Compress::BZip2","sum":2}, + {"parent":"MODULE","rule":"re2","sum":3094}, + {"parent":"TRule_action_or_subquery_args","rule":"TRule_action_or_subquery_args.Block2","sum":4874480} + ])"); + + TFrequencyData expected = { + .Types = { + {"bigint", 7101}, + }, + .Functions = { + {"abc", 1}, + {"compress::bzip2", 2}, + }, + }; + + UNIT_ASSERT_VALUES_EQUAL(actual.Types, expected.Types); + UNIT_ASSERT_VALUES_EQUAL(actual.Functions, expected.Functions); + } + + Y_UNIT_TEST(FrequencyDataResouce) { + TFrequencyData data = LoadFrequencyData(); + Y_UNUSED(data); + } + +} // Y_UNIT_TEST_SUITE(FrequencyTests) diff --git a/yql/essentials/sql/v1/complete/name/static/json_name_set.cpp b/yql/essentials/sql/v1/complete/name/static/json_name_set.cpp new file mode 100644 index 00000000000..29c303b3102 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/json_name_set.cpp @@ -0,0 +1,58 @@ +#include "name_service.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/resource/resource.h> + +namespace NSQLComplete { + + NJson::TJsonValue LoadJsonResource(const TStringBuf filename) { + TString text; + Y_ENSURE(NResource::FindExact(filename, &text)); + return NJson::ReadJsonFastTree(text); + } + + template <class T, class U> + T Merge(T lhs, U rhs) { + std::copy(std::begin(rhs), std::end(rhs), std::back_inserter(lhs)); + return lhs; + } + + TVector<TString> ParseNames(NJson::TJsonValue::TArray& json) { + TVector<TString> keys; + keys.reserve(json.size()); + for (auto& item : json) { + keys.emplace_back(item.GetMapSafe().at("name").GetStringSafe()); + } + return keys; + } + + TVector<TString> ParseTypes(NJson::TJsonValue json) { + return ParseNames(json.GetArraySafe()); + } + + TVector<TString> ParseFunctions(NJson::TJsonValue json) { + return ParseNames(json.GetArraySafe()); + } + + TVector<TString> ParseUfs(NJson::TJsonValue json) { + TVector<TString> names; + for (auto& [module, v] : json.GetMapSafe()) { + auto functions = ParseNames(v.GetArraySafe()); + for (auto& function : functions) { + function.prepend("::").prepend(module); + } + std::copy(std::begin(functions), std::end(functions), std::back_inserter(names)); + } + return names; + } + + NameSet MakeDefaultNameSet() { + return { + .Types = ParseTypes(LoadJsonResource("types.json")), + .Functions = Merge( + ParseFunctions(LoadJsonResource("sql_functions.json")), + ParseUfs(LoadJsonResource("udfs_basic.json"))), + }; + } + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/name_service.cpp b/yql/essentials/sql/v1/complete/name/static/name_service.cpp index fa2bbfb360b..fdb1dd4eae1 100644 --- a/yql/essentials/sql/v1/complete/name/static/name_service.cpp +++ b/yql/essentials/sql/v1/complete/name/static/name_service.cpp @@ -1,18 +1,30 @@ #include "name_service.h" +#include "ranking.h" + namespace NSQLComplete { + bool NoCaseCompare(const TString& lhs, const TString& rhs) { + return std::lexicographical_compare( + std::begin(lhs), std::end(lhs), + std::begin(rhs), std::end(rhs), + [](const char lhs, const char rhs) { + return ToLower(lhs) < ToLower(rhs); + }); + } + + auto NoCaseCompareLimit(size_t size) { + return [size](const TString& lhs, const TString& rhs) -> bool { + return strncasecmp(lhs.data(), rhs.data(), size) < 0; + }; + } + const TVector<TStringBuf> FilteredByPrefix( const TString& prefix, const TVector<TString>& sorted Y_LIFETIME_BOUND) { auto [first, last] = EqualRange( - std::begin(sorted), - std::end(sorted), - prefix, - [&](const TString& lhs, const TString& rhs) { - return strncasecmp(lhs.data(), rhs.data(), prefix.size()) < 0; - }); - + std::begin(sorted), std::end(sorted), + prefix, NoCaseCompareLimit(prefix.size())); return TVector<TStringBuf>(first, last); } @@ -23,55 +35,23 @@ namespace NSQLComplete { } } - size_t KindWeight(const TGenericName& name) { - return std::visit([](const auto& name) { - using T = std::decay_t<decltype(name)>; - if constexpr (std::is_same_v<T, TFunctionName>) { - return 1; - } - if constexpr (std::is_same_v<T, TTypeName>) { - return 2; - } - }, name); - } - - const TStringBuf ContentView(const TGenericName& name Y_LIFETIME_BOUND) { - return std::visit([](const auto& name) -> TStringBuf { - using T = std::decay_t<decltype(name)>; - if constexpr (std::is_base_of_v<TIndentifier, T>) { - return name.Indentifier; - } - }, name); - } - - void Sort(TVector<TGenericName>& names) { - Sort(names, [](const TGenericName& lhs, const TGenericName& rhs) { - const auto lhs_weight = KindWeight(lhs); - const auto lhs_content = ContentView(lhs); - - const auto rhs_weight = KindWeight(rhs); - const auto rhs_content = ContentView(rhs); - - return std::tie(lhs_weight, lhs_content) < - std::tie(rhs_weight, rhs_content); - }); - } - class TStaticNameService: public INameService { public: - explicit TStaticNameService(NameSet names) + explicit TStaticNameService(NameSet names, IRanking::TPtr ranking) : NameSet_(std::move(names)) + , Ranking_(std::move(ranking)) { - Sort(NameSet_.Types); - Sort(NameSet_.Functions); + Sort(NameSet_.Types, NoCaseCompare); + Sort(NameSet_.Functions, NoCaseCompare); } TFuture<TNameResponse> Lookup(TNameRequest request) override { TNameResponse response; if (request.Constraints.TypeName) { - AppendAs<TTypeName>(response.RankedNames, - FilteredByPrefix(request.Prefix, NameSet_.Types)); + AppendAs<TTypeName>( + response.RankedNames, + FilteredByPrefix(request.Prefix, NameSet_.Types)); } if (request.Constraints.Function) { @@ -80,18 +60,22 @@ namespace NSQLComplete { FilteredByPrefix(request.Prefix, NameSet_.Functions)); } - Sort(response.RankedNames); + Ranking_->CropToSortedPrefix(response.RankedNames, request.Limit); - response.RankedNames.crop(request.Limit); return NThreading::MakeFuture(std::move(response)); } private: NameSet NameSet_; + IRanking::TPtr Ranking_; }; - INameService::TPtr MakeStaticNameService(NameSet names) { - return INameService::TPtr(new TStaticNameService(std::move(names))); + INameService::TPtr MakeStaticNameService() { + return MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking()); + } + + INameService::TPtr MakeStaticNameService(NameSet names, IRanking::TPtr ranking) { + return INameService::TPtr(new TStaticNameService(std::move(names), std::move(ranking))); } } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/name_service.h b/yql/essentials/sql/v1/complete/name/static/name_service.h index 79405598032..a5c90465c83 100644 --- a/yql/essentials/sql/v1/complete/name/static/name_service.h +++ b/yql/essentials/sql/v1/complete/name/static/name_service.h @@ -1,5 +1,7 @@ #pragma once +#include "ranking.h" + #include <yql/essentials/sql/v1/complete/name/name_service.h> namespace NSQLComplete { @@ -11,6 +13,8 @@ namespace NSQLComplete { NameSet MakeDefaultNameSet(); - INameService::TPtr MakeStaticNameService(NameSet names); + INameService::TPtr MakeStaticNameService(); + + INameService::TPtr MakeStaticNameService(NameSet names, IRanking::TPtr ranking); } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/ranking.cpp b/yql/essentials/sql/v1/complete/name/static/ranking.cpp new file mode 100644 index 00000000000..45e6e2b2fa2 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/ranking.cpp @@ -0,0 +1,102 @@ +#include "ranking.h" + +#include "frequency.h" + +#include <yql/essentials/sql/v1/complete/name/name_service.h> + +#include <util/charset/utf8.h> + +namespace NSQLComplete { + + class TRanking: public IRanking { + private: + struct TRow { + TGenericName Name; + size_t Weight; + }; + + public: + TRanking(TFrequencyData frequency) + : Frequency_(std::move(frequency)) + { + } + + void CropToSortedPrefix(TVector<TGenericName>& names, size_t limit) override { + limit = std::min(limit, names.size()); + + TVector<TRow> rows; + rows.reserve(names.size()); + for (TGenericName& name : names) { + size_t weight = Weight(name); + rows.emplace_back(std::move(name), weight); + } + + ::PartialSort( + std::begin(rows), std::begin(rows) + limit, std::end(rows), + [this](const TRow& lhs, const TRow& rhs) { + const size_t lhs_weight = ReversedWeight(lhs.Weight); + const auto lhs_content = ContentView(lhs.Name); + + const size_t rhs_weight = ReversedWeight(rhs.Weight); + const auto rhs_content = ContentView(rhs.Name); + + return std::tie(lhs_weight, lhs_content) < + std::tie(rhs_weight, rhs_content); + }); + + names.crop(limit); + rows.crop(limit); + + for (size_t i = 0; i < limit; ++i) { + names[i] = std::move(rows[i].Name); + } + } + + private: + size_t Weight(const TGenericName& name) const { + return std::visit([this](const auto& name) -> size_t { + using T = std::decay_t<decltype(name)>; + + auto identifier = ToLowerUTF8(ContentView(name)); + + if constexpr (std::is_same_v<T, TFunctionName>) { + if (auto weight = Frequency_.Functions.FindPtr(identifier)) { + return *weight; + } + } + + if constexpr (std::is_same_v<T, TTypeName>) { + if (auto weight = Frequency_.Types.FindPtr(identifier)) { + return *weight; + } + } + + return 0; + }, name); + } + + static size_t ReversedWeight(size_t weight) { + return std::numeric_limits<size_t>::max() - weight; + } + + const TStringBuf ContentView(const TGenericName& name Y_LIFETIME_BOUND) const { + return std::visit([](const auto& name) -> TStringBuf { + using T = std::decay_t<decltype(name)>; + if constexpr (std::is_base_of_v<TIndentifier, T>) { + return name.Indentifier; + } + }, name); + } + + TFrequencyData Frequency_; + }; + + IRanking::TPtr MakeDefaultRanking() { + return IRanking::TPtr(new TRanking(LoadFrequencyData())); + } + + IRanking::TPtr MakeDefaultRanking(TFrequencyData frequency) { + return IRanking::TPtr(new TRanking(frequency)); + } + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/ranking.h b/yql/essentials/sql/v1/complete/name/static/ranking.h new file mode 100644 index 00000000000..e24607eded6 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/ranking.h @@ -0,0 +1,23 @@ +#pragma once + +#include "frequency.h" + +#include <yql/essentials/sql/v1/complete/name/name_service.h> + +#include <util/generic/hash.h> + +namespace NSQLComplete { + + class IRanking { + public: + using TPtr = THolder<IRanking>; + + virtual void CropToSortedPrefix(TVector<TGenericName>& names, size_t limit) = 0; + virtual ~IRanking() = default; + }; + + IRanking::TPtr MakeDefaultRanking(); + + IRanking::TPtr MakeDefaultRanking(TFrequencyData frequency); + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp b/yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp new file mode 100644 index 00000000000..fdd36593361 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp @@ -0,0 +1,14 @@ +#include "ranking.h" + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NSQLComplete; + +Y_UNIT_TEST_SUITE(FrequencyTests) { + + Y_UNIT_TEST(FrequencyDataIsParsable) { + TFrequencyData data = LoadFrequencyData(); + Y_UNUSED(data); + } + +} // Y_UNIT_TEST_SUITE(FrequencyTests) diff --git a/yql/essentials/sql/v1/complete/name/static/ut/ya.make b/yql/essentials/sql/v1/complete/name/static/ut/ya.make new file mode 100644 index 00000000000..60963b761b0 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/static/ut/ya.make @@ -0,0 +1,7 @@ +UNITTEST_FOR(yql/essentials/sql/v1/complete/name/static) + +SRCS( + frequency_ut.cpp +) + +END() diff --git a/yql/essentials/sql/v1/complete/name/static/ya.make b/yql/essentials/sql/v1/complete/name/static/ya.make index bdf97e2412f..639371447af 100644 --- a/yql/essentials/sql/v1/complete/name/static/ya.make +++ b/yql/essentials/sql/v1/complete/name/static/ya.make @@ -1,12 +1,26 @@ LIBRARY() SRCS( - default_name_set.cpp + frequency.cpp + json_name_set.cpp name_service.cpp + ranking.cpp ) PEERDIR( yql/essentials/sql/v1/complete/name + yql/essentials/sql/v1/complete/text +) + +RESOURCE( + yql/essentials/data/language/types.json types.json + yql/essentials/data/language/sql_functions.json sql_functions.json + yql/essentials/data/language/udfs_basic.json udfs_basic.json + yql/essentials/data/language/rules_corr_basic.json rules_corr_basic.json ) END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/yql/essentials/sql/v1/complete/sql_complete.cpp b/yql/essentials/sql/v1/complete/sql_complete.cpp index b3ddda2b23f..b73aafe0a4f 100644 --- a/yql/essentials/sql/v1/complete/sql_complete.cpp +++ b/yql/essentials/sql/v1/complete/sql_complete.cpp @@ -127,11 +127,11 @@ namespace NSQLComplete { lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory(); lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory(); - INameService::TPtr names = MakeStaticNameService(MakeDefaultNameSet()); + INameService::TPtr names = MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking()); return MakeSqlCompletionEngine([lexers = std::move(lexers)](bool ansi) { return NSQLTranslationV1::MakeLexer( - lexers, ansi, /* antlr4 = */ true, + lexers, ansi, /* antlr4 = */ true, NSQLTranslationV1::ELexerFlavor::Pure); }, std::move(names)); } diff --git a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp index 5f07d5c3388..1714ed47471 100644 --- a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp +++ b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp @@ -1,7 +1,9 @@ #include "sql_complete.h" #include <yql/essentials/sql/v1/complete/name/fallback/name_service.h> +#include <yql/essentials/sql/v1/complete/name/static/frequency.h> #include <yql/essentials/sql/v1/complete/name/static/name_service.h> +#include <yql/essentials/sql/v1/complete/name/static/ranking.h> #include <yql/essentials/sql/v1/lexer/lexer.h> #include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h> @@ -45,18 +47,20 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory(); return [lexers = std::move(lexers)](bool ansi) { return NSQLTranslationV1::MakeLexer( - lexers, ansi, /* antlr4 = */ true, + lexers, ansi, /* antlr4 = */ true, NSQLTranslationV1::ELexerFlavor::Pure); }; } ISqlCompletionEngine::TPtr MakeSqlCompletionEngineUT() { TLexerSupplier lexer = MakePureLexerSupplier(); - INameService::TPtr names = MakeStaticNameService({ + NameSet names = { .Types = {"Uint64"}, .Functions = {"StartsWith"}, - }); - return MakeSqlCompletionEngine(std::move(lexer), std::move(names)); + }; + auto ranking = MakeDefaultRanking({}); + INameService::TPtr service = MakeStaticNameService(std::move(names), std::move(ranking)); + return MakeSqlCompletionEngine(std::move(lexer), std::move(service)); } TVector<TCandidate> Complete(ISqlCompletionEngine::TPtr& engine, TStringBuf prefix) { @@ -426,23 +430,24 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { Y_UNIT_TEST(InvalidStatementsRecovery) { auto engine = MakeSqlCompletionEngineUT(); - UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "select select; ").size(), 35); - UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "select select;").size(), 35); + UNIT_ASSERT_GE(Complete(engine, "select select; ").size(), 35); + UNIT_ASSERT_GE(Complete(engine, "select select;").size(), 35); UNIT_ASSERT_VALUES_EQUAL_C(Complete(engine, "!;").size(), 0, "Lexer failing"); } - Y_UNIT_TEST(DefaultNameSet) { + Y_UNIT_TEST(DefaultNameService) { auto set = MakeDefaultNameSet(); - auto service = MakeStaticNameService(std::move(set)); + auto service = MakeStaticNameService(std::move(set), MakeDefaultRanking()); auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(service)); { TVector<TCandidate> expected = { - {TypeName, "Uint16"}, - {TypeName, "Uint32"}, {TypeName, "Uint64"}, - {TypeName, "Uint8"}, + {TypeName, "Uint32"}, {TypeName, "Utf8"}, {TypeName, "Uuid"}, + {TypeName, "Uint8"}, + {TypeName, "Unit"}, + {TypeName, "Uint16"}, }; UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT OPTIONAL<U"}), expected); } @@ -469,14 +474,62 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { auto silent = MakeHolder<TSilentNameService>(); auto primary = MakeDeadlinedNameService(std::move(silent), TDuration::MilliSeconds(1)); - auto standby = MakeStaticNameService(MakeDefaultNameSet()); + auto standby = MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking({})); auto fallback = MakeFallbackNameService(std::move(primary), std::move(standby)); auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(fallback)); - UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT CAST (1 AS U"}).size(), 6); - UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT CAST (1 AS "}).size(), 47); - UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT "}).size(), 55); + UNIT_ASSERT_GE(Complete(engine, {"SELECT CAST (1 AS U"}).size(), 6); + UNIT_ASSERT_GE(Complete(engine, {"SELECT CAST (1 AS "}).size(), 47); + UNIT_ASSERT_GE(Complete(engine, {"SELECT "}).size(), 55); + } + + Y_UNIT_TEST(Ranking) { + TFrequencyData frequency = { + .Types = { + {"int32", 128}, + {"int64", 64}, + {"interval", 32}, + {"interval64", 32}, + }, + .Functions = { + {"min", 128}, + {"max", 64}, + {"maxof", 64}, + {"minby", 32}, + {"maxby", 32}, + }, + }; + auto service = MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking(frequency)); + auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(service)); + { + TVector<TCandidate> expected = { + {TypeName, "Int32"}, + {TypeName, "Int64"}, + {TypeName, "Interval"}, + {TypeName, "Interval64"}, + {TypeName, "Int16"}, + {TypeName, "Int8"}, + }; + UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT OPTIONAL<I"}), expected); + } + { + TVector<TCandidate> expectedPrefix = { + {FunctionName, "Min"}, + {FunctionName, "Max"}, + {FunctionName, "MaxOf"}, + {FunctionName, "MaxBy"}, + {FunctionName, "MinBy"}, + {FunctionName, "Math::Abs"}, + {FunctionName, "Math::Acos"}, + {FunctionName, "Math::Asin"}, + }; + + auto actualPrefix = Complete(engine, {"SELECT m"}); + actualPrefix.crop(expectedPrefix.size()); + + UNIT_ASSERT_VALUES_EQUAL(actualPrefix, expectedPrefix); + } } } // Y_UNIT_TEST_SUITE(SqlCompleteTests) diff --git a/yql/essentials/sql/v1/complete/syntax/grammar.cpp b/yql/essentials/sql/v1/complete/syntax/grammar.cpp index 4274d4bfb44..c8f5a2e4a8f 100644 --- a/yql/essentials/sql/v1/complete/syntax/grammar.cpp +++ b/yql/essentials/sql/v1/complete/syntax/grammar.cpp @@ -1,6 +1,6 @@ #include "grammar.h" -#include <yql/essentials/sql/v1/format/sql_format.h> +#include <yql/essentials/sql/v1/reflect/sql_reflect.h> namespace NSQLComplete { @@ -44,7 +44,7 @@ namespace NSQLComplete { std::unordered_set<TTokenId> ComputeKeywordTokens() { const auto& vocabulary = GetVocabulary(); - const auto keywords = NSQLFormat::GetKeywords(); + const auto keywords = NSQLReflect::LoadLexerGrammar().KeywordNames; auto keywordTokens = GetAllTokens(); std::erase_if(keywordTokens, [&](TTokenId token) { diff --git a/yql/essentials/sql/v1/complete/syntax/ya.make b/yql/essentials/sql/v1/complete/syntax/ya.make index a3fe973e315..24fd94a952a 100644 --- a/yql/essentials/sql/v1/complete/syntax/ya.make +++ b/yql/essentials/sql/v1/complete/syntax/ya.make @@ -16,12 +16,11 @@ PEERDIR( yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4 yql/essentials/parser/antlr_ast/gen/v1_antlr4 + yql/essentials/parser/lexer_common yql/essentials/sql/settings yql/essentials/sql/v1/lexer - - # TODO(YQL-19747): Replace with the sql/v1/reflect to get keywords - yql/essentials/sql/v1/format + yql/essentials/sql/v1/reflect ) END() |