summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvityaman <[email protected]>2025-04-02 15:37:42 +0300
committerrobot-piglet <[email protected]>2025-04-02 15:53:10 +0300
commit9728f9489a1c25e2b2e1e7959fa662a389f68db8 (patch)
treece9061c436c9e73f7a6119119349f0be2c7b4b61
parent18a74248135f1108b545fb7e04607445925b764d (diff)
YQL-19747 Introduce types and functions ranking
- [x] Fix bug with incorrect no-case sorting. - [x] Get names from `sql_functions.json` and `types.json`. - [x] Add types and functions ranking according to `rules_corr_basic.json` data via a `PartialSort`. - [x] Add benchmark workspace. --- Pull Request resolved: https://github.com/ytsaurus/ytsaurus/pull/1167 commit_hash:84d93265fb69bf5651f905d6af038056657e9a16
-rw-r--r--yql/essentials/sql/v1/complete/bench/main.cpp40
-rw-r--r--yql/essentials/sql/v1/complete/bench/ya.make13
-rw-r--r--yql/essentials/sql/v1/complete/name/static/default_name_set.cpp73
-rw-r--r--yql/essentials/sql/v1/complete/name/static/frequency.cpp87
-rw-r--r--yql/essentials/sql/v1/complete/name/static/frequency.h17
-rw-r--r--yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp37
-rw-r--r--yql/essentials/sql/v1/complete/name/static/json_name_set.cpp58
-rw-r--r--yql/essentials/sql/v1/complete/name/static/name_service.cpp84
-rw-r--r--yql/essentials/sql/v1/complete/name/static/name_service.h6
-rw-r--r--yql/essentials/sql/v1/complete/name/static/ranking.cpp102
-rw-r--r--yql/essentials/sql/v1/complete/name/static/ranking.h23
-rw-r--r--yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp14
-rw-r--r--yql/essentials/sql/v1/complete/name/static/ut/ya.make7
-rw-r--r--yql/essentials/sql/v1/complete/name/static/ya.make16
-rw-r--r--yql/essentials/sql/v1/complete/sql_complete.cpp4
-rw-r--r--yql/essentials/sql/v1/complete/sql_complete_ut.cpp83
-rw-r--r--yql/essentials/sql/v1/complete/syntax/grammar.cpp4
-rw-r--r--yql/essentials/sql/v1/complete/syntax/ya.make5
-rw-r--r--yql/essentials/tools/yql_complete/ya.make2
-rw-r--r--yql/essentials/tools/yql_complete/yql_complete.cpp40
20 files changed, 566 insertions, 149 deletions
diff --git a/yql/essentials/sql/v1/complete/bench/main.cpp b/yql/essentials/sql/v1/complete/bench/main.cpp
new file mode 100644
index 00000000000..ace37c43e35
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/bench/main.cpp
@@ -0,0 +1,40 @@
+#include <benchmark/benchmark.h>
+
+#include <yql/essentials/sql/v1/complete/name/static/name_service.h>
+#include <yql/essentials/sql/v1/complete/name/static/ranking.h>
+#include <yql/essentials/sql/v1/complete/sql_complete.h>
+
+#include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h>
+#include <yql/essentials/sql/v1/lexer/antlr4_pure_ansi/lexer.h>
+
+#include <util/generic/xrange.h>
+#include <util/system/compiler.h>
+
+namespace NSQLComplete {
+
+ NSQLComplete::TLexerSupplier MakePureLexerSupplier() {
+ NSQLTranslationV1::TLexers lexers;
+ lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory();
+ lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory();
+ return [lexers = std::move(lexers)](bool ansi) {
+ return NSQLTranslationV1::MakeLexer(
+ lexers, ansi, /* antlr4 = */ true,
+ NSQLTranslationV1::ELexerFlavor::Pure);
+ };
+ }
+
+ void BenchmarkComplete(benchmark::State& state) {
+ auto names = NSQLComplete::MakeDefaultNameSet();
+ auto ranking = NSQLComplete::MakeDefaultRanking();
+ auto service = MakeStaticNameService(std::move(names), std::move(ranking));
+ auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(service));
+
+ for (const auto _ : state) {
+ auto completion = engine->Complete({"SELECT "});
+ benchmark::DoNotOptimize(completion);
+ }
+ }
+
+} // namespace NSQLComplete
+
+BENCHMARK(NSQLComplete::BenchmarkComplete);
diff --git a/yql/essentials/sql/v1/complete/bench/ya.make b/yql/essentials/sql/v1/complete/bench/ya.make
new file mode 100644
index 00000000000..aec2f394c82
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/bench/ya.make
@@ -0,0 +1,13 @@
+G_BENCHMARK()
+
+SRCS(
+ main.cpp
+)
+
+PEERDIR(
+ yql/essentials/sql/v1/complete
+ yql/essentials/sql/v1/complete/name/static
+ yql/essentials/sql/v1/lexer
+)
+
+END()
diff --git a/yql/essentials/sql/v1/complete/name/static/default_name_set.cpp b/yql/essentials/sql/v1/complete/name/static/default_name_set.cpp
deleted file mode 100644
index f04c9180a3c..00000000000
--- a/yql/essentials/sql/v1/complete/name/static/default_name_set.cpp
+++ /dev/null
@@ -1,73 +0,0 @@
-#include "name_service.h"
-
-namespace NSQLComplete {
-
- // TODO(YQL-19747): Use some name registry
- NameSet MakeDefaultNameSet() {
- return {
- .Types = {
- "Bool",
- "Int8",
- "Uint8",
- "Int16",
- "Uint16",
- "Int32",
- "Uint32",
- "Int64",
- "Uint64",
- "Float",
- "Double",
- "String",
- "Utf8",
- "Yson",
- "Json",
- "Uuid",
- "JsonDocument",
- "Date",
- "Datetime",
- "Timestamp",
- "Interval",
- "TzDate",
- "TzDatetime",
- "TzTimestamp",
- "Date32",
- "Datetime64",
- "Timestamp64",
- "Interval64",
- "TzDate32",
- "TzDatetime64",
- "TzTimestamp64",
- "Decimal",
- "DyNumber",
- },
- .Functions = {
- "COALESCE",
- "LENGTH",
- "SUBSTRING",
- "FIND",
- "RFIND",
- "StartsWith",
- "EndsWith",
- "IF",
- "NANVL",
- "Random",
- "RandomNumber",
- "RandomUuid",
- "CurrentUtcDate",
- "CurrentUtcDatetime",
- "CurrentUtcTimestamp",
- "CurrentTzDate",
- "CurrentTzDatetime",
- "CurrentTzTimestamp",
- "AddTimezone",
- "RemoveTimezone",
- "Version",
- "MAX_OF",
- "MIN_OF",
- "GREATEST",
- "LEAST",
- },
- };
- }
-
-} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/frequency.cpp b/yql/essentials/sql/v1/complete/name/static/frequency.cpp
new file mode 100644
index 00000000000..d9c8ba9652c
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/frequency.cpp
@@ -0,0 +1,87 @@
+#include "frequency.h"
+
+#include <library/cpp/json/json_reader.h>
+#include <library/cpp/resource/resource.h>
+
+#include <util/charset/utf8.h>
+
+namespace NSQLComplete {
+
+ constexpr struct {
+ struct {
+ const char* Parent = "parent";
+ const char* Rule = "rule";
+ const char* Sum = "sum";
+ } Key;
+ struct {
+ const char* Type = "TYPE";
+ const char* Func = "FUNC";
+ const char* Module = "MODULE";
+ const char* ModuleFunc = "MODULE_FUNC";
+ } Parent;
+ } Json;
+
+ struct TFrequencyItem {
+ TString Parent;
+ TString Rule;
+ size_t Sum;
+
+ static TFrequencyItem ParseJsonMap(NJson::TJsonValue::TMapType&& json) {
+ return {
+ .Parent = json.at(Json.Key.Parent).GetStringSafe(),
+ .Rule = json.at(Json.Key.Rule).GetStringSafe(),
+ .Sum = json.at(Json.Key.Sum).GetUIntegerSafe(),
+ };
+ }
+
+ static TVector<TFrequencyItem> ParseListFromJsonArray(NJson::TJsonValue::TArray& json) {
+ TVector<TFrequencyItem> items;
+ items.reserve(json.size());
+ for (auto& element : json) {
+ auto item = TFrequencyItem::ParseJsonMap(std::move(element.GetMapSafe()));
+ items.emplace_back(std::move(item));
+ }
+ return items;
+ }
+
+ static TVector<TFrequencyItem> ParseListFromJsonText(const TStringBuf text) {
+ NJson::TJsonValue json = NJson::ReadJsonFastTree(text);
+ return ParseListFromJsonArray(json.GetArraySafe());
+ }
+ };
+
+ TFrequencyData Convert(TVector<TFrequencyItem> items) {
+ TFrequencyData data;
+ for (auto& item : items) {
+ if (item.Parent == Json.Parent.Type ||
+ item.Parent == Json.Parent.Func ||
+ item.Parent == Json.Parent.ModuleFunc ||
+ item.Parent == Json.Parent.Module) {
+ item.Rule = ToLowerUTF8(item.Rule);
+ }
+
+ if (item.Parent == Json.Parent.Type) {
+ data.Types[item.Rule] += item.Sum;
+ } else if (item.Parent == Json.Parent.Func ||
+ item.Parent == Json.Parent.ModuleFunc) {
+ data.Functions[item.Rule] += item.Sum;
+ } else if (item.Parent == Json.Parent.Module) {
+ // Ignore, unsupported: Modules
+ } else {
+ // Ignore, unsupported: Parser Call Stacks
+ }
+ }
+ return data;
+ }
+
+ TFrequencyData ParseJsonFrequencyData(const TStringBuf text) {
+ return Convert(TFrequencyItem::ParseListFromJsonText(text));
+ }
+
+ TFrequencyData LoadFrequencyData() {
+ TString text;
+ Y_ENSURE(NResource::FindExact("rules_corr_basic.json", &text));
+ return ParseJsonFrequencyData(text);
+ }
+
+} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/frequency.h b/yql/essentials/sql/v1/complete/name/static/frequency.h
new file mode 100644
index 00000000000..3d128f824b4
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/frequency.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <util/generic/string.h>
+#include <util/generic/hash.h>
+
+namespace NSQLComplete {
+
+ struct TFrequencyData {
+ THashMap<TString, size_t> Types;
+ THashMap<TString, size_t> Functions;
+ };
+
+ TFrequencyData ParseJsonFrequencyData(const TStringBuf text);
+
+ TFrequencyData LoadFrequencyData();
+
+} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp b/yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp
new file mode 100644
index 00000000000..dd6ee2cfbb2
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/frequency_ut.cpp
@@ -0,0 +1,37 @@
+#include "frequency.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+using namespace NSQLComplete;
+
+Y_UNIT_TEST_SUITE(FrequencyTests) {
+
+ Y_UNIT_TEST(FrequencyDataJson) {
+ TFrequencyData actual = ParseJsonFrequencyData(R"([
+ {"parent":"FUNC","rule":"ABC","sum":1},
+ {"parent":"TYPE","rule":"BIGINT","sum":7101},
+ {"parent":"MODULE_FUNC","rule":"Compress::BZip2","sum":2},
+ {"parent":"MODULE","rule":"re2","sum":3094},
+ {"parent":"TRule_action_or_subquery_args","rule":"TRule_action_or_subquery_args.Block2","sum":4874480}
+ ])");
+
+ TFrequencyData expected = {
+ .Types = {
+ {"bigint", 7101},
+ },
+ .Functions = {
+ {"abc", 1},
+ {"compress::bzip2", 2},
+ },
+ };
+
+ UNIT_ASSERT_VALUES_EQUAL(actual.Types, expected.Types);
+ UNIT_ASSERT_VALUES_EQUAL(actual.Functions, expected.Functions);
+ }
+
+ Y_UNIT_TEST(FrequencyDataResouce) {
+ TFrequencyData data = LoadFrequencyData();
+ Y_UNUSED(data);
+ }
+
+} // Y_UNIT_TEST_SUITE(FrequencyTests)
diff --git a/yql/essentials/sql/v1/complete/name/static/json_name_set.cpp b/yql/essentials/sql/v1/complete/name/static/json_name_set.cpp
new file mode 100644
index 00000000000..29c303b3102
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/json_name_set.cpp
@@ -0,0 +1,58 @@
+#include "name_service.h"
+
+#include <library/cpp/json/json_reader.h>
+#include <library/cpp/resource/resource.h>
+
+namespace NSQLComplete {
+
+ NJson::TJsonValue LoadJsonResource(const TStringBuf filename) {
+ TString text;
+ Y_ENSURE(NResource::FindExact(filename, &text));
+ return NJson::ReadJsonFastTree(text);
+ }
+
+ template <class T, class U>
+ T Merge(T lhs, U rhs) {
+ std::copy(std::begin(rhs), std::end(rhs), std::back_inserter(lhs));
+ return lhs;
+ }
+
+ TVector<TString> ParseNames(NJson::TJsonValue::TArray& json) {
+ TVector<TString> keys;
+ keys.reserve(json.size());
+ for (auto& item : json) {
+ keys.emplace_back(item.GetMapSafe().at("name").GetStringSafe());
+ }
+ return keys;
+ }
+
+ TVector<TString> ParseTypes(NJson::TJsonValue json) {
+ return ParseNames(json.GetArraySafe());
+ }
+
+ TVector<TString> ParseFunctions(NJson::TJsonValue json) {
+ return ParseNames(json.GetArraySafe());
+ }
+
+ TVector<TString> ParseUfs(NJson::TJsonValue json) {
+ TVector<TString> names;
+ for (auto& [module, v] : json.GetMapSafe()) {
+ auto functions = ParseNames(v.GetArraySafe());
+ for (auto& function : functions) {
+ function.prepend("::").prepend(module);
+ }
+ std::copy(std::begin(functions), std::end(functions), std::back_inserter(names));
+ }
+ return names;
+ }
+
+ NameSet MakeDefaultNameSet() {
+ return {
+ .Types = ParseTypes(LoadJsonResource("types.json")),
+ .Functions = Merge(
+ ParseFunctions(LoadJsonResource("sql_functions.json")),
+ ParseUfs(LoadJsonResource("udfs_basic.json"))),
+ };
+ }
+
+} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/name_service.cpp b/yql/essentials/sql/v1/complete/name/static/name_service.cpp
index fa2bbfb360b..fdb1dd4eae1 100644
--- a/yql/essentials/sql/v1/complete/name/static/name_service.cpp
+++ b/yql/essentials/sql/v1/complete/name/static/name_service.cpp
@@ -1,18 +1,30 @@
#include "name_service.h"
+#include "ranking.h"
+
namespace NSQLComplete {
+ bool NoCaseCompare(const TString& lhs, const TString& rhs) {
+ return std::lexicographical_compare(
+ std::begin(lhs), std::end(lhs),
+ std::begin(rhs), std::end(rhs),
+ [](const char lhs, const char rhs) {
+ return ToLower(lhs) < ToLower(rhs);
+ });
+ }
+
+ auto NoCaseCompareLimit(size_t size) {
+ return [size](const TString& lhs, const TString& rhs) -> bool {
+ return strncasecmp(lhs.data(), rhs.data(), size) < 0;
+ };
+ }
+
const TVector<TStringBuf> FilteredByPrefix(
const TString& prefix,
const TVector<TString>& sorted Y_LIFETIME_BOUND) {
auto [first, last] = EqualRange(
- std::begin(sorted),
- std::end(sorted),
- prefix,
- [&](const TString& lhs, const TString& rhs) {
- return strncasecmp(lhs.data(), rhs.data(), prefix.size()) < 0;
- });
-
+ std::begin(sorted), std::end(sorted),
+ prefix, NoCaseCompareLimit(prefix.size()));
return TVector<TStringBuf>(first, last);
}
@@ -23,55 +35,23 @@ namespace NSQLComplete {
}
}
- size_t KindWeight(const TGenericName& name) {
- return std::visit([](const auto& name) {
- using T = std::decay_t<decltype(name)>;
- if constexpr (std::is_same_v<T, TFunctionName>) {
- return 1;
- }
- if constexpr (std::is_same_v<T, TTypeName>) {
- return 2;
- }
- }, name);
- }
-
- const TStringBuf ContentView(const TGenericName& name Y_LIFETIME_BOUND) {
- return std::visit([](const auto& name) -> TStringBuf {
- using T = std::decay_t<decltype(name)>;
- if constexpr (std::is_base_of_v<TIndentifier, T>) {
- return name.Indentifier;
- }
- }, name);
- }
-
- void Sort(TVector<TGenericName>& names) {
- Sort(names, [](const TGenericName& lhs, const TGenericName& rhs) {
- const auto lhs_weight = KindWeight(lhs);
- const auto lhs_content = ContentView(lhs);
-
- const auto rhs_weight = KindWeight(rhs);
- const auto rhs_content = ContentView(rhs);
-
- return std::tie(lhs_weight, lhs_content) <
- std::tie(rhs_weight, rhs_content);
- });
- }
-
class TStaticNameService: public INameService {
public:
- explicit TStaticNameService(NameSet names)
+ explicit TStaticNameService(NameSet names, IRanking::TPtr ranking)
: NameSet_(std::move(names))
+ , Ranking_(std::move(ranking))
{
- Sort(NameSet_.Types);
- Sort(NameSet_.Functions);
+ Sort(NameSet_.Types, NoCaseCompare);
+ Sort(NameSet_.Functions, NoCaseCompare);
}
TFuture<TNameResponse> Lookup(TNameRequest request) override {
TNameResponse response;
if (request.Constraints.TypeName) {
- AppendAs<TTypeName>(response.RankedNames,
- FilteredByPrefix(request.Prefix, NameSet_.Types));
+ AppendAs<TTypeName>(
+ response.RankedNames,
+ FilteredByPrefix(request.Prefix, NameSet_.Types));
}
if (request.Constraints.Function) {
@@ -80,18 +60,22 @@ namespace NSQLComplete {
FilteredByPrefix(request.Prefix, NameSet_.Functions));
}
- Sort(response.RankedNames);
+ Ranking_->CropToSortedPrefix(response.RankedNames, request.Limit);
- response.RankedNames.crop(request.Limit);
return NThreading::MakeFuture(std::move(response));
}
private:
NameSet NameSet_;
+ IRanking::TPtr Ranking_;
};
- INameService::TPtr MakeStaticNameService(NameSet names) {
- return INameService::TPtr(new TStaticNameService(std::move(names)));
+ INameService::TPtr MakeStaticNameService() {
+ return MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking());
+ }
+
+ INameService::TPtr MakeStaticNameService(NameSet names, IRanking::TPtr ranking) {
+ return INameService::TPtr(new TStaticNameService(std::move(names), std::move(ranking)));
}
} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/name_service.h b/yql/essentials/sql/v1/complete/name/static/name_service.h
index 79405598032..a5c90465c83 100644
--- a/yql/essentials/sql/v1/complete/name/static/name_service.h
+++ b/yql/essentials/sql/v1/complete/name/static/name_service.h
@@ -1,5 +1,7 @@
#pragma once
+#include "ranking.h"
+
#include <yql/essentials/sql/v1/complete/name/name_service.h>
namespace NSQLComplete {
@@ -11,6 +13,8 @@ namespace NSQLComplete {
NameSet MakeDefaultNameSet();
- INameService::TPtr MakeStaticNameService(NameSet names);
+ INameService::TPtr MakeStaticNameService();
+
+ INameService::TPtr MakeStaticNameService(NameSet names, IRanking::TPtr ranking);
} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/ranking.cpp b/yql/essentials/sql/v1/complete/name/static/ranking.cpp
new file mode 100644
index 00000000000..45e6e2b2fa2
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/ranking.cpp
@@ -0,0 +1,102 @@
+#include "ranking.h"
+
+#include "frequency.h"
+
+#include <yql/essentials/sql/v1/complete/name/name_service.h>
+
+#include <util/charset/utf8.h>
+
+namespace NSQLComplete {
+
+ class TRanking: public IRanking {
+ private:
+ struct TRow {
+ TGenericName Name;
+ size_t Weight;
+ };
+
+ public:
+ TRanking(TFrequencyData frequency)
+ : Frequency_(std::move(frequency))
+ {
+ }
+
+ void CropToSortedPrefix(TVector<TGenericName>& names, size_t limit) override {
+ limit = std::min(limit, names.size());
+
+ TVector<TRow> rows;
+ rows.reserve(names.size());
+ for (TGenericName& name : names) {
+ size_t weight = Weight(name);
+ rows.emplace_back(std::move(name), weight);
+ }
+
+ ::PartialSort(
+ std::begin(rows), std::begin(rows) + limit, std::end(rows),
+ [this](const TRow& lhs, const TRow& rhs) {
+ const size_t lhs_weight = ReversedWeight(lhs.Weight);
+ const auto lhs_content = ContentView(lhs.Name);
+
+ const size_t rhs_weight = ReversedWeight(rhs.Weight);
+ const auto rhs_content = ContentView(rhs.Name);
+
+ return std::tie(lhs_weight, lhs_content) <
+ std::tie(rhs_weight, rhs_content);
+ });
+
+ names.crop(limit);
+ rows.crop(limit);
+
+ for (size_t i = 0; i < limit; ++i) {
+ names[i] = std::move(rows[i].Name);
+ }
+ }
+
+ private:
+ size_t Weight(const TGenericName& name) const {
+ return std::visit([this](const auto& name) -> size_t {
+ using T = std::decay_t<decltype(name)>;
+
+ auto identifier = ToLowerUTF8(ContentView(name));
+
+ if constexpr (std::is_same_v<T, TFunctionName>) {
+ if (auto weight = Frequency_.Functions.FindPtr(identifier)) {
+ return *weight;
+ }
+ }
+
+ if constexpr (std::is_same_v<T, TTypeName>) {
+ if (auto weight = Frequency_.Types.FindPtr(identifier)) {
+ return *weight;
+ }
+ }
+
+ return 0;
+ }, name);
+ }
+
+ static size_t ReversedWeight(size_t weight) {
+ return std::numeric_limits<size_t>::max() - weight;
+ }
+
+ const TStringBuf ContentView(const TGenericName& name Y_LIFETIME_BOUND) const {
+ return std::visit([](const auto& name) -> TStringBuf {
+ using T = std::decay_t<decltype(name)>;
+ if constexpr (std::is_base_of_v<TIndentifier, T>) {
+ return name.Indentifier;
+ }
+ }, name);
+ }
+
+ TFrequencyData Frequency_;
+ };
+
+ IRanking::TPtr MakeDefaultRanking() {
+ return IRanking::TPtr(new TRanking(LoadFrequencyData()));
+ }
+
+ IRanking::TPtr MakeDefaultRanking(TFrequencyData frequency) {
+ return IRanking::TPtr(new TRanking(frequency));
+ }
+
+} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/ranking.h b/yql/essentials/sql/v1/complete/name/static/ranking.h
new file mode 100644
index 00000000000..e24607eded6
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/ranking.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "frequency.h"
+
+#include <yql/essentials/sql/v1/complete/name/name_service.h>
+
+#include <util/generic/hash.h>
+
+namespace NSQLComplete {
+
+ class IRanking {
+ public:
+ using TPtr = THolder<IRanking>;
+
+ virtual void CropToSortedPrefix(TVector<TGenericName>& names, size_t limit) = 0;
+ virtual ~IRanking() = default;
+ };
+
+ IRanking::TPtr MakeDefaultRanking();
+
+ IRanking::TPtr MakeDefaultRanking(TFrequencyData frequency);
+
+} // namespace NSQLComplete
diff --git a/yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp b/yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp
new file mode 100644
index 00000000000..fdd36593361
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/ranking_ut.cpp
@@ -0,0 +1,14 @@
+#include "ranking.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+using namespace NSQLComplete;
+
+Y_UNIT_TEST_SUITE(FrequencyTests) {
+
+ Y_UNIT_TEST(FrequencyDataIsParsable) {
+ TFrequencyData data = LoadFrequencyData();
+ Y_UNUSED(data);
+ }
+
+} // Y_UNIT_TEST_SUITE(FrequencyTests)
diff --git a/yql/essentials/sql/v1/complete/name/static/ut/ya.make b/yql/essentials/sql/v1/complete/name/static/ut/ya.make
new file mode 100644
index 00000000000..60963b761b0
--- /dev/null
+++ b/yql/essentials/sql/v1/complete/name/static/ut/ya.make
@@ -0,0 +1,7 @@
+UNITTEST_FOR(yql/essentials/sql/v1/complete/name/static)
+
+SRCS(
+ frequency_ut.cpp
+)
+
+END()
diff --git a/yql/essentials/sql/v1/complete/name/static/ya.make b/yql/essentials/sql/v1/complete/name/static/ya.make
index bdf97e2412f..639371447af 100644
--- a/yql/essentials/sql/v1/complete/name/static/ya.make
+++ b/yql/essentials/sql/v1/complete/name/static/ya.make
@@ -1,12 +1,26 @@
LIBRARY()
SRCS(
- default_name_set.cpp
+ frequency.cpp
+ json_name_set.cpp
name_service.cpp
+ ranking.cpp
)
PEERDIR(
yql/essentials/sql/v1/complete/name
+ yql/essentials/sql/v1/complete/text
+)
+
+RESOURCE(
+ yql/essentials/data/language/types.json types.json
+ yql/essentials/data/language/sql_functions.json sql_functions.json
+ yql/essentials/data/language/udfs_basic.json udfs_basic.json
+ yql/essentials/data/language/rules_corr_basic.json rules_corr_basic.json
)
END()
+
+RECURSE_FOR_TESTS(
+ ut
+)
diff --git a/yql/essentials/sql/v1/complete/sql_complete.cpp b/yql/essentials/sql/v1/complete/sql_complete.cpp
index b3ddda2b23f..b73aafe0a4f 100644
--- a/yql/essentials/sql/v1/complete/sql_complete.cpp
+++ b/yql/essentials/sql/v1/complete/sql_complete.cpp
@@ -127,11 +127,11 @@ namespace NSQLComplete {
lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory();
lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory();
- INameService::TPtr names = MakeStaticNameService(MakeDefaultNameSet());
+ INameService::TPtr names = MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking());
return MakeSqlCompletionEngine([lexers = std::move(lexers)](bool ansi) {
return NSQLTranslationV1::MakeLexer(
- lexers, ansi, /* antlr4 = */ true,
+ lexers, ansi, /* antlr4 = */ true,
NSQLTranslationV1::ELexerFlavor::Pure);
}, std::move(names));
}
diff --git a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp
index 5f07d5c3388..1714ed47471 100644
--- a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp
+++ b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp
@@ -1,7 +1,9 @@
#include "sql_complete.h"
#include <yql/essentials/sql/v1/complete/name/fallback/name_service.h>
+#include <yql/essentials/sql/v1/complete/name/static/frequency.h>
#include <yql/essentials/sql/v1/complete/name/static/name_service.h>
+#include <yql/essentials/sql/v1/complete/name/static/ranking.h>
#include <yql/essentials/sql/v1/lexer/lexer.h>
#include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h>
@@ -45,18 +47,20 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) {
lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory();
return [lexers = std::move(lexers)](bool ansi) {
return NSQLTranslationV1::MakeLexer(
- lexers, ansi, /* antlr4 = */ true,
+ lexers, ansi, /* antlr4 = */ true,
NSQLTranslationV1::ELexerFlavor::Pure);
};
}
ISqlCompletionEngine::TPtr MakeSqlCompletionEngineUT() {
TLexerSupplier lexer = MakePureLexerSupplier();
- INameService::TPtr names = MakeStaticNameService({
+ NameSet names = {
.Types = {"Uint64"},
.Functions = {"StartsWith"},
- });
- return MakeSqlCompletionEngine(std::move(lexer), std::move(names));
+ };
+ auto ranking = MakeDefaultRanking({});
+ INameService::TPtr service = MakeStaticNameService(std::move(names), std::move(ranking));
+ return MakeSqlCompletionEngine(std::move(lexer), std::move(service));
}
TVector<TCandidate> Complete(ISqlCompletionEngine::TPtr& engine, TStringBuf prefix) {
@@ -426,23 +430,24 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) {
Y_UNIT_TEST(InvalidStatementsRecovery) {
auto engine = MakeSqlCompletionEngineUT();
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "select select; ").size(), 35);
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "select select;").size(), 35);
+ UNIT_ASSERT_GE(Complete(engine, "select select; ").size(), 35);
+ UNIT_ASSERT_GE(Complete(engine, "select select;").size(), 35);
UNIT_ASSERT_VALUES_EQUAL_C(Complete(engine, "!;").size(), 0, "Lexer failing");
}
- Y_UNIT_TEST(DefaultNameSet) {
+ Y_UNIT_TEST(DefaultNameService) {
auto set = MakeDefaultNameSet();
- auto service = MakeStaticNameService(std::move(set));
+ auto service = MakeStaticNameService(std::move(set), MakeDefaultRanking());
auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(service));
{
TVector<TCandidate> expected = {
- {TypeName, "Uint16"},
- {TypeName, "Uint32"},
{TypeName, "Uint64"},
- {TypeName, "Uint8"},
+ {TypeName, "Uint32"},
{TypeName, "Utf8"},
{TypeName, "Uuid"},
+ {TypeName, "Uint8"},
+ {TypeName, "Unit"},
+ {TypeName, "Uint16"},
};
UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT OPTIONAL<U"}), expected);
}
@@ -469,14 +474,62 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) {
auto silent = MakeHolder<TSilentNameService>();
auto primary = MakeDeadlinedNameService(std::move(silent), TDuration::MilliSeconds(1));
- auto standby = MakeStaticNameService(MakeDefaultNameSet());
+ auto standby = MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking({}));
auto fallback = MakeFallbackNameService(std::move(primary), std::move(standby));
auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(fallback));
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT CAST (1 AS U"}).size(), 6);
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT CAST (1 AS "}).size(), 47);
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT "}).size(), 55);
+ UNIT_ASSERT_GE(Complete(engine, {"SELECT CAST (1 AS U"}).size(), 6);
+ UNIT_ASSERT_GE(Complete(engine, {"SELECT CAST (1 AS "}).size(), 47);
+ UNIT_ASSERT_GE(Complete(engine, {"SELECT "}).size(), 55);
+ }
+
+ Y_UNIT_TEST(Ranking) {
+ TFrequencyData frequency = {
+ .Types = {
+ {"int32", 128},
+ {"int64", 64},
+ {"interval", 32},
+ {"interval64", 32},
+ },
+ .Functions = {
+ {"min", 128},
+ {"max", 64},
+ {"maxof", 64},
+ {"minby", 32},
+ {"maxby", 32},
+ },
+ };
+ auto service = MakeStaticNameService(MakeDefaultNameSet(), MakeDefaultRanking(frequency));
+ auto engine = MakeSqlCompletionEngine(MakePureLexerSupplier(), std::move(service));
+ {
+ TVector<TCandidate> expected = {
+ {TypeName, "Int32"},
+ {TypeName, "Int64"},
+ {TypeName, "Interval"},
+ {TypeName, "Interval64"},
+ {TypeName, "Int16"},
+ {TypeName, "Int8"},
+ };
+ UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SELECT OPTIONAL<I"}), expected);
+ }
+ {
+ TVector<TCandidate> expectedPrefix = {
+ {FunctionName, "Min"},
+ {FunctionName, "Max"},
+ {FunctionName, "MaxOf"},
+ {FunctionName, "MaxBy"},
+ {FunctionName, "MinBy"},
+ {FunctionName, "Math::Abs"},
+ {FunctionName, "Math::Acos"},
+ {FunctionName, "Math::Asin"},
+ };
+
+ auto actualPrefix = Complete(engine, {"SELECT m"});
+ actualPrefix.crop(expectedPrefix.size());
+
+ UNIT_ASSERT_VALUES_EQUAL(actualPrefix, expectedPrefix);
+ }
}
} // Y_UNIT_TEST_SUITE(SqlCompleteTests)
diff --git a/yql/essentials/sql/v1/complete/syntax/grammar.cpp b/yql/essentials/sql/v1/complete/syntax/grammar.cpp
index 4274d4bfb44..c8f5a2e4a8f 100644
--- a/yql/essentials/sql/v1/complete/syntax/grammar.cpp
+++ b/yql/essentials/sql/v1/complete/syntax/grammar.cpp
@@ -1,6 +1,6 @@
#include "grammar.h"
-#include <yql/essentials/sql/v1/format/sql_format.h>
+#include <yql/essentials/sql/v1/reflect/sql_reflect.h>
namespace NSQLComplete {
@@ -44,7 +44,7 @@ namespace NSQLComplete {
std::unordered_set<TTokenId> ComputeKeywordTokens() {
const auto& vocabulary = GetVocabulary();
- const auto keywords = NSQLFormat::GetKeywords();
+ const auto keywords = NSQLReflect::LoadLexerGrammar().KeywordNames;
auto keywordTokens = GetAllTokens();
std::erase_if(keywordTokens, [&](TTokenId token) {
diff --git a/yql/essentials/sql/v1/complete/syntax/ya.make b/yql/essentials/sql/v1/complete/syntax/ya.make
index a3fe973e315..24fd94a952a 100644
--- a/yql/essentials/sql/v1/complete/syntax/ya.make
+++ b/yql/essentials/sql/v1/complete/syntax/ya.make
@@ -16,12 +16,11 @@ PEERDIR(
yql/essentials/parser/antlr_ast/gen/v1_ansi_antlr4
yql/essentials/parser/antlr_ast/gen/v1_antlr4
+ yql/essentials/parser/lexer_common
yql/essentials/sql/settings
yql/essentials/sql/v1/lexer
-
- # TODO(YQL-19747): Replace with the sql/v1/reflect to get keywords
- yql/essentials/sql/v1/format
+ yql/essentials/sql/v1/reflect
)
END()
diff --git a/yql/essentials/tools/yql_complete/ya.make b/yql/essentials/tools/yql_complete/ya.make
index d745f9142a5..107e6ba5625 100644
--- a/yql/essentials/tools/yql_complete/ya.make
+++ b/yql/essentials/tools/yql_complete/ya.make
@@ -5,6 +5,8 @@ PROGRAM()
PEERDIR(
library/cpp/getopt
yql/essentials/sql/v1/complete
+ yql/essentials/sql/v1/lexer/antlr4_pure
+ yql/essentials/sql/v1/lexer/antlr4_pure_ansi
)
SRCS(
diff --git a/yql/essentials/tools/yql_complete/yql_complete.cpp b/yql/essentials/tools/yql_complete/yql_complete.cpp
index 289573190bf..320b9f1b487 100644
--- a/yql/essentials/tools/yql_complete/yql_complete.cpp
+++ b/yql/essentials/tools/yql_complete/yql_complete.cpp
@@ -1,14 +1,38 @@
#include <yql/essentials/sql/v1/complete/sql_complete.h>
+#include <yql/essentials/sql/v1/complete/name/static/frequency.h>
+#include <yql/essentials/sql/v1/complete/name/static/ranking.h>
+#include <yql/essentials/sql/v1/complete/name/static/name_service.h>
+
+#include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h>
+#include <yql/essentials/sql/v1/lexer/antlr4_pure_ansi/lexer.h>
#include <library/cpp/getopt/last_getopt.h>
#include <util/stream/file.h>
+NSQLComplete::TFrequencyData LoadFrequencyDataFromFile(TString filepath) {
+ TString text = TUnbufferedFileInput(filepath).ReadAll();
+ return NSQLComplete::ParseJsonFrequencyData(text);
+}
+
+NSQLComplete::TLexerSupplier MakePureLexerSupplier() {
+ NSQLTranslationV1::TLexers lexers;
+ lexers.Antlr4Pure = NSQLTranslationV1::MakeAntlr4PureLexerFactory();
+ lexers.Antlr4PureAnsi = NSQLTranslationV1::MakeAntlr4PureAnsiLexerFactory();
+ return [lexers = std::move(lexers)](bool ansi) {
+ return NSQLTranslationV1::MakeLexer(
+ lexers, ansi, /* antlr4 = */ true,
+ NSQLTranslationV1::ELexerFlavor::Pure);
+ };
+}
+
int Run(int argc, char* argv[]) {
NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default();
TString inFileName;
+ TString freqFileName;
TMaybe<ui64> pos;
opts.AddLongOption('i', "input", "input file").RequiredArgument("input").StoreResult(&inFileName);
+ opts.AddLongOption('f', "freq", "frequences file").StoreResult(&freqFileName);
opts.AddLongOption('p', "pos", "position").StoreResult(&pos);
opts.SetFreeArgsNum(0);
opts.AddHelpOption();
@@ -20,9 +44,21 @@ int Run(int argc, char* argv[]) {
inFile.Reset(new TUnbufferedFileInput(inFileName));
}
IInputStream& in = inFile ? *inFile.Get() : Cin;
-
auto queryString = in.ReadAll();
- auto engine = NSQLComplete::MakeSqlCompletionEngine();
+
+ NSQLComplete::IRanking::TPtr ranking;
+ if (freqFileName.empty()) {
+ ranking = NSQLComplete::MakeDefaultRanking();
+ } else {
+ auto freq = LoadFrequencyDataFromFile(freqFileName);
+ ranking = NSQLComplete::MakeDefaultRanking(std::move(freq));
+ }
+ auto engine = NSQLComplete::MakeSqlCompletionEngine(
+ MakePureLexerSupplier(),
+ NSQLComplete::MakeStaticNameService(
+ NSQLComplete::MakeDefaultNameSet(),
+ std::move(ranking)));
+
NSQLComplete::TCompletionInput input;
input.Text = queryString;
if (pos) {