#include "client.h"
#include "protocol.h"
#include <library/cpp/clickhouse/client/base/coded.h>
#include <library/cpp/clickhouse/client/base/compressed.h>
#include <library/cpp/clickhouse/client/base/wire_format.h>
#include <library/cpp/clickhouse/client/columns/factory.h>
#include <library/cpp/openssl/io/stream.h>
#include <util/generic/buffer.h>
#include <util/generic/vector.h>
#include <util/network/socket.h>
#include <util/random/random.h>
#include <util/stream/buffered.h>
#include <util/stream/buffer.h>
#include <util/stream/mem.h>
#include <util/string/builder.h>
#include <util/string/cast.h>
#include <util/system/unaligned_mem.h>
#include <contrib/libs/lz4/lz4.h>
#include <contrib/restricted/cityhash-1.0.2/city.h>
#define DBMS_NAME "ClickHouse"
#define DBMS_VERSION_MAJOR 1
#define DBMS_VERSION_MINOR 1
#define REVISION 54126
#define DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES 50264
#define DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS 51554
#define DBMS_MIN_REVISION_WITH_BLOCK_INFO 51903
#define DBMS_MIN_REVISION_WITH_CLIENT_INFO 54032
#define DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE 54058
#define DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO 54060
namespace NClickHouse {
struct TClientInfo {
ui8 IfaceType = 1; // TCP
ui8 QueryKind;
TString InitialUser;
TString InitialQueryId;
TString QuotaKey;
TString OsUser;
TString ClientHostname;
TString ClientName;
TString InitialAddress = "[::ffff:127.0.0.1]:0";
ui64 ClientVersionMajor = 0;
ui64 ClientVersionMinor = 0;
ui32 ClientRevision = 0;
};
struct TServerInfo {
TString Name;
TString Timezone;
ui64 VersionMajor;
ui64 VersionMinor;
ui64 Revision;
};
class TClient::TImpl {
public:
TImpl(const TClientOptions& opts);
~TImpl();
void ExecuteQuery(TQuery query);
void Insert(const TString& table_name, const TBlock& block);
void Ping();
void ResetConnection();
private:
bool Handshake();
bool ReceivePacket(ui64* server_packet = nullptr);
void SendQuery(const TString& query);
void SendData(const TBlock& block);
bool SendHello();
bool ReadBlock(TBlock* block, TCodedInputStream* input);
bool ReceiveHello();
/// Reads data packet form input stream.
bool ReceiveData();
/// Reads exception packet form input stream.
bool ReceiveException(bool rethrow = false);
void WriteBlock(const TBlock& block, TCodedOutputStream* output);
private:
void Disconnect() {
Socket_ = TSocket();
}
/// In case of network errors tries to reconnect to server and
/// call fuc several times.
void RetryGuard(std::function<void()> fuc);
private:
class EnsureNull {
public:
inline EnsureNull(TQueryEvents* ev, TQueryEvents** ptr)
: ptr_(ptr)
{
if (ptr_) {
*ptr_ = ev;
}
}
inline ~EnsureNull() {
if (ptr_) {
*ptr_ = nullptr;
}
}
private:
TQueryEvents** ptr_;
};
const TClientOptions Options_;
TQueryEvents* Events_;
int Compression_ = CompressionState::Disable;
TSocket Socket_;
TSocketInput SocketInput_;
TSocketOutput SocketOutput_;
THolder<TBufferedInput> BufferedInput_;
THolder<TBufferedOutput> BufferedOutput_;
THolder<TOpenSslClientIO> SslClient_;
TCodedInputStream Input_;
TCodedOutputStream Output_;
TServerInfo ServerInfo_;
};
TClient::TImpl::TImpl(const TClientOptions& opts)
: Options_(opts)
, Events_(nullptr)
, Socket_(TNetworkAddress(opts.Host, opts.Port), Options_.ConnectTimeout)
, SocketInput_(Socket_)
, SocketOutput_(Socket_)
{
if (opts.UseSsl) {
SslClient_ = MakeHolder<TOpenSslClientIO>(&SocketInput_, &SocketOutput_, opts.SslOptions);
BufferedInput_ = MakeHolder<TBufferedInput>(SslClient_.Get());
BufferedOutput_ = MakeHolder<TBufferedOutput>(SslClient_.Get());
} else {
BufferedInput_ = MakeHolder<TBufferedInput>(&SocketInput_);
BufferedOutput_ = MakeHolder<TBufferedOutput>(&SocketOutput_);
}
Input_ = TCodedInputStream(BufferedInput_.Get());
Output_ = TCodedOutputStream(BufferedOutput_.Get());
if (Options_.RequestTimeout.Seconds()) {
Socket_.SetSocketTimeout(Options_.RequestTimeout.Seconds());
}
if (!Handshake()) {
ythrow yexception() << "fail to connect to " << Options_.Host;
}
if (Options_.CompressionMethod != ECompressionMethod::None) {
Compression_ = CompressionState::Enable;
}
}
TClient::TImpl::~TImpl() {
Disconnect();
}
void TClient::TImpl::ExecuteQuery(TQuery query) {
EnsureNull en(static_cast<TQueryEvents*>(&query), &Events_);
if (Options_.PingBeforeQuery) {
RetryGuard([this]() { Ping(); });
}
SendQuery(query.GetText());
ui64 server_packet = 0;
while (ReceivePacket(&server_packet)) {
;
}
if (server_packet != ServerCodes::EndOfStream && server_packet != ServerCodes::Exception) {
ythrow yexception() << "unexpected packet from server while receiving end of query (got: "
<< (server_packet ? ToString(server_packet) : "nothing") << ")";
}
}
void TClient::TImpl::Insert(const TString& table_name, const TBlock& block) {
if (Options_.PingBeforeQuery) {
RetryGuard([this]() { Ping(); });
}
TVector<TString> fields;
fields.reserve(block.GetColumnCount());
// Enumerate all fields
for (TBlock::TIterator bi(block); bi.IsValid(); bi.Next()) {
fields.push_back(bi.Name());
}
TStringBuilder fields_section;
for (auto elem = fields.begin(); elem != fields.end(); ++elem) {
if (std::distance(elem, fields.end()) == 1) {
fields_section << *elem;
} else {
fields_section << *elem << ",";
}
}
SendQuery("INSERT INTO " + table_name + " ( " + fields_section + " ) VALUES");
ui64 server_packet(0);
// Receive data packet.
while (true) {
bool ret = ReceivePacket(&server_packet);
if (!ret) {
ythrow yexception() << "unable to receive data packet";
}
if (server_packet == ServerCodes::Data) {
break;
}
if (server_packet == ServerCodes::Progress) {
continue;
}
}
// Send data.
SendData(block);
// Send empty block as marker of
// end of data.
SendData(TBlock());
// Wait for EOS.
ui64 eos_packet{0};
while (ReceivePacket(&eos_packet)) {
;
}
if (eos_packet != ServerCodes::EndOfStream && eos_packet != ServerCodes::Exception
&& eos_packet != ServerCodes::Log && Options_.RethrowExceptions) {
ythrow yexception() << "unexpected packet from server while receiving end of query, expected (expected Exception, EndOfStream or Log, got: "
<< (eos_packet ? ToString(eos_packet) : "nothing") << ")";
}
}
void TClient::TImpl::Ping() {
TWireFormat::WriteUInt64(&Output_, ClientCodes::Ping);
Output_.Flush();
ui64 server_packet;
const bool ret = ReceivePacket(&server_packet);
if (!ret || server_packet != ServerCodes::Pong) {
ythrow yexception() << "fail to ping server";
}
}
void TClient::TImpl::ResetConnection() {
Socket_ = TSocket(TNetworkAddress(Options_.Host, Options_.Port), Options_.ConnectTimeout);
if (Options_.UseSsl) {
SslClient_.Reset(new TOpenSslClientIO(&SocketInput_, &SocketOutput_, Options_.SslOptions));
BufferedInput_.Reset(new TBufferedInput(SslClient_.Get()));
BufferedOutput_.Reset(new TBufferedOutput(SslClient_.Get()));
} else {
BufferedInput_.Reset(new TBufferedInput(&SocketInput_));
BufferedOutput_.Reset(new TBufferedOutput(&SocketOutput_));
}
SocketInput_ = TSocketInput(Socket_);
SocketOutput_ = TSocketOutput(Socket_);
Input_ = TCodedInputStream(BufferedInput_.Get());
Output_ = TCodedOutputStream(BufferedOutput_.Get());
if (Options_.RequestTimeout.Seconds()) {
Socket_.SetSocketTimeout(Options_.RequestTimeout.Seconds());
}
if (!Handshake()) {
ythrow yexception() << "fail to connect to " << Options_.Host;
}
}
bool TClient::TImpl::Handshake() {
if (!SendHello()) {
return false;
}
if (!ReceiveHello()) {
return false;
}
return true;
}
bool TClient::TImpl::ReceivePacket(ui64* server_packet) {
ui64 packet_type = 0;
if (!Input_.ReadVarint64(&packet_type)) {
return false;
}
if (server_packet) {
*server_packet = packet_type;
}
switch (packet_type) {
case ServerCodes::Totals:
case ServerCodes::Data: {
if (!ReceiveData()) {
ythrow yexception() << "can't read data packet from input stream";
}
return true;
}
case ServerCodes::Exception: {
ReceiveException();
return false;
}
case ServerCodes::ProfileInfo: {
TProfile profile;
if (!TWireFormat::ReadUInt64(&Input_, &profile.rows)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &profile.blocks)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &profile.bytes)) {
return false;
}
if (!TWireFormat::ReadFixed(&Input_, &profile.applied_limit)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &profile.rows_before_limit)) {
return false;
}
if (!TWireFormat::ReadFixed(&Input_, &profile.calculated_rows_before_limit)) {
return false;
}
if (Events_) {
Events_->OnProfile(profile);
}
return true;
}
case ServerCodes::Progress: {
TProgress info;
if (!TWireFormat::ReadUInt64(&Input_, &info.rows)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &info.bytes)) {
return false;
}
if (REVISION >= DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS) {
if (!TWireFormat::ReadUInt64(&Input_, &info.total_rows)) {
return false;
}
}
if (Events_) {
Events_->OnProgress(info);
}
return true;
}
case ServerCodes::Pong: {
return true;
}
case ServerCodes::EndOfStream: {
if (Events_) {
Events_->OnFinish();
}
return false;
}
default:
ythrow yexception() << "unimplemented " << (int)packet_type;
break;
}
return false;
}
bool TClient::TImpl::ReadBlock(TBlock* block, TCodedInputStream* input) {
// Additional information about block.
if (REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
ui64 num;
TBlockInfo info;
// BlockInfo
if (!TWireFormat::ReadUInt64(input, &num)) {
return false;
}
if (!TWireFormat::ReadFixed(input, &info.IsOverflows)) {
return false;
}
if (!TWireFormat::ReadUInt64(input, &num)) {
return false;
}
if (!TWireFormat::ReadFixed(input, &info.BucketNum)) {
return false;
}
if (!TWireFormat::ReadUInt64(input, &num)) {
return false;
}
// TODO use data
}
ui64 num_columns = 0;
ui64 num_rows = 0;
if (!TWireFormat::ReadUInt64(input, &num_columns)) {
return false;
}
if (!TWireFormat::ReadUInt64(input, &num_rows)) {
return false;
}
for (size_t i = 0; i < num_columns; ++i) {
TString name;
TString type;
if (!TWireFormat::ReadString(input, &name)) {
return false;
}
if (!TWireFormat::ReadString(input, &type)) {
return false;
}
if (TColumnRef col = CreateColumnByType(type)) {
if (num_rows && !col->Load(input, num_rows)) {
ythrow yexception() << "can't load";
}
block->AppendColumn(name, col);
} else {
ythrow yexception() << "unsupported column type: " << type;
}
}
return true;
}
bool TClient::TImpl::ReceiveData() {
TBlock block;
if (REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
TString table_name;
if (!TWireFormat::ReadString(&Input_, &table_name)) {
return false;
}
}
if (Compression_ == CompressionState::Enable) {
TCompressedInput compressed(&Input_);
TCodedInputStream coded(&compressed);
if (!ReadBlock(&block, &coded)) {
return false;
}
} else {
if (!ReadBlock(&block, &Input_)) {
return false;
}
}
if (Events_) {
Events_->OnData(block);
}
return true;
}
bool TClient::TImpl::ReceiveException(bool rethrow) {
std::unique_ptr<TException> e(new TException);
TException* current = e.get();
bool exception_received = true;
do {
bool has_nested = false;
if (!TWireFormat::ReadFixed(&Input_, ¤t->Code)) {
exception_received = false;
break;
}
if (!TWireFormat::ReadString(&Input_, ¤t->Name)) {
exception_received = false;
break;
}
if (!TWireFormat::ReadString(&Input_, ¤t->DisplayText)) {
exception_received = false;
break;
}
if (!TWireFormat::ReadString(&Input_, ¤t->StackTrace)) {
exception_received = false;
break;
}
if (!TWireFormat::ReadFixed(&Input_, &has_nested)) {
exception_received = false;
break;
}
if (has_nested) {
current->Nested.reset(new TException);
current = current->Nested.get();
} else {
break;
}
} while (true);
if (Events_) {
Events_->OnServerException(*e);
}
if (rethrow || Options_.RethrowExceptions) {
throw TServerException(std::move(e));
}
return exception_received;
}
void TClient::TImpl::SendQuery(const TString& query) {
TWireFormat::WriteUInt64(&Output_, ClientCodes::Query);
TWireFormat::WriteString(&Output_, TString());
/// Client info.
if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) {
TClientInfo info;
info.QueryKind = 1;
info.ClientName = "ClickHouse client";
info.ClientVersionMajor = DBMS_VERSION_MAJOR;
info.ClientVersionMinor = DBMS_VERSION_MINOR;
info.ClientRevision = REVISION;
TWireFormat::WriteFixed(&Output_, info.QueryKind);
TWireFormat::WriteString(&Output_, info.InitialUser);
TWireFormat::WriteString(&Output_, info.InitialQueryId);
TWireFormat::WriteString(&Output_, info.InitialAddress);
TWireFormat::WriteFixed(&Output_, info.IfaceType);
TWireFormat::WriteString(&Output_, info.OsUser);
TWireFormat::WriteString(&Output_, info.ClientHostname);
TWireFormat::WriteString(&Output_, info.ClientName);
TWireFormat::WriteUInt64(&Output_, info.ClientVersionMajor);
TWireFormat::WriteUInt64(&Output_, info.ClientVersionMinor);
TWireFormat::WriteUInt64(&Output_, info.ClientRevision);
if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO)
TWireFormat::WriteString(&Output_, info.QuotaKey);
}
/// Per query settings.
//if (settings)
// settings->serialize(*out);
//else
TWireFormat::WriteString(&Output_, TString());
TWireFormat::WriteUInt64(&Output_, Stages::Complete);
TWireFormat::WriteUInt64(&Output_, Compression_);
TWireFormat::WriteString(&Output_, query);
// Send empty block as marker of
// end of data
SendData(TBlock());
Output_.Flush();
}
void TClient::TImpl::WriteBlock(const TBlock& block, TCodedOutputStream* output) {
/// Дополнительная информация о блоке.
if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
TWireFormat::WriteUInt64(output, 1);
TWireFormat::WriteFixed(output, block.Info().IsOverflows);
TWireFormat::WriteUInt64(output, 2);
TWireFormat::WriteFixed(output, block.Info().BucketNum);
TWireFormat::WriteUInt64(output, 0);
}
TWireFormat::WriteUInt64(output, block.GetColumnCount());
TWireFormat::WriteUInt64(output, block.GetRowCount());
for (TBlock::TIterator bi(block); bi.IsValid(); bi.Next()) {
TWireFormat::WriteString(output, bi.Name());
TWireFormat::WriteString(output, bi.Type()->GetName());
bi.Column()->Save(output);
}
}
void TClient::TImpl::SendData(const TBlock& block) {
TWireFormat::WriteUInt64(&Output_, ClientCodes::Data);
if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
TWireFormat::WriteString(&Output_, TString());
}
if (Compression_ == CompressionState::Enable) {
switch (Options_.CompressionMethod) {
case ECompressionMethod::None: {
Y_ABORT_UNLESS(false, "invalid state");
break;
}
case ECompressionMethod::LZ4: {
TBufferOutput tmp;
// Serialize block's data
{
TCodedOutputStream out(&tmp);
WriteBlock(block, &out);
}
// Reserver space for data
TBuffer buf;
buf.Resize(9 + LZ4_compressBound(tmp.Buffer().Size()));
// Compress data
int size = LZ4_compress(tmp.Buffer().Data(), buf.Data() + 9, tmp.Buffer().Size());
buf.Resize(9 + size);
// Fill header
ui8* p = (ui8*)buf.Data();
// Compression method
WriteUnaligned<ui8>(p, (ui8)0x82);
p += 1;
// Compressed data size with header
WriteUnaligned<ui32>(p, (ui32)buf.Size());
p += 4;
// Original data size
WriteUnaligned<ui32>(p, (ui32)tmp.Buffer().Size());
TWireFormat::WriteFixed(&Output_, CityHash_v1_0_2::CityHash128(
buf.Data(), buf.Size()));
TWireFormat::WriteBytes(&Output_, buf.Data(), buf.Size());
break;
}
}
} else {
WriteBlock(block, &Output_);
}
Output_.Flush();
}
bool TClient::TImpl::SendHello() {
TWireFormat::WriteUInt64(&Output_, ClientCodes::Hello);
TWireFormat::WriteString(&Output_, TString(DBMS_NAME) + " client");
TWireFormat::WriteUInt64(&Output_, DBMS_VERSION_MAJOR);
TWireFormat::WriteUInt64(&Output_, DBMS_VERSION_MINOR);
TWireFormat::WriteUInt64(&Output_, REVISION);
TWireFormat::WriteString(&Output_, Options_.DefaultDatabase);
TWireFormat::WriteString(&Output_, Options_.User);
TWireFormat::WriteString(&Output_, Options_.Password);
Output_.Flush();
return true;
}
bool TClient::TImpl::ReceiveHello() {
ui64 packet_type = 0;
if (!Input_.ReadVarint64(&packet_type)) {
return false;
}
if (packet_type == ServerCodes::Hello) {
if (!TWireFormat::ReadString(&Input_, &ServerInfo_.Name)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &ServerInfo_.VersionMajor)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &ServerInfo_.VersionMinor)) {
return false;
}
if (!TWireFormat::ReadUInt64(&Input_, &ServerInfo_.Revision)) {
return false;
}
if (ServerInfo_.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE) {
if (!TWireFormat::ReadString(&Input_, &ServerInfo_.Timezone)) {
return false;
}
}
return true;
} else if (packet_type == ServerCodes::Exception) {
ReceiveException(true);
return false;
}
return false;
}
void TClient::TImpl::RetryGuard(std::function<void()> func) {
for (int i = 0; i <= Options_.SendRetries; ++i) {
try {
func();
return;
} catch (const yexception&) {
bool ok = true;
try {
Sleep(Options_.RetryTimeout);
ResetConnection();
} catch (...) {
ok = false;
}
if (!ok) {
throw;
}
}
}
}
TClient::TClient(const TClientOptions& opts)
: Options_(opts)
, Impl_(new TImpl(opts))
{
}
TClient::~TClient() {
}
void TClient::Execute(const TQuery& query) {
Impl_->ExecuteQuery(query);
}
void TClient::Select(const TString& query, TSelectCallback cb) {
Execute(TQuery(query).OnData(cb));
}
void TClient::Select(const TQuery& query) {
Execute(query);
}
void TClient::Insert(const TString& table_name, const TBlock& block) {
Impl_->Insert(table_name, block);
}
void TClient::Ping() {
Impl_->Ping();
}
void TClient::ResetConnection() {
Impl_->ResetConnection();
}
}