diff options
author | galaxycrab <UgnineSirdis@ydb.tech> | 2023-11-23 11:26:33 +0300 |
---|---|---|
committer | galaxycrab <UgnineSirdis@ydb.tech> | 2023-11-23 12:01:57 +0300 |
commit | 44354d0fc55926c1d4510d1d2c9c9f6a1a5e9300 (patch) | |
tree | cb4d75cd1c6dbc3da0ed927337fd8d1b6ed9da84 /library/cpp/clickhouse/client/client.cpp | |
parent | 0e69bf615395fdd48ecee032faaec81bc468b0b8 (diff) | |
download | ydb-44354d0fc55926c1d4510d1d2c9c9f6a1a5e9300.tar.gz |
YQ Connector:test INNER JOIN
Diffstat (limited to 'library/cpp/clickhouse/client/client.cpp')
-rw-r--r-- | library/cpp/clickhouse/client/client.cpp | 767 |
1 files changed, 767 insertions, 0 deletions
diff --git a/library/cpp/clickhouse/client/client.cpp b/library/cpp/clickhouse/client/client.cpp new file mode 100644 index 0000000000..b0b2613bb5 --- /dev/null +++ b/library/cpp/clickhouse/client/client.cpp @@ -0,0 +1,767 @@ +#include "client.h" +#include "protocol.h" + +#include <library/cpp/clickhouse/client/base/coded.h> +#include <library/cpp/clickhouse/client/base/compressed.h> +#include <library/cpp/clickhouse/client/base/wire_format.h> +#include <library/cpp/clickhouse/client/columns/factory.h> +#include <library/cpp/openssl/io/stream.h> + +#include <util/generic/buffer.h> +#include <util/generic/vector.h> +#include <util/network/socket.h> +#include <util/random/random.h> +#include <util/stream/buffered.h> +#include <util/stream/buffer.h> +#include <util/stream/mem.h> +#include <util/string/builder.h> +#include <util/string/cast.h> +#include <util/system/unaligned_mem.h> + +#include <contrib/libs/lz4/lz4.h> +#include <contrib/restricted/cityhash-1.0.2/city.h> + +#define DBMS_NAME "ClickHouse" +#define DBMS_VERSION_MAJOR 1 +#define DBMS_VERSION_MINOR 1 +#define REVISION 54126 + +#define DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES 50264 +#define DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS 51554 +#define DBMS_MIN_REVISION_WITH_BLOCK_INFO 51903 +#define DBMS_MIN_REVISION_WITH_CLIENT_INFO 54032 +#define DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE 54058 +#define DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO 54060 + +namespace NClickHouse { + struct TClientInfo { + ui8 IfaceType = 1; // TCP + ui8 QueryKind; + TString InitialUser; + TString InitialQueryId; + TString QuotaKey; + TString OsUser; + TString ClientHostname; + TString ClientName; + TString InitialAddress = "[::ffff:127.0.0.1]:0"; + ui64 ClientVersionMajor = 0; + ui64 ClientVersionMinor = 0; + ui32 ClientRevision = 0; + }; + + struct TServerInfo { + TString Name; + TString Timezone; + ui64 VersionMajor; + ui64 VersionMinor; + ui64 Revision; + }; + + class TClient::TImpl { + public: + TImpl(const TClientOptions& opts); + ~TImpl(); + + void ExecuteQuery(TQuery query); + + void Insert(const TString& table_name, const TBlock& block); + + void Ping(); + + void ResetConnection(); + + private: + bool Handshake(); + + bool ReceivePacket(ui64* server_packet = nullptr); + + void SendQuery(const TString& query); + + void SendData(const TBlock& block); + + bool SendHello(); + + bool ReadBlock(TBlock* block, TCodedInputStream* input); + + bool ReceiveHello(); + + /// Reads data packet form input stream. + bool ReceiveData(); + + /// Reads exception packet form input stream. + bool ReceiveException(bool rethrow = false); + + void WriteBlock(const TBlock& block, TCodedOutputStream* output); + + private: + void Disconnect() { + Socket_ = TSocket(); + } + + /// In case of network errors tries to reconnect to server and + /// call fuc several times. + void RetryGuard(std::function<void()> fuc); + + private: + class EnsureNull { + public: + inline EnsureNull(TQueryEvents* ev, TQueryEvents** ptr) + : ptr_(ptr) + { + if (ptr_) { + *ptr_ = ev; + } + } + + inline ~EnsureNull() { + if (ptr_) { + *ptr_ = nullptr; + } + } + + private: + TQueryEvents** ptr_; + }; + + const TClientOptions Options_; + TQueryEvents* Events_; + int Compression_ = CompressionState::Disable; + + TSocket Socket_; + + TSocketInput SocketInput_; + TSocketOutput SocketOutput_; + THolder<TBufferedInput> BufferedInput_; + THolder<TBufferedOutput> BufferedOutput_; + THolder<TOpenSslClientIO> SslClient_; + + TCodedInputStream Input_; + TCodedOutputStream Output_; + + TServerInfo ServerInfo_; + }; + + TClient::TImpl::TImpl(const TClientOptions& opts) + : Options_(opts) + , Events_(nullptr) + , Socket_(TNetworkAddress(opts.Host, opts.Port), Options_.ConnectTimeout) + , SocketInput_(Socket_) + , SocketOutput_(Socket_) + { + if (opts.UseSsl) { + SslClient_ = MakeHolder<TOpenSslClientIO>(&SocketInput_, &SocketOutput_, opts.SslOptions); + BufferedInput_ = MakeHolder<TBufferedInput>(SslClient_.Get()); + BufferedOutput_ = MakeHolder<TBufferedOutput>(SslClient_.Get()); + } else { + BufferedInput_ = MakeHolder<TBufferedInput>(&SocketInput_); + BufferedOutput_ = MakeHolder<TBufferedOutput>(&SocketOutput_); + } + Input_ = TCodedInputStream(BufferedInput_.Get()); + Output_ = TCodedOutputStream(BufferedOutput_.Get()); + + if (Options_.RequestTimeout.Seconds()) { + Socket_.SetSocketTimeout(Options_.RequestTimeout.Seconds()); + } + + if (!Handshake()) { + ythrow yexception() << "fail to connect to " << Options_.Host; + } + + if (Options_.CompressionMethod != ECompressionMethod::None) { + Compression_ = CompressionState::Enable; + } + } + + TClient::TImpl::~TImpl() { + Disconnect(); + } + + void TClient::TImpl::ExecuteQuery(TQuery query) { + EnsureNull en(static_cast<TQueryEvents*>(&query), &Events_); + + if (Options_.PingBeforeQuery) { + RetryGuard([this]() { Ping(); }); + } + + SendQuery(query.GetText()); + + ui64 server_packet = 0; + while (ReceivePacket(&server_packet)) { + ; + } + if (server_packet != ServerCodes::EndOfStream && server_packet != ServerCodes::Exception) { + ythrow yexception() << "unexpected packet from server while receiving end of query (got: " + << (server_packet ? ToString(server_packet) : "nothing") << ")"; + } + } + + void TClient::TImpl::Insert(const TString& table_name, const TBlock& block) { + if (Options_.PingBeforeQuery) { + RetryGuard([this]() { Ping(); }); + } + TVector<TString> fields; + fields.reserve(block.GetColumnCount()); + + // Enumerate all fields + for (TBlock::TIterator bi(block); bi.IsValid(); bi.Next()) { + fields.push_back(bi.Name()); + } + + TStringBuilder fields_section; + for (auto elem = fields.begin(); elem != fields.end(); ++elem) { + if (std::distance(elem, fields.end()) == 1) { + fields_section << *elem; + } else { + fields_section << *elem << ","; + } + } + + SendQuery("INSERT INTO " + table_name + " ( " + fields_section + " ) VALUES"); + + ui64 server_packet(0); + // Receive data packet. + while (true) { + bool ret = ReceivePacket(&server_packet); + + if (!ret) { + ythrow yexception() << "unable to receive data packet"; + } + if (server_packet == ServerCodes::Data) { + break; + } + if (server_packet == ServerCodes::Progress) { + continue; + } + } + + // Send data. + SendData(block); + // Send empty block as marker of + // end of data. + SendData(TBlock()); + + // Wait for EOS. + ui64 eos_packet{0}; + while (ReceivePacket(&eos_packet)) { + ; + } + + if (eos_packet != ServerCodes::EndOfStream && eos_packet != ServerCodes::Exception + && eos_packet != ServerCodes::Log && Options_.RethrowExceptions) { + ythrow yexception() << "unexpected packet from server while receiving end of query, expected (expected Exception, EndOfStream or Log, got: " + << (eos_packet ? ToString(eos_packet) : "nothing") << ")"; + } + } + + void TClient::TImpl::Ping() { + TWireFormat::WriteUInt64(&Output_, ClientCodes::Ping); + Output_.Flush(); + + ui64 server_packet; + const bool ret = ReceivePacket(&server_packet); + + if (!ret || server_packet != ServerCodes::Pong) { + ythrow yexception() << "fail to ping server"; + } + } + + void TClient::TImpl::ResetConnection() { + Socket_ = TSocket(TNetworkAddress(Options_.Host, Options_.Port), Options_.ConnectTimeout); + + if (Options_.UseSsl) { + SslClient_.Reset(new TOpenSslClientIO(&SocketInput_, &SocketOutput_, Options_.SslOptions)); + BufferedInput_.Reset(new TBufferedInput(SslClient_.Get())); + BufferedOutput_.Reset(new TBufferedOutput(SslClient_.Get())); + } else { + BufferedInput_.Reset(new TBufferedInput(&SocketInput_)); + BufferedOutput_.Reset(new TBufferedOutput(&SocketOutput_)); + } + + SocketInput_ = TSocketInput(Socket_); + SocketOutput_ = TSocketOutput(Socket_); + + Input_ = TCodedInputStream(BufferedInput_.Get()); + Output_ = TCodedOutputStream(BufferedOutput_.Get()); + + if (Options_.RequestTimeout.Seconds()) { + Socket_.SetSocketTimeout(Options_.RequestTimeout.Seconds()); + } + + if (!Handshake()) { + ythrow yexception() << "fail to connect to " << Options_.Host; + } + } + + bool TClient::TImpl::Handshake() { + if (!SendHello()) { + return false; + } + if (!ReceiveHello()) { + return false; + } + return true; + } + + bool TClient::TImpl::ReceivePacket(ui64* server_packet) { + ui64 packet_type = 0; + + if (!Input_.ReadVarint64(&packet_type)) { + return false; + } + if (server_packet) { + *server_packet = packet_type; + } + + switch (packet_type) { + case ServerCodes::Totals: + case ServerCodes::Data: { + if (!ReceiveData()) { + ythrow yexception() << "can't read data packet from input stream"; + } + return true; + } + + case ServerCodes::Exception: { + ReceiveException(); + return false; + } + + case ServerCodes::ProfileInfo: { + TProfile profile; + + if (!TWireFormat::ReadUInt64(&Input_, &profile.rows)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &profile.blocks)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &profile.bytes)) { + return false; + } + if (!TWireFormat::ReadFixed(&Input_, &profile.applied_limit)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &profile.rows_before_limit)) { + return false; + } + if (!TWireFormat::ReadFixed(&Input_, &profile.calculated_rows_before_limit)) { + return false; + } + + if (Events_) { + Events_->OnProfile(profile); + } + + return true; + } + + case ServerCodes::Progress: { + TProgress info; + + if (!TWireFormat::ReadUInt64(&Input_, &info.rows)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &info.bytes)) { + return false; + } + if (REVISION >= DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS) { + if (!TWireFormat::ReadUInt64(&Input_, &info.total_rows)) { + return false; + } + } + + if (Events_) { + Events_->OnProgress(info); + } + + return true; + } + + case ServerCodes::Pong: { + return true; + } + + case ServerCodes::EndOfStream: { + if (Events_) { + Events_->OnFinish(); + } + return false; + } + + default: + ythrow yexception() << "unimplemented " << (int)packet_type; + break; + } + + return false; + } + + bool TClient::TImpl::ReadBlock(TBlock* block, TCodedInputStream* input) { + // Additional information about block. + if (REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) { + ui64 num; + TBlockInfo info; + + // BlockInfo + if (!TWireFormat::ReadUInt64(input, &num)) { + return false; + } + if (!TWireFormat::ReadFixed(input, &info.IsOverflows)) { + return false; + } + if (!TWireFormat::ReadUInt64(input, &num)) { + return false; + } + if (!TWireFormat::ReadFixed(input, &info.BucketNum)) { + return false; + } + if (!TWireFormat::ReadUInt64(input, &num)) { + return false; + } + + // TODO use data + } + + ui64 num_columns = 0; + ui64 num_rows = 0; + + if (!TWireFormat::ReadUInt64(input, &num_columns)) { + return false; + } + if (!TWireFormat::ReadUInt64(input, &num_rows)) { + return false; + } + + for (size_t i = 0; i < num_columns; ++i) { + TString name; + TString type; + + if (!TWireFormat::ReadString(input, &name)) { + return false; + } + if (!TWireFormat::ReadString(input, &type)) { + return false; + } + + if (TColumnRef col = CreateColumnByType(type)) { + if (num_rows && !col->Load(input, num_rows)) { + ythrow yexception() << "can't load"; + } + + block->AppendColumn(name, col); + } else { + ythrow yexception() << "unsupported column type: " << type; + } + } + + return true; + } + + bool TClient::TImpl::ReceiveData() { + TBlock block; + + if (REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { + TString table_name; + + if (!TWireFormat::ReadString(&Input_, &table_name)) { + return false; + } + } + + if (Compression_ == CompressionState::Enable) { + TCompressedInput compressed(&Input_); + TCodedInputStream coded(&compressed); + + if (!ReadBlock(&block, &coded)) { + return false; + } + } else { + if (!ReadBlock(&block, &Input_)) { + return false; + } + } + + if (Events_) { + Events_->OnData(block); + } + + return true; + } + + bool TClient::TImpl::ReceiveException(bool rethrow) { + std::unique_ptr<TException> e(new TException); + TException* current = e.get(); + + bool exception_received = true; + do { + bool has_nested = false; + + if (!TWireFormat::ReadFixed(&Input_, ¤t->Code)) { + exception_received = false; + break; + } + if (!TWireFormat::ReadString(&Input_, ¤t->Name)) { + exception_received = false; + break; + } + if (!TWireFormat::ReadString(&Input_, ¤t->DisplayText)) { + exception_received = false; + break; + } + if (!TWireFormat::ReadString(&Input_, ¤t->StackTrace)) { + exception_received = false; + break; + } + if (!TWireFormat::ReadFixed(&Input_, &has_nested)) { + exception_received = false; + break; + } + + if (has_nested) { + current->Nested.reset(new TException); + current = current->Nested.get(); + } else { + break; + } + } while (true); + + if (Events_) { + Events_->OnServerException(*e); + } + + if (rethrow || Options_.RethrowExceptions) { + throw TServerException(std::move(e)); + } + + return exception_received; + } + + void TClient::TImpl::SendQuery(const TString& query) { + TWireFormat::WriteUInt64(&Output_, ClientCodes::Query); + TWireFormat::WriteString(&Output_, TString()); + + /// Client info. + if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) { + TClientInfo info; + + info.QueryKind = 1; + info.ClientName = "ClickHouse client"; + info.ClientVersionMajor = DBMS_VERSION_MAJOR; + info.ClientVersionMinor = DBMS_VERSION_MINOR; + info.ClientRevision = REVISION; + + TWireFormat::WriteFixed(&Output_, info.QueryKind); + TWireFormat::WriteString(&Output_, info.InitialUser); + TWireFormat::WriteString(&Output_, info.InitialQueryId); + TWireFormat::WriteString(&Output_, info.InitialAddress); + TWireFormat::WriteFixed(&Output_, info.IfaceType); + + TWireFormat::WriteString(&Output_, info.OsUser); + TWireFormat::WriteString(&Output_, info.ClientHostname); + TWireFormat::WriteString(&Output_, info.ClientName); + TWireFormat::WriteUInt64(&Output_, info.ClientVersionMajor); + TWireFormat::WriteUInt64(&Output_, info.ClientVersionMinor); + TWireFormat::WriteUInt64(&Output_, info.ClientRevision); + + if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO) + TWireFormat::WriteString(&Output_, info.QuotaKey); + } + + /// Per query settings. + //if (settings) + // settings->serialize(*out); + //else + TWireFormat::WriteString(&Output_, TString()); + + TWireFormat::WriteUInt64(&Output_, Stages::Complete); + TWireFormat::WriteUInt64(&Output_, Compression_); + TWireFormat::WriteString(&Output_, query); + // Send empty block as marker of + // end of data + SendData(TBlock()); + + Output_.Flush(); + } + + void TClient::TImpl::WriteBlock(const TBlock& block, TCodedOutputStream* output) { + /// Дополнительная информация о блоке. + if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) { + TWireFormat::WriteUInt64(output, 1); + TWireFormat::WriteFixed(output, block.Info().IsOverflows); + TWireFormat::WriteUInt64(output, 2); + TWireFormat::WriteFixed(output, block.Info().BucketNum); + TWireFormat::WriteUInt64(output, 0); + } + + TWireFormat::WriteUInt64(output, block.GetColumnCount()); + TWireFormat::WriteUInt64(output, block.GetRowCount()); + + for (TBlock::TIterator bi(block); bi.IsValid(); bi.Next()) { + TWireFormat::WriteString(output, bi.Name()); + TWireFormat::WriteString(output, bi.Type()->GetName()); + + bi.Column()->Save(output); + } + } + + void TClient::TImpl::SendData(const TBlock& block) { + TWireFormat::WriteUInt64(&Output_, ClientCodes::Data); + + if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { + TWireFormat::WriteString(&Output_, TString()); + } + + if (Compression_ == CompressionState::Enable) { + switch (Options_.CompressionMethod) { + case ECompressionMethod::None: { + Y_ABORT_UNLESS(false, "invalid state"); + break; + } + + case ECompressionMethod::LZ4: { + TBufferOutput tmp; + + // Serialize block's data + { + TCodedOutputStream out(&tmp); + WriteBlock(block, &out); + } + // Reserver space for data + TBuffer buf; + buf.Resize(9 + LZ4_compressBound(tmp.Buffer().Size())); + + // Compress data + int size = LZ4_compress(tmp.Buffer().Data(), buf.Data() + 9, tmp.Buffer().Size()); + buf.Resize(9 + size); + + // Fill header + ui8* p = (ui8*)buf.Data(); + // Compression method + WriteUnaligned<ui8>(p, (ui8)0x82); + p += 1; + // Compressed data size with header + WriteUnaligned<ui32>(p, (ui32)buf.Size()); + p += 4; + // Original data size + WriteUnaligned<ui32>(p, (ui32)tmp.Buffer().Size()); + + TWireFormat::WriteFixed(&Output_, CityHash_v1_0_2::CityHash128( + buf.Data(), buf.Size())); + TWireFormat::WriteBytes(&Output_, buf.Data(), buf.Size()); + break; + } + } + } else { + WriteBlock(block, &Output_); + } + + Output_.Flush(); + } + + bool TClient::TImpl::SendHello() { + TWireFormat::WriteUInt64(&Output_, ClientCodes::Hello); + TWireFormat::WriteString(&Output_, TString(DBMS_NAME) + " client"); + TWireFormat::WriteUInt64(&Output_, DBMS_VERSION_MAJOR); + TWireFormat::WriteUInt64(&Output_, DBMS_VERSION_MINOR); + TWireFormat::WriteUInt64(&Output_, REVISION); + TWireFormat::WriteString(&Output_, Options_.DefaultDatabase); + TWireFormat::WriteString(&Output_, Options_.User); + TWireFormat::WriteString(&Output_, Options_.Password); + + Output_.Flush(); + + return true; + } + + bool TClient::TImpl::ReceiveHello() { + ui64 packet_type = 0; + + if (!Input_.ReadVarint64(&packet_type)) { + return false; + } + + if (packet_type == ServerCodes::Hello) { + if (!TWireFormat::ReadString(&Input_, &ServerInfo_.Name)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &ServerInfo_.VersionMajor)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &ServerInfo_.VersionMinor)) { + return false; + } + if (!TWireFormat::ReadUInt64(&Input_, &ServerInfo_.Revision)) { + return false; + } + + if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE) { + if (!TWireFormat::ReadString(&Input_, &ServerInfo_.Timezone)) { + return false; + } + } + + return true; + } else if (packet_type == ServerCodes::Exception) { + ReceiveException(true); + return false; + } + + return false; + } + + void TClient::TImpl::RetryGuard(std::function<void()> func) { + for (int i = 0; i <= Options_.SendRetries; ++i) { + try { + func(); + return; + } catch (const yexception&) { + bool ok = true; + + try { + Sleep(Options_.RetryTimeout); + ResetConnection(); + } catch (...) { + ok = false; + } + + if (!ok) { + throw; + } + } + } + } + + TClient::TClient(const TClientOptions& opts) + : Options_(opts) + , Impl_(new TImpl(opts)) + { + } + + TClient::~TClient() { + } + + void TClient::Execute(const TQuery& query) { + Impl_->ExecuteQuery(query); + } + + void TClient::Select(const TString& query, TSelectCallback cb) { + Execute(TQuery(query).OnData(cb)); + } + + void TClient::Select(const TQuery& query) { + Execute(query); + } + + void TClient::Insert(const TString& table_name, const TBlock& block) { + Impl_->Insert(table_name, block); + } + + void TClient::Ping() { + Impl_->Ping(); + } + + void TClient::ResetConnection() { + Impl_->ResetConnection(); + } + +} |