aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorxenoxeno <xeno@ydb.tech>2022-12-30 18:37:20 +0300
committerxenoxeno <xeno@ydb.tech>2022-12-30 18:37:20 +0300
commit20070209a718431338a2eed7e45673d99c9f3738 (patch)
tree851516f961d2252654973bc02f3394b2605a7351
parent08fb420704acef4bf445c80b67bc9f9fc648583f (diff)
downloadydb-20070209a718431338a2eed7e45673d99c9f3738.tar.gz
extended query protocol
-rw-r--r--ydb/core/base/ticket_parser.h2
-rw-r--r--ydb/core/pgproxy/CMakeLists.darwin.txt1
-rw-r--r--ydb/core/pgproxy/CMakeLists.linux-aarch64.txt1
-rw-r--r--ydb/core/pgproxy/CMakeLists.linux.txt1
-rw-r--r--ydb/core/pgproxy/pg_connection.cpp427
-rw-r--r--ydb/core/pgproxy/pg_proxy_events.h125
-rw-r--r--ydb/core/pgproxy/pg_proxy_types.cpp207
-rw-r--r--ydb/core/pgproxy/pg_proxy_types.h160
-rw-r--r--ydb/core/pgproxy/pg_proxy_ut.cpp2
-rw-r--r--ydb/core/pgproxy/pg_stream.h77
10 files changed, 807 insertions, 196 deletions
diff --git a/ydb/core/base/ticket_parser.h b/ydb/core/base/ticket_parser.h
index d1b8454e6b9..3e5c14d92ee 100644
--- a/ydb/core/base/ticket_parser.h
+++ b/ydb/core/base/ticket_parser.h
@@ -1,4 +1,4 @@
-#pragma once
+ #pragma once
#include <library/cpp/containers/stack_vector/stack_vec.h>
#include <ydb/core/base/defs.h>
#include <ydb/core/base/events.h>
diff --git a/ydb/core/pgproxy/CMakeLists.darwin.txt b/ydb/core/pgproxy/CMakeLists.darwin.txt
index c2f7d21243d..a6984d0aa65 100644
--- a/ydb/core/pgproxy/CMakeLists.darwin.txt
+++ b/ydb/core/pgproxy/CMakeLists.darwin.txt
@@ -19,5 +19,6 @@ target_link_libraries(ydb-core-pgproxy PUBLIC
target_sources(ydb-core-pgproxy PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_connection.cpp
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_listener.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy_types.cpp
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy.cpp
)
diff --git a/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt b/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt
index 783f7a53b77..0db70f26363 100644
--- a/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt
+++ b/ydb/core/pgproxy/CMakeLists.linux-aarch64.txt
@@ -20,5 +20,6 @@ target_link_libraries(ydb-core-pgproxy PUBLIC
target_sources(ydb-core-pgproxy PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_connection.cpp
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_listener.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy_types.cpp
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy.cpp
)
diff --git a/ydb/core/pgproxy/CMakeLists.linux.txt b/ydb/core/pgproxy/CMakeLists.linux.txt
index 783f7a53b77..0db70f26363 100644
--- a/ydb/core/pgproxy/CMakeLists.linux.txt
+++ b/ydb/core/pgproxy/CMakeLists.linux.txt
@@ -20,5 +20,6 @@ target_link_libraries(ydb-core-pgproxy PUBLIC
target_sources(ydb-core-pgproxy PRIVATE
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_connection.cpp
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_listener.cpp
+ ${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy_types.cpp
${CMAKE_SOURCE_DIR}/ydb/core/pgproxy/pg_proxy.cpp
)
diff --git a/ydb/core/pgproxy/pg_connection.cpp b/ydb/core/pgproxy/pg_connection.cpp
index 135ccd92240..dbd2015ebcf 100644
--- a/ydb/core/pgproxy/pg_connection.cpp
+++ b/ydb/core/pgproxy/pg_connection.cpp
@@ -12,11 +12,17 @@ using namespace NActors;
class TPGConnection : public TActorBootstrapped<TPGConnection>, public TNetworkConfig {
public:
+ // incoming messages only
enum class EMessageCode : char {
Initial = 'i',
Query = 'Q',
Terminate = 'X',
PasswordMessage = 'p',
+ Parse = 'P',
+ ParameterStatus = 'S',
+ Bind = 'B',
+ Describe = 'D',
+ Execute = 'E',
};
using TBase = TActorBootstrapped<TPGConnection>;
@@ -43,6 +49,9 @@ public:
TSocketBuffer BufferOutput;
TActorId DatabaseProxy;
std::shared_ptr<TPGInitial> InitialMessage;
+ ui64 IncomingSequenceNumber = 1;
+ ui64 OutgoingSequenceNumber = 1;
+ std::deque<TAutoPtr<IEventHandle>> PostponedEvents;
TPGConnection(TIntrusivePtr<TSocketDescriptor> socket, TNetworkConfig::TSocketAddressType address, const TActorId& databaseProxy)
: Socket(std::move(socket))
@@ -117,7 +126,12 @@ protected:
OnAccept();
}
- TStringBuf GetMessageName(const TPGMessage& message) const {
+ enum class EDirection {
+ Incoming,
+ Outgoing,
+ };
+
+ TStringBuf GetMessageName(EDirection direction, const TPGMessage& message) const {
static const std::unordered_map<char, TStringBuf> messageName = {
{'i', "Initial"},
{'R', "Auth"},
@@ -126,20 +140,41 @@ protected:
{'C', "CommandComplete"},
{'X', "Terminate"},
{'T', "RowDescription"},
- {'D', "DataRow"},
{'S', "ParameterStatus"},
- {'E', "ErrorResponse"},
{'I', "EmptyQueryResponse"},
{'p', "PasswordMessage"},
+ {'P', "Parse"},
+ {'1', "ParseComplete"},
+ {'B', "Bind"},
+ {'2', "BindComplete"},
};
auto itMessageName = messageName.find(message.Message);
if (itMessageName != messageName.end()) {
return itMessageName->second;
}
+ static const std::unordered_map<char, TStringBuf> incomingMessageName = {
+ {'E', "Execute"},
+ {'D', "Describe"},
+ };
+ static const std::unordered_map<char, TStringBuf> outgoingMessageName = {
+ {'E', "ErrorResponse"},
+ {'D', "DataRow"},
+ };
+ switch (direction) {
+ case EDirection::Incoming:
+ itMessageName = incomingMessageName.find(message.Message);
+ break;
+ case EDirection::Outgoing:
+ itMessageName = outgoingMessageName.find(message.Message);
+ break;
+ }
+ if (itMessageName != messageName.end()) {
+ return itMessageName->second;
+ }
return {};
}
- TString GetMessageDump(const TPGMessage& message) const {
+ TString GetMessageDump(EDirection direction, const TPGMessage& message) const {
switch (message.Message) {
case 'i':
return ((const TPGInitial&)message).Dump();
@@ -153,36 +188,66 @@ protected:
return ((const TPGCommandComplete&)message).Dump();
case 'R':
return ((const TPGAuth&)message).Dump();
+ case 'D': {
+ switch (direction) {
+ case EDirection::Incoming:
+ return ((const TPGDescribe&)message).Dump();
+ case EDirection::Outgoing:
+ return ((const TPGDataRow&)message).Dump();
+ }
+ }
+ case 'E': {
+ switch (direction) {
+ case EDirection::Incoming:
+ return ((const TPGExecute&)message).Dump();
+ case EDirection::Outgoing:
+ return ((const TPGErrorResponse&)message).Dump();
+ }
+ }
+ case 'B':
+ return ((const TPGBind&)message).Dump();
+ case 'P':
+ return ((const TPGParse&)message).Dump();
+
}
return {};
}
- void PrintMessage(const TStringBuf& prefix, const TPGMessage& message) {
- BLOG_D(prefix << "'" << message.Message << "' \"" << GetMessageName(message) << "\" Size(" << ntohl(message.Length) << ") " << GetMessageDump(message));
+ void PrintMessage(EDirection direction, const TPGMessage& message) {
+ TStringBuilder prefix;
+ switch (direction) {
+ case EDirection::Incoming:
+ prefix << "-> [" << IncomingSequenceNumber << "] ";
+ break;
+ case EDirection::Outgoing:
+ prefix << "<- [" << OutgoingSequenceNumber << "] ";
+ break;
+ }
+ BLOG_D(prefix << "'" << message.Message << "' \"" << GetMessageName(direction, message) << "\" Size(" << message.GetDataSize() << ") " << GetMessageDump(direction, message));
}
template<typename TMessage>
void SendMessage(const TMessage& message) {
- PrintMessage("<- ", message);
+ PrintMessage(EDirection::Outgoing, message);
BufferOutput.Append(reinterpret_cast<const char*>(&message), sizeof(message));
}
template<typename TMessage>
- void SendStream(TPGStream<TMessage>& message) {
+ void SendStream(TPGStreamOutput<TMessage>& message) {
message.UpdateLength();
const TPGMessage& header = *reinterpret_cast<const TPGMessage*>(message.Data());
- PrintMessage("<- ", header);
+ PrintMessage(EDirection::Outgoing, header);
BufferOutput.Append(message.Data(), message.Size());
}
void SendAuthOk() {
- TPGStream<TPGAuth> authOk;
+ TPGStreamOutput<TPGAuth> authOk;
authOk << uint32_t(TPGAuth::EAuthCode::OK);
SendStream(authOk);
}
void SendAuthClearText() {
- TPGStream<TPGAuth> authClearText;
+ TPGStreamOutput<TPGAuth> authClearText;
authClearText << uint32_t(TPGAuth::EAuthCode::ClearText);
SendStream(authClearText);
}
@@ -211,19 +276,19 @@ protected:
0x0190: 0045 7463 2f55 5443 004b 0000 000c 0004 .Etc/UTC.K......
*/
void SendParameterStatus(TStringBuf name, TStringBuf value) {
- TPGStream<TPGParameterStatus> param;
+ TPGStreamOutput<TPGParameterStatus> param;
param << name << '\0' << value << '\0';
SendStream(param);
}
void SendReadyForQuery() {
- TPGStream<TPGReadyForQuery> readyForQuery;
+ TPGStreamOutput<TPGReadyForQuery> readyForQuery;
readyForQuery << 'I';
SendStream(readyForQuery);
}
void SendAuthError(const TString& error) {
- TPGStream<TPGErrorResponse> errorResponse;
+ TPGStreamOutput<TPGErrorResponse> errorResponse;
errorResponse
<< 'S' << "FATAL" << '\0'
<< 'V' << "FATAL" << '\0'
@@ -244,7 +309,8 @@ protected:
}
void HandleMessage(const TPGInitial* message) {
- if (message->Protocol == 0x2f16d204) { // 790024708 SSL handshake
+ uint32_t protocol = message->GetProtocol();
+ if (protocol == 0x2f16d204) { // 790024708 SSL handshake
if (IsSslSupported) {
BufferOutput.Append('S');
if (!FlushOutput()) {
@@ -266,19 +332,19 @@ protected:
BufferInput.Append('i'); // initial packet pseudo-message
return;
}
- if (message->Protocol == 0x2e16d204) { // 80877102 cancellation message
+ if (protocol == 0x2e16d204) { // 80877102 cancellation message
BLOG_D("cancellation message");
CloseConnection = true;
return;
}
- if (message->Protocol != 0x300) {
- BLOG_W("invalid protocol version (" << Hex(message->Protocol) << ")");
+ if (protocol != 0x300) {
+ BLOG_W("invalid protocol version (" << Hex(protocol) << ")");
CloseConnection = true;
return;
}
InitialMessage = MakePGMessageCopy(message);
if (IsAuthRequired) {
- Send(DatabaseProxy, new TEvPGEvents::TEvAuthRequest(InitialMessage));
+ Send(DatabaseProxy, new TEvPGEvents::TEvAuth(InitialMessage), 0, IncomingSequenceNumber++);
} else {
SendAuthOk();
FinishHandshake();
@@ -287,7 +353,7 @@ protected:
void HandleMessage(const TPGPasswordMessage* message) {
PasswordWasSupplied = true;
- Send(DatabaseProxy, new TEvPGEvents::TEvAuthRequest(InitialMessage, MakePGMessageCopy(message)));
+ Send(DatabaseProxy, new TEvPGEvents::TEvAuth(InitialMessage, MakePGMessageCopy(message)), 0, IncomingSequenceNumber++);
return;
}
@@ -295,98 +361,265 @@ protected:
if (message->GetQuery().empty()) {
SendMessage(TPGEmptyQueryResponse());
} else {
- Send(DatabaseProxy, new TEvPGEvents::TEvQuery(MakePGMessageCopy(message)));
+ Send(DatabaseProxy, new TEvPGEvents::TEvQuery(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++);
}
}
+ void HandleMessage(const TPGParse* message) {
+ if (message->GetQueryData().Query.empty()) {
+ SendMessage(TPGEmptyQueryResponse());
+ } else {
+ Send(DatabaseProxy, new TEvPGEvents::TEvParse(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++);
+ }
+ }
+
+ void HandleMessage(const TPGParameterStatus* message) {
+ Y_UNUSED(message);
+ }
+
+ void HandleMessage(const TPGBind* message) {
+ Send(DatabaseProxy, new TEvPGEvents::TEvBind(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++);
+ }
+
+ // void HandleMessage(const TPGDescribe* message) {
+ // Y_UNUSED(message);
+ // // describe current statement
+ // auto ev = std::make_unique<TEvPGEvents::TEvRowDescription>();
+ // ev->Fields.push_back({
+ // .Name = "column1",
+ // });
+ // Send(SelfId(), ev.release());
+ // }
+
+ // void HandleMessage(const TPGExecute* message) {
+ // Y_UNUSED(message);
+ // // execute current statement
+ // auto ev = std::make_unique<NPG::TEvPGEvents::TEvDataRows>();
+ // {
+ // ev->Rows.emplace_back();
+ // auto& row = ev->Rows.back();
+ // row.resize(1);
+ // row[0] = "345";
+ // }
+ // {
+ // ev->Rows.emplace_back();
+ // auto& row = ev->Rows.back();
+ // row.resize(1);
+ // row[0] = "456";
+ // }
+ // Send(SelfId(), ev.release());
+ // //
+ // Send(SelfId(), new NPG::TEvPGEvents::TEvCommandComplete("OK"));
+ // }
+
+ void HandleMessage(const TPGDescribe* message) {
+ Send(DatabaseProxy, new TEvPGEvents::TEvDescribe(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++);
+ }
+
+ void HandleMessage(const TPGExecute* message) {
+ Send(DatabaseProxy, new TEvPGEvents::TEvExecute(MakePGMessageCopy(message)), 0, IncomingSequenceNumber++);
+ }
+
void HandleMessage(const TPGTerminate*) {
CloseConnection = true;
}
- void HandleConnected(TEvPGEvents::TEvAuthResponse::TPtr& ev) {
- if (ev->Get()->Error) {
- if (PasswordWasSupplied) {
- SendAuthError(ev->Get()->Error);
- CloseConnection = true;
- } else {
- SendAuthClearText();
- }
- } else {
- SendAuthOk();
- FinishHandshake();
+ bool FlushAndPoll() {
+ if (FlushOutput()) {
+ RequestPoller();
+ return true;
}
+ return false;
+ }
- if (!FlushOutput()) {
- return;
+ struct TEventsComparator {
+ bool operator ()(const TAutoPtr<IEventHandle>& ev1, const TAutoPtr<IEventHandle>& ev2) const {
+ return ev1->Cookie < ev2->Cookie;
}
- RequestPoller();
+ };
+
+ template<typename TEv>
+ bool IsEventExpected(const TAutoPtr<TEventHandle<TEv>>& ev) {
+ return (ev->Cookie == 0) || (ev->Cookie == OutgoingSequenceNumber);
}
- void HandleConnected(TEvPGEvents::TEvRowDescription::TPtr& ev) {
- TPGStream<TPGRowDescription> rowDescription;
- rowDescription << uint16_t(ev->Get()->Fields.size()); // number of fields
- for (const auto& field : ev->Get()->Fields) {
- rowDescription
- << TStringBuf(field.Name) << '\0'
- << uint32_t(field.TableId)
- << uint16_t(field.ColumnId)
- << uint32_t(field.DataType)
- << uint16_t(field.DataTypeSize)
- << uint32_t(0xffffffff) // type modifier
- << uint16_t(0) // format text
- ;
- }
- SendStream(rowDescription);
+ template<typename TEv>
+ void PostponeEvent(const TAutoPtr<TEventHandle<TEv>>& ev) {
+ TAutoPtr<IEventHandle> evb = ev.Release();
+ BLOG_D("Postpone event " << evb->Cookie);
+ auto it = std::upper_bound(PostponedEvents.begin(), PostponedEvents.end(), evb, TEventsComparator());
+ PostponedEvents.insert(it, evb);
+ }
- if (!FlushOutput()) {
- return;
+ void ReplayPostponedEvents() {
+ if (!PostponedEvents.empty()) {
+ auto event = PostponedEvents.front();
+ PostponedEvents.pop_front();
+ StateConnected(event, TActivationContext::AsActorContext());
}
- RequestPoller();
}
- void HandleConnected(TEvPGEvents::TEvDataRows::TPtr& ev) {
- for (const auto& row : ev->Get()->Rows) {
- TPGStream<TPGDataRow> dataRow;
- dataRow << uint16_t(row.size()); // number of fields
- for (const auto& item : row) {
- dataRow << uint32_t(item.size()) << item;
+ void HandleConnected(TEvPGEvents::TEvAuthResponse::TPtr& ev) {
+ if (IsEventExpected(ev)) {
+ if (ev->Get()->Error) {
+ if (PasswordWasSupplied) {
+ SendAuthError(ev->Get()->Error);
+ CloseConnection = true;
+ } else {
+ SendAuthClearText();
+ }
+ } else {
+ SendAuthOk();
+ FinishHandshake();
}
- SendStream(dataRow);
+ ++OutgoingSequenceNumber;
+ ReplayPostponedEvents();
+ FlushAndPoll();
+ } else {
+ PostponeEvent(ev);
}
+ }
- if (!FlushOutput()) {
- return;
+ void HandleConnected(TEvPGEvents::TEvQueryResponse::TPtr& ev) {
+ if (IsEventExpected(ev)) {
+ if (ev->Get()->ErrorFields.empty()) {
+ TString tag = "OK";
+ { // rowDescription
+ TPGStreamOutput<TPGRowDescription> rowDescription;
+ rowDescription << uint16_t(ev->Get()->DataFields.size()); // number of fields
+ for (const auto& field : ev->Get()->DataFields) {
+ rowDescription
+ << TStringBuf(field.Name) << '\0'
+ << uint32_t(field.TableId)
+ << uint16_t(field.ColumnId)
+ << uint32_t(field.DataType)
+ << uint16_t(field.DataTypeSize)
+ << uint32_t(0xffffffff) // type modifier
+ << uint16_t(0) // format text
+ ;
+ }
+ SendStream(rowDescription);
+ }
+ { // dataFields
+ for (const auto& row : ev->Get()->DataRows) {
+ TPGStreamOutput<TPGDataRow> dataRow;
+ dataRow << uint16_t(row.size()); // number of fields
+ for (const auto& item : row) {
+ dataRow << uint32_t(item.size()) << item;
+ }
+ SendStream(dataRow);
+ }
+ }
+ { // commandComplete
+ TPGStreamOutput<TPGCommandComplete> commandComplete;
+ commandComplete << tag << '\0';
+ SendStream(commandComplete);
+ }
+ } else {
+ // error response
+ TPGStreamOutput<TPGErrorResponse> errorResponse;
+ for (const auto& field : ev->Get()->ErrorFields) {
+ errorResponse << field.first << field.second << '\0';
+ }
+ errorResponse << '\0';
+ SendStream(errorResponse);
+ }
+ SendReadyForQuery();
+ ++OutgoingSequenceNumber;
+ ReplayPostponedEvents();
+ FlushAndPoll();
+ } else {
+ PostponeEvent(ev);
}
- RequestPoller();
}
- void HandleConnected(TEvPGEvents::TEvCommandComplete::TPtr& ev) {
- TPGStream<TPGCommandComplete> commandComplete;
- commandComplete << ev->Get()->Tag << '\0';
- SendStream(commandComplete);
-
- SendReadyForQuery();
-
- if (!FlushOutput()) {
- return;
+ void HandleConnected(TEvPGEvents::TEvDescribeResponse::TPtr& ev) {
+ if (IsEventExpected(ev)) {
+ TString tag = "OK";
+ { // rowDescription
+ TPGStreamOutput<TPGRowDescription> rowDescription;
+ rowDescription << uint16_t(ev->Get()->DataFields.size()); // number of fields
+ for (const auto& field : ev->Get()->DataFields) {
+ rowDescription
+ << TStringBuf(field.Name) << '\0'
+ << uint32_t(field.TableId)
+ << uint16_t(field.ColumnId)
+ << uint32_t(field.DataType)
+ << uint16_t(field.DataTypeSize)
+ << uint32_t(0xffffffff) // type modifier
+ << uint16_t(0) // format text
+ ;
+ }
+ SendStream(rowDescription);
+ }
+ ++OutgoingSequenceNumber;
+ ReplayPostponedEvents();
+ FlushAndPoll();
+ } else {
+ PostponeEvent(ev);
}
- RequestPoller();
}
- void HandleConnected(TEvPGEvents::TEvErrorResponse::TPtr& ev) {
- TPGStream<TPGErrorResponse> errorResponse;
- for (const auto& field : ev->Get()->ErrorFields) {
- errorResponse << field.first << field.second << '\0';
+ void HandleConnected(TEvPGEvents::TEvExecuteResponse::TPtr& ev) {
+ if (IsEventExpected(ev)) {
+ if (ev->Get()->ErrorFields.empty()) {
+ TString tag = "OK";
+ { // dataFields
+ for (const auto& row : ev->Get()->DataRows) {
+ TPGStreamOutput<TPGDataRow> dataRow;
+ dataRow << uint16_t(row.size()); // number of fields
+ for (const auto& item : row) {
+ dataRow << uint32_t(item.size()) << item;
+ }
+ SendStream(dataRow);
+ }
+ }
+ { // commandComplete
+ TPGStreamOutput<TPGCommandComplete> commandComplete;
+ commandComplete << tag << '\0';
+ SendStream(commandComplete);
+ }
+ } else {
+ // error response
+ TPGStreamOutput<TPGErrorResponse> errorResponse;
+ for (const auto& field : ev->Get()->ErrorFields) {
+ errorResponse << field.first << field.second << '\0';
+ }
+ errorResponse << '\0';
+ SendStream(errorResponse);
+ }
+ SendReadyForQuery();
+ ++OutgoingSequenceNumber;
+ ReplayPostponedEvents();
+ FlushAndPoll();
+ } else {
+ PostponeEvent(ev);
}
- errorResponse << '\0';
- SendStream(errorResponse);
+ }
- SendReadyForQuery();
+ void HandleConnected(TEvPGEvents::TEvParseResponse::TPtr& ev) {
+ if (IsEventExpected(ev)) {
+ TPGStreamOutput<TPGParseComplete> parseComplete;
+ SendStream(parseComplete);
+ SendReadyForQuery();
+ ++OutgoingSequenceNumber;
+ ReplayPostponedEvents();
+ FlushAndPoll();
+ } else {
+ PostponeEvent(ev);
+ }
+ }
- if (!FlushOutput()) {
- return;
+ void HandleConnected(TEvPGEvents::TEvBindResponse::TPtr& ev) {
+ if (IsEventExpected(ev)) {
+ TPGStreamOutput<TPGBindComplete> bindComplete;
+ SendStream(bindComplete);
+ ++OutgoingSequenceNumber;
+ ReplayPostponedEvents();
+ FlushAndPoll();
+ } else {
+ PostponeEvent(ev);
}
- RequestPoller();
}
bool HasInputMessage() const {
@@ -422,7 +655,7 @@ protected:
BufferInput.Advance(res);
while (HasInputMessage()) {
const TPGMessage* message = GetInputMessage();
- PrintMessage("-> ", *message);
+ PrintMessage(EDirection::Incoming, *message);
switch (static_cast<EMessageCode>(message->Message)) {
case EMessageCode::Initial:
HandleMessage(static_cast<const TPGInitial*>(message));
@@ -436,8 +669,23 @@ protected:
case EMessageCode::PasswordMessage:
HandleMessage(static_cast<const TPGPasswordMessage*>(message));
break;
+ case EMessageCode::Parse:
+ HandleMessage(static_cast<const TPGParse*>(message));
+ break;
+ case EMessageCode::ParameterStatus:
+ HandleMessage(static_cast<const TPGParameterStatus*>(message));
+ break;
+ case EMessageCode::Bind:
+ HandleMessage(static_cast<const TPGBind*>(message));
+ break;
+ case EMessageCode::Describe:
+ HandleMessage(static_cast<const TPGDescribe*>(message));
+ break;
+ case EMessageCode::Execute:
+ HandleMessage(static_cast<const TPGExecute*>(message));
+ break;
default:
- BLOG_W("invalid message (" << message->Message << ")");
+ BLOG_ERROR("invalid message (" << message->Message << ")");
CloseConnection = true;
break;
}
@@ -536,10 +784,11 @@ protected:
hFunc(TEvPollerRegisterResult, HandleConnected);
hFunc(TEvPGEvents::TEvAuthResponse, HandleConnected);
- hFunc(TEvPGEvents::TEvRowDescription, HandleConnected);
- hFunc(TEvPGEvents::TEvDataRows, HandleConnected);
- hFunc(TEvPGEvents::TEvCommandComplete, HandleConnected);
- hFunc(TEvPGEvents::TEvErrorResponse, HandleConnected);
+ hFunc(TEvPGEvents::TEvQueryResponse, HandleConnected);
+ hFunc(TEvPGEvents::TEvParseResponse, HandleConnected);
+ hFunc(TEvPGEvents::TEvBindResponse, HandleConnected);
+ hFunc(TEvPGEvents::TEvDescribeResponse, HandleConnected);
+ hFunc(TEvPGEvents::TEvExecuteResponse, HandleConnected);
}
}
};
diff --git a/ydb/core/pgproxy/pg_proxy_events.h b/ydb/core/pgproxy/pg_proxy_events.h
index a1464d503a1..7fdedffb944 100644
--- a/ydb/core/pgproxy/pg_proxy_events.h
+++ b/ydb/core/pgproxy/pg_proxy_events.h
@@ -10,18 +10,35 @@ struct TEvPGEvents {
enum EEv {
EvConnectionOpened = EventSpaceBegin(NActors::TEvents::ES_PGWIRE),
EvConnectionClosed,
- EvAuthRequest,
+ EvAuth,
EvAuthResponse,
EvQuery,
- EvRowDescription,
- EvDataRows,
- EvCommandComplete,
- EvErrorResponse,
+ EvQueryResponse,
+ EvParse,
+ EvParseResponse,
+ EvBind,
+ EvBindResponse,
+ EvDescribe,
+ EvDescribeResponse,
+ EvExecute,
+ EvExecuteResponse,
EvEnd
};
static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PGWIRE), "ES_PGWIRE event space is too small.");
+ struct TRowDescriptionField {
+ TString Name;
+ uint32_t TableId = 0;
+ uint16_t ColumnId = 0;
+ uint32_t DataType;
+ uint16_t DataTypeSize;
+ //uint32_t DataTypeModifier;
+ //uint16_t Format;
+ };
+
+ using TDataRow = std::vector<TString>;
+
struct TEvConnectionOpened : NActors::TEventLocal<TEvConnectionOpened, EvConnectionOpened> {
std::shared_ptr<TPGInitial> Message;
@@ -33,15 +50,15 @@ struct TEvPGEvents {
struct TEvConnectionClosed : NActors::TEventLocal<TEvConnectionClosed, EvConnectionClosed> {
};
- struct TEvAuthRequest : NActors::TEventLocal<TEvAuthRequest, EvAuthRequest> {
+ struct TEvAuth : NActors::TEventLocal<TEvAuth, EvAuth> {
std::shared_ptr<TPGInitial> InitialMessage;
std::unique_ptr<TPGPasswordMessage> PasswordMessage;
- TEvAuthRequest(std::shared_ptr<TPGInitial> initialMessage)
+ TEvAuth(std::shared_ptr<TPGInitial> initialMessage)
: InitialMessage(std::move(initialMessage))
{}
- TEvAuthRequest(std::shared_ptr<TPGInitial> initialMessage, std::unique_ptr<TPGPasswordMessage> message)
+ TEvAuth(std::shared_ptr<TPGInitial> initialMessage, std::unique_ptr<TPGPasswordMessage> message)
: InitialMessage(std::move(initialMessage))
, PasswordMessage(std::move(message))
{}
@@ -72,30 +89,10 @@ struct TEvPGEvents {
{}
};
- struct TEvRowDescription : NActors::TEventLocal<TEvRowDescription, EvRowDescription> {
- struct TField {
- TString Name;
- uint32_t TableId = 0;
- uint16_t ColumnId = 0;
- uint32_t DataType;
- uint16_t DataTypeSize;
- //uint32_t DataTypeModifier;
- //uint16_t Format;
- };
- std::vector<TField> Fields;
- };
-
- struct TEvDataRows : NActors::TEventLocal<TEvDataRows, EvDataRows> {
- using TDataRow = std::vector<TString>;
- std::vector<TDataRow> Rows;
- };
-
- struct TEvCommandComplete : NActors::TEventLocal<TEvCommandComplete, EvCommandComplete> {
- TString Tag;
-
- TEvCommandComplete(const TString& tag)
- : Tag(tag)
- {}
+ struct TEvQueryResponse : NActors::TEventLocal<TEvQueryResponse, EvQueryResponse> {
+ std::vector<TRowDescriptionField> DataFields;
+ std::vector<TDataRow> DataRows;
+ std::vector<std::pair<char, TString>> ErrorFields;
};
/*
@@ -127,7 +124,69 @@ struct TEvPGEvents {
R = "scanner_yyerror"
*/
- struct TEvErrorResponse : NActors::TEventLocal<TEvErrorResponse, EvErrorResponse> {
+ struct TEvParseResponse : NActors::TEventLocal<TEvParseResponse, EvParseResponse> {
+ std::unique_ptr<TPGParse> OriginalMessage;
+
+ TEvParseResponse(std::unique_ptr<TPGParse> originalMessage)
+ : OriginalMessage(std::move(originalMessage))
+ {}
+ };
+
+ struct TEvParse : NActors::TEventLocal<TEvParse, EvParse> {
+ std::unique_ptr<TPGParse> Message;
+
+ TEvParse(std::unique_ptr<TPGParse> message)
+ : Message(std::move(message))
+ {}
+
+ std::unique_ptr<TEvParseResponse> Reply() {
+ return std::make_unique<TEvParseResponse>(std::move(Message));
+ }
+ };
+
+ struct TEvBindResponse : NActors::TEventLocal<TEvBindResponse, EvBindResponse> {
+ std::unique_ptr<TPGBind> OriginalMessage;
+
+ TEvBindResponse(std::unique_ptr<TPGBind> originalMessage)
+ : OriginalMessage(std::move(originalMessage))
+ {}
+ };
+
+ struct TEvBind : NActors::TEventLocal<TEvBind, EvBind> {
+ std::unique_ptr<TPGBind> Message;
+
+ TEvBind(std::unique_ptr<TPGBind> message)
+ : Message(std::move(message))
+ {}
+
+ std::unique_ptr<TEvBindResponse> Reply() {
+ return std::make_unique<TEvBindResponse>(std::move(Message));
+ }
+ };
+
+ struct TEvDescribe : NActors::TEventLocal<TEvDescribe, EvDescribe> {
+ std::unique_ptr<TPGDescribe> Message;
+
+ TEvDescribe(std::unique_ptr<TPGDescribe> message)
+ : Message(std::move(message))
+ {}
+ };
+
+ struct TEvDescribeResponse : NActors::TEventLocal<TEvDescribeResponse, EvDescribeResponse> {
+ std::vector<TRowDescriptionField> DataFields;
+ std::vector<std::pair<char, TString>> ErrorFields;
+ };
+
+ struct TEvExecute : NActors::TEventLocal<TEvExecute, EvExecute> {
+ std::unique_ptr<TPGExecute> Message;
+
+ TEvExecute(std::unique_ptr<TPGExecute> message)
+ : Message(std::move(message))
+ {}
+ };
+
+ struct TEvExecuteResponse : NActors::TEventLocal<TEvExecuteResponse, EvExecuteResponse> {
+ std::vector<TDataRow> DataRows;
std::vector<std::pair<char, TString>> ErrorFields;
};
};
diff --git a/ydb/core/pgproxy/pg_proxy_types.cpp b/ydb/core/pgproxy/pg_proxy_types.cpp
new file mode 100644
index 00000000000..acb836d83db
--- /dev/null
+++ b/ydb/core/pgproxy/pg_proxy_types.cpp
@@ -0,0 +1,207 @@
+#include "pg_proxy_types.h"
+#include "pg_stream.h"
+
+namespace NPG {
+
+TString TPGInitial::Dump() const {
+ TPGStreamInput stream(*this);
+ TStringBuilder text;
+ uint32_t protocol = 0;
+ stream >> protocol;
+ protocol = htonl(protocol);
+ if (protocol == 773247492) { // 80877102 cancellation message
+ uint32_t pid = 0;
+ uint32_t key = 0;
+ stream >> pid >> key;
+ text << "cancellation PID " << pid << " KEY " << key;
+ } else if (protocol == 790024708) { // 790024708 SSL handshake
+ text << "SSL handshake";
+ } else {
+ text << "protocol(" << Hex(protocol) << ") ";
+ while (!stream.Empty()) {
+ TStringBuf key;
+ TStringBuf value;
+ stream >> key >> value;
+ if (key.empty()) {
+ break;
+ }
+ text << key << "=" << value << " ";
+ }
+ }
+ return text;
+}
+
+uint32_t TPGInitial::GetProtocol() const {
+ TPGStreamInput stream(*this);
+ uint32_t protocol = 0;
+ stream >> protocol;
+ protocol = htonl(protocol);
+ return protocol;
+}
+
+std::unordered_map<TString, TString> TPGInitial::GetClientParams() const {
+ std::unordered_map<TString, TString> params;
+ TPGStreamInput stream(*this);
+ TStringBuilder text;
+ uint32_t protocol = 0;
+ stream >> protocol;
+ while (!stream.Empty()) {
+ TStringBuf key;
+ TStringBuf value;
+ stream >> key >> value;
+ if (key.empty()) {
+ break;
+ }
+ params[TString(key)] = value;
+ }
+ return params;
+}
+
+TString TPGErrorResponse::Dump() const {
+ TPGStreamInput stream(*this);
+ TStringBuilder text;
+
+ while (!stream.Empty()) {
+ char code;
+ TString message;
+ stream >> code;
+ if (code == 0) {
+ break;
+ }
+ stream >> message;
+ if (!text.empty()) {
+ text << ' ';
+ }
+ text << code << "=\"" << message << "\"";
+ }
+ return text;
+}
+
+TString TPGParse::Dump() const {
+ TPGStreamInput stream(*this);
+ TStringBuf name;
+ stream >> name;
+ return TStringBuilder() << "Name:" << name;
+}
+
+TPGParse::TQueryData TPGParse::GetQueryData() const {
+ TQueryData queryData;
+ TPGStreamInput stream(*this);
+ stream >> queryData.Name;
+ stream >> queryData.Query;
+ uint16_t numberOfParameterTypes = 0;
+ stream >> numberOfParameterTypes;
+ queryData.ParametersTypes.reserve(numberOfParameterTypes);
+ for (uint16_t n = 0; n < numberOfParameterTypes; ++n) {
+ uint32_t param = 0;
+ stream >> param;
+ queryData.ParametersTypes.emplace_back(param);
+ }
+ return queryData;
+}
+
+TPGBind::TBindData TPGBind::GetBindData() const {
+ TBindData bindData;
+ TPGStreamInput stream(*this);
+ stream >> bindData.PortalName;
+ stream >> bindData.StatementName;
+ uint16_t numberOfParameterFormats = 0;
+ stream >> numberOfParameterFormats;
+ bindData.ParametersFormat.reserve(numberOfParameterFormats);
+ for (uint16_t n = 0; n < numberOfParameterFormats; ++n) {
+ uint16_t format = 0;
+ stream >> format;
+ bindData.ParametersFormat.emplace_back(format);
+ }
+ uint16_t numberOfParameterValues = 0;
+ stream >> numberOfParameterValues;
+ bindData.ParametersValue.reserve(numberOfParameterValues);
+ for (uint16_t n = 0; n < numberOfParameterValues; ++n) {
+ uint32_t size = 0;
+ stream >> size;
+ std::vector<uint8_t> value;
+ stream.Read(value, size);
+ bindData.ParametersValue.emplace_back(std::move(value));
+ }
+ uint16_t numberOfResultFormats = 0;
+ stream >> numberOfResultFormats;
+ bindData.ResultsFormat.reserve(numberOfResultFormats);
+ for (uint16_t n = 0; n < numberOfResultFormats; ++n) {
+ uint16_t format = 0;
+ stream >> format;
+ bindData.ResultsFormat.emplace_back(format);
+ }
+ return bindData;
+}
+
+TString TPGBind::Dump() const {
+ TStringBuilder text;
+ TPGStreamInput stream(*this);
+ TStringBuf portalName;
+ TStringBuf statementName;
+ stream >> portalName;
+ stream >> statementName;
+ if (portalName) {
+ text << "Portal: " << portalName;
+ } else if (statementName) {
+ text << "Statement: " << statementName;
+ }
+ return text;
+}
+
+TString TPGDataRow::Dump() const {
+ TPGStreamInput stream(*this);
+ uint16_t numberOfColumns = 0;
+ stream >> numberOfColumns;
+ return TStringBuilder() << "Columns: " << numberOfColumns;
+}
+
+TPGDescribe::TDescribeData TPGDescribe::GetDescribeData() const {
+ TPGStreamInput stream(*this);
+ TDescribeData data;
+ char describeType = 0;
+ stream >> describeType;
+ data.Type = static_cast<TDescribeData::EDescribeType>(describeType);
+ stream >> data.Name;
+ return data;
+}
+
+TString TPGDescribe::Dump() const {
+ TPGStreamInput stream(*this);
+ TStringBuilder text;
+ char describeType = 0;
+ stream >> describeType;
+ text << "Type:" << describeType;
+ TStringBuf name;
+ stream >> name;
+ text << " Name:" << name;
+ return text;
+}
+
+TPGExecute::TExecuteData TPGExecute::GetExecuteData() const {
+ TPGStreamInput stream(*this);
+ TExecuteData data;
+ stream >> data.PortalName >> data.MaxRows;
+ return data;
+}
+
+TString TPGExecute::Dump() const {
+ TPGStreamInput stream(*this);
+ TStringBuilder text;
+ TStringBuf name;
+ uint32_t maxRows = 0;
+ stream >> name >> maxRows;
+ if (name) {
+ text << "Name: " << name;
+ }
+ if (maxRows) {
+ if (!text.empty()) {
+ text << ' ';
+ }
+ text << "MaxRows: " << maxRows;
+ }
+ return text;
+}
+
+
+} \ No newline at end of file
diff --git a/ydb/core/pgproxy/pg_proxy_types.h b/ydb/core/pgproxy/pg_proxy_types.h
index caf7bacf081..00ac72724a7 100644
--- a/ydb/core/pgproxy/pg_proxy_types.h
+++ b/ydb/core/pgproxy/pg_proxy_types.h
@@ -1,7 +1,10 @@
#pragma once
#include <cstdint>
+#include <unordered_map>
+#include <vector>
#include <util/stream/format.h>
+#include <util/string/builder.h>
#include <arpa/inet.h>
namespace NPG {
@@ -30,78 +33,20 @@ struct TPGMessage {
void operator delete(void* p) {
::operator delete(p);
}
+
+ bool Empty() const {
+ return GetDataSize() == 0;
+ }
};
struct TPGInitial : TPGMessage { // it's not true, because we don't receive message code from a network, but imply it on the start
- uint32_t Protocol;
-
TPGInitial() {
Message = 'i'; // fake code
}
- const char* GetData() const {
- return reinterpret_cast<const char*>(this) + sizeof(*this);
- }
-
- TString Dump() const {
- TStringBuilder text;
- if (Protocol == 773247492) { // 80877102 cancellation message
- const uint32_t* data = reinterpret_cast<const uint32_t*>(GetData());
- if (GetDataSize() >= 8) {
- uint32_t pid = data[0];
- uint32_t key = data[1];
- text << "cancellation PID " << pid << " KEY " << key;
- }
- } else if (Protocol == 790024708) { // 790024708 SSL handshake
- text << "SSL handshake";
- } else {
- text << "protocol(" << Hex(Protocol) << ") ";
- const char* begin = GetData();
- const char* end = GetData() + GetDataSize();
- for (const char* ptr = begin; ptr < end;) {
- TStringBuf key;
- TStringBuf value;
- size_t size = strnlen(ptr, end - ptr);
- key = TStringBuf(ptr, size);
- if (key.empty()) {
- break;
- }
- ptr += size + 1;
- if (ptr >= end) {
- break;
- }
- size = strnlen(ptr, end - ptr);
- value = TStringBuf(ptr, size);
- ptr += size + 1;
- text << key << "=" << value << " ";
- }
- }
- return text;
- }
-
- std::unordered_map<TString, TString> GetClientParams() const {
- std::unordered_map<TString, TString> params;
- const char* begin = GetData();
- const char* end = GetData() + GetDataSize();
- for (const char* ptr = begin; ptr < end;) {
- TString key;
- TString value;
- size_t size = strnlen(ptr, end - ptr);
- key = TStringBuf(ptr, size);
- if (key.empty()) {
- break;
- }
- ptr += size + 1;
- if (ptr >= end) {
- break;
- }
- size = strnlen(ptr, end - ptr);
- value = TStringBuf(ptr, size);
- ptr += size + 1;
- params[key] = value;
- }
- return params;
- }
+ TString Dump() const;
+ std::unordered_map<TString, TString> GetClientParams() const;
+ uint32_t GetProtocol() const;
};
struct TPGAuth : TPGMessage {
@@ -166,7 +111,11 @@ struct TPGParameterStatus : TPGMessage {
}
TString Dump() const {
- return TStringBuilder() << GetName() << "=" << GetValue();
+ if (!Empty()) {
+ return TStringBuilder() << GetName() << "=" << GetValue();
+ } else {
+ return {};
+ }
}
};
@@ -221,6 +170,8 @@ struct TPGErrorResponse : TPGMessage {
TPGErrorResponse() {
Message = 'E';
}
+
+ TString Dump() const;
};
struct TPGTerminate : TPGMessage {
@@ -239,6 +190,8 @@ struct TPGDataRow : TPGMessage {
TPGDataRow() {
Message = 'D';
}
+
+ TString Dump() const;
};
struct TPGEmptyQueryResponse : TPGMessage {
@@ -247,6 +200,81 @@ struct TPGEmptyQueryResponse : TPGMessage {
}
};
+struct TPGParse : TPGMessage {
+ TPGParse() {
+ Message = 'P';
+ }
+
+ struct TQueryData {
+ TString Name;
+ TString Query;
+ std::vector<int32_t> ParametersTypes; // types
+ };
+
+ TQueryData GetQueryData() const;
+ TString Dump() const;
+};
+
+struct TPGParseComplete : TPGMessage {
+ TPGParseComplete() {
+ Message = '1';
+ }
+};
+
+struct TPGBind : TPGMessage {
+ TPGBind() {
+ Message = 'B';
+ }
+
+ struct TBindData {
+ TString PortalName;
+ TString StatementName;
+ std::vector<int16_t> ParametersFormat; // format codes 0=text, 1=binary
+ std::vector<std::vector<uint8_t>> ParametersValue; // parameters content
+ std::vector<int16_t> ResultsFormat; // result format codes 0=text, 1=binary
+ };
+
+ TBindData GetBindData() const;
+ TString Dump() const;
+};
+
+struct TPGBindComplete : TPGMessage {
+ TPGBindComplete() {
+ Message = '2';
+ }
+};
+
+struct TPGDescribe : TPGMessage {
+ TPGDescribe() {
+ Message = 'D';
+ }
+
+ struct TDescribeData {
+ enum class EDescribeType : char {
+ Portal = 'P',
+ Statement = 'S',
+ };
+ EDescribeType Type;
+ TString Name;
+ };
+
+ TDescribeData GetDescribeData() const;
+ TString Dump() const;
+};
+
+struct TPGExecute : TPGMessage {
+ TPGExecute() {
+ Message = 'E';
+ }
+
+ struct TExecuteData {
+ TString PortalName;
+ uint32_t MaxRows;
+ };
+
+ TExecuteData GetExecuteData() const;
+ TString Dump() const;
+};
#pragma pack(pop)
template<typename TPGMessageType>
diff --git a/ydb/core/pgproxy/pg_proxy_ut.cpp b/ydb/core/pgproxy/pg_proxy_ut.cpp
index 472806b2fb3..2e02e98deee 100644
--- a/ydb/core/pgproxy/pg_proxy_ut.cpp
+++ b/ydb/core/pgproxy/pg_proxy_ut.cpp
@@ -64,7 +64,7 @@ Y_UNIT_TEST_SUITE(TPGTest) {
}
TSocket s(TNetworkAddress("::", port));
Send(s, "0000001300030000" "7573657200757365720000"); // user=user
- NPG::TEvPGEvents::TEvAuthRequest* authRequest = actorSystem.GrabEdgeEvent<NPG::TEvPGEvents::TEvAuthRequest>(handle);
+ NPG::TEvPGEvents::TEvAuth* authRequest = actorSystem.GrabEdgeEvent<NPG::TEvPGEvents::TEvAuth>(handle);
UNIT_ASSERT(authRequest);
UNIT_ASSERT_VALUES_EQUAL(authRequest->InitialMessage->GetClientParams()["user"], "user");
actorSystem.Send(new NActors::IEventHandle(handle->Sender, database, new NPG::TEvPGEvents::TEvAuthResponse()));
diff --git a/ydb/core/pgproxy/pg_stream.h b/ydb/core/pgproxy/pg_stream.h
index bb7ea938287..c1eb55c0287 100644
--- a/ydb/core/pgproxy/pg_stream.h
+++ b/ydb/core/pgproxy/pg_stream.h
@@ -1,15 +1,17 @@
#pragma once
#include "pg_proxy_types.h"
+#include <util/generic/buffer.h>
+#include <util/generic/strbuf.h>
namespace NPG {
template<typename TMessage>
-class TPGStream : public TBuffer {
+class TPGStreamOutput : public TBuffer {
public:
using TBase = TBuffer;
- TPGStream() {
+ TPGStreamOutput() {
TMessage header;
TBase::Append(reinterpret_cast<const char*>(&header), sizeof(header));
}
@@ -19,27 +21,90 @@ public:
header->Length = htonl(TBase::Size() - sizeof(char));
}
- TPGStream& operator <<(uint16_t v) {
+ TPGStreamOutput& operator <<(uint16_t v) {
v = htons(v);
TBase::Append(reinterpret_cast<const char*>(&v), sizeof(v));
return *this;
}
- TPGStream& operator <<(uint32_t v) {
+ TPGStreamOutput& operator <<(uint32_t v) {
v = htonl(v);
TBase::Append(reinterpret_cast<const char*>(&v), sizeof(v));
return *this;
}
- TPGStream& operator <<(char v) {
+ TPGStreamOutput& operator <<(char v) {
TBase::Append(v);
return *this;
}
- TPGStream& operator <<(TStringBuf s) {
+ TPGStreamOutput& operator <<(TStringBuf s) {
TBase::Append(s.data(), s.size());
return *this;
}
};
+class TPGStreamInput {
+public:
+ TPGStreamInput(const TPGMessage& message)
+ : Buffer(message.GetData(), message.GetDataSize())
+ {
+ }
+
+ TPGStreamInput& operator >>(TString& s) {
+ s = Buffer.NextTok('\0');
+ return *this;
+ }
+
+ TPGStreamInput& operator >>(TStringBuf& s) {
+ s = Buffer.NextTok('\0');
+ return *this;
+ }
+
+ TPGStreamInput& operator >>(char& v) {
+ if (Buffer.size() >= sizeof(v)) {
+ v = *reinterpret_cast<const char*>(Buffer.data());
+ Buffer.Skip(sizeof(v));
+ } else {
+ v = {};
+ }
+ return *this;
+ }
+
+ TPGStreamInput& operator >>(uint16_t& v) {
+ if (Buffer.size() >= sizeof(v)) {
+ v = ntohs(*reinterpret_cast<const uint16_t*>(Buffer.data()));
+ Buffer.Skip(sizeof(v));
+ } else {
+ v = {};
+ }
+ return *this;
+ }
+
+ TPGStreamInput& operator >>(uint32_t& v) {
+ if (Buffer.size() >= sizeof(v)) {
+ v = ntohl(*reinterpret_cast<const uint32_t*>(Buffer.data()));
+ Buffer.Skip(sizeof(v));
+ } else {
+ v = {};
+ }
+ return *this;
+ }
+
+ TPGStreamInput& Read(std::vector<uint8_t>& data, uint32_t size) {
+ size = std::min<uint32_t>(size, Buffer.size());
+ data.resize(size);
+ memcpy(data.data(), Buffer.data(), size);
+ Buffer.Skip(size);
+ return *this;
+ }
+
+ bool Empty() const {
+ return Buffer.Empty();
+ }
+
+protected:
+ TStringBuf Buffer;
+};
+
}