diff options
author | xenoxeno <xeno@ydb.tech> | 2022-12-30 18:37:20 +0300 |
---|---|---|
committer | xenoxeno <xeno@ydb.tech> | 2022-12-30 18:37:20 +0300 |
commit | 20070209a718431338a2eed7e45673d99c9f3738 (patch) | |
tree | 851516f961d2252654973bc02f3394b2605a7351 | |
parent | 08fb420704acef4bf445c80b67bc9f9fc648583f (diff) | |
download | ydb-20070209a718431338a2eed7e45673d99c9f3738.tar.gz |
extended query protocol
-rw-r--r-- | ydb/core/base/ticket_parser.h | 2 | ||||
-rw-r--r-- | ydb/core/pgproxy/CMakeLists.darwin.txt | 1 | ||||
-rw-r--r-- | ydb/core/pgproxy/CMakeLists.linux-aarch64.txt | 1 | ||||
-rw-r--r-- | ydb/core/pgproxy/CMakeLists.linux.txt | 1 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_connection.cpp | 427 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_proxy_events.h | 125 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_proxy_types.cpp | 207 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_proxy_types.h | 160 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_proxy_ut.cpp | 2 | ||||
-rw-r--r-- | ydb/core/pgproxy/pg_stream.h | 77 |
10 files changed, 807 insertions, 196 deletions
diff --git a/ydb/core/base/ticket_parser.h b/ydb/core/base/ticket_parser.h index d1b8454e6b9..3e5c14d92ee 100644 --- a/ydb/core/base/ticket_parser.h +++ b/ydb/core/base/ticket_parser.h @@ -1,4 +1,4 @@ -#pragma once + #pragma once #include <library/cpp/containers/stack_vector/stack_vec.h> #include <ydb/core/base/defs.h> #include <ydb/core/base/events.h> diff --git a/ydb/core/pgproxy/CMakeLists.darwin.txt b/ydb/core/pgproxy/CMakeLists.darwin.txt index c2f7d21243d..a6984d0aa65 100644 --- a/ydb/core/pgproxy/CMakeLists.darwin.txt +++ b/ydb/core/pgproxy/CMakeLists.darwin.txt @@ -19,5 +19,6 @@ target_link_libraries(ydb-core-pgproxy PUBLIC target_sources(ydb-core-pgproxy PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_connection.cpp ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_listener.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy_types.cpp ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy.cpp ) diff --git a/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt b/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt index 783f7a53b77..0db70f26363 100644 --- a/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt +++ b/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt @@ -20,5 +20,6 @@ target_link_libraries(ydb-core-pgproxy PUBLIC target_sources(ydb-core-pgproxy PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_connection.cpp ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_listener.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy_types.cpp ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy.cpp ) diff --git a/ydb/core/pgproxy/CMakeLists.linux.txt b/ydb/core/pgproxy/CMakeLists.linux.txt index 783f7a53b77..0db70f26363 100644 --- a/ydb/core/pgproxy/CMakeLists.linux.txt +++ b/ydb/core/pgproxy/CMakeLists.linux.txt @@ -20,5 +20,6 @@ target_link_libraries(ydb-core-pgproxy PUBLIC target_sources(ydb-core-pgproxy PRIVATE ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_connection.cpp ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_listener.cpp + ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy_types.cpp ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy.cpp ) diff --git a/ydb/core/pgproxy/pg_connection.cpp b/ydb/core/pgproxy/pg_connection.cpp index 135ccd92240..dbd2015ebcf 100644 --- a/ydb/core/pgproxy/pg_connection.cpp +++ b/ydb/core/pgproxy/pg_connection.cpp @@ -12,11 +12,17 @@ using namespace NActors; class TPGConnection : public TActorBootstrapped<TPGConnection>, public TNetworkConfig { public: + // incoming messages only enum class EMessageCode : char { Initial = 'i', Query = 'Q', Terminate = 'X', PasswordMessage = 'p', + Parse = 'P', + ParameterStatus = 'S', + Bind = 'B', + Describe = 'D', + Execute = 'E', }; using TBase = TActorBootstrapped<TPGConnection>; @@ -43,6 +49,9 @@ public: TSocketBuffer BufferOutput; TActorId DatabaseProxy; std::shared_ptr<TPGInitial> InitialMessage; + ui64 IncomingSequenceNumber = 1; + ui64 OutgoingSequenceNumber = 1; + std::deque<TAutoPtr<IEventHandle>> PostponedEvents; TPGConnection(TIntrusivePtr<TSocketDescriptor> socket, TNetworkConfig::TSocketAddressType address, const TActorId& databaseProxy) : Socket(std::move(socket)) @@ -117,7 +126,12 @@ protected: OnAccept(); } - TStringBuf GetMessageName(const TPGMessage& message) const { + enum class EDirection { + Incoming, + Outgoing, + }; + + TStringBuf GetMessageName(EDirection direction, const TPGMessage& message) const { static const std::unordered_map<char, TStringBuf> messageName = { {'i', "Initial"}, {'R', "Auth"}, @@ -126,20 +140,41 @@ protected: {'C', "CommandComplete"}, {'X', "Terminate"}, {'T', "RowDescription"}, - {'D', "DataRow"}, {'S', "ParameterStatus"}, - {'E', "ErrorResponse"}, {'I', "EmptyQueryResponse"}, {'p', "PasswordMessage"}, + {'P', "Parse"}, + {'1', "ParseComplete"}, + {'B', "Bind"}, + {'2', "BindComplete"}, }; auto itMessageName = messageName.find(message.Message); if (itMessageName != messageName.end()) { return itMessageName->second; } + static const std::unordered_map<char, TStringBuf> incomingMessageName = { + {'E', "Execute"}, + {'D', "Describe"}, + }; + static const std::unordered_map<char, TStringBuf> outgoingMessageName = { + {'E', "ErrorResponse"}, + {'D', "DataRow"}, + }; + switch (direction) { + case EDirection::Incoming: + itMessageName = incomingMessageName.find(message.Message); + break; + case EDirection::Outgoing: + itMessageName = outgoingMessageName.find(message.Message); + break; + } + if (itMessageName != messageName.end()) { + return itMessageName->second; + } return {}; } - TString GetMessageDump(const TPGMessage& message) const { + TString GetMessageDump(EDirection direction, const TPGMessage& message) const { switch (message.Message) { case 'i': return ((const TPGInitial&)message).Dump(); @@ -153,36 +188,66 @@ protected: return ((const TPGCommandComplete&)message).Dump(); case 'R': return ((const TPGAuth&)message).Dump(); + case 'D': { + switch (direction) { + case EDirection::Incoming: + return ((const TPGDescribe&)message).Dump(); + case EDirection::Outgoing: + return ((const TPGDataRow&)message).Dump(); + } + } + case 'E': { + switch (direction) { + case EDirection::Incoming: + return ((const TPGExecute&)message).Dump(); + case EDirection::Outgoing: + return ((const TPGErrorResponse&)message).Dump(); + } + } + case 'B': + return ((const TPGBind&)message).Dump(); + case 'P': + return ((const TPGParse&)message).Dump(); + } return {}; } - void PrintMessage(const TStringBuf& prefix, const TPGMessage& message) { - BLOG_D(prefix << "'" << message.Message << "' \"" << GetMessageName(message) << "\" Size(" << ntohl(message.Length) << ") " << GetMessageDump(message)); + void PrintMessage(EDirection direction, const TPGMessage& message) { + TStringBuilder prefix; + switch (direction) { + case EDirection::Incoming: + prefix << "-> [" << IncomingSequenceNumber << "] "; + break; + case EDirection::Outgoing: + prefix << "<- [" << OutgoingSequenceNumber << "] "; + break; + } + BLOG_D(prefix << "'" << message.Message << "' \"" << GetMessageName(direction, message) << "\" Size(" << message.GetDataSize() << ") " << GetMessageDump(direction, message)); } template<typename TMessage> void SendMessage(const TMessage& message) { - PrintMessage("<- ", message); + PrintMessage(EDirection::Outgoing, message); BufferOutput.Append(reinterpret_cast<const char*>(&message), sizeof(message)); } template<typename TMessage> - void SendStream(TPGStream<TMessage>& message) { + void SendStream(TPGStreamOutput<TMessage>& message) { message.UpdateLength(); const TPGMessage& header = *reinterpret_cast<const TPGMessage*>(message.Data()); - PrintMessage("<- ", header); + PrintMessage(EDirection::Outgoing, header); BufferOutput.Append(message.Data(), message.Size()); } void SendAuthOk() { - TPGStream<TPGAuth> authOk; + TPGStreamOutput<TPGAuth> authOk; authOk << uint32_t(TPGAuth::EAuthCode::OK); SendStream(authOk); } void SendAuthClearText() { - TPGStream<TPGAuth> authClearText; + TPGStreamOutput<TPGAuth> authClearText; authClearText << uint32_t(TPGAuth::EAuthCode::ClearText); SendStream(authClearText); } @@ -211,19 +276,19 @@ protected: 0x0190: 0045 7463 2f55 5443 004b 0000 000c 0004 .Etc/UTC.K...... */ void SendParameterStatus(TStringBuf name, TStringBuf value) { - TPGStream<TPGParameterStatus> param; + TPGStreamOutput<TPGParameterStatus> param; param << name << '\0' << value << '\0'; SendStream(param); } void SendReadyForQuery() { - TPGStream<TPGReadyForQuery> readyForQuery; + TPGStreamOutput<TPGReadyForQuery> readyForQuery; readyForQuery << 'I'; SendStream(readyForQuery); } void SendAuthError(const TString& error) { - TPGStream<TPGErrorResponse> errorResponse; + TPGStreamOutput<TPGErrorResponse> errorResponse; errorResponse << 'S' << "FATAL" << '\0' << 'V' << "FATAL" << '\0' @@ -244,7 +309,8 @@ protected: } void HandleMessage(const TPGInitial* message) { - if (message->Protocol == 0x2f16d204) { // 790024708 SSL handshake + uint32_t protocol = message->GetProtocol(); + if (protocol == 0x2f16d204) { // 790024708 SSL handshake if (IsSslSupported) { BufferOutput.Append('S'); if (!FlushOutput()) { @@ -266,19 +332,19 @@ protected: BufferInput.Append('i'); // initial packet pseudo-message return; } - if (message->Protocol == 0x2e16d204) { // 80877102 cancellation message + if (protocol == 0x2e16d204) { // 80877102 cancellation message BLOG_D("cancellation message"); CloseConnection = true; return; } - if (message->Protocol != 0x300) { - BLOG_W("invalid protocol version (" << Hex(message->Protocol) << ")"); + if (protocol != 0x300) { + BLOG_W("invalid protocol version (" << Hex(protocol) << ")"); CloseConnection = true; return; } InitialMessage = MakePGMessageCopy(message); if (IsAuthRequired) { - Send(DatabaseProxy, new TEvPGEvents::TEvAuthRequest(InitialMessage)); + Send(DatabaseProxy, new TEvPGEvents::TEvAuth(InitialMessage), 0, IncomingSequenceNumber++); } else { SendAuthOk(); FinishHandshake(); @@ -287,7 +353,7 @@ protected: void HandleMessage(const TPGPasswordMessage* message) { PasswordWasSupplied = true; - Send(DatabaseProxy, new TEvPGEvents::TEvAuthRequest(InitialMessage, MakePGMessageCopy(message))); + Send(DatabaseProxy, new TEvPGEvents::TEvAuth(InitialMessage, MakePGMessageCopy(message)), 0, IncomingSequenceNumber++); return; } @@ -295,98 +361,265 @@ protected: if (message->GetQuery().empty()) { SendMessage(TPGEmptyQueryResponse()); } else { - Send(DatabaseProxy, new TEvPGEvents::TEvQuery(MakePGMessageCopy(message))); + Send(DatabaseProxy, new TEvPGEvents::TEvQuery(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++); } } + void HandleMessage(const TPGParse* message) { + if (message->GetQueryData().Query.empty()) { + SendMessage(TPGEmptyQueryResponse()); + } else { + Send(DatabaseProxy, new TEvPGEvents::TEvParse(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++); + } + } + + void HandleMessage(const TPGParameterStatus* message) { + Y_UNUSED(message); + } + + void HandleMessage(const TPGBind* message) { + Send(DatabaseProxy, new TEvPGEvents::TEvBind(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++); + } + + // void HandleMessage(const TPGDescribe* message) { + // Y_UNUSED(message); + // // describe current statement + // auto ev = std::make_unique<TEvPGEvents::TEvRowDescription>(); + // ev->Fields.push_back({ + // .Name = "column1", + // }); + // Send(SelfId(), ev.release()); + // } + + // void HandleMessage(const TPGExecute* message) { + // Y_UNUSED(message); + // // execute current statement + // auto ev = std::make_unique<NPG::TEvPGEvents::TEvDataRows>(); + // { + // ev->Rows.emplace_back(); + // auto& row = ev->Rows.back(); + // row.resize(1); + // row[0] = "345"; + // } + // { + // ev->Rows.emplace_back(); + // auto& row = ev->Rows.back(); + // row.resize(1); + // row[0] = "456"; + // } + // Send(SelfId(), ev.release()); + // // + // Send(SelfId(), new NPG::TEvPGEvents::TEvCommandComplete("OK")); + // } + + void HandleMessage(const TPGDescribe* message) { + Send(DatabaseProxy, new TEvPGEvents::TEvDescribe(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++); + } + + void HandleMessage(const TPGExecute* message) { + Send(DatabaseProxy, new TEvPGEvents::TEvExecute(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++); + } + void HandleMessage(const TPGTerminate*) { CloseConnection = true; } - void HandleConnected(TEvPGEvents::TEvAuthResponse::TPtr& ev) { - if (ev->Get()->Error) { - if (PasswordWasSupplied) { - SendAuthError(ev->Get()->Error); - CloseConnection = true; - } else { - SendAuthClearText(); - } - } else { - SendAuthOk(); - FinishHandshake(); + bool FlushAndPoll() { + if (FlushOutput()) { + RequestPoller(); + return true; } + return false; + } - if (!FlushOutput()) { - return; + struct TEventsComparator { + bool operator ()(const TAutoPtr<IEventHandle>& ev1, const TAutoPtr<IEventHandle>& ev2) const { + return ev1->Cookie < ev2->Cookie; } - RequestPoller(); + }; + + template<typename TEv> + bool IsEventExpected(const TAutoPtr<TEventHandle<TEv>>& ev) { + return (ev->Cookie == 0) || (ev->Cookie == OutgoingSequenceNumber); } - void HandleConnected(TEvPGEvents::TEvRowDescription::TPtr& ev) { - TPGStream<TPGRowDescription> rowDescription; - rowDescription << uint16_t(ev->Get()->Fields.size()); // number of fields - for (const auto& field : ev->Get()->Fields) { - rowDescription - << TStringBuf(field.Name) << '\0' - << uint32_t(field.TableId) - << uint16_t(field.ColumnId) - << uint32_t(field.DataType) - << uint16_t(field.DataTypeSize) - << uint32_t(0xffffffff) // type modifier - << uint16_t(0) // format text - ; - } - SendStream(rowDescription); + template<typename TEv> + void PostponeEvent(const TAutoPtr<TEventHandle<TEv>>& ev) { + TAutoPtr<IEventHandle> evb = ev.Release(); + BLOG_D("Postpone event " << evb->Cookie); + auto it = std::upper_bound(PostponedEvents.begin(), PostponedEvents.end(), evb, TEventsComparator()); + PostponedEvents.insert(it, evb); + } - if (!FlushOutput()) { - return; + void ReplayPostponedEvents() { + if (!PostponedEvents.empty()) { + auto event = PostponedEvents.front(); + PostponedEvents.pop_front(); + StateConnected(event, TActivationContext::AsActorContext()); } - RequestPoller(); } - void HandleConnected(TEvPGEvents::TEvDataRows::TPtr& ev) { - for (const auto& row : ev->Get()->Rows) { - TPGStream<TPGDataRow> dataRow; - dataRow << uint16_t(row.size()); // number of fields - for (const auto& item : row) { - dataRow << uint32_t(item.size()) << item; + void HandleConnected(TEvPGEvents::TEvAuthResponse::TPtr& ev) { + if (IsEventExpected(ev)) { + if (ev->Get()->Error) { + if (PasswordWasSupplied) { + SendAuthError(ev->Get()->Error); + CloseConnection = true; + } else { + SendAuthClearText(); + } + } else { + SendAuthOk(); + FinishHandshake(); } - SendStream(dataRow); + ++OutgoingSequenceNumber; + ReplayPostponedEvents(); + FlushAndPoll(); + } else { + PostponeEvent(ev); } + } - if (!FlushOutput()) { - return; + void HandleConnected(TEvPGEvents::TEvQueryResponse::TPtr& ev) { + if (IsEventExpected(ev)) { + if (ev->Get()->ErrorFields.empty()) { + TString tag = "OK"; + { // rowDescription + TPGStreamOutput<TPGRowDescription> rowDescription; + rowDescription << uint16_t(ev->Get()->DataFields.size()); // number of fields + for (const auto& field : ev->Get()->DataFields) { + rowDescription + << TStringBuf(field.Name) << '\0' + << uint32_t(field.TableId) + << uint16_t(field.ColumnId) + << uint32_t(field.DataType) + << uint16_t(field.DataTypeSize) + << uint32_t(0xffffffff) // type modifier + << uint16_t(0) // format text + ; + } + SendStream(rowDescription); + } + { // dataFields + for (const auto& row : ev->Get()->DataRows) { + TPGStreamOutput<TPGDataRow> dataRow; + dataRow << uint16_t(row.size()); // number of fields + for (const auto& item : row) { + dataRow << uint32_t(item.size()) << item; + } + SendStream(dataRow); + } + } + { // commandComplete + TPGStreamOutput<TPGCommandComplete> commandComplete; + commandComplete << tag << '\0'; + SendStream(commandComplete); + } + } else { + // error response + TPGStreamOutput<TPGErrorResponse> errorResponse; + for (const auto& field : ev->Get()->ErrorFields) { + errorResponse << field.first << field.second << '\0'; + } + errorResponse << '\0'; + SendStream(errorResponse); + } + SendReadyForQuery(); + ++OutgoingSequenceNumber; + ReplayPostponedEvents(); + FlushAndPoll(); + } else { + PostponeEvent(ev); } - RequestPoller(); } - void HandleConnected(TEvPGEvents::TEvCommandComplete::TPtr& ev) { - TPGStream<TPGCommandComplete> commandComplete; - commandComplete << ev->Get()->Tag << '\0'; - SendStream(commandComplete); - - SendReadyForQuery(); - - if (!FlushOutput()) { - return; + void HandleConnected(TEvPGEvents::TEvDescribeResponse::TPtr& ev) { + if (IsEventExpected(ev)) { + TString tag = "OK"; + { // rowDescription + TPGStreamOutput<TPGRowDescription> rowDescription; + rowDescription << uint16_t(ev->Get()->DataFields.size()); // number of fields + for (const auto& field : ev->Get()->DataFields) { + rowDescription + << TStringBuf(field.Name) << '\0' + << uint32_t(field.TableId) + << uint16_t(field.ColumnId) + << uint32_t(field.DataType) + << uint16_t(field.DataTypeSize) + << uint32_t(0xffffffff) // type modifier + << uint16_t(0) // format text + ; + } + SendStream(rowDescription); + } + ++OutgoingSequenceNumber; + ReplayPostponedEvents(); + FlushAndPoll(); + } else { + PostponeEvent(ev); } - RequestPoller(); } - void HandleConnected(TEvPGEvents::TEvErrorResponse::TPtr& ev) { - TPGStream<TPGErrorResponse> errorResponse; - for (const auto& field : ev->Get()->ErrorFields) { - errorResponse << field.first << field.second << '\0'; + void HandleConnected(TEvPGEvents::TEvExecuteResponse::TPtr& ev) { + if (IsEventExpected(ev)) { + if (ev->Get()->ErrorFields.empty()) { + TString tag = "OK"; + { // dataFields + for (const auto& row : ev->Get()->DataRows) { + TPGStreamOutput<TPGDataRow> dataRow; + dataRow << uint16_t(row.size()); // number of fields + for (const auto& item : row) { + dataRow << uint32_t(item.size()) << item; + } + SendStream(dataRow); + } + } + { // commandComplete + TPGStreamOutput<TPGCommandComplete> commandComplete; + commandComplete << tag << '\0'; + SendStream(commandComplete); + } + } else { + // error response + TPGStreamOutput<TPGErrorResponse> errorResponse; + for (const auto& field : ev->Get()->ErrorFields) { + errorResponse << field.first << field.second << '\0'; + } + errorResponse << '\0'; + SendStream(errorResponse); + } + SendReadyForQuery(); + ++OutgoingSequenceNumber; + ReplayPostponedEvents(); + FlushAndPoll(); + } else { + PostponeEvent(ev); } - errorResponse << '\0'; - SendStream(errorResponse); + } - SendReadyForQuery(); + void HandleConnected(TEvPGEvents::TEvParseResponse::TPtr& ev) { + if (IsEventExpected(ev)) { + TPGStreamOutput<TPGParseComplete> parseComplete; + SendStream(parseComplete); + SendReadyForQuery(); + ++OutgoingSequenceNumber; + ReplayPostponedEvents(); + FlushAndPoll(); + } else { + PostponeEvent(ev); + } + } - if (!FlushOutput()) { - return; + void HandleConnected(TEvPGEvents::TEvBindResponse::TPtr& ev) { + if (IsEventExpected(ev)) { + TPGStreamOutput<TPGBindComplete> bindComplete; + SendStream(bindComplete); + ++OutgoingSequenceNumber; + ReplayPostponedEvents(); + FlushAndPoll(); + } else { + PostponeEvent(ev); } - RequestPoller(); } bool HasInputMessage() const { @@ -422,7 +655,7 @@ protected: BufferInput.Advance(res); while (HasInputMessage()) { const TPGMessage* message = GetInputMessage(); - PrintMessage("-> ", *message); + PrintMessage(EDirection::Incoming, *message); switch (static_cast<EMessageCode>(message->Message)) { case EMessageCode::Initial: HandleMessage(static_cast<const TPGInitial*>(message)); @@ -436,8 +669,23 @@ protected: case EMessageCode::PasswordMessage: HandleMessage(static_cast<const TPGPasswordMessage*>(message)); break; + case EMessageCode::Parse: + HandleMessage(static_cast<const TPGParse*>(message)); + break; + case EMessageCode::ParameterStatus: + HandleMessage(static_cast<const TPGParameterStatus*>(message)); + break; + case EMessageCode::Bind: + HandleMessage(static_cast<const TPGBind*>(message)); + break; + case EMessageCode::Describe: + HandleMessage(static_cast<const TPGDescribe*>(message)); + break; + case EMessageCode::Execute: + HandleMessage(static_cast<const TPGExecute*>(message)); + break; default: - BLOG_W("invalid message (" << message->Message << ")"); + BLOG_ERROR("invalid message (" << message->Message << ")"); CloseConnection = true; break; } @@ -536,10 +784,11 @@ protected: hFunc(TEvPollerRegisterResult, HandleConnected); hFunc(TEvPGEvents::TEvAuthResponse, HandleConnected); - hFunc(TEvPGEvents::TEvRowDescription, HandleConnected); - hFunc(TEvPGEvents::TEvDataRows, HandleConnected); - hFunc(TEvPGEvents::TEvCommandComplete, HandleConnected); - hFunc(TEvPGEvents::TEvErrorResponse, HandleConnected); + hFunc(TEvPGEvents::TEvQueryResponse, HandleConnected); + hFunc(TEvPGEvents::TEvParseResponse, HandleConnected); + hFunc(TEvPGEvents::TEvBindResponse, HandleConnected); + hFunc(TEvPGEvents::TEvDescribeResponse, HandleConnected); + hFunc(TEvPGEvents::TEvExecuteResponse, HandleConnected); } } }; diff --git a/ydb/core/pgproxy/pg_proxy_events.h b/ydb/core/pgproxy/pg_proxy_events.h index a1464d503a1..7fdedffb944 100644 --- a/ydb/core/pgproxy/pg_proxy_events.h +++ b/ydb/core/pgproxy/pg_proxy_events.h @@ -10,18 +10,35 @@ struct TEvPGEvents { enum EEv { EvConnectionOpened = EventSpaceBegin(NActors::TEvents::ES_PGWIRE), EvConnectionClosed, - EvAuthRequest, + EvAuth, EvAuthResponse, EvQuery, - EvRowDescription, - EvDataRows, - EvCommandComplete, - EvErrorResponse, + EvQueryResponse, + EvParse, + EvParseResponse, + EvBind, + EvBindResponse, + EvDescribe, + EvDescribeResponse, + EvExecute, + EvExecuteResponse, EvEnd }; static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PGWIRE), "ES_PGWIRE event space is too small."); + struct TRowDescriptionField { + TString Name; + uint32_t TableId = 0; + uint16_t ColumnId = 0; + uint32_t DataType; + uint16_t DataTypeSize; + //uint32_t DataTypeModifier; + //uint16_t Format; + }; + + using TDataRow = std::vector<TString>; + struct TEvConnectionOpened : NActors::TEventLocal<TEvConnectionOpened, EvConnectionOpened> { std::shared_ptr<TPGInitial> Message; @@ -33,15 +50,15 @@ struct TEvPGEvents { struct TEvConnectionClosed : NActors::TEventLocal<TEvConnectionClosed, EvConnectionClosed> { }; - struct TEvAuthRequest : NActors::TEventLocal<TEvAuthRequest, EvAuthRequest> { + struct TEvAuth : NActors::TEventLocal<TEvAuth, EvAuth> { std::shared_ptr<TPGInitial> InitialMessage; std::unique_ptr<TPGPasswordMessage> PasswordMessage; - TEvAuthRequest(std::shared_ptr<TPGInitial> initialMessage) + TEvAuth(std::shared_ptr<TPGInitial> initialMessage) : InitialMessage(std::move(initialMessage)) {} - TEvAuthRequest(std::shared_ptr<TPGInitial> initialMessage, std::unique_ptr<TPGPasswordMessage> message) + TEvAuth(std::shared_ptr<TPGInitial> initialMessage, std::unique_ptr<TPGPasswordMessage> message) : InitialMessage(std::move(initialMessage)) , PasswordMessage(std::move(message)) {} @@ -72,30 +89,10 @@ struct TEvPGEvents { {} }; - struct TEvRowDescription : NActors::TEventLocal<TEvRowDescription, EvRowDescription> { - struct TField { - TString Name; - uint32_t TableId = 0; - uint16_t ColumnId = 0; - uint32_t DataType; - uint16_t DataTypeSize; - //uint32_t DataTypeModifier; - //uint16_t Format; - }; - std::vector<TField> Fields; - }; - - struct TEvDataRows : NActors::TEventLocal<TEvDataRows, EvDataRows> { - using TDataRow = std::vector<TString>; - std::vector<TDataRow> Rows; - }; - - struct TEvCommandComplete : NActors::TEventLocal<TEvCommandComplete, EvCommandComplete> { - TString Tag; - - TEvCommandComplete(const TString& tag) - : Tag(tag) - {} + struct TEvQueryResponse : NActors::TEventLocal<TEvQueryResponse, EvQueryResponse> { + std::vector<TRowDescriptionField> DataFields; + std::vector<TDataRow> DataRows; + std::vector<std::pair<char, TString>> ErrorFields; }; /* @@ -127,7 +124,69 @@ struct TEvPGEvents { R = "scanner_yyerror" */ - struct TEvErrorResponse : NActors::TEventLocal<TEvErrorResponse, EvErrorResponse> { + struct TEvParseResponse : NActors::TEventLocal<TEvParseResponse, EvParseResponse> { + std::unique_ptr<TPGParse> OriginalMessage; + + TEvParseResponse(std::unique_ptr<TPGParse> originalMessage) + : OriginalMessage(std::move(originalMessage)) + {} + }; + + struct TEvParse : NActors::TEventLocal<TEvParse, EvParse> { + std::unique_ptr<TPGParse> Message; + + TEvParse(std::unique_ptr<TPGParse> message) + : Message(std::move(message)) + {} + + std::unique_ptr<TEvParseResponse> Reply() { + return std::make_unique<TEvParseResponse>(std::move(Message)); + } + }; + + struct TEvBindResponse : NActors::TEventLocal<TEvBindResponse, EvBindResponse> { + std::unique_ptr<TPGBind> OriginalMessage; + + TEvBindResponse(std::unique_ptr<TPGBind> originalMessage) + : OriginalMessage(std::move(originalMessage)) + {} + }; + + struct TEvBind : NActors::TEventLocal<TEvBind, EvBind> { + std::unique_ptr<TPGBind> Message; + + TEvBind(std::unique_ptr<TPGBind> message) + : Message(std::move(message)) + {} + + std::unique_ptr<TEvBindResponse> Reply() { + return std::make_unique<TEvBindResponse>(std::move(Message)); + } + }; + + struct TEvDescribe : NActors::TEventLocal<TEvDescribe, EvDescribe> { + std::unique_ptr<TPGDescribe> Message; + + TEvDescribe(std::unique_ptr<TPGDescribe> message) + : Message(std::move(message)) + {} + }; + + struct TEvDescribeResponse : NActors::TEventLocal<TEvDescribeResponse, EvDescribeResponse> { + std::vector<TRowDescriptionField> DataFields; + std::vector<std::pair<char, TString>> ErrorFields; + }; + + struct TEvExecute : NActors::TEventLocal<TEvExecute, EvExecute> { + std::unique_ptr<TPGExecute> Message; + + TEvExecute(std::unique_ptr<TPGExecute> message) + : Message(std::move(message)) + {} + }; + + struct TEvExecuteResponse : NActors::TEventLocal<TEvExecuteResponse, EvExecuteResponse> { + std::vector<TDataRow> DataRows; std::vector<std::pair<char, TString>> ErrorFields; }; }; diff --git a/ydb/core/pgproxy/pg_proxy_types.cpp b/ydb/core/pgproxy/pg_proxy_types.cpp new file mode 100644 index 00000000000..acb836d83db --- /dev/null +++ b/ydb/core/pgproxy/pg_proxy_types.cpp @@ -0,0 +1,207 @@ +#include "pg_proxy_types.h" +#include "pg_stream.h" + +namespace NPG { + +TString TPGInitial::Dump() const { + TPGStreamInput stream(*this); + TStringBuilder text; + uint32_t protocol = 0; + stream >> protocol; + protocol = htonl(protocol); + if (protocol == 773247492) { // 80877102 cancellation message + uint32_t pid = 0; + uint32_t key = 0; + stream >> pid >> key; + text << "cancellation PID " << pid << " KEY " << key; + } else if (protocol == 790024708) { // 790024708 SSL handshake + text << "SSL handshake"; + } else { + text << "protocol(" << Hex(protocol) << ") "; + while (!stream.Empty()) { + TStringBuf key; + TStringBuf value; + stream >> key >> value; + if (key.empty()) { + break; + } + text << key << "=" << value << " "; + } + } + return text; +} + +uint32_t TPGInitial::GetProtocol() const { + TPGStreamInput stream(*this); + uint32_t protocol = 0; + stream >> protocol; + protocol = htonl(protocol); + return protocol; +} + +std::unordered_map<TString, TString> TPGInitial::GetClientParams() const { + std::unordered_map<TString, TString> params; + TPGStreamInput stream(*this); + TStringBuilder text; + uint32_t protocol = 0; + stream >> protocol; + while (!stream.Empty()) { + TStringBuf key; + TStringBuf value; + stream >> key >> value; + if (key.empty()) { + break; + } + params[TString(key)] = value; + } + return params; +} + +TString TPGErrorResponse::Dump() const { + TPGStreamInput stream(*this); + TStringBuilder text; + + while (!stream.Empty()) { + char code; + TString message; + stream >> code; + if (code == 0) { + break; + } + stream >> message; + if (!text.empty()) { + text << ' '; + } + text << code << "=\"" << message << "\""; + } + return text; +} + +TString TPGParse::Dump() const { + TPGStreamInput stream(*this); + TStringBuf name; + stream >> name; + return TStringBuilder() << "Name:" << name; +} + +TPGParse::TQueryData TPGParse::GetQueryData() const { + TQueryData queryData; + TPGStreamInput stream(*this); + stream >> queryData.Name; + stream >> queryData.Query; + uint16_t numberOfParameterTypes = 0; + stream >> numberOfParameterTypes; + queryData.ParametersTypes.reserve(numberOfParameterTypes); + for (uint16_t n = 0; n < numberOfParameterTypes; ++n) { + uint32_t param = 0; + stream >> param; + queryData.ParametersTypes.emplace_back(param); + } + return queryData; +} + +TPGBind::TBindData TPGBind::GetBindData() const { + TBindData bindData; + TPGStreamInput stream(*this); + stream >> bindData.PortalName; + stream >> bindData.StatementName; + uint16_t numberOfParameterFormats = 0; + stream >> numberOfParameterFormats; + bindData.ParametersFormat.reserve(numberOfParameterFormats); + for (uint16_t n = 0; n < numberOfParameterFormats; ++n) { + uint16_t format = 0; + stream >> format; + bindData.ParametersFormat.emplace_back(format); + } + uint16_t numberOfParameterValues = 0; + stream >> numberOfParameterValues; + bindData.ParametersValue.reserve(numberOfParameterValues); + for (uint16_t n = 0; n < numberOfParameterValues; ++n) { + uint32_t size = 0; + stream >> size; + std::vector<uint8_t> value; + stream.Read(value, size); + bindData.ParametersValue.emplace_back(std::move(value)); + } + uint16_t numberOfResultFormats = 0; + stream >> numberOfResultFormats; + bindData.ResultsFormat.reserve(numberOfResultFormats); + for (uint16_t n = 0; n < numberOfResultFormats; ++n) { + uint16_t format = 0; + stream >> format; + bindData.ResultsFormat.emplace_back(format); + } + return bindData; +} + +TString TPGBind::Dump() const { + TStringBuilder text; + TPGStreamInput stream(*this); + TStringBuf portalName; + TStringBuf statementName; + stream >> portalName; + stream >> statementName; + if (portalName) { + text << "Portal: " << portalName; + } else if (statementName) { + text << "Statement: " << statementName; + } + return text; +} + +TString TPGDataRow::Dump() const { + TPGStreamInput stream(*this); + uint16_t numberOfColumns = 0; + stream >> numberOfColumns; + return TStringBuilder() << "Columns: " << numberOfColumns; +} + +TPGDescribe::TDescribeData TPGDescribe::GetDescribeData() const { + TPGStreamInput stream(*this); + TDescribeData data; + char describeType = 0; + stream >> describeType; + data.Type = static_cast<TDescribeData::EDescribeType>(describeType); + stream >> data.Name; + return data; +} + +TString TPGDescribe::Dump() const { + TPGStreamInput stream(*this); + TStringBuilder text; + char describeType = 0; + stream >> describeType; + text << "Type:" << describeType; + TStringBuf name; + stream >> name; + text << " Name:" << name; + return text; +} + +TPGExecute::TExecuteData TPGExecute::GetExecuteData() const { + TPGStreamInput stream(*this); + TExecuteData data; + stream >> data.PortalName >> data.MaxRows; + return data; +} + +TString TPGExecute::Dump() const { + TPGStreamInput stream(*this); + TStringBuilder text; + TStringBuf name; + uint32_t maxRows = 0; + stream >> name >> maxRows; + if (name) { + text << "Name: " << name; + } + if (maxRows) { + if (!text.empty()) { + text << ' '; + } + text << "MaxRows: " << maxRows; + } + return text; +} + + +}
\ No newline at end of file diff --git a/ydb/core/pgproxy/pg_proxy_types.h b/ydb/core/pgproxy/pg_proxy_types.h index caf7bacf081..00ac72724a7 100644 --- a/ydb/core/pgproxy/pg_proxy_types.h +++ b/ydb/core/pgproxy/pg_proxy_types.h @@ -1,7 +1,10 @@ #pragma once #include <cstdint> +#include <unordered_map> +#include <vector> #include <util/stream/format.h> +#include <util/string/builder.h> #include <arpa/inet.h> namespace NPG { @@ -30,78 +33,20 @@ struct TPGMessage { void operator delete(void* p) { ::operator delete(p); } + + bool Empty() const { + return GetDataSize() == 0; + } }; struct TPGInitial : TPGMessage { // it's not true, because we don't receive message code from a network, but imply it on the start - uint32_t Protocol; - TPGInitial() { Message = 'i'; // fake code } - const char* GetData() const { - return reinterpret_cast<const char*>(this) + sizeof(*this); - } - - TString Dump() const { - TStringBuilder text; - if (Protocol == 773247492) { // 80877102 cancellation message - const uint32_t* data = reinterpret_cast<const uint32_t*>(GetData()); - if (GetDataSize() >= 8) { - uint32_t pid = data[0]; - uint32_t key = data[1]; - text << "cancellation PID " << pid << " KEY " << key; - } - } else if (Protocol == 790024708) { // 790024708 SSL handshake - text << "SSL handshake"; - } else { - text << "protocol(" << Hex(Protocol) << ") "; - const char* begin = GetData(); - const char* end = GetData() + GetDataSize(); - for (const char* ptr = begin; ptr < end;) { - TStringBuf key; - TStringBuf value; - size_t size = strnlen(ptr, end - ptr); - key = TStringBuf(ptr, size); - if (key.empty()) { - break; - } - ptr += size + 1; - if (ptr >= end) { - break; - } - size = strnlen(ptr, end - ptr); - value = TStringBuf(ptr, size); - ptr += size + 1; - text << key << "=" << value << " "; - } - } - return text; - } - - std::unordered_map<TString, TString> GetClientParams() const { - std::unordered_map<TString, TString> params; - const char* begin = GetData(); - const char* end = GetData() + GetDataSize(); - for (const char* ptr = begin; ptr < end;) { - TString key; - TString value; - size_t size = strnlen(ptr, end - ptr); - key = TStringBuf(ptr, size); - if (key.empty()) { - break; - } - ptr += size + 1; - if (ptr >= end) { - break; - } - size = strnlen(ptr, end - ptr); - value = TStringBuf(ptr, size); - ptr += size + 1; - params[key] = value; - } - return params; - } + TString Dump() const; + std::unordered_map<TString, TString> GetClientParams() const; + uint32_t GetProtocol() const; }; struct TPGAuth : TPGMessage { @@ -166,7 +111,11 @@ struct TPGParameterStatus : TPGMessage { } TString Dump() const { - return TStringBuilder() << GetName() << "=" << GetValue(); + if (!Empty()) { + return TStringBuilder() << GetName() << "=" << GetValue(); + } else { + return {}; + } } }; @@ -221,6 +170,8 @@ struct TPGErrorResponse : TPGMessage { TPGErrorResponse() { Message = 'E'; } + + TString Dump() const; }; struct TPGTerminate : TPGMessage { @@ -239,6 +190,8 @@ struct TPGDataRow : TPGMessage { TPGDataRow() { Message = 'D'; } + + TString Dump() const; }; struct TPGEmptyQueryResponse : TPGMessage { @@ -247,6 +200,81 @@ struct TPGEmptyQueryResponse : TPGMessage { } }; +struct TPGParse : TPGMessage { + TPGParse() { + Message = 'P'; + } + + struct TQueryData { + TString Name; + TString Query; + std::vector<int32_t> ParametersTypes; // types + }; + + TQueryData GetQueryData() const; + TString Dump() const; +}; + +struct TPGParseComplete : TPGMessage { + TPGParseComplete() { + Message = '1'; + } +}; + +struct TPGBind : TPGMessage { + TPGBind() { + Message = 'B'; + } + + struct TBindData { + TString PortalName; + TString StatementName; + std::vector<int16_t> ParametersFormat; // format codes 0=text, 1=binary + std::vector<std::vector<uint8_t>> ParametersValue; // parameters content + std::vector<int16_t> ResultsFormat; // result format codes 0=text, 1=binary + }; + + TBindData GetBindData() const; + TString Dump() const; +}; + +struct TPGBindComplete : TPGMessage { + TPGBindComplete() { + Message = '2'; + } +}; + +struct TPGDescribe : TPGMessage { + TPGDescribe() { + Message = 'D'; + } + + struct TDescribeData { + enum class EDescribeType : char { + Portal = 'P', + Statement = 'S', + }; + EDescribeType Type; + TString Name; + }; + + TDescribeData GetDescribeData() const; + TString Dump() const; +}; + +struct TPGExecute : TPGMessage { + TPGExecute() { + Message = 'E'; + } + + struct TExecuteData { + TString PortalName; + uint32_t MaxRows; + }; + + TExecuteData GetExecuteData() const; + TString Dump() const; +}; #pragma pack(pop) template<typename TPGMessageType> diff --git a/ydb/core/pgproxy/pg_proxy_ut.cpp b/ydb/core/pgproxy/pg_proxy_ut.cpp index 472806b2fb3..2e02e98deee 100644 --- a/ydb/core/pgproxy/pg_proxy_ut.cpp +++ b/ydb/core/pgproxy/pg_proxy_ut.cpp @@ -64,7 +64,7 @@ Y_UNIT_TEST_SUITE(TPGTest) { } TSocket s(TNetworkAddress("::", port)); Send(s, "0000001300030000" "7573657200757365720000"); // user=user - NPG::TEvPGEvents::TEvAuthRequest* authRequest = actorSystem.GrabEdgeEvent<NPG::TEvPGEvents::TEvAuthRequest>(handle); + NPG::TEvPGEvents::TEvAuth* authRequest = actorSystem.GrabEdgeEvent<NPG::TEvPGEvents::TEvAuth>(handle); UNIT_ASSERT(authRequest); UNIT_ASSERT_VALUES_EQUAL(authRequest->InitialMessage->GetClientParams()["user"], "user"); actorSystem.Send(new NActors::IEventHandle(handle->Sender, database, new NPG::TEvPGEvents::TEvAuthResponse())); diff --git a/ydb/core/pgproxy/pg_stream.h b/ydb/core/pgproxy/pg_stream.h index bb7ea938287..c1eb55c0287 100644 --- a/ydb/core/pgproxy/pg_stream.h +++ b/ydb/core/pgproxy/pg_stream.h @@ -1,15 +1,17 @@ #pragma once #include "pg_proxy_types.h" +#include <util/generic/buffer.h> +#include <util/generic/strbuf.h> namespace NPG { template<typename TMessage> -class TPGStream : public TBuffer { +class TPGStreamOutput : public TBuffer { public: using TBase = TBuffer; - TPGStream() { + TPGStreamOutput() { TMessage header; TBase::Append(reinterpret_cast<const char*>(&header), sizeof(header)); } @@ -19,27 +21,90 @@ public: header->Length = htonl(TBase::Size() - sizeof(char)); } - TPGStream& operator <<(uint16_t v) { + TPGStreamOutput& operator <<(uint16_t v) { v = htons(v); TBase::Append(reinterpret_cast<const char*>(&v), sizeof(v)); return *this; } - TPGStream& operator <<(uint32_t v) { + TPGStreamOutput& operator <<(uint32_t v) { v = htonl(v); TBase::Append(reinterpret_cast<const char*>(&v), sizeof(v)); return *this; } - TPGStream& operator <<(char v) { + TPGStreamOutput& operator <<(char v) { TBase::Append(v); return *this; } - TPGStream& operator <<(TStringBuf s) { + TPGStreamOutput& operator <<(TStringBuf s) { TBase::Append(s.data(), s.size()); return *this; } }; +class TPGStreamInput { +public: + TPGStreamInput(const TPGMessage& message) + : Buffer(message.GetData(), message.GetDataSize()) + { + } + + TPGStreamInput& operator >>(TString& s) { + s = Buffer.NextTok('\0'); + return *this; + } + + TPGStreamInput& operator >>(TStringBuf& s) { + s = Buffer.NextTok('\0'); + return *this; + } + + TPGStreamInput& operator >>(char& v) { + if (Buffer.size() >= sizeof(v)) { + v = *reinterpret_cast<const char*>(Buffer.data()); + Buffer.Skip(sizeof(v)); + } else { + v = {}; + } + return *this; + } + + TPGStreamInput& operator >>(uint16_t& v) { + if (Buffer.size() >= sizeof(v)) { + v = ntohs(*reinterpret_cast<const uint16_t*>(Buffer.data())); + Buffer.Skip(sizeof(v)); + } else { + v = {}; + } + return *this; + } + + TPGStreamInput& operator >>(uint32_t& v) { + if (Buffer.size() >= sizeof(v)) { + v = ntohl(*reinterpret_cast<const uint32_t*>(Buffer.data())); + Buffer.Skip(sizeof(v)); + } else { + v = {}; + } + return *this; + } + + TPGStreamInput& Read(std::vector<uint8_t>& data, uint32_t size) { + size = std::min<uint32_t>(size, Buffer.size()); + data.resize(size); + memcpy(data.data(), Buffer.data(), size); + Buffer.Skip(size); + return *this; + } + + bool Empty() const { + return Buffer.Empty(); + } + +protected: + TStringBuf Buffer; +}; + } |