diff options
author | xenoxeno <xeno@ydb.tech> | 2023-07-25 17:15:39 +0300 |
---|---|---|
committer | root <root@qavm-2ed34686.qemu> | 2023-07-25 17:15:39 +0300 |
commit | 4d672f41f6b7f2f55f5122277fc719f0d4b402e0 (patch) | |
tree | f6656e5ded01eac7199a7b214a3ee168e79d52be | |
parent | 29e76423c4d4340981732d6bf77ab8dc85fd31ed (diff) | |
download | ydb-4d672f41f6b7f2f55f5122277fc719f0d4b402e0.tar.gz |
split statements KIKIMR-18684
-rw-r--r-- | ydb/core/local_pgwire/CMakeLists.darwin-x86_64.txt | 1 | ||||
-rw-r--r-- | ydb/core/local_pgwire/CMakeLists.linux-aarch64.txt | 1 | ||||
-rw-r--r-- | ydb/core/local_pgwire/CMakeLists.linux-x86_64.txt | 1 | ||||
-rw-r--r-- | ydb/core/local_pgwire/CMakeLists.windows-x86_64.txt | 1 | ||||
-rw-r--r-- | ydb/core/local_pgwire/local_pgwire_connection.cpp | 54 | ||||
-rw-r--r-- | ydb/core/local_pgwire/local_pgwire_util.h | 11 | ||||
-rw-r--r-- | ydb/core/local_pgwire/pgwire_kqp_proxy.cpp | 23 | ||||
-rw-r--r-- | ydb/core/local_pgwire/pgwire_kqp_proxy.h | 24 | ||||
-rw-r--r-- | ydb/core/local_pgwire/sql_parser.cpp | 570 | ||||
-rw-r--r-- | ydb/core/local_pgwire/sql_parser.h | 59 | ||||
-rw-r--r-- | ydb/core/local_pgwire/ya.make | 2 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_connection.cpp | 4 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_proxy_events.h | 2 |
13 files changed, 725 insertions, 28 deletions
diff --git a/ydb/core/local_pgwire/CMakeLists.darwin-x86_64.txt b/ydb/core/local_pgwire/CMakeLists.darwin-x86_64.txt index 83fd0a6768..e739b8bcdb 100644 --- a/ydb/core/local_pgwire/CMakeLists.darwin-x86_64.txt +++ b/ydb/core/local_pgwire/CMakeLists.darwin-x86_64.txt @@ -32,4 +32,5 @@ target_sources(ydb-core-local_pgwire PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire_util.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/sql_parser.cpp ) diff --git a/ydb/core/local_pgwire/CMakeLists.linux-aarch64.txt b/ydb/core/local_pgwire/CMakeLists.linux-aarch64.txt index 1ac49af2aa..98e3cee271 100644 --- a/ydb/core/local_pgwire/CMakeLists.linux-aarch64.txt +++ b/ydb/core/local_pgwire/CMakeLists.linux-aarch64.txt @@ -33,4 +33,5 @@ target_sources(ydb-core-local_pgwire PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire_util.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/sql_parser.cpp ) diff --git a/ydb/core/local_pgwire/CMakeLists.linux-x86_64.txt b/ydb/core/local_pgwire/CMakeLists.linux-x86_64.txt index 1ac49af2aa..98e3cee271 100644 --- a/ydb/core/local_pgwire/CMakeLists.linux-x86_64.txt +++ b/ydb/core/local_pgwire/CMakeLists.linux-x86_64.txt @@ -33,4 +33,5 @@ target_sources(ydb-core-local_pgwire PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire_util.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/sql_parser.cpp ) diff --git a/ydb/core/local_pgwire/CMakeLists.windows-x86_64.txt b/ydb/core/local_pgwire/CMakeLists.windows-x86_64.txt index 83fd0a6768..e739b8bcdb 100644 --- a/ydb/core/local_pgwire/CMakeLists.windows-x86_64.txt +++ b/ydb/core/local_pgwire/CMakeLists.windows-x86_64.txt @@ -32,4 +32,5 @@ target_sources(ydb-core-local_pgwire PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/local_pgwire_util.cpp ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/local_pgwire/sql_parser.cpp ) diff --git a/ydb/core/local_pgwire/local_pgwire_connection.cpp b/ydb/core/local_pgwire/local_pgwire_connection.cpp index c532865d13..3221d2de8f 100644 --- a/ydb/core/local_pgwire/local_pgwire_connection.cpp +++ b/ydb/core/local_pgwire/local_pgwire_connection.cpp @@ -1,5 +1,7 @@ #include "log_impl.h" #include "local_pgwire_util.h" +#include "sql_parser.h" +#include "pgwire_kqp_proxy.h" #include <ydb/core/grpc_services/local_rpc/local_rpc.h> #include <ydb/core/pgproxy/pg_proxy_events.h> @@ -20,14 +22,6 @@ namespace NLocalPgWire { using namespace NActors; using namespace NKikimr; -extern NActors::IActor* CreatePgwireKqpProxy( - std::unordered_map<TString, TString> params -); - -NActors::IActor* CreatePgwireKqpProxyQuery(const TActorId& owner, std::unordered_map<TString, TString> params, const TConnectionState& connection, NPG::TEvPGEvents::TEvQuery::TPtr&& evQuery); -NActors::IActor* CreatePgwireKqpProxyParse(const TActorId& owner, std::unordered_map<TString, TString> params, const TConnectionState& connection, NPG::TEvPGEvents::TEvParse::TPtr&& evParse); -NActors::IActor* CreatePgwireKqpProxyExecute(const TActorId& owner, std::unordered_map<TString, TString> params, const TConnectionState& connection, NPG::TEvPGEvents::TEvExecute::TPtr&& evExecute, const TParsedStatement& statement); - class TPgYdbConnection : public TActor<TPgYdbConnection> { using TBase = TActor<TPgYdbConnection>; @@ -44,23 +38,49 @@ public: , ConnectionParams(std::move(params)) {} - void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) { - BLOG_D("TEvQuery " << ev->Sender); - if (IsQueryEmpty(ev->Get()->Message->GetQuery())) { + void ProcessEventsQueue() { + while (!Events.empty() && Inflight == 0) { + StateWork(Events.front()); + Events.pop_front(); + } + } + + void Handle(TEvEvents::TEvSingleQuery::TPtr& ev) { + BLOG_D("TEvSingleQuery " << ev->Sender); + if (IsQueryEmpty(ev->Get()->Query)) { auto response = std::make_unique<NPG::TEvPGEvents::TEvQueryResponse>(); response->EmptyQuery = true; + response->ReadyForQuery = ev->Get()->FinalQuery; Send(ev->Sender, response.release(), 0, ev->Cookie); return; } + ++Inflight; - TActorId actorId = Register(CreatePgwireKqpProxyQuery(SelfId(), ConnectionParams, Connection, std::move(ev))); + TActorId actorId = RegisterWithSameMailbox(CreatePgwireKqpProxyQuery(SelfId(), ConnectionParams, Connection, std::move(ev))); BLOG_D("Created pgwireKqpProxyQuery: " << actorId); } + void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) { + BLOG_D("TEvQuery " << ev->Sender); + + TStatementIterator stmtIter((TString(ev->Get()->Message->GetQuery()))); + std::vector<TString> statements; + + for (auto pStmt = stmtIter.Next(); pStmt != nullptr; pStmt = stmtIter.Next()) { + statements.push_back(*pStmt); + } + + for (std::size_t n = 0; n < statements.size(); ++n) { + Events.push_front(new NActors::IEventHandle(SelfId(), ev->Sender, new TEvEvents::TEvSingleQuery(statements[statements.size() - n - 1], n == 0), 0, ev->Cookie)); + } + + ProcessEventsQueue(); + } + void Handle(NPG::TEvPGEvents::TEvParse::TPtr& ev) { BLOG_D("TEvParse " << ev->Sender); ++Inflight; - TActorId actorId = Register(CreatePgwireKqpProxyParse(SelfId(), ConnectionParams, Connection, std::move(ev))); + TActorId actorId = RegisterWithSameMailbox(CreatePgwireKqpProxyParse(SelfId(), ConnectionParams, Connection, std::move(ev))); BLOG_D("Created pgwireKqpProxyParse: " << actorId); return; } @@ -136,7 +156,7 @@ public: } ++Inflight; - TActorId actorId = Register(CreatePgwireKqpProxyExecute(SelfId(), ConnectionParams, Connection, std::move(ev), it->second)); + TActorId actorId = RegisterWithSameMailbox(CreatePgwireKqpProxyExecute(SelfId(), ConnectionParams, Connection, std::move(ev), it->second)); BLOG_D("Created pgwireKqpProxyExecute: " << actorId); } @@ -170,10 +190,7 @@ public: Connection.SessionId = connection.SessionId; } } - while (!Events.empty() && Inflight == 0) { - StateWork(Events.front()); - Events.pop_front(); - } + ProcessEventsQueue(); } void PassAway() override { @@ -203,6 +220,7 @@ public: STATEFN(StateWork) { switch (ev->GetTypeRewrite()) { hFunc(NPG::TEvPGEvents::TEvQuery, Handle); + hFunc(TEvEvents::TEvSingleQuery, Handle); hFunc(NPG::TEvPGEvents::TEvParse, Handle); hFunc(NPG::TEvPGEvents::TEvBind, Handle); hFunc(NPG::TEvPGEvents::TEvDescribe, Handle); diff --git a/ydb/core/local_pgwire/local_pgwire_util.h b/ydb/core/local_pgwire/local_pgwire_util.h index 60cf61e743..22ba3d3fd8 100644 --- a/ydb/core/local_pgwire/local_pgwire_util.h +++ b/ydb/core/local_pgwire/local_pgwire_util.h @@ -44,6 +44,7 @@ enum EFormatType : int16_t { struct TEvEvents { enum EEv { EvProxyCompleted = EventSpaceBegin(NActors::TEvents::ES_PRIVATE), + EvSingleQuery, EvEnd }; @@ -63,6 +64,16 @@ struct TEvEvents { : ParsedStatement(parsedStatement) {} }; + + struct TEvSingleQuery : NActors::TEventLocal<TEvSingleQuery, EvSingleQuery> { + TString Query; + bool FinalQuery = true; + + TEvSingleQuery(const TString& query, bool finalQuery) + : Query(query) + , FinalQuery(finalQuery) + {} + }; }; TString ColumnPrimitiveValueToString(NYdb::TValueParser& valueParser); diff --git a/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp b/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp index 900edfaf29..5c72508053 100644 --- a/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp +++ b/ydb/core/local_pgwire/pgwire_kqp_proxy.cpp @@ -1,5 +1,6 @@ #include "log_impl.h" #include "local_pgwire_util.h" +#include "pgwire_kqp_proxy.h" #include <ydb/core/kqp/common/events/events.h> #include <ydb/core/kqp/common/simple/services.h> #include <ydb/core/kqp/executer_actor/kqp_executer.h> @@ -142,19 +143,22 @@ protected: class TPgwireKqpProxyQuery : public TPgwireKqpProxy<TPgwireKqpProxyQuery> { using TBase = TPgwireKqpProxy<TPgwireKqpProxyQuery>; - NPG::TEvPGEvents::TEvQuery::TPtr EventQuery_; + TEvEvents::TEvSingleQuery::TPtr EventQuery_; bool WasMeta_ = false; std::size_t RowsSelected_ = 0; public: - TPgwireKqpProxyQuery(const TActorId& owner, std::unordered_map<TString, TString> params, const TConnectionState& connection, NPG::TEvPGEvents::TEvQuery::TPtr&& evQuery) + TPgwireKqpProxyQuery(const TActorId& owner, + std::unordered_map<TString, TString> params, + const TConnectionState& connection, + TEvEvents::TEvSingleQuery::TPtr&& evQuery) : TPgwireKqpProxy(owner, std::move(params), connection) , EventQuery_(std::move(evQuery)) { } void Bootstrap() { - auto query(EventQuery_->Get()->Message->GetQuery()); + auto query(EventQuery_->Get()->Query); auto event = MakeKqpRequest(); NKikimrKqp::TQueryRequest& request = *event->Record.MutableRequest(); @@ -205,13 +209,14 @@ public: } FillResultSet(resultSet, response.get()->DataRows); response->CommandCompleted = false; + response->ReadyForQuery = false; RowsSelected_ += response->DataRows.size(); - BLOG_D(this->SelfId() << "Send rowset data (" << ev->Get()->Record.GetSeqNo() << ") to: " << EventQuery_->Sender); + BLOG_D(this->SelfId() << " Send rowset " << ev->Get()->Record.GetQueryResultIndex() << " data " << ev->Get()->Record.GetSeqNo() << " to " << EventQuery_->Sender); Send(EventQuery_->Sender, response.release(), 0, EventQuery_->Cookie); - BLOG_D(this->SelfId() << "Send stream data ack to: " << ev->Sender); + BLOG_D(this->SelfId() << " Send stream data ack to " << ev->Sender); auto resp = MakeHolder<NKqp::TEvKqpExecuter::TEvStreamDataAck>(); resp->Record.SetSeqNo(ev->Get()->Record.GetSeqNo()); resp->Record.SetFreeSpace(std::numeric_limits<ui64>::max()); @@ -248,6 +253,7 @@ public: response->ErrorFields.push_back({'M', e.what()}); } response->CommandCompleted = true; + response->ReadyForQuery = EventQuery_->Get()->FinalQuery; BLOG_D("Finally replying to " << EventQuery_->Sender); Send(EventQuery_->Sender, response.release(), 0, EventQuery_->Cookie); PassAway(); @@ -413,13 +419,14 @@ public: auto response = MakeResponse(); FillResultSet(resultSet, response.get()->DataRows, Statement_.BindData.ResultsFormat); response->CommandCompleted = false; + response->ReadyForQuery = false; RowsSelected_ += response->DataRows.size(); - BLOG_D(this->SelfId() << "Send rowset data (" << ev->Get()->Record.GetSeqNo() << ") to: " << EventExecute_->Sender); + BLOG_D(this->SelfId() << " Send rowset " << ev->Get()->Record.GetQueryResultIndex() << " data " << ev->Get()->Record.GetSeqNo() << " to " << EventExecute_->Sender); Send(EventExecute_->Sender, response.release(), 0, EventExecute_->Cookie); - BLOG_D(this->SelfId() << "Send stream data ack to: " << ev->Sender); + BLOG_D(this->SelfId() << " Send stream data ack to " << ev->Sender); auto resp = MakeHolder<NKqp::TEvKqpExecuter::TEvStreamDataAck>(); resp->Record.SetSeqNo(ev->Get()->Record.GetSeqNo()); resp->Record.SetFreeSpace(std::numeric_limits<ui64>::max()); @@ -472,7 +479,7 @@ public: NActors::IActor* CreatePgwireKqpProxyQuery(const TActorId& owner, std::unordered_map<TString, TString> params, const TConnectionState& connection, - NPG::TEvPGEvents::TEvQuery::TPtr&& evQuery) { + TEvEvents::TEvSingleQuery::TPtr&& evQuery) { return new TPgwireKqpProxyQuery(owner, std::move(params), connection, std::move(evQuery)); } diff --git a/ydb/core/local_pgwire/pgwire_kqp_proxy.h b/ydb/core/local_pgwire/pgwire_kqp_proxy.h new file mode 100644 index 0000000000..3dab049c17 --- /dev/null +++ b/ydb/core/local_pgwire/pgwire_kqp_proxy.h @@ -0,0 +1,24 @@ +#pragma once + +#include <ydb/core/pgproxy/pg_proxy_events.h> +#include "local_pgwire_util.h" + +namespace NLocalPgWire { + +NActors::IActor* CreatePgwireKqpProxyQuery(const NActors::TActorId& owner, + std::unordered_map<TString, TString> params, + const TConnectionState& connection, + TEvEvents::TEvSingleQuery::TPtr&& evQuery); + +NActors::IActor* CreatePgwireKqpProxyParse(const NActors::TActorId& owner, + std::unordered_map<TString, TString> params, + const TConnectionState& connection, + NPG::TEvPGEvents::TEvParse::TPtr&& evParse); + +NActors::IActor* CreatePgwireKqpProxyExecute(const NActors::TActorId& owner, + std::unordered_map<TString, TString> params, + const TConnectionState& connection, + NPG::TEvPGEvents::TEvExecute::TPtr&& evExecute, + const TParsedStatement& statement); + +} diff --git a/ydb/core/local_pgwire/sql_parser.cpp b/ydb/core/local_pgwire/sql_parser.cpp new file mode 100644 index 0000000000..0b69c6ec03 --- /dev/null +++ b/ydb/core/local_pgwire/sql_parser.cpp @@ -0,0 +1,570 @@ +#include "sql_parser.h" +#include <util/string/split.h> +#include <util/string/builder.h> + +TStatementIterator::TStatementIterator(const TString& program) + : Program_(program) + , Cur_() + , Pos_(0) + , State_(State::InOperator) + , AtStmtStart_(true) + , Mode_(State::InOperator) + , Depth_(0) + , Tag_() + , StandardConformingStrings_(true) +{ +} + +bool TStatementIterator::isInWsSignificantState(State state) { + switch (state) { + case State::QuotedIdentifier: + case State::StringLiteral: + case State::EscapedStringLiteral: + case State::DollarStringLiteral: + return true; + default: + return false; + } +} + +bool TStatementIterator::isEscapedChar(const TString& s, size_t pos) { + bool escaped = false; + while (s[--pos] == '\\') { + escaped = !escaped; + } + return escaped; +} + +TString TStatementIterator::RemoveEmptyLines(const TString& s, bool inStatement) { + if (s.empty()) { + return {}; + } + + TStringBuilder sb; + auto isFirstLine = true; + + if (inStatement && s[0] == '\n') { + sb << '\n'; + } + + for (TStringBuf line : StringSplitter(s).SplitBySet("\r\n").SkipEmpty()) { + if (isFirstLine) { + isFirstLine = false; + } else { + sb << '\n'; + } + sb << line; + } + return sb; +} + +const TString* TStatementIterator::Next() { + if (TStringBuf::npos == Pos_) + return nullptr; + + size_t startPos = Pos_; + size_t curPos = Pos_; + size_t endPos; + auto prevState = State_; + + TStringBuilder stmt; + TStringBuilder rawStmt; + auto inStatement = false; + + while (!CallParser(startPos)) { + endPos = (TStringBuf::npos != Pos_) ? Pos_ : Program_.length(); + + TStringBuf part{&Program_[curPos], endPos - curPos}; + + if (isInWsSignificantState(prevState)) { + if (!rawStmt.empty()) { + stmt << RemoveEmptyLines(rawStmt, inStatement); + rawStmt.clear(); + } + stmt << part; + inStatement = true; + } else { + rawStmt << part; + } + curPos = endPos; + prevState = State_; + } + endPos = (TStringBuf::npos != Pos_) ? Pos_ : Program_.length(); + + TStringBuf part{&Program_[curPos], endPos - curPos}; + + if (isInWsSignificantState(prevState)) { + if (!rawStmt.empty()) { + stmt << RemoveEmptyLines(rawStmt, inStatement); + rawStmt.clear(); + } + stmt << part; + inStatement = true; + } else { + rawStmt << part; + } + +#if 0 + if (0 < Pos_ && !(Pos_ == TStringBuf::npos || Program_[Pos_-1] == '\n')) { + Cerr << "Last char: '" << Program_[Pos_-1] << "'\n"; + } +#endif + + stmt << RemoveEmptyLines(rawStmt, inStatement); + // inv: Pos_ is at the start of next token + if (startPos == endPos) + return nullptr; + + stmt << '\n'; + Cur_ = stmt; + + ApplyStateFromStatement(Cur_); + + return &Cur_; +} + +// States: +// - in-operator +// - line comment +// - block comment +// - quoted identifier (U& quoted identifier is no difference) +// - string literal (U& string literal is the same for our purpose) +// - E string literal +// - $ string literal +// - end-of-operator + +// Rules: +// - in-operator +// -- -> next: line comment +// /* -> depth := 1, next: block comment +// " -> next: quoted identifier +// ' -> next: string literal +// E' -> next: E string literal +// $tag$, not preceded by alnum char (a bit of simplification here but sufficient) -> tag := tag, next: $ string literal +// ; -> current_mode := end-of-operator, next: end-of-operator + +// - line comment +// EOL -> next: current_mode + +// - block comment +// /* -> ++depth +// */ -> --depth, if (depth == 0) -> next: current_mode + +// - quoted identifier +// " -> next: in-operator + +// - string literal +// ' -> next: in-operator + +// - E string literal +// ' -> if not preceeded by \ next: in-operator + +// - $ string literal +// $tag$ -> next: in-operator + +// - end-of-operator +// -- -> next: line comment, just once +// /* -> depth := 1, next: block comment +// non-space char -> unget, emit, current_mode := in-operator, next: in-operator + +// In every state: +// EOS -> emit if consumed part of the input is not empty + +bool TStatementIterator::SaveDollarTag() { + if (Pos_ + 1 == Program_.length()) + return false; + + auto p = Program_.cbegin() + (Pos_ + 1); + + if (std::isdigit(*p)) + return false; + + for (;p != Program_.cend(); ++p) { + if (*p == '$') { + auto bp = &Program_[Pos_]; + auto l = p - bp; + Tag_ = TStringBuf(bp, l + 1); + Pos_ += l; + + return true; + } + if (!(std::isalpha(*p) || std::isdigit(*p) || *p == '_')) + return false; + } + return false; +} + +bool TStatementIterator::IsCopyFromStdin(size_t startPos, size_t endPos) { + TString stmt(Program_, startPos, endPos - startPos + 1); + stmt.to_upper(); + // FROM STDOUT is used in insert.sql testcase, probably a bug + return stmt.Contains(" FROM STDIN") || stmt.Contains(" FROM STDOUT"); +} + +bool TStatementIterator::InOperatorParser(size_t startPos) { + // need \ to detect psql meta-commands + static const TString midNextTokens{"'\";-/$\\"}; + // need : for basic psql-vars support + static const TString initNextTokens{"'\";-/$\\:"}; + const auto& nextTokens = (AtStmtStart_) ? initNextTokens : midNextTokens; + + if (AtStmtStart_) { + Pos_ = Program_.find_first_not_of(" \t\n\r\v", Pos_); + if (TString::npos == Pos_) { + return true; + } + } + + Pos_ = Program_.find_first_of(nextTokens, Pos_); + + if (TString::npos == Pos_) { + return true; + } + + switch (Program_[Pos_]) { + case '\'': + State_ = (!StandardConformingStrings_ || 0 < Pos_ && std::toupper(Program_[Pos_ - 1]) == 'E') + ? State::EscapedStringLiteral + : State::StringLiteral; + break; + + case '"': + State_ = State::QuotedIdentifier; + break; + + case ';': + State_ = Mode_ = IsCopyFromStdin(startPos, Pos_) + ? State::InCopyFromStdin + : State::EndOfOperator; + break; + + case '-': + if (Pos_ < Program_.length() && Program_[Pos_ + 1] == '-') { + State_ = State::LineComment; + ++Pos_; + } + break; + + case '/': + if (Pos_ < Program_.length() && Program_[Pos_ + 1] == '*') { + State_ = State::BlockComment; + ++Depth_; + ++Pos_; + } + break; + + case '$': + if (Pos_ == 0 || std::isspace(Program_[Pos_ - 1])) { + if (SaveDollarTag()) + State_ = State::DollarStringLiteral; + } + break; + + case '\\': + if (AtStmtStart_) { + State_ = State::InMetaCommand; + } else if (Program_.Contains("\\gexec", Pos_)) { + Pos_ += 6; + return Emit(Program_[Pos_] == '\n'); + } + break; + + case ':': + if (Pos_ == 0 || Program_[Pos_-1] == '\n') { + State_ = State::InVar; + } + break; + } + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + return false; +} + +bool TStatementIterator::Emit(bool atEol) { + State_ = Mode_ = State::InOperator; + AtStmtStart_ = true; + + if (atEol) { + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + } + // else do not consume as we're expected to be on the first char of the next statement + + return true; +} + +bool TStatementIterator::EndOfOperatorParser() { + const auto p = std::find_if_not(Program_.cbegin() + Pos_, Program_.cend(), [](const auto& c) { + return c == ' ' || c == '\t' || c == '\r'; + }); + + if (p == Program_.cend()) { + Pos_ = TStringBuf::npos; + return true; + } + + Pos_ = p - Program_.cbegin(); + + switch (*p) { + case '-': + if (Pos_ < Program_.length() && Program_[Pos_ + 1] == '-') { + State_ = State::LineComment; + ++Pos_; + } + break; + + case '/': + if (Pos_ < Program_.length() && Program_[Pos_ + 1] == '*') { + State_ = State::BlockComment; + ++Depth_; + ++Pos_; + } + break; + + default: + return Emit(*p == '\n'); + } + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + return false; +} + +bool TStatementIterator::LineCommentParser() { + Pos_ = Program_.find('\n', Pos_); + + if (TString::npos == Pos_) + return true; + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + if (Mode_ == State::EndOfOperator) { + return Emit(false); + } + + State_ = Mode_; + + return false; +} + +bool TStatementIterator::BlockCommentParser() { + Pos_ = Program_.find_first_of("*/", Pos_); + + if (TString::npos == Pos_) + return true; + + switch(Program_[Pos_]) { + case '/': + if (Pos_ < Program_.length() && Program_[Pos_ + 1] == '*') { + ++Depth_; + ++Pos_; + } + break; + + case '*': + if (Pos_ < Program_.length() && Program_[Pos_ + 1] == '/') { + --Depth_; + ++Pos_; + + if (0 == Depth_) { + State_ = Mode_; + } + } + break; + } + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + return false; +} + +bool TStatementIterator::QuotedIdentifierParser() { + Pos_ = Program_.find('"', Pos_); + + if (TString::npos == Pos_) + return true; + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + State_ = State::InOperator; + AtStmtStart_ = false; + + return false; +} + +bool TStatementIterator::StringLiteralParser() { + Pos_ = Program_.find('\'', Pos_); + + if (TString::npos == Pos_) + return true; + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + State_ = State::InOperator; + AtStmtStart_ = false; + + return false; +} + +bool TStatementIterator::EscapedStringLiteralParser() { + Pos_ = Program_.find('\'', Pos_); + + if (TString::npos == Pos_) + return true; + + if (isEscapedChar(Program_, Pos_)) { + ++Pos_; + return false; + } + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + State_ = State::InOperator; + AtStmtStart_ = false; + + return false; +} + +bool TStatementIterator::DollarStringLiteralParser() { + //Y_ENSURE(Tag_ != nullptr && 2 <= Tag_.length()); + + Pos_ = Program_.find(Tag_, Pos_); + + if (TString::npos == Pos_) + return true; + + Pos_ += Tag_.length(); + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + Tag_.Clear(); + + State_ = State::InOperator; + AtStmtStart_ = false; + + return false; +} + +bool TStatementIterator::MetaCommandParser() { + Pos_ = Program_.find('\n', Pos_); + + if (TString::npos == Pos_) + return true; + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + return Emit(false); +} + +bool TStatementIterator::InCopyFromStdinParser() { + Pos_ = Program_.find("\n\\.\n", Pos_); + + if (TString::npos == Pos_) + return true; + + Pos_ += 4; + return Emit(false); +} + +// For now we support vars occupying a whole line only +bool TStatementIterator::VarParser() { + // TODO: validate var name + Pos_ = Program_.find('\n', Pos_); + + if (TString::npos == Pos_) + return true; + + ++Pos_; + if (Program_.length() == Pos_) { + Pos_ = TString::npos; + return true; + } + + return Emit(false); +} + +bool TStatementIterator::CallParser(size_t startPos) { + switch (State_) { + case State::InOperator: + return InOperatorParser(startPos); + + case State::EndOfOperator: + return EndOfOperatorParser(); + + case State::LineComment: + return LineCommentParser(); + + case State::BlockComment: + return BlockCommentParser(); + + case State::QuotedIdentifier: + return QuotedIdentifierParser(); + + case State::StringLiteral: + return StringLiteralParser(); + + case State::EscapedStringLiteral: + return EscapedStringLiteralParser(); + + case State::DollarStringLiteral: + return DollarStringLiteralParser(); + + case State::InMetaCommand: + return MetaCommandParser(); + + case State::InCopyFromStdin: + return InCopyFromStdinParser(); + + case State::InVar: + return VarParser(); + + default: + Y_UNREACHABLE(); + } +} + +void TStatementIterator::ApplyStateFromStatement(const TStringBuf& stmt) { + if (stmt.contains("set standard_conforming_strings = on;") || + stmt.contains("reset standard_conforming_strings;")) + { + StandardConformingStrings_ = true; + } else if (stmt.contains("set standard_conforming_strings = off;")) { + StandardConformingStrings_ = false; + } +} diff --git a/ydb/core/local_pgwire/sql_parser.h b/ydb/core/local_pgwire/sql_parser.h new file mode 100644 index 0000000000..0567e05ccb --- /dev/null +++ b/ydb/core/local_pgwire/sql_parser.h @@ -0,0 +1,59 @@ + +#include <util/generic/string.h> +#include <util/generic/iterator.h> + +// a temporary copy-paste from: yql/tools/pgrun/pgrun.cpp?rev=r11829751#L49 + +class TStatementIterator final + : public TInputRangeAdaptor<TStatementIterator> +{ + enum class State { + InOperator, + EndOfOperator, + LineComment, + BlockComment, + QuotedIdentifier, + StringLiteral, + EscapedStringLiteral, + DollarStringLiteral, + InMetaCommand, + InCopyFromStdin, + InVar, + }; + +public: + TStatementIterator(const TString& program); + static bool isInWsSignificantState(State state); + static bool isEscapedChar(const TString& s, size_t pos); + TString RemoveEmptyLines(const TString& s, bool inStatement); + const TString* Next(); + +private: + bool SaveDollarTag(); + bool IsCopyFromStdin(size_t startPos, size_t endPos); + bool InOperatorParser(size_t startPos); + bool Emit(bool atEol); + bool EndOfOperatorParser(); + bool LineCommentParser(); + bool BlockCommentParser(); + bool QuotedIdentifierParser(); + bool StringLiteralParser(); + bool EscapedStringLiteralParser(); + bool DollarStringLiteralParser(); + bool MetaCommandParser(); + bool InCopyFromStdinParser(); + bool VarParser(); + bool CallParser(size_t startPos); + void ApplyStateFromStatement(const TStringBuf& stmt); + + TString Program_; + TString Cur_; + size_t Pos_; + State State_; + bool AtStmtStart_; + + State Mode_; + ui16 Depth_; + TStringBuf Tag_; + bool StandardConformingStrings_; +}; diff --git a/ydb/core/local_pgwire/ya.make b/ydb/core/local_pgwire/ya.make index a075a76ad7..58458166ba 100644 --- a/ydb/core/local_pgwire/ya.make +++ b/ydb/core/local_pgwire/ya.make @@ -8,6 +8,8 @@ SRCS( local_pgwire_util.h log_impl.h pgwire_kqp_proxy.cpp + sql_parser.cpp + sql_parser.h ) PEERDIR( diff --git a/ydb/core/pgproxy/pg_connection.cpp b/ydb/core/pgproxy/pg_connection.cpp index 7e3b3c2d1a..ce0be0757d 100644 --- a/ydb/core/pgproxy/pg_connection.cpp +++ b/ydb/core/pgproxy/pg_connection.cpp @@ -564,7 +564,7 @@ protected: } else { SendErrorResponse(ev->Get()->ErrorFields); } - if (ev->Get()->CommandCompleted) { + if (ev->Get()->ReadyForQuery) { BecomeReadyForQuery(); } } else { @@ -629,7 +629,7 @@ protected: } else { SendErrorResponse(ev->Get()->ErrorFields); } - if (ev->Get()->CommandCompleted) { + if (ev->Get()->ReadyForQuery) { ++OutgoingSequenceNumber; BecomeReadyForQuery(); } diff --git a/ydb/core/pgproxy/pg_proxy_events.h b/ydb/core/pgproxy/pg_proxy_events.h index 6b79b90bbc..e01795f716 100644 --- a/ydb/core/pgproxy/pg_proxy_events.h +++ b/ydb/core/pgproxy/pg_proxy_events.h @@ -112,6 +112,7 @@ struct TEvPGEvents { TString Tag; bool EmptyQuery = false; bool CommandCompleted = true; + bool ReadyForQuery = true; char TransactionStatus = 0; }; @@ -215,6 +216,7 @@ struct TEvPGEvents { TString Tag; bool EmptyQuery = false; bool CommandCompleted = true; + bool ReadyForQuery = true; char TransactionStatus = 0; }; |