aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Server
diff options
context:
space:
mode:
authorvitalyisaev <vitalyisaev@ydb.tech>2023-11-14 09:58:56 +0300
committervitalyisaev <vitalyisaev@ydb.tech>2023-11-14 10:20:20 +0300
commitc2b2dfd9827a400a8495e172a56343462e3ceb82 (patch)
treecd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Server
parentd4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff)
downloadydb-c2b2dfd9827a400a8495e172a56343462e3ceb82.tar.gz
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Server')
-rw-r--r--contrib/clickhouse/src/Server/CertificateReloader.cpp131
-rw-r--r--contrib/clickhouse/src/Server/CertificateReloader.h84
-rw-r--r--contrib/clickhouse/src/Server/GRPCServer.cpp1898
-rw-r--r--contrib/clickhouse/src/Server/GRPCServer.h56
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp347
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTMLForm.h124
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPContext.h24
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPRequest.h10
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h19
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h20
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPResponse.h10
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp30
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServer.h33
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp121
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h51
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp24
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h26
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp174
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h73
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp121
-rw-r--r--contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h70
-rw-r--r--contrib/clickhouse/src/Server/HTTP/README.md3
-rw-r--r--contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp85
-rw-r--r--contrib/clickhouse/src/Server/HTTP/ReadHeaders.h13
-rw-r--r--contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp204
-rw-r--r--contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h134
-rw-r--r--contrib/clickhouse/src/Server/HTTPHandler.cpp1320
-rw-r--r--contrib/clickhouse/src/Server/HTTPHandler.h173
-rw-r--r--contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp186
-rw-r--r--contrib/clickhouse/src/Server/HTTPHandlerFactory.h148
-rw-r--r--contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h102
-rw-r--r--contrib/clickhouse/src/Server/HTTPPathHints.cpp16
-rw-r--r--contrib/clickhouse/src/Server/HTTPPathHints.h22
-rw-r--r--contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp38
-rw-r--r--contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h31
-rw-r--r--contrib/clickhouse/src/Server/IServer.h39
-rw-r--r--contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp163
-rw-r--r--contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h51
-rw-r--r--contrib/clickhouse/src/Server/KeeperTCPHandler.cpp695
-rw-r--r--contrib/clickhouse/src/Server/KeeperTCPHandler.h113
-rw-r--r--contrib/clickhouse/src/Server/MySQLHandler.cpp508
-rw-r--r--contrib/clickhouse/src/Server/MySQLHandler.h111
-rw-r--r--contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp140
-rw-r--r--contrib/clickhouse/src/Server/MySQLHandlerFactory.h50
-rw-r--r--contrib/clickhouse/src/Server/NotFoundHandler.cpp31
-rw-r--r--contrib/clickhouse/src/Server/NotFoundHandler.h18
-rw-r--r--contrib/clickhouse/src/Server/PostgreSQLHandler.cpp329
-rw-r--r--contrib/clickhouse/src/Server/PostgreSQLHandler.h81
-rw-r--r--contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp26
-rw-r--r--contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h33
-rw-r--r--contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp161
-rw-r--r--contrib/clickhouse/src/Server/PrometheusMetricsWriter.h38
-rw-r--r--contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp71
-rw-r--r--contrib/clickhouse/src/Server/PrometheusRequestHandler.h28
-rw-r--r--contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp75
-rw-r--r--contrib/clickhouse/src/Server/ProtocolServerAdapter.h74
-rw-r--r--contrib/clickhouse/src/Server/ProxyV1Handler.cpp127
-rw-r--r--contrib/clickhouse/src/Server/ProxyV1Handler.h30
-rw-r--r--contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h56
-rw-r--r--contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp128
-rw-r--r--contrib/clickhouse/src/Server/ReplicasStatusHandler.h21
-rw-r--r--contrib/clickhouse/src/Server/ServerType.cpp153
-rw-r--r--contrib/clickhouse/src/Server/ServerType.h60
-rw-r--r--contrib/clickhouse/src/Server/StaticRequestHandler.cpp179
-rw-r--r--contrib/clickhouse/src/Server/StaticRequestHandler.h35
-rw-r--r--contrib/clickhouse/src/Server/TCPHandler.cpp2151
-rw-r--r--contrib/clickhouse/src/Server/TCPHandler.h294
-rw-r--r--contrib/clickhouse/src/Server/TCPHandlerFactory.h74
-rw-r--r--contrib/clickhouse/src/Server/TCPProtocolStackData.h22
-rw-r--r--contrib/clickhouse/src/Server/TCPProtocolStackFactory.h92
-rw-r--r--contrib/clickhouse/src/Server/TCPProtocolStackHandler.h46
-rw-r--r--contrib/clickhouse/src/Server/TCPServer.cpp36
-rw-r--r--contrib/clickhouse/src/Server/TCPServer.h47
-rw-r--r--contrib/clickhouse/src/Server/TCPServerConnectionFactory.h32
-rw-r--r--contrib/clickhouse/src/Server/TLSHandler.h58
-rw-r--r--contrib/clickhouse/src/Server/TLSHandlerFactory.h64
-rw-r--r--contrib/clickhouse/src/Server/WebUIRequestHandler.cpp75
-rw-r--r--contrib/clickhouse/src/Server/WebUIRequestHandler.h22
-rw-r--r--contrib/clickhouse/src/Server/waitServersToFinish.cpp39
-rw-r--r--contrib/clickhouse/src/Server/waitServersToFinish.h10
80 files changed, 12607 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Server/CertificateReloader.cpp b/contrib/clickhouse/src/Server/CertificateReloader.cpp
new file mode 100644
index 0000000000..8795d4807d
--- /dev/null
+++ b/contrib/clickhouse/src/Server/CertificateReloader.cpp
@@ -0,0 +1,131 @@
+#include "CertificateReloader.h"
+
+#if USE_SSL
+
+#include <Common/logger_useful.h>
+#include <base/errnoToString.h>
+#include <Poco/Net/Context.h>
+#include <Poco/Net/SSLManager.h>
+#include <Poco/Net/Utility.h>
+
+
+namespace DB
+{
+
+namespace
+{
+/// Call set process for certificate.
+int callSetCertificate(SSL * ssl, [[maybe_unused]] void * arg)
+{
+ return CertificateReloader::instance().setCertificate(ssl);
+}
+
+}
+
+
+/// This is callback for OpenSSL. It will be called on every connection to obtain a certificate and private key.
+int CertificateReloader::setCertificate(SSL * ssl)
+{
+ auto current = data.get();
+ if (!current)
+ return -1;
+
+ SSL_use_certificate(ssl, const_cast<X509 *>(current->cert.certificate()));
+ SSL_use_PrivateKey(ssl, const_cast<EVP_PKEY *>(static_cast<const EVP_PKEY *>(current->key)));
+
+ int err = SSL_check_private_key(ssl);
+ if (err != 1)
+ {
+ std::string msg = Poco::Net::Utility::getLastError();
+ LOG_ERROR(log, "Unusable key-pair {}", msg);
+ return -1;
+ }
+
+ return 1;
+}
+
+
+void CertificateReloader::init()
+{
+ LOG_DEBUG(log, "Initializing certificate reloader.");
+
+ /// Set a callback for OpenSSL to allow get the updated cert and key.
+
+ auto* ctx = Poco::Net::SSLManager::instance().defaultServerContext()->sslContext();
+ SSL_CTX_set_cert_cb(ctx, callSetCertificate, nullptr);
+ init_was_not_made = false;
+}
+
+
+void CertificateReloader::tryLoad(const Poco::Util::AbstractConfiguration & config)
+{
+ /// If at least one of the files is modified - recreate
+
+ std::string new_cert_path = config.getString("openSSL.server.certificateFile", "");
+ std::string new_key_path = config.getString("openSSL.server.privateKeyFile", "");
+
+ /// For empty paths (that means, that user doesn't want to use certificates)
+ /// no processing required
+
+ if (new_cert_path.empty() || new_key_path.empty())
+ {
+ LOG_INFO(log, "One of paths is empty. Cannot apply new configuration for certificates. Fill all paths and try again.");
+ }
+ else
+ {
+ bool cert_file_changed = cert_file.changeIfModified(std::move(new_cert_path), log);
+ bool key_file_changed = key_file.changeIfModified(std::move(new_key_path), log);
+ std::string pass_phrase = config.getString("openSSL.server.privateKeyPassphraseHandler.options.password", "");
+
+ if (cert_file_changed || key_file_changed)
+ {
+ LOG_DEBUG(log, "Reloading certificate ({}) and key ({}).", cert_file.path, key_file.path);
+ data.set(std::make_unique<const Data>(cert_file.path, key_file.path, pass_phrase));
+ LOG_INFO(log, "Reloaded certificate ({}) and key ({}).", cert_file.path, key_file.path);
+ }
+
+ /// If callback is not set yet
+ try
+ {
+ if (init_was_not_made)
+ init();
+ }
+ catch (...)
+ {
+ init_was_not_made = true;
+ LOG_ERROR(log, getCurrentExceptionMessageAndPattern(/* with_stacktrace */ false));
+ }
+ }
+}
+
+
+CertificateReloader::Data::Data(std::string cert_path, std::string key_path, std::string pass_phrase)
+ : cert(cert_path), key(/* public key */ "", /* private key */ key_path, pass_phrase)
+{
+}
+
+
+bool CertificateReloader::File::changeIfModified(std::string new_path, Poco::Logger * logger)
+{
+ std::error_code ec;
+ std::filesystem::file_time_type new_modification_time = std::filesystem::last_write_time(new_path, ec);
+ if (ec)
+ {
+ LOG_ERROR(logger, "Cannot obtain modification time for {} file {}, skipping update. {}",
+ description, new_path, errnoToString(ec.value()));
+ return false;
+ }
+
+ if (new_path != path || new_modification_time != modification_time)
+ {
+ path = new_path;
+ modification_time = new_modification_time;
+ return true;
+ }
+
+ return false;
+}
+
+}
+
+#endif
diff --git a/contrib/clickhouse/src/Server/CertificateReloader.h b/contrib/clickhouse/src/Server/CertificateReloader.h
new file mode 100644
index 0000000000..e4db674c14
--- /dev/null
+++ b/contrib/clickhouse/src/Server/CertificateReloader.h
@@ -0,0 +1,84 @@
+#pragma once
+
+#include "clickhouse_config.h"
+
+#if USE_SSL
+
+#include <string>
+#include <filesystem>
+
+#include <Poco/Logger.h>
+#include <Poco/Util/AbstractConfiguration.h>
+#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
+#include <Poco/Crypto/RSAKey.h>
+#include <Poco/Crypto/X509Certificate.h>
+#include <Common/MultiVersion.h>
+
+
+namespace DB
+{
+
+/// The CertificateReloader singleton performs 2 functions:
+/// 1. Dynamic reloading of TLS key-pair when requested by server:
+/// Server config reloader notifies CertificateReloader when the config changes.
+/// On changed config, CertificateReloader reloads certs from disk.
+/// 2. Implement `SSL_CTX_set_cert_cb` to set certificate for a new connection:
+/// OpenSSL invokes a callback to setup a connection.
+class CertificateReloader
+{
+public:
+ using stat_t = struct stat;
+
+ /// Singleton
+ CertificateReloader(CertificateReloader const &) = delete;
+ void operator=(CertificateReloader const &) = delete;
+ static CertificateReloader & instance()
+ {
+ static CertificateReloader instance;
+ return instance;
+ }
+
+ /// Initialize the callback and perform the initial cert loading
+ void init();
+
+ /// Handle configuration reload
+ void tryLoad(const Poco::Util::AbstractConfiguration & config);
+
+ /// A callback for OpenSSL
+ int setCertificate(SSL * ssl);
+
+private:
+ CertificateReloader() = default;
+
+ Poco::Logger * log = &Poco::Logger::get("CertificateReloader");
+
+ struct File
+ {
+ const char * description;
+ explicit File(const char * description_) : description(description_) {}
+
+ std::string path;
+ std::filesystem::file_time_type modification_time;
+
+ bool changeIfModified(std::string new_path, Poco::Logger * logger);
+ };
+
+ File cert_file{"certificate"};
+ File key_file{"key"};
+
+ struct Data
+ {
+ Poco::Crypto::X509Certificate cert;
+ Poco::Crypto::EVPPKey key;
+
+ Data(std::string cert_path, std::string key_path, std::string pass_phrase);
+ };
+
+ MultiVersion<Data> data;
+ bool init_was_not_made = true;
+};
+
+}
+
+#endif
diff --git a/contrib/clickhouse/src/Server/GRPCServer.cpp b/contrib/clickhouse/src/Server/GRPCServer.cpp
new file mode 100644
index 0000000000..4d4be31e23
--- /dev/null
+++ b/contrib/clickhouse/src/Server/GRPCServer.cpp
@@ -0,0 +1,1898 @@
+#include "GRPCServer.h"
+#include <limits>
+#include <memory>
+#if USE_GRPC
+
+#include <Columns/ColumnString.h>
+#include <Columns/ColumnsNumber.h>
+#include <Common/CurrentThread.h>
+#include <Common/SettingsChanges.h>
+#include <Common/setThreadName.h>
+#include <Common/Stopwatch.h>
+#include <Common/ThreadPool.h>
+#include <DataTypes/DataTypeFactory.h>
+#include <QueryPipeline/ProfileInfo.h>
+#include <Interpreters/Context.h>
+#include <Interpreters/InternalTextLogsQueue.h>
+#include <Interpreters/executeQuery.h>
+#include <Interpreters/Session.h>
+#include <IO/CompressionMethod.h>
+#include <IO/ConcatReadBuffer.h>
+#include <IO/ReadBufferFromString.h>
+#include <IO/ReadHelpers.h>
+#include <Parsers/parseQuery.h>
+#include <Parsers/ASTIdentifier_fwd.h>
+#include <Parsers/ASTInsertQuery.h>
+#include <Parsers/ASTQueryWithOutput.h>
+#include <Parsers/ParserQuery.h>
+#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
+#include <Processors/Executors/PullingPipelineExecutor.h>
+#include <Processors/Executors/PushingPipelineExecutor.h>
+#include <Processors/Executors/CompletedPipelineExecutor.h>
+#include <Processors/Executors/PipelineExecutor.h>
+#include <Processors/Formats/IInputFormat.h>
+#include <Processors/Formats/IOutputFormat.h>
+#include <Processors/Sinks/SinkToStorage.h>
+#include <Processors/Sinks/EmptySink.h>
+#include <QueryPipeline/QueryPipelineBuilder.h>
+#include <Server/IServer.h>
+#include <Storages/IStorage.h>
+#include <Poco/FileStream.h>
+#include <Poco/StreamCopier.h>
+#include <Poco/Util/LayeredConfiguration.h>
+#include <base/range.h>
+#include <Common/logger_useful.h>
+#error #include <grpc++/security/server_credentials.h>
+#error #include <grpc++/server.h>
+#error #include <grpc++/server_builder.h>
+
+
+using GRPCService = clickhouse::grpc::ClickHouse::AsyncService;
+using GRPCQueryInfo = clickhouse::grpc::QueryInfo;
+using GRPCResult = clickhouse::grpc::Result;
+using GRPCException = clickhouse::grpc::Exception;
+using GRPCProgress = clickhouse::grpc::Progress;
+using GRPCObsoleteTransportCompression = clickhouse::grpc::ObsoleteTransportCompression;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int INVALID_CONFIG_PARAMETER;
+ extern const int INVALID_GRPC_QUERY_INFO;
+ extern const int INVALID_SESSION_TIMEOUT;
+ extern const int LOGICAL_ERROR;
+ extern const int NETWORK_ERROR;
+ extern const int NO_DATA_TO_INSERT;
+ extern const int SUPPORT_IS_DISABLED;
+ extern const int BAD_REQUEST_PARAMETER;
+}
+
+namespace
+{
+ /// Make grpc to pass logging messages to ClickHouse logging system.
+ void initGRPCLogging(const Poco::Util::AbstractConfiguration & config)
+ {
+ static std::once_flag once_flag;
+ std::call_once(once_flag, [&config]
+ {
+ static Poco::Logger * logger = &Poco::Logger::get("grpc");
+ gpr_set_log_function([](gpr_log_func_args* args)
+ {
+ if (args->severity == GPR_LOG_SEVERITY_DEBUG)
+ LOG_DEBUG(logger, "{} ({}:{})", args->message, args->file, args->line);
+ else if (args->severity == GPR_LOG_SEVERITY_INFO)
+ LOG_INFO(logger, "{} ({}:{})", args->message, args->file, args->line);
+ else if (args->severity == GPR_LOG_SEVERITY_ERROR)
+ LOG_ERROR(logger, "{} ({}:{})", args->message, args->file, args->line);
+ });
+
+ if (config.getBool("grpc.verbose_logs", false))
+ {
+ gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG);
+ grpc_tracer_set_enabled("all", true);
+ }
+ else if (logger->is(Poco::Message::PRIO_DEBUG))
+ {
+ gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG);
+ }
+ else if (logger->is(Poco::Message::PRIO_INFORMATION))
+ {
+ gpr_set_log_verbosity(GPR_LOG_SEVERITY_INFO);
+ }
+ });
+ }
+
+ /// Gets file's contents as a string, throws an exception if failed.
+ String readFile(const String & filepath)
+ {
+ Poco::FileInputStream ifs(filepath);
+ String res;
+ Poco::StreamCopier::copyToString(ifs, res);
+ return res;
+ }
+
+ /// Makes credentials based on the server config.
+ std::shared_ptr<grpc::ServerCredentials> makeCredentials(const Poco::Util::AbstractConfiguration & config)
+ {
+ if (config.getBool("grpc.enable_ssl", false))
+ {
+#if USE_SSL
+ grpc::SslServerCredentialsOptions options;
+ grpc::SslServerCredentialsOptions::PemKeyCertPair key_cert_pair;
+ key_cert_pair.private_key = readFile(config.getString("grpc.ssl_key_file"));
+ key_cert_pair.cert_chain = readFile(config.getString("grpc.ssl_cert_file"));
+ options.pem_key_cert_pairs.emplace_back(std::move(key_cert_pair));
+ if (config.getBool("grpc.ssl_require_client_auth", false))
+ {
+ options.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
+ if (config.has("grpc.ssl_ca_cert_file"))
+ options.pem_root_certs = readFile(config.getString("grpc.ssl_ca_cert_file"));
+ }
+ return grpc::SslServerCredentials(options);
+#else
+ throw DB::Exception(DB::ErrorCodes::SUPPORT_IS_DISABLED, "Can't use SSL in grpc, because ClickHouse was built without SSL library");
+#endif
+ }
+ return grpc::InsecureServerCredentials();
+ }
+
+ /// Transport compression makes gRPC library to compress packed Result messages before sending them through network.
+ struct TransportCompression
+ {
+ grpc_compression_algorithm algorithm;
+ grpc_compression_level level;
+
+ /// Extracts the settings of transport compression from a query info if possible.
+ static std::optional<TransportCompression> fromQueryInfo(const GRPCQueryInfo & query_info)
+ {
+ TransportCompression res;
+ if (!query_info.transport_compression_type().empty())
+ {
+ res.setAlgorithm(query_info.transport_compression_type(), ErrorCodes::INVALID_GRPC_QUERY_INFO);
+ res.setLevel(query_info.transport_compression_level(), ErrorCodes::INVALID_GRPC_QUERY_INFO);
+ return res;
+ }
+
+ if (query_info.has_obsolete_result_compression())
+ {
+ switch (query_info.obsolete_result_compression().algorithm())
+ {
+ case GRPCObsoleteTransportCompression::NO_COMPRESSION: res.algorithm = GRPC_COMPRESS_NONE; break;
+ case GRPCObsoleteTransportCompression::DEFLATE: res.algorithm = GRPC_COMPRESS_DEFLATE; break;
+ case GRPCObsoleteTransportCompression::GZIP: res.algorithm = GRPC_COMPRESS_GZIP; break;
+ case GRPCObsoleteTransportCompression::STREAM_GZIP: res.algorithm = GRPC_COMPRESS_STREAM_GZIP; break;
+ default: throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "Unknown compression algorithm: {}", GRPCObsoleteTransportCompression::CompressionAlgorithm_Name(query_info.obsolete_result_compression().algorithm()));
+ }
+
+ switch (query_info.obsolete_result_compression().level())
+ {
+ case GRPCObsoleteTransportCompression::COMPRESSION_NONE: res.level = GRPC_COMPRESS_LEVEL_NONE; break;
+ case GRPCObsoleteTransportCompression::COMPRESSION_LOW: res.level = GRPC_COMPRESS_LEVEL_LOW; break;
+ case GRPCObsoleteTransportCompression::COMPRESSION_MEDIUM: res.level = GRPC_COMPRESS_LEVEL_MED; break;
+ case GRPCObsoleteTransportCompression::COMPRESSION_HIGH: res.level = GRPC_COMPRESS_LEVEL_HIGH; break;
+ default: throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "Unknown compression level: {}", GRPCObsoleteTransportCompression::CompressionLevel_Name(query_info.obsolete_result_compression().level()));
+ }
+ return res;
+ }
+
+ return std::nullopt;
+ }
+
+ /// Extracts the settings of transport compression from the server configuration.
+ static TransportCompression fromConfiguration(const Poco::Util::AbstractConfiguration & config)
+ {
+ TransportCompression res;
+ if (config.has("grpc.transport_compression_type"))
+ {
+ res.setAlgorithm(config.getString("grpc.transport_compression_type"), ErrorCodes::INVALID_CONFIG_PARAMETER);
+ res.setLevel(config.getInt("grpc.transport_compression_level", 0), ErrorCodes::INVALID_CONFIG_PARAMETER);
+ }
+ else
+ {
+ res.setAlgorithm(config.getString("grpc.compression", "none"), ErrorCodes::INVALID_CONFIG_PARAMETER);
+ res.setLevel(config.getString("grpc.compression_level", "none"), ErrorCodes::INVALID_CONFIG_PARAMETER);
+ }
+ return res;
+ }
+
+ private:
+ void setAlgorithm(const String & str, int error_code)
+ {
+ if (str == "none")
+ algorithm = GRPC_COMPRESS_NONE;
+ else if (str == "deflate")
+ algorithm = GRPC_COMPRESS_DEFLATE;
+ else if (str == "gzip")
+ algorithm = GRPC_COMPRESS_GZIP;
+ else if (str == "stream_gzip")
+ algorithm = GRPC_COMPRESS_STREAM_GZIP;
+ else
+ throw Exception(error_code, "Unknown compression algorithm: '{}'", str);
+ }
+
+ void setLevel(const String & str, int error_code)
+ {
+ if (str == "none")
+ level = GRPC_COMPRESS_LEVEL_NONE;
+ else if (str == "low")
+ level = GRPC_COMPRESS_LEVEL_LOW;
+ else if (str == "medium")
+ level = GRPC_COMPRESS_LEVEL_MED;
+ else if (str == "high")
+ level = GRPC_COMPRESS_LEVEL_HIGH;
+ else
+ throw Exception(error_code, "Unknown compression level: '{}'", str);
+ }
+
+ void setLevel(int level_, int error_code)
+ {
+ if (0 <= level_ && level_ < GRPC_COMPRESS_LEVEL_COUNT)
+ level = static_cast<grpc_compression_level>(level_);
+ else
+ throw Exception(error_code, "Compression level {} is out of range 0..{}", level_, GRPC_COMPRESS_LEVEL_COUNT - 1);
+ }
+ };
+
+ /// Gets session's timeout from query info or from the server config.
+ std::chrono::steady_clock::duration getSessionTimeout(const GRPCQueryInfo & query_info, const Poco::Util::AbstractConfiguration & config)
+ {
+ auto session_timeout = query_info.session_timeout();
+ if (session_timeout)
+ {
+ auto max_session_timeout = config.getUInt("max_session_timeout", 3600);
+ if (session_timeout > max_session_timeout)
+ throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Session timeout '{}' is larger than max_session_timeout: {}. "
+ "Maximum session timeout could be modified in configuration file.",
+ std::to_string(session_timeout), std::to_string(max_session_timeout));
+ }
+ else
+ session_timeout = config.getInt("default_session_timeout", 60);
+ return std::chrono::seconds(session_timeout);
+ }
+
+ /// Generates a description of a query by a specified query info.
+ /// This description is used for logging only.
+ String getQueryDescription(const GRPCQueryInfo & query_info)
+ {
+ String str;
+ if (!query_info.query().empty())
+ {
+ std::string_view query = query_info.query();
+ constexpr size_t max_query_length_to_log = 64;
+ if (query.length() > max_query_length_to_log)
+ query.remove_suffix(query.length() - max_query_length_to_log);
+ if (size_t format_pos = query.find(" FORMAT "); format_pos != String::npos)
+ query.remove_suffix(query.length() - format_pos - strlen(" FORMAT "));
+ str.append("\"").append(query);
+ if (query != query_info.query())
+ str.append("...");
+ str.append("\"");
+ }
+ if (!query_info.query_id().empty())
+ str.append(str.empty() ? "" : ", ").append("query_id: ").append(query_info.query_id());
+ if (!query_info.input_data().empty())
+ str.append(str.empty() ? "" : ", ").append("input_data: ").append(std::to_string(query_info.input_data().size())).append(" bytes");
+ if (query_info.external_tables_size())
+ str.append(str.empty() ? "" : ", ").append("external tables: ").append(std::to_string(query_info.external_tables_size()));
+ return str;
+ }
+
+ /// Generates a description of a result.
+ /// This description is used for logging only.
+ String getResultDescription(const GRPCResult & result)
+ {
+ String str;
+ if (!result.output().empty())
+ str.append("output: ").append(std::to_string(result.output().size())).append(" bytes");
+ if (!result.totals().empty())
+ str.append(str.empty() ? "" : ", ").append("totals");
+ if (!result.extremes().empty())
+ str.append(str.empty() ? "" : ", ").append("extremes");
+ if (result.has_progress())
+ str.append(str.empty() ? "" : ", ").append("progress");
+ if (result.logs_size())
+ str.append(str.empty() ? "" : ", ").append("logs: ").append(std::to_string(result.logs_size())).append(" entries");
+ if (result.cancelled())
+ str.append(str.empty() ? "" : ", ").append("cancelled");
+ if (result.has_exception())
+ str.append(str.empty() ? "" : ", ").append("exception");
+ return str;
+ }
+
+ using CompletionCallback = std::function<void(bool)>;
+
+ /// Requests a connection and provides low-level interface for reading and writing.
+ class BaseResponder
+ {
+ public:
+ virtual ~BaseResponder() = default;
+
+ virtual void start(GRPCService & grpc_service,
+ grpc::ServerCompletionQueue & new_call_queue,
+ grpc::ServerCompletionQueue & notification_queue,
+ const CompletionCallback & callback) = 0;
+
+ virtual void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) = 0;
+ virtual void write(const GRPCResult & result, const CompletionCallback & callback) = 0;
+ virtual void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) = 0;
+
+ Poco::Net::SocketAddress getClientAddress() const
+ {
+ String peer = grpc_context.peer();
+ return Poco::Net::SocketAddress{peer.substr(peer.find(':') + 1)};
+ }
+
+ std::optional<String> getClientHeader(const String & key) const
+ {
+ const auto & client_metadata = grpc_context.client_metadata();
+ auto it = client_metadata.find(key);
+ if (it != client_metadata.end())
+ return String{it->second.data(), it->second.size()};
+ return std::nullopt;
+ }
+
+ void setTransportCompression(const TransportCompression & transport_compression)
+ {
+ grpc_context.set_compression_algorithm(transport_compression.algorithm);
+ grpc_context.set_compression_level(transport_compression.level);
+ }
+
+ protected:
+ CompletionCallback * getCallbackPtr(const CompletionCallback & callback)
+ {
+ /// It would be better to pass callbacks to gRPC calls.
+ /// However gRPC calls can be tagged with `void *` tags only.
+ /// The map `callbacks` here is used to keep callbacks until they're called.
+ std::lock_guard lock{mutex};
+ size_t callback_id = next_callback_id++;
+ auto & callback_in_map = callbacks[callback_id];
+ callback_in_map = [this, callback, callback_id](bool ok)
+ {
+ CompletionCallback callback_to_call;
+ {
+ std::lock_guard lock2{mutex};
+ callback_to_call = callback;
+ callbacks.erase(callback_id);
+ }
+ callback_to_call(ok);
+ };
+ return &callback_in_map;
+ }
+
+ grpc::ServerContext grpc_context;
+
+ private:
+ grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> reader_writer{&grpc_context};
+ std::unordered_map<size_t, CompletionCallback> callbacks;
+ size_t next_callback_id = 0;
+ std::mutex mutex;
+ };
+
+ enum CallType
+ {
+ CALL_SIMPLE, /// ExecuteQuery() call
+ CALL_WITH_STREAM_INPUT, /// ExecuteQueryWithStreamInput() call
+ CALL_WITH_STREAM_OUTPUT, /// ExecuteQueryWithStreamOutput() call
+ CALL_WITH_STREAM_IO, /// ExecuteQueryWithStreamIO() call
+ CALL_MAX,
+ };
+
+ const char * getCallName(CallType call_type)
+ {
+ switch (call_type)
+ {
+ case CALL_SIMPLE: return "ExecuteQuery()";
+ case CALL_WITH_STREAM_INPUT: return "ExecuteQueryWithStreamInput()";
+ case CALL_WITH_STREAM_OUTPUT: return "ExecuteQueryWithStreamOutput()";
+ case CALL_WITH_STREAM_IO: return "ExecuteQueryWithStreamIO()";
+ case CALL_MAX: break;
+ }
+ UNREACHABLE();
+ }
+
+ bool isInputStreaming(CallType call_type)
+ {
+ return (call_type == CALL_WITH_STREAM_INPUT) || (call_type == CALL_WITH_STREAM_IO);
+ }
+
+ bool isOutputStreaming(CallType call_type)
+ {
+ return (call_type == CALL_WITH_STREAM_OUTPUT) || (call_type == CALL_WITH_STREAM_IO);
+ }
+
+ template <enum CallType call_type>
+ class Responder;
+
+ template<>
+ class Responder<CALL_SIMPLE> : public BaseResponder
+ {
+ public:
+ void start(GRPCService & grpc_service,
+ grpc::ServerCompletionQueue & new_call_queue,
+ grpc::ServerCompletionQueue & notification_queue,
+ const CompletionCallback & callback) override
+ {
+ grpc_service.RequestExecuteQuery(&grpc_context, &query_info.emplace(), &response_writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
+ }
+
+ void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
+ {
+ if (!query_info.has_value())
+ callback(false);
+ query_info_ = std::move(query_info).value();
+ query_info.reset();
+ callback(true);
+ }
+
+ void write(const GRPCResult &, const CompletionCallback &) override
+ {
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Responder<CALL_SIMPLE>::write() should not be called");
+ }
+
+ void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
+ {
+ response_writer.Finish(result, status, getCallbackPtr(callback));
+ }
+
+ private:
+ grpc::ServerAsyncResponseWriter<GRPCResult> response_writer{&grpc_context};
+ std::optional<GRPCQueryInfo> query_info;
+ };
+
+ template<>
+ class Responder<CALL_WITH_STREAM_INPUT> : public BaseResponder
+ {
+ public:
+ void start(GRPCService & grpc_service,
+ grpc::ServerCompletionQueue & new_call_queue,
+ grpc::ServerCompletionQueue & notification_queue,
+ const CompletionCallback & callback) override
+ {
+ grpc_service.RequestExecuteQueryWithStreamInput(&grpc_context, &reader, &new_call_queue, &notification_queue, getCallbackPtr(callback));
+ }
+
+ void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
+ {
+ reader.Read(&query_info_, getCallbackPtr(callback));
+ }
+
+ void write(const GRPCResult &, const CompletionCallback &) override
+ {
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Responder<CALL_WITH_STREAM_INPUT>::write() should not be called");
+ }
+
+ void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
+ {
+ reader.Finish(result, status, getCallbackPtr(callback));
+ }
+
+ private:
+ grpc::ServerAsyncReader<GRPCResult, GRPCQueryInfo> reader{&grpc_context};
+ };
+
+ template<>
+ class Responder<CALL_WITH_STREAM_OUTPUT> : public BaseResponder
+ {
+ public:
+ void start(GRPCService & grpc_service,
+ grpc::ServerCompletionQueue & new_call_queue,
+ grpc::ServerCompletionQueue & notification_queue,
+ const CompletionCallback & callback) override
+ {
+ grpc_service.RequestExecuteQueryWithStreamOutput(&grpc_context, &query_info.emplace(), &writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
+ }
+
+ void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
+ {
+ if (!query_info.has_value())
+ callback(false);
+ query_info_ = std::move(query_info).value();
+ query_info.reset();
+ callback(true);
+ }
+
+ void write(const GRPCResult & result, const CompletionCallback & callback) override
+ {
+ writer.Write(result, getCallbackPtr(callback));
+ }
+
+ void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
+ {
+ writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback));
+ }
+
+ private:
+ grpc::ServerAsyncWriter<GRPCResult> writer{&grpc_context};
+ std::optional<GRPCQueryInfo> query_info;
+ };
+
+ template<>
+ class Responder<CALL_WITH_STREAM_IO> : public BaseResponder
+ {
+ public:
+ void start(GRPCService & grpc_service,
+ grpc::ServerCompletionQueue & new_call_queue,
+ grpc::ServerCompletionQueue & notification_queue,
+ const CompletionCallback & callback) override
+ {
+ grpc_service.RequestExecuteQueryWithStreamIO(&grpc_context, &reader_writer, &new_call_queue, &notification_queue, getCallbackPtr(callback));
+ }
+
+ void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override
+ {
+ reader_writer.Read(&query_info_, getCallbackPtr(callback));
+ }
+
+ void write(const GRPCResult & result, const CompletionCallback & callback) override
+ {
+ reader_writer.Write(result, getCallbackPtr(callback));
+ }
+
+ void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override
+ {
+ reader_writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback));
+ }
+
+ private:
+ grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> reader_writer{&grpc_context};
+ };
+
+ std::unique_ptr<BaseResponder> makeResponder(CallType call_type)
+ {
+ switch (call_type)
+ {
+ case CALL_SIMPLE: return std::make_unique<Responder<CALL_SIMPLE>>();
+ case CALL_WITH_STREAM_INPUT: return std::make_unique<Responder<CALL_WITH_STREAM_INPUT>>();
+ case CALL_WITH_STREAM_OUTPUT: return std::make_unique<Responder<CALL_WITH_STREAM_OUTPUT>>();
+ case CALL_WITH_STREAM_IO: return std::make_unique<Responder<CALL_WITH_STREAM_IO>>();
+ case CALL_MAX: break;
+ }
+ UNREACHABLE();
+ }
+
+
+ /// Implementation of ReadBuffer, which just calls a callback.
+ class ReadBufferFromCallback : public ReadBuffer
+ {
+ public:
+ explicit ReadBufferFromCallback(const std::function<std::pair<const void *, size_t>(void)> & callback_)
+ : ReadBuffer(nullptr, 0), callback(callback_) {}
+
+ private:
+ bool nextImpl() override
+ {
+ const void * new_pos;
+ size_t new_size;
+ std::tie(new_pos, new_size) = callback();
+ if (!new_size)
+ return false;
+ BufferBase::set(static_cast<BufferBase::Position>(const_cast<void *>(new_pos)), new_size, 0);
+ return true;
+ }
+
+ std::function<std::pair<const void *, size_t>(void)> callback;
+ };
+
+
+ /// A boolean state protected by mutex able to wait until other thread sets it to a specific value.
+ class BoolState
+ {
+ public:
+ explicit BoolState(bool initial_value) : value(initial_value) {}
+
+ bool get() const
+ {
+ std::lock_guard lock{mutex};
+ return value;
+ }
+
+ void set(bool new_value)
+ {
+ std::lock_guard lock{mutex};
+ if (value == new_value)
+ return;
+ value = new_value;
+ changed.notify_all();
+ }
+
+ void wait(bool wanted_value) const
+ {
+ std::unique_lock lock{mutex};
+ changed.wait(lock, [this, wanted_value]() { return value == wanted_value; });
+ }
+
+ private:
+ bool value;
+ mutable std::mutex mutex;
+ mutable std::condition_variable changed;
+ };
+
+
+ /// Handles a connection after a responder is started (i.e. after getting a new call).
+ class Call
+ {
+ public:
+ Call(CallType call_type_, std::unique_ptr<BaseResponder> responder_, IServer & iserver_, Poco::Logger * log_);
+ ~Call();
+
+ void start(const std::function<void(void)> & on_finish_call_callback);
+
+ private:
+ void run();
+
+ void receiveQuery();
+ void executeQuery();
+
+ void processInput();
+ void initializePipeline(const Block & header);
+ void createExternalTables();
+
+ void generateOutput();
+
+ void finishQuery();
+ void onException(const Exception & exception);
+ void onFatalError();
+ void releaseQueryIDAndSessionID();
+ void close();
+
+ void readQueryInfo();
+ void throwIfFailedToReadQueryInfo();
+ bool isQueryCancelled();
+
+ void addQueryDetailsToResult();
+ void addOutputFormatToResult();
+ void addOutputColumnsNamesAndTypesToResult(const Block & headers);
+ void addProgressToResult();
+ void addTotalsToResult(const Block & totals);
+ void addExtremesToResult(const Block & extremes);
+ void addProfileInfoToResult(const ProfileInfo & info);
+ void addLogsToResult();
+ void sendResult();
+ void throwIfFailedToSendResult();
+ void sendException(const Exception & exception);
+
+ const CallType call_type;
+ std::unique_ptr<BaseResponder> responder;
+ IServer & iserver;
+ Poco::Logger * log = nullptr;
+
+ std::optional<Session> session;
+ ContextMutablePtr query_context;
+ std::optional<CurrentThread::QueryScope> query_scope;
+ OpenTelemetry::TracingContextHolderPtr thread_trace_context;
+ String query_text;
+ ASTPtr ast;
+ ASTInsertQuery * insert_query = nullptr;
+ String input_format;
+ String input_data_delimiter;
+ CompressionMethod input_compression_method = CompressionMethod::None;
+ PODArray<char> output;
+ String output_format;
+ bool send_output_columns_names_and_types = false;
+ CompressionMethod output_compression_method = CompressionMethod::None;
+ int output_compression_level = 0;
+
+ uint64_t interactive_delay = 100000;
+ bool send_exception_with_stacktrace = true;
+ bool input_function_is_used = false;
+
+ BlockIO io;
+ Progress progress;
+ InternalTextLogsQueuePtr logs_queue;
+
+ GRPCQueryInfo query_info; /// We reuse the same messages multiple times.
+ GRPCResult result;
+
+ bool initial_query_info_read = false;
+ bool finalize = false;
+ bool responder_finished = false;
+ bool cancelled = false;
+
+ std::unique_ptr<ReadBuffer> read_buffer;
+ std::unique_ptr<WriteBuffer> write_buffer;
+ WriteBufferFromVector<PODArray<char>> * nested_write_buffer = nullptr;
+ WriteBuffer * compressing_write_buffer = nullptr;
+ std::unique_ptr<QueryPipeline> pipeline;
+ std::unique_ptr<PullingPipelineExecutor> pipeline_executor;
+ std::shared_ptr<IOutputFormat> output_format_processor;
+ bool need_input_data_from_insert_query = true;
+ bool need_input_data_from_query_info = true;
+ bool need_input_data_delimiter = false;
+
+ Stopwatch query_time;
+ UInt64 waited_for_client_reading = 0;
+ UInt64 waited_for_client_writing = 0;
+
+ /// The following fields are accessed both from call_thread and queue_thread.
+ BoolState reading_query_info{false};
+ std::atomic<bool> failed_to_read_query_info = false;
+ GRPCQueryInfo next_query_info_while_reading;
+ std::atomic<bool> want_to_cancel = false;
+ std::atomic<bool> check_query_info_contains_cancel_only = false;
+ BoolState sending_result{false};
+ std::atomic<bool> failed_to_send_result = false;
+
+ ThreadFromGlobalPool call_thread;
+ };
+
+ Call::Call(CallType call_type_, std::unique_ptr<BaseResponder> responder_, IServer & iserver_, Poco::Logger * log_)
+ : call_type(call_type_), responder(std::move(responder_)), iserver(iserver_), log(log_)
+ {
+ }
+
+ Call::~Call()
+ {
+ if (call_thread.joinable())
+ call_thread.join();
+ }
+
+ void Call::start(const std::function<void(void)> & on_finish_call_callback)
+ {
+ auto runner_function = [this, on_finish_call_callback]
+ {
+ try
+ {
+ run();
+ }
+ catch (...)
+ {
+ tryLogCurrentException("GRPCServer");
+ }
+ on_finish_call_callback();
+ };
+ call_thread = ThreadFromGlobalPool(runner_function);
+ }
+
+ void Call::run()
+ {
+ try
+ {
+ setThreadName("GRPCServerCall");
+ receiveQuery();
+ executeQuery();
+ processInput();
+ generateOutput();
+ finishQuery();
+ }
+ catch (Exception & exception)
+ {
+ onException(exception);
+ }
+ catch (Poco::Exception & exception)
+ {
+ onException(Exception{Exception::CreateFromPocoTag{}, exception});
+ }
+ catch (std::exception & exception)
+ {
+ onException(Exception{Exception::CreateFromSTDTag{}, exception});
+ }
+ }
+
+ void Call::receiveQuery()
+ {
+ LOG_INFO(log, "Handling call {}", getCallName(call_type));
+
+ readQueryInfo();
+
+ if (query_info.cancel())
+ throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "Initial query info cannot set the 'cancel' field");
+
+ LOG_DEBUG(log, "Received initial QueryInfo: {}", getQueryDescription(query_info));
+ }
+
+ void Call::executeQuery()
+ {
+ /// Retrieve user credentials.
+ std::string user = query_info.user_name();
+ std::string password = query_info.password();
+ std::string quota_key = query_info.quota();
+ Poco::Net::SocketAddress user_address = responder->getClientAddress();
+
+ if (user.empty())
+ {
+ user = "default";
+ password = "";
+ }
+
+ /// Authentication.
+ session.emplace(iserver.context(), ClientInfo::Interface::GRPC);
+ session->authenticate(user, password, user_address);
+ session->setQuotaClientKey(quota_key);
+
+ ClientInfo client_info = session->getClientInfo();
+
+ /// Parse the OpenTelemetry traceparent header.
+ auto traceparent = responder->getClientHeader("traceparent");
+ if (traceparent)
+ {
+ String error;
+ if (!client_info.client_trace_context.parseTraceparentHeader(traceparent.value(), error))
+ {
+ throw Exception(ErrorCodes::BAD_REQUEST_PARAMETER,
+ "Failed to parse OpenTelemetry traceparent header '{}': {}",
+ traceparent.value(), error);
+ }
+ auto tracestate = responder->getClientHeader("tracestate");
+ client_info.client_trace_context.tracestate = tracestate.value_or("");
+ }
+
+ /// The user could specify session identifier and session timeout.
+ /// It allows to modify settings, create temporary tables and reuse them in subsequent requests.
+ if (!query_info.session_id().empty())
+ {
+ session->makeSessionContext(
+ query_info.session_id(), getSessionTimeout(query_info, iserver.config()), query_info.session_check());
+ }
+
+ query_context = session->makeQueryContext(std::move(client_info));
+
+ /// Prepare settings.
+ SettingsChanges settings_changes;
+ for (const auto & [key, value] : query_info.settings())
+ {
+ settings_changes.push_back({key, value});
+ }
+ query_context->checkSettingsConstraints(settings_changes, SettingSource::QUERY);
+ query_context->applySettingsChanges(settings_changes);
+
+ query_context->setCurrentQueryId(query_info.query_id());
+ query_scope.emplace(query_context, /* fatal_error_callback */ [this]{ onFatalError(); });
+
+ /// Set up tracing context for this query on current thread
+ thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("GRPCServer",
+ query_context->getClientInfo().client_trace_context,
+ query_context->getSettingsRef(),
+ query_context->getOpenTelemetrySpanLog());
+ thread_trace_context->root_span.kind = OpenTelemetry::SERVER;
+
+ /// Prepare for sending exceptions and logs.
+ const Settings & settings = query_context->getSettingsRef();
+ send_exception_with_stacktrace = settings.calculate_text_stack_trace;
+ const auto client_logs_level = settings.send_logs_level;
+ if (client_logs_level != LogsLevel::none)
+ {
+ logs_queue = std::make_shared<InternalTextLogsQueue>();
+ logs_queue->max_priority = Poco::Logger::parseLevel(client_logs_level.toString());
+ logs_queue->setSourceRegexp(settings.send_logs_source_regexp);
+ CurrentThread::attachInternalTextLogsQueue(logs_queue, client_logs_level);
+ }
+
+ /// Set the current database if specified.
+ if (!query_info.database().empty())
+ query_context->setCurrentDatabase(query_info.database());
+
+ /// Apply transport compression for this call.
+ if (auto transport_compression = TransportCompression::fromQueryInfo(query_info))
+ responder->setTransportCompression(*transport_compression);
+
+ /// The interactive delay will be used to show progress.
+ interactive_delay = settings.interactive_delay;
+ query_context->setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); });
+
+ /// Parse the query.
+ query_text = std::move(*(query_info.mutable_query()));
+ const char * begin = query_text.data();
+ const char * end = begin + query_text.size();
+ ParserQuery parser(end, settings.allow_settings_after_format_in_insert);
+ ast = parseQuery(parser, begin, end, "", settings.max_query_size, settings.max_parser_depth);
+
+ /// Choose input format.
+ insert_query = ast->as<ASTInsertQuery>();
+ if (insert_query)
+ {
+ input_format = insert_query->format;
+ if (input_format.empty())
+ input_format = "Values";
+ }
+
+ input_data_delimiter = query_info.input_data_delimiter();
+
+ /// Choose output format.
+ query_context->setDefaultFormat(query_info.output_format());
+ if (const auto * ast_query_with_output = dynamic_cast<const ASTQueryWithOutput *>(ast.get());
+ ast_query_with_output && ast_query_with_output->format)
+ {
+ output_format = getIdentifierName(ast_query_with_output->format);
+ }
+ if (output_format.empty())
+ output_format = query_context->getDefaultFormat();
+
+ send_output_columns_names_and_types = query_info.send_output_columns();
+
+ /// Choose compression.
+ String input_compression_method_str = query_info.input_compression_type();
+ if (input_compression_method_str.empty())
+ input_compression_method_str = query_info.obsolete_compression_type();
+ input_compression_method = chooseCompressionMethod("", input_compression_method_str);
+
+ String output_compression_method_str = query_info.output_compression_type();
+ if (output_compression_method_str.empty())
+ output_compression_method_str = query_info.obsolete_compression_type();
+ output_compression_method = chooseCompressionMethod("", output_compression_method_str);
+ output_compression_level = query_info.output_compression_level();
+
+ /// Set callback to create and fill external tables
+ query_context->setExternalTablesInitializer([this] (ContextPtr context)
+ {
+ if (context != query_context)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in external tables initializer");
+ createExternalTables();
+ });
+
+ /// Set callbacks to execute function input().
+ query_context->setInputInitializer([this] (ContextPtr context, const StoragePtr & input_storage)
+ {
+ if (context != query_context)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in Input initializer");
+ input_function_is_used = true;
+ initializePipeline(input_storage->getInMemoryMetadataPtr()->getSampleBlock());
+ });
+
+ query_context->setInputBlocksReaderCallback([this](ContextPtr context) -> Block
+ {
+ if (context != query_context)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in InputBlocksReader");
+
+ Block block;
+ while (!block && pipeline_executor->pull(block));
+
+ return block;
+ });
+
+ /// Start executing the query.
+ const auto * query_end = end;
+ if (insert_query && insert_query->data)
+ {
+ query_end = insert_query->data;
+ }
+ String query(begin, query_end);
+ io = ::DB::executeQuery(true, query, query_context);
+ }
+
+ void Call::processInput()
+ {
+ if (!io.pipeline.pushing())
+ return;
+
+ bool has_data_to_insert = (insert_query && insert_query->data)
+ || !query_info.input_data().empty() || query_info.next_query_info();
+ if (!has_data_to_insert)
+ {
+ if (!insert_query)
+ throw Exception(ErrorCodes::NO_DATA_TO_INSERT, "Query requires data to insert, but it is not an INSERT query");
+ else
+ {
+ const auto & settings = query_context->getSettingsRef();
+ if (settings.throw_if_no_data_to_insert)
+ throw Exception(ErrorCodes::NO_DATA_TO_INSERT, "No data to insert");
+ else
+ return;
+ }
+ }
+
+ /// This is significant, because parallel parsing may be used.
+ /// So we mustn't touch the input stream from other thread.
+ initializePipeline(io.pipeline.getHeader());
+
+ PushingPipelineExecutor executor(io.pipeline);
+ executor.start();
+
+ Block block;
+ while (pipeline_executor->pull(block))
+ {
+ if (block)
+ executor.push(block);
+ }
+
+ if (isQueryCancelled())
+ executor.cancel();
+ else
+ executor.finish();
+ }
+
+ void Call::initializePipeline(const Block & header)
+ {
+ assert(!read_buffer);
+ read_buffer = std::make_unique<ReadBufferFromCallback>([this]() -> std::pair<const void *, size_t>
+ {
+ if (need_input_data_from_insert_query)
+ {
+ need_input_data_from_insert_query = false;
+ if (insert_query && insert_query->data && (insert_query->data != insert_query->end))
+ {
+ need_input_data_delimiter = !input_data_delimiter.empty();
+ return {insert_query->data, insert_query->end - insert_query->data};
+ }
+ }
+
+ while (true)
+ {
+ if (need_input_data_from_query_info)
+ {
+ if (need_input_data_delimiter && !query_info.input_data().empty())
+ {
+ need_input_data_delimiter = false;
+ return {input_data_delimiter.data(), input_data_delimiter.size()};
+ }
+ need_input_data_from_query_info = false;
+ if (!query_info.input_data().empty())
+ {
+ need_input_data_delimiter = !input_data_delimiter.empty();
+ return {query_info.input_data().data(), query_info.input_data().size()};
+ }
+ }
+
+ if (!query_info.next_query_info())
+ break;
+
+ if (!isInputStreaming(call_type))
+ throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "next_query_info is allowed to be set only for streaming input");
+
+ readQueryInfo();
+ if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty()
+ || !query_info.database().empty() || !query_info.input_data_delimiter().empty() || !query_info.output_format().empty()
+ || query_info.external_tables_size() || !query_info.user_name().empty() || !query_info.password().empty()
+ || !query_info.quota().empty() || !query_info.session_id().empty())
+ {
+ throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO,
+ "Extra query infos can be used only to add more input data. "
+ "Only the following fields can be set: input_data, next_query_info, cancel");
+ }
+
+ if (isQueryCancelled())
+ break;
+
+ LOG_DEBUG(log, "Received extra QueryInfo: input_data: {} bytes", query_info.input_data().size());
+ need_input_data_from_query_info = true;
+ }
+
+ return {nullptr, 0}; /// no more input data
+ });
+
+ read_buffer = wrapReadBufferWithCompressionMethod(std::move(read_buffer), input_compression_method);
+
+ assert(!pipeline);
+ auto source = query_context->getInputFormat(
+ input_format, *read_buffer, header, query_context->getSettings().max_insert_block_size);
+
+ pipeline = std::make_unique<QueryPipeline>(std::move(source));
+ pipeline_executor = std::make_unique<PullingPipelineExecutor>(*pipeline);
+ }
+
+ void Call::createExternalTables()
+ {
+ while (true)
+ {
+ for (const auto & external_table : query_info.external_tables())
+ {
+ String name = external_table.name();
+ if (name.empty())
+ name = "_data";
+ auto temporary_id = StorageID::createEmpty();
+ temporary_id.table_name = name;
+
+ /// If such a table does not exist, create it.
+ StoragePtr storage;
+ if (auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal))
+ {
+ storage = DatabaseCatalog::instance().getTable(resolved, query_context);
+ }
+ else
+ {
+ NamesAndTypesList columns;
+ for (size_t column_idx : collections::range(external_table.columns_size()))
+ {
+ /// TODO: consider changing protocol
+ const auto & name_and_type = external_table.columns(static_cast<int>(column_idx));
+ NameAndTypePair column;
+ column.name = name_and_type.name();
+ if (column.name.empty())
+ column.name = "_" + std::to_string(column_idx + 1);
+ column.type = DataTypeFactory::instance().get(name_and_type.type());
+ columns.emplace_back(std::move(column));
+ }
+ auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {});
+ storage = temporary_table.getTable();
+ query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table));
+ }
+
+ if (!external_table.data().empty())
+ {
+ /// The data will be written directly to the table.
+ auto metadata_snapshot = storage->getInMemoryMetadataPtr();
+ auto sink = storage->write(ASTPtr(), metadata_snapshot, query_context, /*async_insert=*/false);
+
+ std::unique_ptr<ReadBuffer> buf = std::make_unique<ReadBufferFromMemory>(external_table.data().data(), external_table.data().size());
+ buf = wrapReadBufferWithCompressionMethod(std::move(buf), chooseCompressionMethod("", external_table.compression_type()));
+
+ String format = external_table.format();
+ if (format.empty())
+ format = "TabSeparated";
+ ContextMutablePtr external_table_context = query_context;
+ ContextMutablePtr temp_context;
+ if (!external_table.settings().empty())
+ {
+ temp_context = Context::createCopy(query_context);
+ external_table_context = temp_context;
+ SettingsChanges settings_changes;
+ for (const auto & [key, value] : external_table.settings())
+ settings_changes.push_back({key, value});
+ external_table_context->checkSettingsConstraints(settings_changes, SettingSource::QUERY);
+ external_table_context->applySettingsChanges(settings_changes);
+ }
+ auto in = external_table_context->getInputFormat(
+ format, *buf, metadata_snapshot->getSampleBlock(),
+ external_table_context->getSettings().max_insert_block_size);
+
+ QueryPipelineBuilder cur_pipeline;
+ cur_pipeline.init(Pipe(std::move(in)));
+ cur_pipeline.addTransform(std::move(sink));
+ cur_pipeline.setSinks([&](const Block & header, Pipe::StreamType)
+ {
+ return std::make_shared<EmptySink>(header);
+ });
+
+ auto executor = cur_pipeline.execute();
+ executor->execute(1, false);
+ }
+ }
+
+ if (!query_info.input_data().empty())
+ {
+ /// External tables must be created before executing query,
+ /// so all external tables must be send no later sending any input data.
+ break;
+ }
+
+ if (!query_info.next_query_info())
+ break;
+
+ if (!isInputStreaming(call_type))
+ throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "next_query_info is allowed to be set only for streaming input");
+
+ readQueryInfo();
+ if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty()
+ || !query_info.database().empty() || !query_info.input_data_delimiter().empty()
+ || !query_info.output_format().empty() || !query_info.user_name().empty() || !query_info.password().empty()
+ || !query_info.quota().empty() || !query_info.session_id().empty())
+ {
+ throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO,
+ "Extra query infos can be used only "
+ "to add more data to input or more external tables. "
+ "Only the following fields can be set: "
+ "input_data, external_tables, next_query_info, cancel");
+ }
+ if (isQueryCancelled())
+ break;
+ LOG_DEBUG(log, "Received extra QueryInfo: external tables: {}", query_info.external_tables_size());
+ }
+ }
+
+ void Call::generateOutput()
+ {
+ /// We add query_id and time_zone to the first result anyway.
+ addQueryDetailsToResult();
+
+ if (!io.pipeline.initialized() || io.pipeline.pushing())
+ return;
+
+ Block header;
+ if (io.pipeline.pulling())
+ header = io.pipeline.getHeader();
+
+ if (output_compression_method != CompressionMethod::None)
+ output.resize(DBMS_DEFAULT_BUFFER_SIZE); /// Must have enough space for compressed data.
+ write_buffer = std::make_unique<WriteBufferFromVector<PODArray<char>>>(output);
+ nested_write_buffer = static_cast<WriteBufferFromVector<PODArray<char>> *>(write_buffer.get());
+ if (output_compression_method != CompressionMethod::None)
+ {
+ write_buffer = wrapWriteBufferWithCompressionMethod(std::move(write_buffer), output_compression_method, output_compression_level);
+ compressing_write_buffer = write_buffer.get();
+ }
+
+ auto has_output = [&] { return (nested_write_buffer->position() != output.data()) || (compressing_write_buffer && compressing_write_buffer->offset()); };
+
+ output_format_processor = query_context->getOutputFormat(output_format, *write_buffer, header);
+ Stopwatch after_send_progress;
+
+ /// Unless the input() function is used we are not going to receive input data anymore.
+ if (!input_function_is_used)
+ check_query_info_contains_cancel_only = true;
+
+ if (io.pipeline.pulling())
+ {
+ auto executor = std::make_shared<PullingAsyncPipelineExecutor>(io.pipeline);
+ auto check_for_cancel = [&]
+ {
+ if (isQueryCancelled())
+ {
+ executor->cancel();
+ return false;
+ }
+ return true;
+ };
+
+ addOutputFormatToResult();
+ addOutputColumnsNamesAndTypesToResult(header);
+
+ Block block;
+ while (check_for_cancel())
+ {
+ if (!executor->pull(block, interactive_delay / 1000))
+ break;
+
+ throwIfFailedToSendResult();
+ if (!check_for_cancel())
+ break;
+
+ if (block && !io.null_format)
+ output_format_processor->write(materializeBlock(block));
+
+ if (after_send_progress.elapsedMicroseconds() >= interactive_delay)
+ {
+ addProgressToResult();
+ after_send_progress.restart();
+ }
+
+ addLogsToResult();
+
+ if (has_output() || result.has_progress() || result.logs_size())
+ sendResult();
+
+ throwIfFailedToSendResult();
+ if (!check_for_cancel())
+ break;
+ }
+
+ if (!isQueryCancelled())
+ {
+ addTotalsToResult(executor->getTotalsBlock());
+ addExtremesToResult(executor->getExtremesBlock());
+ addProfileInfoToResult(executor->getProfileInfo());
+ }
+ }
+ else
+ {
+ auto executor = std::make_shared<CompletedPipelineExecutor>(io.pipeline);
+ auto callback = [&]() -> bool
+ {
+ throwIfFailedToSendResult();
+ addProgressToResult();
+ addLogsToResult();
+
+ if (has_output() || result.has_progress() || result.logs_size())
+ sendResult();
+
+ throwIfFailedToSendResult();
+
+ return isQueryCancelled();
+ };
+ executor->setCancelCallback(std::move(callback), interactive_delay / 1000);
+ executor->execute();
+ }
+
+ output_format_processor->finalize();
+ }
+
+ void Call::finishQuery()
+ {
+ finalize = true;
+ io.onFinish();
+ addProgressToResult();
+ query_scope->logPeakMemoryUsage();
+ addLogsToResult();
+ releaseQueryIDAndSessionID();
+ sendResult();
+ close();
+
+ LOG_INFO(
+ log,
+ "Finished call {} in {} secs. (including reading by client: {}, writing by client: {})",
+ getCallName(call_type),
+ query_time.elapsedSeconds(),
+ static_cast<double>(waited_for_client_reading) / 1000000000ULL,
+ static_cast<double>(waited_for_client_writing) / 1000000000ULL);
+ }
+
+ void Call::onException(const Exception & exception)
+ {
+ io.onException();
+
+ LOG_ERROR(log, getExceptionMessageAndPattern(exception, send_exception_with_stacktrace));
+
+ if (responder && !responder_finished)
+ {
+ try
+ {
+ /// Try to send logs to client, but it could be risky too.
+ addLogsToResult();
+ }
+ catch (...)
+ {
+ LOG_WARNING(log, "Couldn't send logs to client");
+ }
+
+ releaseQueryIDAndSessionID();
+
+ try
+ {
+ sendException(exception);
+ }
+ catch (...)
+ {
+ LOG_WARNING(log, "Couldn't send exception information to the client");
+ }
+ }
+
+ close();
+ }
+
+ void Call::onFatalError()
+ {
+ if (responder && !responder_finished)
+ {
+ try
+ {
+ result.mutable_exception()->set_name("FatalError");
+ addLogsToResult();
+ sendResult();
+ }
+ catch (...)
+ {
+ }
+ }
+ }
+
+ void Call::releaseQueryIDAndSessionID()
+ {
+ /// releaseQueryIDAndSessionID() should be called before sending the final result to the client
+ /// because the client may decide to send another query with the same query ID or session ID
+ /// immediately after it receives our final result, and it's prohibited to have
+ /// two queries executed at the same time with the same query ID or session ID.
+ io.process_list_entry.reset();
+ if (query_context)
+ query_context->setProcessListElement(nullptr);
+ if (session)
+ session->releaseSessionID();
+ }
+
+ void Call::close()
+ {
+ responder.reset();
+ pipeline_executor.reset();
+ pipeline.reset();
+ output_format_processor.reset();
+ read_buffer.reset();
+ write_buffer.reset();
+ nested_write_buffer = nullptr;
+ compressing_write_buffer = nullptr;
+ io = {};
+ query_scope.reset();
+ query_context.reset();
+ thread_trace_context.reset();
+ session.reset();
+ }
+
+ void Call::readQueryInfo()
+ {
+ auto start_reading = [&]
+ {
+ reading_query_info.set(true);
+ responder->read(next_query_info_while_reading, [this](bool ok)
+ {
+ /// Called on queue_thread.
+ if (ok)
+ {
+ const auto & nqi = next_query_info_while_reading;
+ if (check_query_info_contains_cancel_only)
+ {
+ if (!nqi.query().empty() || !nqi.query_id().empty() || !nqi.settings().empty() || !nqi.database().empty()
+ || !nqi.input_data().empty() || !nqi.input_data_delimiter().empty() || !nqi.output_format().empty()
+ || !nqi.user_name().empty() || !nqi.password().empty() || !nqi.quota().empty() || !nqi.session_id().empty())
+ {
+ LOG_WARNING(log, "Cannot add extra information to a query which is already executing. Only the 'cancel' field can be set");
+ }
+ }
+ if (nqi.cancel())
+ want_to_cancel = true;
+ }
+ else
+ {
+ /// We cannot throw an exception right here because this code is executed
+ /// on queue_thread.
+ failed_to_read_query_info = true;
+ }
+ reading_query_info.set(false);
+ });
+ };
+
+ auto finish_reading = [&]
+ {
+ if (reading_query_info.get())
+ {
+ Stopwatch client_writing_watch;
+ reading_query_info.wait(false);
+ waited_for_client_writing += client_writing_watch.elapsedNanoseconds();
+ }
+ throwIfFailedToReadQueryInfo();
+ query_info = std::move(next_query_info_while_reading);
+ initial_query_info_read = true;
+ };
+
+ if (!initial_query_info_read)
+ {
+ /// Initial query info hasn't been read yet, so we're going to read it now.
+ start_reading();
+ }
+
+ /// Maybe it's reading a query info right now. Let it finish.
+ finish_reading();
+
+ if (isInputStreaming(call_type))
+ {
+ /// Next query info can contain more input data. Now we start reading a next query info,
+ /// so another call of readQueryInfo() in the future will probably be able to take it.
+ start_reading();
+ }
+ }
+
+ void Call::throwIfFailedToReadQueryInfo()
+ {
+ if (failed_to_read_query_info)
+ {
+ if (initial_query_info_read)
+ throw Exception(ErrorCodes::NETWORK_ERROR, "Failed to read extra QueryInfo");
+ else
+ throw Exception(ErrorCodes::NETWORK_ERROR, "Failed to read initial QueryInfo");
+ }
+ }
+
+ bool Call::isQueryCancelled()
+ {
+ if (cancelled)
+ {
+ result.set_cancelled(true);
+ return true;
+ }
+
+ if (want_to_cancel)
+ {
+ LOG_INFO(log, "Query cancelled");
+ cancelled = true;
+ result.set_cancelled(true);
+ return true;
+ }
+
+ return false;
+ }
+
+ void Call::addQueryDetailsToResult()
+ {
+ *result.mutable_query_id() = query_context->getClientInfo().current_query_id;
+ *result.mutable_time_zone() = DateLUT::instance().getTimeZone();
+ }
+
+ void Call::addOutputFormatToResult()
+ {
+ *result.mutable_output_format() = output_format;
+ }
+
+ void Call::addOutputColumnsNamesAndTypesToResult(const Block & header)
+ {
+ if (!send_output_columns_names_and_types)
+ return;
+ for (const auto & column : header)
+ {
+ auto & name_and_type = *result.add_output_columns();
+ *name_and_type.mutable_name() = column.name;
+ *name_and_type.mutable_type() = column.type->getName();
+ }
+ }
+
+ void Call::addProgressToResult()
+ {
+ auto values = progress.fetchValuesAndResetPiecewiseAtomically();
+ if (!values.read_rows && !values.read_bytes && !values.total_rows_to_read && !values.written_rows && !values.written_bytes)
+ return;
+ auto & grpc_progress = *result.mutable_progress();
+ /// Sum is used because we need to accumulate values for the case if streaming output is disabled.
+ grpc_progress.set_read_rows(grpc_progress.read_rows() + values.read_rows);
+ grpc_progress.set_read_bytes(grpc_progress.read_bytes() + values.read_bytes);
+ grpc_progress.set_total_rows_to_read(grpc_progress.total_rows_to_read() + values.total_rows_to_read);
+ grpc_progress.set_written_rows(grpc_progress.written_rows() + values.written_rows);
+ grpc_progress.set_written_bytes(grpc_progress.written_bytes() + values.written_bytes);
+ }
+
+ void Call::addTotalsToResult(const Block & totals)
+ {
+ if (!totals)
+ return;
+
+ PODArray<char> memory;
+ if (output_compression_method != CompressionMethod::None)
+ memory.resize(DBMS_DEFAULT_BUFFER_SIZE); /// Must have enough space for compressed data.
+ std::unique_ptr<WriteBuffer> buf = std::make_unique<WriteBufferFromVector<PODArray<char>>>(memory);
+ buf = wrapWriteBufferWithCompressionMethod(std::move(buf), output_compression_method, output_compression_level);
+ auto format = query_context->getOutputFormat(output_format, *buf, totals);
+ format->write(materializeBlock(totals));
+ format->finalize();
+ buf->finalize();
+
+ result.mutable_totals()->assign(memory.data(), memory.size());
+ }
+
+ void Call::addExtremesToResult(const Block & extremes)
+ {
+ if (!extremes)
+ return;
+
+ PODArray<char> memory;
+ if (output_compression_method != CompressionMethod::None)
+ memory.resize(DBMS_DEFAULT_BUFFER_SIZE); /// Must have enough space for compressed data.
+ std::unique_ptr<WriteBuffer> buf = std::make_unique<WriteBufferFromVector<PODArray<char>>>(memory);
+ buf = wrapWriteBufferWithCompressionMethod(std::move(buf), output_compression_method, output_compression_level);
+ auto format = query_context->getOutputFormat(output_format, *buf, extremes);
+ format->write(materializeBlock(extremes));
+ format->finalize();
+ buf->finalize();
+
+ result.mutable_extremes()->assign(memory.data(), memory.size());
+ }
+
+ void Call::addProfileInfoToResult(const ProfileInfo & info)
+ {
+ auto & stats = *result.mutable_stats();
+ stats.set_rows(info.rows);
+ stats.set_blocks(info.blocks);
+ stats.set_allocated_bytes(info.bytes);
+ stats.set_applied_limit(info.hasAppliedLimit());
+ stats.set_rows_before_limit(info.getRowsBeforeLimit());
+ }
+
+ void Call::addLogsToResult()
+ {
+ if (!logs_queue)
+ return;
+
+ static_assert(::clickhouse::grpc::LOG_NONE == 0);
+ static_assert(::clickhouse::grpc::LOG_FATAL == static_cast<int>(Poco::Message::PRIO_FATAL));
+ static_assert(::clickhouse::grpc::LOG_CRITICAL == static_cast<int>(Poco::Message::PRIO_CRITICAL));
+ static_assert(::clickhouse::grpc::LOG_ERROR == static_cast<int>(Poco::Message::PRIO_ERROR));
+ static_assert(::clickhouse::grpc::LOG_WARNING == static_cast<int>(Poco::Message::PRIO_WARNING));
+ static_assert(::clickhouse::grpc::LOG_NOTICE == static_cast<int>(Poco::Message::PRIO_NOTICE));
+ static_assert(::clickhouse::grpc::LOG_INFORMATION == static_cast<int>(Poco::Message::PRIO_INFORMATION));
+ static_assert(::clickhouse::grpc::LOG_DEBUG == static_cast<int>(Poco::Message::PRIO_DEBUG));
+ static_assert(::clickhouse::grpc::LOG_TRACE == static_cast<int>(Poco::Message::PRIO_TRACE));
+
+ MutableColumns columns;
+ while (logs_queue->tryPop(columns))
+ {
+ if (columns.empty() || columns[0]->empty())
+ continue;
+
+ const auto & column_time = typeid_cast<const ColumnUInt32 &>(*columns[0]);
+ const auto & column_time_microseconds = typeid_cast<const ColumnUInt32 &>(*columns[1]);
+ const auto & column_query_id = typeid_cast<const ColumnString &>(*columns[3]);
+ const auto & column_thread_id = typeid_cast<const ColumnUInt64 &>(*columns[4]);
+ const auto & column_level = typeid_cast<const ColumnInt8 &>(*columns[5]);
+ const auto & column_source = typeid_cast<const ColumnString &>(*columns[6]);
+ const auto & column_text = typeid_cast<const ColumnString &>(*columns[7]);
+ size_t num_rows = column_time.size();
+
+ for (size_t row = 0; row != num_rows; ++row)
+ {
+ auto & log_entry = *result.add_logs();
+ log_entry.set_time(column_time.getElement(row));
+ log_entry.set_time_microseconds(column_time_microseconds.getElement(row));
+ std::string_view query_id = column_query_id.getDataAt(row).toView();
+ log_entry.set_query_id(query_id.data(), query_id.size());
+ log_entry.set_thread_id(column_thread_id.getElement(row));
+ log_entry.set_level(static_cast<::clickhouse::grpc::LogsLevel>(column_level.getElement(row)));
+ std::string_view source = column_source.getDataAt(row).toView();
+ log_entry.set_source(source.data(), source.size());
+ std::string_view text = column_text.getDataAt(row).toView();
+ log_entry.set_text(text.data(), text.size());
+ }
+ }
+ }
+
+ void Call::sendResult()
+ {
+ /// gRPC doesn't allow to write anything to a finished responder.
+ if (responder_finished)
+ return;
+
+ /// If output is not streaming then only the final result can be sent.
+ bool send_final_message = finalize || result.has_exception() || result.cancelled();
+ if (!send_final_message && !isOutputStreaming(call_type))
+ return;
+
+ /// Copy output to `result.output`, with optional compressing.
+ if (write_buffer)
+ {
+ size_t output_size;
+ if (send_final_message)
+ {
+ if (compressing_write_buffer)
+ LOG_DEBUG(log, "Compressing final {} bytes", compressing_write_buffer->offset());
+ write_buffer->finalize();
+ output_size = output.size();
+ }
+ else
+ {
+ if (compressing_write_buffer && compressing_write_buffer->offset())
+ {
+ LOG_DEBUG(log, "Compressing {} bytes", compressing_write_buffer->offset());
+ compressing_write_buffer->sync();
+ }
+ output_size = nested_write_buffer->position() - output.data();
+ }
+
+ if (output_size)
+ {
+ result.mutable_output()->assign(output.data(), output_size);
+ nested_write_buffer->restart(); /// We're going to reuse the same buffer again for next block of data.
+ }
+ }
+
+ if (!send_final_message && result.output().empty() && result.totals().empty() && result.extremes().empty() && !result.logs_size()
+ && !result.has_progress() && !result.has_stats() && !result.has_exception() && !result.cancelled())
+ return; /// Nothing to send.
+
+ /// Wait for previous write to finish.
+ /// (gRPC doesn't allow to start sending another result while the previous is still being sending.)
+ if (sending_result.get())
+ {
+ Stopwatch client_reading_watch;
+ sending_result.wait(false);
+ waited_for_client_reading += client_reading_watch.elapsedNanoseconds();
+ }
+ throwIfFailedToSendResult();
+
+ /// Start sending the result.
+ LOG_DEBUG(log, "Sending {} result to the client: {}", (send_final_message ? "final" : "intermediate"), getResultDescription(result));
+
+ sending_result.set(true);
+ auto callback = [this](bool ok)
+ {
+ /// Called on queue_thread.
+ if (!ok)
+ failed_to_send_result = true;
+ sending_result.set(false);
+ };
+
+ Stopwatch client_reading_final_watch;
+ if (send_final_message)
+ {
+ responder_finished = true;
+ responder->writeAndFinish(result, {}, callback);
+ }
+ else
+ responder->write(result, callback);
+
+ /// gRPC has already retrieved all data from `result`, so we don't have to keep it.
+ result.Clear();
+
+ if (send_final_message)
+ {
+ /// Wait until the result is actually sent.
+ sending_result.wait(false);
+ waited_for_client_reading += client_reading_final_watch.elapsedNanoseconds();
+ throwIfFailedToSendResult();
+ LOG_TRACE(log, "Final result has been sent to the client");
+ }
+ }
+
+ void Call::throwIfFailedToSendResult()
+ {
+ if (failed_to_send_result)
+ throw Exception(ErrorCodes::NETWORK_ERROR, "Failed to send result to the client");
+ }
+
+ void Call::sendException(const Exception & exception)
+ {
+ auto & grpc_exception = *result.mutable_exception();
+ grpc_exception.set_code(exception.code());
+ grpc_exception.set_name(exception.name());
+ grpc_exception.set_display_text(exception.displayText());
+ if (send_exception_with_stacktrace)
+ grpc_exception.set_stack_trace(exception.getStackTraceString());
+ sendResult();
+ }
+}
+
+
+class GRPCServer::Runner
+{
+public:
+ explicit Runner(GRPCServer & owner_) : owner(owner_) {}
+
+ ~Runner()
+ {
+ if (queue_thread.joinable())
+ queue_thread.join();
+ }
+
+ void start()
+ {
+ startReceivingNewCalls();
+
+ /// We run queue in a separate thread.
+ auto runner_function = [this]
+ {
+ try
+ {
+ run();
+ }
+ catch (...)
+ {
+ tryLogCurrentException("GRPCServer");
+ }
+ };
+ queue_thread = ThreadFromGlobalPool{runner_function};
+ }
+
+ void stop() { stopReceivingNewCalls(); }
+
+ size_t getNumCurrentCalls() const
+ {
+ std::lock_guard lock{mutex};
+ return current_calls.size();
+ }
+
+private:
+ void startReceivingNewCalls()
+ {
+ std::lock_guard lock{mutex};
+ responders_for_new_calls.resize(CALL_MAX);
+ for (CallType call_type : collections::range(CALL_MAX))
+ makeResponderForNewCall(call_type);
+ }
+
+ void makeResponderForNewCall(CallType call_type)
+ {
+ /// `mutex` is already locked.
+ responders_for_new_calls[call_type] = makeResponder(call_type);
+
+ responders_for_new_calls[call_type]->start(
+ owner.grpc_service, *owner.queue, *owner.queue,
+ [this, call_type](bool ok) { onNewCall(call_type, ok); });
+ }
+
+ void stopReceivingNewCalls()
+ {
+ std::lock_guard lock{mutex};
+ should_stop = true;
+ }
+
+ void onNewCall(CallType call_type, bool responder_started_ok)
+ {
+ std::lock_guard lock{mutex};
+ auto responder = std::move(responders_for_new_calls[call_type]);
+ if (should_stop)
+ return;
+ makeResponderForNewCall(call_type);
+ if (responder_started_ok)
+ {
+ /// Connection established and the responder has been started.
+ /// So we pass this responder to a Call and make another responder for next connection.
+ auto new_call = std::make_unique<Call>(call_type, std::move(responder), owner.iserver, owner.log);
+ auto * new_call_ptr = new_call.get();
+ current_calls[new_call_ptr] = std::move(new_call);
+ new_call_ptr->start([this, new_call_ptr]() { onFinishCall(new_call_ptr); });
+ }
+ }
+
+ void onFinishCall(Call * call)
+ {
+ /// Called on call_thread. That's why we can't destroy the `call` right now
+ /// (thread can't join to itself). Thus here we only move the `call` from
+ /// `current_calls` to `finished_calls` and run() will actually destroy the `call`.
+ std::lock_guard lock{mutex};
+ auto it = current_calls.find(call);
+ finished_calls.push_back(std::move(it->second));
+ current_calls.erase(it);
+ }
+
+ void run()
+ {
+ setThreadName("GRPCServerQueue");
+ while (true)
+ {
+ {
+ std::lock_guard lock{mutex};
+ finished_calls.clear(); /// Destroy finished calls.
+
+ /// If (should_stop == true) we continue processing until there is no active calls.
+ if (should_stop && current_calls.empty())
+ {
+ bool all_responders_gone = std::all_of(
+ responders_for_new_calls.begin(), responders_for_new_calls.end(),
+ [](std::unique_ptr<BaseResponder> & responder) { return !responder; });
+ if (all_responders_gone)
+ break;
+ }
+ }
+
+ bool ok = false;
+ void * tag = nullptr;
+ if (!owner.queue->Next(&tag, &ok))
+ {
+ /// Queue shutted down.
+ break;
+ }
+
+ auto & callback = *static_cast<CompletionCallback *>(tag);
+ callback(ok);
+ }
+ }
+
+ GRPCServer & owner;
+ ThreadFromGlobalPool queue_thread;
+ std::vector<std::unique_ptr<BaseResponder>> responders_for_new_calls;
+ std::map<Call *, std::unique_ptr<Call>> current_calls;
+ std::vector<std::unique_ptr<Call>> finished_calls;
+ bool should_stop = false;
+ mutable std::mutex mutex;
+};
+
+
+GRPCServer::GRPCServer(IServer & iserver_, const Poco::Net::SocketAddress & address_to_listen_)
+ : iserver(iserver_)
+ , address_to_listen(address_to_listen_)
+ , log(&Poco::Logger::get("GRPCServer"))
+ , runner(std::make_unique<Runner>(*this))
+{}
+
+GRPCServer::~GRPCServer()
+{
+ /// Server should be shutdown before CompletionQueue.
+ if (grpc_server)
+ grpc_server->Shutdown();
+
+ /// Completion Queue should be shutdown before destroying the runner,
+ /// because the runner is now probably executing CompletionQueue::Next() on queue_thread
+ /// which is blocked until an event is available or the queue is shutting down.
+ if (queue)
+ queue->Shutdown();
+
+ runner.reset();
+}
+
+void GRPCServer::start()
+{
+ initGRPCLogging(iserver.config());
+ grpc::ServerBuilder builder;
+ builder.AddListeningPort(address_to_listen.toString(), makeCredentials(iserver.config()));
+ builder.RegisterService(&grpc_service);
+ builder.SetMaxSendMessageSize(iserver.config().getInt("grpc.max_send_message_size", -1));
+ builder.SetMaxReceiveMessageSize(iserver.config().getInt("grpc.max_receive_message_size", -1));
+ auto default_transport_compression = TransportCompression::fromConfiguration(iserver.config());
+ builder.SetDefaultCompressionAlgorithm(default_transport_compression.algorithm);
+ builder.SetDefaultCompressionLevel(default_transport_compression.level);
+
+ queue = builder.AddCompletionQueue();
+ grpc_server = builder.BuildAndStart();
+ if (nullptr == grpc_server)
+ {
+ throw DB::Exception(DB::ErrorCodes::NETWORK_ERROR, "Can't start grpc server, there is a port conflict");
+ }
+
+ runner->start();
+}
+
+
+void GRPCServer::stop()
+{
+ /// Stop receiving new calls.
+ runner->stop();
+}
+
+size_t GRPCServer::currentConnections() const
+{
+ return runner->getNumCurrentCalls();
+}
+
+}
+#endif
diff --git a/contrib/clickhouse/src/Server/GRPCServer.h b/contrib/clickhouse/src/Server/GRPCServer.h
new file mode 100644
index 0000000000..b77f22f046
--- /dev/null
+++ b/contrib/clickhouse/src/Server/GRPCServer.h
@@ -0,0 +1,56 @@
+#pragma once
+
+#include "clickhouse_config.h"
+
+#if USE_GRPC
+#include <Poco/Net/SocketAddress.h>
+#include <base/types.h>
+#error #include "clickhouse_grpc.grpc.pb.h"
+
+namespace Poco { class Logger; }
+
+namespace grpc
+{
+class Server;
+class ServerCompletionQueue;
+}
+
+namespace DB
+{
+class IServer;
+
+class GRPCServer
+{
+public:
+ GRPCServer(IServer & iserver_, const Poco::Net::SocketAddress & address_to_listen_);
+ ~GRPCServer();
+
+ /// Starts the server. A new thread will be created that waits for and accepts incoming connections.
+ void start();
+
+ /// Stops the server. No new connections will be accepted.
+ void stop();
+
+ /// Returns the port this server is listening to.
+ UInt16 portNumber() const { return address_to_listen.port(); }
+
+ /// Returns the number of currently handled connections.
+ size_t currentConnections() const;
+
+ /// Returns the number of current threads.
+ size_t currentThreads() const { return currentConnections(); }
+
+private:
+ using GRPCService = clickhouse::grpc::ClickHouse::AsyncService;
+ class Runner;
+
+ IServer & iserver;
+ const Poco::Net::SocketAddress address_to_listen;
+ Poco::Logger * log;
+ GRPCService grpc_service;
+ std::unique_ptr<grpc::Server> grpc_server;
+ std::unique_ptr<grpc::ServerCompletionQueue> queue;
+ std::unique_ptr<Runner> runner;
+};
+}
+#endif
diff --git a/contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp b/contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp
new file mode 100644
index 0000000000..1abf9e5b83
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp
@@ -0,0 +1,347 @@
+#include <Server/HTTP/HTMLForm.h>
+
+#include <Core/Settings.h>
+#include <IO/EmptyReadBuffer.h>
+#include <IO/ReadBufferFromString.h>
+#include <Server/HTTP/ReadHeaders.h>
+
+#include <Poco/CountingStream.h>
+#include <Poco/Net/MultipartReader.h>
+#include <Poco/Net/MultipartWriter.h>
+#include <Poco/Net/NetException.h>
+#include <Poco/Net/NullPartHandler.h>
+#include <Poco/NullStream.h>
+#include <Poco/StreamCopier.h>
+#include <Poco/UTF8String.h>
+
+#include <sstream>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int CANNOT_READ_ALL_DATA;
+}
+
+namespace
+{
+
+class NullPartHandler : public HTMLForm::PartHandler
+{
+public:
+ void handlePart(const Poco::Net::MessageHeader &, ReadBuffer &) override {}
+};
+
+}
+
+const std::string HTMLForm::ENCODING_URL = "application/x-www-form-urlencoded";
+const std::string HTMLForm::ENCODING_MULTIPART = "multipart/form-data";
+const int HTMLForm::UNKNOWN_CONTENT_LENGTH = -1;
+
+
+HTMLForm::HTMLForm(const Settings & settings)
+ : max_fields_number(settings.http_max_fields)
+ , max_field_name_size(settings.http_max_field_name_size)
+ , max_field_value_size(settings.http_max_field_value_size)
+ , encoding(ENCODING_URL)
+{
+}
+
+
+HTMLForm::HTMLForm(const Settings & settings, const std::string & encoding_) : HTMLForm(settings)
+{
+ encoding = encoding_;
+}
+
+
+HTMLForm::HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler)
+ : HTMLForm(settings)
+{
+ load(request, requestBody, handler);
+}
+
+
+HTMLForm::HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody) : HTMLForm(settings)
+{
+ load(request, requestBody);
+}
+
+
+HTMLForm::HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request) : HTMLForm(settings, Poco::URI(request.getURI()))
+{
+}
+
+HTMLForm::HTMLForm(const Settings & settings, const Poco::URI & uri) : HTMLForm(settings)
+{
+ ReadBufferFromString istr(uri.getRawQuery()); // STYLE_CHECK_ALLOW_STD_STRING_STREAM
+ readQuery(istr);
+}
+
+
+void HTMLForm::load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler)
+{
+ clear();
+
+ Poco::URI uri(request.getURI());
+ const std::string & query = uri.getRawQuery();
+ if (!query.empty())
+ {
+ ReadBufferFromString istr(query);
+ readQuery(istr);
+ }
+
+ if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST || request.getMethod() == Poco::Net::HTTPRequest::HTTP_PUT)
+ {
+ std::string media_type;
+ NameValueCollection params;
+ Poco::Net::MessageHeader::splitParameters(request.getContentType(), media_type, params);
+ encoding = media_type;
+ if (encoding == ENCODING_MULTIPART)
+ {
+ boundary = params["boundary"];
+ readMultipart(requestBody, handler);
+ }
+ else
+ {
+ readQuery(requestBody);
+ }
+ }
+}
+
+
+void HTMLForm::load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody)
+{
+ NullPartHandler nah;
+ load(request, requestBody, nah);
+}
+
+
+void HTMLForm::read(ReadBuffer & in)
+{
+ readQuery(in);
+}
+
+
+void HTMLForm::readQuery(ReadBuffer & in)
+{
+ size_t fields = 0;
+ char ch = 0; // silence "uninitialized" warning from gcc-*
+ bool is_first = true;
+
+ while (true)
+ {
+ if (max_fields_number > 0 && fields == max_fields_number)
+ throw Poco::Net::HTMLFormException("Too many form fields");
+
+ std::string name;
+ std::string value;
+
+ while (in.read(ch) && ch != '=' && ch != '&')
+ {
+ if (ch == '+')
+ ch = ' ';
+ if (name.size() < max_field_name_size)
+ name += ch;
+ else
+ throw Poco::Net::HTMLFormException("Field name too long");
+ }
+
+ if (ch == '=')
+ {
+ while (in.read(ch) && ch != '&')
+ {
+ if (ch == '+')
+ ch = ' ';
+ if (value.size() < max_field_value_size)
+ value += ch;
+ else
+ throw Poco::Net::HTMLFormException("Field value too long");
+ }
+ }
+
+ // Remove UTF-8 BOM from first name, if present
+ if (is_first)
+ Poco::UTF8::removeBOM(name);
+
+ std::string decoded_name;
+ std::string decoded_value;
+ Poco::URI::decode(name, decoded_name);
+ Poco::URI::decode(value, decoded_value);
+ add(decoded_name, decoded_value);
+ ++fields;
+
+ is_first = false;
+
+ if (in.eof())
+ break;
+ }
+}
+
+
+void HTMLForm::readMultipart(ReadBuffer & in_, PartHandler & handler)
+{
+ /// Assume there is always a boundary provided.
+ assert(!boundary.empty());
+
+ size_t fields = 0;
+ MultipartReadBuffer in(in_, boundary);
+
+ if (!in.skipToNextBoundary())
+ throw Poco::Net::HTMLFormException("No boundary line found");
+
+ /// Read each part until next boundary (or last boundary)
+ while (!in.eof())
+ {
+ if (max_fields_number && fields > max_fields_number)
+ throw Poco::Net::HTMLFormException("Too many form fields");
+
+ Poco::Net::MessageHeader header;
+ readHeaders(header, in, max_fields_number, max_field_name_size, max_field_value_size);
+ skipToNextLineOrEOF(in);
+
+ NameValueCollection params;
+ if (header.has("Content-Disposition"))
+ {
+ std::string unused;
+ Poco::Net::MessageHeader::splitParameters(header.get("Content-Disposition"), unused, params);
+ }
+
+ if (params.has("filename"))
+ handler.handlePart(header, in);
+ else
+ {
+ std::string name = params["name"];
+ std::string value;
+ char ch;
+
+ while (in.read(ch))
+ {
+ if (value.size() > max_field_value_size)
+ throw Poco::Net::HTMLFormException("Field value too long");
+ value += ch;
+ }
+
+ add(name, value);
+ }
+
+ ++fields;
+
+ /// If we already encountered EOF for the buffer |in|, it's possible that the next symbol is a start of boundary line.
+ /// In this case reading the boundary line will reset the EOF state, potentially breaking invariant of EOF idempotency -
+ /// if there is such invariant in the first place.
+ if (!in.skipToNextBoundary())
+ break;
+ }
+
+ /// It's important to check, because we could get "fake" EOF and incomplete request if a client suddenly died in the middle.
+ if (!in.isActualEOF())
+ throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Unexpected EOF, "
+ "did not find the last boundary while parsing a multipart HTTP request");
+}
+
+
+HTMLForm::MultipartReadBuffer::MultipartReadBuffer(ReadBuffer & in_, const std::string & boundary_)
+ : ReadBuffer(nullptr, 0), in(in_), boundary("--" + boundary_)
+{
+ /// For consistency with |nextImpl()|
+ position() = in.position();
+}
+
+bool HTMLForm::MultipartReadBuffer::skipToNextBoundary()
+{
+ if (in.eof())
+ return false;
+
+ chassert(boundary_hit);
+ chassert(!found_last_boundary);
+
+ boundary_hit = false;
+
+ while (!in.eof())
+ {
+ auto line = readLine(true);
+ if (startsWith(line, boundary))
+ {
+ set(in.position(), 0);
+ next(); /// We need to restrict our buffer to size of next available line.
+ found_last_boundary = startsWith(line, boundary + "--");
+ return !found_last_boundary;
+ }
+ }
+
+ return false;
+}
+
+std::string HTMLForm::MultipartReadBuffer::readLine(bool append_crlf)
+{
+ std::string line;
+ char ch = 0; // silence "uninitialized" warning from gcc-*
+
+ /// If we don't append CRLF, it means that we may have to prepend CRLF from previous content line, which wasn't the boundary.
+ if (in.read(ch))
+ line += ch;
+ if (in.read(ch))
+ line += ch;
+ if (append_crlf && line == "\r\n")
+ return line;
+
+ while (!in.eof())
+ {
+ while (in.read(ch) && ch != '\r')
+ line += ch;
+
+ if (in.eof()) break;
+
+ assert(ch == '\r');
+
+ if (in.peek(ch) && ch == '\n')
+ {
+ in.ignore();
+ if (append_crlf) line += "\r\n";
+ break;
+ }
+
+ line += ch;
+ }
+
+ return line;
+}
+
+bool HTMLForm::MultipartReadBuffer::nextImpl()
+{
+ if (boundary_hit)
+ return false;
+
+ assert(position() >= in.position());
+
+ in.position() = position();
+
+ /// We expect to start from the first symbol after EOL, so we can put checkpoint
+ /// and safely try to read til the next EOL and check for boundary.
+ in.setCheckpoint();
+
+ /// FIXME: there is an extra copy because we cannot traverse PeekableBuffer from checkpoint to position()
+ /// since it may store different data parts in different sub-buffers,
+ /// anyway calling makeContinuousMemoryFromCheckpointToPos() will also make an extra copy.
+ /// According to RFC2046 the preceding CRLF is a part of boundary line.
+ std::string line = readLine(false);
+ boundary_hit = startsWith(line, "\r\n" + boundary);
+ bool has_next = !boundary_hit && !line.empty();
+
+ if (has_next)
+ /// If we don't make sure that memory is contiguous then situation may happen, when part of the line is inside internal memory
+ /// and other part is inside sub-buffer, thus we'll be unable to setup our working buffer properly.
+ in.makeContinuousMemoryFromCheckpointToPos();
+
+ in.rollbackToCheckpoint(true);
+
+ /// Rolling back to checkpoint may change underlying buffers.
+ /// Limit readable data to a single line.
+ BufferBase::set(in.position(), line.size(), 0);
+
+ return has_next;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTMLForm.h b/contrib/clickhouse/src/Server/HTTP/HTMLForm.h
new file mode 100644
index 0000000000..c75dafccaf
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTMLForm.h
@@ -0,0 +1,124 @@
+#pragma once
+
+#include <IO/PeekableReadBuffer.h>
+#include <IO/ReadHelpers.h>
+
+#include <boost/noncopyable.hpp>
+#include <Poco/Net/HTTPRequest.h>
+#include <Poco/Net/NameValueCollection.h>
+#include <Poco/Net/PartSource.h>
+#include <Poco/URI.h>
+
+namespace DB
+{
+
+struct Settings;
+
+class HTMLForm : public Poco::Net::NameValueCollection, private boost::noncopyable
+{
+public:
+ class PartHandler;
+
+ enum Options
+ {
+ OPT_USE_CONTENT_LENGTH = 0x01, /// don't use Chunked Transfer-Encoding for multipart requests.
+ };
+
+ /// Creates an empty HTMLForm and sets the
+ /// encoding to "application/x-www-form-urlencoded".
+ explicit HTMLForm(const Settings & settings);
+
+ /// Creates an empty HTMLForm that uses the given encoding.
+ /// Encoding must be either "application/x-www-form-urlencoded" (which is the default) or "multipart/form-data".
+ explicit HTMLForm(const Settings & settings, const std::string & encoding);
+
+ /// Creates a HTMLForm from the given HTTP request.
+ /// Uploaded files are passed to the given PartHandler.
+ HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler);
+
+ /// Creates a HTMLForm from the given HTTP request.
+ /// Uploaded files are silently discarded.
+ HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody);
+
+ /// Creates a HTMLForm from the given HTTP request.
+ /// The request must be a GET request and the form data must be in the query string (URL encoded).
+ /// For POST requests, you must use one of the constructors taking an additional input stream for the request body.
+ explicit HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request);
+
+ explicit HTMLForm(const Settings & settings, const Poco::URI & uri);
+
+ template <typename T>
+ T getParsed(const std::string & key, T default_value)
+ {
+ auto it = find(key);
+ return (it != end()) ? DB::parse<T>(it->second) : default_value;
+ }
+
+ /// Reads the form data from the given HTTP request.
+ /// Uploaded files are passed to the given PartHandler.
+ void load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler);
+
+ /// Reads the form data from the given HTTP request.
+ /// Uploaded files are silently discarded.
+ void load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody);
+
+ /// Reads the URL-encoded form data from the given input stream.
+ /// Note that read() does not clear the form before reading the new values.
+ void read(ReadBuffer & in);
+
+ static const std::string ENCODING_URL; /// "application/x-www-form-urlencoded"
+ static const std::string ENCODING_MULTIPART; /// "multipart/form-data"
+ static const int UNKNOWN_CONTENT_LENGTH;
+
+protected:
+ void readQuery(ReadBuffer & in);
+ void readMultipart(ReadBuffer & in, PartHandler & handler);
+
+private:
+ /// This buffer provides data line by line to check for boundary line in a convenient way.
+ class MultipartReadBuffer;
+
+ struct Part
+ {
+ std::string name;
+ std::unique_ptr<Poco::Net::PartSource> source;
+ };
+
+ using PartVec = std::vector<Part>;
+
+ const size_t max_fields_number, max_field_name_size, max_field_value_size;
+
+ std::string encoding;
+ std::string boundary;
+ PartVec parts;
+};
+
+class HTMLForm::PartHandler
+{
+public:
+ virtual ~PartHandler() = default;
+ virtual void handlePart(const Poco::Net::MessageHeader &, ReadBuffer &) = 0;
+};
+
+class HTMLForm::MultipartReadBuffer : public ReadBuffer
+{
+public:
+ MultipartReadBuffer(ReadBuffer & in, const std::string & boundary);
+
+ /// Returns false if last boundary found.
+ bool skipToNextBoundary();
+
+ bool isActualEOF() const { return found_last_boundary; }
+
+private:
+ PeekableReadBuffer in;
+ const std::string boundary;
+ bool boundary_hit = true;
+ bool found_last_boundary = false;
+
+ std::string readLine(bool append_crlf);
+
+ bool nextImpl() override;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPContext.h b/contrib/clickhouse/src/Server/HTTP/HTTPContext.h
new file mode 100644
index 0000000000..09c46ed188
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPContext.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include <Poco/Timespan.h>
+
+namespace DB
+{
+
+struct IHTTPContext
+{
+ virtual uint64_t getMaxHstsAge() const = 0;
+ virtual uint64_t getMaxUriSize() const = 0;
+ virtual uint64_t getMaxFields() const = 0;
+ virtual uint64_t getMaxFieldNameSize() const = 0;
+ virtual uint64_t getMaxFieldValueSize() const = 0;
+ virtual uint64_t getMaxChunkSize() const = 0;
+ virtual Poco::Timespan getReceiveTimeout() const = 0;
+ virtual Poco::Timespan getSendTimeout() const = 0;
+
+ virtual ~IHTTPContext() = default;
+};
+
+using HTTPContextPtr = std::shared_ptr<IHTTPContext>;
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPRequest.h b/contrib/clickhouse/src/Server/HTTP/HTTPRequest.h
new file mode 100644
index 0000000000..40839cbcdd
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPRequest.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include <Poco/Net/HTTPRequest.h>
+
+namespace DB
+{
+
+using HTTPRequest = Poco::Net::HTTPRequest;
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h
new file mode 100644
index 0000000000..19340866bb
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <Server/HTTP/HTTPServerRequest.h>
+#include <Server/HTTP/HTTPServerResponse.h>
+
+#include <boost/noncopyable.hpp>
+
+namespace DB
+{
+
+class HTTPRequestHandler : private boost::noncopyable
+{
+public:
+ virtual ~HTTPRequestHandler() = default;
+
+ virtual void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) = 0;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h
new file mode 100644
index 0000000000..3d50bf0a2e
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+
+#include <boost/noncopyable.hpp>
+
+namespace DB
+{
+
+class HTTPRequestHandlerFactory : private boost::noncopyable
+{
+public:
+ virtual ~HTTPRequestHandlerFactory() = default;
+
+ virtual std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) = 0;
+};
+
+using HTTPRequestHandlerFactoryPtr = std::shared_ptr<HTTPRequestHandlerFactory>;
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPResponse.h b/contrib/clickhouse/src/Server/HTTP/HTTPResponse.h
new file mode 100644
index 0000000000..c73bcec6c3
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPResponse.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include <Poco/Net/HTTPResponse.h>
+
+namespace DB
+{
+
+using HTTPResponse = Poco::Net::HTTPResponse;
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp
new file mode 100644
index 0000000000..4673493326
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp
@@ -0,0 +1,30 @@
+#include <Server/HTTP/HTTPServer.h>
+
+#include <Server/HTTP/HTTPServerConnectionFactory.h>
+
+
+namespace DB
+{
+HTTPServer::HTTPServer(
+ HTTPContextPtr context,
+ HTTPRequestHandlerFactoryPtr factory_,
+ Poco::ThreadPool & thread_pool,
+ Poco::Net::ServerSocket & socket_,
+ Poco::Net::HTTPServerParams::Ptr params)
+ : TCPServer(new HTTPServerConnectionFactory(context, params, factory_), thread_pool, socket_, params), factory(factory_)
+{
+}
+
+HTTPServer::~HTTPServer()
+{
+ /// We should call stop and join thread here instead of destructor of parent TCPHandler,
+ /// because there's possible race on 'vptr' between this virtual destructor and 'run' method.
+ stop();
+}
+
+void HTTPServer::stopAll(bool /* abortCurrent */)
+{
+ stop();
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServer.h b/contrib/clickhouse/src/Server/HTTP/HTTPServer.h
new file mode 100644
index 0000000000..adfb21e7c6
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServer.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandlerFactory.h>
+#include <Server/HTTP/HTTPContext.h>
+#include <Server/TCPServer.h>
+
+#include <Poco/Net/HTTPServerParams.h>
+
+#include <base/types.h>
+
+
+namespace DB
+{
+
+class HTTPServer : public TCPServer
+{
+public:
+ explicit HTTPServer(
+ HTTPContextPtr context,
+ HTTPRequestHandlerFactoryPtr factory,
+ Poco::ThreadPool & thread_pool,
+ Poco::Net::ServerSocket & socket,
+ Poco::Net::HTTPServerParams::Ptr params);
+
+ ~HTTPServer() override;
+
+ void stopAll(bool abort_current = false);
+
+private:
+ HTTPRequestHandlerFactoryPtr factory;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp
new file mode 100644
index 0000000000..ad17bc4348
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp
@@ -0,0 +1,121 @@
+#include <Server/HTTP/HTTPServerConnection.h>
+#include <Server/TCPServer.h>
+
+#include <Poco/Net/NetException.h>
+
+namespace DB
+{
+
+HTTPServerConnection::HTTPServerConnection(
+ HTTPContextPtr context_,
+ TCPServer & tcp_server_,
+ const Poco::Net::StreamSocket & socket,
+ Poco::Net::HTTPServerParams::Ptr params_,
+ HTTPRequestHandlerFactoryPtr factory_)
+ : TCPServerConnection(socket), context(std::move(context_)), tcp_server(tcp_server_), params(params_), factory(factory_), stopped(false)
+{
+ poco_check_ptr(factory);
+}
+
+void HTTPServerConnection::run()
+{
+ std::string server = params->getSoftwareVersion();
+ Poco::Net::HTTPServerSession session(socket(), params);
+
+ while (!stopped && tcp_server.isOpen() && session.hasMoreRequests() && session.connected())
+ {
+ try
+ {
+ std::lock_guard lock(mutex);
+ if (!stopped && tcp_server.isOpen() && session.connected())
+ {
+ HTTPServerResponse response(session);
+ HTTPServerRequest request(context, response, session);
+
+ Poco::Timestamp now;
+
+ if (!forwarded_for.empty())
+ request.set("X-Forwarded-For", forwarded_for);
+
+ if (request.isSecure())
+ {
+ size_t hsts_max_age = context->getMaxHstsAge();
+
+ if (hsts_max_age > 0)
+ response.add("Strict-Transport-Security", "max-age=" + std::to_string(hsts_max_age));
+
+ }
+
+ response.setDate(now);
+ response.setVersion(request.getVersion());
+ response.setKeepAlive(params->getKeepAlive() && request.getKeepAlive() && session.canKeepAlive());
+ if (!server.empty())
+ response.set("Server", server);
+ try
+ {
+ if (!tcp_server.isOpen())
+ {
+ sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_SERVICE_UNAVAILABLE);
+ break;
+ }
+ std::unique_ptr<HTTPRequestHandler> handler(factory->createRequestHandler(request));
+
+ if (handler)
+ {
+ if (request.getExpectContinue() && response.getStatus() == Poco::Net::HTTPResponse::HTTP_OK)
+ response.sendContinue();
+
+ handler->handleRequest(request, response);
+ session.setKeepAlive(params->getKeepAlive() && response.getKeepAlive() && session.canKeepAlive());
+ }
+ else
+ sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_NOT_IMPLEMENTED);
+ }
+ catch (Poco::Exception &)
+ {
+ if (!response.sent())
+ {
+ try
+ {
+ sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
+ }
+ catch (...)
+ {
+ }
+ }
+ throw;
+ }
+ }
+ }
+ catch (const Poco::Net::NoMessageException &)
+ {
+ break;
+ }
+ catch (const Poco::Net::MessageException &)
+ {
+ sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_BAD_REQUEST);
+ }
+ catch (const Poco::Exception &)
+ {
+ if (session.networkException())
+ {
+ session.networkException()->rethrow();
+ }
+ else
+ throw;
+ }
+ }
+}
+
+// static
+void HTTPServerConnection::sendErrorResponse(Poco::Net::HTTPServerSession & session, Poco::Net::HTTPResponse::HTTPStatus status)
+{
+ HTTPServerResponse response(session);
+ response.setVersion(Poco::Net::HTTPMessage::HTTP_1_1);
+ response.setStatusAndReason(status);
+ response.setKeepAlive(false);
+ response.send();
+ session.setKeepAlive(false);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h
new file mode 100644
index 0000000000..7087f8d5a2
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h
@@ -0,0 +1,51 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandlerFactory.h>
+#include <Server/HTTP/HTTPContext.h>
+
+#include <Poco/Net/HTTPServerParams.h>
+#include <Poco/Net/HTTPServerSession.h>
+#include <Poco/Net/TCPServerConnection.h>
+
+namespace DB
+{
+class TCPServer;
+
+class HTTPServerConnection : public Poco::Net::TCPServerConnection
+{
+public:
+ HTTPServerConnection(
+ HTTPContextPtr context,
+ TCPServer & tcp_server,
+ const Poco::Net::StreamSocket & socket,
+ Poco::Net::HTTPServerParams::Ptr params,
+ HTTPRequestHandlerFactoryPtr factory);
+
+ HTTPServerConnection(
+ HTTPContextPtr context_,
+ TCPServer & tcp_server_,
+ const Poco::Net::StreamSocket & socket_,
+ Poco::Net::HTTPServerParams::Ptr params_,
+ HTTPRequestHandlerFactoryPtr factory_,
+ const String & forwarded_for_)
+ : HTTPServerConnection(context_, tcp_server_, socket_, params_, factory_)
+ {
+ forwarded_for = forwarded_for_;
+ }
+
+ void run() override;
+
+protected:
+ static void sendErrorResponse(Poco::Net::HTTPServerSession & session, Poco::Net::HTTPResponse::HTTPStatus status);
+
+private:
+ HTTPContextPtr context;
+ TCPServer & tcp_server;
+ Poco::Net::HTTPServerParams::Ptr params;
+ HTTPRequestHandlerFactoryPtr factory;
+ String forwarded_for;
+ bool stopped;
+ std::mutex mutex; // guards the |factory| with assumption that creating handlers is not thread-safe.
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp
new file mode 100644
index 0000000000..2c9ac0cda2
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp
@@ -0,0 +1,24 @@
+#include <Server/HTTP/HTTPServerConnectionFactory.h>
+
+#include <Server/HTTP/HTTPServerConnection.h>
+
+namespace DB
+{
+HTTPServerConnectionFactory::HTTPServerConnectionFactory(
+ HTTPContextPtr context_, Poco::Net::HTTPServerParams::Ptr params_, HTTPRequestHandlerFactoryPtr factory_)
+ : context(std::move(context_)), params(params_), factory(factory_)
+{
+ poco_check_ptr(factory);
+}
+
+Poco::Net::TCPServerConnection * HTTPServerConnectionFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server)
+{
+ return new HTTPServerConnection(context, tcp_server, socket, params, factory);
+}
+
+Poco::Net::TCPServerConnection * HTTPServerConnectionFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data)
+{
+ return new HTTPServerConnection(context, tcp_server, socket, params, factory, stack_data.forwarded_for);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h
new file mode 100644
index 0000000000..e18249da4d
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandlerFactory.h>
+#include <Server/HTTP/HTTPContext.h>
+#include <Server/TCPServerConnectionFactory.h>
+
+#include <Poco/Net/HTTPServerParams.h>
+
+namespace DB
+{
+
+class HTTPServerConnectionFactory : public TCPServerConnectionFactory
+{
+public:
+ HTTPServerConnectionFactory(HTTPContextPtr context, Poco::Net::HTTPServerParams::Ptr params, HTTPRequestHandlerFactoryPtr factory);
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override;
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) override;
+
+private:
+ HTTPContextPtr context;
+ Poco::Net::HTTPServerParams::Ptr params;
+ HTTPRequestHandlerFactoryPtr factory;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp
new file mode 100644
index 0000000000..891ac39c93
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp
@@ -0,0 +1,174 @@
+#include <Server/HTTP/HTTPServerRequest.h>
+
+#include <IO/EmptyReadBuffer.h>
+#include <IO/HTTPChunkedReadBuffer.h>
+#include <IO/LimitReadBuffer.h>
+#include <IO/ReadBufferFromPocoSocket.h>
+#include <IO/ReadHelpers.h>
+#include <Server/HTTP/HTTPServerResponse.h>
+#include <Server/HTTP/ReadHeaders.h>
+
+#include <Poco/Net/HTTPHeaderStream.h>
+#include <Poco/Net/HTTPStream.h>
+#include <Poco/Net/NetException.h>
+
+#include <Common/logger_useful.h>
+
+#if USE_SSL
+#include <Poco/Net/SecureStreamSocketImpl.h>
+#include <Poco/Net/SSLException.h>
+#include <Poco/Net/X509Certificate.h>
+#endif
+
+namespace DB
+{
+HTTPServerRequest::HTTPServerRequest(HTTPContextPtr context, HTTPServerResponse & response, Poco::Net::HTTPServerSession & session)
+ : max_uri_size(context->getMaxUriSize())
+ , max_fields_number(context->getMaxFields())
+ , max_field_name_size(context->getMaxFieldNameSize())
+ , max_field_value_size(context->getMaxFieldValueSize())
+{
+ response.attachRequest(this);
+
+ /// Now that we know socket is still connected, obtain addresses
+ client_address = session.clientAddress();
+ server_address = session.serverAddress();
+ secure = session.socket().secure();
+
+ auto receive_timeout = context->getReceiveTimeout();
+ auto send_timeout = context->getSendTimeout();
+
+ session.socket().setReceiveTimeout(receive_timeout);
+ session.socket().setSendTimeout(send_timeout);
+
+ auto in = std::make_unique<ReadBufferFromPocoSocket>(session.socket());
+ socket = session.socket().impl();
+
+ readRequest(*in); /// Try parse according to RFC7230
+
+ /// If a client crashes, most systems will gracefully terminate the connection with FIN just like it's done on close().
+ /// So we will get 0 from recv(...) and will not be able to understand that something went wrong (well, we probably
+ /// will get RST later on attempt to write to the socket that closed on the other side, but it will happen when the query is finished).
+ /// If we are extremely unlucky and data format is TSV, for example, then we may stop parsing exactly between rows
+ /// and decide that it's EOF (but it is not). It may break deduplication, because clients cannot control it
+ /// and retry with exactly the same (incomplete) set of rows.
+ /// That's why we have to check body size if it's provided.
+ if (getChunkedTransferEncoding())
+ stream = std::make_unique<HTTPChunkedReadBuffer>(std::move(in), context->getMaxChunkSize());
+ else if (hasContentLength())
+ {
+ size_t content_length = getContentLength();
+ stream = std::make_unique<LimitReadBuffer>(std::move(in), content_length,
+ /* trow_exception */ true, /* exact_limit */ content_length);
+ }
+ else if (getMethod() != HTTPRequest::HTTP_GET && getMethod() != HTTPRequest::HTTP_HEAD && getMethod() != HTTPRequest::HTTP_DELETE)
+ {
+ stream = std::move(in);
+ if (!startsWith(getContentType(), "multipart/form-data"))
+ LOG_WARNING(LogFrequencyLimiter(&Poco::Logger::get("HTTPServerRequest"), 10), "Got an HTTP request with no content length "
+ "and no chunked/multipart encoding, it may be impossible to distinguish graceful EOF from abnormal connection loss");
+ }
+ else
+ /// We have to distinguish empty buffer and nullptr.
+ stream = std::make_unique<EmptyReadBuffer>();
+}
+
+bool HTTPServerRequest::checkPeerConnected() const
+{
+ try
+ {
+ char b;
+ if (!socket->receiveBytes(&b, 1, MSG_DONTWAIT | MSG_PEEK))
+ return false;
+ }
+ catch (Poco::TimeoutException &)
+ {
+ }
+ catch (...)
+ {
+ return false;
+ }
+
+ return true;
+}
+
+#if USE_SSL
+bool HTTPServerRequest::havePeerCertificate() const
+{
+ if (!secure)
+ return false;
+
+ const Poco::Net::SecureStreamSocketImpl * secure_socket = dynamic_cast<const Poco::Net::SecureStreamSocketImpl *>(socket);
+ if (!secure_socket)
+ return false;
+
+ return secure_socket->havePeerCertificate();
+}
+
+Poco::Net::X509Certificate HTTPServerRequest::peerCertificate() const
+{
+ if (secure)
+ {
+ const Poco::Net::SecureStreamSocketImpl * secure_socket = dynamic_cast<const Poco::Net::SecureStreamSocketImpl *>(socket);
+ if (secure_socket)
+ return secure_socket->peerCertificate();
+ }
+ throw Poco::Net::SSLException("No certificate available");
+}
+#endif
+
+void HTTPServerRequest::readRequest(ReadBuffer & in)
+{
+ char ch;
+ std::string method;
+ std::string uri;
+ std::string version;
+
+ method.reserve(16);
+ uri.reserve(64);
+ version.reserve(16);
+
+ if (in.eof())
+ throw Poco::Net::NoMessageException();
+
+ skipWhitespaceIfAny(in);
+
+ if (in.eof())
+ throw Poco::Net::MessageException("No HTTP request header");
+
+ while (in.read(ch) && !Poco::Ascii::isSpace(ch) && method.size() <= MAX_METHOD_LENGTH)
+ method += ch;
+
+ if (method.size() > MAX_METHOD_LENGTH)
+ throw Poco::Net::MessageException("HTTP request method invalid or too long");
+
+ skipWhitespaceIfAny(in);
+
+ while (in.read(ch) && !Poco::Ascii::isSpace(ch) && uri.size() <= max_uri_size)
+ uri += ch;
+
+ if (uri.size() > max_uri_size)
+ throw Poco::Net::MessageException("HTTP request URI invalid or too long");
+
+ skipWhitespaceIfAny(in);
+
+ while (in.read(ch) && !Poco::Ascii::isSpace(ch) && version.size() <= MAX_VERSION_LENGTH)
+ version += ch;
+
+ if (version.size() > MAX_VERSION_LENGTH)
+ throw Poco::Net::MessageException("Invalid HTTP version string");
+
+ // since HTTP always use Windows-style EOL '\r\n' we always can safely skip to '\n'
+
+ skipToNextLineOrEOF(in);
+
+ readHeaders(*this, in, max_fields_number, max_field_name_size, max_field_value_size);
+
+ skipToNextLineOrEOF(in);
+
+ setMethod(method);
+ setURI(uri);
+ setVersion(version);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h
new file mode 100644
index 0000000000..da0e498b0d
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h
@@ -0,0 +1,73 @@
+#pragma once
+
+#include <Interpreters/Context_fwd.h>
+#include <IO/ReadBuffer.h>
+#include <Server/HTTP/HTTPRequest.h>
+#include <Server/HTTP/HTTPContext.h>
+#include "clickhouse_config.h"
+
+#include <Poco/Net/HTTPServerSession.h>
+
+namespace Poco::Net { class X509Certificate; }
+
+namespace DB
+{
+
+class HTTPServerResponse;
+class ReadBufferFromPocoSocket;
+
+class HTTPServerRequest : public HTTPRequest
+{
+public:
+ HTTPServerRequest(HTTPContextPtr context, HTTPServerResponse & response, Poco::Net::HTTPServerSession & session);
+
+ /// FIXME: it's a little bit inconvenient interface. The rationale is that all other ReadBuffer's wrap each other
+ /// via unique_ptr - but we can't inherit HTTPServerRequest from ReadBuffer and pass it around,
+ /// since we also need it in other places.
+
+ /// Returns the input stream for reading the request body.
+ ReadBuffer & getStream()
+ {
+ poco_check_ptr(stream);
+ return *stream;
+ }
+
+ bool checkPeerConnected() const;
+
+ bool isSecure() const { return secure; }
+
+ /// Returns the client's address.
+ const Poco::Net::SocketAddress & clientAddress() const { return client_address; }
+
+ /// Returns the server's address.
+ const Poco::Net::SocketAddress & serverAddress() const { return server_address; }
+
+#if USE_SSL
+ bool havePeerCertificate() const;
+ Poco::Net::X509Certificate peerCertificate() const;
+#endif
+
+private:
+ /// Limits for basic sanity checks when reading a header
+ enum Limits
+ {
+ MAX_METHOD_LENGTH = 32,
+ MAX_VERSION_LENGTH = 8,
+ };
+
+ const size_t max_uri_size;
+ const size_t max_fields_number;
+ const size_t max_field_name_size;
+ const size_t max_field_value_size;
+
+ std::unique_ptr<ReadBuffer> stream;
+ Poco::Net::SocketImpl * socket;
+ Poco::Net::SocketAddress client_address;
+ Poco::Net::SocketAddress server_address;
+
+ bool secure;
+
+ void readRequest(ReadBuffer & in);
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp
new file mode 100644
index 0000000000..25e7604a51
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp
@@ -0,0 +1,121 @@
+#include <Server/HTTP/HTTPServerResponse.h>
+#include <Server/HTTP/HTTPServerRequest.h>
+#include <Poco/CountingStream.h>
+#include <Poco/DateTimeFormat.h>
+#include <Poco/DateTimeFormatter.h>
+#include <Poco/FileStream.h>
+#include <Poco/Net/HTTPChunkedStream.h>
+#include <Poco/Net/HTTPFixedLengthStream.h>
+#include <Poco/Net/HTTPHeaderStream.h>
+#include <Poco/Net/HTTPStream.h>
+#include <Poco/StreamCopier.h>
+
+
+namespace DB
+{
+
+HTTPServerResponse::HTTPServerResponse(Poco::Net::HTTPServerSession & session_) : session(session_)
+{
+}
+
+void HTTPServerResponse::sendContinue()
+{
+ Poco::Net::HTTPHeaderOutputStream hs(session);
+ hs << getVersion() << " 100 Continue\r\n\r\n";
+}
+
+std::shared_ptr<std::ostream> HTTPServerResponse::send()
+{
+ poco_assert(!stream);
+
+ if ((request && request->getMethod() == HTTPRequest::HTTP_HEAD) || getStatus() < 200 || getStatus() == HTTPResponse::HTTP_NO_CONTENT
+ || getStatus() == HTTPResponse::HTTP_NOT_MODIFIED)
+ {
+ Poco::CountingOutputStream cs;
+ write(cs);
+ stream = std::make_shared<Poco::Net::HTTPFixedLengthOutputStream>(session, cs.chars());
+ write(*stream);
+ }
+ else if (getChunkedTransferEncoding())
+ {
+ Poco::Net::HTTPHeaderOutputStream hs(session);
+ write(hs);
+ stream = std::make_shared<Poco::Net::HTTPChunkedOutputStream>(session);
+ }
+ else if (hasContentLength())
+ {
+ Poco::CountingOutputStream cs;
+ write(cs);
+ stream = std::make_shared<Poco::Net::HTTPFixedLengthOutputStream>(session, getContentLength64() + cs.chars());
+ write(*stream);
+ }
+ else
+ {
+ stream = std::make_shared<Poco::Net::HTTPOutputStream>(session);
+ setKeepAlive(false);
+ write(*stream);
+ }
+
+ return stream;
+}
+
+std::pair<std::shared_ptr<std::ostream>, std::shared_ptr<std::ostream>> HTTPServerResponse::beginSend()
+{
+ poco_assert(!stream);
+ poco_assert(!header_stream);
+
+ /// NOTE: Code is not exception safe.
+
+ if ((request && request->getMethod() == HTTPRequest::HTTP_HEAD) || getStatus() < 200 || getStatus() == HTTPResponse::HTTP_NO_CONTENT
+ || getStatus() == HTTPResponse::HTTP_NOT_MODIFIED)
+ {
+ throw Poco::Exception("HTTPServerResponse::beginSend is invalid for HEAD request");
+ }
+ else if (getChunkedTransferEncoding())
+ {
+ header_stream = std::make_shared<Poco::Net::HTTPHeaderOutputStream>(session);
+ beginWrite(*header_stream);
+ stream = std::make_shared<Poco::Net::HTTPChunkedOutputStream>(session);
+ }
+ else if (hasContentLength())
+ {
+ throw Poco::Exception("HTTPServerResponse::beginSend is invalid for response with Content-Length header");
+ }
+ else
+ {
+ stream = std::make_shared<Poco::Net::HTTPOutputStream>(session);
+ header_stream = stream;
+ setKeepAlive(false);
+ beginWrite(*stream);
+ }
+
+ return std::make_pair(header_stream, stream);
+}
+
+void HTTPServerResponse::sendBuffer(const void * buffer, std::size_t length)
+{
+ poco_assert(!stream);
+
+ setContentLength(static_cast<int>(length));
+ setChunkedTransferEncoding(false);
+
+ stream = std::make_shared<Poco::Net::HTTPHeaderOutputStream>(session);
+ write(*stream);
+ if (request && request->getMethod() != HTTPRequest::HTTP_HEAD)
+ {
+ stream->write(static_cast<const char *>(buffer), static_cast<std::streamsize>(length));
+ }
+}
+
+void HTTPServerResponse::requireAuthentication(const std::string & realm)
+{
+ poco_assert(!stream);
+
+ setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED);
+ std::string auth("Basic realm=\"");
+ auth.append(realm);
+ auth.append("\"");
+ set("WWW-Authenticate", auth);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h
new file mode 100644
index 0000000000..236a56e232
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h
@@ -0,0 +1,70 @@
+#pragma once
+
+#include <Server/HTTP/HTTPResponse.h>
+
+#include <Poco/Net/HTTPServerSession.h>
+#include <Poco/Net/HTTPResponse.h>
+
+#include <memory>
+
+
+namespace DB
+{
+
+class HTTPServerRequest;
+
+class HTTPServerResponse : public HTTPResponse
+{
+public:
+ explicit HTTPServerResponse(Poco::Net::HTTPServerSession & session);
+
+ void sendContinue(); /// Sends a 100 Continue response to the client.
+
+ /// Sends the response header to the client and
+ /// returns an output stream for sending the
+ /// response body.
+ ///
+ /// Must not be called after beginSend(), sendFile(), sendBuffer()
+ /// or redirect() has been called.
+ std::shared_ptr<std::ostream> send(); /// TODO: use some WriteBuffer implementation here.
+
+ /// Sends the response headers to the client
+ /// but do not finish headers with \r\n,
+ /// allowing to continue sending additional header fields.
+ ///
+ /// Must not be called after send(), sendFile(), sendBuffer()
+ /// or redirect() has been called.
+ std::pair<std::shared_ptr<std::ostream>, std::shared_ptr<std::ostream>> beginSend(); /// TODO: use some WriteBuffer implementation here.
+
+ /// Sends the response header to the client, followed
+ /// by the contents of the given buffer.
+ ///
+ /// The Content-Length header of the response is set
+ /// to length and chunked transfer encoding is disabled.
+ ///
+ /// If both the HTTP message header and body (from the
+ /// given buffer) fit into one single network packet, the
+ /// complete response can be sent in one network packet.
+ ///
+ /// Must not be called after send(), sendFile()
+ /// or redirect() has been called.
+ void sendBuffer(const void * pBuffer, std::size_t length); /// FIXME: do we need this one?
+
+ void requireAuthentication(const std::string & realm);
+ /// Sets the status code to 401 (Unauthorized)
+ /// and sets the "WWW-Authenticate" header field
+ /// according to the given realm.
+
+ /// Returns true if the response (header) has been sent.
+ bool sent() const { return !!stream; }
+
+ void attachRequest(HTTPServerRequest * request_) { request = request_; }
+
+private:
+ Poco::Net::HTTPServerSession & session;
+ HTTPServerRequest * request = nullptr;
+ std::shared_ptr<std::ostream> stream;
+ std::shared_ptr<std::ostream> header_stream;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/README.md b/contrib/clickhouse/src/Server/HTTP/README.md
new file mode 100644
index 0000000000..7173096278
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/README.md
@@ -0,0 +1,3 @@
+# Notice
+
+The source code located in this folder is based on some files from the POCO project, from here `contrib/poco/Net/src`.
diff --git a/contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp
new file mode 100644
index 0000000000..b705750106
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp
@@ -0,0 +1,85 @@
+#include <Server/HTTP/ReadHeaders.h>
+
+#include <IO/ReadBuffer.h>
+#include <IO/ReadHelpers.h>
+
+#include <Poco/Net/NetException.h>
+
+namespace DB
+{
+
+void readHeaders(
+ Poco::Net::MessageHeader & headers, ReadBuffer & in, size_t max_fields_number, size_t max_name_length, size_t max_value_length)
+{
+ char ch = 0; // silence uninitialized warning from gcc-*
+ std::string name;
+ std::string value;
+
+ name.reserve(32);
+ value.reserve(64);
+
+ size_t fields = 0;
+
+ while (true)
+ {
+ if (fields > max_fields_number)
+ throw Poco::Net::MessageException("Too many header fields");
+
+ name.clear();
+ value.clear();
+
+ /// Field name
+ while (in.peek(ch) && ch != ':' && !Poco::Ascii::isSpace(ch) && name.size() <= max_name_length)
+ {
+ name += ch;
+ in.ignore();
+ }
+
+ if (in.eof())
+ throw Poco::Net::MessageException("Field is invalid");
+
+ if (name.empty())
+ {
+ if (ch == '\r')
+ /// Start of the empty-line delimiter
+ break;
+ if (ch == ':')
+ throw Poco::Net::MessageException("Field name is empty");
+ }
+ else
+ {
+ if (name.size() > max_name_length)
+ throw Poco::Net::MessageException("Field name is too long");
+ if (ch != ':')
+ throw Poco::Net::MessageException(fmt::format("Field name is invalid or no colon found: \"{}\"", name));
+ }
+
+ in.ignore();
+
+ skipWhitespaceIfAny(in, true);
+
+ if (in.eof())
+ throw Poco::Net::MessageException("Field is invalid");
+
+ /// Field value - folded values not supported.
+ while (in.read(ch) && ch != '\r' && ch != '\n' && value.size() <= max_value_length)
+ value += ch;
+
+ if (in.eof())
+ throw Poco::Net::MessageException("Field is invalid");
+
+ if (ch == '\n')
+ throw Poco::Net::MessageException("No CRLF found");
+
+ if (value.size() > max_value_length)
+ throw Poco::Net::MessageException("Field value is too long");
+
+ skipToNextLineOrEOF(in);
+
+ Poco::trimRightInPlace(value);
+ headers.add(name, headers.decodeWord(value));
+ ++fields;
+ }
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/ReadHeaders.h b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.h
new file mode 100644
index 0000000000..1b0e627f77
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <Poco/Net/MessageHeader.h>
+
+namespace DB
+{
+
+class ReadBuffer;
+
+void readHeaders(
+ Poco::Net::MessageHeader & headers, ReadBuffer & in, size_t max_fields_number, size_t max_name_length, size_t max_value_length);
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp
new file mode 100644
index 0000000000..046feee017
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp
@@ -0,0 +1,204 @@
+#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
+
+#include <IO/HTTPCommon.h>
+#include <IO/Progress.h>
+#include <IO/WriteBufferFromString.h>
+#include <IO/WriteHelpers.h>
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+}
+
+
+void WriteBufferFromHTTPServerResponse::startSendHeaders()
+{
+ if (!headers_started_sending)
+ {
+ headers_started_sending = true;
+
+ if (add_cors_header)
+ response.set("Access-Control-Allow-Origin", "*");
+
+ setResponseDefaultHeaders(response, keep_alive_timeout);
+
+ if (!is_http_method_head)
+ std::tie(response_header_ostr, response_body_ostr) = response.beginSend();
+ }
+}
+
+void WriteBufferFromHTTPServerResponse::writeHeaderProgressImpl(const char * header_name)
+{
+ if (headers_finished_sending)
+ return;
+
+ WriteBufferFromOwnString progress_string_writer;
+
+ accumulated_progress.writeJSON(progress_string_writer);
+
+ if (response_header_ostr)
+ *response_header_ostr << header_name << progress_string_writer.str() << "\r\n" << std::flush;
+}
+
+void WriteBufferFromHTTPServerResponse::writeHeaderSummary()
+{
+ writeHeaderProgressImpl("X-ClickHouse-Summary: ");
+}
+
+void WriteBufferFromHTTPServerResponse::writeHeaderProgress()
+{
+ writeHeaderProgressImpl("X-ClickHouse-Progress: ");
+}
+
+void WriteBufferFromHTTPServerResponse::writeExceptionCode()
+{
+ if (headers_finished_sending || !exception_code)
+ return;
+ if (response_header_ostr)
+ *response_header_ostr << "X-ClickHouse-Exception-Code: " << exception_code << "\r\n" << std::flush;
+}
+
+void WriteBufferFromHTTPServerResponse::finishSendHeaders()
+{
+ if (!headers_finished_sending)
+ {
+ writeHeaderSummary();
+ writeExceptionCode();
+ headers_finished_sending = true;
+
+ if (!is_http_method_head)
+ {
+ /// Send end of headers delimiter.
+ if (response_header_ostr)
+ *response_header_ostr << "\r\n" << std::flush;
+ }
+ else
+ {
+ if (!response_body_ostr)
+ response_body_ostr = response.send();
+ }
+ }
+}
+
+
+void WriteBufferFromHTTPServerResponse::nextImpl()
+{
+ if (!initialized)
+ {
+ std::lock_guard lock(mutex);
+
+ /// Initialize as early as possible since if the code throws,
+ /// next() should not be called anymore.
+ initialized = true;
+
+ startSendHeaders();
+
+ if (!out && !is_http_method_head)
+ {
+ if (compress)
+ {
+ auto content_encoding_name = toContentEncodingName(compression_method);
+
+ *response_header_ostr << "Content-Encoding: " << content_encoding_name << "\r\n";
+ }
+
+ /// We reuse our buffer in "out" to avoid extra allocations and copies.
+
+ if (compress)
+ out = wrapWriteBufferWithCompressionMethod(
+ std::make_unique<WriteBufferFromOStream>(*response_body_ostr),
+ compress ? compression_method : CompressionMethod::None,
+ compression_level,
+ working_buffer.size(),
+ working_buffer.begin());
+ else
+ out = std::make_unique<WriteBufferFromOStream>(
+ *response_body_ostr,
+ working_buffer.size(),
+ working_buffer.begin());
+ }
+
+ finishSendHeaders();
+ }
+
+ if (out)
+ {
+ out->buffer() = buffer();
+ out->position() = position();
+ out->next();
+ }
+}
+
+
+WriteBufferFromHTTPServerResponse::WriteBufferFromHTTPServerResponse(
+ HTTPServerResponse & response_,
+ bool is_http_method_head_,
+ size_t keep_alive_timeout_,
+ bool compress_,
+ CompressionMethod compression_method_)
+ : BufferWithOwnMemory<WriteBuffer>(DBMS_DEFAULT_BUFFER_SIZE)
+ , response(response_)
+ , is_http_method_head(is_http_method_head_)
+ , keep_alive_timeout(keep_alive_timeout_)
+ , compress(compress_)
+ , compression_method(compression_method_)
+{
+}
+
+
+void WriteBufferFromHTTPServerResponse::onProgress(const Progress & progress)
+{
+ std::lock_guard lock(mutex);
+
+ /// Cannot add new headers if body was started to send.
+ if (headers_finished_sending)
+ return;
+
+ accumulated_progress.incrementPiecewiseAtomically(progress);
+ if (send_progress && progress_watch.elapsed() >= send_progress_interval_ms * 1000000)
+ {
+ progress_watch.restart();
+
+ /// Send all common headers before our special progress headers.
+ startSendHeaders();
+ writeHeaderProgress();
+ }
+}
+
+WriteBufferFromHTTPServerResponse::~WriteBufferFromHTTPServerResponse()
+{
+ finalize();
+}
+
+void WriteBufferFromHTTPServerResponse::finalizeImpl()
+{
+ try
+ {
+ next();
+ if (out)
+ out->finalize();
+ out.reset();
+ /// Catch write-after-finalize bugs.
+ set(nullptr, 0);
+ }
+ catch (...)
+ {
+ /// Avoid calling WriteBufferFromOStream::next() from dtor
+ /// (via WriteBufferFromHTTPServerResponse::next())
+ out.reset();
+ throw;
+ }
+
+ if (!offset())
+ {
+ /// If no remaining data, just send headers.
+ std::lock_guard lock(mutex);
+ startSendHeaders();
+ finishSendHeaders();
+ }
+}
+
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h
new file mode 100644
index 0000000000..94202e1e0e
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h
@@ -0,0 +1,134 @@
+#pragma once
+
+#include <IO/BufferWithOwnMemory.h>
+#include <IO/CompressionMethod.h>
+#include <IO/HTTPCommon.h>
+#include <IO/Progress.h>
+#include <IO/WriteBuffer.h>
+#include <IO/WriteBufferFromOStream.h>
+#include <Server/HTTP/HTTPServerResponse.h>
+#include <Common/NetException.h>
+#include <Common/Stopwatch.h>
+
+#include <mutex>
+#include <optional>
+
+
+namespace DB
+{
+
+/// The difference from WriteBufferFromOStream is that this buffer gets the underlying std::ostream
+/// (using response.send()) only after data is flushed for the first time. This is needed in HTTP
+/// servers to change some HTTP headers (e.g. response code) before any data is sent to the client
+/// (headers can't be changed after response.send() is called).
+///
+/// In short, it allows delaying the call to response.send().
+///
+/// Additionally, supports HTTP response compression (in this case corresponding Content-Encoding
+/// header will be set).
+///
+/// Also this class write and flush special X-ClickHouse-Progress HTTP headers
+/// if no data was sent at the time of progress notification.
+/// This allows to implement progress bar in HTTP clients.
+class WriteBufferFromHTTPServerResponse final : public BufferWithOwnMemory<WriteBuffer>
+{
+public:
+ WriteBufferFromHTTPServerResponse(
+ HTTPServerResponse & response_,
+ bool is_http_method_head_,
+ size_t keep_alive_timeout_,
+ bool compress_ = false, /// If true - set Content-Encoding header and compress the result.
+ CompressionMethod compression_method_ = CompressionMethod::None);
+
+ ~WriteBufferFromHTTPServerResponse() override;
+
+ /// Writes progress in repeating HTTP headers.
+ void onProgress(const Progress & progress);
+
+ /// Turn compression on or off.
+ /// The setting has any effect only if HTTP headers haven't been sent yet.
+ void setCompression(bool enable_compression)
+ {
+ compress = enable_compression;
+ }
+
+ /// Set compression level if the compression is turned on.
+ /// The setting has any effect only if HTTP headers haven't been sent yet.
+ void setCompressionLevel(int level)
+ {
+ compression_level = level;
+ }
+
+ /// Turn CORS on or off.
+ /// The setting has any effect only if HTTP headers haven't been sent yet.
+ void addHeaderCORS(bool enable_cors)
+ {
+ add_cors_header = enable_cors;
+ }
+
+ /// Send progress
+ void setSendProgress(bool send_progress_) { send_progress = send_progress_; }
+
+ /// Don't send HTTP headers with progress more frequently.
+ void setSendProgressInterval(size_t send_progress_interval_ms_)
+ {
+ send_progress_interval_ms = send_progress_interval_ms_;
+ }
+
+ void setExceptionCode(int exception_code_) { exception_code = exception_code_; }
+
+private:
+ /// Send at least HTTP headers if no data has been sent yet.
+ /// Use after the data has possibly been sent and no error happened (and thus you do not plan
+ /// to change response HTTP code.
+ /// This method is idempotent.
+ void finalizeImpl() override;
+
+ /// Must be called under locked mutex.
+ /// This method send headers, if this was not done already,
+ /// but not finish them with \r\n, allowing to send more headers subsequently.
+ void startSendHeaders();
+
+ /// Used to write the header X-ClickHouse-Progress / X-ClickHouse-Summary
+ void writeHeaderProgressImpl(const char * header_name);
+ /// Used to write the header X-ClickHouse-Progress
+ void writeHeaderProgress();
+ /// Used to write the header X-ClickHouse-Summary
+ void writeHeaderSummary();
+ /// Use to write the header X-ClickHouse-Exception-Code even when progress has been sent
+ void writeExceptionCode();
+
+ /// This method finish headers with \r\n, allowing to start to send body.
+ void finishSendHeaders();
+
+ void nextImpl() override;
+
+ HTTPServerResponse & response;
+
+ bool is_http_method_head;
+ bool add_cors_header = false;
+ size_t keep_alive_timeout = 0;
+ bool compress = false;
+ CompressionMethod compression_method;
+ int compression_level = 1;
+
+ std::shared_ptr<std::ostream> response_body_ostr;
+ std::shared_ptr<std::ostream> response_header_ostr;
+
+ std::unique_ptr<WriteBuffer> out;
+ bool initialized = false;
+
+ bool headers_started_sending = false;
+ bool headers_finished_sending = false; /// If true, you could not add any headers.
+
+ Progress accumulated_progress;
+ bool send_progress = false;
+ size_t send_progress_interval_ms = 100;
+ Stopwatch progress_watch;
+
+ int exception_code = 0;
+
+ std::mutex mutex; /// progress callback could be called from different threads.
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPHandler.cpp b/contrib/clickhouse/src/Server/HTTPHandler.cpp
new file mode 100644
index 0000000000..41ed78bc69
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPHandler.cpp
@@ -0,0 +1,1320 @@
+#include <Server/HTTPHandler.h>
+
+#include <Access/Authentication.h>
+#include <Access/Credentials.h>
+#include <Access/ExternalAuthenticators.h>
+#include <Compression/CompressedReadBuffer.h>
+#include <Compression/CompressedWriteBuffer.h>
+#include <Core/ExternalTable.h>
+#include <Disks/StoragePolicy.h>
+#include <IO/CascadeWriteBuffer.h>
+#include <IO/ConcatReadBuffer.h>
+#include <IO/MemoryReadWriteBuffer.h>
+#include <IO/ReadBufferFromString.h>
+#include <IO/WriteHelpers.h>
+#include <IO/copyData.h>
+#include <Interpreters/Context.h>
+#include <Interpreters/TemporaryDataOnDisk.h>
+#include <Parsers/QueryParameterVisitor.h>
+#include <Interpreters/executeQuery.h>
+#include <Interpreters/Session.h>
+#include <Server/HTTPHandlerFactory.h>
+#include <Server/HTTPHandlerRequestFilter.h>
+#include <Server/IServer.h>
+#include <Common/logger_useful.h>
+#include <Common/SettingsChanges.h>
+#include <Common/StringUtils/StringUtils.h>
+#include <Common/scope_guard_safe.h>
+#include <Common/setThreadName.h>
+#include <Common/typeid_cast.h>
+#include <Parsers/ASTSetQuery.h>
+
+#include <base/getFQDNOrHostName.h>
+#include <base/scope_guard.h>
+#include <Server/HTTP/HTTPResponse.h>
+
+#include "clickhouse_config.h"
+
+#include <Poco/Base64Decoder.h>
+#include <Poco/Base64Encoder.h>
+#include <Poco/Net/HTTPBasicCredentials.h>
+#include <Poco/Net/HTTPStream.h>
+#include <Poco/MemoryStream.h>
+#include <Poco/StreamCopier.h>
+#include <Poco/String.h>
+#include <Poco/Net/SocketAddress.h>
+
+#include <re2/re2.h>
+
+#include <chrono>
+#include <sstream>
+
+#if USE_SSL
+#include <Poco/Net/X509Certificate.h>
+#endif
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int LOGICAL_ERROR;
+ extern const int CANNOT_PARSE_TEXT;
+ extern const int CANNOT_PARSE_ESCAPE_SEQUENCE;
+ extern const int CANNOT_PARSE_QUOTED_STRING;
+ extern const int CANNOT_PARSE_DATE;
+ extern const int CANNOT_PARSE_DATETIME;
+ extern const int CANNOT_PARSE_NUMBER;
+ extern const int CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING;
+ extern const int CANNOT_PARSE_IPV4;
+ extern const int CANNOT_PARSE_IPV6;
+ extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED;
+ extern const int CANNOT_OPEN_FILE;
+ extern const int CANNOT_COMPILE_REGEXP;
+
+ extern const int UNKNOWN_ELEMENT_IN_AST;
+ extern const int UNKNOWN_TYPE_OF_AST_NODE;
+ extern const int TOO_DEEP_AST;
+ extern const int TOO_BIG_AST;
+ extern const int UNEXPECTED_AST_STRUCTURE;
+
+ extern const int SYNTAX_ERROR;
+
+ extern const int INCORRECT_DATA;
+ extern const int TYPE_MISMATCH;
+
+ extern const int UNKNOWN_TABLE;
+ extern const int UNKNOWN_FUNCTION;
+ extern const int UNKNOWN_IDENTIFIER;
+ extern const int UNKNOWN_TYPE;
+ extern const int UNKNOWN_STORAGE;
+ extern const int UNKNOWN_DATABASE;
+ extern const int UNKNOWN_SETTING;
+ extern const int UNKNOWN_DIRECTION_OF_SORTING;
+ extern const int UNKNOWN_AGGREGATE_FUNCTION;
+ extern const int UNKNOWN_FORMAT;
+ extern const int UNKNOWN_DATABASE_ENGINE;
+ extern const int UNKNOWN_TYPE_OF_QUERY;
+ extern const int NO_ELEMENTS_IN_CONFIG;
+
+ extern const int QUERY_IS_TOO_LARGE;
+
+ extern const int NOT_IMPLEMENTED;
+ extern const int SOCKET_TIMEOUT;
+
+ extern const int UNKNOWN_USER;
+ extern const int WRONG_PASSWORD;
+ extern const int REQUIRED_PASSWORD;
+ extern const int AUTHENTICATION_FAILED;
+
+ extern const int INVALID_SESSION_TIMEOUT;
+ extern const int HTTP_LENGTH_REQUIRED;
+ extern const int SUPPORT_IS_DISABLED;
+
+ extern const int TIMEOUT_EXCEEDED;
+}
+
+namespace
+{
+bool tryAddHttpOptionHeadersFromConfig(HTTPServerResponse & response, const Poco::Util::LayeredConfiguration & config)
+{
+ if (config.has("http_options_response"))
+ {
+ Strings config_keys;
+ config.keys("http_options_response", config_keys);
+ for (const std::string & config_key : config_keys)
+ {
+ if (config_key == "header" || config_key.starts_with("header["))
+ {
+ /// If there is empty header name, it will not be processed and message about it will be in logs
+ if (config.getString("http_options_response." + config_key + ".name", "").empty())
+ LOG_WARNING(&Poco::Logger::get("processOptionsRequest"), "Empty header was found in config. It will not be processed.");
+ else
+ response.add(config.getString("http_options_response." + config_key + ".name", ""),
+ config.getString("http_options_response." + config_key + ".value", ""));
+
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+/// Process options request. Useful for CORS.
+void processOptionsRequest(HTTPServerResponse & response, const Poco::Util::LayeredConfiguration & config)
+{
+ /// If can add some headers from config
+ if (tryAddHttpOptionHeadersFromConfig(response, config))
+ {
+ response.setKeepAlive(false);
+ response.setStatusAndReason(HTTPResponse::HTTP_NO_CONTENT);
+ response.send();
+ }
+}
+}
+
+static String base64Decode(const String & encoded)
+{
+ String decoded;
+ Poco::MemoryInputStream istr(encoded.data(), encoded.size());
+ Poco::Base64Decoder decoder(istr);
+ Poco::StreamCopier::copyToString(decoder, decoded);
+ return decoded;
+}
+
+static String base64Encode(const String & decoded)
+{
+ std::ostringstream ostr; // STYLE_CHECK_ALLOW_STD_STRING_STREAM
+ ostr.exceptions(std::ios::failbit);
+ Poco::Base64Encoder encoder(ostr);
+ encoder.rdbuf()->setLineLength(0);
+ encoder << decoded;
+ encoder.close();
+ return ostr.str();
+}
+
+static Poco::Net::HTTPResponse::HTTPStatus exceptionCodeToHTTPStatus(int exception_code)
+{
+ using namespace Poco::Net;
+
+ if (exception_code == ErrorCodes::REQUIRED_PASSWORD)
+ {
+ return HTTPResponse::HTTP_UNAUTHORIZED;
+ }
+ else if (exception_code == ErrorCodes::UNKNOWN_USER ||
+ exception_code == ErrorCodes::WRONG_PASSWORD ||
+ exception_code == ErrorCodes::AUTHENTICATION_FAILED)
+ {
+ return HTTPResponse::HTTP_FORBIDDEN;
+ }
+ else if (exception_code == ErrorCodes::CANNOT_PARSE_TEXT ||
+ exception_code == ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE ||
+ exception_code == ErrorCodes::CANNOT_PARSE_QUOTED_STRING ||
+ exception_code == ErrorCodes::CANNOT_PARSE_DATE ||
+ exception_code == ErrorCodes::CANNOT_PARSE_DATETIME ||
+ exception_code == ErrorCodes::CANNOT_PARSE_NUMBER ||
+ exception_code == ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING ||
+ exception_code == ErrorCodes::CANNOT_PARSE_IPV4 ||
+ exception_code == ErrorCodes::CANNOT_PARSE_IPV6 ||
+ exception_code == ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED ||
+ exception_code == ErrorCodes::UNKNOWN_ELEMENT_IN_AST ||
+ exception_code == ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE ||
+ exception_code == ErrorCodes::TOO_DEEP_AST ||
+ exception_code == ErrorCodes::TOO_BIG_AST ||
+ exception_code == ErrorCodes::UNEXPECTED_AST_STRUCTURE ||
+ exception_code == ErrorCodes::SYNTAX_ERROR ||
+ exception_code == ErrorCodes::INCORRECT_DATA ||
+ exception_code == ErrorCodes::TYPE_MISMATCH)
+ {
+ return HTTPResponse::HTTP_BAD_REQUEST;
+ }
+ else if (exception_code == ErrorCodes::UNKNOWN_TABLE ||
+ exception_code == ErrorCodes::UNKNOWN_FUNCTION ||
+ exception_code == ErrorCodes::UNKNOWN_IDENTIFIER ||
+ exception_code == ErrorCodes::UNKNOWN_TYPE ||
+ exception_code == ErrorCodes::UNKNOWN_STORAGE ||
+ exception_code == ErrorCodes::UNKNOWN_DATABASE ||
+ exception_code == ErrorCodes::UNKNOWN_SETTING ||
+ exception_code == ErrorCodes::UNKNOWN_DIRECTION_OF_SORTING ||
+ exception_code == ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION ||
+ exception_code == ErrorCodes::UNKNOWN_FORMAT ||
+ exception_code == ErrorCodes::UNKNOWN_DATABASE_ENGINE ||
+ exception_code == ErrorCodes::UNKNOWN_TYPE_OF_QUERY)
+ {
+ return HTTPResponse::HTTP_NOT_FOUND;
+ }
+ else if (exception_code == ErrorCodes::QUERY_IS_TOO_LARGE)
+ {
+ return HTTPResponse::HTTP_REQUESTENTITYTOOLARGE;
+ }
+ else if (exception_code == ErrorCodes::NOT_IMPLEMENTED)
+ {
+ return HTTPResponse::HTTP_NOT_IMPLEMENTED;
+ }
+ else if (exception_code == ErrorCodes::SOCKET_TIMEOUT ||
+ exception_code == ErrorCodes::CANNOT_OPEN_FILE)
+ {
+ return HTTPResponse::HTTP_SERVICE_UNAVAILABLE;
+ }
+ else if (exception_code == ErrorCodes::HTTP_LENGTH_REQUIRED)
+ {
+ return HTTPResponse::HTTP_LENGTH_REQUIRED;
+ }
+ else if (exception_code == ErrorCodes::TIMEOUT_EXCEEDED)
+ {
+ return HTTPResponse::HTTP_REQUEST_TIMEOUT;
+ }
+
+ return HTTPResponse::HTTP_INTERNAL_SERVER_ERROR;
+}
+
+
+static std::chrono::steady_clock::duration parseSessionTimeout(
+ const Poco::Util::AbstractConfiguration & config,
+ const HTMLForm & params)
+{
+ unsigned session_timeout = config.getInt("default_session_timeout", 60);
+
+ if (params.has("session_timeout"))
+ {
+ unsigned max_session_timeout = config.getUInt("max_session_timeout", 3600);
+ std::string session_timeout_str = params.get("session_timeout");
+
+ ReadBufferFromString buf(session_timeout_str);
+ if (!tryReadIntText(session_timeout, buf) || !buf.eof())
+ throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Invalid session timeout: '{}'", session_timeout_str);
+
+ if (session_timeout > max_session_timeout)
+ throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Session timeout '{}' is larger than max_session_timeout: {}. "
+ "Maximum session timeout could be modified in configuration file.",
+ session_timeout_str, max_session_timeout);
+ }
+
+ return std::chrono::seconds(session_timeout);
+}
+
+
+void HTTPHandler::pushDelayedResults(Output & used_output)
+{
+ std::vector<WriteBufferPtr> write_buffers;
+ ConcatReadBuffer::Buffers read_buffers;
+
+ auto * cascade_buffer = typeid_cast<CascadeWriteBuffer *>(used_output.out_maybe_delayed_and_compressed.get());
+ if (!cascade_buffer)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected CascadeWriteBuffer");
+
+ cascade_buffer->getResultBuffers(write_buffers);
+
+ if (write_buffers.empty())
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "At least one buffer is expected to overwrite result into HTTP response");
+
+ for (auto & write_buf : write_buffers)
+ {
+ if (!write_buf)
+ continue;
+
+ IReadableWriteBuffer * write_buf_concrete = dynamic_cast<IReadableWriteBuffer *>(write_buf.get());
+ if (write_buf_concrete)
+ {
+ ReadBufferPtr reread_buf = write_buf_concrete->tryGetReadBuffer();
+ if (reread_buf)
+ read_buffers.emplace_back(wrapReadBufferPointer(reread_buf));
+ }
+ }
+
+ if (!read_buffers.empty())
+ {
+ ConcatReadBuffer concat_read_buffer(std::move(read_buffers));
+ copyData(concat_read_buffer, *used_output.out_maybe_compressed);
+ }
+}
+
+
+HTTPHandler::HTTPHandler(IServer & server_, const std::string & name, const std::optional<String> & content_type_override_)
+ : server(server_)
+ , log(&Poco::Logger::get(name))
+ , default_settings(server.context()->getSettingsRef())
+ , content_type_override(content_type_override_)
+{
+ server_display_name = server.config().getString("display_name", getFQDNOrHostName());
+}
+
+
+/// We need d-tor to be present in this translation unit to make it play well with some
+/// forward decls in the header. Other than that, the default d-tor would be OK.
+HTTPHandler::~HTTPHandler() = default;
+
+
+bool HTTPHandler::authenticateUser(
+ HTTPServerRequest & request,
+ HTMLForm & params,
+ HTTPServerResponse & response)
+{
+ using namespace Poco::Net;
+
+ /// The user and password can be passed by headers (similar to X-Auth-*),
+ /// which is used by load balancers to pass authentication information.
+ std::string user = request.get("X-ClickHouse-User", "");
+ std::string password = request.get("X-ClickHouse-Key", "");
+ std::string quota_key = request.get("X-ClickHouse-Quota", "");
+
+ /// The header 'X-ClickHouse-SSL-Certificate-Auth: on' enables checking the common name
+ /// extracted from the SSL certificate used for this connection instead of checking password.
+ bool has_ssl_certificate_auth = (request.get("X-ClickHouse-SSL-Certificate-Auth", "") == "on");
+ bool has_auth_headers = !user.empty() || !password.empty() || !quota_key.empty() || has_ssl_certificate_auth;
+
+ /// User name and password can be passed using HTTP Basic auth or query parameters
+ /// (both methods are insecure).
+ bool has_http_credentials = request.hasCredentials();
+ bool has_credentials_in_query_params = params.has("user") || params.has("password") || params.has("quota_key");
+
+ std::string spnego_challenge;
+ std::string certificate_common_name;
+
+ if (has_auth_headers)
+ {
+ /// It is prohibited to mix different authorization schemes.
+ if (has_http_credentials)
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
+ "Invalid authentication: it is not allowed "
+ "to use SSL certificate authentication and Authorization HTTP header simultaneously");
+ if (has_credentials_in_query_params)
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
+ "Invalid authentication: it is not allowed "
+ "to use SSL certificate authentication and authentication via parameters simultaneously simultaneously");
+
+ if (has_ssl_certificate_auth)
+ {
+#if USE_SSL
+ if (!password.empty())
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
+ "Invalid authentication: it is not allowed "
+ "to use SSL certificate authentication and authentication via password simultaneously");
+
+ if (request.havePeerCertificate())
+ certificate_common_name = request.peerCertificate().commonName();
+
+ if (certificate_common_name.empty())
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
+ "Invalid authentication: SSL certificate authentication requires nonempty certificate's Common Name");
+#else
+ throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,
+ "SSL certificate authentication disabled because ClickHouse was built without SSL library");
+#endif
+ }
+ }
+ else if (has_http_credentials)
+ {
+ /// It is prohibited to mix different authorization schemes.
+ if (has_credentials_in_query_params)
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
+ "Invalid authentication: it is not allowed "
+ "to use Authorization HTTP header and authentication via parameters simultaneously");
+
+ std::string scheme;
+ std::string auth_info;
+ request.getCredentials(scheme, auth_info);
+
+ if (Poco::icompare(scheme, "Basic") == 0)
+ {
+ HTTPBasicCredentials credentials(auth_info);
+ user = credentials.getUsername();
+ password = credentials.getPassword();
+ }
+ else if (Poco::icompare(scheme, "Negotiate") == 0)
+ {
+ spnego_challenge = auth_info;
+
+ if (spnego_challenge.empty())
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: SPNEGO challenge is empty");
+ }
+ else
+ {
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: '{}' HTTP Authorization scheme is not supported", scheme);
+ }
+
+ quota_key = params.get("quota_key", "");
+ }
+ else
+ {
+ /// If the user name is not set we assume it's the 'default' user.
+ user = params.get("user", "default");
+ password = params.get("password", "");
+ quota_key = params.get("quota_key", "");
+ }
+
+ if (!certificate_common_name.empty())
+ {
+ if (!request_credentials)
+ request_credentials = std::make_unique<SSLCertificateCredentials>(user, certificate_common_name);
+
+ auto * certificate_credentials = dynamic_cast<SSLCertificateCredentials *>(request_credentials.get());
+ if (!certificate_credentials)
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: expected SSL certificate authorization scheme");
+ }
+ else if (!spnego_challenge.empty())
+ {
+ if (!request_credentials)
+ request_credentials = server.context()->makeGSSAcceptorContext();
+
+ auto * gss_acceptor_context = dynamic_cast<GSSAcceptorContext *>(request_credentials.get());
+ if (!gss_acceptor_context)
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: unexpected 'Negotiate' HTTP Authorization scheme expected");
+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunreachable-code"
+ const auto spnego_response = base64Encode(gss_acceptor_context->processToken(base64Decode(spnego_challenge), log));
+#pragma clang diagnostic pop
+
+ if (!spnego_response.empty())
+ response.set("WWW-Authenticate", "Negotiate " + spnego_response);
+
+ if (!gss_acceptor_context->isFailed() && !gss_acceptor_context->isReady())
+ {
+ if (spnego_response.empty())
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: 'Negotiate' HTTP Authorization failure");
+
+ response.setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED);
+ response.send();
+ return false;
+ }
+ }
+ else // I.e., now using user name and password strings ("Basic").
+ {
+ if (!request_credentials)
+ request_credentials = std::make_unique<BasicCredentials>();
+
+ auto * basic_credentials = dynamic_cast<BasicCredentials *>(request_credentials.get());
+ if (!basic_credentials)
+ throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: expected 'Basic' HTTP Authorization scheme");
+
+ basic_credentials->setUserName(user);
+ basic_credentials->setPassword(password);
+ }
+
+ /// Set client info. It will be used for quota accounting parameters in 'setUser' method.
+
+ ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
+ if (request.getMethod() == HTTPServerRequest::HTTP_GET)
+ http_method = ClientInfo::HTTPMethod::GET;
+ else if (request.getMethod() == HTTPServerRequest::HTTP_POST)
+ http_method = ClientInfo::HTTPMethod::POST;
+
+ session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""));
+ session->setForwardedFor(request.get("X-Forwarded-For", ""));
+ session->setQuotaClientKey(quota_key);
+
+ /// Extract the last entry from comma separated list of forwarded_for addresses.
+ /// Only the last proxy can be trusted (if any).
+ String forwarded_address = session->getClientInfo().getLastForwardedFor();
+ try
+ {
+ if (!forwarded_address.empty() && server.config().getBool("auth_use_forwarded_address", false))
+ session->authenticate(*request_credentials, Poco::Net::SocketAddress(forwarded_address, request.clientAddress().port()));
+ else
+ session->authenticate(*request_credentials, request.clientAddress());
+ }
+ catch (const Authentication::Require<BasicCredentials> & required_credentials)
+ {
+ request_credentials = std::make_unique<BasicCredentials>();
+
+ if (required_credentials.getRealm().empty())
+ response.set("WWW-Authenticate", "Basic");
+ else
+ response.set("WWW-Authenticate", "Basic realm=\"" + required_credentials.getRealm() + "\"");
+
+ response.setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED);
+ response.send();
+ return false;
+ }
+ catch (const Authentication::Require<GSSAcceptorContext> & required_credentials)
+ {
+ request_credentials = server.context()->makeGSSAcceptorContext();
+
+ if (required_credentials.getRealm().empty())
+ response.set("WWW-Authenticate", "Negotiate");
+ else
+ response.set("WWW-Authenticate", "Negotiate realm=\"" + required_credentials.getRealm() + "\"");
+
+ response.setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED);
+ response.send();
+ return false;
+ }
+
+ request_credentials.reset();
+ return true;
+}
+
+
+void HTTPHandler::processQuery(
+ HTTPServerRequest & request,
+ HTMLForm & params,
+ HTTPServerResponse & response,
+ Output & used_output,
+ std::optional<CurrentThread::QueryScope> & query_scope)
+{
+ using namespace Poco::Net;
+
+ LOG_TRACE(log, "Request URI: {}", request.getURI());
+
+ if (!authenticateUser(request, params, response))
+ return; // '401 Unauthorized' response with 'Negotiate' has been sent at this point.
+
+ /// The user could specify session identifier and session timeout.
+ /// It allows to modify settings, create temporary tables and reuse them in subsequent requests.
+ String session_id;
+ std::chrono::steady_clock::duration session_timeout;
+ bool session_is_set = params.has("session_id");
+ const auto & config = server.config();
+
+ if (session_is_set)
+ {
+ session_id = params.get("session_id");
+ session_timeout = parseSessionTimeout(config, params);
+ std::string session_check = params.get("session_check", "");
+ session->makeSessionContext(session_id, session_timeout, session_check == "1");
+ }
+ else
+ {
+ /// We should create it even if we don't have a session_id
+ session->makeSessionContext();
+ }
+
+ auto context = session->makeQueryContext();
+
+ /// This parameter is used to tune the behavior of output formats (such as Native) for compatibility.
+ if (params.has("client_protocol_version"))
+ {
+ UInt64 version_param = parse<UInt64>(params.get("client_protocol_version"));
+ context->setClientProtocolVersion(version_param);
+ }
+
+ /// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
+ String http_response_compression_methods = request.get("Accept-Encoding", "");
+ CompressionMethod http_response_compression_method = CompressionMethod::None;
+
+ if (!http_response_compression_methods.empty())
+ http_response_compression_method = chooseHTTPCompressionMethod(http_response_compression_methods);
+
+ bool client_supports_http_compression = http_response_compression_method != CompressionMethod::None;
+
+ /// Client can pass a 'compress' flag in the query string. In this case the query result is
+ /// compressed using internal algorithm. This is not reflected in HTTP headers.
+ bool internal_compression = params.getParsed<bool>("compress", false);
+
+ /// At least, we should postpone sending of first buffer_size result bytes
+ size_t buffer_size_total = std::max(
+ params.getParsed<size_t>("buffer_size", context->getSettingsRef().http_response_buffer_size),
+ static_cast<size_t>(DBMS_DEFAULT_BUFFER_SIZE));
+
+ /// If it is specified, the whole result will be buffered.
+ /// First ~buffer_size bytes will be buffered in memory, the remaining bytes will be stored in temporary file.
+ bool buffer_until_eof = params.getParsed<bool>("wait_end_of_query", context->getSettingsRef().http_wait_end_of_query);
+
+ size_t buffer_size_http = DBMS_DEFAULT_BUFFER_SIZE;
+ size_t buffer_size_memory = (buffer_size_total > buffer_size_http) ? buffer_size_total : 0;
+
+ unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
+
+ used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(
+ response,
+ request.getMethod() == HTTPRequest::HTTP_HEAD,
+ keep_alive_timeout,
+ client_supports_http_compression,
+ http_response_compression_method);
+
+ if (internal_compression)
+ used_output.out_maybe_compressed = std::make_shared<CompressedWriteBuffer>(*used_output.out);
+ else
+ used_output.out_maybe_compressed = used_output.out;
+
+ if (buffer_size_memory > 0 || buffer_until_eof)
+ {
+ CascadeWriteBuffer::WriteBufferPtrs cascade_buffer1;
+ CascadeWriteBuffer::WriteBufferConstructors cascade_buffer2;
+
+ if (buffer_size_memory > 0)
+ cascade_buffer1.emplace_back(std::make_shared<MemoryWriteBuffer>(buffer_size_memory));
+
+ if (buffer_until_eof)
+ {
+ auto tmp_data = std::make_shared<TemporaryDataOnDisk>(server.context()->getTempDataOnDisk());
+
+ auto create_tmp_disk_buffer = [tmp_data] (const WriteBufferPtr &) -> WriteBufferPtr
+ {
+ return tmp_data->createRawStream();
+ };
+
+ cascade_buffer2.emplace_back(std::move(create_tmp_disk_buffer));
+ }
+ else
+ {
+ auto push_memory_buffer_and_continue = [next_buffer = used_output.out_maybe_compressed] (const WriteBufferPtr & prev_buf)
+ {
+ auto * prev_memory_buffer = typeid_cast<MemoryWriteBuffer *>(prev_buf.get());
+ if (!prev_memory_buffer)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected MemoryWriteBuffer");
+
+ auto rdbuf = prev_memory_buffer->tryGetReadBuffer();
+ copyData(*rdbuf, *next_buffer);
+
+ return next_buffer;
+ };
+
+ cascade_buffer2.emplace_back(push_memory_buffer_and_continue);
+ }
+
+ used_output.out_maybe_delayed_and_compressed = std::make_shared<CascadeWriteBuffer>(
+ std::move(cascade_buffer1), std::move(cascade_buffer2));
+ }
+ else
+ {
+ used_output.out_maybe_delayed_and_compressed = used_output.out_maybe_compressed;
+ }
+
+ /// Request body can be compressed using algorithm specified in the Content-Encoding header.
+ String http_request_compression_method_str = request.get("Content-Encoding", "");
+ int zstd_window_log_max = static_cast<int>(context->getSettingsRef().zstd_window_log_max);
+ auto in_post = wrapReadBufferWithCompressionMethod(
+ wrapReadBufferReference(request.getStream()),
+ chooseCompressionMethod({}, http_request_compression_method_str), zstd_window_log_max);
+
+ /// The data can also be compressed using incompatible internal algorithm. This is indicated by
+ /// 'decompress' query parameter.
+ std::unique_ptr<ReadBuffer> in_post_maybe_compressed;
+ bool in_post_compressed = false;
+ if (params.getParsed<bool>("decompress", false))
+ {
+ in_post_maybe_compressed = std::make_unique<CompressedReadBuffer>(*in_post);
+ in_post_compressed = true;
+ }
+ else
+ in_post_maybe_compressed = std::move(in_post);
+
+ std::unique_ptr<ReadBuffer> in;
+
+ static const NameSet reserved_param_names{"compress", "decompress", "user", "password", "quota_key", "query_id", "stacktrace",
+ "buffer_size", "wait_end_of_query", "session_id", "session_timeout", "session_check", "client_protocol_version", "close_session"};
+
+ Names reserved_param_suffixes;
+
+ auto param_could_be_skipped = [&] (const String & name)
+ {
+ /// Empty parameter appears when URL like ?&a=b or a=b&&c=d. Just skip them for user's convenience.
+ if (name.empty())
+ return true;
+
+ if (reserved_param_names.contains(name))
+ return true;
+
+ for (const String & suffix : reserved_param_suffixes)
+ {
+ if (endsWith(name, suffix))
+ return true;
+ }
+
+ return false;
+ };
+
+ /// Settings can be overridden in the query.
+ /// Some parameters (database, default_format, everything used in the code above) do not
+ /// belong to the Settings class.
+
+ /// 'readonly' setting values mean:
+ /// readonly = 0 - any query is allowed, client can change any setting.
+ /// readonly = 1 - only readonly queries are allowed, client can't change settings.
+ /// readonly = 2 - only readonly queries are allowed, client can change any setting except 'readonly'.
+
+ /// In theory if initially readonly = 0, the client can change any setting and then set readonly
+ /// to some other value.
+ const auto & settings = context->getSettingsRef();
+
+ /// Only readonly queries are allowed for HTTP GET requests.
+ if (request.getMethod() == HTTPServerRequest::HTTP_GET)
+ {
+ if (settings.readonly == 0)
+ context->setSetting("readonly", 2);
+ }
+
+ bool has_external_data = startsWith(request.getContentType(), "multipart/form-data");
+
+ if (has_external_data)
+ {
+ /// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
+ reserved_param_suffixes.reserve(3);
+ /// It is a bug and ambiguity with `date_time_input_format` and `low_cardinality_allow_in_native_format` formats/settings.
+ reserved_param_suffixes.emplace_back("_format");
+ reserved_param_suffixes.emplace_back("_types");
+ reserved_param_suffixes.emplace_back("_structure");
+ }
+
+ std::string database = request.get("X-ClickHouse-Database", "");
+ std::string default_format = request.get("X-ClickHouse-Format", "");
+
+ SettingsChanges settings_changes;
+ for (const auto & [key, value] : params)
+ {
+ if (key == "database")
+ {
+ if (database.empty())
+ database = value;
+ }
+ else if (key == "default_format")
+ {
+ if (default_format.empty())
+ default_format = value;
+ }
+ else if (param_could_be_skipped(key))
+ {
+ }
+ else
+ {
+ /// Other than query parameters are treated as settings.
+ if (!customizeQueryParam(context, key, value))
+ settings_changes.push_back({key, value});
+ }
+ }
+
+ if (!database.empty())
+ context->setCurrentDatabase(database);
+
+ if (!default_format.empty())
+ context->setDefaultFormat(default_format);
+
+ /// For external data we also want settings
+ context->checkSettingsConstraints(settings_changes, SettingSource::QUERY);
+ context->applySettingsChanges(settings_changes);
+
+ /// Set the query id supplied by the user, if any, and also update the OpenTelemetry fields.
+ context->setCurrentQueryId(params.get("query_id", request.get("X-ClickHouse-Query-Id", "")));
+
+ /// Initialize query scope, once query_id is initialized.
+ /// (To track as much allocations as possible)
+ query_scope.emplace(context);
+
+ /// NOTE: this may create pretty huge allocations that will not be accounted in trace_log,
+ /// because memory_profiler_sample_probability/memory_profiler_step are not applied yet,
+ /// they will be applied in ProcessList::insert() from executeQuery() itself.
+ const auto & query = getQuery(request, params, context);
+ std::unique_ptr<ReadBuffer> in_param = std::make_unique<ReadBufferFromString>(query);
+
+ /// HTTP response compression is turned on only if the client signalled that they support it
+ /// (using Accept-Encoding header) and 'enable_http_compression' setting is turned on.
+ used_output.out->setCompression(client_supports_http_compression && settings.enable_http_compression);
+ if (client_supports_http_compression)
+ used_output.out->setCompressionLevel(static_cast<int>(settings.http_zlib_compression_level));
+
+ used_output.out->setSendProgress(settings.send_progress_in_http_headers);
+ used_output.out->setSendProgressInterval(settings.http_headers_progress_interval_ms);
+
+ /// If 'http_native_compression_disable_checksumming_on_decompress' setting is turned on,
+ /// checksums of client data compressed with internal algorithm are not checked.
+ if (in_post_compressed && settings.http_native_compression_disable_checksumming_on_decompress)
+ static_cast<CompressedReadBuffer &>(*in_post_maybe_compressed).disableChecksumming();
+
+ /// Add CORS header if 'add_http_cors_header' setting is turned on send * in Access-Control-Allow-Origin
+ /// Note that whether the header is added is determined by the settings, and we can only get the user settings after authentication.
+ /// Once the authentication fails, the header can't be added.
+ if (settings.add_http_cors_header && !request.get("Origin", "").empty() && !config.has("http_options_response"))
+ used_output.out->addHeaderCORS(true);
+
+ auto append_callback = [my_context = context] (ProgressCallback callback)
+ {
+ auto prev = my_context->getProgressCallback();
+
+ my_context->setProgressCallback([prev, callback] (const Progress & progress)
+ {
+ if (prev)
+ prev(progress);
+
+ callback(progress);
+ });
+ };
+
+ /// While still no data has been sent, we will report about query execution progress by sending HTTP headers.
+ /// Note that we add it unconditionally so the progress is available for `X-ClickHouse-Summary`
+ append_callback([&used_output](const Progress & progress)
+ {
+ used_output.out->onProgress(progress);
+ });
+
+ if (settings.readonly > 0 && settings.cancel_http_readonly_queries_on_client_close)
+ {
+ append_callback([&context, &request](const Progress &)
+ {
+ /// Assume that at the point this method is called no one is reading data from the socket any more:
+ /// should be true for read-only queries.
+ if (!request.checkPeerConnected())
+ context->killCurrentQuery();
+ });
+ }
+
+ customizeContext(request, context, *in_post_maybe_compressed);
+ in = has_external_data ? std::move(in_param) : std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
+
+ executeQuery(*in, *used_output.out_maybe_delayed_and_compressed, /* allow_into_outfile = */ false, context,
+ [&response, this] (const QueryResultDetails & details)
+ {
+ response.add("X-ClickHouse-Query-Id", details.query_id);
+
+ if (content_type_override)
+ response.setContentType(*content_type_override);
+ else if (details.content_type)
+ response.setContentType(*details.content_type);
+
+ if (details.format)
+ response.add("X-ClickHouse-Format", *details.format);
+
+ if (details.timezone)
+ response.add("X-ClickHouse-Timezone", *details.timezone);
+ }
+ );
+
+ if (used_output.hasDelayed())
+ {
+ /// TODO: set Content-Length if possible
+ pushDelayedResults(used_output);
+ }
+
+ /// Send HTTP headers with code 200 if no exception happened and the data is still not sent to the client.
+ used_output.finalize();
+}
+
+void HTTPHandler::trySendExceptionToClient(
+ const std::string & s, int exception_code, HTTPServerRequest & request, HTTPServerResponse & response, Output & used_output)
+try
+{
+ /// In case data has already been sent, like progress headers, try using the output buffer to
+ /// set the exception code since it will be able to append it if it hasn't finished writing headers
+ if (response.sent() && used_output.out)
+ used_output.out->setExceptionCode(exception_code);
+ else
+ response.set("X-ClickHouse-Exception-Code", toString<int>(exception_code));
+
+ /// FIXME: make sure that no one else is reading from the same stream at the moment.
+
+ /// If HTTP method is POST and Keep-Alive is turned on, we should read the whole request body
+ /// to avoid reading part of the current request body in the next request.
+ if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST
+ && response.getKeepAlive()
+ && exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED
+ && !request.getStream().eof())
+ {
+ request.getStream().ignoreAll();
+ }
+
+ if (exception_code == ErrorCodes::REQUIRED_PASSWORD)
+ {
+ response.requireAuthentication("ClickHouse server HTTP API");
+ }
+ else
+ {
+ response.setStatusAndReason(exceptionCodeToHTTPStatus(exception_code));
+ }
+
+ if (!response.sent() && !used_output.out_maybe_compressed)
+ {
+ /// If nothing was sent yet and we don't even know if we must compress the response.
+ *response.send() << s << std::endl;
+ }
+ else if (used_output.out_maybe_compressed)
+ {
+ /// Destroy CascadeBuffer to actualize buffers' positions and reset extra references
+ if (used_output.hasDelayed())
+ {
+ /// do not call finalize here for CascadeWriteBuffer used_output.out_maybe_delayed_and_compressed,
+ /// exception is written into used_output.out_maybe_compressed later
+ /// HTTPHandler::trySendExceptionToClient is called with exception context, it is Ok to destroy buffers
+ used_output.out_maybe_delayed_and_compressed.reset();
+ }
+
+ /// Send the error message into already used (and possibly compressed) stream.
+ /// Note that the error message will possibly be sent after some data.
+ /// Also HTTP code 200 could have already been sent.
+
+ /// If buffer has data, and that data wasn't sent yet, then no need to send that data
+ bool data_sent = used_output.out->count() != used_output.out->offset();
+
+ if (!data_sent)
+ {
+ used_output.out_maybe_compressed->position() = used_output.out_maybe_compressed->buffer().begin();
+ used_output.out->position() = used_output.out->buffer().begin();
+ }
+
+ writeString(s, *used_output.out_maybe_compressed);
+ writeChar('\n', *used_output.out_maybe_compressed);
+
+ used_output.out_maybe_compressed->next();
+ }
+ else
+ {
+ UNREACHABLE();
+ }
+
+ used_output.finalize();
+}
+catch (...)
+{
+ tryLogCurrentException(log, "Cannot send exception to client");
+
+ try
+ {
+ used_output.finalize();
+ }
+ catch (...)
+ {
+ tryLogCurrentException(log, "Cannot flush data to client (after sending exception)");
+ }
+}
+
+
+void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ setThreadName("HTTPHandler");
+ ThreadStatus thread_status;
+
+ session = std::make_unique<Session>(server.context(), ClientInfo::Interface::HTTP, request.isSecure());
+ SCOPE_EXIT({ session.reset(); });
+ std::optional<CurrentThread::QueryScope> query_scope;
+
+ Output used_output;
+
+ /// In case of exception, send stack trace to client.
+ bool with_stacktrace = false;
+ /// Close http session (if any) after processing the request
+ bool close_session = false;
+ String session_id;
+
+ SCOPE_EXIT_SAFE({
+ if (close_session && !session_id.empty())
+ session->closeSession(session_id);
+ });
+
+ OpenTelemetry::TracingContextHolderPtr thread_trace_context;
+ SCOPE_EXIT({
+ // make sure the response status is recorded
+ if (thread_trace_context)
+ thread_trace_context->root_span.addAttribute("clickhouse.http_status", response.getStatus());
+ });
+
+ try
+ {
+ if (request.getMethod() == HTTPServerRequest::HTTP_OPTIONS)
+ {
+ processOptionsRequest(response, server.config());
+ return;
+ }
+
+ // Parse the OpenTelemetry traceparent header.
+ auto & client_trace_context = session->getClientTraceContext();
+ if (request.has("traceparent"))
+ {
+ std::string opentelemetry_traceparent = request.get("traceparent");
+ std::string error;
+ if (!client_trace_context.parseTraceparentHeader(opentelemetry_traceparent, error))
+ {
+ LOG_DEBUG(log, "Failed to parse OpenTelemetry traceparent header '{}': {}", opentelemetry_traceparent, error);
+ }
+ client_trace_context.tracestate = request.get("tracestate", "");
+ }
+
+ // Setup tracing context for this thread
+ auto context = session->sessionOrGlobalContext();
+ thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("HTTPHandler",
+ client_trace_context,
+ context->getSettingsRef(),
+ context->getOpenTelemetrySpanLog());
+ thread_trace_context->root_span.kind = OpenTelemetry::SERVER;
+ thread_trace_context->root_span.addAttribute("clickhouse.uri", request.getURI());
+
+ response.setContentType("text/plain; charset=UTF-8");
+ response.set("X-ClickHouse-Server-Display-Name", server_display_name);
+
+ if (!request.get("Origin", "").empty())
+ tryAddHttpOptionHeadersFromConfig(response, server.config());
+
+ /// For keep-alive to work.
+ if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
+ response.setChunkedTransferEncoding(true);
+
+ HTMLForm params(default_settings, request);
+ with_stacktrace = params.getParsed<bool>("stacktrace", false);
+ close_session = params.getParsed<bool>("close_session", false);
+ if (close_session)
+ session_id = params.get("session_id");
+
+ /// FIXME: maybe this check is already unnecessary.
+ /// Workaround. Poco does not detect 411 Length Required case.
+ if (request.getMethod() == HTTPRequest::HTTP_POST && !request.getChunkedTransferEncoding() && !request.hasContentLength())
+ {
+ throw Exception(ErrorCodes::HTTP_LENGTH_REQUIRED,
+ "The Transfer-Encoding is not chunked and there "
+ "is no Content-Length header for POST request");
+ }
+
+ processQuery(request, params, response, used_output, query_scope);
+ if (request_credentials)
+ LOG_DEBUG(log, "Authentication in progress...");
+ else
+ LOG_DEBUG(log, "Done processing query");
+ }
+ catch (...)
+ {
+ SCOPE_EXIT({
+ request_credentials.reset(); // ...so that the next requests on the connection have to always start afresh in case of exceptions.
+ });
+
+ /// Check if exception was thrown in used_output.finalize().
+ /// In this case used_output can be in invalid state and we
+ /// cannot write in it anymore. So, just log this exception.
+ if (used_output.isFinalized())
+ {
+ if (thread_trace_context)
+ thread_trace_context->root_span.addAttribute("clickhouse.exception", "Cannot flush data to client");
+
+ tryLogCurrentException(log, "Cannot flush data to client");
+ return;
+ }
+
+ tryLogCurrentException(log);
+
+ /** If exception is received from remote server, then stack trace is embedded in message.
+ * If exception is thrown on local server, then stack trace is in separate field.
+ */
+ ExecutionStatus status = ExecutionStatus::fromCurrentException("", with_stacktrace);
+ trySendExceptionToClient(status.message, status.code, request, response, used_output);
+
+ if (thread_trace_context)
+ thread_trace_context->root_span.addAttribute(status);
+ }
+
+ used_output.finalize();
+}
+
+DynamicQueryHandler::DynamicQueryHandler(IServer & server_, const std::string & param_name_, const std::optional<String>& content_type_override_)
+ : HTTPHandler(server_, "DynamicQueryHandler", content_type_override_), param_name(param_name_)
+{
+}
+
+bool DynamicQueryHandler::customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value)
+{
+ if (key == param_name)
+ return true; /// do nothing
+
+ if (startsWith(key, QUERY_PARAMETER_NAME_PREFIX))
+ {
+ /// Save name and values of substitution in dictionary.
+ const String parameter_name = key.substr(strlen(QUERY_PARAMETER_NAME_PREFIX));
+
+ if (!context->getQueryParameters().contains(parameter_name))
+ context->setQueryParameter(parameter_name, value);
+ return true;
+ }
+
+ return false;
+}
+
+std::string DynamicQueryHandler::getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context)
+{
+ if (likely(!startsWith(request.getContentType(), "multipart/form-data")))
+ {
+ /// Part of the query can be passed in the 'query' parameter and the rest in the request body
+ /// (http method need not necessarily be POST). In this case the entire query consists of the
+ /// contents of the 'query' parameter, a line break and the request body.
+ std::string query_param = params.get(param_name, "");
+ return query_param.empty() ? query_param : query_param + "\n";
+ }
+
+ /// Support for "external data for query processing".
+ /// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
+ ExternalTablesHandler handler(context, params);
+ params.load(request, request.getStream(), handler);
+
+ std::string full_query;
+ /// Params are of both form params POST and uri (GET params)
+ for (const auto & it : params)
+ {
+ if (it.first == param_name)
+ {
+ full_query += it.second;
+ }
+ else
+ {
+ customizeQueryParam(context, it.first, it.second);
+ }
+ }
+
+ return full_query;
+}
+
+PredefinedQueryHandler::PredefinedQueryHandler(
+ IServer & server_,
+ const NameSet & receive_params_,
+ const std::string & predefined_query_,
+ const CompiledRegexPtr & url_regex_,
+ const std::unordered_map<String, CompiledRegexPtr> & header_name_with_regex_,
+ const std::optional<String> & content_type_override_)
+ : HTTPHandler(server_, "PredefinedQueryHandler", content_type_override_)
+ , receive_params(receive_params_)
+ , predefined_query(predefined_query_)
+ , url_regex(url_regex_)
+ , header_name_with_capture_regex(header_name_with_regex_)
+{
+}
+
+bool PredefinedQueryHandler::customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value)
+{
+ if (receive_params.contains(key))
+ {
+ context->setQueryParameter(key, value);
+ return true;
+ }
+
+ return false;
+}
+
+void PredefinedQueryHandler::customizeContext(HTTPServerRequest & request, ContextMutablePtr context, ReadBuffer & body)
+{
+ /// If in the configuration file, the handler's header is regex and contains named capture group
+ /// We will extract regex named capture groups as query parameters
+
+ const auto & set_query_params = [&](const char * begin, const char * end, const CompiledRegexPtr & compiled_regex)
+ {
+ int num_captures = compiled_regex->NumberOfCapturingGroups() + 1;
+
+ std::string_view matches[num_captures];
+ std::string_view input(begin, end - begin);
+ if (compiled_regex->Match(input, 0, end - begin, re2::RE2::Anchor::ANCHOR_BOTH, matches, num_captures))
+ {
+ for (const auto & [capturing_name, capturing_index] : compiled_regex->NamedCapturingGroups())
+ {
+ const auto & capturing_value = matches[capturing_index];
+
+ if (capturing_value.data())
+ context->setQueryParameter(capturing_name, String(capturing_value.data(), capturing_value.size()));
+ }
+ }
+ };
+
+ if (url_regex)
+ {
+ const auto & uri = request.getURI();
+ set_query_params(uri.data(), find_first_symbols<'?'>(uri.data(), uri.data() + uri.size()), url_regex);
+ }
+
+ for (const auto & [header_name, regex] : header_name_with_capture_regex)
+ {
+ const auto & header_value = request.get(header_name);
+ set_query_params(header_value.data(), header_value.data() + header_value.size(), regex);
+ }
+
+ if (unlikely(receive_params.contains("_request_body") && !context->getQueryParameters().contains("_request_body")))
+ {
+ WriteBufferFromOwnString value;
+ const auto & settings = context->getSettingsRef();
+
+ copyDataMaxBytes(body, value, settings.http_max_request_param_data_size);
+ context->setQueryParameter("_request_body", value.str());
+ }
+}
+
+std::string PredefinedQueryHandler::getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context)
+{
+ if (unlikely(startsWith(request.getContentType(), "multipart/form-data")))
+ {
+ /// Support for "external data for query processing".
+ ExternalTablesHandler handler(context, params);
+ params.load(request, request.getStream(), handler);
+ }
+
+ return predefined_query;
+}
+
+HTTPRequestHandlerFactoryPtr createDynamicHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix)
+{
+ auto query_param_name = config.getString(config_prefix + ".handler.query_param_name", "query");
+
+ std::optional<String> content_type_override;
+ if (config.has(config_prefix + ".handler.content_type"))
+ content_type_override = config.getString(config_prefix + ".handler.content_type");
+
+ auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<DynamicQueryHandler>>(
+ server, std::move(query_param_name), std::move(content_type_override));
+
+ factory->addFiltersFromConfig(config, config_prefix);
+
+ return factory;
+}
+
+static inline bool capturingNamedQueryParam(NameSet receive_params, const CompiledRegexPtr & compiled_regex)
+{
+ const auto & capturing_names = compiled_regex->NamedCapturingGroups();
+ return std::count_if(capturing_names.begin(), capturing_names.end(), [&](const auto & iterator)
+ {
+ return std::count_if(receive_params.begin(), receive_params.end(),
+ [&](const auto & param_name) { return param_name == iterator.first; });
+ });
+}
+
+static inline CompiledRegexPtr getCompiledRegex(const std::string & expression)
+{
+ auto compiled_regex = std::make_shared<const re2::RE2>(expression);
+
+ if (!compiled_regex->ok())
+ throw Exception(ErrorCodes::CANNOT_COMPILE_REGEXP, "Cannot compile re2: {} for http handling rule, error: {}. "
+ "Look at https://github.com/google/re2/wiki/Syntax for reference.", expression, compiled_regex->error());
+
+ return compiled_regex;
+}
+
+HTTPRequestHandlerFactoryPtr createPredefinedHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix)
+{
+ if (!config.has(config_prefix + ".handler.query"))
+ throw Exception(ErrorCodes::NO_ELEMENTS_IN_CONFIG, "There is no path '{}.handler.query' in configuration file.", config_prefix);
+
+ std::string predefined_query = config.getString(config_prefix + ".handler.query");
+ NameSet analyze_receive_params = analyzeReceiveQueryParams(predefined_query);
+
+ std::unordered_map<String, CompiledRegexPtr> headers_name_with_regex;
+ Poco::Util::AbstractConfiguration::Keys headers_name;
+ config.keys(config_prefix + ".headers", headers_name);
+
+ for (const auto & header_name : headers_name)
+ {
+ auto expression = config.getString(config_prefix + ".headers." + header_name);
+
+ if (!startsWith(expression, "regex:"))
+ continue;
+
+ expression = expression.substr(6);
+ auto regex = getCompiledRegex(expression);
+ if (capturingNamedQueryParam(analyze_receive_params, regex))
+ headers_name_with_regex.emplace(std::make_pair(header_name, regex));
+ }
+
+ std::optional<String> content_type_override;
+ if (config.has(config_prefix + ".handler.content_type"))
+ content_type_override = config.getString(config_prefix + ".handler.content_type");
+
+ std::shared_ptr<HandlingRuleHTTPHandlerFactory<PredefinedQueryHandler>> factory;
+
+ if (config.has(config_prefix + ".url"))
+ {
+ auto url_expression = config.getString(config_prefix + ".url");
+
+ if (startsWith(url_expression, "regex:"))
+ url_expression = url_expression.substr(6);
+
+ auto regex = getCompiledRegex(url_expression);
+ if (capturingNamedQueryParam(analyze_receive_params, regex))
+ {
+ factory = std::make_shared<HandlingRuleHTTPHandlerFactory<PredefinedQueryHandler>>(
+ server,
+ std::move(analyze_receive_params),
+ std::move(predefined_query),
+ std::move(regex),
+ std::move(headers_name_with_regex),
+ std::move(content_type_override));
+ factory->addFiltersFromConfig(config, config_prefix);
+ return factory;
+ }
+ }
+
+ factory = std::make_shared<HandlingRuleHTTPHandlerFactory<PredefinedQueryHandler>>(
+ server,
+ std::move(analyze_receive_params),
+ std::move(predefined_query),
+ CompiledRegexPtr{},
+ std::move(headers_name_with_regex),
+ std::move(content_type_override));
+ factory->addFiltersFromConfig(config, config_prefix);
+
+ return factory;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPHandler.h b/contrib/clickhouse/src/Server/HTTPHandler.h
new file mode 100644
index 0000000000..5eda592753
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPHandler.h
@@ -0,0 +1,173 @@
+#pragma once
+
+#include <Core/Names.h>
+#include <Server/HTTP/HTMLForm.h>
+#include <Server/HTTP/HTTPRequestHandler.h>
+#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
+#include <Common/CurrentMetrics.h>
+#include <Common/CurrentThread.h>
+
+#include <re2/re2.h>
+
+namespace CurrentMetrics
+{
+ extern const Metric HTTPConnection;
+}
+
+namespace Poco { class Logger; }
+
+namespace DB
+{
+
+class Session;
+class Credentials;
+class IServer;
+struct Settings;
+class WriteBufferFromHTTPServerResponse;
+
+using CompiledRegexPtr = std::shared_ptr<const re2::RE2>;
+
+class HTTPHandler : public HTTPRequestHandler
+{
+public:
+ HTTPHandler(IServer & server_, const std::string & name, const std::optional<String> & content_type_override_);
+ virtual ~HTTPHandler() override;
+
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+
+ /// This method is called right before the query execution.
+ virtual void customizeContext(HTTPServerRequest & /* request */, ContextMutablePtr /* context */, ReadBuffer & /* body */) {}
+
+ virtual bool customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) = 0;
+
+ virtual std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) = 0;
+
+private:
+ struct Output
+ {
+ /* Raw data
+ * ↓
+ * CascadeWriteBuffer out_maybe_delayed_and_compressed (optional)
+ * ↓ (forwards data if an overflow is occur or explicitly via pushDelayedResults)
+ * CompressedWriteBuffer out_maybe_compressed (optional)
+ * ↓
+ * WriteBufferFromHTTPServerResponse out
+ */
+
+ std::shared_ptr<WriteBufferFromHTTPServerResponse> out;
+ /// Points to 'out' or to CompressedWriteBuffer(*out), depending on settings.
+ std::shared_ptr<WriteBuffer> out_maybe_compressed;
+ /// Points to 'out' or to CompressedWriteBuffer(*out) or to CascadeWriteBuffer.
+ std::shared_ptr<WriteBuffer> out_maybe_delayed_and_compressed;
+
+ bool finalized = false;
+
+ inline bool hasDelayed() const
+ {
+ return out_maybe_delayed_and_compressed != out_maybe_compressed;
+ }
+
+ inline void finalize()
+ {
+ if (finalized)
+ return;
+ finalized = true;
+
+ if (out_maybe_delayed_and_compressed)
+ out_maybe_delayed_and_compressed->finalize();
+ if (out_maybe_compressed)
+ out_maybe_compressed->finalize();
+ if (out)
+ out->finalize();
+ }
+
+ inline bool isFinalized() const
+ {
+ return finalized;
+ }
+ };
+
+ IServer & server;
+ Poco::Logger * log;
+
+ /// It is the name of the server that will be sent in an http-header X-ClickHouse-Server-Display-Name.
+ String server_display_name;
+
+ CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection};
+
+ /// Reference to the immutable settings in the global context.
+ /// Those settings are used only to extract a http request's parameters.
+ /// See settings http_max_fields, http_max_field_name_size, http_max_field_value_size in HTMLForm.
+ const Settings & default_settings;
+
+ /// Overrides Content-Type provided by the format of the response.
+ std::optional<String> content_type_override;
+
+ // session is reset at the end of each request/response.
+ std::unique_ptr<Session> session;
+
+ // The request_credential instance may outlive a single request/response loop.
+ // This happens only when the authentication mechanism requires more than a single request/response exchange (e.g., SPNEGO).
+ std::unique_ptr<Credentials> request_credentials;
+
+ // Returns true when the user successfully authenticated,
+ // the session instance will be configured accordingly, and the request_credentials instance will be dropped.
+ // Returns false when the user is not authenticated yet, and the 'Negotiate' response is sent,
+ // the session and request_credentials instances are preserved.
+ // Throws an exception if authentication failed.
+ bool authenticateUser(
+ HTTPServerRequest & request,
+ HTMLForm & params,
+ HTTPServerResponse & response);
+
+ /// Also initializes 'used_output'.
+ void processQuery(
+ HTTPServerRequest & request,
+ HTMLForm & params,
+ HTTPServerResponse & response,
+ Output & used_output,
+ std::optional<CurrentThread::QueryScope> & query_scope);
+
+ void trySendExceptionToClient(
+ const std::string & s,
+ int exception_code,
+ HTTPServerRequest & request,
+ HTTPServerResponse & response,
+ Output & used_output);
+
+ static void pushDelayedResults(Output & used_output);
+};
+
+class DynamicQueryHandler : public HTTPHandler
+{
+private:
+ std::string param_name;
+public:
+ explicit DynamicQueryHandler(IServer & server_, const std::string & param_name_ = "query", const std::optional<String>& content_type_override_ = std::nullopt);
+
+ std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) override;
+
+ bool customizeQueryParam(ContextMutablePtr context, const std::string &key, const std::string &value) override;
+};
+
+class PredefinedQueryHandler : public HTTPHandler
+{
+private:
+ NameSet receive_params;
+ std::string predefined_query;
+ CompiledRegexPtr url_regex;
+ std::unordered_map<String, CompiledRegexPtr> header_name_with_capture_regex;
+public:
+ PredefinedQueryHandler(
+ IServer & server_, const NameSet & receive_params_, const std::string & predefined_query_
+ , const CompiledRegexPtr & url_regex_, const std::unordered_map<String, CompiledRegexPtr> & header_name_with_regex_
+ , const std::optional<std::string> & content_type_override_);
+
+ void customizeContext(HTTPServerRequest & request, ContextMutablePtr context, ReadBuffer & body) override;
+
+ std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) override;
+
+ bool customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) override;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp b/contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp
new file mode 100644
index 0000000000..1c911034da
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp
@@ -0,0 +1,186 @@
+#include <Server/HTTPHandlerFactory.h>
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+#include <Server/IServer.h>
+#include <Access/Credentials.h>
+
+#include <Poco/Util/AbstractConfiguration.h>
+
+#include "HTTPHandler.h"
+#include "NotFoundHandler.h"
+#include "StaticRequestHandler.h"
+#include "ReplicasStatusHandler.h"
+#include "InterserverIOHTTPHandler.h"
+#include "PrometheusRequestHandler.h"
+#include "WebUIRequestHandler.h"
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int LOGICAL_ERROR;
+ extern const int UNKNOWN_ELEMENT_IN_CONFIG;
+ extern const int INVALID_CONFIG_PARAMETER;
+}
+
+static void addCommonDefaultHandlersFactory(HTTPRequestHandlerFactoryMain & factory, IServer & server);
+static void addDefaultHandlersFactory(
+ HTTPRequestHandlerFactoryMain & factory,
+ IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics);
+
+static inline auto createHandlersFactoryFromConfig(
+ IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & name,
+ const String & prefix,
+ AsynchronousMetrics & async_metrics)
+{
+ auto main_handler_factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name);
+
+ Poco::Util::AbstractConfiguration::Keys keys;
+ config.keys(prefix, keys);
+
+ for (const auto & key : keys)
+ {
+ if (key == "defaults")
+ {
+ addDefaultHandlersFactory(*main_handler_factory, server, config, async_metrics);
+ }
+ else if (startsWith(key, "rule"))
+ {
+ const auto & handler_type = config.getString(prefix + "." + key + ".handler.type", "");
+
+ if (handler_type.empty())
+ throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "Handler type in config is not specified here: "
+ "{}.{}.handler.type", prefix, key);
+
+ if (handler_type == "static")
+ main_handler_factory->addHandler(createStaticHandlerFactory(server, config, prefix + "." + key));
+ else if (handler_type == "dynamic_query_handler")
+ main_handler_factory->addHandler(createDynamicHandlerFactory(server, config, prefix + "." + key));
+ else if (handler_type == "predefined_query_handler")
+ main_handler_factory->addHandler(createPredefinedHandlerFactory(server, config, prefix + "." + key));
+ else if (handler_type == "prometheus")
+ main_handler_factory->addHandler(createPrometheusHandlerFactory(server, config, async_metrics, prefix + "." + key));
+ else if (handler_type == "replicas_status")
+ main_handler_factory->addHandler(createReplicasStatusHandlerFactory(server, config, prefix + "." + key));
+ else
+ throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "Unknown handler type '{}' in config here: {}.{}.handler.type",
+ handler_type, prefix, key);
+ }
+ else
+ throw Exception(ErrorCodes::UNKNOWN_ELEMENT_IN_CONFIG, "Unknown element in config: "
+ "{}.{}, must be 'rule' or 'defaults'", prefix, key);
+ }
+
+ return main_handler_factory;
+}
+
+static inline HTTPRequestHandlerFactoryPtr
+createHTTPHandlerFactory(IServer & server, const Poco::Util::AbstractConfiguration & config, const std::string & name, AsynchronousMetrics & async_metrics)
+{
+ if (config.has("http_handlers"))
+ {
+ return createHandlersFactoryFromConfig(server, config, name, "http_handlers", async_metrics);
+ }
+ else
+ {
+ auto factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name);
+ addDefaultHandlersFactory(*factory, server, config, async_metrics);
+ return factory;
+ }
+}
+
+static inline HTTPRequestHandlerFactoryPtr createInterserverHTTPHandlerFactory(IServer & server, const std::string & name)
+{
+ auto factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name);
+ addCommonDefaultHandlersFactory(*factory, server);
+
+ auto main_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<InterserverIOHTTPHandler>>(server);
+ main_handler->allowPostAndGetParamsAndOptionsRequest();
+ factory->addHandler(main_handler);
+
+ return factory;
+}
+
+HTTPRequestHandlerFactoryPtr createHandlerFactory(IServer & server, const Poco::Util::AbstractConfiguration & config, AsynchronousMetrics & async_metrics, const std::string & name)
+{
+ if (name == "HTTPHandler-factory" || name == "HTTPSHandler-factory")
+ return createHTTPHandlerFactory(server, config, name, async_metrics);
+ else if (name == "InterserverIOHTTPHandler-factory" || name == "InterserverIOHTTPSHandler-factory")
+ return createInterserverHTTPHandlerFactory(server, name);
+ else if (name == "PrometheusHandler-factory")
+ return createPrometheusMainHandlerFactory(server, config, async_metrics, name);
+
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "LOGICAL ERROR: Unknown HTTP handler factory name.");
+}
+
+static const auto ping_response_expression = "Ok.\n";
+static const auto root_response_expression = "config://http_server_default_response";
+
+void addCommonDefaultHandlersFactory(HTTPRequestHandlerFactoryMain & factory, IServer & server)
+{
+ auto root_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<StaticRequestHandler>>(server, root_response_expression);
+ root_handler->attachStrictPath("/");
+ root_handler->allowGetAndHeadRequest();
+ factory.addHandler(root_handler);
+
+ auto ping_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<StaticRequestHandler>>(server, ping_response_expression);
+ ping_handler->attachStrictPath("/ping");
+ ping_handler->allowGetAndHeadRequest();
+ factory.addPathToHints("/ping");
+ factory.addHandler(ping_handler);
+
+ auto replicas_status_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<ReplicasStatusHandler>>(server);
+ replicas_status_handler->attachNonStrictPath("/replicas_status");
+ replicas_status_handler->allowGetAndHeadRequest();
+ factory.addPathToHints("/replicas_status");
+ factory.addHandler(replicas_status_handler);
+
+ auto play_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<WebUIRequestHandler>>(server);
+ play_handler->attachNonStrictPath("/play");
+ play_handler->allowGetAndHeadRequest();
+ factory.addPathToHints("/play");
+ factory.addHandler(play_handler);
+
+ auto dashboard_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<WebUIRequestHandler>>(server);
+ dashboard_handler->attachNonStrictPath("/dashboard");
+ dashboard_handler->allowGetAndHeadRequest();
+ factory.addPathToHints("/dashboard");
+ factory.addHandler(dashboard_handler);
+
+ auto js_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<WebUIRequestHandler>>(server);
+ js_handler->attachNonStrictPath("/js/");
+ js_handler->allowGetAndHeadRequest();
+ factory.addHandler(js_handler);
+}
+
+void addDefaultHandlersFactory(
+ HTTPRequestHandlerFactoryMain & factory,
+ IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics)
+{
+ addCommonDefaultHandlersFactory(factory, server);
+
+ auto query_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<DynamicQueryHandler>>(server, "query");
+ query_handler->allowPostAndGetParamsAndOptionsRequest();
+ factory.addHandler(query_handler);
+
+ /// We check that prometheus handler will be served on current (default) port.
+ /// Otherwise it will be created separately, see createHandlerFactory(...).
+ if (config.has("prometheus") && config.getInt("prometheus.port", 0) == 0)
+ {
+ auto prometheus_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<PrometheusRequestHandler>>(
+ server, PrometheusMetricsWriter(config, "prometheus", async_metrics));
+ prometheus_handler->attachStrictPath(config.getString("prometheus.endpoint", "/metrics"));
+ prometheus_handler->allowGetAndHeadRequest();
+ factory.addHandler(prometheus_handler);
+ }
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPHandlerFactory.h b/contrib/clickhouse/src/Server/HTTPHandlerFactory.h
new file mode 100644
index 0000000000..fe11833dc3
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPHandlerFactory.h
@@ -0,0 +1,148 @@
+#pragma once
+
+#include <Common/AsynchronousMetrics.h>
+#include <Server/HTTP/HTMLForm.h>
+#include <Server/HTTP/HTTPRequestHandlerFactory.h>
+#include <Server/HTTPHandlerRequestFilter.h>
+#include <Server/HTTPRequestHandlerFactoryMain.h>
+#include <Common/StringUtils/StringUtils.h>
+
+#include <Poco/Util/AbstractConfiguration.h>
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int UNKNOWN_ELEMENT_IN_CONFIG;
+}
+
+class IServer;
+
+template <typename TEndpoint>
+class HandlingRuleHTTPHandlerFactory : public HTTPRequestHandlerFactory
+{
+public:
+ using Filter = std::function<bool(const HTTPServerRequest &)>;
+
+ template <typename... TArgs>
+ explicit HandlingRuleHTTPHandlerFactory(TArgs &&... args)
+ {
+ creator = [my_args = std::tuple<TArgs...>(std::forward<TArgs>(args) ...)]()
+ {
+ return std::apply([&](auto && ... endpoint_args)
+ {
+ return std::make_unique<TEndpoint>(std::forward<decltype(endpoint_args)>(endpoint_args)...);
+ }, std::move(my_args));
+ };
+ }
+
+ void addFilter(Filter cur_filter)
+ {
+ Filter prev_filter = filter;
+ filter = [prev_filter, cur_filter](const auto & request)
+ {
+ return prev_filter ? prev_filter(request) && cur_filter(request) : cur_filter(request);
+ };
+ }
+
+ void addFiltersFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & prefix)
+ {
+ Poco::Util::AbstractConfiguration::Keys filters_type;
+ config.keys(prefix, filters_type);
+
+ for (const auto & filter_type : filters_type)
+ {
+ if (filter_type == "handler")
+ continue;
+ else if (filter_type == "url")
+ addFilter(urlFilter(config, prefix + ".url"));
+ else if (filter_type == "headers")
+ addFilter(headersFilter(config, prefix + ".headers"));
+ else if (filter_type == "methods")
+ addFilter(methodsFilter(config, prefix + ".methods"));
+ else
+ throw Exception(ErrorCodes::UNKNOWN_ELEMENT_IN_CONFIG, "Unknown element in config: {}.{}", prefix, filter_type);
+ }
+ }
+
+ void attachStrictPath(const String & strict_path)
+ {
+ addFilter([strict_path](const auto & request) { return request.getURI() == strict_path; });
+ }
+
+ void attachNonStrictPath(const String & non_strict_path)
+ {
+ addFilter([non_strict_path](const auto & request) { return startsWith(request.getURI(), non_strict_path); });
+ }
+
+ /// Handle GET or HEAD endpoint on specified path
+ void allowGetAndHeadRequest()
+ {
+ addFilter([](const auto & request)
+ {
+ return request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET
+ || request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD;
+ });
+ }
+
+ /// Handle Post request or (Get or Head) with params or OPTIONS requests
+ void allowPostAndGetParamsAndOptionsRequest()
+ {
+ addFilter([](const auto & request)
+ {
+ return (request.getURI().find('?') != std::string::npos
+ && (request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET
+ || request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD))
+ || request.getMethod() == Poco::Net::HTTPRequest::HTTP_OPTIONS
+ || request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST;
+ });
+ }
+
+ std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override
+ {
+ return filter(request) ? creator() : nullptr;
+ }
+
+private:
+ Filter filter;
+ std::function<std::unique_ptr<HTTPRequestHandler> ()> creator;
+};
+
+HTTPRequestHandlerFactoryPtr createStaticHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix);
+
+HTTPRequestHandlerFactoryPtr createDynamicHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix);
+
+HTTPRequestHandlerFactoryPtr createPredefinedHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix);
+
+HTTPRequestHandlerFactoryPtr createReplicasStatusHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix);
+
+HTTPRequestHandlerFactoryPtr
+createPrometheusHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics,
+ const std::string & config_prefix);
+
+HTTPRequestHandlerFactoryPtr
+createPrometheusMainHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics,
+ const std::string & name);
+
+/// @param server - used in handlers to check IServer::isCancelled()
+/// @param config - not the same as server.config(), since it can be newer
+/// @param async_metrics - used for prometheus (in case of prometheus.asynchronous_metrics=true)
+HTTPRequestHandlerFactoryPtr createHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics,
+ const std::string & name);
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h b/contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h
new file mode 100644
index 0000000000..25cbb95087
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h
@@ -0,0 +1,102 @@
+#pragma once
+
+#include <Server/HTTP/HTTPServerRequest.h>
+#include <Common/Exception.h>
+#include <Common/StringUtils/StringUtils.h>
+#include <base/find_symbols.h>
+
+#include <re2/re2.h>
+#include <Poco/StringTokenizer.h>
+#include <Poco/Util/LayeredConfiguration.h>
+
+#include <unordered_map>
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int CANNOT_COMPILE_REGEXP;
+}
+
+using CompiledRegexPtr = std::shared_ptr<const re2::RE2>;
+
+static inline bool checkRegexExpression(std::string_view match_str, const CompiledRegexPtr & compiled_regex)
+{
+ int num_captures = compiled_regex->NumberOfCapturingGroups() + 1;
+
+ std::string_view matches[num_captures];
+ return compiled_regex->Match({match_str.data(), match_str.size()}, 0, match_str.size(), re2::RE2::Anchor::ANCHOR_BOTH, matches, num_captures);
+}
+
+static inline bool checkExpression(std::string_view match_str, const std::pair<String, CompiledRegexPtr> & expression)
+{
+ if (expression.second)
+ return checkRegexExpression(match_str, expression.second);
+
+ return match_str == expression.first;
+}
+
+static inline auto methodsFilter(const Poco::Util::AbstractConfiguration & config, const std::string & config_path) /// NOLINT
+{
+ std::vector<String> methods;
+ Poco::StringTokenizer tokenizer(config.getString(config_path), ",");
+
+ for (const auto & iterator : tokenizer)
+ methods.emplace_back(Poco::toUpper(Poco::trim(iterator)));
+
+ return [methods](const HTTPServerRequest & request) { return std::count(methods.begin(), methods.end(), request.getMethod()); };
+}
+
+static inline auto getExpression(const std::string & expression)
+{
+ if (!startsWith(expression, "regex:"))
+ return std::make_pair(expression, CompiledRegexPtr{});
+
+ auto compiled_regex = std::make_shared<const re2::RE2>(expression.substr(6));
+
+ if (!compiled_regex->ok())
+ throw Exception(ErrorCodes::CANNOT_COMPILE_REGEXP, "cannot compile re2: {} for http handling rule, error: {}. "
+ "Look at https://github.com/google/re2/wiki/Syntax for reference.",
+ expression, compiled_regex->error());
+ return std::make_pair(expression, compiled_regex);
+}
+
+static inline auto urlFilter(const Poco::Util::AbstractConfiguration & config, const std::string & config_path) /// NOLINT
+{
+ return [expression = getExpression(config.getString(config_path))](const HTTPServerRequest & request)
+ {
+ const auto & uri = request.getURI();
+ const auto & end = find_first_symbols<'?'>(uri.data(), uri.data() + uri.size());
+
+ return checkExpression(std::string_view(uri.data(), end - uri.data()), expression);
+ };
+}
+
+static inline auto headersFilter(const Poco::Util::AbstractConfiguration & config, const std::string & prefix) /// NOLINT
+{
+ std::unordered_map<String, std::pair<String, CompiledRegexPtr>> headers_expression;
+ Poco::Util::AbstractConfiguration::Keys headers_name;
+ config.keys(prefix, headers_name);
+
+ for (const auto & header_name : headers_name)
+ {
+ const auto & expression = getExpression(config.getString(prefix + "." + header_name));
+ checkExpression("", expression); /// Check expression syntax is correct
+ headers_expression.emplace(std::make_pair(header_name, expression));
+ }
+
+ return [headers_expression](const HTTPServerRequest & request)
+ {
+ for (const auto & [header_name, header_expression] : headers_expression)
+ {
+ const auto & header_value = request.get(header_name, "");
+ if (!checkExpression(std::string_view(header_value.data(), header_value.size()), header_expression))
+ return false;
+ }
+
+ return true;
+ };
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPPathHints.cpp b/contrib/clickhouse/src/Server/HTTPPathHints.cpp
new file mode 100644
index 0000000000..51ef3eabff
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPPathHints.cpp
@@ -0,0 +1,16 @@
+#include <Server/HTTPPathHints.h>
+
+namespace DB
+{
+
+void HTTPPathHints::add(const String & http_path)
+{
+ http_paths.push_back(http_path);
+}
+
+std::vector<String> HTTPPathHints::getAllRegisteredNames() const
+{
+ return http_paths;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPPathHints.h b/contrib/clickhouse/src/Server/HTTPPathHints.h
new file mode 100644
index 0000000000..708816ebf0
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPPathHints.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <base/types.h>
+
+#include <Common/NamePrompter.h>
+
+namespace DB
+{
+
+class HTTPPathHints : public IHints<1, HTTPPathHints>
+{
+public:
+ std::vector<String> getAllRegisteredNames() const override;
+ void add(const String & http_path);
+
+private:
+ std::vector<String> http_paths;
+};
+
+using HTTPPathHintsPtr = std::shared_ptr<HTTPPathHints>;
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp
new file mode 100644
index 0000000000..5481bcd508
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp
@@ -0,0 +1,38 @@
+#include <Server/HTTPRequestHandlerFactoryMain.h>
+#include <Server/NotFoundHandler.h>
+
+#include <Common/logger_useful.h>
+
+namespace DB
+{
+
+HTTPRequestHandlerFactoryMain::HTTPRequestHandlerFactoryMain(const std::string & name_)
+ : log(&Poco::Logger::get(name_)), name(name_)
+{
+}
+
+std::unique_ptr<HTTPRequestHandler> HTTPRequestHandlerFactoryMain::createRequestHandler(const HTTPServerRequest & request)
+{
+ LOG_TRACE(log, "HTTP Request for {}. Method: {}, Address: {}, User-Agent: {}{}, Content Type: {}, Transfer Encoding: {}, X-Forwarded-For: {}",
+ name, request.getMethod(), request.clientAddress().toString(), request.get("User-Agent", "(none)"),
+ (request.hasContentLength() ? (", Length: " + std::to_string(request.getContentLength())) : ("")),
+ request.getContentType(), request.getTransferEncoding(), request.get("X-Forwarded-For", "(none)"));
+
+ for (auto & handler_factory : child_factories)
+ {
+ auto handler = handler_factory->createRequestHandler(request);
+ if (handler)
+ return handler;
+ }
+
+ if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET
+ || request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD
+ || request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
+ {
+ return std::unique_ptr<HTTPRequestHandler>(new NotFoundHandler(hints.getHints(request.getURI())));
+ }
+
+ return nullptr;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h
new file mode 100644
index 0000000000..07b278d831
--- /dev/null
+++ b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h
@@ -0,0 +1,31 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandlerFactory.h>
+#include <Server/HTTPPathHints.h>
+
+#include <vector>
+
+namespace DB
+{
+
+/// Handle request using child handlers
+class HTTPRequestHandlerFactoryMain : public HTTPRequestHandlerFactory
+{
+public:
+ explicit HTTPRequestHandlerFactoryMain(const std::string & name_);
+
+ void addHandler(HTTPRequestHandlerFactoryPtr child_factory) { child_factories.emplace_back(child_factory); }
+
+ void addPathToHints(const std::string & http_path) { hints.add(http_path); }
+
+ std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override;
+
+private:
+ Poco::Logger * log;
+ std::string name;
+ HTTPPathHints hints;
+
+ std::vector<HTTPRequestHandlerFactoryPtr> child_factories;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/IServer.h b/contrib/clickhouse/src/Server/IServer.h
new file mode 100644
index 0000000000..c55b045d2a
--- /dev/null
+++ b/contrib/clickhouse/src/Server/IServer.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <Interpreters/Context_fwd.h>
+
+namespace Poco
+{
+
+namespace Util
+{
+class LayeredConfiguration;
+}
+
+class Logger;
+
+}
+
+
+namespace DB
+{
+
+class IServer
+{
+public:
+ /// Returns the application's configuration.
+ virtual Poco::Util::LayeredConfiguration & config() const = 0;
+
+ /// Returns the application's logger.
+ virtual Poco::Logger & logger() const = 0;
+
+ /// Returns global application's context.
+ virtual ContextMutablePtr context() const = 0;
+
+ /// Returns true if shutdown signaled.
+ virtual bool isCancelled() const = 0;
+
+ virtual ~IServer() = default;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp
new file mode 100644
index 0000000000..9741592868
--- /dev/null
+++ b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp
@@ -0,0 +1,163 @@
+#include <Server/InterserverIOHTTPHandler.h>
+
+#include <Server/IServer.h>
+
+#include <Compression/CompressedWriteBuffer.h>
+#include <IO/ReadBufferFromIStream.h>
+#include <Interpreters/Context.h>
+#include <Interpreters/InterserverIOHandler.h>
+#include <Server/HTTP/HTMLForm.h>
+#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
+#include <Common/setThreadName.h>
+#include <Common/logger_useful.h>
+
+#include <Poco/Net/HTTPBasicCredentials.h>
+#include <Poco/Util/LayeredConfiguration.h>
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int ABORTED;
+ extern const int TOO_MANY_SIMULTANEOUS_QUERIES;
+}
+
+std::pair<String, bool> InterserverIOHTTPHandler::checkAuthentication(HTTPServerRequest & request) const
+{
+ auto server_credentials = server.context()->getInterserverCredentials();
+ if (server_credentials)
+ {
+ if (!request.hasCredentials())
+ return server_credentials->isValidUser("", "");
+
+ String scheme, info;
+ request.getCredentials(scheme, info);
+
+ if (scheme != "Basic")
+ return {"Server requires HTTP Basic authentication but client provides another method", false};
+
+ Poco::Net::HTTPBasicCredentials credentials(info);
+ return server_credentials->isValidUser(credentials.getUsername(), credentials.getPassword());
+ }
+ else if (request.hasCredentials())
+ {
+ return {"Client requires HTTP Basic authentication, but server doesn't provide it", false};
+ }
+
+ return {"", true};
+}
+
+void InterserverIOHTTPHandler::processQuery(HTTPServerRequest & request, HTTPServerResponse & response, Output & used_output)
+{
+ HTMLForm params(server.context()->getSettingsRef(), request);
+
+ LOG_TRACE(log, "Request URI: {}", request.getURI());
+
+ String endpoint_name = params.get("endpoint");
+ bool compress = params.get("compress") == "true";
+
+ auto & body = request.getStream();
+
+ auto endpoint = server.context()->getInterserverIOHandler().getEndpoint(endpoint_name);
+ /// Locked for read while query processing
+ std::shared_lock lock(endpoint->rwlock);
+ if (endpoint->blocker.isCancelled())
+ throw Exception(ErrorCodes::ABORTED, "Transferring part to replica was cancelled");
+
+ if (compress)
+ {
+ CompressedWriteBuffer compressed_out(*used_output.out);
+ endpoint->processQuery(params, body, compressed_out, response);
+ }
+ else
+ {
+ endpoint->processQuery(params, body, *used_output.out, response);
+ }
+}
+
+
+void InterserverIOHTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ setThreadName("IntersrvHandler");
+ ThreadStatus thread_status;
+
+ /// In order to work keep-alive.
+ if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
+ response.setChunkedTransferEncoding(true);
+
+ Output used_output;
+ const auto & config = server.config();
+ unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
+ used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(
+ response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
+
+ auto write_response = [&](const std::string & message)
+ {
+ auto & out = *used_output.out;
+ if (response.sent())
+ {
+ out.finalize();
+ return;
+ }
+
+ try
+ {
+ writeString(message, out);
+ out.finalize();
+ }
+ catch (...)
+ {
+ tryLogCurrentException(log);
+ out.finalize();
+ }
+ };
+
+ try
+ {
+ if (auto [message, success] = checkAuthentication(request); success)
+ {
+ processQuery(request, response, used_output);
+ used_output.out->finalize();
+ LOG_DEBUG(log, "Done processing query");
+ }
+ else
+ {
+ response.setStatusAndReason(HTTPServerResponse::HTTP_UNAUTHORIZED);
+ write_response(message);
+ LOG_WARNING(log, "Query processing failed request: '{}' authentication failed", request.getURI());
+ }
+ }
+ catch (Exception & e)
+ {
+ if (e.code() == ErrorCodes::TOO_MANY_SIMULTANEOUS_QUERIES)
+ {
+ used_output.out->finalize();
+ return;
+ }
+
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
+
+ /// Sending to remote server was cancelled due to server shutdown or drop table.
+ bool is_real_error = e.code() != ErrorCodes::ABORTED;
+
+ PreformattedMessage message = getCurrentExceptionMessageAndPattern(is_real_error);
+ write_response(message.text);
+
+ if (is_real_error)
+ LOG_ERROR(log, message);
+ else
+ LOG_INFO(log, message);
+ }
+ catch (...)
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
+ PreformattedMessage message = getCurrentExceptionMessageAndPattern(/* with_stacktrace */ false);
+ write_response(message.text);
+
+ LOG_ERROR(log, message);
+ }
+}
+
+
+}
diff --git a/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h
new file mode 100644
index 0000000000..da5b286b9e
--- /dev/null
+++ b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h
@@ -0,0 +1,51 @@
+#pragma once
+
+#include <Interpreters/InterserverCredentials.h>
+#include <Server/HTTP/HTTPRequestHandler.h>
+#include <Common/CurrentMetrics.h>
+
+#include <Poco/Logger.h>
+
+#include <memory>
+#include <string>
+
+
+namespace CurrentMetrics
+{
+ extern const Metric InterserverConnection;
+}
+
+namespace DB
+{
+
+class IServer;
+class WriteBufferFromHTTPServerResponse;
+
+class InterserverIOHTTPHandler : public HTTPRequestHandler
+{
+public:
+ explicit InterserverIOHTTPHandler(IServer & server_)
+ : server(server_)
+ , log(&Poco::Logger::get("InterserverIOHTTPHandler"))
+ {
+ }
+
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+
+private:
+ struct Output
+ {
+ std::shared_ptr<WriteBufferFromHTTPServerResponse> out;
+ };
+
+ IServer & server;
+ Poco::Logger * log;
+
+ CurrentMetrics::Increment metric_increment{CurrentMetrics::InterserverConnection};
+
+ void processQuery(HTTPServerRequest & request, HTTPServerResponse & response, Output & used_output);
+
+ std::pair<String, bool> checkAuthentication(HTTPServerRequest & request) const;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/KeeperTCPHandler.cpp b/contrib/clickhouse/src/Server/KeeperTCPHandler.cpp
new file mode 100644
index 0000000000..84ed738850
--- /dev/null
+++ b/contrib/clickhouse/src/Server/KeeperTCPHandler.cpp
@@ -0,0 +1,695 @@
+#include <Server/KeeperTCPHandler.h>
+
+#if USE_NURAFT
+
+#include <Common/ZooKeeper/ZooKeeperIO.h>
+#include <Core/Types.h>
+#include <IO/WriteBufferFromPocoSocket.h>
+#include <IO/ReadBufferFromPocoSocket.h>
+#include <Poco/Net/NetException.h>
+#include <Common/CurrentThread.h>
+#include <Common/Stopwatch.h>
+#include <Common/NetException.h>
+#include <Common/setThreadName.h>
+#include <Common/logger_useful.h>
+#include <base/defines.h>
+#include <chrono>
+#include <Common/PipeFDs.h>
+#include <Poco/Util/AbstractConfiguration.h>
+#include <IO/ReadBufferFromFileDescriptor.h>
+#include <queue>
+#include <mutex>
+#include <Coordination/FourLetterCommand.h>
+#include <base/hex.h>
+
+
+#ifdef POCO_HAVE_FD_EPOLL
+ #include <sys/epoll.h>
+#else
+ #include <poll.h>
+#endif
+
+
+namespace DB
+{
+
+struct LastOp
+{
+public:
+ String name{"NA"};
+ int64_t last_cxid{-1};
+ int64_t last_zxid{-1};
+ int64_t last_response_time{0};
+};
+
+static const LastOp EMPTY_LAST_OP {"NA", -1, -1, 0};
+
+namespace ErrorCodes
+{
+ extern const int SYSTEM_ERROR;
+ extern const int LOGICAL_ERROR;
+ extern const int UNEXPECTED_PACKET_FROM_CLIENT;
+ extern const int TIMEOUT_EXCEEDED;
+}
+
+struct PollResult
+{
+ size_t responses_count{0};
+ bool has_requests{false};
+ bool error{false};
+};
+
+struct SocketInterruptablePollWrapper
+{
+ int sockfd;
+ PipeFDs pipe;
+ ReadBufferFromFileDescriptor response_in;
+
+#if defined(POCO_HAVE_FD_EPOLL)
+ int epollfd;
+ epoll_event socket_event{};
+ epoll_event pipe_event{};
+#endif
+
+ using InterruptCallback = std::function<void()>;
+
+ explicit SocketInterruptablePollWrapper(const Poco::Net::StreamSocket & poco_socket_)
+ : sockfd(poco_socket_.impl()->sockfd())
+ , response_in(pipe.fds_rw[0])
+ {
+ pipe.setNonBlockingReadWrite();
+
+#if defined(POCO_HAVE_FD_EPOLL)
+ epollfd = epoll_create(2);
+ if (epollfd < 0)
+ throwFromErrno("Cannot epoll_create", ErrorCodes::SYSTEM_ERROR);
+
+ socket_event.events = EPOLLIN | EPOLLERR | EPOLLPRI;
+ socket_event.data.fd = sockfd;
+ if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &socket_event) < 0)
+ {
+ int err = ::close(epollfd);
+ chassert(!err || errno == EINTR);
+
+ throwFromErrno("Cannot insert socket into epoll queue", ErrorCodes::SYSTEM_ERROR);
+ }
+ pipe_event.events = EPOLLIN | EPOLLERR | EPOLLPRI;
+ pipe_event.data.fd = pipe.fds_rw[0];
+ if (epoll_ctl(epollfd, EPOLL_CTL_ADD, pipe.fds_rw[0], &pipe_event) < 0)
+ {
+ int err = ::close(epollfd);
+ chassert(!err || errno == EINTR);
+
+ throwFromErrno("Cannot insert socket into epoll queue", ErrorCodes::SYSTEM_ERROR);
+ }
+#endif
+ }
+
+ int getResponseFD() const
+ {
+ return pipe.fds_rw[1];
+ }
+
+ PollResult poll(Poco::Timespan remaining_time, const std::shared_ptr<ReadBufferFromPocoSocket> & in)
+ {
+
+ bool socket_ready = false;
+ bool fd_ready = false;
+
+ if (in->available() != 0)
+ socket_ready = true;
+
+ if (response_in.available() != 0)
+ fd_ready = true;
+
+ int rc = 0;
+ if (!fd_ready)
+ {
+#if defined(POCO_HAVE_FD_EPOLL)
+ epoll_event evout[2];
+ evout[0].data.fd = evout[1].data.fd = -1;
+ do
+ {
+ Poco::Timestamp start;
+ /// TODO: use epoll_pwait() for more precise timers
+ rc = epoll_wait(epollfd, evout, 2, static_cast<int>(remaining_time.totalMilliseconds()));
+ if (rc < 0 && errno == EINTR)
+ {
+ Poco::Timestamp end;
+ Poco::Timespan waited = end - start;
+ if (waited < remaining_time)
+ remaining_time -= waited;
+ else
+ remaining_time = 0;
+ }
+ }
+ while (rc < 0 && errno == EINTR);
+
+ for (int i = 0; i < rc; ++i)
+ {
+ if (evout[i].data.fd == sockfd)
+ socket_ready = true;
+ if (evout[i].data.fd == pipe.fds_rw[0])
+ fd_ready = true;
+ }
+#else
+ pollfd poll_buf[2];
+ poll_buf[0].fd = sockfd;
+ poll_buf[0].events = POLLIN;
+ poll_buf[1].fd = pipe.fds_rw[0];
+ poll_buf[1].events = POLLIN;
+
+ do
+ {
+ Poco::Timestamp start;
+ rc = ::poll(poll_buf, 2, static_cast<int>(remaining_time.totalMilliseconds()));
+ if (rc < 0 && errno == POCO_EINTR)
+ {
+ Poco::Timestamp end;
+ Poco::Timespan waited = end - start;
+ if (waited < remaining_time)
+ remaining_time -= waited;
+ else
+ remaining_time = 0;
+ }
+ }
+ while (rc < 0 && errno == POCO_EINTR);
+
+ if (rc >= 1)
+ {
+ if (poll_buf[0].revents & POLLIN)
+ socket_ready = true;
+ if (poll_buf[1].revents & POLLIN)
+ fd_ready = true;
+ }
+#endif
+ }
+
+ PollResult result{};
+ result.has_requests = socket_ready;
+ if (fd_ready)
+ {
+ UInt8 dummy;
+ readIntBinary(dummy, response_in);
+ result.responses_count = 1;
+ auto available = response_in.available();
+ response_in.ignore(available);
+ result.responses_count += available;
+ }
+
+ if (rc < 0)
+ result.error = true;
+
+ return result;
+ }
+
+#if defined(POCO_HAVE_FD_EPOLL)
+ ~SocketInterruptablePollWrapper()
+ {
+ int err = ::close(epollfd);
+ chassert(!err || errno == EINTR);
+ }
+#endif
+};
+
+KeeperTCPHandler::KeeperTCPHandler(
+ const Poco::Util::AbstractConfiguration & config_ref,
+ std::shared_ptr<KeeperDispatcher> keeper_dispatcher_,
+ Poco::Timespan receive_timeout_,
+ Poco::Timespan send_timeout_,
+ const Poco::Net::StreamSocket & socket_)
+ : Poco::Net::TCPServerConnection(socket_)
+ , log(&Poco::Logger::get("KeeperTCPHandler"))
+ , keeper_dispatcher(keeper_dispatcher_)
+ , operation_timeout(
+ 0,
+ config_ref.getUInt(
+ "keeper_server.coordination_settings.operation_timeout_ms", Coordination::DEFAULT_OPERATION_TIMEOUT_MS) * 1000)
+ , min_session_timeout(
+ 0,
+ config_ref.getUInt(
+ "keeper_server.coordination_settings.min_session_timeout_ms", Coordination::DEFAULT_MIN_SESSION_TIMEOUT_MS) * 1000)
+ , max_session_timeout(
+ 0,
+ config_ref.getUInt(
+ "keeper_server.coordination_settings.session_timeout_ms", Coordination::DEFAULT_MAX_SESSION_TIMEOUT_MS) * 1000)
+ , poll_wrapper(std::make_unique<SocketInterruptablePollWrapper>(socket_))
+ , send_timeout(send_timeout_)
+ , receive_timeout(receive_timeout_)
+ , responses(std::make_unique<ThreadSafeResponseQueue>(std::numeric_limits<size_t>::max()))
+ , last_op(std::make_unique<LastOp>(EMPTY_LAST_OP))
+{
+ KeeperTCPHandler::registerConnection(this);
+}
+
+void KeeperTCPHandler::sendHandshake(bool has_leader)
+{
+ Coordination::write(Coordination::SERVER_HANDSHAKE_LENGTH, *out);
+ if (has_leader)
+ {
+ Coordination::write(Coordination::ZOOKEEPER_PROTOCOL_VERSION, *out);
+ }
+ else
+ {
+ /// Ignore connections if we are not leader, client will throw exception
+ /// and reconnect to another replica faster. ClickHouse client provide
+ /// clear message for such protocol version.
+ Coordination::write(Coordination::KEEPER_PROTOCOL_VERSION_CONNECTION_REJECT, *out);
+ }
+
+ Coordination::write(static_cast<int32_t>(session_timeout.totalMilliseconds()), *out);
+ Coordination::write(session_id, *out);
+ std::array<char, Coordination::PASSWORD_LENGTH> passwd{};
+ Coordination::write(passwd, *out);
+ out->next();
+}
+
+void KeeperTCPHandler::run()
+{
+ runImpl();
+}
+
+Poco::Timespan KeeperTCPHandler::receiveHandshake(int32_t handshake_length)
+{
+ int32_t protocol_version;
+ int64_t last_zxid_seen;
+ int32_t timeout_ms;
+ int64_t previous_session_id = 0; /// We don't support session restore. So previous session_id is always zero.
+ std::array<char, Coordination::PASSWORD_LENGTH> passwd {};
+
+ if (!isHandShake(handshake_length))
+ throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected handshake length received: {}", toString(handshake_length));
+
+ Coordination::read(protocol_version, *in);
+
+ if (protocol_version != Coordination::ZOOKEEPER_PROTOCOL_VERSION)
+ throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected protocol version: {}", toString(protocol_version));
+
+ Coordination::read(last_zxid_seen, *in);
+ Coordination::read(timeout_ms, *in);
+
+ /// TODO Stop ignoring this value
+ Coordination::read(previous_session_id, *in);
+ Coordination::read(passwd, *in);
+
+ int8_t readonly;
+ if (handshake_length == Coordination::CLIENT_HANDSHAKE_LENGTH_WITH_READONLY)
+ Coordination::read(readonly, *in);
+
+ return Poco::Timespan(timeout_ms * 1000);
+}
+
+
+void KeeperTCPHandler::runImpl()
+{
+ setThreadName("KeeperHandler");
+ ThreadStatus thread_status;
+
+ socket().setReceiveTimeout(receive_timeout);
+ socket().setSendTimeout(send_timeout);
+ socket().setNoDelay(true);
+
+ in = std::make_shared<ReadBufferFromPocoSocket>(socket());
+ out = std::make_shared<WriteBufferFromPocoSocket>(socket());
+
+ if (in->eof())
+ {
+ LOG_WARNING(log, "Client has not sent any data.");
+ return;
+ }
+
+ int32_t header;
+ try
+ {
+ Coordination::read(header, *in);
+ }
+ catch (const Exception & e)
+ {
+ LOG_WARNING(log, "Error while read connection header {}", e.displayText());
+ return;
+ }
+
+ /// All four letter word command code is larger than 2^24 or lower than 0.
+ /// Hand shake package length must be lower than 2^24 and larger than 0.
+ /// So collision never happens.
+ int32_t four_letter_cmd = header;
+ if (!isHandShake(four_letter_cmd))
+ {
+ connected.store(true, std::memory_order_relaxed);
+ tryExecuteFourLetterWordCmd(four_letter_cmd);
+ return;
+ }
+
+ try
+ {
+ int32_t handshake_length = header;
+ auto client_timeout = receiveHandshake(handshake_length);
+
+ if (client_timeout.totalMilliseconds() == 0)
+ client_timeout = Poco::Timespan(Coordination::DEFAULT_SESSION_TIMEOUT_MS * Poco::Timespan::MILLISECONDS);
+ session_timeout = std::max(client_timeout, min_session_timeout);
+ session_timeout = std::min(session_timeout, max_session_timeout);
+ }
+ catch (const Exception & e) /// Typical for an incorrect username, password, or address.
+ {
+ LOG_WARNING(log, "Cannot receive handshake {}", e.displayText());
+ return;
+ }
+
+ if (keeper_dispatcher->isServerActive())
+ {
+ try
+ {
+ LOG_INFO(log, "Requesting session ID for the new client");
+ session_id = keeper_dispatcher->getSessionID(session_timeout.totalMilliseconds());
+ LOG_INFO(log, "Received session ID {}", session_id);
+ }
+ catch (const Exception & e)
+ {
+ LOG_WARNING(log, "Cannot receive session id {}", e.displayText());
+ sendHandshake(false);
+ return;
+
+ }
+
+ sendHandshake(true);
+ }
+ else
+ {
+ LOG_WARNING(log, "Ignoring user request, because the server is not active yet");
+ sendHandshake(false);
+ return;
+ }
+
+ auto response_fd = poll_wrapper->getResponseFD();
+ auto response_callback = [responses = this->responses, response_fd](const Coordination::ZooKeeperResponsePtr & response)
+ {
+ if (!responses->push(response))
+ throw Exception(ErrorCodes::SYSTEM_ERROR,
+ "Could not push response with xid {} and zxid {}",
+ response->xid,
+ response->zxid);
+
+ UInt8 single_byte = 1;
+ [[maybe_unused]] ssize_t result = write(response_fd, &single_byte, sizeof(single_byte));
+ };
+ keeper_dispatcher->registerSession(session_id, response_callback);
+
+ Stopwatch logging_stopwatch;
+ auto log_long_operation = [&](const String & operation)
+ {
+ constexpr UInt64 operation_max_ms = 500;
+ auto elapsed_ms = logging_stopwatch.elapsedMilliseconds();
+ if (operation_max_ms < elapsed_ms)
+ LOG_TEST(log, "{} for session {} took {} ms", operation, session_id, elapsed_ms);
+ logging_stopwatch.restart();
+ };
+
+ session_stopwatch.start();
+ connected.store(true, std::memory_order_release);
+ bool close_received = false;
+
+ try
+ {
+ while (true)
+ {
+ using namespace std::chrono_literals;
+
+ PollResult result = poll_wrapper->poll(session_timeout, in);
+ log_long_operation("Polling socket");
+ if (result.has_requests && !close_received)
+ {
+ if (in->eof())
+ {
+ LOG_DEBUG(log, "Client closed connection, session id #{}", session_id);
+ keeper_dispatcher->finishSession(session_id);
+ break;
+ }
+
+ auto [received_op, received_xid] = receiveRequest();
+ packageReceived();
+ log_long_operation("Receiving request");
+
+ if (received_op == Coordination::OpNum::Close)
+ {
+ LOG_DEBUG(log, "Received close event with xid {} for session id #{}", received_xid, session_id);
+ close_xid = received_xid;
+ close_received = true;
+ }
+ else if (received_op == Coordination::OpNum::Heartbeat)
+ {
+ LOG_TRACE(log, "Received heartbeat for session #{}", session_id);
+ }
+ else
+ operations[received_xid] = Poco::Timestamp();
+
+ /// Each request restarts session stopwatch
+ session_stopwatch.restart();
+ }
+
+ /// Process exact amount of responses from pipe
+ /// otherwise state of responses queue and signaling pipe
+ /// became inconsistent and race condition is possible.
+ while (result.responses_count != 0)
+ {
+ Coordination::ZooKeeperResponsePtr response;
+
+ if (!responses->tryPop(response))
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "We must have ready response, but queue is empty. It's a bug.");
+ log_long_operation("Waiting for response to be ready");
+
+ if (response->xid == close_xid)
+ {
+ LOG_DEBUG(log, "Session #{} successfully closed", session_id);
+ return;
+ }
+
+ updateStats(response);
+ packageSent();
+
+ response->write(*out);
+ log_long_operation("Sending response");
+ if (response->error == Coordination::Error::ZSESSIONEXPIRED)
+ {
+ LOG_DEBUG(log, "Session #{} expired because server shutting down or quorum is not alive", session_id);
+ keeper_dispatcher->finishSession(session_id);
+ return;
+ }
+
+ result.responses_count--;
+ }
+
+ if (result.error)
+ throw Exception(ErrorCodes::SYSTEM_ERROR, "Exception happened while reading from socket");
+
+ if (session_stopwatch.elapsedMicroseconds() > static_cast<UInt64>(session_timeout.totalMicroseconds()))
+ {
+ LOG_DEBUG(log, "Session #{} expired", session_id);
+ keeper_dispatcher->finishSession(session_id);
+ break;
+ }
+ }
+ }
+ catch (const Exception & ex)
+ {
+ log_long_operation("Unknown operation");
+ LOG_TRACE(log, "Has {} responses in the queue", responses->size());
+ LOG_INFO(log, "Got exception processing session #{}: {}", session_id, getExceptionMessage(ex, true));
+ keeper_dispatcher->finishSession(session_id);
+ }
+}
+
+bool KeeperTCPHandler::isHandShake(int32_t handshake_length)
+{
+ return handshake_length == Coordination::CLIENT_HANDSHAKE_LENGTH
+ || handshake_length == Coordination::CLIENT_HANDSHAKE_LENGTH_WITH_READONLY;
+}
+
+bool KeeperTCPHandler::tryExecuteFourLetterWordCmd(int32_t command)
+{
+ if (!FourLetterCommandFactory::instance().isKnown(command))
+ {
+ LOG_WARNING(log, "invalid four letter command {}", IFourLetterCommand::toName(command));
+ return false;
+ }
+ else if (!FourLetterCommandFactory::instance().isEnabled(command))
+ {
+ LOG_WARNING(log, "Not enabled four letter command {}", IFourLetterCommand::toName(command));
+ return false;
+ }
+ else
+ {
+ auto command_ptr = FourLetterCommandFactory::instance().get(command);
+ LOG_DEBUG(log, "Receive four letter command {}", command_ptr->name());
+
+ try
+ {
+ String res = command_ptr->run();
+ out->write(res.data(), res.size());
+ out->next();
+ }
+ catch (...)
+ {
+ tryLogCurrentException(log, "Error when executing four letter command " + command_ptr->name());
+ }
+
+ return true;
+ }
+}
+
+std::pair<Coordination::OpNum, Coordination::XID> KeeperTCPHandler::receiveRequest()
+{
+ int32_t length;
+ Coordination::read(length, *in);
+ int32_t xid;
+ Coordination::read(xid, *in);
+
+ Coordination::OpNum opnum;
+ Coordination::read(opnum, *in);
+
+ Coordination::ZooKeeperRequestPtr request = Coordination::ZooKeeperRequestFactory::instance().get(opnum);
+ request->xid = xid;
+ request->readImpl(*in);
+
+ if (!keeper_dispatcher->putRequest(request, session_id))
+ throw Exception(ErrorCodes::TIMEOUT_EXCEEDED, "Session {} already disconnected", session_id);
+ return std::make_pair(opnum, xid);
+}
+
+void KeeperTCPHandler::packageSent()
+{
+ conn_stats.incrementPacketsSent();
+ keeper_dispatcher->incrementPacketsSent();
+}
+
+void KeeperTCPHandler::packageReceived()
+{
+ conn_stats.incrementPacketsReceived();
+ keeper_dispatcher->incrementPacketsReceived();
+}
+
+void KeeperTCPHandler::updateStats(Coordination::ZooKeeperResponsePtr & response)
+{
+ /// update statistics ignoring watch response and heartbeat.
+ if (response->xid != Coordination::WATCH_XID && response->getOpNum() != Coordination::OpNum::Heartbeat)
+ {
+ Int64 elapsed = (Poco::Timestamp() - operations[response->xid]) / 1000;
+ conn_stats.updateLatency(elapsed);
+
+ operations.erase(response->xid);
+ keeper_dispatcher->updateKeeperStatLatency(elapsed);
+
+ last_op.set(std::make_unique<LastOp>(LastOp{
+ .name = Coordination::toString(response->getOpNum()),
+ .last_cxid = response->xid,
+ .last_zxid = response->zxid,
+ .last_response_time = Poco::Timestamp().epochMicroseconds() / 1000,
+ }));
+ }
+
+}
+
+KeeperConnectionStats & KeeperTCPHandler::getConnectionStats()
+{
+ return conn_stats;
+}
+
+void KeeperTCPHandler::dumpStats(WriteBufferFromOwnString & buf, bool brief)
+{
+ if (!connected.load(std::memory_order_acquire))
+ return;
+
+ auto & stats = getConnectionStats();
+
+ writeText(' ', buf);
+ writeText(socket().peerAddress().toString(), buf);
+ writeText("(recved=", buf);
+ writeIntText(stats.getPacketsReceived(), buf);
+ writeText(",sent=", buf);
+ writeIntText(stats.getPacketsSent(), buf);
+ if (!brief)
+ {
+ if (session_id != 0)
+ {
+ writeText(",sid=0x", buf);
+ writeText(getHexUIntLowercase(session_id), buf);
+
+ writeText(",lop=", buf);
+ LastOpPtr op = last_op.get();
+ writeText(op->name, buf);
+ writeText(",est=", buf);
+ writeIntText(established.epochMicroseconds() / 1000, buf);
+ writeText(",to=", buf);
+ writeIntText(session_timeout.totalMilliseconds(), buf);
+ int64_t last_cxid = op->last_cxid;
+ if (last_cxid >= 0)
+ {
+ writeText(",lcxid=0x", buf);
+ writeText(getHexUIntLowercase(last_cxid), buf);
+ }
+ writeText(",lzxid=0x", buf);
+ writeText(getHexUIntLowercase(op->last_zxid), buf);
+ writeText(",lresp=", buf);
+ writeIntText(op->last_response_time, buf);
+
+ writeText(",llat=", buf);
+ writeIntText(stats.getLastLatency(), buf);
+ writeText(",minlat=", buf);
+ writeIntText(stats.getMinLatency(), buf);
+ writeText(",avglat=", buf);
+ writeIntText(stats.getAvgLatency(), buf);
+ writeText(",maxlat=", buf);
+ writeIntText(stats.getMaxLatency(), buf);
+ }
+ }
+ writeText(')', buf);
+ writeText('\n', buf);
+}
+
+void KeeperTCPHandler::resetStats()
+{
+ conn_stats.reset();
+ last_op.set(std::make_unique<LastOp>(EMPTY_LAST_OP));
+}
+
+KeeperTCPHandler::~KeeperTCPHandler()
+{
+ KeeperTCPHandler::unregisterConnection(this);
+}
+
+std::mutex KeeperTCPHandler::conns_mutex;
+std::unordered_set<KeeperTCPHandler *> KeeperTCPHandler::connections;
+
+void KeeperTCPHandler::registerConnection(KeeperTCPHandler * conn)
+{
+ std::lock_guard lock(conns_mutex);
+ connections.insert(conn);
+}
+
+void KeeperTCPHandler::unregisterConnection(KeeperTCPHandler * conn)
+{
+ std::lock_guard lock(conns_mutex);
+ connections.erase(conn);
+}
+
+void KeeperTCPHandler::dumpConnections(WriteBufferFromOwnString & buf, bool brief)
+{
+ std::lock_guard lock(conns_mutex);
+ for (auto * conn : connections)
+ {
+ conn->dumpStats(buf, brief);
+ }
+}
+
+void KeeperTCPHandler::resetConnsStats()
+{
+ std::lock_guard lock(conns_mutex);
+ for (auto * conn : connections)
+ {
+ conn->resetStats();
+ }
+}
+
+}
+
+#endif
diff --git a/contrib/clickhouse/src/Server/KeeperTCPHandler.h b/contrib/clickhouse/src/Server/KeeperTCPHandler.h
new file mode 100644
index 0000000000..653a599de6
--- /dev/null
+++ b/contrib/clickhouse/src/Server/KeeperTCPHandler.h
@@ -0,0 +1,113 @@
+#pragma once
+
+#include "clickhouse_config.h"
+
+#if USE_NURAFT
+
+#include <Poco/Net/TCPServerConnection.h>
+#include <Common/MultiVersion.h>
+#include "IServer.h"
+#include <Common/Stopwatch.h>
+#include <Common/ZooKeeper/ZooKeeperCommon.h>
+#include <Common/ZooKeeper/ZooKeeperConstants.h>
+#include <Common/ConcurrentBoundedQueue.h>
+#include <Coordination/KeeperDispatcher.h>
+#include <IO/WriteBufferFromPocoSocket.h>
+#include <IO/ReadBufferFromPocoSocket.h>
+#include <unordered_map>
+#include <Coordination/KeeperConnectionStats.h>
+#include <Poco/Timestamp.h>
+
+namespace DB
+{
+
+struct SocketInterruptablePollWrapper;
+using SocketInterruptablePollWrapperPtr = std::unique_ptr<SocketInterruptablePollWrapper>;
+
+using ThreadSafeResponseQueue = ConcurrentBoundedQueue<Coordination::ZooKeeperResponsePtr>;
+using ThreadSafeResponseQueuePtr = std::shared_ptr<ThreadSafeResponseQueue>;
+
+struct LastOp;
+using LastOpMultiVersion = MultiVersion<LastOp>;
+using LastOpPtr = LastOpMultiVersion::Version;
+
+class KeeperTCPHandler : public Poco::Net::TCPServerConnection
+{
+public:
+ static void registerConnection(KeeperTCPHandler * conn);
+ static void unregisterConnection(KeeperTCPHandler * conn);
+ /// dump all connections statistics
+ static void dumpConnections(WriteBufferFromOwnString & buf, bool brief);
+ static void resetConnsStats();
+
+private:
+ static std::mutex conns_mutex;
+ /// all connections
+ static std::unordered_set<KeeperTCPHandler *> connections;
+
+public:
+ KeeperTCPHandler(
+ const Poco::Util::AbstractConfiguration & config_ref,
+ std::shared_ptr<KeeperDispatcher> keeper_dispatcher_,
+ Poco::Timespan receive_timeout_,
+ Poco::Timespan send_timeout_,
+ const Poco::Net::StreamSocket & socket_);
+ void run() override;
+
+ KeeperConnectionStats & getConnectionStats();
+ void dumpStats(WriteBufferFromOwnString & buf, bool brief);
+ void resetStats();
+
+ ~KeeperTCPHandler() override;
+
+private:
+ Poco::Logger * log;
+ std::shared_ptr<KeeperDispatcher> keeper_dispatcher;
+ Poco::Timespan operation_timeout;
+ Poco::Timespan min_session_timeout;
+ Poco::Timespan max_session_timeout;
+ Poco::Timespan session_timeout;
+ int64_t session_id{-1};
+ Stopwatch session_stopwatch;
+ SocketInterruptablePollWrapperPtr poll_wrapper;
+ Poco::Timespan send_timeout;
+ Poco::Timespan receive_timeout;
+
+ ThreadSafeResponseQueuePtr responses;
+
+ Coordination::XID close_xid = Coordination::CLOSE_XID;
+
+ /// Streams for reading/writing from/to client connection socket.
+ std::shared_ptr<ReadBufferFromPocoSocket> in;
+ std::shared_ptr<WriteBufferFromPocoSocket> out;
+
+ std::atomic<bool> connected{false};
+
+ void runImpl();
+
+ void sendHandshake(bool has_leader);
+ Poco::Timespan receiveHandshake(int32_t handshake_length);
+
+ static bool isHandShake(int32_t handshake_length);
+ bool tryExecuteFourLetterWordCmd(int32_t command);
+
+ std::pair<Coordination::OpNum, Coordination::XID> receiveRequest();
+
+ void packageSent();
+ void packageReceived();
+
+ void updateStats(Coordination::ZooKeeperResponsePtr & response);
+
+ Poco::Timestamp established;
+
+ using Operations = std::unordered_map<Coordination::XID, Poco::Timestamp>;
+ Operations operations;
+
+ LastOpMultiVersion last_op;
+
+ KeeperConnectionStats conn_stats;
+
+};
+
+}
+#endif
diff --git a/contrib/clickhouse/src/Server/MySQLHandler.cpp b/contrib/clickhouse/src/Server/MySQLHandler.cpp
new file mode 100644
index 0000000000..f98b86e6cf
--- /dev/null
+++ b/contrib/clickhouse/src/Server/MySQLHandler.cpp
@@ -0,0 +1,508 @@
+#include "MySQLHandler.h"
+
+#include <limits>
+#include <Common/NetException.h>
+#include <Common/OpenSSLHelpers.h>
+#include <Core/MySQL/PacketsGeneric.h>
+#include <Core/MySQL/PacketsConnection.h>
+#include <Core/MySQL/PacketsProtocolText.h>
+#include <Core/NamesAndTypes.h>
+#include <Interpreters/Session.h>
+#include <Interpreters/executeQuery.h>
+#include <IO/copyData.h>
+#include <IO/LimitReadBuffer.h>
+#include <IO/ReadBufferFromPocoSocket.h>
+#include <IO/ReadBufferFromString.h>
+#include <IO/WriteBufferFromPocoSocket.h>
+#include <IO/WriteBufferFromString.h>
+#include <IO/ReadHelpers.h>
+#include <Server/TCPServer.h>
+#include <Storages/IStorage.h>
+#include <regex>
+#include <Common/setThreadName.h>
+#include <Core/MySQL/Authentication.h>
+#include <Common/logger_useful.h>
+#include <base/scope_guard.h>
+
+#include "config_version.h"
+
+#if USE_SSL
+# include <Poco/Crypto/RSAKey.h>
+# include <Poco/Net/SSLManager.h>
+# include <Poco/Net/SecureStreamSocket.h>
+
+#endif
+
+namespace DB
+{
+
+using namespace MySQLProtocol;
+using namespace MySQLProtocol::Generic;
+using namespace MySQLProtocol::ProtocolText;
+using namespace MySQLProtocol::ConnectionPhase;
+
+#if USE_SSL
+using Poco::Net::SecureStreamSocket;
+using Poco::Net::SSLManager;
+#endif
+
+namespace ErrorCodes
+{
+ extern const int CANNOT_READ_ALL_DATA;
+ extern const int NOT_IMPLEMENTED;
+ extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
+ extern const int SUPPORT_IS_DISABLED;
+ extern const int UNSUPPORTED_METHOD;
+}
+
+
+static const size_t PACKET_HEADER_SIZE = 4;
+static const size_t SSL_REQUEST_PAYLOAD_SIZE = 32;
+
+static String selectEmptyReplacementQuery(const String & query);
+static String showTableStatusReplacementQuery(const String & query);
+static String killConnectionIdReplacementQuery(const String & query);
+static String selectLimitReplacementQuery(const String & query);
+
+MySQLHandler::MySQLHandler(
+ IServer & server_,
+ TCPServer & tcp_server_,
+ const Poco::Net::StreamSocket & socket_,
+ bool ssl_enabled, uint32_t connection_id_)
+ : Poco::Net::TCPServerConnection(socket_)
+ , server(server_)
+ , tcp_server(tcp_server_)
+ , log(&Poco::Logger::get("MySQLHandler"))
+ , connection_id(connection_id_)
+ , auth_plugin(new MySQLProtocol::Authentication::Native41())
+{
+ server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
+ if (ssl_enabled)
+ server_capabilities |= CLIENT_SSL;
+
+ replacements.emplace("KILL QUERY", killConnectionIdReplacementQuery);
+ replacements.emplace("SHOW TABLE STATUS LIKE", showTableStatusReplacementQuery);
+ replacements.emplace("SHOW VARIABLES", selectEmptyReplacementQuery);
+ replacements.emplace("SET SQL_SELECT_LIMIT", selectLimitReplacementQuery);
+}
+
+void MySQLHandler::run()
+{
+ setThreadName("MySQLHandler");
+ ThreadStatus thread_status;
+
+ session = std::make_unique<Session>(server.context(), ClientInfo::Interface::MYSQL);
+ SCOPE_EXIT({ session.reset(); });
+
+ session->setClientConnectionId(connection_id);
+
+ in = std::make_shared<ReadBufferFromPocoSocket>(socket());
+ out = std::make_shared<WriteBufferFromPocoSocket>(socket());
+ packet_endpoint = std::make_shared<MySQLProtocol::PacketEndpoint>(*in, *out, sequence_id);
+
+ try
+ {
+ Handshake handshake(server_capabilities, connection_id, VERSION_STRING + String("-") + VERSION_NAME,
+ auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci);
+ packet_endpoint->sendPacket<Handshake>(handshake, true);
+
+ LOG_TRACE(log, "Sent handshake");
+
+ HandshakeResponse handshake_response;
+ finishHandshake(handshake_response);
+ client_capabilities = handshake_response.capability_flags;
+ max_packet_size = handshake_response.max_packet_size ? handshake_response.max_packet_size : MAX_PACKET_LENGTH;
+
+ LOG_TRACE(log,
+ "Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}",
+ handshake_response.capability_flags,
+ handshake_response.max_packet_size,
+ static_cast<int>(handshake_response.character_set),
+ handshake_response.username,
+ handshake_response.auth_response.length(),
+ handshake_response.database,
+ handshake_response.auth_plugin_name);
+
+ if (!(client_capabilities & CLIENT_PROTOCOL_41))
+ throw Exception(ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES, "Required capability: CLIENT_PROTOCOL_41.");
+
+ authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response);
+
+ try
+ {
+ session->makeSessionContext();
+ session->sessionContext()->setDefaultFormat("MySQLWire");
+ if (!handshake_response.database.empty())
+ session->sessionContext()->setCurrentDatabase(handshake_response.database);
+ }
+ catch (const Exception & exc)
+ {
+ log->log(exc);
+ packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
+ }
+
+ OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
+ packet_endpoint->sendPacket(ok_packet, true);
+
+ while (tcp_server.isOpen())
+ {
+ packet_endpoint->resetSequenceId();
+ MySQLPacketPayloadReadBuffer payload = packet_endpoint->getPayload();
+
+ while (!in->poll(1000000))
+ if (!tcp_server.isOpen())
+ return;
+ char command = 0;
+ payload.readStrict(command);
+
+ // For commands which are executed without MemoryTracker.
+ LimitReadBuffer limited_payload(payload, 10000, /* trow_exception */ true, /* exact_limit */ {}, "too long MySQL packet.");
+
+ LOG_DEBUG(log, "Received command: {}. Connection id: {}.",
+ static_cast<int>(static_cast<unsigned char>(command)), connection_id);
+
+ if (!tcp_server.isOpen())
+ return;
+ try
+ {
+ switch (command)
+ {
+ case COM_QUIT:
+ return;
+ case COM_INIT_DB:
+ comInitDB(limited_payload);
+ break;
+ case COM_QUERY:
+ comQuery(payload);
+ break;
+ case COM_FIELD_LIST:
+ comFieldList(limited_payload);
+ break;
+ case COM_PING:
+ comPing();
+ break;
+ default:
+ throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Command {} is not implemented.", command);
+ }
+ }
+ catch (const NetException & exc)
+ {
+ log->log(exc);
+ throw;
+ }
+ catch (...)
+ {
+ tryLogCurrentException(log, "MySQLHandler: Cannot read packet: ");
+ packet_endpoint->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true);
+ }
+ }
+ }
+ catch (const Poco::Exception & exc)
+ {
+ log->log(exc);
+ }
+}
+
+/** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest.
+ * Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake.
+ * If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
+ */
+void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResponse & packet)
+{
+ size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE;
+
+ /// Buffer for SSLRequest or part of HandshakeResponse.
+ char buf[packet_size];
+ size_t pos = 0;
+
+ /// Reads at least count and at most packet_size bytes.
+ auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void {
+ while (pos < count)
+ {
+ int ret = socket().receiveBytes(buf + pos, static_cast<uint32_t>(packet_size - pos));
+ if (ret == 0)
+ {
+ throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data. Bytes read: {}. Bytes expected: 3", std::to_string(pos));
+ }
+ pos += ret;
+ }
+ };
+ read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes.
+
+ size_t payload_size = unalignedLoad<uint32_t>(buf) & 0xFFFFFFu;
+ LOG_TRACE(log, "payload size: {}", payload_size);
+
+ if (payload_size == SSL_REQUEST_PAYLOAD_SIZE)
+ {
+ finishHandshakeSSL(packet_size, buf, pos, read_bytes, packet);
+ }
+ else
+ {
+ /// Reading rest of HandshakeResponse.
+ packet_size = PACKET_HEADER_SIZE + payload_size;
+ WriteBufferFromOwnString buf_for_handshake_response;
+ buf_for_handshake_response.write(buf, pos);
+ copyData(*packet_endpoint->in, buf_for_handshake_response, packet_size - pos);
+ ReadBufferFromString payload(buf_for_handshake_response.str());
+ payload.ignore(PACKET_HEADER_SIZE);
+ packet.readPayloadWithUnpacked(payload);
+ packet_endpoint->sequence_id++;
+ }
+}
+
+void MySQLHandler::authenticate(const String & user_name, const String & auth_plugin_name, const String & initial_auth_response)
+{
+ try
+ {
+ // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
+ if (session->getAuthenticationTypeOrLogInFailure(user_name) == DB::AuthenticationType::SHA256_PASSWORD)
+ {
+ authPluginSSL();
+ }
+
+ std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
+ auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress());
+ }
+ catch (const Exception & exc)
+ {
+ LOG_ERROR(log, "Authentication for user {} failed.", user_name);
+ packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
+ throw;
+ }
+ LOG_DEBUG(log, "Authentication for user {} succeeded.", user_name);
+}
+
+void MySQLHandler::comInitDB(ReadBuffer & payload)
+{
+ String database;
+ readStringUntilEOF(database, payload);
+ LOG_DEBUG(log, "Setting current database to {}", database);
+ session->sessionContext()->setCurrentDatabase(database);
+ packet_endpoint->sendPacket(OKPacket(0, client_capabilities, 0, 0, 1), true);
+}
+
+void MySQLHandler::comFieldList(ReadBuffer & payload)
+{
+ ComFieldList packet;
+ packet.readPayloadWithUnpacked(payload);
+ const auto session_context = session->sessionContext();
+ String database = session_context->getCurrentDatabase();
+ StoragePtr table_ptr = DatabaseCatalog::instance().getTable({database, packet.table}, session_context);
+ auto metadata_snapshot = table_ptr->getInMemoryMetadataPtr();
+ for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll())
+ {
+ ColumnDefinition column_definition(
+ database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0, true
+ );
+ packet_endpoint->sendPacket(column_definition);
+ }
+ packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, 0, 0, 0), true);
+}
+
+void MySQLHandler::comPing()
+{
+ packet_endpoint->sendPacket(OKPacket(0x0, client_capabilities, 0, 0, 0), true);
+}
+
+static bool isFederatedServerSetupSetCommand(const String & query);
+
+void MySQLHandler::comQuery(ReadBuffer & payload)
+{
+ String query = String(payload.position(), payload.buffer().end());
+
+ // This is a workaround in order to support adding ClickHouse to MySQL using federated server.
+ // As Clickhouse doesn't support these statements, we just send OK packet in response.
+ if (isFederatedServerSetupSetCommand(query))
+ {
+ packet_endpoint->sendPacket(OKPacket(0x00, client_capabilities, 0, 0, 0), true);
+ }
+ else
+ {
+ String replacement_query;
+ bool should_replace = false;
+ bool with_output = false;
+
+ for (auto const & x : replacements)
+ {
+ if (0 == strncasecmp(x.first.c_str(), query.c_str(), x.first.size()))
+ {
+ should_replace = true;
+ replacement_query = x.second(query);
+ break;
+ }
+ }
+
+ ReadBufferFromString replacement(replacement_query);
+
+ auto query_context = session->makeQueryContext();
+ query_context->setCurrentQueryId(fmt::format("mysql:{}:{}", connection_id, toString(UUIDHelpers::generateV4())));
+ CurrentThread::QueryScope query_scope{query_context};
+
+ std::atomic<size_t> affected_rows {0};
+ auto prev = query_context->getProgressCallback();
+ query_context->setProgressCallback([&, my_prev = prev](const Progress & progress)
+ {
+ if (my_prev)
+ my_prev(progress);
+
+ affected_rows += progress.written_rows;
+ });
+
+ FormatSettings format_settings;
+ format_settings.mysql_wire.client_capabilities = client_capabilities;
+ format_settings.mysql_wire.max_packet_size = max_packet_size;
+ format_settings.mysql_wire.sequence_id = &sequence_id;
+
+ auto set_result_details = [&with_output](const QueryResultDetails & details)
+ {
+ if (details.format)
+ {
+ if (*details.format != "MySQLWire")
+ throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
+
+ with_output = true;
+ }
+ };
+
+ executeQuery(should_replace ? replacement : payload, *out, false, query_context, set_result_details, format_settings);
+
+ if (!with_output)
+ packet_endpoint->sendPacket(OKPacket(0x00, client_capabilities, affected_rows, 0, 0), true);
+ }
+}
+
+void MySQLHandler::authPluginSSL()
+{
+ throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,
+ "ClickHouse was built without SSL support. Try specifying password using double SHA1 in users.xml.");
+}
+
+void MySQLHandler::finishHandshakeSSL(
+ [[maybe_unused]] size_t packet_size, [[maybe_unused]] char * buf, [[maybe_unused]] size_t pos,
+ [[maybe_unused]] std::function<void(size_t)> read_bytes, [[maybe_unused]] MySQLProtocol::ConnectionPhase::HandshakeResponse & packet)
+{
+ throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "Client requested SSL, while it is disabled.");
+}
+
+#if USE_SSL
+MySQLHandlerSSL::MySQLHandlerSSL(
+ IServer & server_,
+ TCPServer & tcp_server_,
+ const Poco::Net::StreamSocket & socket_,
+ bool ssl_enabled,
+ uint32_t connection_id_,
+ RSA & public_key_,
+ RSA & private_key_)
+ : MySQLHandler(server_, tcp_server_, socket_, ssl_enabled, connection_id_)
+ , public_key(public_key_)
+ , private_key(private_key_)
+{}
+
+void MySQLHandlerSSL::authPluginSSL()
+{
+ auth_plugin = std::make_unique<MySQLProtocol::Authentication::Sha256Password>(public_key, private_key, log);
+}
+
+void MySQLHandlerSSL::finishHandshakeSSL(
+ size_t packet_size, char *buf, size_t pos, std::function<void(size_t)> read_bytes,
+ MySQLProtocol::ConnectionPhase::HandshakeResponse & packet)
+{
+ read_bytes(packet_size); /// Reading rest SSLRequest.
+ SSLRequest ssl_request;
+ ReadBufferFromMemory payload(buf, pos);
+ payload.ignore(PACKET_HEADER_SIZE);
+ ssl_request.readPayloadWithUnpacked(payload);
+ client_capabilities = ssl_request.capability_flags;
+ max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
+ secure_connection = true;
+ ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
+ in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
+ out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
+ sequence_id = 2;
+ packet_endpoint = std::make_shared<MySQLProtocol::PacketEndpoint>(*in, *out, sequence_id);
+ packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
+}
+
+#endif
+
+static bool isFederatedServerSetupSetCommand(const String & query)
+{
+ static const std::regex expr{
+ "(^(SET NAMES(.*)))"
+ "|(^(SET character_set_results(.*)))"
+ "|(^(SET FOREIGN_KEY_CHECKS(.*)))"
+ "|(^(SET AUTOCOMMIT(.*)))"
+ "|(^(SET sql_mode(.*)))"
+ "|(^(SET @@(.*)))"
+ "|(^(SET SESSION TRANSACTION ISOLATION LEVEL(.*)))"
+ , std::regex::icase};
+ return 1 == std::regex_match(query, expr);
+}
+
+/// Replace "[query(such as SHOW VARIABLES...)]" into "".
+static String selectEmptyReplacementQuery(const String & query)
+{
+ std::ignore = query;
+ return "select ''";
+}
+
+/// Replace "SHOW TABLE STATUS LIKE 'xx'" into "SELECT ... FROM system.tables WHERE name LIKE 'xx'".
+static String showTableStatusReplacementQuery(const String & query)
+{
+ const String prefix = "SHOW TABLE STATUS LIKE ";
+ if (query.size() > prefix.size())
+ {
+ String suffix = query.data() + prefix.length();
+ return (
+ "SELECT"
+ " name AS Name,"
+ " engine AS Engine,"
+ " '10' AS Version,"
+ " 'Dynamic' AS Row_format,"
+ " 0 AS Rows,"
+ " 0 AS Avg_row_length,"
+ " 0 AS Data_length,"
+ " 0 AS Max_data_length,"
+ " 0 AS Index_length,"
+ " 0 AS Data_free,"
+ " 'NULL' AS Auto_increment,"
+ " metadata_modification_time AS Create_time,"
+ " metadata_modification_time AS Update_time,"
+ " metadata_modification_time AS Check_time,"
+ " 'utf8_bin' AS Collation,"
+ " 'NULL' AS Checksum,"
+ " '' AS Create_options,"
+ " '' AS Comment"
+ " FROM system.tables"
+ " WHERE name LIKE "
+ + suffix);
+ }
+ return query;
+}
+
+static String selectLimitReplacementQuery(const String & query)
+{
+ const String prefix = "SET SQL_SELECT_LIMIT";
+ if (query.starts_with(prefix))
+ return "SET limit" + std::string(query.data() + prefix.length());
+ return query;
+}
+
+/// Replace "KILL QUERY [connection_id]" into "KILL QUERY WHERE query_id LIKE 'mysql:[connection_id]:xxx'".
+static String killConnectionIdReplacementQuery(const String & query)
+{
+ const String prefix = "KILL QUERY ";
+ if (query.size() > prefix.size())
+ {
+ String suffix = query.data() + prefix.length();
+ static const std::regex expr{"^[0-9]"};
+ if (std::regex_match(suffix, expr))
+ {
+ String replacement = fmt::format("KILL QUERY WHERE query_id LIKE 'mysql:{}:%'", suffix);
+ return replacement;
+ }
+ }
+ return query;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/MySQLHandler.h b/contrib/clickhouse/src/Server/MySQLHandler.h
new file mode 100644
index 0000000000..a4de5c4590
--- /dev/null
+++ b/contrib/clickhouse/src/Server/MySQLHandler.h
@@ -0,0 +1,111 @@
+#pragma once
+
+#include <Poco/Net/TCPServerConnection.h>
+#include <base/getFQDNOrHostName.h>
+#include <Common/CurrentMetrics.h>
+#include <Core/MySQL/Authentication.h>
+#include <Core/MySQL/PacketsGeneric.h>
+#include <Core/MySQL/PacketsConnection.h>
+#include <Core/MySQL/PacketsProtocolText.h>
+#include "IServer.h"
+
+#include "clickhouse_config.h"
+
+#if USE_SSL
+# include <Poco/Net/SecureStreamSocket.h>
+#endif
+
+#include <memory>
+
+namespace CurrentMetrics
+{
+ extern const Metric MySQLConnection;
+}
+
+namespace DB
+{
+class ReadBufferFromPocoSocket;
+class TCPServer;
+
+/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
+class MySQLHandler : public Poco::Net::TCPServerConnection
+{
+public:
+ MySQLHandler(
+ IServer & server_,
+ TCPServer & tcp_server_,
+ const Poco::Net::StreamSocket & socket_,
+ bool ssl_enabled,
+ uint32_t connection_id_);
+
+ void run() final;
+
+protected:
+ CurrentMetrics::Increment metric_increment{CurrentMetrics::MySQLConnection};
+
+ /// Enables SSL, if client requested.
+ void finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResponse &);
+
+ void comQuery(ReadBuffer & payload);
+
+ void comFieldList(ReadBuffer & payload);
+
+ void comPing();
+
+ void comInitDB(ReadBuffer & payload);
+
+ void authenticate(const String & user_name, const String & auth_plugin_name, const String & auth_response);
+
+ virtual void authPluginSSL();
+ virtual void finishHandshakeSSL(size_t packet_size, char * buf, size_t pos, std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet);
+
+ IServer & server;
+ TCPServer & tcp_server;
+ Poco::Logger * log;
+ uint32_t connection_id = 0;
+
+ uint32_t server_capabilities = 0;
+ uint32_t client_capabilities = 0;
+ size_t max_packet_size = 0;
+ uint8_t sequence_id = 0;
+
+ MySQLProtocol::PacketEndpointPtr packet_endpoint;
+ std::unique_ptr<Session> session;
+
+ using ReplacementFn = std::function<String(const String & query)>;
+ using Replacements = std::unordered_map<std::string, ReplacementFn>;
+ Replacements replacements;
+
+ std::unique_ptr<MySQLProtocol::Authentication::IPlugin> auth_plugin;
+ std::shared_ptr<ReadBufferFromPocoSocket> in;
+ std::shared_ptr<WriteBuffer> out;
+ bool secure_connection = false;
+};
+
+#if USE_SSL
+class MySQLHandlerSSL : public MySQLHandler
+{
+public:
+ MySQLHandlerSSL(
+ IServer & server_,
+ TCPServer & tcp_server_,
+ const Poco::Net::StreamSocket & socket_,
+ bool ssl_enabled,
+ uint32_t connection_id_,
+ RSA & public_key_,
+ RSA & private_key_);
+
+private:
+ void authPluginSSL() override;
+
+ void finishHandshakeSSL(
+ size_t packet_size, char * buf, size_t pos,
+ std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) override;
+
+ RSA & public_key;
+ RSA & private_key;
+ std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
+};
+#endif
+
+}
diff --git a/contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp b/contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp
new file mode 100644
index 0000000000..deadb10f9a
--- /dev/null
+++ b/contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp
@@ -0,0 +1,140 @@
+#include "MySQLHandlerFactory.h"
+#include <Common/OpenSSLHelpers.h>
+#include <Poco/Net/TCPServerConnectionFactory.h>
+#include <Poco/Util/Application.h>
+#include <Common/logger_useful.h>
+#include <base/scope_guard.h>
+#include <Server/MySQLHandler.h>
+
+#if USE_SSL
+# include <Poco/Net/SSLManager.h>
+#endif
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int CANNOT_OPEN_FILE;
+ extern const int CANNOT_CLOSE_FILE;
+ extern const int NO_ELEMENTS_IN_CONFIG;
+ extern const int OPENSSL_ERROR;
+}
+
+MySQLHandlerFactory::MySQLHandlerFactory(IServer & server_)
+ : server(server_)
+ , log(&Poco::Logger::get("MySQLHandlerFactory"))
+{
+#if USE_SSL
+ try
+ {
+ Poco::Net::SSLManager::instance().defaultServerContext();
+ }
+ catch (...)
+ {
+ LOG_TRACE(log, "Failed to create SSL context. SSL will be disabled. Error: {}", getCurrentExceptionMessage(false));
+ ssl_enabled = false;
+ }
+
+ /// Reading rsa keys for SHA256 authentication plugin.
+ try
+ {
+ readRSAKeys();
+ }
+ catch (...)
+ {
+ LOG_TRACE(log, "Failed to read RSA key pair from server certificate. Error: {}", getCurrentExceptionMessage(false));
+ generateRSAKeys();
+ }
+#endif
+}
+
+#if USE_SSL
+void MySQLHandlerFactory::readRSAKeys()
+{
+ const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config();
+ String certificate_file_property = "openSSL.server.certificateFile";
+ String private_key_file_property = "openSSL.server.privateKeyFile";
+
+ if (!config.has(certificate_file_property))
+ throw Exception(ErrorCodes::NO_ELEMENTS_IN_CONFIG, "Certificate file is not set.");
+
+ if (!config.has(private_key_file_property))
+ throw Exception(ErrorCodes::NO_ELEMENTS_IN_CONFIG, "Private key file is not set.");
+
+ {
+ String certificate_file = config.getString(certificate_file_property);
+ FILE * fp = fopen(certificate_file.data(), "r");
+ if (fp == nullptr)
+ throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open certificate file: {}.", certificate_file);
+ SCOPE_EXIT(
+ if (0 != fclose(fp))
+ throwFromErrno("Cannot close file with the certificate in MySQLHandlerFactory", ErrorCodes::CANNOT_CLOSE_FILE);
+ );
+
+ X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr);
+ SCOPE_EXIT(X509_free(x509));
+ if (x509 == nullptr)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to read PEM certificate from {}. Error: {}", certificate_file, getOpenSSLErrors());
+
+ EVP_PKEY * p = X509_get_pubkey(x509);
+ if (p == nullptr)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to get RSA key from X509. Error: {}", getOpenSSLErrors());
+ SCOPE_EXIT(EVP_PKEY_free(p));
+
+ public_key.reset(EVP_PKEY_get1_RSA(p));
+ if (public_key.get() == nullptr)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to get RSA key from ENV_PKEY. Error: {}", getOpenSSLErrors());
+ }
+
+ {
+ String private_key_file = config.getString(private_key_file_property);
+
+ FILE * fp = fopen(private_key_file.data(), "r");
+ if (fp == nullptr)
+ throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open private key file {}.", private_key_file);
+ SCOPE_EXIT(
+ if (0 != fclose(fp))
+ throwFromErrno("Cannot close file with the certificate in MySQLHandlerFactory", ErrorCodes::CANNOT_CLOSE_FILE);
+ );
+
+ private_key.reset(PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr));
+ if (!private_key)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to read RSA private key from {}. Error: {}", private_key_file, getOpenSSLErrors());
+ }
+}
+
+void MySQLHandlerFactory::generateRSAKeys()
+{
+ LOG_TRACE(log, "Generating new RSA key pair.");
+ public_key.reset(RSA_new());
+ if (!public_key)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to allocate RSA key. Error: {}", getOpenSSLErrors());
+
+ BIGNUM * e = BN_new();
+ if (!e)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to allocate BIGNUM. Error: {}", getOpenSSLErrors());
+ SCOPE_EXIT(BN_free(e));
+
+ if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(public_key.get(), 2048, e, nullptr))
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to generate RSA key. Error: {}", getOpenSSLErrors());
+
+ private_key.reset(RSAPrivateKey_dup(public_key.get()));
+ if (!private_key)
+ throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to copy RSA key. Error: {}", getOpenSSLErrors());
+}
+#endif
+
+Poco::Net::TCPServerConnection * MySQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server)
+{
+ uint32_t connection_id = last_connection_id++;
+ LOG_TRACE(log, "MySQL connection. Id: {}. Address: {}", connection_id, socket.peerAddress().toString());
+#if USE_SSL
+ return new MySQLHandlerSSL(server, tcp_server, socket, ssl_enabled, connection_id, *public_key, *private_key);
+#else
+ return new MySQLHandler(server, tcp_server, socket, ssl_enabled, connection_id);
+#endif
+
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/MySQLHandlerFactory.h b/contrib/clickhouse/src/Server/MySQLHandlerFactory.h
new file mode 100644
index 0000000000..edc53e5367
--- /dev/null
+++ b/contrib/clickhouse/src/Server/MySQLHandlerFactory.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include <atomic>
+#include <memory>
+#include <Server/IServer.h>
+#include <Server/TCPServerConnectionFactory.h>
+
+#include "clickhouse_config.h"
+
+#if USE_SSL
+# include <openssl/rsa.h>
+#endif
+
+namespace DB
+{
+class TCPServer;
+
+class MySQLHandlerFactory : public TCPServerConnectionFactory
+{
+private:
+ IServer & server;
+ Poco::Logger * log;
+
+#if USE_SSL
+ struct RSADeleter
+ {
+ void operator()(RSA * ptr) { RSA_free(ptr); }
+ };
+ using RSAPtr = std::unique_ptr<RSA, RSADeleter>;
+
+ RSAPtr public_key;
+ RSAPtr private_key;
+
+ bool ssl_enabled = true;
+#else
+ bool ssl_enabled = false;
+#endif
+
+ std::atomic<unsigned> last_connection_id = 0;
+public:
+ explicit MySQLHandlerFactory(IServer & server_);
+
+ void readRSAKeys();
+
+ void generateRSAKeys();
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/NotFoundHandler.cpp b/contrib/clickhouse/src/Server/NotFoundHandler.cpp
new file mode 100644
index 0000000000..5b1db50855
--- /dev/null
+++ b/contrib/clickhouse/src/Server/NotFoundHandler.cpp
@@ -0,0 +1,31 @@
+#include <Server/NotFoundHandler.h>
+
+#include <IO/HTTPCommon.h>
+#include <Common/Exception.h>
+
+namespace DB
+{
+void NotFoundHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ try
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_NOT_FOUND);
+ *response.send() << "There is no handle " << request.getURI()
+ << (!hints.empty() ? fmt::format(". Maybe you meant {}.", hints.front()) : "") << "\n\n"
+ << "Use / or /ping for health checks.\n"
+ << "Or /replicas_status for more sophisticated health checks.\n\n"
+ << "Send queries from your program with POST method or GET /?query=...\n\n"
+ << "Use clickhouse-client:\n\n"
+ << "For interactive data analysis:\n"
+ << " clickhouse-client\n\n"
+ << "For batch query processing:\n"
+ << " clickhouse-client --query='SELECT 1' > result\n"
+ << " clickhouse-client < query > result\n";
+ }
+ catch (...)
+ {
+ tryLogCurrentException("NotFoundHandler");
+ }
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/NotFoundHandler.h b/contrib/clickhouse/src/Server/NotFoundHandler.h
new file mode 100644
index 0000000000..1cbfcd57f8
--- /dev/null
+++ b/contrib/clickhouse/src/Server/NotFoundHandler.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+
+namespace DB
+{
+
+/// Response with 404 and verbose description.
+class NotFoundHandler : public HTTPRequestHandler
+{
+public:
+ NotFoundHandler(std::vector<std::string> hints_) : hints(std::move(hints_)) {}
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+private:
+ std::vector<std::string> hints;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandler.cpp b/contrib/clickhouse/src/Server/PostgreSQLHandler.cpp
new file mode 100644
index 0000000000..7b07815425
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PostgreSQLHandler.cpp
@@ -0,0 +1,329 @@
+#include <IO/ReadBufferFromPocoSocket.h>
+#include <IO/ReadHelpers.h>
+#include <IO/ReadBufferFromString.h>
+#include <IO/WriteBufferFromPocoSocket.h>
+#include <Interpreters/Context.h>
+#include <Interpreters/executeQuery.h>
+#include "PostgreSQLHandler.h"
+#include <Parsers/parseQuery.h>
+#include <Server/TCPServer.h>
+#include <Common/setThreadName.h>
+#include <base/scope_guard.h>
+#include <random>
+
+#include "config_version.h"
+
+#if USE_SSL
+# include <Poco/Net/SecureStreamSocket.h>
+# include <Poco/Net/SSLManager.h>
+#endif
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int SYNTAX_ERROR;
+}
+
+PostgreSQLHandler::PostgreSQLHandler(
+ const Poco::Net::StreamSocket & socket_,
+ IServer & server_,
+ TCPServer & tcp_server_,
+ bool ssl_enabled_,
+ Int32 connection_id_,
+ std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> & auth_methods_)
+ : Poco::Net::TCPServerConnection(socket_)
+ , server(server_)
+ , tcp_server(tcp_server_)
+ , ssl_enabled(ssl_enabled_)
+ , connection_id(connection_id_)
+ , authentication_manager(auth_methods_)
+{
+ changeIO(socket());
+}
+
+void PostgreSQLHandler::changeIO(Poco::Net::StreamSocket & socket)
+{
+ in = std::make_shared<ReadBufferFromPocoSocket>(socket);
+ out = std::make_shared<WriteBufferFromPocoSocket>(socket);
+ message_transport = std::make_shared<PostgreSQLProtocol::Messaging::MessageTransport>(in.get(), out.get());
+}
+
+void PostgreSQLHandler::run()
+{
+ setThreadName("PostgresHandler");
+ ThreadStatus thread_status;
+
+ session = std::make_unique<Session>(server.context(), ClientInfo::Interface::POSTGRESQL);
+ SCOPE_EXIT({ session.reset(); });
+
+ session->setClientConnectionId(connection_id);
+
+ try
+ {
+ if (!startup())
+ return;
+
+ while (tcp_server.isOpen())
+ {
+ message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true);
+
+ constexpr size_t connection_check_timeout = 1; // 1 second
+ while (!in->poll(1000000 * connection_check_timeout))
+ if (!tcp_server.isOpen())
+ return;
+ PostgreSQLProtocol::Messaging::FrontMessageType message_type = message_transport->receiveMessageType();
+
+ if (!tcp_server.isOpen())
+ return;
+ switch (message_type)
+ {
+ case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY:
+ processQuery();
+ break;
+ case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE:
+ LOG_DEBUG(log, "Client closed the connection");
+ return;
+ case PostgreSQLProtocol::Messaging::FrontMessageType::PARSE:
+ case PostgreSQLProtocol::Messaging::FrontMessageType::BIND:
+ case PostgreSQLProtocol::Messaging::FrontMessageType::DESCRIBE:
+ case PostgreSQLProtocol::Messaging::FrontMessageType::SYNC:
+ case PostgreSQLProtocol::Messaging::FrontMessageType::FLUSH:
+ case PostgreSQLProtocol::Messaging::FrontMessageType::CLOSE:
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR,
+ "0A000",
+ "ClickHouse doesn't support extended query mechanism"),
+ true);
+ LOG_ERROR(log, "Client tried to access via extended query protocol");
+ message_transport->dropMessage();
+ break;
+ default:
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR,
+ "0A000",
+ "Command is not supported"),
+ true);
+ LOG_ERROR(log, "Command is not supported. Command code {:d}", static_cast<Int32>(message_type));
+ message_transport->dropMessage();
+ }
+ }
+ }
+ catch (const Poco::Exception &exc)
+ {
+ log->log(exc);
+ }
+
+}
+
+bool PostgreSQLHandler::startup()
+{
+ Int32 payload_size;
+ Int32 info;
+ establishSecureConnection(payload_size, info);
+
+ if (static_cast<PostgreSQLProtocol::Messaging::FrontMessageType>(info) == PostgreSQLProtocol::Messaging::FrontMessageType::CANCEL_REQUEST)
+ {
+ LOG_DEBUG(log, "Client issued request canceling");
+ cancelRequest();
+ return false;
+ }
+
+ std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> start_up_msg = receiveStartupMessage(payload_size);
+ const auto & user_name = start_up_msg->user;
+ authentication_manager.authenticate(user_name, *session, *message_transport, socket().peerAddress());
+
+ try
+ {
+ session->makeSessionContext();
+ session->sessionContext()->setDefaultFormat("PostgreSQLWire");
+ if (!start_up_msg->database.empty())
+ session->sessionContext()->setCurrentDatabase(start_up_msg->database);
+ }
+ catch (const Exception & exc)
+ {
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "XX000", exc.message()),
+ true);
+ throw;
+ }
+
+ sendParameterStatusData(*start_up_msg);
+
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::BackendKeyData(connection_id, secret_key), true);
+
+ LOG_DEBUG(log, "Successfully finished Startup stage");
+ return true;
+}
+
+void PostgreSQLHandler::establishSecureConnection(Int32 & payload_size, Int32 & info)
+{
+ bool was_encryption_req = true;
+ readBinaryBigEndian(payload_size, *in);
+ readBinaryBigEndian(info, *in);
+
+ switch (static_cast<PostgreSQLProtocol::Messaging::FrontMessageType>(info))
+ {
+ case PostgreSQLProtocol::Messaging::FrontMessageType::SSL_REQUEST:
+ LOG_DEBUG(log, "Client requested SSL");
+ if (ssl_enabled)
+ makeSecureConnectionSSL();
+ else
+ message_transport->send('N', true);
+ break;
+ case PostgreSQLProtocol::Messaging::FrontMessageType::GSSENC_REQUEST:
+ LOG_DEBUG(log, "Client requested GSSENC");
+ message_transport->send('N', true);
+ break;
+ default:
+ was_encryption_req = false;
+ }
+ if (was_encryption_req)
+ {
+ readBinaryBigEndian(payload_size, *in);
+ readBinaryBigEndian(info, *in);
+ }
+}
+
+#if USE_SSL
+void PostgreSQLHandler::makeSecureConnectionSSL()
+{
+ message_transport->send('S');
+ ss = std::make_shared<Poco::Net::SecureStreamSocket>(
+ Poco::Net::SecureStreamSocket::attach(socket(), Poco::Net::SSLManager::instance().defaultServerContext()));
+ changeIO(*ss);
+}
+#else
+void PostgreSQLHandler::makeSecureConnectionSSL() {}
+#endif
+
+void PostgreSQLHandler::sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message)
+{
+ std::unordered_map<String, String> & parameters = start_up_message.parameters;
+
+ if (parameters.find("application_name") != parameters.end())
+ message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("application_name", parameters["application_name"]));
+ if (parameters.find("client_encoding") != parameters.end())
+ message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("client_encoding", parameters["client_encoding"]));
+ else
+ message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("client_encoding", "UTF8"));
+
+ message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("server_version", VERSION_STRING));
+ message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("server_encoding", "UTF8"));
+ message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("DateStyle", "ISO"));
+ message_transport->flush();
+}
+
+void PostgreSQLHandler::cancelRequest()
+{
+ std::unique_ptr<PostgreSQLProtocol::Messaging::CancelRequest> msg =
+ message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::CancelRequest>(8);
+
+ String query = fmt::format("KILL QUERY WHERE query_id = 'postgres:{:d}:{:d}'", msg->process_id, msg->secret_key);
+ ReadBufferFromString replacement(query);
+
+ auto query_context = session->makeQueryContext();
+ query_context->setCurrentQueryId("");
+ executeQuery(replacement, *out, true, query_context, {});
+}
+
+inline std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> PostgreSQLHandler::receiveStartupMessage(int payload_size)
+{
+ std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> message;
+ try
+ {
+ message = message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::StartupMessage>(payload_size - 8);
+ }
+ catch (const Exception &)
+ {
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "08P01", "Can't correctly handle Startup message"),
+ true);
+ throw;
+ }
+
+ LOG_DEBUG(log, "Successfully received Startup message");
+ return message;
+}
+
+void PostgreSQLHandler::processQuery()
+{
+ try
+ {
+ std::unique_ptr<PostgreSQLProtocol::Messaging::Query> query =
+ message_transport->receive<PostgreSQLProtocol::Messaging::Query>();
+
+ if (isEmptyQuery(query->query))
+ {
+ message_transport->send(PostgreSQLProtocol::Messaging::EmptyQueryResponse());
+ return;
+ }
+
+ bool psycopg2_cond = query->query == "BEGIN" || query->query == "COMMIT"; // psycopg2 starts and ends queries with BEGIN/COMMIT commands
+ bool jdbc_cond = query->query.find("SET extra_float_digits") != String::npos || query->query.find("SET application_name") != String::npos; // jdbc starts with setting this parameter
+ if (psycopg2_cond || jdbc_cond)
+ {
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::CommandComplete(
+ PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(query->query), 0));
+ return;
+ }
+
+ const auto & settings = session->sessionContext()->getSettingsRef();
+ std::vector<String> queries;
+ auto parse_res = splitMultipartQuery(query->query, queries,
+ settings.max_query_size,
+ settings.max_parser_depth,
+ settings.allow_settings_after_format_in_insert);
+ if (!parse_res.second)
+ throw Exception(ErrorCodes::SYNTAX_ERROR, "Cannot parse and execute the following part of query: {}", String(parse_res.first));
+
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::uniform_int_distribution<Int32> dis(0, INT32_MAX);
+
+ for (const auto & spl_query : queries)
+ {
+ secret_key = dis(gen);
+ auto query_context = session->makeQueryContext();
+ query_context->setCurrentQueryId(fmt::format("postgres:{:d}:{:d}", connection_id, secret_key));
+
+ CurrentThread::QueryScope query_scope{query_context};
+ ReadBufferFromString read_buf(spl_query);
+ executeQuery(read_buf, *out, false, query_context, {});
+
+ PostgreSQLProtocol::Messaging::CommandComplete::Command command =
+ PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(spl_query);
+ message_transport->send(PostgreSQLProtocol::Messaging::CommandComplete(command, 0), true);
+ }
+
+ }
+ catch (const Exception & e)
+ {
+ message_transport->send(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse(
+ PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "2F000", "Query execution failed.\n" + e.displayText()),
+ true);
+ throw;
+ }
+}
+
+bool PostgreSQLHandler::isEmptyQuery(const String & query)
+{
+ if (query.empty())
+ return true;
+ /// golang driver pgx sends ";"
+ if (query == ";")
+ return true;
+
+ Poco::RegularExpression regex(R"(\A\s*\z)");
+ return regex.match(query);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandler.h b/contrib/clickhouse/src/Server/PostgreSQLHandler.h
new file mode 100644
index 0000000000..4de5977cc4
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PostgreSQLHandler.h
@@ -0,0 +1,81 @@
+#pragma once
+
+#include <Common/CurrentMetrics.h>
+#include "clickhouse_config.h"
+#include <Core/PostgreSQLProtocol.h>
+#include <Poco/Net/TCPServerConnection.h>
+#include "IServer.h"
+
+#if USE_SSL
+# include <Poco/Net/SecureStreamSocket.h>
+#endif
+
+namespace CurrentMetrics
+{
+ extern const Metric PostgreSQLConnection;
+}
+
+namespace DB
+{
+class ReadBufferFromPocoSocket;
+class Session;
+class TCPServer;
+
+/** PostgreSQL wire protocol implementation.
+ * For more info see https://www.postgresql.org/docs/current/protocol.html
+ */
+class PostgreSQLHandler : public Poco::Net::TCPServerConnection
+{
+public:
+ PostgreSQLHandler(
+ const Poco::Net::StreamSocket & socket_,
+ IServer & server_,
+ TCPServer & tcp_server_,
+ bool ssl_enabled_,
+ Int32 connection_id_,
+ std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> & auth_methods_);
+
+ void run() final;
+
+private:
+ Poco::Logger * log = &Poco::Logger::get("PostgreSQLHandler");
+
+ IServer & server;
+ TCPServer & tcp_server;
+ std::unique_ptr<Session> session;
+ bool ssl_enabled = false;
+ Int32 connection_id = 0;
+ Int32 secret_key = 0;
+
+ std::shared_ptr<ReadBufferFromPocoSocket> in;
+ std::shared_ptr<WriteBuffer> out;
+ std::shared_ptr<PostgreSQLProtocol::Messaging::MessageTransport> message_transport;
+
+#if USE_SSL
+ std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
+#endif
+
+ PostgreSQLProtocol::PGAuthentication::AuthenticationManager authentication_manager;
+
+ CurrentMetrics::Increment metric_increment{CurrentMetrics::PostgreSQLConnection};
+
+ void changeIO(Poco::Net::StreamSocket & socket);
+
+ bool startup();
+
+ void establishSecureConnection(Int32 & payload_size, Int32 & info);
+
+ void makeSecureConnectionSSL();
+
+ void sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message);
+
+ void cancelRequest();
+
+ std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> receiveStartupMessage(int payload_size);
+
+ void processQuery();
+
+ static bool isEmptyQuery(const String & query);
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp
new file mode 100644
index 0000000000..6f2124861e
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp
@@ -0,0 +1,26 @@
+#include "PostgreSQLHandlerFactory.h"
+#include <memory>
+#include <Server/PostgreSQLHandler.h>
+
+namespace DB
+{
+
+PostgreSQLHandlerFactory::PostgreSQLHandlerFactory(IServer & server_)
+ : server(server_)
+ , log(&Poco::Logger::get("PostgreSQLHandlerFactory"))
+{
+ auth_methods =
+ {
+ std::make_shared<PostgreSQLProtocol::PGAuthentication::NoPasswordAuth>(),
+ std::make_shared<PostgreSQLProtocol::PGAuthentication::CleartextPasswordAuth>(),
+ };
+}
+
+Poco::Net::TCPServerConnection * PostgreSQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server)
+{
+ Int32 connection_id = last_connection_id++;
+ LOG_TRACE(log, "PostgreSQL connection. Id: {}. Address: {}", connection_id, socket.peerAddress().toString());
+ return new PostgreSQLHandler(socket, server, tcp_server, ssl_enabled, connection_id, auth_methods);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h
new file mode 100644
index 0000000000..08b7ad1d85
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include <atomic>
+#include <memory>
+#include <Server/IServer.h>
+#include <Server/TCPServerConnectionFactory.h>
+#include <Core/PostgreSQLProtocol.h>
+#include "clickhouse_config.h"
+
+namespace DB
+{
+
+class PostgreSQLHandlerFactory : public TCPServerConnectionFactory
+{
+private:
+ IServer & server;
+ Poco::Logger * log;
+
+#if USE_SSL
+ bool ssl_enabled = true;
+#else
+ bool ssl_enabled = false;
+#endif
+
+ std::atomic<Int32> last_connection_id = 0;
+ std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> auth_methods;
+
+public:
+ explicit PostgreSQLHandlerFactory(IServer & server_);
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & server) override;
+};
+}
diff --git a/contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp
new file mode 100644
index 0000000000..2331e45522
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp
@@ -0,0 +1,161 @@
+#include "PrometheusMetricsWriter.h"
+
+#include <IO/WriteHelpers.h>
+#include <Common/StatusInfo.h>
+#include <regex> /// TODO: this library is harmful.
+#include <algorithm>
+
+
+namespace
+{
+
+template <typename T>
+void writeOutLine(DB::WriteBuffer & wb, T && val)
+{
+ DB::writeText(std::forward<T>(val), wb);
+ DB::writeChar('\n', wb);
+}
+
+template <typename T, typename... TArgs>
+void writeOutLine(DB::WriteBuffer & wb, T && val, TArgs &&... args)
+{
+ DB::writeText(std::forward<T>(val), wb);
+ DB::writeChar(' ', wb);
+ writeOutLine(wb, std::forward<TArgs>(args)...);
+}
+
+/// Returns false if name is not valid
+bool replaceInvalidChars(std::string & metric_name)
+{
+ /// dirty solution
+ metric_name = std::regex_replace(metric_name, std::regex("[^a-zA-Z0-9_:]"), "_");
+ metric_name = std::regex_replace(metric_name, std::regex("^[^a-zA-Z]*"), "");
+ return !metric_name.empty();
+}
+
+void convertHelpToSingleLine(std::string & help)
+{
+ std::replace(help.begin(), help.end(), '\n', ' ');
+}
+
+}
+
+
+namespace DB
+{
+
+PrometheusMetricsWriter::PrometheusMetricsWriter(
+ const Poco::Util::AbstractConfiguration & config, const std::string & config_name,
+ const AsynchronousMetrics & async_metrics_)
+ : async_metrics(async_metrics_)
+ , send_events(config.getBool(config_name + ".events", true))
+ , send_metrics(config.getBool(config_name + ".metrics", true))
+ , send_asynchronous_metrics(config.getBool(config_name + ".asynchronous_metrics", true))
+ , send_status_info(config.getBool(config_name + ".status_info", true))
+{
+}
+
+void PrometheusMetricsWriter::write(WriteBuffer & wb) const
+{
+ if (send_events)
+ {
+ for (ProfileEvents::Event i = ProfileEvents::Event(0), end = ProfileEvents::end(); i < end; ++i)
+ {
+ const auto counter = ProfileEvents::global_counters[i].load(std::memory_order_relaxed);
+
+ std::string metric_name{ProfileEvents::getName(static_cast<ProfileEvents::Event>(i))};
+ std::string metric_doc{ProfileEvents::getDocumentation(static_cast<ProfileEvents::Event>(i))};
+
+ convertHelpToSingleLine(metric_doc);
+
+ if (!replaceInvalidChars(metric_name))
+ continue;
+ std::string key{profile_events_prefix + metric_name};
+
+ writeOutLine(wb, "# HELP", key, metric_doc);
+ writeOutLine(wb, "# TYPE", key, "counter");
+ writeOutLine(wb, key, counter);
+ }
+ }
+
+ if (send_metrics)
+ {
+ for (size_t i = 0, end = CurrentMetrics::end(); i < end; ++i)
+ {
+ const auto value = CurrentMetrics::values[i].load(std::memory_order_relaxed);
+
+ std::string metric_name{CurrentMetrics::getName(static_cast<CurrentMetrics::Metric>(i))};
+ std::string metric_doc{CurrentMetrics::getDocumentation(static_cast<CurrentMetrics::Metric>(i))};
+
+ convertHelpToSingleLine(metric_doc);
+
+ if (!replaceInvalidChars(metric_name))
+ continue;
+ std::string key{current_metrics_prefix + metric_name};
+
+ writeOutLine(wb, "# HELP", key, metric_doc);
+ writeOutLine(wb, "# TYPE", key, "gauge");
+ writeOutLine(wb, key, value);
+ }
+ }
+
+ if (send_asynchronous_metrics)
+ {
+ auto async_metrics_values = async_metrics.getValues();
+ for (const auto & name_value : async_metrics_values)
+ {
+ std::string key{asynchronous_metrics_prefix + name_value.first};
+
+ if (!replaceInvalidChars(key))
+ continue;
+
+ auto value = name_value.second;
+
+ std::string metric_doc{value.documentation};
+ convertHelpToSingleLine(metric_doc);
+
+ // TODO: add HELP section? asynchronous_metrics contains only key and value
+ writeOutLine(wb, "# HELP", key, metric_doc);
+ writeOutLine(wb, "# TYPE", key, "gauge");
+ writeOutLine(wb, key, value.value);
+ }
+ }
+
+ if (send_status_info)
+ {
+ for (size_t i = 0, end = CurrentStatusInfo::end(); i < end; ++i)
+ {
+ std::lock_guard lock(CurrentStatusInfo::locks[static_cast<CurrentStatusInfo::Status>(i)]);
+ std::string metric_name{CurrentStatusInfo::getName(static_cast<CurrentStatusInfo::Status>(i))};
+ std::string metric_doc{CurrentStatusInfo::getDocumentation(static_cast<CurrentStatusInfo::Status>(i))};
+
+ convertHelpToSingleLine(metric_doc);
+
+ if (!replaceInvalidChars(metric_name))
+ continue;
+ std::string key{current_status_prefix + metric_name};
+
+ writeOutLine(wb, "# HELP", key, metric_doc);
+ writeOutLine(wb, "# TYPE", key, "gauge");
+
+ for (const auto & value: CurrentStatusInfo::values[i])
+ {
+ for (const auto & enum_value: CurrentStatusInfo::getAllPossibleValues(static_cast<CurrentStatusInfo::Status>(i)))
+ {
+ DB::writeText(key, wb);
+ DB::writeChar('{', wb);
+ DB::writeText(key, wb);
+ DB::writeChar('=', wb);
+ writeDoubleQuotedString(enum_value.first, wb);
+ DB::writeText(",name=", wb);
+ writeDoubleQuotedString(value.first, wb);
+ DB::writeText("} ", wb);
+ DB::writeText(value.second == enum_value.second, wb);
+ DB::writeChar('\n', wb);
+ }
+ }
+ }
+ }
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/PrometheusMetricsWriter.h b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.h
new file mode 100644
index 0000000000..b4f6ab57de
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.h
@@ -0,0 +1,38 @@
+#pragma once
+
+#include <string>
+
+#include <Common/AsynchronousMetrics.h>
+#include <IO/WriteBuffer.h>
+
+#include <Poco/Util/AbstractConfiguration.h>
+
+
+namespace DB
+{
+
+/// Write metrics in Prometheus format
+class PrometheusMetricsWriter
+{
+public:
+ PrometheusMetricsWriter(
+ const Poco::Util::AbstractConfiguration & config, const std::string & config_name,
+ const AsynchronousMetrics & async_metrics_);
+
+ void write(WriteBuffer & wb) const;
+
+private:
+ const AsynchronousMetrics & async_metrics;
+
+ const bool send_events;
+ const bool send_metrics;
+ const bool send_asynchronous_metrics;
+ const bool send_status_info;
+
+ static inline constexpr auto profile_events_prefix = "ClickHouseProfileEvents_";
+ static inline constexpr auto current_metrics_prefix = "ClickHouseMetrics_";
+ static inline constexpr auto asynchronous_metrics_prefix = "ClickHouseAsyncMetrics_";
+ static inline constexpr auto current_status_prefix = "ClickHouseStatusInfo_";
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp b/contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp
new file mode 100644
index 0000000000..7902562420
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp
@@ -0,0 +1,71 @@
+#include <Server/PrometheusRequestHandler.h>
+
+#include <IO/HTTPCommon.h>
+#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
+#include <Server/HTTPHandlerFactory.h>
+#include <Server/IServer.h>
+#include <Common/CurrentMetrics.h>
+#include <Common/Exception.h>
+#include <Common/ProfileEvents.h>
+
+#include <Poco/Util/LayeredConfiguration.h>
+
+
+namespace DB
+{
+void PrometheusRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ try
+ {
+ const auto & config = server.config();
+ unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
+
+ setResponseDefaultHeaders(response, keep_alive_timeout);
+
+ response.setContentType("text/plain; version=0.0.4; charset=UTF-8");
+
+ WriteBufferFromHTTPServerResponse wb(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
+ try
+ {
+ metrics_writer.write(wb);
+ wb.finalize();
+ }
+ catch (...)
+ {
+ wb.finalize();
+ }
+ }
+ catch (...)
+ {
+ tryLogCurrentException("PrometheusRequestHandler");
+ }
+}
+
+HTTPRequestHandlerFactoryPtr
+createPrometheusHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics,
+ const std::string & config_prefix)
+{
+ auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<PrometheusRequestHandler>>(
+ server, PrometheusMetricsWriter(config, config_prefix + ".handler", async_metrics));
+ factory->addFiltersFromConfig(config, config_prefix);
+ return factory;
+}
+
+HTTPRequestHandlerFactoryPtr
+createPrometheusMainHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ AsynchronousMetrics & async_metrics,
+ const std::string & name)
+{
+ auto factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name);
+ auto handler = std::make_shared<HandlingRuleHTTPHandlerFactory<PrometheusRequestHandler>>(
+ server, PrometheusMetricsWriter(config, "prometheus", async_metrics));
+ handler->attachStrictPath(config.getString("prometheus.endpoint", "/metrics"));
+ handler->allowGetAndHeadRequest();
+ factory->addHandler(handler);
+ return factory;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/PrometheusRequestHandler.h b/contrib/clickhouse/src/Server/PrometheusRequestHandler.h
new file mode 100644
index 0000000000..1fb3d9f0f5
--- /dev/null
+++ b/contrib/clickhouse/src/Server/PrometheusRequestHandler.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+
+#include "PrometheusMetricsWriter.h"
+
+namespace DB
+{
+
+class IServer;
+
+class PrometheusRequestHandler : public HTTPRequestHandler
+{
+private:
+ IServer & server;
+ const PrometheusMetricsWriter & metrics_writer;
+
+public:
+ explicit PrometheusRequestHandler(IServer & server_, const PrometheusMetricsWriter & metrics_writer_)
+ : server(server_)
+ , metrics_writer(metrics_writer_)
+ {
+ }
+
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp b/contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp
new file mode 100644
index 0000000000..8d14a84989
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp
@@ -0,0 +1,75 @@
+#include <Server/ProtocolServerAdapter.h>
+#include <Server/TCPServer.h>
+
+#if USE_GRPC && !defined(CLICKHOUSE_KEEPER_STANDALONE_BUILD)
+#include <Server/GRPCServer.h>
+#endif
+
+
+namespace DB
+{
+class ProtocolServerAdapter::TCPServerAdapterImpl : public Impl
+{
+public:
+ explicit TCPServerAdapterImpl(std::unique_ptr<TCPServer> tcp_server_) : tcp_server(std::move(tcp_server_)) {}
+ ~TCPServerAdapterImpl() override = default;
+
+ void start() override { tcp_server->start(); }
+ void stop() override { tcp_server->stop(); }
+ bool isStopping() const override { return !tcp_server->isOpen(); }
+ UInt16 portNumber() const override { return tcp_server->portNumber(); }
+ size_t currentConnections() const override { return tcp_server->currentConnections(); }
+ size_t currentThreads() const override { return tcp_server->currentThreads(); }
+
+private:
+ std::unique_ptr<TCPServer> tcp_server;
+};
+
+ProtocolServerAdapter::ProtocolServerAdapter(
+ const std::string & listen_host_,
+ const char * port_name_,
+ const std::string & description_,
+ std::unique_ptr<TCPServer> tcp_server_)
+ : listen_host(listen_host_)
+ , port_name(port_name_)
+ , description(description_)
+ , impl(std::make_unique<TCPServerAdapterImpl>(std::move(tcp_server_)))
+{
+}
+
+#if USE_GRPC && !defined(CLICKHOUSE_KEEPER_STANDALONE_BUILD)
+class ProtocolServerAdapter::GRPCServerAdapterImpl : public Impl
+{
+public:
+ explicit GRPCServerAdapterImpl(std::unique_ptr<GRPCServer> grpc_server_) : grpc_server(std::move(grpc_server_)) {}
+ ~GRPCServerAdapterImpl() override = default;
+
+ void start() override { grpc_server->start(); }
+ void stop() override
+ {
+ is_stopping = true;
+ grpc_server->stop();
+ }
+ bool isStopping() const override { return is_stopping; }
+ UInt16 portNumber() const override { return grpc_server->portNumber(); }
+ size_t currentConnections() const override { return grpc_server->currentConnections(); }
+ size_t currentThreads() const override { return grpc_server->currentThreads(); }
+
+private:
+ std::unique_ptr<GRPCServer> grpc_server;
+ bool is_stopping = false;
+};
+
+ProtocolServerAdapter::ProtocolServerAdapter(
+ const std::string & listen_host_,
+ const char * port_name_,
+ const std::string & description_,
+ std::unique_ptr<GRPCServer> grpc_server_)
+ : listen_host(listen_host_)
+ , port_name(port_name_)
+ , description(description_)
+ , impl(std::make_unique<GRPCServerAdapterImpl>(std::move(grpc_server_)))
+{
+}
+#endif
+}
diff --git a/contrib/clickhouse/src/Server/ProtocolServerAdapter.h b/contrib/clickhouse/src/Server/ProtocolServerAdapter.h
new file mode 100644
index 0000000000..9497ec3e8e
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ProtocolServerAdapter.h
@@ -0,0 +1,74 @@
+#pragma once
+
+#include "clickhouse_config.h"
+
+#include <base/types.h>
+#include <memory>
+#include <string>
+
+
+namespace DB
+{
+
+class GRPCServer;
+class TCPServer;
+
+/// Provides an unified interface to access a protocol implementing server
+/// no matter what type it has (HTTPServer, TCPServer, MySQLServer, GRPCServer, ...).
+class ProtocolServerAdapter
+{
+ friend class ProtocolServers;
+public:
+ ProtocolServerAdapter(ProtocolServerAdapter && src) = default;
+ ProtocolServerAdapter & operator =(ProtocolServerAdapter && src) = default;
+ ProtocolServerAdapter(const std::string & listen_host_, const char * port_name_, const std::string & description_, std::unique_ptr<TCPServer> tcp_server_);
+
+#if USE_GRPC && !defined(CLICKHOUSE_KEEPER_STANDALONE_BUILD)
+ ProtocolServerAdapter(const std::string & listen_host_, const char * port_name_, const std::string & description_, std::unique_ptr<GRPCServer> grpc_server_);
+#endif
+
+ /// Starts the server. A new thread will be created that waits for and accepts incoming connections.
+ void start() { impl->start(); }
+
+ /// Stops the server. No new connections will be accepted.
+ void stop() { impl->stop(); }
+
+ bool isStopping() const { return impl->isStopping(); }
+
+ /// Returns the number of currently handled connections.
+ size_t currentConnections() const { return impl->currentConnections(); }
+
+ /// Returns the number of current threads.
+ size_t currentThreads() const { return impl->currentThreads(); }
+
+ /// Returns the port this server is listening to.
+ UInt16 portNumber() const { return impl->portNumber(); }
+
+ const std::string & getListenHost() const { return listen_host; }
+
+ const std::string & getPortName() const { return port_name; }
+
+ const std::string & getDescription() const { return description; }
+
+private:
+ class Impl
+ {
+ public:
+ virtual ~Impl() = default;
+ virtual void start() = 0;
+ virtual void stop() = 0;
+ virtual bool isStopping() const = 0;
+ virtual UInt16 portNumber() const = 0;
+ virtual size_t currentConnections() const = 0;
+ virtual size_t currentThreads() const = 0;
+ };
+ class TCPServerAdapterImpl;
+ class GRPCServerAdapterImpl;
+
+ std::string listen_host;
+ std::string port_name;
+ std::string description;
+ std::unique_ptr<Impl> impl;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/ProxyV1Handler.cpp b/contrib/clickhouse/src/Server/ProxyV1Handler.cpp
new file mode 100644
index 0000000000..56621940a2
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ProxyV1Handler.cpp
@@ -0,0 +1,127 @@
+#include <Server/ProxyV1Handler.h>
+#include <Poco/Net/NetException.h>
+#include <Common/NetException.h>
+#include <Common/logger_useful.h>
+#include <Interpreters/Context.h>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int NETWORK_ERROR;
+ extern const int SOCKET_TIMEOUT;
+ extern const int CANNOT_READ_FROM_SOCKET;
+ extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED;
+}
+
+void ProxyV1Handler::run()
+{
+ const auto & settings = server.context()->getSettingsRef();
+ socket().setReceiveTimeout(settings.receive_timeout);
+
+ std::string word;
+ bool eol;
+
+ // Read PROXYv1 protocol header
+ // http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
+
+ // read "PROXY"
+ if (!readWord(5, word, eol) || word != "PROXY" || eol)
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ // read "TCP4" or "TCP6" or "UNKNOWN"
+ if (!readWord(7, word, eol))
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ if (word != "TCP4" && word != "TCP6" && word != "UNKNOWN")
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ if (word == "UNKNOWN" && eol)
+ return;
+
+ if (eol)
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ // read address
+ if (!readWord(39, word, eol) || eol)
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ stack_data.forwarded_for = std::move(word);
+
+ // read address
+ if (!readWord(39, word, eol) || eol)
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ // read port
+ if (!readWord(5, word, eol) || eol)
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ // read port and "\r\n"
+ if (!readWord(5, word, eol) || !eol)
+ throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation");
+
+ if (!stack_data.forwarded_for.empty())
+ LOG_TRACE(log, "Forwarded client address from PROXY header: {}", stack_data.forwarded_for);
+}
+
+bool ProxyV1Handler::readWord(int max_len, std::string & word, bool & eol)
+{
+ word.clear();
+ eol = false;
+
+ char ch = 0;
+ int n = 0;
+ bool is_cr = false;
+ try
+ {
+ for (++max_len; max_len > 0 || is_cr; --max_len)
+ {
+ n = socket().receiveBytes(&ch, 1);
+ if (n == 0)
+ {
+ socket().shutdown();
+ return false;
+ }
+ if (n < 0)
+ break;
+
+ if (is_cr)
+ return ch == 0x0A;
+
+ if (ch == 0x0D)
+ {
+ is_cr = true;
+ eol = true;
+ continue;
+ }
+
+ if (ch == ' ')
+ return true;
+
+ word.push_back(ch);
+ }
+ }
+ catch (const Poco::Net::NetException & e)
+ {
+ throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while reading from socket ({})", e.displayText(), socket().peerAddress().toString());
+ }
+ catch (const Poco::TimeoutException &)
+ {
+ throw NetException(ErrorCodes::SOCKET_TIMEOUT, "Timeout exceeded while reading from socket ({}, {} ms)",
+ socket().peerAddress().toString(),
+ socket().getReceiveTimeout().totalMilliseconds());
+ }
+ catch (const Poco::IOException & e)
+ {
+ throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while reading from socket ({})", e.displayText(), socket().peerAddress().toString());
+ }
+
+ if (n < 0)
+ throw NetException(ErrorCodes::CANNOT_READ_FROM_SOCKET, "Cannot read from socket ({})", socket().peerAddress().toString());
+
+ return false;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/ProxyV1Handler.h b/contrib/clickhouse/src/Server/ProxyV1Handler.h
new file mode 100644
index 0000000000..b50c2acbc5
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ProxyV1Handler.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <Poco/Net/TCPServerConnection.h>
+#include <Server/IServer.h>
+#include <Server/TCPProtocolStackData.h>
+
+
+namespace DB
+{
+
+class ProxyV1Handler : public Poco::Net::TCPServerConnection
+{
+ using StreamSocket = Poco::Net::StreamSocket;
+public:
+ explicit ProxyV1Handler(const StreamSocket & socket, IServer & server_, const std::string & conf_name_, TCPProtocolStackData & stack_data_)
+ : Poco::Net::TCPServerConnection(socket), log(&Poco::Logger::get("ProxyV1Handler")), server(server_), conf_name(conf_name_), stack_data(stack_data_) {}
+
+ void run() override;
+
+protected:
+ bool readWord(int max_len, std::string & word, bool & eol);
+
+private:
+ Poco::Logger * log;
+ IServer & server;
+ std::string conf_name;
+ TCPProtocolStackData & stack_data;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h b/contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h
new file mode 100644
index 0000000000..028596d745
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h
@@ -0,0 +1,56 @@
+#pragma once
+
+#include <Poco/Net/NetException.h>
+#include <Poco/Net/TCPServerConnection.h>
+#include <Server/ProxyV1Handler.h>
+#include <Common/logger_useful.h>
+#include <Server/IServer.h>
+#include <Server/TCPServer.h>
+#include <Server/TCPProtocolStackData.h>
+
+
+namespace DB
+{
+
+class ProxyV1HandlerFactory : public TCPServerConnectionFactory
+{
+private:
+ IServer & server;
+ Poco::Logger * log;
+ std::string conf_name;
+
+ class DummyTCPHandler : public Poco::Net::TCPServerConnection
+ {
+ public:
+ using Poco::Net::TCPServerConnection::TCPServerConnection;
+ void run() override {}
+ };
+
+public:
+ explicit ProxyV1HandlerFactory(IServer & server_, const std::string & conf_name_)
+ : server(server_), log(&Poco::Logger::get("ProxyV1HandlerFactory")), conf_name(conf_name_)
+ {
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override
+ {
+ TCPProtocolStackData stack_data;
+ return createConnection(socket, tcp_server, stack_data);
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override
+ {
+ try
+ {
+ LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
+ return new ProxyV1Handler(socket, server, conf_name, stack_data);
+ }
+ catch (const Poco::Net::NetException &)
+ {
+ LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
+ return new DummyTCPHandler(socket);
+ }
+ }
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp b/contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp
new file mode 100644
index 0000000000..8c0ab0c1a3
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp
@@ -0,0 +1,128 @@
+#include <Server/ReplicasStatusHandler.h>
+
+#include <Databases/IDatabase.h>
+#include <IO/HTTPCommon.h>
+#include <Interpreters/Context.h>
+#include <Server/HTTP/HTMLForm.h>
+#include <Server/HTTPHandlerFactory.h>
+#include <Server/HTTPHandlerRequestFilter.h>
+#include <Server/IServer.h>
+#include <Storages/StorageReplicatedMergeTree.h>
+#include <Common/typeid_cast.h>
+
+#include <Poco/Net/HTTPRequestHandlerFactory.h>
+#include <Poco/Net/HTTPServerRequest.h>
+#include <Poco/Net/HTTPServerResponse.h>
+
+
+namespace DB
+{
+
+ReplicasStatusHandler::ReplicasStatusHandler(IServer & server) : WithContext(server.context())
+{
+}
+
+void ReplicasStatusHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ try
+ {
+ HTMLForm params(getContext()->getSettingsRef(), request);
+
+ /// Even if lag is small, output detailed information about the lag.
+ bool verbose = params.get("verbose", "") == "1";
+
+ const MergeTreeSettings & settings = getContext()->getReplicatedMergeTreeSettings();
+
+ bool ok = true;
+ WriteBufferFromOwnString message;
+
+ auto databases = DatabaseCatalog::instance().getDatabases();
+
+ /// Iterate through all the replicated tables.
+ for (const auto & db : databases)
+ {
+ /// Check if database can contain replicated tables
+ if (!db.second->canContainMergeTreeTables())
+ continue;
+
+ for (auto iterator = db.second->getTablesIterator(getContext()); iterator->isValid(); iterator->next())
+ {
+ const auto & table = iterator->table();
+ if (!table)
+ continue;
+
+ StorageReplicatedMergeTree * table_replicated = dynamic_cast<StorageReplicatedMergeTree *>(table.get());
+
+ if (!table_replicated)
+ continue;
+
+ time_t absolute_delay = 0;
+ time_t relative_delay = 0;
+
+ if (!table_replicated->isTableReadOnly())
+ {
+ table_replicated->getReplicaDelays(absolute_delay, relative_delay);
+
+ if ((settings.min_absolute_delay_to_close && absolute_delay >= static_cast<time_t>(settings.min_absolute_delay_to_close))
+ || (settings.min_relative_delay_to_close && relative_delay >= static_cast<time_t>(settings.min_relative_delay_to_close)))
+ ok = false;
+
+ message << backQuoteIfNeed(db.first) << "." << backQuoteIfNeed(iterator->name())
+ << ":\tAbsolute delay: " << absolute_delay << ". Relative delay: " << relative_delay << ".\n";
+ }
+ else
+ {
+ message << backQuoteIfNeed(db.first) << "." << backQuoteIfNeed(iterator->name())
+ << ":\tis readonly. \n";
+ }
+ }
+ }
+
+ const auto & config = getContext()->getConfigRef();
+ setResponseDefaultHeaders(response, config.getUInt("keep_alive_timeout", 10));
+
+ if (!ok)
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_SERVICE_UNAVAILABLE);
+ verbose = true;
+ }
+
+ if (verbose)
+ *response.send() << message.str();
+ else
+ {
+ const char * data = "Ok.\n";
+ response.sendBuffer(data, strlen(data));
+ }
+ }
+ catch (...)
+ {
+ tryLogCurrentException("ReplicasStatusHandler");
+
+ try
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
+
+ if (!response.sent())
+ {
+ /// We have not sent anything yet and we don't even know if we need to compress response.
+ *response.send() << getCurrentExceptionMessage(false) << std::endl;
+ }
+ }
+ catch (...)
+ {
+ LOG_ERROR((&Poco::Logger::get("ReplicasStatusHandler")), "Cannot send exception to client");
+ }
+ }
+}
+
+HTTPRequestHandlerFactoryPtr createReplicasStatusHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix)
+{
+ auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<ReplicasStatusHandler>>(server);
+ factory->addFiltersFromConfig(config, config_prefix);
+ return factory;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/ReplicasStatusHandler.h b/contrib/clickhouse/src/Server/ReplicasStatusHandler.h
new file mode 100644
index 0000000000..1a5388aa2a
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ReplicasStatusHandler.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+
+namespace DB
+{
+
+class Context;
+class IServer;
+
+/// Replies "Ok.\n" if all replicas on this server don't lag too much. Otherwise output lag information.
+class ReplicasStatusHandler : public HTTPRequestHandler, WithContext
+{
+public:
+ explicit ReplicasStatusHandler(IServer & server_);
+
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+};
+
+
+}
diff --git a/contrib/clickhouse/src/Server/ServerType.cpp b/contrib/clickhouse/src/Server/ServerType.cpp
new file mode 100644
index 0000000000..fb052e7d6e
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ServerType.cpp
@@ -0,0 +1,153 @@
+#include <Server/ServerType.h>
+
+#include <vector>
+#include <algorithm>
+
+#include <magic_enum.hpp>
+
+
+namespace DB
+{
+
+namespace
+{
+ std::vector<std::string> getTypeIndexToTypeName()
+ {
+ constexpr std::size_t types_size = magic_enum::enum_count<ServerType::Type>();
+
+ std::vector<std::string> type_index_to_type_name;
+ type_index_to_type_name.resize(types_size);
+
+ auto entries = magic_enum::enum_entries<ServerType::Type>();
+ for (const auto & [entry, str] : entries)
+ {
+ auto str_copy = String(str);
+ std::replace(str_copy.begin(), str_copy.end(), '_', ' ');
+ type_index_to_type_name[static_cast<UInt64>(entry)] = std::move(str_copy);
+ }
+
+ return type_index_to_type_name;
+ }
+}
+
+const char * ServerType::serverTypeToString(ServerType::Type type)
+{
+ /** During parsing if SystemQuery is not parsed properly it is added to Expected variants as description check IParser.h.
+ * Description string must be statically allocated.
+ */
+ static std::vector<std::string> type_index_to_type_name = getTypeIndexToTypeName();
+ const auto & type_name = type_index_to_type_name[static_cast<UInt64>(type)];
+ return type_name.data();
+}
+
+bool ServerType::shouldStart(Type server_type, const std::string & server_custom_name) const
+{
+ auto is_type_default = [](Type current_type)
+ {
+ switch (current_type)
+ {
+ case Type::TCP:
+ case Type::TCP_WITH_PROXY:
+ case Type::TCP_SECURE:
+ case Type::HTTP:
+ case Type::HTTPS:
+ case Type::MYSQL:
+ case Type::GRPC:
+ case Type::POSTGRESQL:
+ case Type::PROMETHEUS:
+ case Type::INTERSERVER_HTTP:
+ case Type::INTERSERVER_HTTPS:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ if (exclude_types.contains(Type::QUERIES_ALL))
+ return false;
+
+ if (exclude_types.contains(Type::QUERIES_DEFAULT) && is_type_default(server_type))
+ return false;
+
+ if (exclude_types.contains(Type::QUERIES_CUSTOM) && server_type == Type::CUSTOM)
+ return false;
+
+ if (exclude_types.contains(server_type))
+ {
+ if (server_type != Type::CUSTOM)
+ return false;
+
+ if (exclude_custom_names.contains(server_custom_name))
+ return false;
+ }
+
+ if (type == Type::QUERIES_ALL)
+ return true;
+
+ if (type == Type::QUERIES_DEFAULT)
+ return is_type_default(server_type);
+
+ if (type == Type::QUERIES_CUSTOM)
+ return server_type == Type::CUSTOM;
+
+ if (type == Type::CUSTOM)
+ return server_type == type && server_custom_name == custom_name;
+
+ return server_type == type;
+}
+
+bool ServerType::shouldStop(const std::string & port_name) const
+{
+ Type port_type;
+ std::string port_custom_name;
+
+ if (port_name == "http_port")
+ port_type = Type::HTTP;
+
+ else if (port_name == "https_port")
+ port_type = Type::HTTPS;
+
+ else if (port_name == "tcp_port")
+ port_type = Type::TCP;
+
+ else if (port_name == "tcp_with_proxy_port")
+ port_type = Type::TCP_WITH_PROXY;
+
+ else if (port_name == "tcp_port_secure")
+ port_type = Type::TCP_SECURE;
+
+ else if (port_name == "mysql_port")
+ port_type = Type::MYSQL;
+
+ else if (port_name == "postgresql_port")
+ port_type = Type::POSTGRESQL;
+
+ else if (port_name == "grpc_port")
+ port_type = Type::GRPC;
+
+ else if (port_name == "prometheus.port")
+ port_type = Type::PROMETHEUS;
+
+ else if (port_name == "interserver_http_port")
+ port_type = Type::INTERSERVER_HTTP;
+
+ else if (port_name == "interserver_https_port")
+ port_type = Type::INTERSERVER_HTTPS;
+
+ else if (port_name.starts_with("protocols.") && port_name.ends_with(".port"))
+ {
+ port_type = Type::CUSTOM;
+
+ constexpr size_t protocols_size = std::string_view("protocols.").size();
+ constexpr size_t ports_size = std::string_view(".ports").size();
+
+ port_custom_name = port_name.substr(protocols_size, port_name.size() - protocols_size - ports_size + 1);
+ }
+
+ else
+ return false;
+
+ return shouldStart(port_type, port_custom_name);
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/ServerType.h b/contrib/clickhouse/src/Server/ServerType.h
new file mode 100644
index 0000000000..e3544fe6a2
--- /dev/null
+++ b/contrib/clickhouse/src/Server/ServerType.h
@@ -0,0 +1,60 @@
+#pragma once
+
+#include <base/types.h>
+#include <unordered_set>
+
+namespace DB
+{
+
+class ServerType
+{
+public:
+ enum Type
+ {
+ TCP,
+ TCP_WITH_PROXY,
+ TCP_SECURE,
+ HTTP,
+ HTTPS,
+ MYSQL,
+ GRPC,
+ POSTGRESQL,
+ PROMETHEUS,
+ CUSTOM,
+ INTERSERVER_HTTP,
+ INTERSERVER_HTTPS,
+ QUERIES_ALL,
+ QUERIES_DEFAULT,
+ QUERIES_CUSTOM,
+ END
+ };
+
+ using Types = std::unordered_set<Type>;
+ using CustomNames = std::unordered_set<String>;
+
+ ServerType() = default;
+
+ explicit ServerType(
+ Type type_,
+ const std::string & custom_name_ = "",
+ const Types & exclude_types_ = {},
+ const CustomNames exclude_custom_names_ = {})
+ : type(type_),
+ custom_name(custom_name_),
+ exclude_types(exclude_types_),
+ exclude_custom_names(exclude_custom_names_) {}
+
+ static const char * serverTypeToString(Type type);
+
+ /// Checks whether provided in the arguments type should be started or stopped based on current server type.
+ bool shouldStart(Type server_type, const std::string & server_custom_name = "") const;
+ bool shouldStop(const std::string & port_name) const;
+
+ Type type;
+ std::string custom_name;
+
+ Types exclude_types;
+ CustomNames exclude_custom_names;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/StaticRequestHandler.cpp b/contrib/clickhouse/src/Server/StaticRequestHandler.cpp
new file mode 100644
index 0000000000..13a01ba813
--- /dev/null
+++ b/contrib/clickhouse/src/Server/StaticRequestHandler.cpp
@@ -0,0 +1,179 @@
+#include "StaticRequestHandler.h"
+#include "IServer.h"
+
+#include "HTTPHandlerFactory.h"
+#include "HTTPHandlerRequestFilter.h"
+
+#include <IO/HTTPCommon.h>
+#include <IO/ReadBufferFromFile.h>
+#include <IO/WriteBufferFromString.h>
+#include <IO/copyData.h>
+#include <IO/WriteHelpers.h>
+#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
+#include <Interpreters/Context.h>
+
+#include <Common/Exception.h>
+
+#include <Poco/Net/HTTPServerRequest.h>
+#include <Poco/Net/HTTPServerResponse.h>
+#include <Poco/Net/HTTPRequestHandlerFactory.h>
+#include <Poco/Util/LayeredConfiguration.h>
+#include <filesystem>
+
+
+namespace fs = std::filesystem;
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int INCORRECT_FILE_NAME;
+ extern const int HTTP_LENGTH_REQUIRED;
+ extern const int INVALID_CONFIG_PARAMETER;
+}
+
+static inline WriteBufferPtr
+responseWriteBuffer(HTTPServerRequest & request, HTTPServerResponse & response, unsigned int keep_alive_timeout)
+{
+ /// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
+ String http_response_compression_methods = request.get("Accept-Encoding", "");
+ CompressionMethod http_response_compression_method = CompressionMethod::None;
+
+ if (!http_response_compression_methods.empty())
+ http_response_compression_method = chooseHTTPCompressionMethod(http_response_compression_methods);
+
+ bool client_supports_http_compression = http_response_compression_method != CompressionMethod::None;
+
+ return std::make_shared<WriteBufferFromHTTPServerResponse>(
+ response,
+ request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD,
+ keep_alive_timeout,
+ client_supports_http_compression,
+ http_response_compression_method);
+}
+
+static inline void trySendExceptionToClient(
+ const std::string & s, int exception_code, HTTPServerRequest & request, HTTPServerResponse & response, WriteBuffer & out)
+{
+ try
+ {
+ response.set("X-ClickHouse-Exception-Code", toString<int>(exception_code));
+
+ /// If HTTP method is POST and Keep-Alive is turned on, we should read the whole request body
+ /// to avoid reading part of the current request body in the next request.
+ if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST
+ && response.getKeepAlive() && !request.getStream().eof() && exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED)
+ request.getStream().ignore(std::numeric_limits<std::streamsize>::max());
+
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
+
+ if (!response.sent())
+ *response.send() << s << std::endl;
+ else
+ {
+ if (out.count() != out.offset())
+ out.position() = out.buffer().begin();
+
+ writeString(s, out);
+ writeChar('\n', out);
+
+ out.next();
+ out.finalize();
+ }
+ }
+ catch (...)
+ {
+ tryLogCurrentException("StaticRequestHandler", "Cannot send exception to client");
+ }
+}
+
+void StaticRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ auto keep_alive_timeout = server.config().getUInt("keep_alive_timeout", 10);
+ const auto & out = responseWriteBuffer(request, response, keep_alive_timeout);
+
+ try
+ {
+ response.setContentType(content_type);
+
+ if (request.getVersion() == Poco::Net::HTTPServerRequest::HTTP_1_1)
+ response.setChunkedTransferEncoding(true);
+
+ /// Workaround. Poco does not detect 411 Length Required case.
+ if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST && !request.getChunkedTransferEncoding() && !request.hasContentLength())
+ throw Exception(ErrorCodes::HTTP_LENGTH_REQUIRED,
+ "The Transfer-Encoding is not chunked and there "
+ "is no Content-Length header for POST request");
+
+ setResponseDefaultHeaders(response, keep_alive_timeout);
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTPStatus(status));
+ writeResponse(*out);
+ }
+ catch (...)
+ {
+ tryLogCurrentException("StaticRequestHandler");
+
+ int exception_code = getCurrentExceptionCode();
+ std::string exception_message = getCurrentExceptionMessage(false, true);
+ trySendExceptionToClient(exception_message, exception_code, request, response, *out);
+ }
+
+ out->finalize();
+}
+
+void StaticRequestHandler::writeResponse(WriteBuffer & out)
+{
+ static const String file_prefix = "file://";
+ static const String config_prefix = "config://";
+
+ if (startsWith(response_expression, file_prefix))
+ {
+ auto file_name = response_expression.substr(file_prefix.size(), response_expression.size() - file_prefix.size());
+ if (file_name.starts_with('/'))
+ file_name = file_name.substr(1);
+
+ fs::path user_files_absolute_path = fs::canonical(fs::path(server.context()->getUserFilesPath()));
+ String file_path = fs::weakly_canonical(user_files_absolute_path / file_name);
+
+ if (!fs::exists(file_path))
+ throw Exception(ErrorCodes::INCORRECT_FILE_NAME, "Invalid file name {} for static HTTPHandler. ", file_path);
+
+ ReadBufferFromFile in(file_path);
+ copyData(in, out);
+ }
+ else if (startsWith(response_expression, config_prefix))
+ {
+ if (response_expression.size() <= config_prefix.size())
+ throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER,
+ "Static handling rule handler must contain a complete configuration path, for example: "
+ "config://config_key");
+
+ const auto & config_path = response_expression.substr(config_prefix.size(), response_expression.size() - config_prefix.size());
+ writeString(server.config().getRawString(config_path, "Ok.\n"), out);
+ }
+ else
+ writeString(response_expression, out);
+}
+
+StaticRequestHandler::StaticRequestHandler(IServer & server_, const String & expression, int status_, const String & content_type_)
+ : server(server_), status(status_), content_type(content_type_), response_expression(expression)
+{
+}
+
+HTTPRequestHandlerFactoryPtr createStaticHandlerFactory(IServer & server,
+ const Poco::Util::AbstractConfiguration & config,
+ const std::string & config_prefix)
+{
+ int status = config.getInt(config_prefix + ".handler.status", 200);
+ std::string response_content = config.getRawString(config_prefix + ".handler.response_content", "Ok.\n");
+ std::string response_content_type = config.getString(config_prefix + ".handler.content_type", "text/plain; charset=UTF-8");
+ auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<StaticRequestHandler>>(
+ server, std::move(response_content), std::move(status), std::move(response_content_type));
+
+ factory->addFiltersFromConfig(config, config_prefix);
+
+ return factory;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/StaticRequestHandler.h b/contrib/clickhouse/src/Server/StaticRequestHandler.h
new file mode 100644
index 0000000000..df9374d440
--- /dev/null
+++ b/contrib/clickhouse/src/Server/StaticRequestHandler.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+#include <base/types.h>
+
+
+namespace DB
+{
+
+class IServer;
+class WriteBuffer;
+
+/// Response with custom string. Can be used for browser.
+class StaticRequestHandler : public HTTPRequestHandler
+{
+private:
+ IServer & server;
+
+ int status;
+ String content_type;
+ String response_expression;
+
+public:
+ StaticRequestHandler(
+ IServer & server,
+ const String & expression,
+ int status_ = 200,
+ const String & content_type_ = "text/html; charset=UTF-8");
+
+ void writeResponse(WriteBuffer & out);
+
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPHandler.cpp b/contrib/clickhouse/src/Server/TCPHandler.cpp
new file mode 100644
index 0000000000..136f2dd953
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPHandler.cpp
@@ -0,0 +1,2151 @@
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <mutex>
+#include <vector>
+#include <string_view>
+#include <cstring>
+#include <base/types.h>
+#include <base/scope_guard.h>
+#include <Poco/Net/NetException.h>
+#include <Poco/Net/SocketAddress.h>
+#include <Poco/Util/LayeredConfiguration.h>
+#include <Common/CurrentThread.h>
+#include <Common/Stopwatch.h>
+#include <Common/NetException.h>
+#include <Common/setThreadName.h>
+#include <Common/OpenSSLHelpers.h>
+#include <IO/Progress.h>
+#include <Compression/CompressedReadBuffer.h>
+#include <Compression/CompressedWriteBuffer.h>
+#include <IO/ReadBufferFromPocoSocket.h>
+#include <IO/WriteBufferFromPocoSocket.h>
+#include <IO/LimitReadBuffer.h>
+#include <IO/ReadHelpers.h>
+#include <IO/WriteHelpers.h>
+#include <Formats/NativeReader.h>
+#include <Formats/NativeWriter.h>
+#include <Interpreters/executeQuery.h>
+#include <Interpreters/TablesStatus.h>
+#include <Interpreters/InternalTextLogsQueue.h>
+#include <Interpreters/OpenTelemetrySpanLog.h>
+#include <Interpreters/Session.h>
+#include <Server/TCPServer.h>
+#include <Storages/StorageReplicatedMergeTree.h>
+#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
+#include <Storages/StorageS3Cluster.h>
+#include <Core/ExternalTable.h>
+#include <Core/ServerSettings.h>
+#include <Access/AccessControl.h>
+#include <Access/Credentials.h>
+#include <DataTypes/DataTypeLowCardinality.h>
+#include <Compression/CompressionFactory.h>
+#include <Common/logger_useful.h>
+#include <Common/CurrentMetrics.h>
+#include <Common/thread_local_rng.h>
+#include <fmt/format.h>
+
+#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
+#include <Processors/Executors/PushingPipelineExecutor.h>
+#include <Processors/Executors/PushingAsyncPipelineExecutor.h>
+#include <Processors/Executors/CompletedPipelineExecutor.h>
+#include <Processors/Sinks/SinkToStorage.h>
+
+#if USE_SSL
+# include <Poco/Net/SecureStreamSocket.h>
+# include <Poco/Net/SecureStreamSocketImpl.h>
+#endif
+
+#include "Core/Protocol.h"
+#include "Storages/MergeTree/RequestResponse.h"
+#include "TCPHandler.h"
+
+#include "config_version.h"
+
+using namespace std::literals;
+using namespace DB;
+
+
+namespace CurrentMetrics
+{
+ extern const Metric QueryThread;
+ extern const Metric ReadTaskRequestsSent;
+ extern const Metric MergeTreeReadTaskRequestsSent;
+ extern const Metric MergeTreeAllRangesAnnouncementsSent;
+}
+
+namespace ProfileEvents
+{
+ extern const Event ReadTaskRequestsSent;
+ extern const Event MergeTreeReadTaskRequestsSent;
+ extern const Event MergeTreeAllRangesAnnouncementsSent;
+ extern const Event ReadTaskRequestsSentElapsedMicroseconds;
+ extern const Event MergeTreeReadTaskRequestsSentElapsedMicroseconds;
+ extern const Event MergeTreeAllRangesAnnouncementsSentElapsedMicroseconds;
+}
+
+namespace DB::ErrorCodes
+{
+ extern const int LOGICAL_ERROR;
+ extern const int ATTEMPT_TO_READ_AFTER_EOF;
+ extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT;
+ extern const int UNKNOWN_EXCEPTION;
+ extern const int UNKNOWN_PACKET_FROM_CLIENT;
+ extern const int POCO_EXCEPTION;
+ extern const int SOCKET_TIMEOUT;
+ extern const int UNEXPECTED_PACKET_FROM_CLIENT;
+ extern const int UNKNOWN_PROTOCOL;
+ extern const int AUTHENTICATION_FAILED;
+ extern const int QUERY_WAS_CANCELLED;
+ extern const int CLIENT_INFO_DOES_NOT_MATCH;
+}
+
+namespace
+{
+NameToNameMap convertToQueryParameters(const Settings & passed_params)
+{
+ NameToNameMap query_parameters;
+ for (const auto & param : passed_params)
+ {
+ std::string value;
+ ReadBufferFromOwnString buf(param.getValueString());
+ readQuoted(value, buf);
+ query_parameters.emplace(param.getName(), value);
+ }
+ return query_parameters;
+}
+
+// This function corrects the wrong client_name from the old client.
+// Old clients 28.7 and some intermediate versions of 28.7 were sending different ClientInfo.client_name
+// "ClickHouse client" was sent with the hello message.
+// "ClickHouse" or "ClickHouse " was sent with the query message.
+void correctQueryClientInfo(const ClientInfo & session_client_info, ClientInfo & client_info)
+{
+ if (client_info.getVersionNumber() <= VersionNumber(23, 8, 1) &&
+ session_client_info.client_name == "ClickHouse client" &&
+ (client_info.client_name == "ClickHouse" || client_info.client_name == "ClickHouse "))
+ {
+ client_info.client_name = "ClickHouse client";
+ }
+}
+
+void validateClientInfo(const ClientInfo & session_client_info, const ClientInfo & client_info)
+{
+ // Secondary query may contain different client_info.
+ // In the case of select from distributed table or 'select * from remote' from non-tcp handler. Server sends the initial client_info data.
+ //
+ // Example 1: curl -q -s --max-time 60 -sS "http://127.0.0.1:8123/?" -d "SELECT 1 FROM remote('127.0.0.1', system.one)"
+ // HTTP handler initiates TCP connection with remote 127.0.0.1 (session on remote 127.0.0.1 use TCP interface)
+ // HTTP handler sends client_info with HTTP interface and HTTP data by TCP protocol in Protocol::Client::Query message.
+ //
+ // Example 2: select * from <distributed_table> --host shard_1 // distributed table has 2 shards: shard_1, shard_2
+ // shard_1 receives a message with 'ClickHouse client' client_name
+ // shard_1 initiates TCP connection with shard_2 with 'ClickHouse server' client_name.
+ // shard_1 sends 'ClickHouse client' client_name in Protocol::Client::Query message to shard_2.
+ if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY)
+ return;
+
+ if (session_client_info.interface != client_info.interface)
+ {
+ throw Exception(
+ DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH,
+ "Client info's interface does not match: {} not equal to {}",
+ toString(session_client_info.interface),
+ toString(client_info.interface));
+ }
+
+ if (session_client_info.interface == ClientInfo::Interface::TCP)
+ {
+ if (session_client_info.client_name != client_info.client_name)
+ throw Exception(
+ DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH,
+ "Client info's client_name does not match: {} not equal to {}",
+ session_client_info.client_name,
+ client_info.client_name);
+
+ // TCP handler got patch version 0 always for backward compatibility.
+ if (!session_client_info.clientVersionEquals(client_info, false))
+ throw Exception(
+ DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH,
+ "Client info's version does not match: {} not equal to {}",
+ session_client_info.getVersionStr(),
+ client_info.getVersionStr());
+
+ // os_user, quota_key, client_trace_context can be different.
+ }
+}
+}
+
+namespace DB
+{
+
+TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_)
+ : Poco::Net::TCPServerConnection(socket_)
+ , server(server_)
+ , tcp_server(tcp_server_)
+ , parse_proxy_protocol(parse_proxy_protocol_)
+ , log(&Poco::Logger::get("TCPHandler"))
+ , server_display_name(std::move(server_display_name_))
+{
+}
+
+TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_)
+: Poco::Net::TCPServerConnection(socket_)
+ , server(server_)
+ , tcp_server(tcp_server_)
+ , log(&Poco::Logger::get("TCPHandler"))
+ , forwarded_for(stack_data.forwarded_for)
+ , certificate(stack_data.certificate)
+ , default_database(stack_data.default_database)
+ , server_display_name(std::move(server_display_name_))
+{
+ if (!forwarded_for.empty())
+ LOG_TRACE(log, "Forwarded client address: {}", forwarded_for);
+}
+
+TCPHandler::~TCPHandler()
+{
+ try
+ {
+ state.reset();
+ if (out)
+ out->next();
+ }
+ catch (...)
+ {
+ tryLogCurrentException(__PRETTY_FUNCTION__);
+ }
+}
+
+void TCPHandler::runImpl()
+{
+ setThreadName("TCPHandler");
+ ThreadStatus thread_status;
+
+ extractConnectionSettingsFromContext(server.context());
+
+ socket().setReceiveTimeout(receive_timeout);
+ socket().setSendTimeout(send_timeout);
+ socket().setNoDelay(true);
+
+ in = std::make_shared<ReadBufferFromPocoSocket>(socket());
+ out = std::make_shared<WriteBufferFromPocoSocket>(socket());
+
+ /// Support for PROXY protocol
+ if (parse_proxy_protocol && !receiveProxyHeader())
+ return;
+
+ if (in->eof())
+ {
+ LOG_INFO(log, "Client has not sent any data.");
+ return;
+ }
+
+ /// User will be authenticated here. It will also set settings from user profile into connection_context.
+ try
+ {
+ receiveHello();
+
+ /// In interserver mode queries are executed without a session context.
+ if (!is_interserver_mode)
+ session->makeSessionContext();
+
+ sendHello();
+ if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM)
+ receiveAddendum();
+
+ if (!is_interserver_mode)
+ {
+ /// If session created, then settings in session context has been updated.
+ /// So it's better to update the connection settings for flexibility.
+ extractConnectionSettingsFromContext(session->sessionContext());
+
+ /// When connecting, the default database could be specified.
+ if (!default_database.empty())
+ session->sessionContext()->setCurrentDatabase(default_database);
+ }
+ }
+ catch (const Exception & e) /// Typical for an incorrect username, password, or address.
+ {
+ if (e.code() == ErrorCodes::CLIENT_HAS_CONNECTED_TO_WRONG_PORT)
+ {
+ LOG_DEBUG(log, "Client has connected to wrong port.");
+ return;
+ }
+
+ if (e.code() == ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF)
+ {
+ LOG_INFO(log, "Client has gone away.");
+ return;
+ }
+
+ try
+ {
+ /// We try to send error information to the client.
+ sendException(e, send_exception_with_stack_trace);
+ }
+ catch (...) {}
+
+ throw;
+ }
+
+ while (tcp_server.isOpen())
+ {
+ /// We are waiting for a packet from the client. Thus, every `poll_interval` seconds check whether we need to shut down.
+ {
+ Stopwatch idle_time;
+ UInt64 timeout_ms = std::min(poll_interval, idle_connection_timeout) * 1000000;
+ while (tcp_server.isOpen() && !server.isCancelled() && !static_cast<ReadBufferFromPocoSocket &>(*in).poll(timeout_ms))
+ {
+ if (idle_time.elapsedSeconds() > idle_connection_timeout)
+ {
+ LOG_TRACE(log, "Closing idle connection");
+ return;
+ }
+ }
+ }
+
+ /// If we need to shut down, or client disconnects.
+ if (!tcp_server.isOpen() || server.isCancelled() || in->eof())
+ {
+ LOG_TEST(log, "Closing connection (open: {}, cancelled: {}, eof: {})", tcp_server.isOpen(), server.isCancelled(), in->eof());
+ break;
+ }
+
+ state.reset();
+
+ /// Initialized later.
+ std::optional<CurrentThread::QueryScope> query_scope;
+ OpenTelemetry::TracingContextHolderPtr thread_trace_context;
+
+ /** An exception during the execution of request (it must be sent over the network to the client).
+ * The client will be able to accept it, if it did not happen while sending another packet and the client has not disconnected yet.
+ */
+ std::unique_ptr<DB::Exception> exception;
+ bool network_error = false;
+ bool query_duration_already_logged = false;
+ auto log_query_duration = [this, &query_duration_already_logged]()
+ {
+ if (query_duration_already_logged)
+ return;
+ query_duration_already_logged = true;
+ auto elapsed_sec = state.watch.elapsedSeconds();
+ /// We already logged more detailed info if we read some rows
+ if (elapsed_sec < 1.0 && state.progress.read_rows)
+ return;
+ LOG_DEBUG(log, "Processed in {} sec.", elapsed_sec);
+ };
+
+ try
+ {
+ /// If a user passed query-local timeouts, reset socket to initial state at the end of the query
+ SCOPE_EXIT({state.timeout_setter.reset();});
+
+ /** If Query - process it. If Ping or Cancel - go back to the beginning.
+ * There may come settings for a separate query that modify `query_context`.
+ * It's possible to receive part uuids packet before the query, so then receivePacket has to be called twice.
+ */
+ if (!receivePacket())
+ continue;
+
+ /** If part_uuids got received in previous packet, trying to read again.
+ */
+ if (state.empty() && state.part_uuids_to_ignore && !receivePacket())
+ continue;
+
+ /// Set up tracing context for this query on current thread
+ thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("TCPHandler",
+ query_context->getClientInfo().client_trace_context,
+ query_context->getSettingsRef(),
+ query_context->getOpenTelemetrySpanLog());
+ thread_trace_context->root_span.kind = OpenTelemetry::SERVER;
+
+ query_scope.emplace(query_context, /* fatal_error_callback */ [this]
+ {
+ std::lock_guard lock(fatal_error_mutex);
+ sendLogs();
+ });
+
+ /// If query received, then settings in query_context has been updated.
+ /// So it's better to update the connection settings for flexibility.
+ extractConnectionSettingsFromContext(query_context);
+
+ /// Sync timeouts on client and server during current query to avoid dangling queries on server
+ /// NOTE: We use send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter),
+ /// because send_timeout is client-side setting which has opposite meaning on the server side.
+ /// NOTE: these settings are applied only for current connection (not for distributed tables' connections)
+ state.timeout_setter = std::make_unique<TimeoutSetter>(socket(), receive_timeout, send_timeout);
+
+ /// Should we send internal logs to client?
+ const auto client_logs_level = query_context->getSettingsRef().send_logs_level;
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_LOGS
+ && client_logs_level != LogsLevel::none)
+ {
+ state.logs_queue = std::make_shared<InternalTextLogsQueue>();
+ state.logs_queue->max_priority = Poco::Logger::parseLevel(client_logs_level.toString());
+ state.logs_queue->setSourceRegexp(query_context->getSettingsRef().send_logs_source_regexp);
+ CurrentThread::attachInternalTextLogsQueue(state.logs_queue, client_logs_level);
+ }
+ if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS)
+ {
+ state.profile_queue = std::make_shared<InternalProfileEventsQueue>(std::numeric_limits<int>::max());
+ CurrentThread::attachInternalProfileEventsQueue(state.profile_queue);
+ }
+
+ query_context->setExternalTablesInitializer([this] (ContextPtr context)
+ {
+ if (context != query_context)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in external tables initializer");
+
+ /// Get blocks of temporary tables
+ readData();
+
+ /// Reset the input stream, as we received an empty block while receiving external table data.
+ /// So, the stream has been marked as cancelled and we can't read from it anymore.
+ state.block_in.reset();
+ state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker.
+ });
+
+ /// Send structure of columns to client for function input()
+ query_context->setInputInitializer([this] (ContextPtr context, const StoragePtr & input_storage)
+ {
+ if (context != query_context)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in Input initializer");
+
+ auto metadata_snapshot = input_storage->getInMemoryMetadataPtr();
+ state.need_receive_data_for_input = true;
+
+ /// Send ColumnsDescription for input storage.
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA
+ && query_context->getSettingsRef().input_format_defaults_for_omitted_fields)
+ {
+ sendTableColumns(metadata_snapshot->getColumns());
+ }
+
+ /// Send block to the client - input storage structure.
+ state.input_header = metadata_snapshot->getSampleBlock();
+ sendData(state.input_header);
+ sendTimezone();
+ });
+
+ query_context->setInputBlocksReaderCallback([this] (ContextPtr context) -> Block
+ {
+ if (context != query_context)
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in InputBlocksReader");
+
+ if (!readDataNext())
+ {
+ state.block_in.reset();
+ state.maybe_compressed_in.reset();
+ return Block();
+ }
+ return state.block_for_input;
+ });
+
+ customizeContext(query_context);
+
+ /// This callback is needed for requesting read tasks inside pipeline for distributed processing
+ query_context->setReadTaskCallback([this]() -> String
+ {
+ Stopwatch watch;
+ CurrentMetrics::Increment callback_metric_increment(CurrentMetrics::ReadTaskRequestsSent);
+
+ std::lock_guard lock(task_callback_mutex);
+
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ return {};
+
+ sendReadTaskRequestAssumeLocked();
+ ProfileEvents::increment(ProfileEvents::ReadTaskRequestsSent);
+ auto res = receiveReadTaskResponseAssumeLocked();
+ ProfileEvents::increment(ProfileEvents::ReadTaskRequestsSentElapsedMicroseconds, watch.elapsedMicroseconds());
+ return res;
+ });
+
+ query_context->setMergeTreeAllRangesCallback([this](InitialAllRangesAnnouncement announcement)
+ {
+ Stopwatch watch;
+ CurrentMetrics::Increment callback_metric_increment(CurrentMetrics::MergeTreeAllRangesAnnouncementsSent);
+ std::lock_guard lock(task_callback_mutex);
+
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ return;
+
+ sendMergeTreeAllRangesAnnounecementAssumeLocked(announcement);
+ ProfileEvents::increment(ProfileEvents::MergeTreeAllRangesAnnouncementsSent);
+ ProfileEvents::increment(ProfileEvents::MergeTreeAllRangesAnnouncementsSentElapsedMicroseconds, watch.elapsedMicroseconds());
+ });
+
+ query_context->setMergeTreeReadTaskCallback([this](ParallelReadRequest request) -> std::optional<ParallelReadResponse>
+ {
+ Stopwatch watch;
+ CurrentMetrics::Increment callback_metric_increment(CurrentMetrics::MergeTreeReadTaskRequestsSent);
+ std::lock_guard lock(task_callback_mutex);
+
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ return std::nullopt;
+
+ sendMergeTreeReadTaskRequestAssumeLocked(std::move(request));
+ ProfileEvents::increment(ProfileEvents::MergeTreeReadTaskRequestsSent);
+ auto res = receivePartitionMergeTreeReadTaskResponseAssumeLocked();
+ ProfileEvents::increment(ProfileEvents::MergeTreeReadTaskRequestsSentElapsedMicroseconds, watch.elapsedMicroseconds());
+ return res;
+ });
+
+ /// Processing Query
+ state.io = executeQuery(state.query, query_context, false, state.stage);
+
+ after_check_cancelled.restart();
+ after_send_progress.restart();
+
+ auto finish_or_cancel = [this]()
+ {
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ state.io.onCancelOrConnectionLoss();
+ else
+ state.io.onFinish();
+ };
+
+ if (state.io.pipeline.pushing())
+ {
+ /// FIXME: check explicitly that insert query suggests to receive data via native protocol,
+ state.need_receive_data_for_insert = true;
+ processInsertQuery();
+ finish_or_cancel();
+ }
+ else if (state.io.pipeline.pulling())
+ {
+ processOrdinaryQueryWithProcessors();
+ finish_or_cancel();
+ }
+ else if (state.io.pipeline.completed())
+ {
+ {
+ CompletedPipelineExecutor executor(state.io.pipeline);
+
+ /// Should not check for cancel in case of input.
+ if (!state.need_receive_data_for_input)
+ {
+ auto callback = [this]()
+ {
+ std::scoped_lock lock(task_callback_mutex, fatal_error_mutex);
+
+ if (getQueryCancellationStatus() == CancellationStatus::FULLY_CANCELLED)
+ return true;
+
+ sendProgress();
+ sendSelectProfileEvents();
+ sendLogs();
+
+ return false;
+ };
+
+ executor.setCancelCallback(callback, interactive_delay / 1000);
+ }
+ executor.execute();
+ }
+
+ finish_or_cancel();
+
+ std::lock_guard lock(task_callback_mutex);
+
+ /// Send final progress after calling onFinish(), since it will update the progress.
+ ///
+ /// NOTE: we cannot send Progress for regular INSERT (with VALUES)
+ /// without breaking protocol compatibility, but it can be done
+ /// by increasing revision.
+ sendProgress();
+ sendSelectProfileEvents();
+ }
+ else
+ {
+ finish_or_cancel();
+ }
+
+ /// Do it before sending end of stream, to have a chance to show log message in client.
+ query_scope->logPeakMemoryUsage();
+ log_query_duration();
+
+ if (state.is_connection_closed)
+ break;
+
+ {
+ std::lock_guard lock(task_callback_mutex);
+ sendLogs();
+ sendEndOfStream();
+ }
+
+ /// QueryState should be cleared before QueryScope, since otherwise
+ /// the MemoryTracker will be wrong for possible deallocations.
+ /// (i.e. deallocations from the Aggregator with two-level aggregation)
+ state.reset();
+ last_sent_snapshots = ProfileEvents::ThreadIdToCountersSnapshot{};
+ query_scope.reset();
+ thread_trace_context.reset();
+ }
+ catch (const Exception & e)
+ {
+ state.io.onException();
+ exception.reset(e.clone());
+
+ if (e.code() == ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT)
+ throw;
+
+ /// If there is UNEXPECTED_PACKET_FROM_CLIENT emulate network_error
+ /// to break the loop, but do not throw to send the exception to
+ /// the client.
+ if (e.code() == ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT)
+ network_error = true;
+
+ /// If a timeout occurred, try to inform client about it and close the session
+ if (e.code() == ErrorCodes::SOCKET_TIMEOUT)
+ network_error = true;
+
+ if (network_error)
+ LOG_TEST(log, "Going to close connection due to exception: {}", e.message());
+ }
+ catch (const Poco::Net::NetException & e)
+ {
+ /** We can get here if there was an error during connection to the client,
+ * or in connection with a remote server that was used to process the request.
+ * It is not possible to distinguish between these two cases.
+ * Although in one of them, we have to send exception to the client, but in the other - we can not.
+ * We will try to send exception to the client in any case - see below.
+ */
+ state.io.onException();
+ exception = std::make_unique<DB::Exception>(Exception::CreateFromPocoTag{}, e);
+ }
+ catch (const Poco::Exception & e)
+ {
+ state.io.onException();
+ exception = std::make_unique<DB::Exception>(Exception::CreateFromPocoTag{}, e);
+ }
+// Server should die on std logic errors in debug, like with assert()
+// or ErrorCodes::LOGICAL_ERROR. This helps catch these errors in
+// tests.
+#ifdef ABORT_ON_LOGICAL_ERROR
+ catch (const std::logic_error & e)
+ {
+ state.io.onException();
+ exception = std::make_unique<DB::Exception>(Exception::CreateFromSTDTag{}, e);
+ sendException(*exception, send_exception_with_stack_trace);
+ std::abort();
+ }
+#endif
+ catch (const std::exception & e)
+ {
+ state.io.onException();
+ exception = std::make_unique<DB::Exception>(Exception::CreateFromSTDTag{}, e);
+ }
+ catch (...)
+ {
+ state.io.onException();
+ exception = std::make_unique<DB::Exception>(ErrorCodes::UNKNOWN_EXCEPTION, "Unknown exception");
+ }
+
+ try
+ {
+ if (exception)
+ {
+ if (thread_trace_context)
+ thread_trace_context->root_span.addAttribute(*exception);
+
+ try
+ {
+ /// Try to send logs to client, but it could be risky too
+ /// Assume that we can't break output here
+ sendLogs();
+ }
+ catch (...)
+ {
+ tryLogCurrentException(log, "Can't send logs to client");
+ }
+
+ const auto & e = *exception;
+ LOG_ERROR(log, getExceptionMessageAndPattern(e, send_exception_with_stack_trace));
+ sendException(*exception, send_exception_with_stack_trace);
+ }
+ }
+ catch (...)
+ {
+ /** Could not send exception information to the client. */
+ network_error = true;
+ LOG_WARNING(log, "Client has gone away.");
+ }
+
+ try
+ {
+ /// A query packet is always followed by one or more data packets.
+ /// If some of those data packets are left, try to skip them.
+ if (exception && !state.empty() && !state.read_all_data)
+ skipData();
+ }
+ catch (...)
+ {
+ network_error = true;
+ LOG_WARNING(log, "Can't skip data packets after query failure.");
+ }
+
+ log_query_duration();
+
+ /// QueryState should be cleared before QueryScope, since otherwise
+ /// the MemoryTracker will be wrong for possible deallocations.
+ /// (i.e. deallocations from the Aggregator with two-level aggregation)
+ state.reset();
+ query_scope.reset();
+ thread_trace_context.reset();
+
+ /// It is important to destroy query context here. We do not want it to live arbitrarily longer than the query.
+ query_context.reset();
+
+ if (is_interserver_mode)
+ {
+ /// We don't really have session in interserver mode, new one is created for each query. It's better to reset it now.
+ session.reset();
+ }
+
+ if (network_error)
+ break;
+ }
+}
+
+
+void TCPHandler::extractConnectionSettingsFromContext(const ContextPtr & context)
+{
+ const auto & settings = context->getSettingsRef();
+ send_exception_with_stack_trace = settings.calculate_text_stack_trace;
+ send_timeout = settings.send_timeout;
+ receive_timeout = settings.receive_timeout;
+ poll_interval = settings.poll_interval;
+ idle_connection_timeout = settings.idle_connection_timeout;
+ interactive_delay = settings.interactive_delay;
+ sleep_in_send_tables_status = settings.sleep_in_send_tables_status_ms;
+ unknown_packet_in_send_data = settings.unknown_packet_in_send_data;
+ sleep_after_receiving_query = settings.sleep_after_receiving_query_ms;
+}
+
+
+bool TCPHandler::readDataNext()
+{
+ Stopwatch watch(CLOCK_MONOTONIC_COARSE);
+
+ /// Poll interval should not be greater than receive_timeout
+ constexpr UInt64 min_timeout_us = 5000; // 5 ms
+ UInt64 timeout_us = std::max(min_timeout_us, std::min(poll_interval * 1000000, static_cast<UInt64>(receive_timeout.totalMicroseconds())));
+ bool read_ok = false;
+
+ /// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down.
+ while (true)
+ {
+ if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(timeout_us))
+ {
+ /// If client disconnected.
+ if (in->eof())
+ {
+ LOG_INFO(log, "Client has dropped the connection, cancel the query.");
+ state.is_connection_closed = true;
+ state.cancellation_status = CancellationStatus::FULLY_CANCELLED;
+ break;
+ }
+
+ /// We accept and process data.
+ read_ok = receivePacket();
+ break;
+ }
+
+ /// Do we need to shut down?
+ if (server.isCancelled())
+ break;
+
+ /** Have we waited for data for too long?
+ * If we periodically poll, the receive_timeout of the socket itself does not work.
+ * Therefore, an additional check is added.
+ */
+ Float64 elapsed = watch.elapsedSeconds();
+ if (elapsed > static_cast<Float64>(receive_timeout.totalSeconds()))
+ {
+ throw Exception(ErrorCodes::SOCKET_TIMEOUT,
+ "Timeout exceeded while receiving data from client. Waited for {} seconds, timeout is {} seconds.",
+ static_cast<size_t>(elapsed), receive_timeout.totalSeconds());
+ }
+ }
+
+ if (read_ok)
+ {
+ sendLogs();
+ sendInsertProfileEvents();
+ }
+ else
+ state.read_all_data = true;
+
+ return read_ok;
+}
+
+
+void TCPHandler::readData()
+{
+ sendLogs();
+
+ while (readDataNext())
+ ;
+
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ throw Exception(ErrorCodes::QUERY_WAS_CANCELLED, "Query was cancelled");
+}
+
+
+void TCPHandler::skipData()
+{
+ state.skipping_data = true;
+ SCOPE_EXIT({ state.skipping_data = false; });
+
+ while (readDataNext())
+ ;
+
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ throw Exception(ErrorCodes::QUERY_WAS_CANCELLED, "Query was cancelled");
+}
+
+
+void TCPHandler::processInsertQuery()
+{
+ size_t num_threads = state.io.pipeline.getNumThreads();
+
+ auto run_executor = [&](auto & executor)
+ {
+ /// Made above the rest of the lines,
+ /// so that in case of `writePrefix` function throws an exception,
+ /// client receive exception before sending data.
+ executor.start();
+
+ /// Send ColumnsDescription for insertion table
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA)
+ {
+ const auto & table_id = query_context->getInsertionTable();
+ if (query_context->getSettingsRef().input_format_defaults_for_omitted_fields)
+ {
+ if (!table_id.empty())
+ {
+ auto storage_ptr = DatabaseCatalog::instance().getTable(table_id, query_context);
+ sendTableColumns(storage_ptr->getInMemoryMetadataPtr()->getColumns());
+ }
+ }
+ }
+
+ /// Send block to the client - table structure.
+ sendData(executor.getHeader());
+ sendLogs();
+
+ while (readDataNext())
+ executor.push(std::move(state.block_for_insert));
+
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ executor.cancel();
+ else
+ executor.finish();
+ };
+
+ if (num_threads > 1)
+ {
+ PushingAsyncPipelineExecutor executor(state.io.pipeline);
+ run_executor(executor);
+ }
+ else
+ {
+ PushingPipelineExecutor executor(state.io.pipeline);
+ run_executor(executor);
+ }
+
+ sendInsertProfileEvents();
+}
+
+
+void TCPHandler::processOrdinaryQueryWithProcessors()
+{
+ auto & pipeline = state.io.pipeline;
+
+ if (query_context->getSettingsRef().allow_experimental_query_deduplication)
+ {
+ std::lock_guard lock(task_callback_mutex);
+ sendPartUUIDs();
+ }
+
+ /// Send header-block, to allow client to prepare output format for data to send.
+ {
+ const auto & header = pipeline.getHeader();
+
+ if (header)
+ {
+ std::lock_guard lock(task_callback_mutex);
+ sendData(header);
+ }
+ }
+
+ /// Defer locking to cover a part of the scope below and everything after it
+ std::unique_lock progress_lock(task_callback_mutex, std::defer_lock);
+
+ {
+ PullingAsyncPipelineExecutor executor(pipeline);
+ CurrentMetrics::Increment query_thread_metric_increment{CurrentMetrics::QueryThread};
+
+ Block block;
+ while (executor.pull(block, interactive_delay / 1000))
+ {
+ std::unique_lock lock(task_callback_mutex);
+
+ auto cancellation_status = getQueryCancellationStatus();
+ if (cancellation_status == CancellationStatus::FULLY_CANCELLED)
+ {
+ /// Several callback like callback for parallel reading could be called from inside the pipeline
+ /// and we have to unlock the mutex from our side to prevent deadlock.
+ lock.unlock();
+ /// A packet was received requesting to stop execution of the request.
+ executor.cancel();
+ break;
+ }
+ else if (cancellation_status == CancellationStatus::READ_CANCELLED)
+ {
+ executor.cancelReading();
+ }
+
+ if (after_send_progress.elapsed() / 1000 >= interactive_delay)
+ {
+ /// Some time passed and there is a progress.
+ after_send_progress.restart();
+ sendProgress();
+ sendSelectProfileEvents();
+ }
+
+ sendLogs();
+
+ if (block)
+ {
+ if (!state.io.null_format)
+ sendData(block);
+ }
+ }
+
+ /// This lock wasn't acquired before and we make .lock() call here
+ /// so everything under this line is covered even together
+ /// with sendProgress() out of the scope
+ progress_lock.lock();
+
+ /** If data has run out, we will send the profiling data and total values to
+ * the last zero block to be able to use
+ * this information in the suffix output of stream.
+ * If the request was interrupted, then `sendTotals` and other methods could not be called,
+ * because we have not read all the data yet,
+ * and there could be ongoing calculations in other threads at the same time.
+ */
+ if (getQueryCancellationStatus() != CancellationStatus::FULLY_CANCELLED)
+ {
+ sendTotals(executor.getTotalsBlock());
+ sendExtremes(executor.getExtremesBlock());
+ sendProfileInfo(executor.getProfileInfo());
+ sendProgress();
+ sendLogs();
+ sendSelectProfileEvents();
+ }
+
+ if (state.is_connection_closed)
+ return;
+
+ sendData({});
+ last_sent_snapshots.clear();
+ }
+
+ sendProgress();
+}
+
+
+void TCPHandler::processTablesStatusRequest()
+{
+ TablesStatusRequest request;
+ request.read(*in, client_tcp_protocol_version);
+
+ ContextPtr context_to_resolve_table_names;
+ if (is_interserver_mode)
+ {
+ /// In interserver mode session context does not exists, because authentication is done for each query.
+ /// We also cannot create query context earlier, because it cannot be created before authentication,
+ /// but query is not received yet. So we have to do this trick.
+ ContextMutablePtr fake_interserver_context = Context::createCopy(server.context());
+ if (!default_database.empty())
+ fake_interserver_context->setCurrentDatabase(default_database);
+ context_to_resolve_table_names = fake_interserver_context;
+ }
+ else
+ {
+ assert(session);
+ context_to_resolve_table_names = session->sessionContext();
+ }
+
+ TablesStatusResponse response;
+ for (const QualifiedTableName & table_name: request.tables)
+ {
+ auto resolved_id = context_to_resolve_table_names->tryResolveStorageID({table_name.database, table_name.table});
+ StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, context_to_resolve_table_names);
+ if (!table)
+ continue;
+
+ TableStatus status;
+ if (auto * replicated_table = dynamic_cast<StorageReplicatedMergeTree *>(table.get()))
+ {
+ status.is_replicated = true;
+ status.absolute_delay = static_cast<UInt32>(replicated_table->getAbsoluteDelay());
+ }
+ else
+ status.is_replicated = false;
+
+ response.table_states_by_id.emplace(table_name, std::move(status));
+ }
+
+
+ writeVarUInt(Protocol::Server::TablesStatusResponse, *out);
+
+ /// For testing hedged requests
+ if (unlikely(sleep_in_send_tables_status.totalMilliseconds()))
+ {
+ out->next();
+ std::chrono::milliseconds ms(sleep_in_send_tables_status.totalMilliseconds());
+ std::this_thread::sleep_for(ms);
+ }
+
+ response.write(*out, client_tcp_protocol_version);
+}
+
+void TCPHandler::receiveUnexpectedTablesStatusRequest()
+{
+ TablesStatusRequest skip_request;
+ skip_request.read(*in, client_tcp_protocol_version);
+
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet TablesStatusRequest received from client");
+}
+
+void TCPHandler::sendPartUUIDs()
+{
+ auto uuids = query_context->getPartUUIDs()->get();
+ if (!uuids.empty())
+ {
+ for (const auto & uuid : uuids)
+ LOG_TRACE(log, "Sending UUID: {}", toString(uuid));
+
+ writeVarUInt(Protocol::Server::PartUUIDs, *out);
+ writeVectorBinary(uuids, *out);
+ out->next();
+ }
+}
+
+
+void TCPHandler::sendReadTaskRequestAssumeLocked()
+{
+ writeVarUInt(Protocol::Server::ReadTaskRequest, *out);
+ out->next();
+}
+
+
+void TCPHandler::sendMergeTreeAllRangesAnnounecementAssumeLocked(InitialAllRangesAnnouncement announcement)
+{
+ writeVarUInt(Protocol::Server::MergeTreeAllRangesAnnounecement, *out);
+ announcement.serialize(*out);
+ out->next();
+}
+
+
+void TCPHandler::sendMergeTreeReadTaskRequestAssumeLocked(ParallelReadRequest request)
+{
+ writeVarUInt(Protocol::Server::MergeTreeReadTaskRequest, *out);
+ request.serialize(*out);
+ out->next();
+}
+
+
+void TCPHandler::sendProfileInfo(const ProfileInfo & info)
+{
+ writeVarUInt(Protocol::Server::ProfileInfo, *out);
+ info.write(*out);
+ out->next();
+}
+
+
+void TCPHandler::sendTotals(const Block & totals)
+{
+ if (totals)
+ {
+ initBlockOutput(totals);
+
+ writeVarUInt(Protocol::Server::Totals, *out);
+ writeStringBinary("", *out);
+
+ state.block_out->write(totals);
+ state.maybe_compressed_out->next();
+ out->next();
+ }
+}
+
+
+void TCPHandler::sendExtremes(const Block & extremes)
+{
+ if (extremes)
+ {
+ initBlockOutput(extremes);
+
+ writeVarUInt(Protocol::Server::Extremes, *out);
+ writeStringBinary("", *out);
+
+ state.block_out->write(extremes);
+ state.maybe_compressed_out->next();
+ out->next();
+ }
+}
+
+void TCPHandler::sendProfileEvents()
+{
+ Block block;
+ ProfileEvents::getProfileEvents(server_display_name, state.profile_queue, block, last_sent_snapshots);
+ if (block.rows() != 0)
+ {
+ initProfileEventsBlockOutput(block);
+
+ writeVarUInt(Protocol::Server::ProfileEvents, *out);
+ writeStringBinary("", *out);
+
+ state.profile_events_block_out->write(block);
+ out->next();
+ }
+}
+
+void TCPHandler::sendSelectProfileEvents()
+{
+ if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS)
+ return;
+
+ sendProfileEvents();
+}
+
+void TCPHandler::sendInsertProfileEvents()
+{
+ if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_PROFILE_EVENTS_IN_INSERT)
+ return;
+ if (query_kind != ClientInfo::QueryKind::INITIAL_QUERY)
+ return;
+
+ sendProfileEvents();
+}
+
+void TCPHandler::sendTimezone()
+{
+ if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_TIMEZONE_UPDATES)
+ return;
+
+ const String & tz = query_context->getSettingsRef().session_timezone.value;
+
+ LOG_DEBUG(log, "TCPHandler::sendTimezone(): {}", tz);
+ writeVarUInt(Protocol::Server::TimezoneUpdate, *out);
+ writeStringBinary(tz, *out);
+ out->next();
+}
+
+
+bool TCPHandler::receiveProxyHeader()
+{
+ if (in->eof())
+ {
+ LOG_WARNING(log, "Client has not sent any data.");
+ return false;
+ }
+
+ String forwarded_address;
+
+ /// Only PROXYv1 is supported.
+ /// Validation of protocol is not fully performed.
+
+ LimitReadBuffer limit_in(*in, 107, /* trow_exception */ true, /* exact_limit */ {}); /// Maximum length from the specs.
+
+ assertString("PROXY ", limit_in);
+
+ if (limit_in.eof())
+ {
+ LOG_WARNING(log, "Incomplete PROXY header is received.");
+ return false;
+ }
+
+ /// TCP4 / TCP6 / UNKNOWN
+ if ('T' == *limit_in.position())
+ {
+ assertString("TCP", limit_in);
+
+ if (limit_in.eof())
+ {
+ LOG_WARNING(log, "Incomplete PROXY header is received.");
+ return false;
+ }
+
+ if ('4' != *limit_in.position() && '6' != *limit_in.position())
+ {
+ LOG_WARNING(log, "Unexpected protocol in PROXY header is received.");
+ return false;
+ }
+
+ ++limit_in.position();
+ assertChar(' ', limit_in);
+
+ /// Read the first field and ignore other.
+ readStringUntilWhitespace(forwarded_address, limit_in);
+
+ /// Skip until \r\n
+ while (!limit_in.eof() && *limit_in.position() != '\r')
+ ++limit_in.position();
+ assertString("\r\n", limit_in);
+ }
+ else if (checkString("UNKNOWN", limit_in))
+ {
+ /// This is just a health check, there is no subsequent data in this connection.
+
+ while (!limit_in.eof() && *limit_in.position() != '\r')
+ ++limit_in.position();
+ assertString("\r\n", limit_in);
+ return false;
+ }
+ else
+ {
+ LOG_WARNING(log, "Unexpected protocol in PROXY header is received.");
+ return false;
+ }
+
+ LOG_TRACE(log, "Forwarded client address from PROXY header: {}", forwarded_address);
+ forwarded_for = std::move(forwarded_address);
+ return true;
+}
+
+
+namespace
+{
+
+std::string formatHTTPErrorResponseWhenUserIsConnectedToWrongPort(const Poco::Util::AbstractConfiguration& config)
+{
+ std::string result = fmt::format(
+ "HTTP/1.0 400 Bad Request\r\n\r\n"
+ "Port {} is for clickhouse-client program\r\n",
+ config.getString("tcp_port"));
+
+ if (config.has("http_port"))
+ {
+ result += fmt::format(
+ "You must use port {} for HTTP.\r\n",
+ config.getString("http_port"));
+ }
+
+ return result;
+}
+
+}
+
+std::unique_ptr<Session> TCPHandler::makeSession()
+{
+ auto interface = is_interserver_mode ? ClientInfo::Interface::TCP_INTERSERVER : ClientInfo::Interface::TCP;
+
+ auto res = std::make_unique<Session>(server.context(), interface, socket().secure(), certificate);
+
+ res->setForwardedFor(forwarded_for);
+ res->setClientName(client_name);
+ res->setClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version);
+ res->setConnectionClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version);
+ res->setClientInterface(interface);
+
+ return res;
+}
+
+void TCPHandler::receiveHello()
+{
+ /// Receive `hello` packet.
+ UInt64 packet_type = 0;
+ String user;
+ String password;
+ String default_db;
+
+ readVarUInt(packet_type, *in);
+ if (packet_type != Protocol::Client::Hello)
+ {
+ /** If you accidentally accessed the HTTP protocol for a port destined for an internal TCP protocol,
+ * Then instead of the packet type, there will be G (GET) or P (POST), in most cases.
+ */
+ if (packet_type == 'G' || packet_type == 'P')
+ {
+ writeString(formatHTTPErrorResponseWhenUserIsConnectedToWrongPort(server.config()), *out);
+ throw Exception(ErrorCodes::CLIENT_HAS_CONNECTED_TO_WRONG_PORT, "Client has connected to wrong port");
+ }
+ else
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT,
+ "Unexpected packet from client (expected Hello, got {})", packet_type);
+ }
+
+ readStringBinary(client_name, *in);
+ readVarUInt(client_version_major, *in);
+ readVarUInt(client_version_minor, *in);
+ // NOTE For backward compatibility of the protocol, client cannot send its version_patch.
+ readVarUInt(client_tcp_protocol_version, *in);
+ readStringBinary(default_db, *in);
+ if (!default_db.empty())
+ default_database = default_db;
+ readStringBinary(user, *in);
+ readStringBinary(password, *in);
+
+ if (user.empty())
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet from client (no user in Hello package)");
+
+ LOG_DEBUG(log, "Connected {} version {}.{}.{}, revision: {}{}{}.",
+ client_name,
+ client_version_major, client_version_minor, client_version_patch,
+ client_tcp_protocol_version,
+ (!default_database.empty() ? ", database: " + default_database : ""),
+ (!user.empty() ? ", user: " + user : "")
+ );
+
+ is_interserver_mode = (user == USER_INTERSERVER_MARKER) && password.empty();
+ if (is_interserver_mode)
+ {
+ if (client_tcp_protocol_version < DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2)
+ LOG_WARNING(LogFrequencyLimiter(log, 10),
+ "Using deprecated interserver protocol because the client is too old. Consider upgrading all nodes in cluster.");
+ receiveClusterNameAndSalt();
+ return;
+ }
+
+ session = makeSession();
+ const auto & client_info = session->getClientInfo();
+
+#if USE_SSL
+ /// Authentication with SSL user certificate
+ if (dynamic_cast<Poco::Net::SecureStreamSocketImpl*>(socket().impl()))
+ {
+ Poco::Net::SecureStreamSocket secure_socket(socket());
+ if (secure_socket.havePeerCertificate())
+ {
+ try
+ {
+ session->authenticate(
+ SSLCertificateCredentials{user, secure_socket.peerCertificate().commonName()},
+ getClientAddress(client_info));
+ return;
+ }
+ catch (...)
+ {
+ tryLogCurrentException(log, "SSL authentication failed, falling back to password authentication");
+ }
+ }
+ }
+#endif
+
+ session->authenticate(user, password, getClientAddress(client_info));
+}
+
+void TCPHandler::receiveAddendum()
+{
+ if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY)
+ readStringBinary(quota_key, *in);
+
+ if (!is_interserver_mode)
+ session->setQuotaClientKey(quota_key);
+}
+
+
+void TCPHandler::receiveUnexpectedHello()
+{
+ UInt64 skip_uint_64;
+ String skip_string;
+
+ readStringBinary(skip_string, *in);
+ readVarUInt(skip_uint_64, *in);
+ readVarUInt(skip_uint_64, *in);
+ readVarUInt(skip_uint_64, *in);
+ readStringBinary(skip_string, *in);
+ readStringBinary(skip_string, *in);
+ readStringBinary(skip_string, *in);
+
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Hello received from client");
+}
+
+
+void TCPHandler::sendHello()
+{
+ writeVarUInt(Protocol::Server::Hello, *out);
+ writeStringBinary(VERSION_NAME, *out);
+ writeVarUInt(VERSION_MAJOR, *out);
+ writeVarUInt(VERSION_MINOR, *out);
+ writeVarUInt(DBMS_TCP_PROTOCOL_VERSION, *out);
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE)
+ writeStringBinary(DateLUT::instance().getTimeZone(), *out);
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME)
+ writeStringBinary(server_display_name, *out);
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_VERSION_PATCH)
+ writeVarUInt(VERSION_PATCH, *out);
+ if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_PASSWORD_COMPLEXITY_RULES)
+ {
+ auto rules = server.context()->getAccessControl().getPasswordComplexityRules();
+
+ writeVarUInt(rules.size(), *out);
+ for (const auto & [original_pattern, exception_message] : rules)
+ {
+ writeStringBinary(original_pattern, *out);
+ writeStringBinary(exception_message, *out);
+ }
+ }
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2)
+ {
+ chassert(!nonce.has_value());
+ /// Contains lots of stuff (including time), so this should be enough for NONCE.
+ nonce.emplace(thread_local_rng());
+ writeIntBinary(nonce.value(), *out);
+ }
+ out->next();
+}
+
+
+bool TCPHandler::receivePacket()
+{
+ UInt64 packet_type = 0;
+ readVarUInt(packet_type, *in);
+
+ switch (packet_type)
+ {
+ case Protocol::Client::IgnoredPartUUIDs:
+ /// Part uuids packet if any comes before query.
+ if (!state.empty() || state.part_uuids_to_ignore)
+ receiveUnexpectedIgnoredPartUUIDs();
+ receiveIgnoredPartUUIDs();
+ return true;
+
+ case Protocol::Client::Query:
+ if (!state.empty())
+ receiveUnexpectedQuery();
+ receiveQuery();
+ return true;
+
+ case Protocol::Client::Data:
+ case Protocol::Client::Scalar:
+ if (state.skipping_data)
+ return receiveUnexpectedData(false);
+ if (state.empty())
+ receiveUnexpectedData(true);
+ return receiveData(packet_type == Protocol::Client::Scalar);
+
+ case Protocol::Client::Ping:
+ writeVarUInt(Protocol::Server::Pong, *out);
+ out->next();
+ return false;
+
+ case Protocol::Client::Cancel:
+ decreaseCancellationStatus("Received 'Cancel' packet from the client, canceling the query.");
+ return false;
+
+ case Protocol::Client::Hello:
+ receiveUnexpectedHello();
+
+ case Protocol::Client::TablesStatusRequest:
+ if (!state.empty())
+ receiveUnexpectedTablesStatusRequest();
+ processTablesStatusRequest();
+ out->next();
+ return false;
+
+ default:
+ throw Exception(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT, "Unknown packet {} from client", toString(packet_type));
+ }
+}
+
+
+void TCPHandler::receiveIgnoredPartUUIDs()
+{
+ readVectorBinary(state.part_uuids_to_ignore.emplace(), *in);
+}
+
+
+void TCPHandler::receiveUnexpectedIgnoredPartUUIDs()
+{
+ std::vector<UUID> skip_part_uuids;
+ readVectorBinary(skip_part_uuids, *in);
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet IgnoredPartUUIDs received from client");
+}
+
+
+String TCPHandler::receiveReadTaskResponseAssumeLocked()
+{
+ UInt64 packet_type = 0;
+ readVarUInt(packet_type, *in);
+ if (packet_type != Protocol::Client::ReadTaskResponse)
+ {
+ if (packet_type == Protocol::Client::Cancel)
+ {
+ decreaseCancellationStatus("Received 'Cancel' packet from the client, canceling the read task.");
+ return {};
+ }
+ else
+ {
+ throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Received {} packet after requesting read task",
+ Protocol::Client::toString(packet_type));
+ }
+ }
+ UInt64 version;
+ readVarUInt(version, *in);
+ if (version != DBMS_CLUSTER_PROCESSING_PROTOCOL_VERSION)
+ throw Exception(ErrorCodes::UNKNOWN_PROTOCOL, "Protocol version for distributed processing mismatched");
+ String response;
+ readStringBinary(response, *in);
+ return response;
+}
+
+
+std::optional<ParallelReadResponse> TCPHandler::receivePartitionMergeTreeReadTaskResponseAssumeLocked()
+{
+ UInt64 packet_type = 0;
+ readVarUInt(packet_type, *in);
+ if (packet_type != Protocol::Client::MergeTreeReadTaskResponse)
+ {
+ if (packet_type == Protocol::Client::Cancel)
+ {
+ decreaseCancellationStatus("Received 'Cancel' packet from the client, canceling the MergeTree read task.");
+ return std::nullopt;
+ }
+ else
+ {
+ throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Received {} packet after requesting read task",
+ Protocol::Client::toString(packet_type));
+ }
+ }
+ ParallelReadResponse response;
+ response.deserialize(*in);
+ return response;
+}
+
+
+void TCPHandler::receiveClusterNameAndSalt()
+{
+ readStringBinary(cluster, *in);
+ readStringBinary(salt, *in, 32);
+}
+
+void TCPHandler::receiveQuery()
+{
+ UInt64 stage = 0;
+ UInt64 compression = 0;
+
+ state.is_empty = false;
+ readStringBinary(state.query_id, *in);
+
+ /// In interserver mode,
+ /// initial_user can be empty in case of Distributed INSERT via Buffer/Kafka,
+ /// (i.e. when the INSERT is done with the global context without user),
+ /// so it is better to reset session to avoid using old user.
+ if (is_interserver_mode)
+ {
+ session = makeSession();
+ }
+
+ /// Read client info.
+ ClientInfo client_info = session->getClientInfo();
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO)
+ {
+ client_info.read(*in, client_tcp_protocol_version);
+
+ correctQueryClientInfo(session->getClientInfo(), client_info);
+ const auto & config_ref = Context::getGlobalContextInstance()->getServerSettings();
+ if (config_ref.validate_tcp_client_information)
+ validateClientInfo(session->getClientInfo(), client_info);
+ }
+
+ /// Per query settings are also passed via TCP.
+ /// We need to check them before applying due to they can violate the settings constraints.
+ auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS)
+ ? SettingsWriteFormat::STRINGS_WITH_FLAGS
+ : SettingsWriteFormat::BINARY;
+ Settings passed_settings;
+ passed_settings.read(*in, settings_format);
+
+ /// Interserver secret.
+ std::string received_hash;
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET)
+ {
+ readStringBinary(received_hash, *in, 32);
+ }
+
+ readVarUInt(stage, *in);
+ state.stage = QueryProcessingStage::Enum(stage);
+
+ readVarUInt(compression, *in);
+ state.compression = static_cast<Protocol::Compression>(compression);
+ last_block_in.compression = state.compression;
+
+ readStringBinary(state.query, *in);
+
+ Settings passed_params;
+ if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS)
+ passed_params.read(*in, settings_format);
+
+ if (is_interserver_mode)
+ {
+ client_info.interface = ClientInfo::Interface::TCP_INTERSERVER;
+#if USE_SSL
+ String cluster_secret = server.context()->getCluster(cluster)->getSecret();
+
+ if (salt.empty() || cluster_secret.empty())
+ {
+ auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed (no salt/cluster secret)");
+ session->onAuthenticationFailure(/* user_name= */ std::nullopt, socket().peerAddress(), exception);
+ throw exception; /// NOLINT
+ }
+
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2 && !nonce.has_value())
+ {
+ auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed (no nonce)");
+ session->onAuthenticationFailure(/* user_name= */ std::nullopt, socket().peerAddress(), exception);
+ throw exception; /// NOLINT
+ }
+
+ std::string data(salt);
+ // For backward compatibility
+ if (nonce.has_value())
+ data += std::to_string(nonce.value());
+ data += cluster_secret;
+ data += state.query;
+ data += state.query_id;
+ data += client_info.initial_user;
+
+ std::string calculated_hash = encodeSHA256(data);
+ assert(calculated_hash.size() == 32);
+
+ /// TODO maybe also check that peer address actually belongs to the cluster?
+ if (calculated_hash != received_hash)
+ {
+ auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed");
+ session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception);
+ throw exception; /// NOLINT
+ }
+
+ /// NOTE Usually we get some fields of client_info (including initial_address and initial_user) from user input,
+ /// so we should not rely on that. However, in this particular case we got client_info from other clickhouse-server, so it's ok.
+ if (client_info.initial_user.empty())
+ {
+ LOG_DEBUG(log, "User (no user, interserver mode) (client: {})", getClientAddress(client_info).toString());
+ }
+ else
+ {
+ LOG_DEBUG(log, "User (initial, interserver mode): {} (client: {})", client_info.initial_user, getClientAddress(client_info).toString());
+ /// In case of inter-server mode authorization is done with the
+ /// initial address of the client, not the real address from which
+ /// the query was come, since the real address is the address of
+ /// the initiator server, while we are interested in client's
+ /// address.
+ session->authenticate(AlwaysAllowCredentials{client_info.initial_user}, client_info.initial_address);
+ }
+#else
+ auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED,
+ "Inter-server secret support is disabled, because ClickHouse was built without SSL library");
+ session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception);
+ throw exception; /// NOLINT
+#endif
+ }
+
+ query_context = session->makeQueryContext(std::move(client_info));
+
+ /// Sets the default database if it wasn't set earlier for the session context.
+ if (is_interserver_mode && !default_database.empty())
+ query_context->setCurrentDatabase(default_database);
+
+ if (state.part_uuids_to_ignore)
+ query_context->getIgnoredPartUUIDs()->add(*state.part_uuids_to_ignore);
+
+ query_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); });
+ query_context->setFileProgressCallback([this](const FileProgress & value) { this->updateProgress(Progress(value)); });
+
+ ///
+ /// Settings
+ ///
+ auto settings_changes = passed_settings.changes();
+ query_kind = query_context->getClientInfo().query_kind;
+ if (query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
+ {
+ /// Throw an exception if the passed settings violate the constraints.
+ query_context->checkSettingsConstraints(settings_changes, SettingSource::QUERY);
+ }
+ else
+ {
+ /// Quietly clamp to the constraints if it's not an initial query.
+ query_context->clampToSettingsConstraints(settings_changes, SettingSource::QUERY);
+ }
+ query_context->applySettingsChanges(settings_changes);
+
+ /// Use the received query id, or generate a random default. It is convenient
+ /// to also generate the default OpenTelemetry trace id at the same time, and
+ /// set the trace parent.
+ /// Notes:
+ /// 1) ClientInfo might contain upstream trace id, so we decide whether to use
+ /// the default ids after we have received the ClientInfo.
+ /// 2) There is the opentelemetry_start_trace_probability setting that
+ /// controls when we start a new trace. It can be changed via Native protocol,
+ /// so we have to apply the changes first.
+ query_context->setCurrentQueryId(state.query_id);
+
+ query_context->addQueryParameters(convertToQueryParameters(passed_params));
+
+ /// For testing hedged requests
+ if (unlikely(sleep_after_receiving_query.totalMilliseconds()))
+ {
+ std::chrono::milliseconds ms(sleep_after_receiving_query.totalMilliseconds());
+ std::this_thread::sleep_for(ms);
+ }
+}
+
+void TCPHandler::receiveUnexpectedQuery()
+{
+ UInt64 skip_uint_64;
+ String skip_string;
+
+ readStringBinary(skip_string, *in);
+
+ ClientInfo skip_client_info;
+ if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO)
+ skip_client_info.read(*in, client_tcp_protocol_version);
+
+ Settings skip_settings;
+ auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) ? SettingsWriteFormat::STRINGS_WITH_FLAGS
+ : SettingsWriteFormat::BINARY;
+ skip_settings.read(*in, settings_format);
+
+ std::string skip_hash;
+ bool interserver_secret = client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET;
+ if (interserver_secret)
+ readStringBinary(skip_hash, *in, 32);
+
+ readVarUInt(skip_uint_64, *in);
+
+ readVarUInt(skip_uint_64, *in);
+ last_block_in.compression = static_cast<Protocol::Compression>(skip_uint_64);
+
+ readStringBinary(skip_string, *in);
+
+ if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS)
+ skip_settings.read(*in, settings_format);
+
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Query received from client");
+}
+
+bool TCPHandler::receiveData(bool scalar)
+{
+ initBlockInput();
+
+ /// The name of the temporary table for writing data, default to empty string
+ auto temporary_id = StorageID::createEmpty();
+ readStringBinary(temporary_id.table_name, *in);
+
+ /// Read one block from the network and write it down
+ Block block = state.block_in->read();
+
+ if (!block)
+ {
+ state.read_all_data = true;
+ return false;
+ }
+
+ if (scalar)
+ {
+ /// Scalar value
+ query_context->addScalar(temporary_id.table_name, block);
+ }
+ else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input)
+ {
+ /// Data for external tables
+
+ auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal);
+ StoragePtr storage;
+ /// If such a table does not exist, create it.
+ if (resolved)
+ {
+ storage = DatabaseCatalog::instance().getTable(resolved, query_context);
+ }
+ else
+ {
+ NamesAndTypesList columns = block.getNamesAndTypesList();
+ auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {});
+ storage = temporary_table.getTable();
+ query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table));
+ }
+ auto metadata_snapshot = storage->getInMemoryMetadataPtr();
+ /// The data will be written directly to the table.
+ QueryPipeline temporary_table_out(storage->write(ASTPtr(), metadata_snapshot, query_context, /*async_insert=*/false));
+ PushingPipelineExecutor executor(temporary_table_out);
+ executor.start();
+ executor.push(block);
+ executor.finish();
+ }
+ else if (state.need_receive_data_for_input)
+ {
+ /// 'input' table function.
+ state.block_for_input = block;
+ }
+ else
+ {
+ /// INSERT query.
+ state.block_for_insert = block;
+ }
+ return true;
+}
+
+
+bool TCPHandler::receiveUnexpectedData(bool throw_exception)
+{
+ String skip_external_table_name;
+ readStringBinary(skip_external_table_name, *in);
+
+ std::shared_ptr<ReadBuffer> maybe_compressed_in;
+ if (last_block_in.compression == Protocol::Compression::Enable)
+ maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true);
+ else
+ maybe_compressed_in = in;
+
+ auto skip_block_in = std::make_shared<NativeReader>(*maybe_compressed_in, client_tcp_protocol_version);
+ bool read_ok = !!skip_block_in->read();
+
+ if (!read_ok)
+ state.read_all_data = true;
+
+ if (throw_exception)
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Data received from client");
+
+ return read_ok;
+}
+
+void TCPHandler::initBlockInput()
+{
+ if (!state.block_in)
+ {
+ /// 'allow_different_codecs' is set to true, because some parts of compressed data can be precompressed in advance
+ /// with another codec that the rest of the data. Example: data sent by Distributed tables.
+
+ if (state.compression == Protocol::Compression::Enable)
+ state.maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true);
+ else
+ state.maybe_compressed_in = in;
+
+ Block header;
+ if (state.io.pipeline.pushing())
+ header = state.io.pipeline.getHeader();
+ else if (state.need_receive_data_for_input)
+ header = state.input_header;
+
+ state.block_in = std::make_unique<NativeReader>(
+ *state.maybe_compressed_in,
+ header,
+ client_tcp_protocol_version);
+ }
+}
+
+
+void TCPHandler::initBlockOutput(const Block & block)
+{
+ if (!state.block_out)
+ {
+ const Settings & query_settings = query_context->getSettingsRef();
+ if (!state.maybe_compressed_out)
+ {
+ std::string method = Poco::toUpper(query_settings.network_compression_method.toString());
+ std::optional<int> level;
+ if (method == "ZSTD")
+ level = query_settings.network_zstd_compression_level;
+
+ if (state.compression == Protocol::Compression::Enable)
+ {
+ CompressionCodecFactory::instance().validateCodec(method, level, !query_settings.allow_suspicious_codecs, query_settings.allow_experimental_codecs, query_settings.enable_deflate_qpl_codec);
+
+ state.maybe_compressed_out = std::make_shared<CompressedWriteBuffer>(
+ *out, CompressionCodecFactory::instance().get(method, level));
+ }
+ else
+ state.maybe_compressed_out = out;
+ }
+
+ state.block_out = std::make_unique<NativeWriter>(
+ *state.maybe_compressed_out,
+ client_tcp_protocol_version,
+ block.cloneEmpty(),
+ !query_settings.low_cardinality_allow_in_native_format);
+ }
+}
+
+void TCPHandler::initLogsBlockOutput(const Block & block)
+{
+ if (!state.logs_block_out)
+ {
+ /// Use uncompressed stream since log blocks usually contain only one row
+ const Settings & query_settings = query_context->getSettingsRef();
+ state.logs_block_out = std::make_unique<NativeWriter>(
+ *out,
+ client_tcp_protocol_version,
+ block.cloneEmpty(),
+ !query_settings.low_cardinality_allow_in_native_format);
+ }
+}
+
+
+void TCPHandler::initProfileEventsBlockOutput(const Block & block)
+{
+ if (!state.profile_events_block_out)
+ {
+ const Settings & query_settings = query_context->getSettingsRef();
+ state.profile_events_block_out = std::make_unique<NativeWriter>(
+ *out,
+ client_tcp_protocol_version,
+ block.cloneEmpty(),
+ !query_settings.low_cardinality_allow_in_native_format);
+ }
+}
+
+void TCPHandler::decreaseCancellationStatus(const std::string & log_message)
+{
+ auto prev_status = magic_enum::enum_name(state.cancellation_status);
+
+ bool partial_result_on_first_cancel = false;
+ if (query_context)
+ {
+ const auto & settings = query_context->getSettingsRef();
+ partial_result_on_first_cancel = settings.partial_result_on_first_cancel;
+ }
+
+ if (partial_result_on_first_cancel && state.cancellation_status == CancellationStatus::NOT_CANCELLED)
+ {
+ state.cancellation_status = CancellationStatus::READ_CANCELLED;
+ }
+ else
+ {
+ state.cancellation_status = CancellationStatus::FULLY_CANCELLED;
+ }
+
+ auto current_status = magic_enum::enum_name(state.cancellation_status);
+ LOG_INFO(log, "Change cancellation status from {} to {}. Log message: {}", prev_status, current_status, log_message);
+}
+
+QueryState::CancellationStatus TCPHandler::getQueryCancellationStatus()
+{
+ if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED || state.sent_all_data)
+ return CancellationStatus::FULLY_CANCELLED;
+
+ if (after_check_cancelled.elapsed() / 1000 < interactive_delay)
+ return state.cancellation_status;
+
+ after_check_cancelled.restart();
+
+ /// During request execution the only packet that can come from the client is stopping the query.
+ if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(0))
+ {
+ if (in->eof())
+ {
+ LOG_INFO(log, "Client has dropped the connection, cancel the query.");
+ state.cancellation_status = CancellationStatus::FULLY_CANCELLED;
+ state.is_connection_closed = true;
+ return CancellationStatus::FULLY_CANCELLED;
+ }
+
+ UInt64 packet_type = 0;
+ readVarUInt(packet_type, *in);
+
+ switch (packet_type)
+ {
+ case Protocol::Client::Cancel:
+ if (state.empty())
+ throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Cancel received from client");
+
+ decreaseCancellationStatus("Query was cancelled.");
+
+ return state.cancellation_status;
+
+ default:
+ throw NetException(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT, "Unknown packet from client {}", toString(packet_type));
+ }
+ }
+
+ return state.cancellation_status;
+}
+
+
+void TCPHandler::sendData(const Block & block)
+{
+ initBlockOutput(block);
+
+ size_t prev_bytes_written_out = out->count();
+ size_t prev_bytes_written_compressed_out = state.maybe_compressed_out->count();
+
+ try
+ {
+ /// For testing hedged requests
+ if (unknown_packet_in_send_data)
+ {
+ constexpr UInt64 marker = (1ULL<<63) - 1;
+ --unknown_packet_in_send_data;
+ if (unknown_packet_in_send_data == 0)
+ writeVarUInt(marker, *out);
+ }
+
+ writeVarUInt(Protocol::Server::Data, *out);
+ /// Send external table name (empty name is the main table)
+ writeStringBinary("", *out);
+
+ /// For testing hedged requests
+ if (block.rows() > 0 && query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds())
+ {
+ out->next();
+ std::chrono::milliseconds ms(query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds());
+ std::this_thread::sleep_for(ms);
+ }
+
+ state.block_out->write(block);
+ state.maybe_compressed_out->next();
+ out->next();
+ }
+ catch (...)
+ {
+ /// In case of unsuccessful write, if the buffer with written data was not flushed,
+ /// we will rollback write to avoid breaking the protocol.
+ /// (otherwise the client will not be able to receive exception after unfinished data
+ /// as it will expect the continuation of the data).
+ /// It looks like hangs on client side or a message like "Data compressed with different methods".
+
+ if (state.compression == Protocol::Compression::Enable)
+ {
+ auto extra_bytes_written_compressed = state.maybe_compressed_out->count() - prev_bytes_written_compressed_out;
+ if (state.maybe_compressed_out->offset() >= extra_bytes_written_compressed)
+ state.maybe_compressed_out->position() -= extra_bytes_written_compressed;
+ }
+
+ auto extra_bytes_written_out = out->count() - prev_bytes_written_out;
+ if (out->offset() >= extra_bytes_written_out)
+ out->position() -= extra_bytes_written_out;
+
+ throw;
+ }
+}
+
+
+void TCPHandler::sendLogData(const Block & block)
+{
+ initLogsBlockOutput(block);
+
+ writeVarUInt(Protocol::Server::Log, *out);
+ /// Send log tag (empty tag is the default tag)
+ writeStringBinary("", *out);
+
+ state.logs_block_out->write(block);
+ out->next();
+}
+
+void TCPHandler::sendTableColumns(const ColumnsDescription & columns)
+{
+ writeVarUInt(Protocol::Server::TableColumns, *out);
+
+ /// Send external table name (empty name is the main table)
+ writeStringBinary("", *out);
+ writeStringBinary(columns.toString(), *out);
+
+ out->next();
+}
+
+void TCPHandler::sendException(const Exception & e, bool with_stack_trace)
+{
+ state.io.setAllDataSent();
+
+ writeVarUInt(Protocol::Server::Exception, *out);
+ writeException(e, *out, with_stack_trace);
+ out->next();
+}
+
+
+void TCPHandler::sendEndOfStream()
+{
+ state.sent_all_data = true;
+ state.io.setAllDataSent();
+
+ writeVarUInt(Protocol::Server::EndOfStream, *out);
+ out->next();
+}
+
+
+void TCPHandler::updateProgress(const Progress & value)
+{
+ state.progress.incrementPiecewiseAtomically(value);
+}
+
+
+void TCPHandler::sendProgress()
+{
+ writeVarUInt(Protocol::Server::Progress, *out);
+ auto increment = state.progress.fetchValuesAndResetPiecewiseAtomically();
+ UInt64 current_elapsed_ns = state.watch.elapsedNanoseconds();
+ increment.elapsed_ns = current_elapsed_ns - state.prev_elapsed_ns;
+ state.prev_elapsed_ns = current_elapsed_ns;
+ increment.write(*out, client_tcp_protocol_version);
+ out->next();
+}
+
+
+void TCPHandler::sendLogs()
+{
+ if (!state.logs_queue)
+ return;
+
+ MutableColumns logs_columns;
+ MutableColumns curr_logs_columns;
+ size_t rows = 0;
+
+ for (; state.logs_queue->tryPop(curr_logs_columns); ++rows)
+ {
+ if (rows == 0)
+ {
+ logs_columns = std::move(curr_logs_columns);
+ }
+ else
+ {
+ for (size_t j = 0; j < logs_columns.size(); ++j)
+ logs_columns[j]->insertRangeFrom(*curr_logs_columns[j], 0, curr_logs_columns[j]->size());
+ }
+ }
+
+ if (rows > 0)
+ {
+ Block block = InternalTextLogsQueue::getSampleBlock();
+ block.setColumns(std::move(logs_columns));
+ sendLogData(block);
+ }
+}
+
+
+void TCPHandler::run()
+{
+ try
+ {
+ runImpl();
+
+ LOG_DEBUG(log, "Done processing connection.");
+ }
+ catch (Poco::Exception & e)
+ {
+ /// Timeout - not an error.
+ if (e.what() == "Timeout"sv)
+ {
+ LOG_DEBUG(log, "Poco::Exception. Code: {}, e.code() = {}, e.displayText() = {}, e.what() = {}", ErrorCodes::POCO_EXCEPTION, e.code(), e.displayText(), e.what());
+ }
+ else
+ throw;
+ }
+}
+
+Poco::Net::SocketAddress TCPHandler::getClientAddress(const ClientInfo & client_info)
+{
+ /// Extract the last entry from comma separated list of forwarded_for addresses.
+ /// Only the last proxy can be trusted (if any).
+ String forwarded_address = client_info.getLastForwardedFor();
+ if (!forwarded_address.empty() && server.config().getBool("auth_use_forwarded_address", false))
+ return Poco::Net::SocketAddress(forwarded_address, socket().peerAddress().port());
+ else
+ return socket().peerAddress();
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPHandler.h b/contrib/clickhouse/src/Server/TCPHandler.h
new file mode 100644
index 0000000000..235f634afe
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPHandler.h
@@ -0,0 +1,294 @@
+#pragma once
+
+#include <optional>
+#include <Poco/Net/TCPServerConnection.h>
+
+#include <base/getFQDNOrHostName.h>
+#include <Common/ProfileEvents.h>
+#include <Common/CurrentMetrics.h>
+#include <Common/Stopwatch.h>
+#include <Common/ThreadStatus.h>
+#include <Core/Protocol.h>
+#include <Core/QueryProcessingStage.h>
+#include <IO/Progress.h>
+#include <IO/TimeoutSetter.h>
+#include <QueryPipeline/BlockIO.h>
+#include <Interpreters/InternalTextLogsQueue.h>
+#include <Interpreters/Context_fwd.h>
+#include <Interpreters/ClientInfo.h>
+#include <Interpreters/ProfileEventsExt.h>
+#include <Formats/NativeReader.h>
+#include <Formats/NativeWriter.h>
+
+#include "IServer.h"
+#include "Server/TCPProtocolStackData.h"
+#include "Storages/MergeTree/RequestResponse.h"
+#include "base/types.h"
+
+
+namespace CurrentMetrics
+{
+ extern const Metric TCPConnection;
+}
+
+namespace Poco { class Logger; }
+
+namespace DB
+{
+
+class Session;
+struct Settings;
+class ColumnsDescription;
+struct ProfileInfo;
+class TCPServer;
+class NativeWriter;
+class NativeReader;
+
+/// State of query processing.
+struct QueryState
+{
+ /// Identifier of the query.
+ String query_id;
+
+ QueryProcessingStage::Enum stage = QueryProcessingStage::Complete;
+ Protocol::Compression compression = Protocol::Compression::Disable;
+
+ /// A queue with internal logs that will be passed to client. It must be
+ /// destroyed after input/output blocks, because they may contain other
+ /// threads that use this queue.
+ InternalTextLogsQueuePtr logs_queue;
+ std::unique_ptr<NativeWriter> logs_block_out;
+
+ InternalProfileEventsQueuePtr profile_queue;
+ std::unique_ptr<NativeWriter> profile_events_block_out;
+
+ /// From where to read data for INSERT.
+ std::shared_ptr<ReadBuffer> maybe_compressed_in;
+ std::unique_ptr<NativeReader> block_in;
+
+ /// Where to write result data.
+ std::shared_ptr<WriteBuffer> maybe_compressed_out;
+ std::unique_ptr<NativeWriter> block_out;
+ Block block_for_insert;
+
+ /// Query text.
+ String query;
+ /// Streams of blocks, that are processing the query.
+ BlockIO io;
+
+ enum class CancellationStatus: UInt8
+ {
+ FULLY_CANCELLED,
+ READ_CANCELLED,
+ NOT_CANCELLED
+ };
+
+ /// Is request cancelled
+ CancellationStatus cancellation_status = CancellationStatus::NOT_CANCELLED;
+ bool is_connection_closed = false;
+ /// empty or not
+ bool is_empty = true;
+ /// Data was sent.
+ bool sent_all_data = false;
+ /// Request requires data from the client (INSERT, but not INSERT SELECT).
+ bool need_receive_data_for_insert = false;
+ /// Data was read.
+ bool read_all_data = false;
+
+ /// A state got uuids to exclude from a query
+ std::optional<std::vector<UUID>> part_uuids_to_ignore;
+
+ /// Request requires data from client for function input()
+ bool need_receive_data_for_input = false;
+ /// temporary place for incoming data block for input()
+ Block block_for_input;
+ /// sample block from StorageInput
+ Block input_header;
+
+ /// If true, the data packets will be skipped instead of reading. Used to recover after errors.
+ bool skipping_data = false;
+
+ /// To output progress, the difference after the previous sending of progress.
+ Progress progress;
+ Stopwatch watch;
+ UInt64 prev_elapsed_ns = 0;
+
+ /// Timeouts setter for current query
+ std::unique_ptr<TimeoutSetter> timeout_setter;
+
+ void reset()
+ {
+ *this = QueryState();
+ }
+
+ bool empty() const
+ {
+ return is_empty;
+ }
+};
+
+
+struct LastBlockInputParameters
+{
+ Protocol::Compression compression = Protocol::Compression::Disable;
+};
+
+class TCPHandler : public Poco::Net::TCPServerConnection
+{
+public:
+ /** parse_proxy_protocol_ - if true, expect and parse the header of PROXY protocol in every connection
+ * and set the information about forwarded address accordingly.
+ * See https://github.com/wolfeidau/proxyv2/blob/master/docs/proxy-protocol.txt
+ *
+ * Note: immediate IP address is always used for access control (accept-list of IP networks),
+ * because it allows to check the IP ranges of the trusted proxy.
+ * Proxy-forwarded (original client) IP address is used for quota accounting if quota is keyed by forwarded IP.
+ */
+ TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_);
+ TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_);
+ ~TCPHandler() override;
+
+ void run() override;
+
+ /// This method is called right before the query execution.
+ virtual void customizeContext(ContextMutablePtr /*context*/) {}
+
+private:
+ IServer & server;
+ TCPServer & tcp_server;
+ bool parse_proxy_protocol = false;
+ Poco::Logger * log;
+
+ String forwarded_for;
+ String certificate;
+
+ String client_name;
+ UInt64 client_version_major = 0;
+ UInt64 client_version_minor = 0;
+ UInt64 client_version_patch = 0;
+ UInt32 client_tcp_protocol_version = 0;
+ String quota_key;
+
+ /// Connection settings, which are extracted from a context.
+ bool send_exception_with_stack_trace = true;
+ Poco::Timespan send_timeout = Poco::Timespan(DBMS_DEFAULT_SEND_TIMEOUT_SEC, 0);
+ Poco::Timespan receive_timeout = Poco::Timespan(DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC, 0);
+ UInt64 poll_interval = DBMS_DEFAULT_POLL_INTERVAL;
+ UInt64 idle_connection_timeout = 3600;
+ UInt64 interactive_delay = 100000;
+ Poco::Timespan sleep_in_send_tables_status;
+ UInt64 unknown_packet_in_send_data = 0;
+ Poco::Timespan sleep_after_receiving_query;
+
+ std::unique_ptr<Session> session;
+ ContextMutablePtr query_context;
+ ClientInfo::QueryKind query_kind = ClientInfo::QueryKind::NO_QUERY;
+
+ /// Streams for reading/writing from/to client connection socket.
+ std::shared_ptr<ReadBuffer> in;
+ std::shared_ptr<WriteBuffer> out;
+
+ /// Time after the last check to stop the request and send the progress.
+ Stopwatch after_check_cancelled;
+ Stopwatch after_send_progress;
+
+ String default_database;
+
+ /// For inter-server secret (remote_server.*.secret)
+ bool is_interserver_mode = false;
+ /// For DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET
+ String salt;
+ /// For DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2
+ std::optional<UInt64> nonce;
+ String cluster;
+
+ std::mutex task_callback_mutex;
+ std::mutex fatal_error_mutex;
+
+ /// At the moment, only one ongoing query in the connection is supported at a time.
+ QueryState state;
+
+ /// Last block input parameters are saved to be able to receive unexpected data packet sent after exception.
+ LastBlockInputParameters last_block_in;
+
+ CurrentMetrics::Increment metric_increment{CurrentMetrics::TCPConnection};
+
+ ProfileEvents::ThreadIdToCountersSnapshot last_sent_snapshots;
+
+ /// It is the name of the server that will be sent to the client.
+ String server_display_name;
+
+ void runImpl();
+
+ void extractConnectionSettingsFromContext(const ContextPtr & context);
+
+ std::unique_ptr<Session> makeSession();
+
+ bool receiveProxyHeader();
+ void receiveHello();
+ void receiveAddendum();
+ bool receivePacket();
+ void receiveQuery();
+ void receiveIgnoredPartUUIDs();
+ String receiveReadTaskResponseAssumeLocked();
+ std::optional<ParallelReadResponse> receivePartitionMergeTreeReadTaskResponseAssumeLocked();
+ bool receiveData(bool scalar);
+ bool readDataNext();
+ void readData();
+ void skipData();
+ void receiveClusterNameAndSalt();
+
+ bool receiveUnexpectedData(bool throw_exception = true);
+ [[noreturn]] void receiveUnexpectedQuery();
+ [[noreturn]] void receiveUnexpectedIgnoredPartUUIDs();
+ [[noreturn]] void receiveUnexpectedHello();
+ [[noreturn]] void receiveUnexpectedTablesStatusRequest();
+
+ /// Process INSERT query
+ void processInsertQuery();
+
+ /// Process a request that does not require the receiving of data blocks from the client
+ void processOrdinaryQuery();
+
+ void processOrdinaryQueryWithProcessors();
+
+ void processTablesStatusRequest();
+
+ void sendHello();
+ void sendData(const Block & block); /// Write a block to the network.
+ void sendLogData(const Block & block);
+ void sendTableColumns(const ColumnsDescription & columns);
+ void sendException(const Exception & e, bool with_stack_trace);
+ void sendProgress();
+ void sendLogs();
+ void sendEndOfStream();
+ void sendPartUUIDs();
+ void sendReadTaskRequestAssumeLocked();
+ void sendMergeTreeAllRangesAnnounecementAssumeLocked(InitialAllRangesAnnouncement announcement);
+ void sendMergeTreeReadTaskRequestAssumeLocked(ParallelReadRequest request);
+ void sendProfileInfo(const ProfileInfo & info);
+ void sendTotals(const Block & totals);
+ void sendExtremes(const Block & extremes);
+ void sendProfileEvents();
+ void sendSelectProfileEvents();
+ void sendInsertProfileEvents();
+ void sendTimezone();
+
+ /// Creates state.block_in/block_out for blocks read/write, depending on whether compression is enabled.
+ void initBlockInput();
+ void initBlockOutput(const Block & block);
+ void initLogsBlockOutput(const Block & block);
+ void initProfileEventsBlockOutput(const Block & block);
+
+ using CancellationStatus = QueryState::CancellationStatus;
+
+ void decreaseCancellationStatus(const std::string & log_message);
+ CancellationStatus getQueryCancellationStatus();
+
+ /// This function is called from different threads.
+ void updateProgress(const Progress & value);
+
+ Poco::Net::SocketAddress getClientAddress(const ClientInfo & client_info);
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPHandlerFactory.h b/contrib/clickhouse/src/Server/TCPHandlerFactory.h
new file mode 100644
index 0000000000..fde04c6e0a
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPHandlerFactory.h
@@ -0,0 +1,74 @@
+#pragma once
+
+#include <Poco/Net/NetException.h>
+#include <Poco/Util/LayeredConfiguration.h>
+#include <Common/logger_useful.h>
+#include "Server/TCPProtocolStackData.h"
+#include <Server/IServer.h>
+#include <Server/TCPHandler.h>
+#include <Server/TCPServerConnectionFactory.h>
+
+namespace Poco { class Logger; }
+
+namespace DB
+{
+
+class TCPHandlerFactory : public TCPServerConnectionFactory
+{
+private:
+ IServer & server;
+ bool parse_proxy_protocol = false;
+ Poco::Logger * log;
+ std::string server_display_name;
+
+ class DummyTCPHandler : public Poco::Net::TCPServerConnection
+ {
+ public:
+ using Poco::Net::TCPServerConnection::TCPServerConnection;
+ void run() override {}
+ };
+
+public:
+ /** parse_proxy_protocol_ - if true, expect and parse the header of PROXY protocol in every connection
+ * and set the information about forwarded address accordingly.
+ * See https://github.com/wolfeidau/proxyv2/blob/master/docs/proxy-protocol.txt
+ */
+ TCPHandlerFactory(IServer & server_, bool secure_, bool parse_proxy_protocol_)
+ : server(server_), parse_proxy_protocol(parse_proxy_protocol_)
+ , log(&Poco::Logger::get(std::string("TCP") + (secure_ ? "S" : "") + "HandlerFactory"))
+ {
+ server_display_name = server.config().getString("display_name", getFQDNOrHostName());
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override
+ {
+ try
+ {
+ LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
+
+ return new TCPHandler(server, tcp_server, socket, parse_proxy_protocol, server_display_name);
+ }
+ catch (const Poco::Net::NetException &)
+ {
+ LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
+ return new DummyTCPHandler(socket);
+ }
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) override
+ {
+ try
+ {
+ LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
+
+ return new TCPHandler(server, tcp_server, socket, stack_data, server_display_name);
+ }
+ catch (const Poco::Net::NetException &)
+ {
+ LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
+ return new DummyTCPHandler(socket);
+ }
+ }
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPProtocolStackData.h b/contrib/clickhouse/src/Server/TCPProtocolStackData.h
new file mode 100644
index 0000000000..4ad401e723
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPProtocolStackData.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <string>
+#include <Poco/Net/StreamSocket.h>
+
+namespace DB
+{
+
+// Data to communicate between protocol layers
+struct TCPProtocolStackData
+{
+ // socket implementation can be replaced by some layer - TLS as an example
+ Poco::Net::StreamSocket socket;
+ // host from PROXY layer
+ std::string forwarded_for;
+ // certificate path from TLS layer to TCP layer
+ std::string certificate;
+ // default database from endpoint configuration to TCP layer
+ std::string default_database;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPProtocolStackFactory.h b/contrib/clickhouse/src/Server/TCPProtocolStackFactory.h
new file mode 100644
index 0000000000..7373e6e1c4
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPProtocolStackFactory.h
@@ -0,0 +1,92 @@
+#pragma once
+
+#include <Server/TCPServerConnectionFactory.h>
+#include <Server/IServer.h>
+#include <Server/TCPProtocolStackHandler.h>
+#include <Poco/Logger.h>
+#include <Poco/Net/NetException.h>
+#include <Common/logger_useful.h>
+#include <Access/Common/AllowedClientHosts.h>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int UNKNOWN_ADDRESS_PATTERN_TYPE;
+ extern const int IP_ADDRESS_NOT_ALLOWED;
+}
+
+
+class TCPProtocolStackFactory : public TCPServerConnectionFactory
+{
+private:
+ IServer & server [[maybe_unused]];
+ Poco::Logger * log;
+ std::string conf_name;
+ std::vector<TCPServerConnectionFactory::Ptr> stack;
+ AllowedClientHosts allowed_client_hosts;
+
+ class DummyTCPHandler : public Poco::Net::TCPServerConnection
+ {
+ public:
+ using Poco::Net::TCPServerConnection::TCPServerConnection;
+ void run() override {}
+ };
+
+public:
+ template <typename... T>
+ explicit TCPProtocolStackFactory(IServer & server_, const std::string & conf_name_, T... factory)
+ : server(server_), log(&Poco::Logger::get("TCPProtocolStackFactory")), conf_name(conf_name_), stack({factory...})
+ {
+ const auto & config = server.config();
+ /// Fill list of allowed hosts.
+ const auto networks_config = conf_name + ".networks";
+ if (config.has(networks_config))
+ {
+ Poco::Util::AbstractConfiguration::Keys keys;
+ config.keys(networks_config, keys);
+ for (const String & key : keys)
+ {
+ String value = config.getString(networks_config + "." + key);
+ if (key.starts_with("ip"))
+ allowed_client_hosts.addSubnet(value);
+ else if (key.starts_with("host_regexp"))
+ allowed_client_hosts.addNameRegexp(value);
+ else if (key.starts_with("host"))
+ allowed_client_hosts.addName(value);
+ else
+ throw Exception(ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE, "Unknown address pattern type: {}", key);
+ }
+ }
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override
+ {
+ if (!allowed_client_hosts.empty() && !allowed_client_hosts.contains(socket.peerAddress().host()))
+ throw Exception(ErrorCodes::IP_ADDRESS_NOT_ALLOWED, "Connections from {} are not allowed", socket.peerAddress().toString());
+
+ try
+ {
+ LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
+ return new TCPProtocolStackHandler(server, tcp_server, socket, stack, conf_name);
+ }
+ catch (const Poco::Net::NetException &)
+ {
+ LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
+ return new DummyTCPHandler(socket);
+ }
+ }
+
+ void append(TCPServerConnectionFactory::Ptr factory)
+ {
+ stack.push_back(std::move(factory));
+ }
+
+ size_t size() { return stack.size(); }
+ bool empty() { return stack.empty(); }
+};
+
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPProtocolStackHandler.h b/contrib/clickhouse/src/Server/TCPProtocolStackHandler.h
new file mode 100644
index 0000000000..e16a6b6b2c
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPProtocolStackHandler.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include <Server/TCPServerConnectionFactory.h>
+#include <Server/TCPServer.h>
+#include <Poco/Util/LayeredConfiguration.h>
+#include <Server/IServer.h>
+#include <Server/TCPProtocolStackData.h>
+
+
+namespace DB
+{
+
+
+class TCPProtocolStackHandler : public Poco::Net::TCPServerConnection
+{
+ using StreamSocket = Poco::Net::StreamSocket;
+ using TCPServerConnection = Poco::Net::TCPServerConnection;
+private:
+ IServer & server;
+ TCPServer & tcp_server;
+ std::vector<TCPServerConnectionFactory::Ptr> stack;
+ std::string conf_name;
+
+public:
+ TCPProtocolStackHandler(IServer & server_, TCPServer & tcp_server_, const StreamSocket & socket, const std::vector<TCPServerConnectionFactory::Ptr> & stack_, const std::string & conf_name_)
+ : TCPServerConnection(socket), server(server_), tcp_server(tcp_server_), stack(stack_), conf_name(conf_name_)
+ {}
+
+ void run() override
+ {
+ const auto & conf = server.config();
+ TCPProtocolStackData stack_data;
+ stack_data.socket = socket();
+ stack_data.default_database = conf.getString(conf_name + ".default_database", "");
+ for (auto & factory : stack)
+ {
+ std::unique_ptr<TCPServerConnection> connection(factory->createConnection(socket(), tcp_server, stack_data));
+ connection->run();
+ if (stack_data.socket != socket())
+ socket() = stack_data.socket;
+ }
+ }
+};
+
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPServer.cpp b/contrib/clickhouse/src/Server/TCPServer.cpp
new file mode 100644
index 0000000000..380c4ef992
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPServer.cpp
@@ -0,0 +1,36 @@
+#include <Poco/Net/TCPServerConnectionFactory.h>
+#include <Server/TCPServer.h>
+
+namespace DB
+{
+
+class TCPServerConnectionFactoryImpl : public Poco::Net::TCPServerConnectionFactory
+{
+public:
+ TCPServerConnectionFactoryImpl(TCPServer & tcp_server_, DB::TCPServerConnectionFactory::Ptr factory_)
+ : tcp_server(tcp_server_)
+ , factory(factory_)
+ {}
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override
+ {
+ return factory->createConnection(socket, tcp_server);
+ }
+private:
+ TCPServer & tcp_server;
+ DB::TCPServerConnectionFactory::Ptr factory;
+};
+
+TCPServer::TCPServer(
+ TCPServerConnectionFactory::Ptr factory_,
+ Poco::ThreadPool & thread_pool,
+ Poco::Net::ServerSocket & socket_,
+ Poco::Net::TCPServerParams::Ptr params)
+ : Poco::Net::TCPServer(new TCPServerConnectionFactoryImpl(*this, factory_), thread_pool, socket_, params)
+ , factory(factory_)
+ , socket(socket_)
+ , is_open(true)
+ , port_number(socket.address().port())
+{}
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPServer.h b/contrib/clickhouse/src/Server/TCPServer.h
new file mode 100644
index 0000000000..219fed5342
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPServer.h
@@ -0,0 +1,47 @@
+#pragma once
+
+#include <Poco/Net/TCPServer.h>
+
+#include <base/types.h>
+#include <Server/TCPServerConnectionFactory.h>
+
+
+namespace DB
+{
+class Context;
+
+class TCPServer : public Poco::Net::TCPServer
+{
+public:
+ explicit TCPServer(
+ TCPServerConnectionFactory::Ptr factory,
+ Poco::ThreadPool & thread_pool,
+ Poco::Net::ServerSocket & socket,
+ Poco::Net::TCPServerParams::Ptr params = new Poco::Net::TCPServerParams);
+
+ /// Close the socket and ask existing connections to stop serving queries
+ void stop()
+ {
+ Poco::Net::TCPServer::stop();
+ // This notifies already established connections that they should stop serving
+ // queries and close their socket as soon as they can.
+ is_open = false;
+ // Poco's stop() stops listening on the socket but leaves it open.
+ // To be able to hand over control of the listening port to a new server, and
+ // to get fast connection refusal instead of timeouts, we also need to close
+ // the listening socket.
+ socket.close();
+ }
+
+ bool isOpen() const { return is_open; }
+
+ UInt16 portNumber() const { return port_number; }
+
+private:
+ TCPServerConnectionFactory::Ptr factory;
+ Poco::Net::ServerSocket socket;
+ std::atomic<bool> is_open;
+ UInt16 port_number;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/TCPServerConnectionFactory.h b/contrib/clickhouse/src/Server/TCPServerConnectionFactory.h
new file mode 100644
index 0000000000..18b30557b0
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TCPServerConnectionFactory.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include <Poco/SharedPtr.h>
+#include <Server/TCPProtocolStackData.h>
+
+namespace Poco
+{
+namespace Net
+{
+ class StreamSocket;
+ class TCPServerConnection;
+}
+}
+namespace DB
+{
+class TCPServer;
+
+class TCPServerConnectionFactory
+{
+public:
+ using Ptr = Poco::SharedPtr<TCPServerConnectionFactory>;
+
+ virtual ~TCPServerConnectionFactory() = default;
+
+ /// Same as Poco::Net::TCPServerConnectionFactory except we can pass the TCPServer
+ virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) = 0;
+ virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData &/* stack_data */)
+ {
+ return createConnection(socket, tcp_server);
+ }
+};
+}
diff --git a/contrib/clickhouse/src/Server/TLSHandler.h b/contrib/clickhouse/src/Server/TLSHandler.h
new file mode 100644
index 0000000000..d4d7584d12
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TLSHandler.h
@@ -0,0 +1,58 @@
+#pragma once
+
+#include <Poco/Net/TCPServerConnection.h>
+#include <Poco/SharedPtr.h>
+#include <Common/Exception.h>
+#include <Server/TCPProtocolStackData.h>
+
+#if USE_SSL
+# include <Poco/Net/Context.h>
+# include <Poco/Net/SecureStreamSocket.h>
+# include <Poco/Net/SSLManager.h>
+#endif
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int SUPPORT_IS_DISABLED;
+}
+
+class TLSHandler : public Poco::Net::TCPServerConnection
+{
+#if USE_SSL
+ using SecureStreamSocket = Poco::Net::SecureStreamSocket;
+ using SSLManager = Poco::Net::SSLManager;
+ using Context = Poco::Net::Context;
+#endif
+ using StreamSocket = Poco::Net::StreamSocket;
+public:
+ explicit TLSHandler(const StreamSocket & socket, const std::string & key_, const std::string & certificate_, TCPProtocolStackData & stack_data_)
+ : Poco::Net::TCPServerConnection(socket)
+ , key(key_)
+ , certificate(certificate_)
+ , stack_data(stack_data_)
+ {}
+
+ void run() override
+ {
+#if USE_SSL
+ auto ctx = SSLManager::instance().defaultServerContext();
+ // if (!key.empty() && !certificate.empty())
+ // ctx = new Context(Context::Usage::SERVER_USE, key, certificate, ctx->getCAPaths().caLocation);
+ socket() = SecureStreamSocket::attach(socket(), ctx);
+ stack_data.socket = socket();
+ stack_data.certificate = certificate;
+#else
+ throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.");
+#endif
+ }
+private:
+ std::string key [[maybe_unused]];
+ std::string certificate [[maybe_unused]];
+ TCPProtocolStackData & stack_data [[maybe_unused]];
+};
+
+
+}
diff --git a/contrib/clickhouse/src/Server/TLSHandlerFactory.h b/contrib/clickhouse/src/Server/TLSHandlerFactory.h
new file mode 100644
index 0000000000..9e3002d297
--- /dev/null
+++ b/contrib/clickhouse/src/Server/TLSHandlerFactory.h
@@ -0,0 +1,64 @@
+#pragma once
+
+#include <Poco/Logger.h>
+#include <Poco/Net/TCPServerConnection.h>
+#include <Poco/Net/NetException.h>
+#include <Poco/Util/LayeredConfiguration.h>
+#include <Server/TLSHandler.h>
+#include <Server/IServer.h>
+#include <Server/TCPServer.h>
+#include <Server/TCPProtocolStackData.h>
+#include <Common/logger_useful.h>
+
+
+namespace DB
+{
+
+
+class TLSHandlerFactory : public TCPServerConnectionFactory
+{
+private:
+ IServer & server;
+ Poco::Logger * log;
+ std::string conf_name;
+
+ class DummyTCPHandler : public Poco::Net::TCPServerConnection
+ {
+ public:
+ using Poco::Net::TCPServerConnection::TCPServerConnection;
+ void run() override {}
+ };
+
+public:
+ explicit TLSHandlerFactory(IServer & server_, const std::string & conf_name_)
+ : server(server_), log(&Poco::Logger::get("TLSHandlerFactory")), conf_name(conf_name_)
+ {
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override
+ {
+ TCPProtocolStackData stack_data;
+ return createConnection(socket, tcp_server, stack_data);
+ }
+
+ Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override
+ {
+ try
+ {
+ LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
+ return new TLSHandler(
+ socket,
+ server.config().getString(conf_name + ".privateKeyFile", ""),
+ server.config().getString(conf_name + ".certificateFile", ""),
+ stack_data);
+ }
+ catch (const Poco::Net::NetException &)
+ {
+ LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent).");
+ return new DummyTCPHandler(socket);
+ }
+ }
+};
+
+
+}
diff --git a/contrib/clickhouse/src/Server/WebUIRequestHandler.cpp b/contrib/clickhouse/src/Server/WebUIRequestHandler.cpp
new file mode 100644
index 0000000000..131badbe83
--- /dev/null
+++ b/contrib/clickhouse/src/Server/WebUIRequestHandler.cpp
@@ -0,0 +1,75 @@
+#include "WebUIRequestHandler.h"
+#include "IServer.h"
+
+#include <Poco/Net/HTTPServerRequest.h>
+#include <Poco/Net/HTTPServerResponse.h>
+#include <Poco/Util/LayeredConfiguration.h>
+
+#include <IO/HTTPCommon.h>
+
+#include <re2/re2.h>
+
+#include <incbin.h>
+
+#include "clickhouse_config.h"
+
+/// Embedded HTML pages
+INCBIN(resource_play_html, SOURCE_DIR "/programs/server/play.html");
+INCBIN(resource_dashboard_html, SOURCE_DIR "/programs/server/dashboard.html");
+INCBIN(resource_uplot_js, SOURCE_DIR "/programs/server/js/uplot.js");
+
+
+namespace DB
+{
+
+WebUIRequestHandler::WebUIRequestHandler(IServer & server_)
+ : server(server_)
+{
+}
+
+
+void WebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response)
+{
+ auto keep_alive_timeout = server.config().getUInt("keep_alive_timeout", 10);
+
+ response.setContentType("text/html; charset=UTF-8");
+
+ if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
+ response.setChunkedTransferEncoding(true);
+
+ setResponseDefaultHeaders(response, keep_alive_timeout);
+
+ if (request.getURI().starts_with("/play"))
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK);
+ *response.send() << std::string_view(reinterpret_cast<const char *>(gresource_play_htmlData), gresource_play_htmlSize);
+ }
+ else if (request.getURI().starts_with("/dashboard"))
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK);
+
+ std::string html(reinterpret_cast<const char *>(gresource_dashboard_htmlData), gresource_dashboard_htmlSize);
+
+ /// Replace a link to external JavaScript file to embedded file.
+ /// This allows to open the HTML without running a server and to host it on server.
+ /// Note: we can embed the JavaScript file inline to the HTML,
+ /// but we don't do it to keep the "view-source" perfectly readable.
+
+ static re2::RE2 uplot_url = R"(https://[^\s"'`]+u[Pp]lot[^\s"'`]*\.js)";
+ RE2::Replace(&html, uplot_url, "/js/uplot.js");
+
+ *response.send() << html;
+ }
+ else if (request.getURI() == "/js/uplot.js")
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK);
+ *response.send() << std::string_view(reinterpret_cast<const char *>(gresource_uplot_jsData), gresource_uplot_jsSize);
+ }
+ else
+ {
+ response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_NOT_FOUND);
+ *response.send() << "Not found.\n";
+ }
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/WebUIRequestHandler.h b/contrib/clickhouse/src/Server/WebUIRequestHandler.h
new file mode 100644
index 0000000000..09fe62d41c
--- /dev/null
+++ b/contrib/clickhouse/src/Server/WebUIRequestHandler.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <Server/HTTP/HTTPRequestHandler.h>
+
+
+namespace DB
+{
+
+class IServer;
+
+/// Response with HTML page that allows to send queries and show results in browser.
+class WebUIRequestHandler : public HTTPRequestHandler
+{
+private:
+ IServer & server;
+
+public:
+ WebUIRequestHandler(IServer & server_);
+ void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override;
+};
+
+}
diff --git a/contrib/clickhouse/src/Server/waitServersToFinish.cpp b/contrib/clickhouse/src/Server/waitServersToFinish.cpp
new file mode 100644
index 0000000000..3b07c08206
--- /dev/null
+++ b/contrib/clickhouse/src/Server/waitServersToFinish.cpp
@@ -0,0 +1,39 @@
+#include <Server/waitServersToFinish.h>
+#include <Server/ProtocolServerAdapter.h>
+#include <base/sleep.h>
+
+namespace DB
+{
+
+size_t waitServersToFinish(std::vector<DB::ProtocolServerAdapter> & servers, std::mutex & mutex, size_t seconds_to_wait)
+{
+ const size_t sleep_max_ms = 1000 * seconds_to_wait;
+ const size_t sleep_one_ms = 100;
+ size_t sleep_current_ms = 0;
+ size_t current_connections = 0;
+ for (;;)
+ {
+ current_connections = 0;
+
+ {
+ std::scoped_lock lock{mutex};
+ for (auto & server : servers)
+ {
+ server.stop();
+ current_connections += server.currentConnections();
+ }
+ }
+
+ if (!current_connections)
+ break;
+
+ sleep_current_ms += sleep_one_ms;
+ if (sleep_current_ms < sleep_max_ms)
+ sleepForMilliseconds(sleep_one_ms);
+ else
+ break;
+ }
+ return current_connections;
+}
+
+}
diff --git a/contrib/clickhouse/src/Server/waitServersToFinish.h b/contrib/clickhouse/src/Server/waitServersToFinish.h
new file mode 100644
index 0000000000..b6daa02596
--- /dev/null
+++ b/contrib/clickhouse/src/Server/waitServersToFinish.h
@@ -0,0 +1,10 @@
+#pragma once
+#include <Core/Types.h>
+
+namespace DB
+{
+class ProtocolServerAdapter;
+
+size_t waitServersToFinish(std::vector<ProtocolServerAdapter> & servers, std::mutex & mutex, size_t seconds_to_wait);
+
+}