summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvityaman <[email protected]>2025-04-07 23:50:34 +0300
committerrobot-piglet <[email protected]>2025-04-08 00:17:32 +0300
commit5af5cb3b7b423f2d3fcebf2e1406d37cbd6f3bbf (patch)
tree4a846ecb4e938d459032f34cf013fe5dfdb1010b
parent6a114f0cafe3074d67f8022fe051f0b28dfe31ab (diff)
YQL-19747 Improve yql_complete tool and add input validation
No description --- Pull Request resolved: https://github.com/ytsaurus/ytsaurus/pull/1185 commit_hash:1def5874ff6a9a5b3dcdd0ad285d2e64b16c9306
-rw-r--r--yql/essentials/parser/pg_wrapper/parser.cpp4
-rw-r--r--yql/essentials/public/fastcheck/format.cpp7
-rw-r--r--yql/essentials/public/issue/yql_issue.cpp2
-rw-r--r--yql/essentials/public/issue/yql_issue.h4
-rw-r--r--yql/essentials/sql/v1/complete/sql_complete.cpp9
-rw-r--r--yql/essentials/sql/v1/complete/sql_complete_ut.cpp32
-rw-r--r--yql/essentials/tools/yql_complete/ya.make1
-rw-r--r--yql/essentials/tools/yql_complete/yql_complete.cpp22
8 files changed, 62 insertions, 19 deletions
diff --git a/yql/essentials/parser/pg_wrapper/parser.cpp b/yql/essentials/parser/pg_wrapper/parser.cpp
index 94369296a87..c7dc29756ca 100644
--- a/yql/essentials/parser/pg_wrapper/parser.cpp
+++ b/yql/essentials/parser/pg_wrapper/parser.cpp
@@ -2,7 +2,9 @@
#include "arena_ctx.h"
+#include <util/charset/utf8.h>
#include <util/generic/scope.h>
+
#include <fcntl.h>
#include <stdint.h>
@@ -219,7 +221,7 @@ void PGParse(const TString& input, IPGParseEvents& events) {
break;
}
- if (!TTextWalker::IsUtf8Intermediate(input[i])) {
+ if (!IsUTF8ContinuationByte(input[i])) {
++codepoints;
}
walker.Advance(input[i]);
diff --git a/yql/essentials/public/fastcheck/format.cpp b/yql/essentials/public/fastcheck/format.cpp
index dac43f0ffa6..d4717b4bf79 100644
--- a/yql/essentials/public/fastcheck/format.cpp
+++ b/yql/essentials/public/fastcheck/format.cpp
@@ -5,6 +5,7 @@
#include <yql/essentials/sql/v1/proto_parser/antlr4/proto_parser.h>
#include <yql/essentials/sql/v1/proto_parser/antlr4_ansi/proto_parser.h>
#include <yql/essentials/core/issue/yql_issue.h>
+#include <util/charset/utf8.h>
#include <util/string/builder.h>
namespace NYql {
@@ -88,7 +89,7 @@ private:
continue;
}
- while (i > 0 && TTextWalker::IsUtf8Intermediate(request.Program[i])) {
+ while (i > 0 && IsUTF8ContinuationByte(request.Program[i])) {
--i;
}
@@ -96,12 +97,12 @@ private:
}
TString formattedSample = formattedQuery.substr(i, FormatContextLimit);
- while (!formattedSample.empty() && TTextWalker::IsUtf8Intermediate(formattedQuery.back())) {
+ while (!formattedSample.empty() && IsUTF8ContinuationByte(formattedQuery.back())) {
formattedSample.erase(formattedSample.size() - 1);
}
TString origSample = request.Program.substr(i, FormatContextLimit);
- while (!origSample.empty() && TTextWalker::IsUtf8Intermediate(origSample.back())) {
+ while (!origSample.empty() && IsUTF8ContinuationByte(origSample.back())) {
origSample.erase(origSample.size() - 1);
}
diff --git a/yql/essentials/public/issue/yql_issue.cpp b/yql/essentials/public/issue/yql_issue.cpp
index bb171a78692..af47895927d 100644
--- a/yql/essentials/public/issue/yql_issue.cpp
+++ b/yql/essentials/public/issue/yql_issue.cpp
@@ -54,7 +54,7 @@ TTextWalker& TTextWalker::Advance(char c) {
}
ui32 charDistance = 1;
- if (Utf8Aware && IsUtf8Intermediate(c)) {
+ if (Utf8Aware && IsUTF8ContinuationByte(c)) {
charDistance = 0;
}
diff --git a/yql/essentials/public/issue/yql_issue.h b/yql/essentials/public/issue/yql_issue.h
index 07fcdfed86e..2c60b979531 100644
--- a/yql/essentials/public/issue/yql_issue.h
+++ b/yql/essentials/public/issue/yql_issue.h
@@ -63,10 +63,6 @@ public:
{
}
- static inline bool IsUtf8Intermediate(char c) {
- return (c & 0xC0) == 0x80;
- }
-
template<typename T>
TTextWalker& Advance(const T& buf) {
for (char c : buf) {
diff --git a/yql/essentials/sql/v1/complete/sql_complete.cpp b/yql/essentials/sql/v1/complete/sql_complete.cpp
index 74ddbc04154..ed3afa29df4 100644
--- a/yql/essentials/sql/v1/complete/sql_complete.cpp
+++ b/yql/essentials/sql/v1/complete/sql_complete.cpp
@@ -26,6 +26,15 @@ namespace NSQLComplete {
}
TCompletion Complete(TCompletionInput input) {
+ if (
+ input.CursorPosition < input.Text.length() &&
+ IsUTF8ContinuationByte(input.Text.at(input.CursorPosition)) ||
+ input.Text.length() < input.CursorPosition) {
+ ythrow yexception()
+ << "invalid cursor position " << input.CursorPosition
+ << " for input size " << input.Text.size();
+ }
+
auto prefix = input.Text.Head(input.CursorPosition);
auto completedToken = GetCompletedToken(prefix);
diff --git a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp
index 10d358c8d3e..e9f5dbdfb73 100644
--- a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp
+++ b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp
@@ -63,8 +63,8 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) {
return MakeSqlCompletionEngine(std::move(lexer), std::move(service));
}
- TVector<TCandidate> Complete(ISqlCompletionEngine::TPtr& engine, TStringBuf prefix) {
- return engine->Complete({prefix}).Candidates;
+ TVector<TCandidate> Complete(ISqlCompletionEngine::TPtr& engine, TCompletionInput input) {
+ return engine->Complete(input).Candidates;
}
Y_UNIT_TEST(Beginning) {
@@ -438,17 +438,31 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) {
};
auto engine = MakeSqlCompletionEngineUT();
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "se"), expected);
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "sE"), expected);
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "Se"), expected);
- UNIT_ASSERT_VALUES_EQUAL(Complete(engine, "SE"), expected);
+ UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"se"}), expected);
+ UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"sE"}), expected);
+ UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"Se"}), expected);
+ UNIT_ASSERT_VALUES_EQUAL(Complete(engine, {"SE"}), expected);
}
Y_UNIT_TEST(InvalidStatementsRecovery) {
auto engine = MakeSqlCompletionEngineUT();
- UNIT_ASSERT_GE(Complete(engine, "select select; ").size(), 35);
- UNIT_ASSERT_GE(Complete(engine, "select select;").size(), 35);
- UNIT_ASSERT_VALUES_EQUAL_C(Complete(engine, "!;").size(), 0, "Lexer failing");
+ UNIT_ASSERT_GE(Complete(engine, {"select select; "}).size(), 35);
+ UNIT_ASSERT_GE(Complete(engine, {"select select;"}).size(), 35);
+ UNIT_ASSERT_VALUES_EQUAL_C(Complete(engine, {"!;"}).size(), 0, "Lexer failing");
+ }
+
+ Y_UNIT_TEST(InvalidCursorPosition) {
+ auto engine = MakeSqlCompletionEngineUT();
+
+ UNIT_ASSERT_NO_EXCEPTION(Complete(engine, {"", 0}));
+ UNIT_ASSERT_EXCEPTION(Complete(engine, {"", 1}), yexception);
+
+ UNIT_ASSERT_NO_EXCEPTION(Complete(engine, {"s", 0}));
+ UNIT_ASSERT_NO_EXCEPTION(Complete(engine, {"s", 1}));
+
+ UNIT_ASSERT_NO_EXCEPTION(Complete(engine, {"ы", 0}));
+ UNIT_ASSERT_EXCEPTION(Complete(engine, {"ы", 1}), yexception);
+ UNIT_ASSERT_NO_EXCEPTION(Complete(engine, {"ы", 2}));
}
Y_UNIT_TEST(DefaultNameService) {
diff --git a/yql/essentials/tools/yql_complete/ya.make b/yql/essentials/tools/yql_complete/ya.make
index 107e6ba5625..21a98628b1e 100644
--- a/yql/essentials/tools/yql_complete/ya.make
+++ b/yql/essentials/tools/yql_complete/ya.make
@@ -7,6 +7,7 @@ PEERDIR(
yql/essentials/sql/v1/complete
yql/essentials/sql/v1/lexer/antlr4_pure
yql/essentials/sql/v1/lexer/antlr4_pure_ansi
+ yql/essentials/utils
)
SRCS(
diff --git a/yql/essentials/tools/yql_complete/yql_complete.cpp b/yql/essentials/tools/yql_complete/yql_complete.cpp
index 320b9f1b487..0d592240ebf 100644
--- a/yql/essentials/tools/yql_complete/yql_complete.cpp
+++ b/yql/essentials/tools/yql_complete/yql_complete.cpp
@@ -6,7 +6,11 @@
#include <yql/essentials/sql/v1/lexer/antlr4_pure/lexer.h>
#include <yql/essentials/sql/v1/lexer/antlr4_pure_ansi/lexer.h>
+#include <yql/essentials/utils/utf8.h>
+
#include <library/cpp/getopt/last_getopt.h>
+
+#include <util/charset/utf8.h>
#include <util/stream/file.h>
NSQLComplete::TFrequencyData LoadFrequencyDataFromFile(TString filepath) {
@@ -25,6 +29,11 @@ NSQLComplete::TLexerSupplier MakePureLexerSupplier() {
};
}
+size_t UTF8PositionToBytes(const TStringBuf text, size_t position) {
+ const TStringBuf substr = SubstrUTF8(text, position, text.length());
+ return substr.begin() - text.begin();
+}
+
int Run(int argc, char* argv[]) {
NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default();
@@ -60,9 +69,20 @@ int Run(int argc, char* argv[]) {
std::move(ranking)));
NSQLComplete::TCompletionInput input;
+
input.Text = queryString;
+ if (!NYql::IsUtf8(input.Text)) {
+ ythrow yexception() << "provided input is not UTF encoded";
+ }
+
if (pos) {
- input.CursorPosition = *pos;
+ input.CursorPosition = UTF8PositionToBytes(input.Text, *pos);
+ } else if (Count(input.Text, '#') == 1) {
+ Cerr << "Note: found an only '#', setting the cursor position\n";
+ input.CursorPosition = input.Text.find('#');
+ } else if (Count(input.Text, '#') >= 2) {
+ Cerr << "Note: found multiple '#', defaulting the cursor position\n";
+ input.CursorPosition = queryString.size();
} else {
input.CursorPosition = queryString.size();
}