diff options
author | vitalyisaev <vitalyisaev@ydb.tech> | 2023-11-14 09:58:56 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@ydb.tech> | 2023-11-14 10:20:20 +0300 |
commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/Server | |
parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
download | ydb-c2b2dfd9827a400a8495e172a56343462e3ceb82.tar.gz |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/Server')
80 files changed, 12607 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/Server/CertificateReloader.cpp b/contrib/clickhouse/src/Server/CertificateReloader.cpp new file mode 100644 index 0000000000..8795d4807d --- /dev/null +++ b/contrib/clickhouse/src/Server/CertificateReloader.cpp @@ -0,0 +1,131 @@ +#include "CertificateReloader.h" + +#if USE_SSL + +#include <Common/logger_useful.h> +#include <base/errnoToString.h> +#include <Poco/Net/Context.h> +#include <Poco/Net/SSLManager.h> +#include <Poco/Net/Utility.h> + + +namespace DB +{ + +namespace +{ +/// Call set process for certificate. +int callSetCertificate(SSL * ssl, [[maybe_unused]] void * arg) +{ + return CertificateReloader::instance().setCertificate(ssl); +} + +} + + +/// This is callback for OpenSSL. It will be called on every connection to obtain a certificate and private key. +int CertificateReloader::setCertificate(SSL * ssl) +{ + auto current = data.get(); + if (!current) + return -1; + + SSL_use_certificate(ssl, const_cast<X509 *>(current->cert.certificate())); + SSL_use_PrivateKey(ssl, const_cast<EVP_PKEY *>(static_cast<const EVP_PKEY *>(current->key))); + + int err = SSL_check_private_key(ssl); + if (err != 1) + { + std::string msg = Poco::Net::Utility::getLastError(); + LOG_ERROR(log, "Unusable key-pair {}", msg); + return -1; + } + + return 1; +} + + +void CertificateReloader::init() +{ + LOG_DEBUG(log, "Initializing certificate reloader."); + + /// Set a callback for OpenSSL to allow get the updated cert and key. + + auto* ctx = Poco::Net::SSLManager::instance().defaultServerContext()->sslContext(); + SSL_CTX_set_cert_cb(ctx, callSetCertificate, nullptr); + init_was_not_made = false; +} + + +void CertificateReloader::tryLoad(const Poco::Util::AbstractConfiguration & config) +{ + /// If at least one of the files is modified - recreate + + std::string new_cert_path = config.getString("openSSL.server.certificateFile", ""); + std::string new_key_path = config.getString("openSSL.server.privateKeyFile", ""); + + /// For empty paths (that means, that user doesn't want to use certificates) + /// no processing required + + if (new_cert_path.empty() || new_key_path.empty()) + { + LOG_INFO(log, "One of paths is empty. Cannot apply new configuration for certificates. Fill all paths and try again."); + } + else + { + bool cert_file_changed = cert_file.changeIfModified(std::move(new_cert_path), log); + bool key_file_changed = key_file.changeIfModified(std::move(new_key_path), log); + std::string pass_phrase = config.getString("openSSL.server.privateKeyPassphraseHandler.options.password", ""); + + if (cert_file_changed || key_file_changed) + { + LOG_DEBUG(log, "Reloading certificate ({}) and key ({}).", cert_file.path, key_file.path); + data.set(std::make_unique<const Data>(cert_file.path, key_file.path, pass_phrase)); + LOG_INFO(log, "Reloaded certificate ({}) and key ({}).", cert_file.path, key_file.path); + } + + /// If callback is not set yet + try + { + if (init_was_not_made) + init(); + } + catch (...) + { + init_was_not_made = true; + LOG_ERROR(log, getCurrentExceptionMessageAndPattern(/* with_stacktrace */ false)); + } + } +} + + +CertificateReloader::Data::Data(std::string cert_path, std::string key_path, std::string pass_phrase) + : cert(cert_path), key(/* public key */ "", /* private key */ key_path, pass_phrase) +{ +} + + +bool CertificateReloader::File::changeIfModified(std::string new_path, Poco::Logger * logger) +{ + std::error_code ec; + std::filesystem::file_time_type new_modification_time = std::filesystem::last_write_time(new_path, ec); + if (ec) + { + LOG_ERROR(logger, "Cannot obtain modification time for {} file {}, skipping update. {}", + description, new_path, errnoToString(ec.value())); + return false; + } + + if (new_path != path || new_modification_time != modification_time) + { + path = new_path; + modification_time = new_modification_time; + return true; + } + + return false; +} + +} + +#endif diff --git a/contrib/clickhouse/src/Server/CertificateReloader.h b/contrib/clickhouse/src/Server/CertificateReloader.h new file mode 100644 index 0000000000..e4db674c14 --- /dev/null +++ b/contrib/clickhouse/src/Server/CertificateReloader.h @@ -0,0 +1,84 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_SSL + +#include <string> +#include <filesystem> + +#include <Poco/Logger.h> +#include <Poco/Util/AbstractConfiguration.h> +#include <openssl/ssl.h> +#include <openssl/x509v3.h> +#include <Poco/Crypto/RSAKey.h> +#include <Poco/Crypto/X509Certificate.h> +#include <Common/MultiVersion.h> + + +namespace DB +{ + +/// The CertificateReloader singleton performs 2 functions: +/// 1. Dynamic reloading of TLS key-pair when requested by server: +/// Server config reloader notifies CertificateReloader when the config changes. +/// On changed config, CertificateReloader reloads certs from disk. +/// 2. Implement `SSL_CTX_set_cert_cb` to set certificate for a new connection: +/// OpenSSL invokes a callback to setup a connection. +class CertificateReloader +{ +public: + using stat_t = struct stat; + + /// Singleton + CertificateReloader(CertificateReloader const &) = delete; + void operator=(CertificateReloader const &) = delete; + static CertificateReloader & instance() + { + static CertificateReloader instance; + return instance; + } + + /// Initialize the callback and perform the initial cert loading + void init(); + + /// Handle configuration reload + void tryLoad(const Poco::Util::AbstractConfiguration & config); + + /// A callback for OpenSSL + int setCertificate(SSL * ssl); + +private: + CertificateReloader() = default; + + Poco::Logger * log = &Poco::Logger::get("CertificateReloader"); + + struct File + { + const char * description; + explicit File(const char * description_) : description(description_) {} + + std::string path; + std::filesystem::file_time_type modification_time; + + bool changeIfModified(std::string new_path, Poco::Logger * logger); + }; + + File cert_file{"certificate"}; + File key_file{"key"}; + + struct Data + { + Poco::Crypto::X509Certificate cert; + Poco::Crypto::EVPPKey key; + + Data(std::string cert_path, std::string key_path, std::string pass_phrase); + }; + + MultiVersion<Data> data; + bool init_was_not_made = true; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/Server/GRPCServer.cpp b/contrib/clickhouse/src/Server/GRPCServer.cpp new file mode 100644 index 0000000000..4d4be31e23 --- /dev/null +++ b/contrib/clickhouse/src/Server/GRPCServer.cpp @@ -0,0 +1,1898 @@ +#include "GRPCServer.h" +#include <limits> +#include <memory> +#if USE_GRPC + +#include <Columns/ColumnString.h> +#include <Columns/ColumnsNumber.h> +#include <Common/CurrentThread.h> +#include <Common/SettingsChanges.h> +#include <Common/setThreadName.h> +#include <Common/Stopwatch.h> +#include <Common/ThreadPool.h> +#include <DataTypes/DataTypeFactory.h> +#include <QueryPipeline/ProfileInfo.h> +#include <Interpreters/Context.h> +#include <Interpreters/InternalTextLogsQueue.h> +#include <Interpreters/executeQuery.h> +#include <Interpreters/Session.h> +#include <IO/CompressionMethod.h> +#include <IO/ConcatReadBuffer.h> +#include <IO/ReadBufferFromString.h> +#include <IO/ReadHelpers.h> +#include <Parsers/parseQuery.h> +#include <Parsers/ASTIdentifier_fwd.h> +#include <Parsers/ASTInsertQuery.h> +#include <Parsers/ASTQueryWithOutput.h> +#include <Parsers/ParserQuery.h> +#include <Processors/Executors/PullingAsyncPipelineExecutor.h> +#include <Processors/Executors/PullingPipelineExecutor.h> +#include <Processors/Executors/PushingPipelineExecutor.h> +#include <Processors/Executors/CompletedPipelineExecutor.h> +#include <Processors/Executors/PipelineExecutor.h> +#include <Processors/Formats/IInputFormat.h> +#include <Processors/Formats/IOutputFormat.h> +#include <Processors/Sinks/SinkToStorage.h> +#include <Processors/Sinks/EmptySink.h> +#include <QueryPipeline/QueryPipelineBuilder.h> +#include <Server/IServer.h> +#include <Storages/IStorage.h> +#include <Poco/FileStream.h> +#include <Poco/StreamCopier.h> +#include <Poco/Util/LayeredConfiguration.h> +#include <base/range.h> +#include <Common/logger_useful.h> +#error #include <grpc++/security/server_credentials.h> +#error #include <grpc++/server.h> +#error #include <grpc++/server_builder.h> + + +using GRPCService = clickhouse::grpc::ClickHouse::AsyncService; +using GRPCQueryInfo = clickhouse::grpc::QueryInfo; +using GRPCResult = clickhouse::grpc::Result; +using GRPCException = clickhouse::grpc::Exception; +using GRPCProgress = clickhouse::grpc::Progress; +using GRPCObsoleteTransportCompression = clickhouse::grpc::ObsoleteTransportCompression; + +namespace DB +{ +namespace ErrorCodes +{ + extern const int INVALID_CONFIG_PARAMETER; + extern const int INVALID_GRPC_QUERY_INFO; + extern const int INVALID_SESSION_TIMEOUT; + extern const int LOGICAL_ERROR; + extern const int NETWORK_ERROR; + extern const int NO_DATA_TO_INSERT; + extern const int SUPPORT_IS_DISABLED; + extern const int BAD_REQUEST_PARAMETER; +} + +namespace +{ + /// Make grpc to pass logging messages to ClickHouse logging system. + void initGRPCLogging(const Poco::Util::AbstractConfiguration & config) + { + static std::once_flag once_flag; + std::call_once(once_flag, [&config] + { + static Poco::Logger * logger = &Poco::Logger::get("grpc"); + gpr_set_log_function([](gpr_log_func_args* args) + { + if (args->severity == GPR_LOG_SEVERITY_DEBUG) + LOG_DEBUG(logger, "{} ({}:{})", args->message, args->file, args->line); + else if (args->severity == GPR_LOG_SEVERITY_INFO) + LOG_INFO(logger, "{} ({}:{})", args->message, args->file, args->line); + else if (args->severity == GPR_LOG_SEVERITY_ERROR) + LOG_ERROR(logger, "{} ({}:{})", args->message, args->file, args->line); + }); + + if (config.getBool("grpc.verbose_logs", false)) + { + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + grpc_tracer_set_enabled("all", true); + } + else if (logger->is(Poco::Message::PRIO_DEBUG)) + { + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); + } + else if (logger->is(Poco::Message::PRIO_INFORMATION)) + { + gpr_set_log_verbosity(GPR_LOG_SEVERITY_INFO); + } + }); + } + + /// Gets file's contents as a string, throws an exception if failed. + String readFile(const String & filepath) + { + Poco::FileInputStream ifs(filepath); + String res; + Poco::StreamCopier::copyToString(ifs, res); + return res; + } + + /// Makes credentials based on the server config. + std::shared_ptr<grpc::ServerCredentials> makeCredentials(const Poco::Util::AbstractConfiguration & config) + { + if (config.getBool("grpc.enable_ssl", false)) + { +#if USE_SSL + grpc::SslServerCredentialsOptions options; + grpc::SslServerCredentialsOptions::PemKeyCertPair key_cert_pair; + key_cert_pair.private_key = readFile(config.getString("grpc.ssl_key_file")); + key_cert_pair.cert_chain = readFile(config.getString("grpc.ssl_cert_file")); + options.pem_key_cert_pairs.emplace_back(std::move(key_cert_pair)); + if (config.getBool("grpc.ssl_require_client_auth", false)) + { + options.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; + if (config.has("grpc.ssl_ca_cert_file")) + options.pem_root_certs = readFile(config.getString("grpc.ssl_ca_cert_file")); + } + return grpc::SslServerCredentials(options); +#else + throw DB::Exception(DB::ErrorCodes::SUPPORT_IS_DISABLED, "Can't use SSL in grpc, because ClickHouse was built without SSL library"); +#endif + } + return grpc::InsecureServerCredentials(); + } + + /// Transport compression makes gRPC library to compress packed Result messages before sending them through network. + struct TransportCompression + { + grpc_compression_algorithm algorithm; + grpc_compression_level level; + + /// Extracts the settings of transport compression from a query info if possible. + static std::optional<TransportCompression> fromQueryInfo(const GRPCQueryInfo & query_info) + { + TransportCompression res; + if (!query_info.transport_compression_type().empty()) + { + res.setAlgorithm(query_info.transport_compression_type(), ErrorCodes::INVALID_GRPC_QUERY_INFO); + res.setLevel(query_info.transport_compression_level(), ErrorCodes::INVALID_GRPC_QUERY_INFO); + return res; + } + + if (query_info.has_obsolete_result_compression()) + { + switch (query_info.obsolete_result_compression().algorithm()) + { + case GRPCObsoleteTransportCompression::NO_COMPRESSION: res.algorithm = GRPC_COMPRESS_NONE; break; + case GRPCObsoleteTransportCompression::DEFLATE: res.algorithm = GRPC_COMPRESS_DEFLATE; break; + case GRPCObsoleteTransportCompression::GZIP: res.algorithm = GRPC_COMPRESS_GZIP; break; + case GRPCObsoleteTransportCompression::STREAM_GZIP: res.algorithm = GRPC_COMPRESS_STREAM_GZIP; break; + default: throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "Unknown compression algorithm: {}", GRPCObsoleteTransportCompression::CompressionAlgorithm_Name(query_info.obsolete_result_compression().algorithm())); + } + + switch (query_info.obsolete_result_compression().level()) + { + case GRPCObsoleteTransportCompression::COMPRESSION_NONE: res.level = GRPC_COMPRESS_LEVEL_NONE; break; + case GRPCObsoleteTransportCompression::COMPRESSION_LOW: res.level = GRPC_COMPRESS_LEVEL_LOW; break; + case GRPCObsoleteTransportCompression::COMPRESSION_MEDIUM: res.level = GRPC_COMPRESS_LEVEL_MED; break; + case GRPCObsoleteTransportCompression::COMPRESSION_HIGH: res.level = GRPC_COMPRESS_LEVEL_HIGH; break; + default: throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "Unknown compression level: {}", GRPCObsoleteTransportCompression::CompressionLevel_Name(query_info.obsolete_result_compression().level())); + } + return res; + } + + return std::nullopt; + } + + /// Extracts the settings of transport compression from the server configuration. + static TransportCompression fromConfiguration(const Poco::Util::AbstractConfiguration & config) + { + TransportCompression res; + if (config.has("grpc.transport_compression_type")) + { + res.setAlgorithm(config.getString("grpc.transport_compression_type"), ErrorCodes::INVALID_CONFIG_PARAMETER); + res.setLevel(config.getInt("grpc.transport_compression_level", 0), ErrorCodes::INVALID_CONFIG_PARAMETER); + } + else + { + res.setAlgorithm(config.getString("grpc.compression", "none"), ErrorCodes::INVALID_CONFIG_PARAMETER); + res.setLevel(config.getString("grpc.compression_level", "none"), ErrorCodes::INVALID_CONFIG_PARAMETER); + } + return res; + } + + private: + void setAlgorithm(const String & str, int error_code) + { + if (str == "none") + algorithm = GRPC_COMPRESS_NONE; + else if (str == "deflate") + algorithm = GRPC_COMPRESS_DEFLATE; + else if (str == "gzip") + algorithm = GRPC_COMPRESS_GZIP; + else if (str == "stream_gzip") + algorithm = GRPC_COMPRESS_STREAM_GZIP; + else + throw Exception(error_code, "Unknown compression algorithm: '{}'", str); + } + + void setLevel(const String & str, int error_code) + { + if (str == "none") + level = GRPC_COMPRESS_LEVEL_NONE; + else if (str == "low") + level = GRPC_COMPRESS_LEVEL_LOW; + else if (str == "medium") + level = GRPC_COMPRESS_LEVEL_MED; + else if (str == "high") + level = GRPC_COMPRESS_LEVEL_HIGH; + else + throw Exception(error_code, "Unknown compression level: '{}'", str); + } + + void setLevel(int level_, int error_code) + { + if (0 <= level_ && level_ < GRPC_COMPRESS_LEVEL_COUNT) + level = static_cast<grpc_compression_level>(level_); + else + throw Exception(error_code, "Compression level {} is out of range 0..{}", level_, GRPC_COMPRESS_LEVEL_COUNT - 1); + } + }; + + /// Gets session's timeout from query info or from the server config. + std::chrono::steady_clock::duration getSessionTimeout(const GRPCQueryInfo & query_info, const Poco::Util::AbstractConfiguration & config) + { + auto session_timeout = query_info.session_timeout(); + if (session_timeout) + { + auto max_session_timeout = config.getUInt("max_session_timeout", 3600); + if (session_timeout > max_session_timeout) + throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Session timeout '{}' is larger than max_session_timeout: {}. " + "Maximum session timeout could be modified in configuration file.", + std::to_string(session_timeout), std::to_string(max_session_timeout)); + } + else + session_timeout = config.getInt("default_session_timeout", 60); + return std::chrono::seconds(session_timeout); + } + + /// Generates a description of a query by a specified query info. + /// This description is used for logging only. + String getQueryDescription(const GRPCQueryInfo & query_info) + { + String str; + if (!query_info.query().empty()) + { + std::string_view query = query_info.query(); + constexpr size_t max_query_length_to_log = 64; + if (query.length() > max_query_length_to_log) + query.remove_suffix(query.length() - max_query_length_to_log); + if (size_t format_pos = query.find(" FORMAT "); format_pos != String::npos) + query.remove_suffix(query.length() - format_pos - strlen(" FORMAT ")); + str.append("\"").append(query); + if (query != query_info.query()) + str.append("..."); + str.append("\""); + } + if (!query_info.query_id().empty()) + str.append(str.empty() ? "" : ", ").append("query_id: ").append(query_info.query_id()); + if (!query_info.input_data().empty()) + str.append(str.empty() ? "" : ", ").append("input_data: ").append(std::to_string(query_info.input_data().size())).append(" bytes"); + if (query_info.external_tables_size()) + str.append(str.empty() ? "" : ", ").append("external tables: ").append(std::to_string(query_info.external_tables_size())); + return str; + } + + /// Generates a description of a result. + /// This description is used for logging only. + String getResultDescription(const GRPCResult & result) + { + String str; + if (!result.output().empty()) + str.append("output: ").append(std::to_string(result.output().size())).append(" bytes"); + if (!result.totals().empty()) + str.append(str.empty() ? "" : ", ").append("totals"); + if (!result.extremes().empty()) + str.append(str.empty() ? "" : ", ").append("extremes"); + if (result.has_progress()) + str.append(str.empty() ? "" : ", ").append("progress"); + if (result.logs_size()) + str.append(str.empty() ? "" : ", ").append("logs: ").append(std::to_string(result.logs_size())).append(" entries"); + if (result.cancelled()) + str.append(str.empty() ? "" : ", ").append("cancelled"); + if (result.has_exception()) + str.append(str.empty() ? "" : ", ").append("exception"); + return str; + } + + using CompletionCallback = std::function<void(bool)>; + + /// Requests a connection and provides low-level interface for reading and writing. + class BaseResponder + { + public: + virtual ~BaseResponder() = default; + + virtual void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) = 0; + + virtual void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) = 0; + virtual void write(const GRPCResult & result, const CompletionCallback & callback) = 0; + virtual void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) = 0; + + Poco::Net::SocketAddress getClientAddress() const + { + String peer = grpc_context.peer(); + return Poco::Net::SocketAddress{peer.substr(peer.find(':') + 1)}; + } + + std::optional<String> getClientHeader(const String & key) const + { + const auto & client_metadata = grpc_context.client_metadata(); + auto it = client_metadata.find(key); + if (it != client_metadata.end()) + return String{it->second.data(), it->second.size()}; + return std::nullopt; + } + + void setTransportCompression(const TransportCompression & transport_compression) + { + grpc_context.set_compression_algorithm(transport_compression.algorithm); + grpc_context.set_compression_level(transport_compression.level); + } + + protected: + CompletionCallback * getCallbackPtr(const CompletionCallback & callback) + { + /// It would be better to pass callbacks to gRPC calls. + /// However gRPC calls can be tagged with `void *` tags only. + /// The map `callbacks` here is used to keep callbacks until they're called. + std::lock_guard lock{mutex}; + size_t callback_id = next_callback_id++; + auto & callback_in_map = callbacks[callback_id]; + callback_in_map = [this, callback, callback_id](bool ok) + { + CompletionCallback callback_to_call; + { + std::lock_guard lock2{mutex}; + callback_to_call = callback; + callbacks.erase(callback_id); + } + callback_to_call(ok); + }; + return &callback_in_map; + } + + grpc::ServerContext grpc_context; + + private: + grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> reader_writer{&grpc_context}; + std::unordered_map<size_t, CompletionCallback> callbacks; + size_t next_callback_id = 0; + std::mutex mutex; + }; + + enum CallType + { + CALL_SIMPLE, /// ExecuteQuery() call + CALL_WITH_STREAM_INPUT, /// ExecuteQueryWithStreamInput() call + CALL_WITH_STREAM_OUTPUT, /// ExecuteQueryWithStreamOutput() call + CALL_WITH_STREAM_IO, /// ExecuteQueryWithStreamIO() call + CALL_MAX, + }; + + const char * getCallName(CallType call_type) + { + switch (call_type) + { + case CALL_SIMPLE: return "ExecuteQuery()"; + case CALL_WITH_STREAM_INPUT: return "ExecuteQueryWithStreamInput()"; + case CALL_WITH_STREAM_OUTPUT: return "ExecuteQueryWithStreamOutput()"; + case CALL_WITH_STREAM_IO: return "ExecuteQueryWithStreamIO()"; + case CALL_MAX: break; + } + UNREACHABLE(); + } + + bool isInputStreaming(CallType call_type) + { + return (call_type == CALL_WITH_STREAM_INPUT) || (call_type == CALL_WITH_STREAM_IO); + } + + bool isOutputStreaming(CallType call_type) + { + return (call_type == CALL_WITH_STREAM_OUTPUT) || (call_type == CALL_WITH_STREAM_IO); + } + + template <enum CallType call_type> + class Responder; + + template<> + class Responder<CALL_SIMPLE> : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQuery(&grpc_context, &query_info.emplace(), &response_writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + if (!query_info.has_value()) + callback(false); + query_info_ = std::move(query_info).value(); + query_info.reset(); + callback(true); + } + + void write(const GRPCResult &, const CompletionCallback &) override + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Responder<CALL_SIMPLE>::write() should not be called"); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + response_writer.Finish(result, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncResponseWriter<GRPCResult> response_writer{&grpc_context}; + std::optional<GRPCQueryInfo> query_info; + }; + + template<> + class Responder<CALL_WITH_STREAM_INPUT> : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQueryWithStreamInput(&grpc_context, &reader, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + reader.Read(&query_info_, getCallbackPtr(callback)); + } + + void write(const GRPCResult &, const CompletionCallback &) override + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Responder<CALL_WITH_STREAM_INPUT>::write() should not be called"); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + reader.Finish(result, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncReader<GRPCResult, GRPCQueryInfo> reader{&grpc_context}; + }; + + template<> + class Responder<CALL_WITH_STREAM_OUTPUT> : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQueryWithStreamOutput(&grpc_context, &query_info.emplace(), &writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + if (!query_info.has_value()) + callback(false); + query_info_ = std::move(query_info).value(); + query_info.reset(); + callback(true); + } + + void write(const GRPCResult & result, const CompletionCallback & callback) override + { + writer.Write(result, getCallbackPtr(callback)); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncWriter<GRPCResult> writer{&grpc_context}; + std::optional<GRPCQueryInfo> query_info; + }; + + template<> + class Responder<CALL_WITH_STREAM_IO> : public BaseResponder + { + public: + void start(GRPCService & grpc_service, + grpc::ServerCompletionQueue & new_call_queue, + grpc::ServerCompletionQueue & notification_queue, + const CompletionCallback & callback) override + { + grpc_service.RequestExecuteQueryWithStreamIO(&grpc_context, &reader_writer, &new_call_queue, ¬ification_queue, getCallbackPtr(callback)); + } + + void read(GRPCQueryInfo & query_info_, const CompletionCallback & callback) override + { + reader_writer.Read(&query_info_, getCallbackPtr(callback)); + } + + void write(const GRPCResult & result, const CompletionCallback & callback) override + { + reader_writer.Write(result, getCallbackPtr(callback)); + } + + void writeAndFinish(const GRPCResult & result, const grpc::Status & status, const CompletionCallback & callback) override + { + reader_writer.WriteAndFinish(result, {}, status, getCallbackPtr(callback)); + } + + private: + grpc::ServerAsyncReaderWriter<GRPCResult, GRPCQueryInfo> reader_writer{&grpc_context}; + }; + + std::unique_ptr<BaseResponder> makeResponder(CallType call_type) + { + switch (call_type) + { + case CALL_SIMPLE: return std::make_unique<Responder<CALL_SIMPLE>>(); + case CALL_WITH_STREAM_INPUT: return std::make_unique<Responder<CALL_WITH_STREAM_INPUT>>(); + case CALL_WITH_STREAM_OUTPUT: return std::make_unique<Responder<CALL_WITH_STREAM_OUTPUT>>(); + case CALL_WITH_STREAM_IO: return std::make_unique<Responder<CALL_WITH_STREAM_IO>>(); + case CALL_MAX: break; + } + UNREACHABLE(); + } + + + /// Implementation of ReadBuffer, which just calls a callback. + class ReadBufferFromCallback : public ReadBuffer + { + public: + explicit ReadBufferFromCallback(const std::function<std::pair<const void *, size_t>(void)> & callback_) + : ReadBuffer(nullptr, 0), callback(callback_) {} + + private: + bool nextImpl() override + { + const void * new_pos; + size_t new_size; + std::tie(new_pos, new_size) = callback(); + if (!new_size) + return false; + BufferBase::set(static_cast<BufferBase::Position>(const_cast<void *>(new_pos)), new_size, 0); + return true; + } + + std::function<std::pair<const void *, size_t>(void)> callback; + }; + + + /// A boolean state protected by mutex able to wait until other thread sets it to a specific value. + class BoolState + { + public: + explicit BoolState(bool initial_value) : value(initial_value) {} + + bool get() const + { + std::lock_guard lock{mutex}; + return value; + } + + void set(bool new_value) + { + std::lock_guard lock{mutex}; + if (value == new_value) + return; + value = new_value; + changed.notify_all(); + } + + void wait(bool wanted_value) const + { + std::unique_lock lock{mutex}; + changed.wait(lock, [this, wanted_value]() { return value == wanted_value; }); + } + + private: + bool value; + mutable std::mutex mutex; + mutable std::condition_variable changed; + }; + + + /// Handles a connection after a responder is started (i.e. after getting a new call). + class Call + { + public: + Call(CallType call_type_, std::unique_ptr<BaseResponder> responder_, IServer & iserver_, Poco::Logger * log_); + ~Call(); + + void start(const std::function<void(void)> & on_finish_call_callback); + + private: + void run(); + + void receiveQuery(); + void executeQuery(); + + void processInput(); + void initializePipeline(const Block & header); + void createExternalTables(); + + void generateOutput(); + + void finishQuery(); + void onException(const Exception & exception); + void onFatalError(); + void releaseQueryIDAndSessionID(); + void close(); + + void readQueryInfo(); + void throwIfFailedToReadQueryInfo(); + bool isQueryCancelled(); + + void addQueryDetailsToResult(); + void addOutputFormatToResult(); + void addOutputColumnsNamesAndTypesToResult(const Block & headers); + void addProgressToResult(); + void addTotalsToResult(const Block & totals); + void addExtremesToResult(const Block & extremes); + void addProfileInfoToResult(const ProfileInfo & info); + void addLogsToResult(); + void sendResult(); + void throwIfFailedToSendResult(); + void sendException(const Exception & exception); + + const CallType call_type; + std::unique_ptr<BaseResponder> responder; + IServer & iserver; + Poco::Logger * log = nullptr; + + std::optional<Session> session; + ContextMutablePtr query_context; + std::optional<CurrentThread::QueryScope> query_scope; + OpenTelemetry::TracingContextHolderPtr thread_trace_context; + String query_text; + ASTPtr ast; + ASTInsertQuery * insert_query = nullptr; + String input_format; + String input_data_delimiter; + CompressionMethod input_compression_method = CompressionMethod::None; + PODArray<char> output; + String output_format; + bool send_output_columns_names_and_types = false; + CompressionMethod output_compression_method = CompressionMethod::None; + int output_compression_level = 0; + + uint64_t interactive_delay = 100000; + bool send_exception_with_stacktrace = true; + bool input_function_is_used = false; + + BlockIO io; + Progress progress; + InternalTextLogsQueuePtr logs_queue; + + GRPCQueryInfo query_info; /// We reuse the same messages multiple times. + GRPCResult result; + + bool initial_query_info_read = false; + bool finalize = false; + bool responder_finished = false; + bool cancelled = false; + + std::unique_ptr<ReadBuffer> read_buffer; + std::unique_ptr<WriteBuffer> write_buffer; + WriteBufferFromVector<PODArray<char>> * nested_write_buffer = nullptr; + WriteBuffer * compressing_write_buffer = nullptr; + std::unique_ptr<QueryPipeline> pipeline; + std::unique_ptr<PullingPipelineExecutor> pipeline_executor; + std::shared_ptr<IOutputFormat> output_format_processor; + bool need_input_data_from_insert_query = true; + bool need_input_data_from_query_info = true; + bool need_input_data_delimiter = false; + + Stopwatch query_time; + UInt64 waited_for_client_reading = 0; + UInt64 waited_for_client_writing = 0; + + /// The following fields are accessed both from call_thread and queue_thread. + BoolState reading_query_info{false}; + std::atomic<bool> failed_to_read_query_info = false; + GRPCQueryInfo next_query_info_while_reading; + std::atomic<bool> want_to_cancel = false; + std::atomic<bool> check_query_info_contains_cancel_only = false; + BoolState sending_result{false}; + std::atomic<bool> failed_to_send_result = false; + + ThreadFromGlobalPool call_thread; + }; + + Call::Call(CallType call_type_, std::unique_ptr<BaseResponder> responder_, IServer & iserver_, Poco::Logger * log_) + : call_type(call_type_), responder(std::move(responder_)), iserver(iserver_), log(log_) + { + } + + Call::~Call() + { + if (call_thread.joinable()) + call_thread.join(); + } + + void Call::start(const std::function<void(void)> & on_finish_call_callback) + { + auto runner_function = [this, on_finish_call_callback] + { + try + { + run(); + } + catch (...) + { + tryLogCurrentException("GRPCServer"); + } + on_finish_call_callback(); + }; + call_thread = ThreadFromGlobalPool(runner_function); + } + + void Call::run() + { + try + { + setThreadName("GRPCServerCall"); + receiveQuery(); + executeQuery(); + processInput(); + generateOutput(); + finishQuery(); + } + catch (Exception & exception) + { + onException(exception); + } + catch (Poco::Exception & exception) + { + onException(Exception{Exception::CreateFromPocoTag{}, exception}); + } + catch (std::exception & exception) + { + onException(Exception{Exception::CreateFromSTDTag{}, exception}); + } + } + + void Call::receiveQuery() + { + LOG_INFO(log, "Handling call {}", getCallName(call_type)); + + readQueryInfo(); + + if (query_info.cancel()) + throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "Initial query info cannot set the 'cancel' field"); + + LOG_DEBUG(log, "Received initial QueryInfo: {}", getQueryDescription(query_info)); + } + + void Call::executeQuery() + { + /// Retrieve user credentials. + std::string user = query_info.user_name(); + std::string password = query_info.password(); + std::string quota_key = query_info.quota(); + Poco::Net::SocketAddress user_address = responder->getClientAddress(); + + if (user.empty()) + { + user = "default"; + password = ""; + } + + /// Authentication. + session.emplace(iserver.context(), ClientInfo::Interface::GRPC); + session->authenticate(user, password, user_address); + session->setQuotaClientKey(quota_key); + + ClientInfo client_info = session->getClientInfo(); + + /// Parse the OpenTelemetry traceparent header. + auto traceparent = responder->getClientHeader("traceparent"); + if (traceparent) + { + String error; + if (!client_info.client_trace_context.parseTraceparentHeader(traceparent.value(), error)) + { + throw Exception(ErrorCodes::BAD_REQUEST_PARAMETER, + "Failed to parse OpenTelemetry traceparent header '{}': {}", + traceparent.value(), error); + } + auto tracestate = responder->getClientHeader("tracestate"); + client_info.client_trace_context.tracestate = tracestate.value_or(""); + } + + /// The user could specify session identifier and session timeout. + /// It allows to modify settings, create temporary tables and reuse them in subsequent requests. + if (!query_info.session_id().empty()) + { + session->makeSessionContext( + query_info.session_id(), getSessionTimeout(query_info, iserver.config()), query_info.session_check()); + } + + query_context = session->makeQueryContext(std::move(client_info)); + + /// Prepare settings. + SettingsChanges settings_changes; + for (const auto & [key, value] : query_info.settings()) + { + settings_changes.push_back({key, value}); + } + query_context->checkSettingsConstraints(settings_changes, SettingSource::QUERY); + query_context->applySettingsChanges(settings_changes); + + query_context->setCurrentQueryId(query_info.query_id()); + query_scope.emplace(query_context, /* fatal_error_callback */ [this]{ onFatalError(); }); + + /// Set up tracing context for this query on current thread + thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("GRPCServer", + query_context->getClientInfo().client_trace_context, + query_context->getSettingsRef(), + query_context->getOpenTelemetrySpanLog()); + thread_trace_context->root_span.kind = OpenTelemetry::SERVER; + + /// Prepare for sending exceptions and logs. + const Settings & settings = query_context->getSettingsRef(); + send_exception_with_stacktrace = settings.calculate_text_stack_trace; + const auto client_logs_level = settings.send_logs_level; + if (client_logs_level != LogsLevel::none) + { + logs_queue = std::make_shared<InternalTextLogsQueue>(); + logs_queue->max_priority = Poco::Logger::parseLevel(client_logs_level.toString()); + logs_queue->setSourceRegexp(settings.send_logs_source_regexp); + CurrentThread::attachInternalTextLogsQueue(logs_queue, client_logs_level); + } + + /// Set the current database if specified. + if (!query_info.database().empty()) + query_context->setCurrentDatabase(query_info.database()); + + /// Apply transport compression for this call. + if (auto transport_compression = TransportCompression::fromQueryInfo(query_info)) + responder->setTransportCompression(*transport_compression); + + /// The interactive delay will be used to show progress. + interactive_delay = settings.interactive_delay; + query_context->setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); }); + + /// Parse the query. + query_text = std::move(*(query_info.mutable_query())); + const char * begin = query_text.data(); + const char * end = begin + query_text.size(); + ParserQuery parser(end, settings.allow_settings_after_format_in_insert); + ast = parseQuery(parser, begin, end, "", settings.max_query_size, settings.max_parser_depth); + + /// Choose input format. + insert_query = ast->as<ASTInsertQuery>(); + if (insert_query) + { + input_format = insert_query->format; + if (input_format.empty()) + input_format = "Values"; + } + + input_data_delimiter = query_info.input_data_delimiter(); + + /// Choose output format. + query_context->setDefaultFormat(query_info.output_format()); + if (const auto * ast_query_with_output = dynamic_cast<const ASTQueryWithOutput *>(ast.get()); + ast_query_with_output && ast_query_with_output->format) + { + output_format = getIdentifierName(ast_query_with_output->format); + } + if (output_format.empty()) + output_format = query_context->getDefaultFormat(); + + send_output_columns_names_and_types = query_info.send_output_columns(); + + /// Choose compression. + String input_compression_method_str = query_info.input_compression_type(); + if (input_compression_method_str.empty()) + input_compression_method_str = query_info.obsolete_compression_type(); + input_compression_method = chooseCompressionMethod("", input_compression_method_str); + + String output_compression_method_str = query_info.output_compression_type(); + if (output_compression_method_str.empty()) + output_compression_method_str = query_info.obsolete_compression_type(); + output_compression_method = chooseCompressionMethod("", output_compression_method_str); + output_compression_level = query_info.output_compression_level(); + + /// Set callback to create and fill external tables + query_context->setExternalTablesInitializer([this] (ContextPtr context) + { + if (context != query_context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in external tables initializer"); + createExternalTables(); + }); + + /// Set callbacks to execute function input(). + query_context->setInputInitializer([this] (ContextPtr context, const StoragePtr & input_storage) + { + if (context != query_context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in Input initializer"); + input_function_is_used = true; + initializePipeline(input_storage->getInMemoryMetadataPtr()->getSampleBlock()); + }); + + query_context->setInputBlocksReaderCallback([this](ContextPtr context) -> Block + { + if (context != query_context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in InputBlocksReader"); + + Block block; + while (!block && pipeline_executor->pull(block)); + + return block; + }); + + /// Start executing the query. + const auto * query_end = end; + if (insert_query && insert_query->data) + { + query_end = insert_query->data; + } + String query(begin, query_end); + io = ::DB::executeQuery(true, query, query_context); + } + + void Call::processInput() + { + if (!io.pipeline.pushing()) + return; + + bool has_data_to_insert = (insert_query && insert_query->data) + || !query_info.input_data().empty() || query_info.next_query_info(); + if (!has_data_to_insert) + { + if (!insert_query) + throw Exception(ErrorCodes::NO_DATA_TO_INSERT, "Query requires data to insert, but it is not an INSERT query"); + else + { + const auto & settings = query_context->getSettingsRef(); + if (settings.throw_if_no_data_to_insert) + throw Exception(ErrorCodes::NO_DATA_TO_INSERT, "No data to insert"); + else + return; + } + } + + /// This is significant, because parallel parsing may be used. + /// So we mustn't touch the input stream from other thread. + initializePipeline(io.pipeline.getHeader()); + + PushingPipelineExecutor executor(io.pipeline); + executor.start(); + + Block block; + while (pipeline_executor->pull(block)) + { + if (block) + executor.push(block); + } + + if (isQueryCancelled()) + executor.cancel(); + else + executor.finish(); + } + + void Call::initializePipeline(const Block & header) + { + assert(!read_buffer); + read_buffer = std::make_unique<ReadBufferFromCallback>([this]() -> std::pair<const void *, size_t> + { + if (need_input_data_from_insert_query) + { + need_input_data_from_insert_query = false; + if (insert_query && insert_query->data && (insert_query->data != insert_query->end)) + { + need_input_data_delimiter = !input_data_delimiter.empty(); + return {insert_query->data, insert_query->end - insert_query->data}; + } + } + + while (true) + { + if (need_input_data_from_query_info) + { + if (need_input_data_delimiter && !query_info.input_data().empty()) + { + need_input_data_delimiter = false; + return {input_data_delimiter.data(), input_data_delimiter.size()}; + } + need_input_data_from_query_info = false; + if (!query_info.input_data().empty()) + { + need_input_data_delimiter = !input_data_delimiter.empty(); + return {query_info.input_data().data(), query_info.input_data().size()}; + } + } + + if (!query_info.next_query_info()) + break; + + if (!isInputStreaming(call_type)) + throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "next_query_info is allowed to be set only for streaming input"); + + readQueryInfo(); + if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty() + || !query_info.database().empty() || !query_info.input_data_delimiter().empty() || !query_info.output_format().empty() + || query_info.external_tables_size() || !query_info.user_name().empty() || !query_info.password().empty() + || !query_info.quota().empty() || !query_info.session_id().empty()) + { + throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, + "Extra query infos can be used only to add more input data. " + "Only the following fields can be set: input_data, next_query_info, cancel"); + } + + if (isQueryCancelled()) + break; + + LOG_DEBUG(log, "Received extra QueryInfo: input_data: {} bytes", query_info.input_data().size()); + need_input_data_from_query_info = true; + } + + return {nullptr, 0}; /// no more input data + }); + + read_buffer = wrapReadBufferWithCompressionMethod(std::move(read_buffer), input_compression_method); + + assert(!pipeline); + auto source = query_context->getInputFormat( + input_format, *read_buffer, header, query_context->getSettings().max_insert_block_size); + + pipeline = std::make_unique<QueryPipeline>(std::move(source)); + pipeline_executor = std::make_unique<PullingPipelineExecutor>(*pipeline); + } + + void Call::createExternalTables() + { + while (true) + { + for (const auto & external_table : query_info.external_tables()) + { + String name = external_table.name(); + if (name.empty()) + name = "_data"; + auto temporary_id = StorageID::createEmpty(); + temporary_id.table_name = name; + + /// If such a table does not exist, create it. + StoragePtr storage; + if (auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal)) + { + storage = DatabaseCatalog::instance().getTable(resolved, query_context); + } + else + { + NamesAndTypesList columns; + for (size_t column_idx : collections::range(external_table.columns_size())) + { + /// TODO: consider changing protocol + const auto & name_and_type = external_table.columns(static_cast<int>(column_idx)); + NameAndTypePair column; + column.name = name_and_type.name(); + if (column.name.empty()) + column.name = "_" + std::to_string(column_idx + 1); + column.type = DataTypeFactory::instance().get(name_and_type.type()); + columns.emplace_back(std::move(column)); + } + auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {}); + storage = temporary_table.getTable(); + query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table)); + } + + if (!external_table.data().empty()) + { + /// The data will be written directly to the table. + auto metadata_snapshot = storage->getInMemoryMetadataPtr(); + auto sink = storage->write(ASTPtr(), metadata_snapshot, query_context, /*async_insert=*/false); + + std::unique_ptr<ReadBuffer> buf = std::make_unique<ReadBufferFromMemory>(external_table.data().data(), external_table.data().size()); + buf = wrapReadBufferWithCompressionMethod(std::move(buf), chooseCompressionMethod("", external_table.compression_type())); + + String format = external_table.format(); + if (format.empty()) + format = "TabSeparated"; + ContextMutablePtr external_table_context = query_context; + ContextMutablePtr temp_context; + if (!external_table.settings().empty()) + { + temp_context = Context::createCopy(query_context); + external_table_context = temp_context; + SettingsChanges settings_changes; + for (const auto & [key, value] : external_table.settings()) + settings_changes.push_back({key, value}); + external_table_context->checkSettingsConstraints(settings_changes, SettingSource::QUERY); + external_table_context->applySettingsChanges(settings_changes); + } + auto in = external_table_context->getInputFormat( + format, *buf, metadata_snapshot->getSampleBlock(), + external_table_context->getSettings().max_insert_block_size); + + QueryPipelineBuilder cur_pipeline; + cur_pipeline.init(Pipe(std::move(in))); + cur_pipeline.addTransform(std::move(sink)); + cur_pipeline.setSinks([&](const Block & header, Pipe::StreamType) + { + return std::make_shared<EmptySink>(header); + }); + + auto executor = cur_pipeline.execute(); + executor->execute(1, false); + } + } + + if (!query_info.input_data().empty()) + { + /// External tables must be created before executing query, + /// so all external tables must be send no later sending any input data. + break; + } + + if (!query_info.next_query_info()) + break; + + if (!isInputStreaming(call_type)) + throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, "next_query_info is allowed to be set only for streaming input"); + + readQueryInfo(); + if (!query_info.query().empty() || !query_info.query_id().empty() || !query_info.settings().empty() + || !query_info.database().empty() || !query_info.input_data_delimiter().empty() + || !query_info.output_format().empty() || !query_info.user_name().empty() || !query_info.password().empty() + || !query_info.quota().empty() || !query_info.session_id().empty()) + { + throw Exception(ErrorCodes::INVALID_GRPC_QUERY_INFO, + "Extra query infos can be used only " + "to add more data to input or more external tables. " + "Only the following fields can be set: " + "input_data, external_tables, next_query_info, cancel"); + } + if (isQueryCancelled()) + break; + LOG_DEBUG(log, "Received extra QueryInfo: external tables: {}", query_info.external_tables_size()); + } + } + + void Call::generateOutput() + { + /// We add query_id and time_zone to the first result anyway. + addQueryDetailsToResult(); + + if (!io.pipeline.initialized() || io.pipeline.pushing()) + return; + + Block header; + if (io.pipeline.pulling()) + header = io.pipeline.getHeader(); + + if (output_compression_method != CompressionMethod::None) + output.resize(DBMS_DEFAULT_BUFFER_SIZE); /// Must have enough space for compressed data. + write_buffer = std::make_unique<WriteBufferFromVector<PODArray<char>>>(output); + nested_write_buffer = static_cast<WriteBufferFromVector<PODArray<char>> *>(write_buffer.get()); + if (output_compression_method != CompressionMethod::None) + { + write_buffer = wrapWriteBufferWithCompressionMethod(std::move(write_buffer), output_compression_method, output_compression_level); + compressing_write_buffer = write_buffer.get(); + } + + auto has_output = [&] { return (nested_write_buffer->position() != output.data()) || (compressing_write_buffer && compressing_write_buffer->offset()); }; + + output_format_processor = query_context->getOutputFormat(output_format, *write_buffer, header); + Stopwatch after_send_progress; + + /// Unless the input() function is used we are not going to receive input data anymore. + if (!input_function_is_used) + check_query_info_contains_cancel_only = true; + + if (io.pipeline.pulling()) + { + auto executor = std::make_shared<PullingAsyncPipelineExecutor>(io.pipeline); + auto check_for_cancel = [&] + { + if (isQueryCancelled()) + { + executor->cancel(); + return false; + } + return true; + }; + + addOutputFormatToResult(); + addOutputColumnsNamesAndTypesToResult(header); + + Block block; + while (check_for_cancel()) + { + if (!executor->pull(block, interactive_delay / 1000)) + break; + + throwIfFailedToSendResult(); + if (!check_for_cancel()) + break; + + if (block && !io.null_format) + output_format_processor->write(materializeBlock(block)); + + if (after_send_progress.elapsedMicroseconds() >= interactive_delay) + { + addProgressToResult(); + after_send_progress.restart(); + } + + addLogsToResult(); + + if (has_output() || result.has_progress() || result.logs_size()) + sendResult(); + + throwIfFailedToSendResult(); + if (!check_for_cancel()) + break; + } + + if (!isQueryCancelled()) + { + addTotalsToResult(executor->getTotalsBlock()); + addExtremesToResult(executor->getExtremesBlock()); + addProfileInfoToResult(executor->getProfileInfo()); + } + } + else + { + auto executor = std::make_shared<CompletedPipelineExecutor>(io.pipeline); + auto callback = [&]() -> bool + { + throwIfFailedToSendResult(); + addProgressToResult(); + addLogsToResult(); + + if (has_output() || result.has_progress() || result.logs_size()) + sendResult(); + + throwIfFailedToSendResult(); + + return isQueryCancelled(); + }; + executor->setCancelCallback(std::move(callback), interactive_delay / 1000); + executor->execute(); + } + + output_format_processor->finalize(); + } + + void Call::finishQuery() + { + finalize = true; + io.onFinish(); + addProgressToResult(); + query_scope->logPeakMemoryUsage(); + addLogsToResult(); + releaseQueryIDAndSessionID(); + sendResult(); + close(); + + LOG_INFO( + log, + "Finished call {} in {} secs. (including reading by client: {}, writing by client: {})", + getCallName(call_type), + query_time.elapsedSeconds(), + static_cast<double>(waited_for_client_reading) / 1000000000ULL, + static_cast<double>(waited_for_client_writing) / 1000000000ULL); + } + + void Call::onException(const Exception & exception) + { + io.onException(); + + LOG_ERROR(log, getExceptionMessageAndPattern(exception, send_exception_with_stacktrace)); + + if (responder && !responder_finished) + { + try + { + /// Try to send logs to client, but it could be risky too. + addLogsToResult(); + } + catch (...) + { + LOG_WARNING(log, "Couldn't send logs to client"); + } + + releaseQueryIDAndSessionID(); + + try + { + sendException(exception); + } + catch (...) + { + LOG_WARNING(log, "Couldn't send exception information to the client"); + } + } + + close(); + } + + void Call::onFatalError() + { + if (responder && !responder_finished) + { + try + { + result.mutable_exception()->set_name("FatalError"); + addLogsToResult(); + sendResult(); + } + catch (...) + { + } + } + } + + void Call::releaseQueryIDAndSessionID() + { + /// releaseQueryIDAndSessionID() should be called before sending the final result to the client + /// because the client may decide to send another query with the same query ID or session ID + /// immediately after it receives our final result, and it's prohibited to have + /// two queries executed at the same time with the same query ID or session ID. + io.process_list_entry.reset(); + if (query_context) + query_context->setProcessListElement(nullptr); + if (session) + session->releaseSessionID(); + } + + void Call::close() + { + responder.reset(); + pipeline_executor.reset(); + pipeline.reset(); + output_format_processor.reset(); + read_buffer.reset(); + write_buffer.reset(); + nested_write_buffer = nullptr; + compressing_write_buffer = nullptr; + io = {}; + query_scope.reset(); + query_context.reset(); + thread_trace_context.reset(); + session.reset(); + } + + void Call::readQueryInfo() + { + auto start_reading = [&] + { + reading_query_info.set(true); + responder->read(next_query_info_while_reading, [this](bool ok) + { + /// Called on queue_thread. + if (ok) + { + const auto & nqi = next_query_info_while_reading; + if (check_query_info_contains_cancel_only) + { + if (!nqi.query().empty() || !nqi.query_id().empty() || !nqi.settings().empty() || !nqi.database().empty() + || !nqi.input_data().empty() || !nqi.input_data_delimiter().empty() || !nqi.output_format().empty() + || !nqi.user_name().empty() || !nqi.password().empty() || !nqi.quota().empty() || !nqi.session_id().empty()) + { + LOG_WARNING(log, "Cannot add extra information to a query which is already executing. Only the 'cancel' field can be set"); + } + } + if (nqi.cancel()) + want_to_cancel = true; + } + else + { + /// We cannot throw an exception right here because this code is executed + /// on queue_thread. + failed_to_read_query_info = true; + } + reading_query_info.set(false); + }); + }; + + auto finish_reading = [&] + { + if (reading_query_info.get()) + { + Stopwatch client_writing_watch; + reading_query_info.wait(false); + waited_for_client_writing += client_writing_watch.elapsedNanoseconds(); + } + throwIfFailedToReadQueryInfo(); + query_info = std::move(next_query_info_while_reading); + initial_query_info_read = true; + }; + + if (!initial_query_info_read) + { + /// Initial query info hasn't been read yet, so we're going to read it now. + start_reading(); + } + + /// Maybe it's reading a query info right now. Let it finish. + finish_reading(); + + if (isInputStreaming(call_type)) + { + /// Next query info can contain more input data. Now we start reading a next query info, + /// so another call of readQueryInfo() in the future will probably be able to take it. + start_reading(); + } + } + + void Call::throwIfFailedToReadQueryInfo() + { + if (failed_to_read_query_info) + { + if (initial_query_info_read) + throw Exception(ErrorCodes::NETWORK_ERROR, "Failed to read extra QueryInfo"); + else + throw Exception(ErrorCodes::NETWORK_ERROR, "Failed to read initial QueryInfo"); + } + } + + bool Call::isQueryCancelled() + { + if (cancelled) + { + result.set_cancelled(true); + return true; + } + + if (want_to_cancel) + { + LOG_INFO(log, "Query cancelled"); + cancelled = true; + result.set_cancelled(true); + return true; + } + + return false; + } + + void Call::addQueryDetailsToResult() + { + *result.mutable_query_id() = query_context->getClientInfo().current_query_id; + *result.mutable_time_zone() = DateLUT::instance().getTimeZone(); + } + + void Call::addOutputFormatToResult() + { + *result.mutable_output_format() = output_format; + } + + void Call::addOutputColumnsNamesAndTypesToResult(const Block & header) + { + if (!send_output_columns_names_and_types) + return; + for (const auto & column : header) + { + auto & name_and_type = *result.add_output_columns(); + *name_and_type.mutable_name() = column.name; + *name_and_type.mutable_type() = column.type->getName(); + } + } + + void Call::addProgressToResult() + { + auto values = progress.fetchValuesAndResetPiecewiseAtomically(); + if (!values.read_rows && !values.read_bytes && !values.total_rows_to_read && !values.written_rows && !values.written_bytes) + return; + auto & grpc_progress = *result.mutable_progress(); + /// Sum is used because we need to accumulate values for the case if streaming output is disabled. + grpc_progress.set_read_rows(grpc_progress.read_rows() + values.read_rows); + grpc_progress.set_read_bytes(grpc_progress.read_bytes() + values.read_bytes); + grpc_progress.set_total_rows_to_read(grpc_progress.total_rows_to_read() + values.total_rows_to_read); + grpc_progress.set_written_rows(grpc_progress.written_rows() + values.written_rows); + grpc_progress.set_written_bytes(grpc_progress.written_bytes() + values.written_bytes); + } + + void Call::addTotalsToResult(const Block & totals) + { + if (!totals) + return; + + PODArray<char> memory; + if (output_compression_method != CompressionMethod::None) + memory.resize(DBMS_DEFAULT_BUFFER_SIZE); /// Must have enough space for compressed data. + std::unique_ptr<WriteBuffer> buf = std::make_unique<WriteBufferFromVector<PODArray<char>>>(memory); + buf = wrapWriteBufferWithCompressionMethod(std::move(buf), output_compression_method, output_compression_level); + auto format = query_context->getOutputFormat(output_format, *buf, totals); + format->write(materializeBlock(totals)); + format->finalize(); + buf->finalize(); + + result.mutable_totals()->assign(memory.data(), memory.size()); + } + + void Call::addExtremesToResult(const Block & extremes) + { + if (!extremes) + return; + + PODArray<char> memory; + if (output_compression_method != CompressionMethod::None) + memory.resize(DBMS_DEFAULT_BUFFER_SIZE); /// Must have enough space for compressed data. + std::unique_ptr<WriteBuffer> buf = std::make_unique<WriteBufferFromVector<PODArray<char>>>(memory); + buf = wrapWriteBufferWithCompressionMethod(std::move(buf), output_compression_method, output_compression_level); + auto format = query_context->getOutputFormat(output_format, *buf, extremes); + format->write(materializeBlock(extremes)); + format->finalize(); + buf->finalize(); + + result.mutable_extremes()->assign(memory.data(), memory.size()); + } + + void Call::addProfileInfoToResult(const ProfileInfo & info) + { + auto & stats = *result.mutable_stats(); + stats.set_rows(info.rows); + stats.set_blocks(info.blocks); + stats.set_allocated_bytes(info.bytes); + stats.set_applied_limit(info.hasAppliedLimit()); + stats.set_rows_before_limit(info.getRowsBeforeLimit()); + } + + void Call::addLogsToResult() + { + if (!logs_queue) + return; + + static_assert(::clickhouse::grpc::LOG_NONE == 0); + static_assert(::clickhouse::grpc::LOG_FATAL == static_cast<int>(Poco::Message::PRIO_FATAL)); + static_assert(::clickhouse::grpc::LOG_CRITICAL == static_cast<int>(Poco::Message::PRIO_CRITICAL)); + static_assert(::clickhouse::grpc::LOG_ERROR == static_cast<int>(Poco::Message::PRIO_ERROR)); + static_assert(::clickhouse::grpc::LOG_WARNING == static_cast<int>(Poco::Message::PRIO_WARNING)); + static_assert(::clickhouse::grpc::LOG_NOTICE == static_cast<int>(Poco::Message::PRIO_NOTICE)); + static_assert(::clickhouse::grpc::LOG_INFORMATION == static_cast<int>(Poco::Message::PRIO_INFORMATION)); + static_assert(::clickhouse::grpc::LOG_DEBUG == static_cast<int>(Poco::Message::PRIO_DEBUG)); + static_assert(::clickhouse::grpc::LOG_TRACE == static_cast<int>(Poco::Message::PRIO_TRACE)); + + MutableColumns columns; + while (logs_queue->tryPop(columns)) + { + if (columns.empty() || columns[0]->empty()) + continue; + + const auto & column_time = typeid_cast<const ColumnUInt32 &>(*columns[0]); + const auto & column_time_microseconds = typeid_cast<const ColumnUInt32 &>(*columns[1]); + const auto & column_query_id = typeid_cast<const ColumnString &>(*columns[3]); + const auto & column_thread_id = typeid_cast<const ColumnUInt64 &>(*columns[4]); + const auto & column_level = typeid_cast<const ColumnInt8 &>(*columns[5]); + const auto & column_source = typeid_cast<const ColumnString &>(*columns[6]); + const auto & column_text = typeid_cast<const ColumnString &>(*columns[7]); + size_t num_rows = column_time.size(); + + for (size_t row = 0; row != num_rows; ++row) + { + auto & log_entry = *result.add_logs(); + log_entry.set_time(column_time.getElement(row)); + log_entry.set_time_microseconds(column_time_microseconds.getElement(row)); + std::string_view query_id = column_query_id.getDataAt(row).toView(); + log_entry.set_query_id(query_id.data(), query_id.size()); + log_entry.set_thread_id(column_thread_id.getElement(row)); + log_entry.set_level(static_cast<::clickhouse::grpc::LogsLevel>(column_level.getElement(row))); + std::string_view source = column_source.getDataAt(row).toView(); + log_entry.set_source(source.data(), source.size()); + std::string_view text = column_text.getDataAt(row).toView(); + log_entry.set_text(text.data(), text.size()); + } + } + } + + void Call::sendResult() + { + /// gRPC doesn't allow to write anything to a finished responder. + if (responder_finished) + return; + + /// If output is not streaming then only the final result can be sent. + bool send_final_message = finalize || result.has_exception() || result.cancelled(); + if (!send_final_message && !isOutputStreaming(call_type)) + return; + + /// Copy output to `result.output`, with optional compressing. + if (write_buffer) + { + size_t output_size; + if (send_final_message) + { + if (compressing_write_buffer) + LOG_DEBUG(log, "Compressing final {} bytes", compressing_write_buffer->offset()); + write_buffer->finalize(); + output_size = output.size(); + } + else + { + if (compressing_write_buffer && compressing_write_buffer->offset()) + { + LOG_DEBUG(log, "Compressing {} bytes", compressing_write_buffer->offset()); + compressing_write_buffer->sync(); + } + output_size = nested_write_buffer->position() - output.data(); + } + + if (output_size) + { + result.mutable_output()->assign(output.data(), output_size); + nested_write_buffer->restart(); /// We're going to reuse the same buffer again for next block of data. + } + } + + if (!send_final_message && result.output().empty() && result.totals().empty() && result.extremes().empty() && !result.logs_size() + && !result.has_progress() && !result.has_stats() && !result.has_exception() && !result.cancelled()) + return; /// Nothing to send. + + /// Wait for previous write to finish. + /// (gRPC doesn't allow to start sending another result while the previous is still being sending.) + if (sending_result.get()) + { + Stopwatch client_reading_watch; + sending_result.wait(false); + waited_for_client_reading += client_reading_watch.elapsedNanoseconds(); + } + throwIfFailedToSendResult(); + + /// Start sending the result. + LOG_DEBUG(log, "Sending {} result to the client: {}", (send_final_message ? "final" : "intermediate"), getResultDescription(result)); + + sending_result.set(true); + auto callback = [this](bool ok) + { + /// Called on queue_thread. + if (!ok) + failed_to_send_result = true; + sending_result.set(false); + }; + + Stopwatch client_reading_final_watch; + if (send_final_message) + { + responder_finished = true; + responder->writeAndFinish(result, {}, callback); + } + else + responder->write(result, callback); + + /// gRPC has already retrieved all data from `result`, so we don't have to keep it. + result.Clear(); + + if (send_final_message) + { + /// Wait until the result is actually sent. + sending_result.wait(false); + waited_for_client_reading += client_reading_final_watch.elapsedNanoseconds(); + throwIfFailedToSendResult(); + LOG_TRACE(log, "Final result has been sent to the client"); + } + } + + void Call::throwIfFailedToSendResult() + { + if (failed_to_send_result) + throw Exception(ErrorCodes::NETWORK_ERROR, "Failed to send result to the client"); + } + + void Call::sendException(const Exception & exception) + { + auto & grpc_exception = *result.mutable_exception(); + grpc_exception.set_code(exception.code()); + grpc_exception.set_name(exception.name()); + grpc_exception.set_display_text(exception.displayText()); + if (send_exception_with_stacktrace) + grpc_exception.set_stack_trace(exception.getStackTraceString()); + sendResult(); + } +} + + +class GRPCServer::Runner +{ +public: + explicit Runner(GRPCServer & owner_) : owner(owner_) {} + + ~Runner() + { + if (queue_thread.joinable()) + queue_thread.join(); + } + + void start() + { + startReceivingNewCalls(); + + /// We run queue in a separate thread. + auto runner_function = [this] + { + try + { + run(); + } + catch (...) + { + tryLogCurrentException("GRPCServer"); + } + }; + queue_thread = ThreadFromGlobalPool{runner_function}; + } + + void stop() { stopReceivingNewCalls(); } + + size_t getNumCurrentCalls() const + { + std::lock_guard lock{mutex}; + return current_calls.size(); + } + +private: + void startReceivingNewCalls() + { + std::lock_guard lock{mutex}; + responders_for_new_calls.resize(CALL_MAX); + for (CallType call_type : collections::range(CALL_MAX)) + makeResponderForNewCall(call_type); + } + + void makeResponderForNewCall(CallType call_type) + { + /// `mutex` is already locked. + responders_for_new_calls[call_type] = makeResponder(call_type); + + responders_for_new_calls[call_type]->start( + owner.grpc_service, *owner.queue, *owner.queue, + [this, call_type](bool ok) { onNewCall(call_type, ok); }); + } + + void stopReceivingNewCalls() + { + std::lock_guard lock{mutex}; + should_stop = true; + } + + void onNewCall(CallType call_type, bool responder_started_ok) + { + std::lock_guard lock{mutex}; + auto responder = std::move(responders_for_new_calls[call_type]); + if (should_stop) + return; + makeResponderForNewCall(call_type); + if (responder_started_ok) + { + /// Connection established and the responder has been started. + /// So we pass this responder to a Call and make another responder for next connection. + auto new_call = std::make_unique<Call>(call_type, std::move(responder), owner.iserver, owner.log); + auto * new_call_ptr = new_call.get(); + current_calls[new_call_ptr] = std::move(new_call); + new_call_ptr->start([this, new_call_ptr]() { onFinishCall(new_call_ptr); }); + } + } + + void onFinishCall(Call * call) + { + /// Called on call_thread. That's why we can't destroy the `call` right now + /// (thread can't join to itself). Thus here we only move the `call` from + /// `current_calls` to `finished_calls` and run() will actually destroy the `call`. + std::lock_guard lock{mutex}; + auto it = current_calls.find(call); + finished_calls.push_back(std::move(it->second)); + current_calls.erase(it); + } + + void run() + { + setThreadName("GRPCServerQueue"); + while (true) + { + { + std::lock_guard lock{mutex}; + finished_calls.clear(); /// Destroy finished calls. + + /// If (should_stop == true) we continue processing until there is no active calls. + if (should_stop && current_calls.empty()) + { + bool all_responders_gone = std::all_of( + responders_for_new_calls.begin(), responders_for_new_calls.end(), + [](std::unique_ptr<BaseResponder> & responder) { return !responder; }); + if (all_responders_gone) + break; + } + } + + bool ok = false; + void * tag = nullptr; + if (!owner.queue->Next(&tag, &ok)) + { + /// Queue shutted down. + break; + } + + auto & callback = *static_cast<CompletionCallback *>(tag); + callback(ok); + } + } + + GRPCServer & owner; + ThreadFromGlobalPool queue_thread; + std::vector<std::unique_ptr<BaseResponder>> responders_for_new_calls; + std::map<Call *, std::unique_ptr<Call>> current_calls; + std::vector<std::unique_ptr<Call>> finished_calls; + bool should_stop = false; + mutable std::mutex mutex; +}; + + +GRPCServer::GRPCServer(IServer & iserver_, const Poco::Net::SocketAddress & address_to_listen_) + : iserver(iserver_) + , address_to_listen(address_to_listen_) + , log(&Poco::Logger::get("GRPCServer")) + , runner(std::make_unique<Runner>(*this)) +{} + +GRPCServer::~GRPCServer() +{ + /// Server should be shutdown before CompletionQueue. + if (grpc_server) + grpc_server->Shutdown(); + + /// Completion Queue should be shutdown before destroying the runner, + /// because the runner is now probably executing CompletionQueue::Next() on queue_thread + /// which is blocked until an event is available or the queue is shutting down. + if (queue) + queue->Shutdown(); + + runner.reset(); +} + +void GRPCServer::start() +{ + initGRPCLogging(iserver.config()); + grpc::ServerBuilder builder; + builder.AddListeningPort(address_to_listen.toString(), makeCredentials(iserver.config())); + builder.RegisterService(&grpc_service); + builder.SetMaxSendMessageSize(iserver.config().getInt("grpc.max_send_message_size", -1)); + builder.SetMaxReceiveMessageSize(iserver.config().getInt("grpc.max_receive_message_size", -1)); + auto default_transport_compression = TransportCompression::fromConfiguration(iserver.config()); + builder.SetDefaultCompressionAlgorithm(default_transport_compression.algorithm); + builder.SetDefaultCompressionLevel(default_transport_compression.level); + + queue = builder.AddCompletionQueue(); + grpc_server = builder.BuildAndStart(); + if (nullptr == grpc_server) + { + throw DB::Exception(DB::ErrorCodes::NETWORK_ERROR, "Can't start grpc server, there is a port conflict"); + } + + runner->start(); +} + + +void GRPCServer::stop() +{ + /// Stop receiving new calls. + runner->stop(); +} + +size_t GRPCServer::currentConnections() const +{ + return runner->getNumCurrentCalls(); +} + +} +#endif diff --git a/contrib/clickhouse/src/Server/GRPCServer.h b/contrib/clickhouse/src/Server/GRPCServer.h new file mode 100644 index 0000000000..b77f22f046 --- /dev/null +++ b/contrib/clickhouse/src/Server/GRPCServer.h @@ -0,0 +1,56 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_GRPC +#include <Poco/Net/SocketAddress.h> +#include <base/types.h> +#error #include "clickhouse_grpc.grpc.pb.h" + +namespace Poco { class Logger; } + +namespace grpc +{ +class Server; +class ServerCompletionQueue; +} + +namespace DB +{ +class IServer; + +class GRPCServer +{ +public: + GRPCServer(IServer & iserver_, const Poco::Net::SocketAddress & address_to_listen_); + ~GRPCServer(); + + /// Starts the server. A new thread will be created that waits for and accepts incoming connections. + void start(); + + /// Stops the server. No new connections will be accepted. + void stop(); + + /// Returns the port this server is listening to. + UInt16 portNumber() const { return address_to_listen.port(); } + + /// Returns the number of currently handled connections. + size_t currentConnections() const; + + /// Returns the number of current threads. + size_t currentThreads() const { return currentConnections(); } + +private: + using GRPCService = clickhouse::grpc::ClickHouse::AsyncService; + class Runner; + + IServer & iserver; + const Poco::Net::SocketAddress address_to_listen; + Poco::Logger * log; + GRPCService grpc_service; + std::unique_ptr<grpc::Server> grpc_server; + std::unique_ptr<grpc::ServerCompletionQueue> queue; + std::unique_ptr<Runner> runner; +}; +} +#endif diff --git a/contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp b/contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp new file mode 100644 index 0000000000..1abf9e5b83 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTMLForm.cpp @@ -0,0 +1,347 @@ +#include <Server/HTTP/HTMLForm.h> + +#include <Core/Settings.h> +#include <IO/EmptyReadBuffer.h> +#include <IO/ReadBufferFromString.h> +#include <Server/HTTP/ReadHeaders.h> + +#include <Poco/CountingStream.h> +#include <Poco/Net/MultipartReader.h> +#include <Poco/Net/MultipartWriter.h> +#include <Poco/Net/NetException.h> +#include <Poco/Net/NullPartHandler.h> +#include <Poco/NullStream.h> +#include <Poco/StreamCopier.h> +#include <Poco/UTF8String.h> + +#include <sstream> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_READ_ALL_DATA; +} + +namespace +{ + +class NullPartHandler : public HTMLForm::PartHandler +{ +public: + void handlePart(const Poco::Net::MessageHeader &, ReadBuffer &) override {} +}; + +} + +const std::string HTMLForm::ENCODING_URL = "application/x-www-form-urlencoded"; +const std::string HTMLForm::ENCODING_MULTIPART = "multipart/form-data"; +const int HTMLForm::UNKNOWN_CONTENT_LENGTH = -1; + + +HTMLForm::HTMLForm(const Settings & settings) + : max_fields_number(settings.http_max_fields) + , max_field_name_size(settings.http_max_field_name_size) + , max_field_value_size(settings.http_max_field_value_size) + , encoding(ENCODING_URL) +{ +} + + +HTMLForm::HTMLForm(const Settings & settings, const std::string & encoding_) : HTMLForm(settings) +{ + encoding = encoding_; +} + + +HTMLForm::HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler) + : HTMLForm(settings) +{ + load(request, requestBody, handler); +} + + +HTMLForm::HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody) : HTMLForm(settings) +{ + load(request, requestBody); +} + + +HTMLForm::HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request) : HTMLForm(settings, Poco::URI(request.getURI())) +{ +} + +HTMLForm::HTMLForm(const Settings & settings, const Poco::URI & uri) : HTMLForm(settings) +{ + ReadBufferFromString istr(uri.getRawQuery()); // STYLE_CHECK_ALLOW_STD_STRING_STREAM + readQuery(istr); +} + + +void HTMLForm::load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler) +{ + clear(); + + Poco::URI uri(request.getURI()); + const std::string & query = uri.getRawQuery(); + if (!query.empty()) + { + ReadBufferFromString istr(query); + readQuery(istr); + } + + if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST || request.getMethod() == Poco::Net::HTTPRequest::HTTP_PUT) + { + std::string media_type; + NameValueCollection params; + Poco::Net::MessageHeader::splitParameters(request.getContentType(), media_type, params); + encoding = media_type; + if (encoding == ENCODING_MULTIPART) + { + boundary = params["boundary"]; + readMultipart(requestBody, handler); + } + else + { + readQuery(requestBody); + } + } +} + + +void HTMLForm::load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody) +{ + NullPartHandler nah; + load(request, requestBody, nah); +} + + +void HTMLForm::read(ReadBuffer & in) +{ + readQuery(in); +} + + +void HTMLForm::readQuery(ReadBuffer & in) +{ + size_t fields = 0; + char ch = 0; // silence "uninitialized" warning from gcc-* + bool is_first = true; + + while (true) + { + if (max_fields_number > 0 && fields == max_fields_number) + throw Poco::Net::HTMLFormException("Too many form fields"); + + std::string name; + std::string value; + + while (in.read(ch) && ch != '=' && ch != '&') + { + if (ch == '+') + ch = ' '; + if (name.size() < max_field_name_size) + name += ch; + else + throw Poco::Net::HTMLFormException("Field name too long"); + } + + if (ch == '=') + { + while (in.read(ch) && ch != '&') + { + if (ch == '+') + ch = ' '; + if (value.size() < max_field_value_size) + value += ch; + else + throw Poco::Net::HTMLFormException("Field value too long"); + } + } + + // Remove UTF-8 BOM from first name, if present + if (is_first) + Poco::UTF8::removeBOM(name); + + std::string decoded_name; + std::string decoded_value; + Poco::URI::decode(name, decoded_name); + Poco::URI::decode(value, decoded_value); + add(decoded_name, decoded_value); + ++fields; + + is_first = false; + + if (in.eof()) + break; + } +} + + +void HTMLForm::readMultipart(ReadBuffer & in_, PartHandler & handler) +{ + /// Assume there is always a boundary provided. + assert(!boundary.empty()); + + size_t fields = 0; + MultipartReadBuffer in(in_, boundary); + + if (!in.skipToNextBoundary()) + throw Poco::Net::HTMLFormException("No boundary line found"); + + /// Read each part until next boundary (or last boundary) + while (!in.eof()) + { + if (max_fields_number && fields > max_fields_number) + throw Poco::Net::HTMLFormException("Too many form fields"); + + Poco::Net::MessageHeader header; + readHeaders(header, in, max_fields_number, max_field_name_size, max_field_value_size); + skipToNextLineOrEOF(in); + + NameValueCollection params; + if (header.has("Content-Disposition")) + { + std::string unused; + Poco::Net::MessageHeader::splitParameters(header.get("Content-Disposition"), unused, params); + } + + if (params.has("filename")) + handler.handlePart(header, in); + else + { + std::string name = params["name"]; + std::string value; + char ch; + + while (in.read(ch)) + { + if (value.size() > max_field_value_size) + throw Poco::Net::HTMLFormException("Field value too long"); + value += ch; + } + + add(name, value); + } + + ++fields; + + /// If we already encountered EOF for the buffer |in|, it's possible that the next symbol is a start of boundary line. + /// In this case reading the boundary line will reset the EOF state, potentially breaking invariant of EOF idempotency - + /// if there is such invariant in the first place. + if (!in.skipToNextBoundary()) + break; + } + + /// It's important to check, because we could get "fake" EOF and incomplete request if a client suddenly died in the middle. + if (!in.isActualEOF()) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Unexpected EOF, " + "did not find the last boundary while parsing a multipart HTTP request"); +} + + +HTMLForm::MultipartReadBuffer::MultipartReadBuffer(ReadBuffer & in_, const std::string & boundary_) + : ReadBuffer(nullptr, 0), in(in_), boundary("--" + boundary_) +{ + /// For consistency with |nextImpl()| + position() = in.position(); +} + +bool HTMLForm::MultipartReadBuffer::skipToNextBoundary() +{ + if (in.eof()) + return false; + + chassert(boundary_hit); + chassert(!found_last_boundary); + + boundary_hit = false; + + while (!in.eof()) + { + auto line = readLine(true); + if (startsWith(line, boundary)) + { + set(in.position(), 0); + next(); /// We need to restrict our buffer to size of next available line. + found_last_boundary = startsWith(line, boundary + "--"); + return !found_last_boundary; + } + } + + return false; +} + +std::string HTMLForm::MultipartReadBuffer::readLine(bool append_crlf) +{ + std::string line; + char ch = 0; // silence "uninitialized" warning from gcc-* + + /// If we don't append CRLF, it means that we may have to prepend CRLF from previous content line, which wasn't the boundary. + if (in.read(ch)) + line += ch; + if (in.read(ch)) + line += ch; + if (append_crlf && line == "\r\n") + return line; + + while (!in.eof()) + { + while (in.read(ch) && ch != '\r') + line += ch; + + if (in.eof()) break; + + assert(ch == '\r'); + + if (in.peek(ch) && ch == '\n') + { + in.ignore(); + if (append_crlf) line += "\r\n"; + break; + } + + line += ch; + } + + return line; +} + +bool HTMLForm::MultipartReadBuffer::nextImpl() +{ + if (boundary_hit) + return false; + + assert(position() >= in.position()); + + in.position() = position(); + + /// We expect to start from the first symbol after EOL, so we can put checkpoint + /// and safely try to read til the next EOL and check for boundary. + in.setCheckpoint(); + + /// FIXME: there is an extra copy because we cannot traverse PeekableBuffer from checkpoint to position() + /// since it may store different data parts in different sub-buffers, + /// anyway calling makeContinuousMemoryFromCheckpointToPos() will also make an extra copy. + /// According to RFC2046 the preceding CRLF is a part of boundary line. + std::string line = readLine(false); + boundary_hit = startsWith(line, "\r\n" + boundary); + bool has_next = !boundary_hit && !line.empty(); + + if (has_next) + /// If we don't make sure that memory is contiguous then situation may happen, when part of the line is inside internal memory + /// and other part is inside sub-buffer, thus we'll be unable to setup our working buffer properly. + in.makeContinuousMemoryFromCheckpointToPos(); + + in.rollbackToCheckpoint(true); + + /// Rolling back to checkpoint may change underlying buffers. + /// Limit readable data to a single line. + BufferBase::set(in.position(), line.size(), 0); + + return has_next; +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTMLForm.h b/contrib/clickhouse/src/Server/HTTP/HTMLForm.h new file mode 100644 index 0000000000..c75dafccaf --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTMLForm.h @@ -0,0 +1,124 @@ +#pragma once + +#include <IO/PeekableReadBuffer.h> +#include <IO/ReadHelpers.h> + +#include <boost/noncopyable.hpp> +#include <Poco/Net/HTTPRequest.h> +#include <Poco/Net/NameValueCollection.h> +#include <Poco/Net/PartSource.h> +#include <Poco/URI.h> + +namespace DB +{ + +struct Settings; + +class HTMLForm : public Poco::Net::NameValueCollection, private boost::noncopyable +{ +public: + class PartHandler; + + enum Options + { + OPT_USE_CONTENT_LENGTH = 0x01, /// don't use Chunked Transfer-Encoding for multipart requests. + }; + + /// Creates an empty HTMLForm and sets the + /// encoding to "application/x-www-form-urlencoded". + explicit HTMLForm(const Settings & settings); + + /// Creates an empty HTMLForm that uses the given encoding. + /// Encoding must be either "application/x-www-form-urlencoded" (which is the default) or "multipart/form-data". + explicit HTMLForm(const Settings & settings, const std::string & encoding); + + /// Creates a HTMLForm from the given HTTP request. + /// Uploaded files are passed to the given PartHandler. + HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler); + + /// Creates a HTMLForm from the given HTTP request. + /// Uploaded files are silently discarded. + HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody); + + /// Creates a HTMLForm from the given HTTP request. + /// The request must be a GET request and the form data must be in the query string (URL encoded). + /// For POST requests, you must use one of the constructors taking an additional input stream for the request body. + explicit HTMLForm(const Settings & settings, const Poco::Net::HTTPRequest & request); + + explicit HTMLForm(const Settings & settings, const Poco::URI & uri); + + template <typename T> + T getParsed(const std::string & key, T default_value) + { + auto it = find(key); + return (it != end()) ? DB::parse<T>(it->second) : default_value; + } + + /// Reads the form data from the given HTTP request. + /// Uploaded files are passed to the given PartHandler. + void load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody, PartHandler & handler); + + /// Reads the form data from the given HTTP request. + /// Uploaded files are silently discarded. + void load(const Poco::Net::HTTPRequest & request, ReadBuffer & requestBody); + + /// Reads the URL-encoded form data from the given input stream. + /// Note that read() does not clear the form before reading the new values. + void read(ReadBuffer & in); + + static const std::string ENCODING_URL; /// "application/x-www-form-urlencoded" + static const std::string ENCODING_MULTIPART; /// "multipart/form-data" + static const int UNKNOWN_CONTENT_LENGTH; + +protected: + void readQuery(ReadBuffer & in); + void readMultipart(ReadBuffer & in, PartHandler & handler); + +private: + /// This buffer provides data line by line to check for boundary line in a convenient way. + class MultipartReadBuffer; + + struct Part + { + std::string name; + std::unique_ptr<Poco::Net::PartSource> source; + }; + + using PartVec = std::vector<Part>; + + const size_t max_fields_number, max_field_name_size, max_field_value_size; + + std::string encoding; + std::string boundary; + PartVec parts; +}; + +class HTMLForm::PartHandler +{ +public: + virtual ~PartHandler() = default; + virtual void handlePart(const Poco::Net::MessageHeader &, ReadBuffer &) = 0; +}; + +class HTMLForm::MultipartReadBuffer : public ReadBuffer +{ +public: + MultipartReadBuffer(ReadBuffer & in, const std::string & boundary); + + /// Returns false if last boundary found. + bool skipToNextBoundary(); + + bool isActualEOF() const { return found_last_boundary; } + +private: + PeekableReadBuffer in; + const std::string boundary; + bool boundary_hit = true; + bool found_last_boundary = false; + + std::string readLine(bool append_crlf); + + bool nextImpl() override; +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPContext.h b/contrib/clickhouse/src/Server/HTTP/HTTPContext.h new file mode 100644 index 0000000000..09c46ed188 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPContext.h @@ -0,0 +1,24 @@ +#pragma once + +#include <Poco/Timespan.h> + +namespace DB +{ + +struct IHTTPContext +{ + virtual uint64_t getMaxHstsAge() const = 0; + virtual uint64_t getMaxUriSize() const = 0; + virtual uint64_t getMaxFields() const = 0; + virtual uint64_t getMaxFieldNameSize() const = 0; + virtual uint64_t getMaxFieldValueSize() const = 0; + virtual uint64_t getMaxChunkSize() const = 0; + virtual Poco::Timespan getReceiveTimeout() const = 0; + virtual Poco::Timespan getSendTimeout() const = 0; + + virtual ~IHTTPContext() = default; +}; + +using HTTPContextPtr = std::shared_ptr<IHTTPContext>; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPRequest.h b/contrib/clickhouse/src/Server/HTTP/HTTPRequest.h new file mode 100644 index 0000000000..40839cbcdd --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPRequest.h @@ -0,0 +1,10 @@ +#pragma once + +#include <Poco/Net/HTTPRequest.h> + +namespace DB +{ + +using HTTPRequest = Poco::Net::HTTPRequest; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h new file mode 100644 index 0000000000..19340866bb --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandler.h @@ -0,0 +1,19 @@ +#pragma once + +#include <Server/HTTP/HTTPServerRequest.h> +#include <Server/HTTP/HTTPServerResponse.h> + +#include <boost/noncopyable.hpp> + +namespace DB +{ + +class HTTPRequestHandler : private boost::noncopyable +{ +public: + virtual ~HTTPRequestHandler() = default; + + virtual void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) = 0; +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h new file mode 100644 index 0000000000..3d50bf0a2e --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPRequestHandlerFactory.h @@ -0,0 +1,20 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandler.h> + +#include <boost/noncopyable.hpp> + +namespace DB +{ + +class HTTPRequestHandlerFactory : private boost::noncopyable +{ +public: + virtual ~HTTPRequestHandlerFactory() = default; + + virtual std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) = 0; +}; + +using HTTPRequestHandlerFactoryPtr = std::shared_ptr<HTTPRequestHandlerFactory>; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPResponse.h b/contrib/clickhouse/src/Server/HTTP/HTTPResponse.h new file mode 100644 index 0000000000..c73bcec6c3 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPResponse.h @@ -0,0 +1,10 @@ +#pragma once + +#include <Poco/Net/HTTPResponse.h> + +namespace DB +{ + +using HTTPResponse = Poco::Net::HTTPResponse; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp new file mode 100644 index 0000000000..4673493326 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServer.cpp @@ -0,0 +1,30 @@ +#include <Server/HTTP/HTTPServer.h> + +#include <Server/HTTP/HTTPServerConnectionFactory.h> + + +namespace DB +{ +HTTPServer::HTTPServer( + HTTPContextPtr context, + HTTPRequestHandlerFactoryPtr factory_, + Poco::ThreadPool & thread_pool, + Poco::Net::ServerSocket & socket_, + Poco::Net::HTTPServerParams::Ptr params) + : TCPServer(new HTTPServerConnectionFactory(context, params, factory_), thread_pool, socket_, params), factory(factory_) +{ +} + +HTTPServer::~HTTPServer() +{ + /// We should call stop and join thread here instead of destructor of parent TCPHandler, + /// because there's possible race on 'vptr' between this virtual destructor and 'run' method. + stop(); +} + +void HTTPServer::stopAll(bool /* abortCurrent */) +{ + stop(); +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServer.h b/contrib/clickhouse/src/Server/HTTP/HTTPServer.h new file mode 100644 index 0000000000..adfb21e7c6 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServer.h @@ -0,0 +1,33 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandlerFactory.h> +#include <Server/HTTP/HTTPContext.h> +#include <Server/TCPServer.h> + +#include <Poco/Net/HTTPServerParams.h> + +#include <base/types.h> + + +namespace DB +{ + +class HTTPServer : public TCPServer +{ +public: + explicit HTTPServer( + HTTPContextPtr context, + HTTPRequestHandlerFactoryPtr factory, + Poco::ThreadPool & thread_pool, + Poco::Net::ServerSocket & socket, + Poco::Net::HTTPServerParams::Ptr params); + + ~HTTPServer() override; + + void stopAll(bool abort_current = false); + +private: + HTTPRequestHandlerFactoryPtr factory; +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp new file mode 100644 index 0000000000..ad17bc4348 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.cpp @@ -0,0 +1,121 @@ +#include <Server/HTTP/HTTPServerConnection.h> +#include <Server/TCPServer.h> + +#include <Poco/Net/NetException.h> + +namespace DB +{ + +HTTPServerConnection::HTTPServerConnection( + HTTPContextPtr context_, + TCPServer & tcp_server_, + const Poco::Net::StreamSocket & socket, + Poco::Net::HTTPServerParams::Ptr params_, + HTTPRequestHandlerFactoryPtr factory_) + : TCPServerConnection(socket), context(std::move(context_)), tcp_server(tcp_server_), params(params_), factory(factory_), stopped(false) +{ + poco_check_ptr(factory); +} + +void HTTPServerConnection::run() +{ + std::string server = params->getSoftwareVersion(); + Poco::Net::HTTPServerSession session(socket(), params); + + while (!stopped && tcp_server.isOpen() && session.hasMoreRequests() && session.connected()) + { + try + { + std::lock_guard lock(mutex); + if (!stopped && tcp_server.isOpen() && session.connected()) + { + HTTPServerResponse response(session); + HTTPServerRequest request(context, response, session); + + Poco::Timestamp now; + + if (!forwarded_for.empty()) + request.set("X-Forwarded-For", forwarded_for); + + if (request.isSecure()) + { + size_t hsts_max_age = context->getMaxHstsAge(); + + if (hsts_max_age > 0) + response.add("Strict-Transport-Security", "max-age=" + std::to_string(hsts_max_age)); + + } + + response.setDate(now); + response.setVersion(request.getVersion()); + response.setKeepAlive(params->getKeepAlive() && request.getKeepAlive() && session.canKeepAlive()); + if (!server.empty()) + response.set("Server", server); + try + { + if (!tcp_server.isOpen()) + { + sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_SERVICE_UNAVAILABLE); + break; + } + std::unique_ptr<HTTPRequestHandler> handler(factory->createRequestHandler(request)); + + if (handler) + { + if (request.getExpectContinue() && response.getStatus() == Poco::Net::HTTPResponse::HTTP_OK) + response.sendContinue(); + + handler->handleRequest(request, response); + session.setKeepAlive(params->getKeepAlive() && response.getKeepAlive() && session.canKeepAlive()); + } + else + sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_NOT_IMPLEMENTED); + } + catch (Poco::Exception &) + { + if (!response.sent()) + { + try + { + sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR); + } + catch (...) + { + } + } + throw; + } + } + } + catch (const Poco::Net::NoMessageException &) + { + break; + } + catch (const Poco::Net::MessageException &) + { + sendErrorResponse(session, Poco::Net::HTTPResponse::HTTP_BAD_REQUEST); + } + catch (const Poco::Exception &) + { + if (session.networkException()) + { + session.networkException()->rethrow(); + } + else + throw; + } + } +} + +// static +void HTTPServerConnection::sendErrorResponse(Poco::Net::HTTPServerSession & session, Poco::Net::HTTPResponse::HTTPStatus status) +{ + HTTPServerResponse response(session); + response.setVersion(Poco::Net::HTTPMessage::HTTP_1_1); + response.setStatusAndReason(status); + response.setKeepAlive(false); + response.send(); + session.setKeepAlive(false); +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h new file mode 100644 index 0000000000..7087f8d5a2 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnection.h @@ -0,0 +1,51 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandlerFactory.h> +#include <Server/HTTP/HTTPContext.h> + +#include <Poco/Net/HTTPServerParams.h> +#include <Poco/Net/HTTPServerSession.h> +#include <Poco/Net/TCPServerConnection.h> + +namespace DB +{ +class TCPServer; + +class HTTPServerConnection : public Poco::Net::TCPServerConnection +{ +public: + HTTPServerConnection( + HTTPContextPtr context, + TCPServer & tcp_server, + const Poco::Net::StreamSocket & socket, + Poco::Net::HTTPServerParams::Ptr params, + HTTPRequestHandlerFactoryPtr factory); + + HTTPServerConnection( + HTTPContextPtr context_, + TCPServer & tcp_server_, + const Poco::Net::StreamSocket & socket_, + Poco::Net::HTTPServerParams::Ptr params_, + HTTPRequestHandlerFactoryPtr factory_, + const String & forwarded_for_) + : HTTPServerConnection(context_, tcp_server_, socket_, params_, factory_) + { + forwarded_for = forwarded_for_; + } + + void run() override; + +protected: + static void sendErrorResponse(Poco::Net::HTTPServerSession & session, Poco::Net::HTTPResponse::HTTPStatus status); + +private: + HTTPContextPtr context; + TCPServer & tcp_server; + Poco::Net::HTTPServerParams::Ptr params; + HTTPRequestHandlerFactoryPtr factory; + String forwarded_for; + bool stopped; + std::mutex mutex; // guards the |factory| with assumption that creating handlers is not thread-safe. +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp new file mode 100644 index 0000000000..2c9ac0cda2 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.cpp @@ -0,0 +1,24 @@ +#include <Server/HTTP/HTTPServerConnectionFactory.h> + +#include <Server/HTTP/HTTPServerConnection.h> + +namespace DB +{ +HTTPServerConnectionFactory::HTTPServerConnectionFactory( + HTTPContextPtr context_, Poco::Net::HTTPServerParams::Ptr params_, HTTPRequestHandlerFactoryPtr factory_) + : context(std::move(context_)), params(params_), factory(factory_) +{ + poco_check_ptr(factory); +} + +Poco::Net::TCPServerConnection * HTTPServerConnectionFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) +{ + return new HTTPServerConnection(context, tcp_server, socket, params, factory); +} + +Poco::Net::TCPServerConnection * HTTPServerConnectionFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) +{ + return new HTTPServerConnection(context, tcp_server, socket, params, factory, stack_data.forwarded_for); +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h new file mode 100644 index 0000000000..e18249da4d --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerConnectionFactory.h @@ -0,0 +1,26 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandlerFactory.h> +#include <Server/HTTP/HTTPContext.h> +#include <Server/TCPServerConnectionFactory.h> + +#include <Poco/Net/HTTPServerParams.h> + +namespace DB +{ + +class HTTPServerConnectionFactory : public TCPServerConnectionFactory +{ +public: + HTTPServerConnectionFactory(HTTPContextPtr context, Poco::Net::HTTPServerParams::Ptr params, HTTPRequestHandlerFactoryPtr factory); + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override; + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) override; + +private: + HTTPContextPtr context; + Poco::Net::HTTPServerParams::Ptr params; + HTTPRequestHandlerFactoryPtr factory; +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp new file mode 100644 index 0000000000..891ac39c93 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.cpp @@ -0,0 +1,174 @@ +#include <Server/HTTP/HTTPServerRequest.h> + +#include <IO/EmptyReadBuffer.h> +#include <IO/HTTPChunkedReadBuffer.h> +#include <IO/LimitReadBuffer.h> +#include <IO/ReadBufferFromPocoSocket.h> +#include <IO/ReadHelpers.h> +#include <Server/HTTP/HTTPServerResponse.h> +#include <Server/HTTP/ReadHeaders.h> + +#include <Poco/Net/HTTPHeaderStream.h> +#include <Poco/Net/HTTPStream.h> +#include <Poco/Net/NetException.h> + +#include <Common/logger_useful.h> + +#if USE_SSL +#include <Poco/Net/SecureStreamSocketImpl.h> +#include <Poco/Net/SSLException.h> +#include <Poco/Net/X509Certificate.h> +#endif + +namespace DB +{ +HTTPServerRequest::HTTPServerRequest(HTTPContextPtr context, HTTPServerResponse & response, Poco::Net::HTTPServerSession & session) + : max_uri_size(context->getMaxUriSize()) + , max_fields_number(context->getMaxFields()) + , max_field_name_size(context->getMaxFieldNameSize()) + , max_field_value_size(context->getMaxFieldValueSize()) +{ + response.attachRequest(this); + + /// Now that we know socket is still connected, obtain addresses + client_address = session.clientAddress(); + server_address = session.serverAddress(); + secure = session.socket().secure(); + + auto receive_timeout = context->getReceiveTimeout(); + auto send_timeout = context->getSendTimeout(); + + session.socket().setReceiveTimeout(receive_timeout); + session.socket().setSendTimeout(send_timeout); + + auto in = std::make_unique<ReadBufferFromPocoSocket>(session.socket()); + socket = session.socket().impl(); + + readRequest(*in); /// Try parse according to RFC7230 + + /// If a client crashes, most systems will gracefully terminate the connection with FIN just like it's done on close(). + /// So we will get 0 from recv(...) and will not be able to understand that something went wrong (well, we probably + /// will get RST later on attempt to write to the socket that closed on the other side, but it will happen when the query is finished). + /// If we are extremely unlucky and data format is TSV, for example, then we may stop parsing exactly between rows + /// and decide that it's EOF (but it is not). It may break deduplication, because clients cannot control it + /// and retry with exactly the same (incomplete) set of rows. + /// That's why we have to check body size if it's provided. + if (getChunkedTransferEncoding()) + stream = std::make_unique<HTTPChunkedReadBuffer>(std::move(in), context->getMaxChunkSize()); + else if (hasContentLength()) + { + size_t content_length = getContentLength(); + stream = std::make_unique<LimitReadBuffer>(std::move(in), content_length, + /* trow_exception */ true, /* exact_limit */ content_length); + } + else if (getMethod() != HTTPRequest::HTTP_GET && getMethod() != HTTPRequest::HTTP_HEAD && getMethod() != HTTPRequest::HTTP_DELETE) + { + stream = std::move(in); + if (!startsWith(getContentType(), "multipart/form-data")) + LOG_WARNING(LogFrequencyLimiter(&Poco::Logger::get("HTTPServerRequest"), 10), "Got an HTTP request with no content length " + "and no chunked/multipart encoding, it may be impossible to distinguish graceful EOF from abnormal connection loss"); + } + else + /// We have to distinguish empty buffer and nullptr. + stream = std::make_unique<EmptyReadBuffer>(); +} + +bool HTTPServerRequest::checkPeerConnected() const +{ + try + { + char b; + if (!socket->receiveBytes(&b, 1, MSG_DONTWAIT | MSG_PEEK)) + return false; + } + catch (Poco::TimeoutException &) + { + } + catch (...) + { + return false; + } + + return true; +} + +#if USE_SSL +bool HTTPServerRequest::havePeerCertificate() const +{ + if (!secure) + return false; + + const Poco::Net::SecureStreamSocketImpl * secure_socket = dynamic_cast<const Poco::Net::SecureStreamSocketImpl *>(socket); + if (!secure_socket) + return false; + + return secure_socket->havePeerCertificate(); +} + +Poco::Net::X509Certificate HTTPServerRequest::peerCertificate() const +{ + if (secure) + { + const Poco::Net::SecureStreamSocketImpl * secure_socket = dynamic_cast<const Poco::Net::SecureStreamSocketImpl *>(socket); + if (secure_socket) + return secure_socket->peerCertificate(); + } + throw Poco::Net::SSLException("No certificate available"); +} +#endif + +void HTTPServerRequest::readRequest(ReadBuffer & in) +{ + char ch; + std::string method; + std::string uri; + std::string version; + + method.reserve(16); + uri.reserve(64); + version.reserve(16); + + if (in.eof()) + throw Poco::Net::NoMessageException(); + + skipWhitespaceIfAny(in); + + if (in.eof()) + throw Poco::Net::MessageException("No HTTP request header"); + + while (in.read(ch) && !Poco::Ascii::isSpace(ch) && method.size() <= MAX_METHOD_LENGTH) + method += ch; + + if (method.size() > MAX_METHOD_LENGTH) + throw Poco::Net::MessageException("HTTP request method invalid or too long"); + + skipWhitespaceIfAny(in); + + while (in.read(ch) && !Poco::Ascii::isSpace(ch) && uri.size() <= max_uri_size) + uri += ch; + + if (uri.size() > max_uri_size) + throw Poco::Net::MessageException("HTTP request URI invalid or too long"); + + skipWhitespaceIfAny(in); + + while (in.read(ch) && !Poco::Ascii::isSpace(ch) && version.size() <= MAX_VERSION_LENGTH) + version += ch; + + if (version.size() > MAX_VERSION_LENGTH) + throw Poco::Net::MessageException("Invalid HTTP version string"); + + // since HTTP always use Windows-style EOL '\r\n' we always can safely skip to '\n' + + skipToNextLineOrEOF(in); + + readHeaders(*this, in, max_fields_number, max_field_name_size, max_field_value_size); + + skipToNextLineOrEOF(in); + + setMethod(method); + setURI(uri); + setVersion(version); +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h new file mode 100644 index 0000000000..da0e498b0d --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerRequest.h @@ -0,0 +1,73 @@ +#pragma once + +#include <Interpreters/Context_fwd.h> +#include <IO/ReadBuffer.h> +#include <Server/HTTP/HTTPRequest.h> +#include <Server/HTTP/HTTPContext.h> +#include "clickhouse_config.h" + +#include <Poco/Net/HTTPServerSession.h> + +namespace Poco::Net { class X509Certificate; } + +namespace DB +{ + +class HTTPServerResponse; +class ReadBufferFromPocoSocket; + +class HTTPServerRequest : public HTTPRequest +{ +public: + HTTPServerRequest(HTTPContextPtr context, HTTPServerResponse & response, Poco::Net::HTTPServerSession & session); + + /// FIXME: it's a little bit inconvenient interface. The rationale is that all other ReadBuffer's wrap each other + /// via unique_ptr - but we can't inherit HTTPServerRequest from ReadBuffer and pass it around, + /// since we also need it in other places. + + /// Returns the input stream for reading the request body. + ReadBuffer & getStream() + { + poco_check_ptr(stream); + return *stream; + } + + bool checkPeerConnected() const; + + bool isSecure() const { return secure; } + + /// Returns the client's address. + const Poco::Net::SocketAddress & clientAddress() const { return client_address; } + + /// Returns the server's address. + const Poco::Net::SocketAddress & serverAddress() const { return server_address; } + +#if USE_SSL + bool havePeerCertificate() const; + Poco::Net::X509Certificate peerCertificate() const; +#endif + +private: + /// Limits for basic sanity checks when reading a header + enum Limits + { + MAX_METHOD_LENGTH = 32, + MAX_VERSION_LENGTH = 8, + }; + + const size_t max_uri_size; + const size_t max_fields_number; + const size_t max_field_name_size; + const size_t max_field_value_size; + + std::unique_ptr<ReadBuffer> stream; + Poco::Net::SocketImpl * socket; + Poco::Net::SocketAddress client_address; + Poco::Net::SocketAddress server_address; + + bool secure; + + void readRequest(ReadBuffer & in); +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp new file mode 100644 index 0000000000..25e7604a51 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.cpp @@ -0,0 +1,121 @@ +#include <Server/HTTP/HTTPServerResponse.h> +#include <Server/HTTP/HTTPServerRequest.h> +#include <Poco/CountingStream.h> +#include <Poco/DateTimeFormat.h> +#include <Poco/DateTimeFormatter.h> +#include <Poco/FileStream.h> +#include <Poco/Net/HTTPChunkedStream.h> +#include <Poco/Net/HTTPFixedLengthStream.h> +#include <Poco/Net/HTTPHeaderStream.h> +#include <Poco/Net/HTTPStream.h> +#include <Poco/StreamCopier.h> + + +namespace DB +{ + +HTTPServerResponse::HTTPServerResponse(Poco::Net::HTTPServerSession & session_) : session(session_) +{ +} + +void HTTPServerResponse::sendContinue() +{ + Poco::Net::HTTPHeaderOutputStream hs(session); + hs << getVersion() << " 100 Continue\r\n\r\n"; +} + +std::shared_ptr<std::ostream> HTTPServerResponse::send() +{ + poco_assert(!stream); + + if ((request && request->getMethod() == HTTPRequest::HTTP_HEAD) || getStatus() < 200 || getStatus() == HTTPResponse::HTTP_NO_CONTENT + || getStatus() == HTTPResponse::HTTP_NOT_MODIFIED) + { + Poco::CountingOutputStream cs; + write(cs); + stream = std::make_shared<Poco::Net::HTTPFixedLengthOutputStream>(session, cs.chars()); + write(*stream); + } + else if (getChunkedTransferEncoding()) + { + Poco::Net::HTTPHeaderOutputStream hs(session); + write(hs); + stream = std::make_shared<Poco::Net::HTTPChunkedOutputStream>(session); + } + else if (hasContentLength()) + { + Poco::CountingOutputStream cs; + write(cs); + stream = std::make_shared<Poco::Net::HTTPFixedLengthOutputStream>(session, getContentLength64() + cs.chars()); + write(*stream); + } + else + { + stream = std::make_shared<Poco::Net::HTTPOutputStream>(session); + setKeepAlive(false); + write(*stream); + } + + return stream; +} + +std::pair<std::shared_ptr<std::ostream>, std::shared_ptr<std::ostream>> HTTPServerResponse::beginSend() +{ + poco_assert(!stream); + poco_assert(!header_stream); + + /// NOTE: Code is not exception safe. + + if ((request && request->getMethod() == HTTPRequest::HTTP_HEAD) || getStatus() < 200 || getStatus() == HTTPResponse::HTTP_NO_CONTENT + || getStatus() == HTTPResponse::HTTP_NOT_MODIFIED) + { + throw Poco::Exception("HTTPServerResponse::beginSend is invalid for HEAD request"); + } + else if (getChunkedTransferEncoding()) + { + header_stream = std::make_shared<Poco::Net::HTTPHeaderOutputStream>(session); + beginWrite(*header_stream); + stream = std::make_shared<Poco::Net::HTTPChunkedOutputStream>(session); + } + else if (hasContentLength()) + { + throw Poco::Exception("HTTPServerResponse::beginSend is invalid for response with Content-Length header"); + } + else + { + stream = std::make_shared<Poco::Net::HTTPOutputStream>(session); + header_stream = stream; + setKeepAlive(false); + beginWrite(*stream); + } + + return std::make_pair(header_stream, stream); +} + +void HTTPServerResponse::sendBuffer(const void * buffer, std::size_t length) +{ + poco_assert(!stream); + + setContentLength(static_cast<int>(length)); + setChunkedTransferEncoding(false); + + stream = std::make_shared<Poco::Net::HTTPHeaderOutputStream>(session); + write(*stream); + if (request && request->getMethod() != HTTPRequest::HTTP_HEAD) + { + stream->write(static_cast<const char *>(buffer), static_cast<std::streamsize>(length)); + } +} + +void HTTPServerResponse::requireAuthentication(const std::string & realm) +{ + poco_assert(!stream); + + setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED); + std::string auth("Basic realm=\""); + auth.append(realm); + auth.append("\""); + set("WWW-Authenticate", auth); +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h new file mode 100644 index 0000000000..236a56e232 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/HTTPServerResponse.h @@ -0,0 +1,70 @@ +#pragma once + +#include <Server/HTTP/HTTPResponse.h> + +#include <Poco/Net/HTTPServerSession.h> +#include <Poco/Net/HTTPResponse.h> + +#include <memory> + + +namespace DB +{ + +class HTTPServerRequest; + +class HTTPServerResponse : public HTTPResponse +{ +public: + explicit HTTPServerResponse(Poco::Net::HTTPServerSession & session); + + void sendContinue(); /// Sends a 100 Continue response to the client. + + /// Sends the response header to the client and + /// returns an output stream for sending the + /// response body. + /// + /// Must not be called after beginSend(), sendFile(), sendBuffer() + /// or redirect() has been called. + std::shared_ptr<std::ostream> send(); /// TODO: use some WriteBuffer implementation here. + + /// Sends the response headers to the client + /// but do not finish headers with \r\n, + /// allowing to continue sending additional header fields. + /// + /// Must not be called after send(), sendFile(), sendBuffer() + /// or redirect() has been called. + std::pair<std::shared_ptr<std::ostream>, std::shared_ptr<std::ostream>> beginSend(); /// TODO: use some WriteBuffer implementation here. + + /// Sends the response header to the client, followed + /// by the contents of the given buffer. + /// + /// The Content-Length header of the response is set + /// to length and chunked transfer encoding is disabled. + /// + /// If both the HTTP message header and body (from the + /// given buffer) fit into one single network packet, the + /// complete response can be sent in one network packet. + /// + /// Must not be called after send(), sendFile() + /// or redirect() has been called. + void sendBuffer(const void * pBuffer, std::size_t length); /// FIXME: do we need this one? + + void requireAuthentication(const std::string & realm); + /// Sets the status code to 401 (Unauthorized) + /// and sets the "WWW-Authenticate" header field + /// according to the given realm. + + /// Returns true if the response (header) has been sent. + bool sent() const { return !!stream; } + + void attachRequest(HTTPServerRequest * request_) { request = request_; } + +private: + Poco::Net::HTTPServerSession & session; + HTTPServerRequest * request = nullptr; + std::shared_ptr<std::ostream> stream; + std::shared_ptr<std::ostream> header_stream; +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTP/README.md b/contrib/clickhouse/src/Server/HTTP/README.md new file mode 100644 index 0000000000..7173096278 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/README.md @@ -0,0 +1,3 @@ +# Notice + +The source code located in this folder is based on some files from the POCO project, from here `contrib/poco/Net/src`. diff --git a/contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp new file mode 100644 index 0000000000..b705750106 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.cpp @@ -0,0 +1,85 @@ +#include <Server/HTTP/ReadHeaders.h> + +#include <IO/ReadBuffer.h> +#include <IO/ReadHelpers.h> + +#include <Poco/Net/NetException.h> + +namespace DB +{ + +void readHeaders( + Poco::Net::MessageHeader & headers, ReadBuffer & in, size_t max_fields_number, size_t max_name_length, size_t max_value_length) +{ + char ch = 0; // silence uninitialized warning from gcc-* + std::string name; + std::string value; + + name.reserve(32); + value.reserve(64); + + size_t fields = 0; + + while (true) + { + if (fields > max_fields_number) + throw Poco::Net::MessageException("Too many header fields"); + + name.clear(); + value.clear(); + + /// Field name + while (in.peek(ch) && ch != ':' && !Poco::Ascii::isSpace(ch) && name.size() <= max_name_length) + { + name += ch; + in.ignore(); + } + + if (in.eof()) + throw Poco::Net::MessageException("Field is invalid"); + + if (name.empty()) + { + if (ch == '\r') + /// Start of the empty-line delimiter + break; + if (ch == ':') + throw Poco::Net::MessageException("Field name is empty"); + } + else + { + if (name.size() > max_name_length) + throw Poco::Net::MessageException("Field name is too long"); + if (ch != ':') + throw Poco::Net::MessageException(fmt::format("Field name is invalid or no colon found: \"{}\"", name)); + } + + in.ignore(); + + skipWhitespaceIfAny(in, true); + + if (in.eof()) + throw Poco::Net::MessageException("Field is invalid"); + + /// Field value - folded values not supported. + while (in.read(ch) && ch != '\r' && ch != '\n' && value.size() <= max_value_length) + value += ch; + + if (in.eof()) + throw Poco::Net::MessageException("Field is invalid"); + + if (ch == '\n') + throw Poco::Net::MessageException("No CRLF found"); + + if (value.size() > max_value_length) + throw Poco::Net::MessageException("Field value is too long"); + + skipToNextLineOrEOF(in); + + Poco::trimRightInPlace(value); + headers.add(name, headers.decodeWord(value)); + ++fields; + } +} + +} diff --git a/contrib/clickhouse/src/Server/HTTP/ReadHeaders.h b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.h new file mode 100644 index 0000000000..1b0e627f77 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/ReadHeaders.h @@ -0,0 +1,13 @@ +#pragma once + +#include <Poco/Net/MessageHeader.h> + +namespace DB +{ + +class ReadBuffer; + +void readHeaders( + Poco::Net::MessageHeader & headers, ReadBuffer & in, size_t max_fields_number, size_t max_name_length, size_t max_value_length); + +} diff --git a/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp new file mode 100644 index 0000000000..046feee017 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.cpp @@ -0,0 +1,204 @@ +#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h> + +#include <IO/HTTPCommon.h> +#include <IO/Progress.h> +#include <IO/WriteBufferFromString.h> +#include <IO/WriteHelpers.h> + +namespace DB +{ + +namespace ErrorCodes +{ +} + + +void WriteBufferFromHTTPServerResponse::startSendHeaders() +{ + if (!headers_started_sending) + { + headers_started_sending = true; + + if (add_cors_header) + response.set("Access-Control-Allow-Origin", "*"); + + setResponseDefaultHeaders(response, keep_alive_timeout); + + if (!is_http_method_head) + std::tie(response_header_ostr, response_body_ostr) = response.beginSend(); + } +} + +void WriteBufferFromHTTPServerResponse::writeHeaderProgressImpl(const char * header_name) +{ + if (headers_finished_sending) + return; + + WriteBufferFromOwnString progress_string_writer; + + accumulated_progress.writeJSON(progress_string_writer); + + if (response_header_ostr) + *response_header_ostr << header_name << progress_string_writer.str() << "\r\n" << std::flush; +} + +void WriteBufferFromHTTPServerResponse::writeHeaderSummary() +{ + writeHeaderProgressImpl("X-ClickHouse-Summary: "); +} + +void WriteBufferFromHTTPServerResponse::writeHeaderProgress() +{ + writeHeaderProgressImpl("X-ClickHouse-Progress: "); +} + +void WriteBufferFromHTTPServerResponse::writeExceptionCode() +{ + if (headers_finished_sending || !exception_code) + return; + if (response_header_ostr) + *response_header_ostr << "X-ClickHouse-Exception-Code: " << exception_code << "\r\n" << std::flush; +} + +void WriteBufferFromHTTPServerResponse::finishSendHeaders() +{ + if (!headers_finished_sending) + { + writeHeaderSummary(); + writeExceptionCode(); + headers_finished_sending = true; + + if (!is_http_method_head) + { + /// Send end of headers delimiter. + if (response_header_ostr) + *response_header_ostr << "\r\n" << std::flush; + } + else + { + if (!response_body_ostr) + response_body_ostr = response.send(); + } + } +} + + +void WriteBufferFromHTTPServerResponse::nextImpl() +{ + if (!initialized) + { + std::lock_guard lock(mutex); + + /// Initialize as early as possible since if the code throws, + /// next() should not be called anymore. + initialized = true; + + startSendHeaders(); + + if (!out && !is_http_method_head) + { + if (compress) + { + auto content_encoding_name = toContentEncodingName(compression_method); + + *response_header_ostr << "Content-Encoding: " << content_encoding_name << "\r\n"; + } + + /// We reuse our buffer in "out" to avoid extra allocations and copies. + + if (compress) + out = wrapWriteBufferWithCompressionMethod( + std::make_unique<WriteBufferFromOStream>(*response_body_ostr), + compress ? compression_method : CompressionMethod::None, + compression_level, + working_buffer.size(), + working_buffer.begin()); + else + out = std::make_unique<WriteBufferFromOStream>( + *response_body_ostr, + working_buffer.size(), + working_buffer.begin()); + } + + finishSendHeaders(); + } + + if (out) + { + out->buffer() = buffer(); + out->position() = position(); + out->next(); + } +} + + +WriteBufferFromHTTPServerResponse::WriteBufferFromHTTPServerResponse( + HTTPServerResponse & response_, + bool is_http_method_head_, + size_t keep_alive_timeout_, + bool compress_, + CompressionMethod compression_method_) + : BufferWithOwnMemory<WriteBuffer>(DBMS_DEFAULT_BUFFER_SIZE) + , response(response_) + , is_http_method_head(is_http_method_head_) + , keep_alive_timeout(keep_alive_timeout_) + , compress(compress_) + , compression_method(compression_method_) +{ +} + + +void WriteBufferFromHTTPServerResponse::onProgress(const Progress & progress) +{ + std::lock_guard lock(mutex); + + /// Cannot add new headers if body was started to send. + if (headers_finished_sending) + return; + + accumulated_progress.incrementPiecewiseAtomically(progress); + if (send_progress && progress_watch.elapsed() >= send_progress_interval_ms * 1000000) + { + progress_watch.restart(); + + /// Send all common headers before our special progress headers. + startSendHeaders(); + writeHeaderProgress(); + } +} + +WriteBufferFromHTTPServerResponse::~WriteBufferFromHTTPServerResponse() +{ + finalize(); +} + +void WriteBufferFromHTTPServerResponse::finalizeImpl() +{ + try + { + next(); + if (out) + out->finalize(); + out.reset(); + /// Catch write-after-finalize bugs. + set(nullptr, 0); + } + catch (...) + { + /// Avoid calling WriteBufferFromOStream::next() from dtor + /// (via WriteBufferFromHTTPServerResponse::next()) + out.reset(); + throw; + } + + if (!offset()) + { + /// If no remaining data, just send headers. + std::lock_guard lock(mutex); + startSendHeaders(); + finishSendHeaders(); + } +} + + +} diff --git a/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h new file mode 100644 index 0000000000..94202e1e0e --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTP/WriteBufferFromHTTPServerResponse.h @@ -0,0 +1,134 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/CompressionMethod.h> +#include <IO/HTTPCommon.h> +#include <IO/Progress.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteBufferFromOStream.h> +#include <Server/HTTP/HTTPServerResponse.h> +#include <Common/NetException.h> +#include <Common/Stopwatch.h> + +#include <mutex> +#include <optional> + + +namespace DB +{ + +/// The difference from WriteBufferFromOStream is that this buffer gets the underlying std::ostream +/// (using response.send()) only after data is flushed for the first time. This is needed in HTTP +/// servers to change some HTTP headers (e.g. response code) before any data is sent to the client +/// (headers can't be changed after response.send() is called). +/// +/// In short, it allows delaying the call to response.send(). +/// +/// Additionally, supports HTTP response compression (in this case corresponding Content-Encoding +/// header will be set). +/// +/// Also this class write and flush special X-ClickHouse-Progress HTTP headers +/// if no data was sent at the time of progress notification. +/// This allows to implement progress bar in HTTP clients. +class WriteBufferFromHTTPServerResponse final : public BufferWithOwnMemory<WriteBuffer> +{ +public: + WriteBufferFromHTTPServerResponse( + HTTPServerResponse & response_, + bool is_http_method_head_, + size_t keep_alive_timeout_, + bool compress_ = false, /// If true - set Content-Encoding header and compress the result. + CompressionMethod compression_method_ = CompressionMethod::None); + + ~WriteBufferFromHTTPServerResponse() override; + + /// Writes progress in repeating HTTP headers. + void onProgress(const Progress & progress); + + /// Turn compression on or off. + /// The setting has any effect only if HTTP headers haven't been sent yet. + void setCompression(bool enable_compression) + { + compress = enable_compression; + } + + /// Set compression level if the compression is turned on. + /// The setting has any effect only if HTTP headers haven't been sent yet. + void setCompressionLevel(int level) + { + compression_level = level; + } + + /// Turn CORS on or off. + /// The setting has any effect only if HTTP headers haven't been sent yet. + void addHeaderCORS(bool enable_cors) + { + add_cors_header = enable_cors; + } + + /// Send progress + void setSendProgress(bool send_progress_) { send_progress = send_progress_; } + + /// Don't send HTTP headers with progress more frequently. + void setSendProgressInterval(size_t send_progress_interval_ms_) + { + send_progress_interval_ms = send_progress_interval_ms_; + } + + void setExceptionCode(int exception_code_) { exception_code = exception_code_; } + +private: + /// Send at least HTTP headers if no data has been sent yet. + /// Use after the data has possibly been sent and no error happened (and thus you do not plan + /// to change response HTTP code. + /// This method is idempotent. + void finalizeImpl() override; + + /// Must be called under locked mutex. + /// This method send headers, if this was not done already, + /// but not finish them with \r\n, allowing to send more headers subsequently. + void startSendHeaders(); + + /// Used to write the header X-ClickHouse-Progress / X-ClickHouse-Summary + void writeHeaderProgressImpl(const char * header_name); + /// Used to write the header X-ClickHouse-Progress + void writeHeaderProgress(); + /// Used to write the header X-ClickHouse-Summary + void writeHeaderSummary(); + /// Use to write the header X-ClickHouse-Exception-Code even when progress has been sent + void writeExceptionCode(); + + /// This method finish headers with \r\n, allowing to start to send body. + void finishSendHeaders(); + + void nextImpl() override; + + HTTPServerResponse & response; + + bool is_http_method_head; + bool add_cors_header = false; + size_t keep_alive_timeout = 0; + bool compress = false; + CompressionMethod compression_method; + int compression_level = 1; + + std::shared_ptr<std::ostream> response_body_ostr; + std::shared_ptr<std::ostream> response_header_ostr; + + std::unique_ptr<WriteBuffer> out; + bool initialized = false; + + bool headers_started_sending = false; + bool headers_finished_sending = false; /// If true, you could not add any headers. + + Progress accumulated_progress; + bool send_progress = false; + size_t send_progress_interval_ms = 100; + Stopwatch progress_watch; + + int exception_code = 0; + + std::mutex mutex; /// progress callback could be called from different threads. +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTPHandler.cpp b/contrib/clickhouse/src/Server/HTTPHandler.cpp new file mode 100644 index 0000000000..41ed78bc69 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPHandler.cpp @@ -0,0 +1,1320 @@ +#include <Server/HTTPHandler.h> + +#include <Access/Authentication.h> +#include <Access/Credentials.h> +#include <Access/ExternalAuthenticators.h> +#include <Compression/CompressedReadBuffer.h> +#include <Compression/CompressedWriteBuffer.h> +#include <Core/ExternalTable.h> +#include <Disks/StoragePolicy.h> +#include <IO/CascadeWriteBuffer.h> +#include <IO/ConcatReadBuffer.h> +#include <IO/MemoryReadWriteBuffer.h> +#include <IO/ReadBufferFromString.h> +#include <IO/WriteHelpers.h> +#include <IO/copyData.h> +#include <Interpreters/Context.h> +#include <Interpreters/TemporaryDataOnDisk.h> +#include <Parsers/QueryParameterVisitor.h> +#include <Interpreters/executeQuery.h> +#include <Interpreters/Session.h> +#include <Server/HTTPHandlerFactory.h> +#include <Server/HTTPHandlerRequestFilter.h> +#include <Server/IServer.h> +#include <Common/logger_useful.h> +#include <Common/SettingsChanges.h> +#include <Common/StringUtils/StringUtils.h> +#include <Common/scope_guard_safe.h> +#include <Common/setThreadName.h> +#include <Common/typeid_cast.h> +#include <Parsers/ASTSetQuery.h> + +#include <base/getFQDNOrHostName.h> +#include <base/scope_guard.h> +#include <Server/HTTP/HTTPResponse.h> + +#include "clickhouse_config.h" + +#include <Poco/Base64Decoder.h> +#include <Poco/Base64Encoder.h> +#include <Poco/Net/HTTPBasicCredentials.h> +#include <Poco/Net/HTTPStream.h> +#include <Poco/MemoryStream.h> +#include <Poco/StreamCopier.h> +#include <Poco/String.h> +#include <Poco/Net/SocketAddress.h> + +#include <re2/re2.h> + +#include <chrono> +#include <sstream> + +#if USE_SSL +#include <Poco/Net/X509Certificate.h> +#endif + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int CANNOT_PARSE_TEXT; + extern const int CANNOT_PARSE_ESCAPE_SEQUENCE; + extern const int CANNOT_PARSE_QUOTED_STRING; + extern const int CANNOT_PARSE_DATE; + extern const int CANNOT_PARSE_DATETIME; + extern const int CANNOT_PARSE_NUMBER; + extern const int CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING; + extern const int CANNOT_PARSE_IPV4; + extern const int CANNOT_PARSE_IPV6; + extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_COMPILE_REGEXP; + + extern const int UNKNOWN_ELEMENT_IN_AST; + extern const int UNKNOWN_TYPE_OF_AST_NODE; + extern const int TOO_DEEP_AST; + extern const int TOO_BIG_AST; + extern const int UNEXPECTED_AST_STRUCTURE; + + extern const int SYNTAX_ERROR; + + extern const int INCORRECT_DATA; + extern const int TYPE_MISMATCH; + + extern const int UNKNOWN_TABLE; + extern const int UNKNOWN_FUNCTION; + extern const int UNKNOWN_IDENTIFIER; + extern const int UNKNOWN_TYPE; + extern const int UNKNOWN_STORAGE; + extern const int UNKNOWN_DATABASE; + extern const int UNKNOWN_SETTING; + extern const int UNKNOWN_DIRECTION_OF_SORTING; + extern const int UNKNOWN_AGGREGATE_FUNCTION; + extern const int UNKNOWN_FORMAT; + extern const int UNKNOWN_DATABASE_ENGINE; + extern const int UNKNOWN_TYPE_OF_QUERY; + extern const int NO_ELEMENTS_IN_CONFIG; + + extern const int QUERY_IS_TOO_LARGE; + + extern const int NOT_IMPLEMENTED; + extern const int SOCKET_TIMEOUT; + + extern const int UNKNOWN_USER; + extern const int WRONG_PASSWORD; + extern const int REQUIRED_PASSWORD; + extern const int AUTHENTICATION_FAILED; + + extern const int INVALID_SESSION_TIMEOUT; + extern const int HTTP_LENGTH_REQUIRED; + extern const int SUPPORT_IS_DISABLED; + + extern const int TIMEOUT_EXCEEDED; +} + +namespace +{ +bool tryAddHttpOptionHeadersFromConfig(HTTPServerResponse & response, const Poco::Util::LayeredConfiguration & config) +{ + if (config.has("http_options_response")) + { + Strings config_keys; + config.keys("http_options_response", config_keys); + for (const std::string & config_key : config_keys) + { + if (config_key == "header" || config_key.starts_with("header[")) + { + /// If there is empty header name, it will not be processed and message about it will be in logs + if (config.getString("http_options_response." + config_key + ".name", "").empty()) + LOG_WARNING(&Poco::Logger::get("processOptionsRequest"), "Empty header was found in config. It will not be processed."); + else + response.add(config.getString("http_options_response." + config_key + ".name", ""), + config.getString("http_options_response." + config_key + ".value", "")); + + } + } + return true; + } + return false; +} + +/// Process options request. Useful for CORS. +void processOptionsRequest(HTTPServerResponse & response, const Poco::Util::LayeredConfiguration & config) +{ + /// If can add some headers from config + if (tryAddHttpOptionHeadersFromConfig(response, config)) + { + response.setKeepAlive(false); + response.setStatusAndReason(HTTPResponse::HTTP_NO_CONTENT); + response.send(); + } +} +} + +static String base64Decode(const String & encoded) +{ + String decoded; + Poco::MemoryInputStream istr(encoded.data(), encoded.size()); + Poco::Base64Decoder decoder(istr); + Poco::StreamCopier::copyToString(decoder, decoded); + return decoded; +} + +static String base64Encode(const String & decoded) +{ + std::ostringstream ostr; // STYLE_CHECK_ALLOW_STD_STRING_STREAM + ostr.exceptions(std::ios::failbit); + Poco::Base64Encoder encoder(ostr); + encoder.rdbuf()->setLineLength(0); + encoder << decoded; + encoder.close(); + return ostr.str(); +} + +static Poco::Net::HTTPResponse::HTTPStatus exceptionCodeToHTTPStatus(int exception_code) +{ + using namespace Poco::Net; + + if (exception_code == ErrorCodes::REQUIRED_PASSWORD) + { + return HTTPResponse::HTTP_UNAUTHORIZED; + } + else if (exception_code == ErrorCodes::UNKNOWN_USER || + exception_code == ErrorCodes::WRONG_PASSWORD || + exception_code == ErrorCodes::AUTHENTICATION_FAILED) + { + return HTTPResponse::HTTP_FORBIDDEN; + } + else if (exception_code == ErrorCodes::CANNOT_PARSE_TEXT || + exception_code == ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE || + exception_code == ErrorCodes::CANNOT_PARSE_QUOTED_STRING || + exception_code == ErrorCodes::CANNOT_PARSE_DATE || + exception_code == ErrorCodes::CANNOT_PARSE_DATETIME || + exception_code == ErrorCodes::CANNOT_PARSE_NUMBER || + exception_code == ErrorCodes::CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING || + exception_code == ErrorCodes::CANNOT_PARSE_IPV4 || + exception_code == ErrorCodes::CANNOT_PARSE_IPV6 || + exception_code == ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED || + exception_code == ErrorCodes::UNKNOWN_ELEMENT_IN_AST || + exception_code == ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE || + exception_code == ErrorCodes::TOO_DEEP_AST || + exception_code == ErrorCodes::TOO_BIG_AST || + exception_code == ErrorCodes::UNEXPECTED_AST_STRUCTURE || + exception_code == ErrorCodes::SYNTAX_ERROR || + exception_code == ErrorCodes::INCORRECT_DATA || + exception_code == ErrorCodes::TYPE_MISMATCH) + { + return HTTPResponse::HTTP_BAD_REQUEST; + } + else if (exception_code == ErrorCodes::UNKNOWN_TABLE || + exception_code == ErrorCodes::UNKNOWN_FUNCTION || + exception_code == ErrorCodes::UNKNOWN_IDENTIFIER || + exception_code == ErrorCodes::UNKNOWN_TYPE || + exception_code == ErrorCodes::UNKNOWN_STORAGE || + exception_code == ErrorCodes::UNKNOWN_DATABASE || + exception_code == ErrorCodes::UNKNOWN_SETTING || + exception_code == ErrorCodes::UNKNOWN_DIRECTION_OF_SORTING || + exception_code == ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION || + exception_code == ErrorCodes::UNKNOWN_FORMAT || + exception_code == ErrorCodes::UNKNOWN_DATABASE_ENGINE || + exception_code == ErrorCodes::UNKNOWN_TYPE_OF_QUERY) + { + return HTTPResponse::HTTP_NOT_FOUND; + } + else if (exception_code == ErrorCodes::QUERY_IS_TOO_LARGE) + { + return HTTPResponse::HTTP_REQUESTENTITYTOOLARGE; + } + else if (exception_code == ErrorCodes::NOT_IMPLEMENTED) + { + return HTTPResponse::HTTP_NOT_IMPLEMENTED; + } + else if (exception_code == ErrorCodes::SOCKET_TIMEOUT || + exception_code == ErrorCodes::CANNOT_OPEN_FILE) + { + return HTTPResponse::HTTP_SERVICE_UNAVAILABLE; + } + else if (exception_code == ErrorCodes::HTTP_LENGTH_REQUIRED) + { + return HTTPResponse::HTTP_LENGTH_REQUIRED; + } + else if (exception_code == ErrorCodes::TIMEOUT_EXCEEDED) + { + return HTTPResponse::HTTP_REQUEST_TIMEOUT; + } + + return HTTPResponse::HTTP_INTERNAL_SERVER_ERROR; +} + + +static std::chrono::steady_clock::duration parseSessionTimeout( + const Poco::Util::AbstractConfiguration & config, + const HTMLForm & params) +{ + unsigned session_timeout = config.getInt("default_session_timeout", 60); + + if (params.has("session_timeout")) + { + unsigned max_session_timeout = config.getUInt("max_session_timeout", 3600); + std::string session_timeout_str = params.get("session_timeout"); + + ReadBufferFromString buf(session_timeout_str); + if (!tryReadIntText(session_timeout, buf) || !buf.eof()) + throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Invalid session timeout: '{}'", session_timeout_str); + + if (session_timeout > max_session_timeout) + throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Session timeout '{}' is larger than max_session_timeout: {}. " + "Maximum session timeout could be modified in configuration file.", + session_timeout_str, max_session_timeout); + } + + return std::chrono::seconds(session_timeout); +} + + +void HTTPHandler::pushDelayedResults(Output & used_output) +{ + std::vector<WriteBufferPtr> write_buffers; + ConcatReadBuffer::Buffers read_buffers; + + auto * cascade_buffer = typeid_cast<CascadeWriteBuffer *>(used_output.out_maybe_delayed_and_compressed.get()); + if (!cascade_buffer) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected CascadeWriteBuffer"); + + cascade_buffer->getResultBuffers(write_buffers); + + if (write_buffers.empty()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "At least one buffer is expected to overwrite result into HTTP response"); + + for (auto & write_buf : write_buffers) + { + if (!write_buf) + continue; + + IReadableWriteBuffer * write_buf_concrete = dynamic_cast<IReadableWriteBuffer *>(write_buf.get()); + if (write_buf_concrete) + { + ReadBufferPtr reread_buf = write_buf_concrete->tryGetReadBuffer(); + if (reread_buf) + read_buffers.emplace_back(wrapReadBufferPointer(reread_buf)); + } + } + + if (!read_buffers.empty()) + { + ConcatReadBuffer concat_read_buffer(std::move(read_buffers)); + copyData(concat_read_buffer, *used_output.out_maybe_compressed); + } +} + + +HTTPHandler::HTTPHandler(IServer & server_, const std::string & name, const std::optional<String> & content_type_override_) + : server(server_) + , log(&Poco::Logger::get(name)) + , default_settings(server.context()->getSettingsRef()) + , content_type_override(content_type_override_) +{ + server_display_name = server.config().getString("display_name", getFQDNOrHostName()); +} + + +/// We need d-tor to be present in this translation unit to make it play well with some +/// forward decls in the header. Other than that, the default d-tor would be OK. +HTTPHandler::~HTTPHandler() = default; + + +bool HTTPHandler::authenticateUser( + HTTPServerRequest & request, + HTMLForm & params, + HTTPServerResponse & response) +{ + using namespace Poco::Net; + + /// The user and password can be passed by headers (similar to X-Auth-*), + /// which is used by load balancers to pass authentication information. + std::string user = request.get("X-ClickHouse-User", ""); + std::string password = request.get("X-ClickHouse-Key", ""); + std::string quota_key = request.get("X-ClickHouse-Quota", ""); + + /// The header 'X-ClickHouse-SSL-Certificate-Auth: on' enables checking the common name + /// extracted from the SSL certificate used for this connection instead of checking password. + bool has_ssl_certificate_auth = (request.get("X-ClickHouse-SSL-Certificate-Auth", "") == "on"); + bool has_auth_headers = !user.empty() || !password.empty() || !quota_key.empty() || has_ssl_certificate_auth; + + /// User name and password can be passed using HTTP Basic auth or query parameters + /// (both methods are insecure). + bool has_http_credentials = request.hasCredentials(); + bool has_credentials_in_query_params = params.has("user") || params.has("password") || params.has("quota_key"); + + std::string spnego_challenge; + std::string certificate_common_name; + + if (has_auth_headers) + { + /// It is prohibited to mix different authorization schemes. + if (has_http_credentials) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Invalid authentication: it is not allowed " + "to use SSL certificate authentication and Authorization HTTP header simultaneously"); + if (has_credentials_in_query_params) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Invalid authentication: it is not allowed " + "to use SSL certificate authentication and authentication via parameters simultaneously simultaneously"); + + if (has_ssl_certificate_auth) + { +#if USE_SSL + if (!password.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Invalid authentication: it is not allowed " + "to use SSL certificate authentication and authentication via password simultaneously"); + + if (request.havePeerCertificate()) + certificate_common_name = request.peerCertificate().commonName(); + + if (certificate_common_name.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Invalid authentication: SSL certificate authentication requires nonempty certificate's Common Name"); +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, + "SSL certificate authentication disabled because ClickHouse was built without SSL library"); +#endif + } + } + else if (has_http_credentials) + { + /// It is prohibited to mix different authorization schemes. + if (has_credentials_in_query_params) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Invalid authentication: it is not allowed " + "to use Authorization HTTP header and authentication via parameters simultaneously"); + + std::string scheme; + std::string auth_info; + request.getCredentials(scheme, auth_info); + + if (Poco::icompare(scheme, "Basic") == 0) + { + HTTPBasicCredentials credentials(auth_info); + user = credentials.getUsername(); + password = credentials.getPassword(); + } + else if (Poco::icompare(scheme, "Negotiate") == 0) + { + spnego_challenge = auth_info; + + if (spnego_challenge.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: SPNEGO challenge is empty"); + } + else + { + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: '{}' HTTP Authorization scheme is not supported", scheme); + } + + quota_key = params.get("quota_key", ""); + } + else + { + /// If the user name is not set we assume it's the 'default' user. + user = params.get("user", "default"); + password = params.get("password", ""); + quota_key = params.get("quota_key", ""); + } + + if (!certificate_common_name.empty()) + { + if (!request_credentials) + request_credentials = std::make_unique<SSLCertificateCredentials>(user, certificate_common_name); + + auto * certificate_credentials = dynamic_cast<SSLCertificateCredentials *>(request_credentials.get()); + if (!certificate_credentials) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: expected SSL certificate authorization scheme"); + } + else if (!spnego_challenge.empty()) + { + if (!request_credentials) + request_credentials = server.context()->makeGSSAcceptorContext(); + + auto * gss_acceptor_context = dynamic_cast<GSSAcceptorContext *>(request_credentials.get()); + if (!gss_acceptor_context) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: unexpected 'Negotiate' HTTP Authorization scheme expected"); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunreachable-code" + const auto spnego_response = base64Encode(gss_acceptor_context->processToken(base64Decode(spnego_challenge), log)); +#pragma clang diagnostic pop + + if (!spnego_response.empty()) + response.set("WWW-Authenticate", "Negotiate " + spnego_response); + + if (!gss_acceptor_context->isFailed() && !gss_acceptor_context->isReady()) + { + if (spnego_response.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: 'Negotiate' HTTP Authorization failure"); + + response.setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED); + response.send(); + return false; + } + } + else // I.e., now using user name and password strings ("Basic"). + { + if (!request_credentials) + request_credentials = std::make_unique<BasicCredentials>(); + + auto * basic_credentials = dynamic_cast<BasicCredentials *>(request_credentials.get()); + if (!basic_credentials) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Invalid authentication: expected 'Basic' HTTP Authorization scheme"); + + basic_credentials->setUserName(user); + basic_credentials->setPassword(password); + } + + /// Set client info. It will be used for quota accounting parameters in 'setUser' method. + + ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; + if (request.getMethod() == HTTPServerRequest::HTTP_GET) + http_method = ClientInfo::HTTPMethod::GET; + else if (request.getMethod() == HTTPServerRequest::HTTP_POST) + http_method = ClientInfo::HTTPMethod::POST; + + session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", "")); + session->setForwardedFor(request.get("X-Forwarded-For", "")); + session->setQuotaClientKey(quota_key); + + /// Extract the last entry from comma separated list of forwarded_for addresses. + /// Only the last proxy can be trusted (if any). + String forwarded_address = session->getClientInfo().getLastForwardedFor(); + try + { + if (!forwarded_address.empty() && server.config().getBool("auth_use_forwarded_address", false)) + session->authenticate(*request_credentials, Poco::Net::SocketAddress(forwarded_address, request.clientAddress().port())); + else + session->authenticate(*request_credentials, request.clientAddress()); + } + catch (const Authentication::Require<BasicCredentials> & required_credentials) + { + request_credentials = std::make_unique<BasicCredentials>(); + + if (required_credentials.getRealm().empty()) + response.set("WWW-Authenticate", "Basic"); + else + response.set("WWW-Authenticate", "Basic realm=\"" + required_credentials.getRealm() + "\""); + + response.setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED); + response.send(); + return false; + } + catch (const Authentication::Require<GSSAcceptorContext> & required_credentials) + { + request_credentials = server.context()->makeGSSAcceptorContext(); + + if (required_credentials.getRealm().empty()) + response.set("WWW-Authenticate", "Negotiate"); + else + response.set("WWW-Authenticate", "Negotiate realm=\"" + required_credentials.getRealm() + "\""); + + response.setStatusAndReason(HTTPResponse::HTTP_UNAUTHORIZED); + response.send(); + return false; + } + + request_credentials.reset(); + return true; +} + + +void HTTPHandler::processQuery( + HTTPServerRequest & request, + HTMLForm & params, + HTTPServerResponse & response, + Output & used_output, + std::optional<CurrentThread::QueryScope> & query_scope) +{ + using namespace Poco::Net; + + LOG_TRACE(log, "Request URI: {}", request.getURI()); + + if (!authenticateUser(request, params, response)) + return; // '401 Unauthorized' response with 'Negotiate' has been sent at this point. + + /// The user could specify session identifier and session timeout. + /// It allows to modify settings, create temporary tables and reuse them in subsequent requests. + String session_id; + std::chrono::steady_clock::duration session_timeout; + bool session_is_set = params.has("session_id"); + const auto & config = server.config(); + + if (session_is_set) + { + session_id = params.get("session_id"); + session_timeout = parseSessionTimeout(config, params); + std::string session_check = params.get("session_check", ""); + session->makeSessionContext(session_id, session_timeout, session_check == "1"); + } + else + { + /// We should create it even if we don't have a session_id + session->makeSessionContext(); + } + + auto context = session->makeQueryContext(); + + /// This parameter is used to tune the behavior of output formats (such as Native) for compatibility. + if (params.has("client_protocol_version")) + { + UInt64 version_param = parse<UInt64>(params.get("client_protocol_version")); + context->setClientProtocolVersion(version_param); + } + + /// The client can pass a HTTP header indicating supported compression method (gzip or deflate). + String http_response_compression_methods = request.get("Accept-Encoding", ""); + CompressionMethod http_response_compression_method = CompressionMethod::None; + + if (!http_response_compression_methods.empty()) + http_response_compression_method = chooseHTTPCompressionMethod(http_response_compression_methods); + + bool client_supports_http_compression = http_response_compression_method != CompressionMethod::None; + + /// Client can pass a 'compress' flag in the query string. In this case the query result is + /// compressed using internal algorithm. This is not reflected in HTTP headers. + bool internal_compression = params.getParsed<bool>("compress", false); + + /// At least, we should postpone sending of first buffer_size result bytes + size_t buffer_size_total = std::max( + params.getParsed<size_t>("buffer_size", context->getSettingsRef().http_response_buffer_size), + static_cast<size_t>(DBMS_DEFAULT_BUFFER_SIZE)); + + /// If it is specified, the whole result will be buffered. + /// First ~buffer_size bytes will be buffered in memory, the remaining bytes will be stored in temporary file. + bool buffer_until_eof = params.getParsed<bool>("wait_end_of_query", context->getSettingsRef().http_wait_end_of_query); + + size_t buffer_size_http = DBMS_DEFAULT_BUFFER_SIZE; + size_t buffer_size_memory = (buffer_size_total > buffer_size_http) ? buffer_size_total : 0; + + unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10); + + used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>( + response, + request.getMethod() == HTTPRequest::HTTP_HEAD, + keep_alive_timeout, + client_supports_http_compression, + http_response_compression_method); + + if (internal_compression) + used_output.out_maybe_compressed = std::make_shared<CompressedWriteBuffer>(*used_output.out); + else + used_output.out_maybe_compressed = used_output.out; + + if (buffer_size_memory > 0 || buffer_until_eof) + { + CascadeWriteBuffer::WriteBufferPtrs cascade_buffer1; + CascadeWriteBuffer::WriteBufferConstructors cascade_buffer2; + + if (buffer_size_memory > 0) + cascade_buffer1.emplace_back(std::make_shared<MemoryWriteBuffer>(buffer_size_memory)); + + if (buffer_until_eof) + { + auto tmp_data = std::make_shared<TemporaryDataOnDisk>(server.context()->getTempDataOnDisk()); + + auto create_tmp_disk_buffer = [tmp_data] (const WriteBufferPtr &) -> WriteBufferPtr + { + return tmp_data->createRawStream(); + }; + + cascade_buffer2.emplace_back(std::move(create_tmp_disk_buffer)); + } + else + { + auto push_memory_buffer_and_continue = [next_buffer = used_output.out_maybe_compressed] (const WriteBufferPtr & prev_buf) + { + auto * prev_memory_buffer = typeid_cast<MemoryWriteBuffer *>(prev_buf.get()); + if (!prev_memory_buffer) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected MemoryWriteBuffer"); + + auto rdbuf = prev_memory_buffer->tryGetReadBuffer(); + copyData(*rdbuf, *next_buffer); + + return next_buffer; + }; + + cascade_buffer2.emplace_back(push_memory_buffer_and_continue); + } + + used_output.out_maybe_delayed_and_compressed = std::make_shared<CascadeWriteBuffer>( + std::move(cascade_buffer1), std::move(cascade_buffer2)); + } + else + { + used_output.out_maybe_delayed_and_compressed = used_output.out_maybe_compressed; + } + + /// Request body can be compressed using algorithm specified in the Content-Encoding header. + String http_request_compression_method_str = request.get("Content-Encoding", ""); + int zstd_window_log_max = static_cast<int>(context->getSettingsRef().zstd_window_log_max); + auto in_post = wrapReadBufferWithCompressionMethod( + wrapReadBufferReference(request.getStream()), + chooseCompressionMethod({}, http_request_compression_method_str), zstd_window_log_max); + + /// The data can also be compressed using incompatible internal algorithm. This is indicated by + /// 'decompress' query parameter. + std::unique_ptr<ReadBuffer> in_post_maybe_compressed; + bool in_post_compressed = false; + if (params.getParsed<bool>("decompress", false)) + { + in_post_maybe_compressed = std::make_unique<CompressedReadBuffer>(*in_post); + in_post_compressed = true; + } + else + in_post_maybe_compressed = std::move(in_post); + + std::unique_ptr<ReadBuffer> in; + + static const NameSet reserved_param_names{"compress", "decompress", "user", "password", "quota_key", "query_id", "stacktrace", + "buffer_size", "wait_end_of_query", "session_id", "session_timeout", "session_check", "client_protocol_version", "close_session"}; + + Names reserved_param_suffixes; + + auto param_could_be_skipped = [&] (const String & name) + { + /// Empty parameter appears when URL like ?&a=b or a=b&&c=d. Just skip them for user's convenience. + if (name.empty()) + return true; + + if (reserved_param_names.contains(name)) + return true; + + for (const String & suffix : reserved_param_suffixes) + { + if (endsWith(name, suffix)) + return true; + } + + return false; + }; + + /// Settings can be overridden in the query. + /// Some parameters (database, default_format, everything used in the code above) do not + /// belong to the Settings class. + + /// 'readonly' setting values mean: + /// readonly = 0 - any query is allowed, client can change any setting. + /// readonly = 1 - only readonly queries are allowed, client can't change settings. + /// readonly = 2 - only readonly queries are allowed, client can change any setting except 'readonly'. + + /// In theory if initially readonly = 0, the client can change any setting and then set readonly + /// to some other value. + const auto & settings = context->getSettingsRef(); + + /// Only readonly queries are allowed for HTTP GET requests. + if (request.getMethod() == HTTPServerRequest::HTTP_GET) + { + if (settings.readonly == 0) + context->setSetting("readonly", 2); + } + + bool has_external_data = startsWith(request.getContentType(), "multipart/form-data"); + + if (has_external_data) + { + /// Skip unneeded parameters to avoid confusing them later with context settings or query parameters. + reserved_param_suffixes.reserve(3); + /// It is a bug and ambiguity with `date_time_input_format` and `low_cardinality_allow_in_native_format` formats/settings. + reserved_param_suffixes.emplace_back("_format"); + reserved_param_suffixes.emplace_back("_types"); + reserved_param_suffixes.emplace_back("_structure"); + } + + std::string database = request.get("X-ClickHouse-Database", ""); + std::string default_format = request.get("X-ClickHouse-Format", ""); + + SettingsChanges settings_changes; + for (const auto & [key, value] : params) + { + if (key == "database") + { + if (database.empty()) + database = value; + } + else if (key == "default_format") + { + if (default_format.empty()) + default_format = value; + } + else if (param_could_be_skipped(key)) + { + } + else + { + /// Other than query parameters are treated as settings. + if (!customizeQueryParam(context, key, value)) + settings_changes.push_back({key, value}); + } + } + + if (!database.empty()) + context->setCurrentDatabase(database); + + if (!default_format.empty()) + context->setDefaultFormat(default_format); + + /// For external data we also want settings + context->checkSettingsConstraints(settings_changes, SettingSource::QUERY); + context->applySettingsChanges(settings_changes); + + /// Set the query id supplied by the user, if any, and also update the OpenTelemetry fields. + context->setCurrentQueryId(params.get("query_id", request.get("X-ClickHouse-Query-Id", ""))); + + /// Initialize query scope, once query_id is initialized. + /// (To track as much allocations as possible) + query_scope.emplace(context); + + /// NOTE: this may create pretty huge allocations that will not be accounted in trace_log, + /// because memory_profiler_sample_probability/memory_profiler_step are not applied yet, + /// they will be applied in ProcessList::insert() from executeQuery() itself. + const auto & query = getQuery(request, params, context); + std::unique_ptr<ReadBuffer> in_param = std::make_unique<ReadBufferFromString>(query); + + /// HTTP response compression is turned on only if the client signalled that they support it + /// (using Accept-Encoding header) and 'enable_http_compression' setting is turned on. + used_output.out->setCompression(client_supports_http_compression && settings.enable_http_compression); + if (client_supports_http_compression) + used_output.out->setCompressionLevel(static_cast<int>(settings.http_zlib_compression_level)); + + used_output.out->setSendProgress(settings.send_progress_in_http_headers); + used_output.out->setSendProgressInterval(settings.http_headers_progress_interval_ms); + + /// If 'http_native_compression_disable_checksumming_on_decompress' setting is turned on, + /// checksums of client data compressed with internal algorithm are not checked. + if (in_post_compressed && settings.http_native_compression_disable_checksumming_on_decompress) + static_cast<CompressedReadBuffer &>(*in_post_maybe_compressed).disableChecksumming(); + + /// Add CORS header if 'add_http_cors_header' setting is turned on send * in Access-Control-Allow-Origin + /// Note that whether the header is added is determined by the settings, and we can only get the user settings after authentication. + /// Once the authentication fails, the header can't be added. + if (settings.add_http_cors_header && !request.get("Origin", "").empty() && !config.has("http_options_response")) + used_output.out->addHeaderCORS(true); + + auto append_callback = [my_context = context] (ProgressCallback callback) + { + auto prev = my_context->getProgressCallback(); + + my_context->setProgressCallback([prev, callback] (const Progress & progress) + { + if (prev) + prev(progress); + + callback(progress); + }); + }; + + /// While still no data has been sent, we will report about query execution progress by sending HTTP headers. + /// Note that we add it unconditionally so the progress is available for `X-ClickHouse-Summary` + append_callback([&used_output](const Progress & progress) + { + used_output.out->onProgress(progress); + }); + + if (settings.readonly > 0 && settings.cancel_http_readonly_queries_on_client_close) + { + append_callback([&context, &request](const Progress &) + { + /// Assume that at the point this method is called no one is reading data from the socket any more: + /// should be true for read-only queries. + if (!request.checkPeerConnected()) + context->killCurrentQuery(); + }); + } + + customizeContext(request, context, *in_post_maybe_compressed); + in = has_external_data ? std::move(in_param) : std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed); + + executeQuery(*in, *used_output.out_maybe_delayed_and_compressed, /* allow_into_outfile = */ false, context, + [&response, this] (const QueryResultDetails & details) + { + response.add("X-ClickHouse-Query-Id", details.query_id); + + if (content_type_override) + response.setContentType(*content_type_override); + else if (details.content_type) + response.setContentType(*details.content_type); + + if (details.format) + response.add("X-ClickHouse-Format", *details.format); + + if (details.timezone) + response.add("X-ClickHouse-Timezone", *details.timezone); + } + ); + + if (used_output.hasDelayed()) + { + /// TODO: set Content-Length if possible + pushDelayedResults(used_output); + } + + /// Send HTTP headers with code 200 if no exception happened and the data is still not sent to the client. + used_output.finalize(); +} + +void HTTPHandler::trySendExceptionToClient( + const std::string & s, int exception_code, HTTPServerRequest & request, HTTPServerResponse & response, Output & used_output) +try +{ + /// In case data has already been sent, like progress headers, try using the output buffer to + /// set the exception code since it will be able to append it if it hasn't finished writing headers + if (response.sent() && used_output.out) + used_output.out->setExceptionCode(exception_code); + else + response.set("X-ClickHouse-Exception-Code", toString<int>(exception_code)); + + /// FIXME: make sure that no one else is reading from the same stream at the moment. + + /// If HTTP method is POST and Keep-Alive is turned on, we should read the whole request body + /// to avoid reading part of the current request body in the next request. + if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST + && response.getKeepAlive() + && exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED + && !request.getStream().eof()) + { + request.getStream().ignoreAll(); + } + + if (exception_code == ErrorCodes::REQUIRED_PASSWORD) + { + response.requireAuthentication("ClickHouse server HTTP API"); + } + else + { + response.setStatusAndReason(exceptionCodeToHTTPStatus(exception_code)); + } + + if (!response.sent() && !used_output.out_maybe_compressed) + { + /// If nothing was sent yet and we don't even know if we must compress the response. + *response.send() << s << std::endl; + } + else if (used_output.out_maybe_compressed) + { + /// Destroy CascadeBuffer to actualize buffers' positions and reset extra references + if (used_output.hasDelayed()) + { + /// do not call finalize here for CascadeWriteBuffer used_output.out_maybe_delayed_and_compressed, + /// exception is written into used_output.out_maybe_compressed later + /// HTTPHandler::trySendExceptionToClient is called with exception context, it is Ok to destroy buffers + used_output.out_maybe_delayed_and_compressed.reset(); + } + + /// Send the error message into already used (and possibly compressed) stream. + /// Note that the error message will possibly be sent after some data. + /// Also HTTP code 200 could have already been sent. + + /// If buffer has data, and that data wasn't sent yet, then no need to send that data + bool data_sent = used_output.out->count() != used_output.out->offset(); + + if (!data_sent) + { + used_output.out_maybe_compressed->position() = used_output.out_maybe_compressed->buffer().begin(); + used_output.out->position() = used_output.out->buffer().begin(); + } + + writeString(s, *used_output.out_maybe_compressed); + writeChar('\n', *used_output.out_maybe_compressed); + + used_output.out_maybe_compressed->next(); + } + else + { + UNREACHABLE(); + } + + used_output.finalize(); +} +catch (...) +{ + tryLogCurrentException(log, "Cannot send exception to client"); + + try + { + used_output.finalize(); + } + catch (...) + { + tryLogCurrentException(log, "Cannot flush data to client (after sending exception)"); + } +} + + +void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + setThreadName("HTTPHandler"); + ThreadStatus thread_status; + + session = std::make_unique<Session>(server.context(), ClientInfo::Interface::HTTP, request.isSecure()); + SCOPE_EXIT({ session.reset(); }); + std::optional<CurrentThread::QueryScope> query_scope; + + Output used_output; + + /// In case of exception, send stack trace to client. + bool with_stacktrace = false; + /// Close http session (if any) after processing the request + bool close_session = false; + String session_id; + + SCOPE_EXIT_SAFE({ + if (close_session && !session_id.empty()) + session->closeSession(session_id); + }); + + OpenTelemetry::TracingContextHolderPtr thread_trace_context; + SCOPE_EXIT({ + // make sure the response status is recorded + if (thread_trace_context) + thread_trace_context->root_span.addAttribute("clickhouse.http_status", response.getStatus()); + }); + + try + { + if (request.getMethod() == HTTPServerRequest::HTTP_OPTIONS) + { + processOptionsRequest(response, server.config()); + return; + } + + // Parse the OpenTelemetry traceparent header. + auto & client_trace_context = session->getClientTraceContext(); + if (request.has("traceparent")) + { + std::string opentelemetry_traceparent = request.get("traceparent"); + std::string error; + if (!client_trace_context.parseTraceparentHeader(opentelemetry_traceparent, error)) + { + LOG_DEBUG(log, "Failed to parse OpenTelemetry traceparent header '{}': {}", opentelemetry_traceparent, error); + } + client_trace_context.tracestate = request.get("tracestate", ""); + } + + // Setup tracing context for this thread + auto context = session->sessionOrGlobalContext(); + thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("HTTPHandler", + client_trace_context, + context->getSettingsRef(), + context->getOpenTelemetrySpanLog()); + thread_trace_context->root_span.kind = OpenTelemetry::SERVER; + thread_trace_context->root_span.addAttribute("clickhouse.uri", request.getURI()); + + response.setContentType("text/plain; charset=UTF-8"); + response.set("X-ClickHouse-Server-Display-Name", server_display_name); + + if (!request.get("Origin", "").empty()) + tryAddHttpOptionHeadersFromConfig(response, server.config()); + + /// For keep-alive to work. + if (request.getVersion() == HTTPServerRequest::HTTP_1_1) + response.setChunkedTransferEncoding(true); + + HTMLForm params(default_settings, request); + with_stacktrace = params.getParsed<bool>("stacktrace", false); + close_session = params.getParsed<bool>("close_session", false); + if (close_session) + session_id = params.get("session_id"); + + /// FIXME: maybe this check is already unnecessary. + /// Workaround. Poco does not detect 411 Length Required case. + if (request.getMethod() == HTTPRequest::HTTP_POST && !request.getChunkedTransferEncoding() && !request.hasContentLength()) + { + throw Exception(ErrorCodes::HTTP_LENGTH_REQUIRED, + "The Transfer-Encoding is not chunked and there " + "is no Content-Length header for POST request"); + } + + processQuery(request, params, response, used_output, query_scope); + if (request_credentials) + LOG_DEBUG(log, "Authentication in progress..."); + else + LOG_DEBUG(log, "Done processing query"); + } + catch (...) + { + SCOPE_EXIT({ + request_credentials.reset(); // ...so that the next requests on the connection have to always start afresh in case of exceptions. + }); + + /// Check if exception was thrown in used_output.finalize(). + /// In this case used_output can be in invalid state and we + /// cannot write in it anymore. So, just log this exception. + if (used_output.isFinalized()) + { + if (thread_trace_context) + thread_trace_context->root_span.addAttribute("clickhouse.exception", "Cannot flush data to client"); + + tryLogCurrentException(log, "Cannot flush data to client"); + return; + } + + tryLogCurrentException(log); + + /** If exception is received from remote server, then stack trace is embedded in message. + * If exception is thrown on local server, then stack trace is in separate field. + */ + ExecutionStatus status = ExecutionStatus::fromCurrentException("", with_stacktrace); + trySendExceptionToClient(status.message, status.code, request, response, used_output); + + if (thread_trace_context) + thread_trace_context->root_span.addAttribute(status); + } + + used_output.finalize(); +} + +DynamicQueryHandler::DynamicQueryHandler(IServer & server_, const std::string & param_name_, const std::optional<String>& content_type_override_) + : HTTPHandler(server_, "DynamicQueryHandler", content_type_override_), param_name(param_name_) +{ +} + +bool DynamicQueryHandler::customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) +{ + if (key == param_name) + return true; /// do nothing + + if (startsWith(key, QUERY_PARAMETER_NAME_PREFIX)) + { + /// Save name and values of substitution in dictionary. + const String parameter_name = key.substr(strlen(QUERY_PARAMETER_NAME_PREFIX)); + + if (!context->getQueryParameters().contains(parameter_name)) + context->setQueryParameter(parameter_name, value); + return true; + } + + return false; +} + +std::string DynamicQueryHandler::getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) +{ + if (likely(!startsWith(request.getContentType(), "multipart/form-data"))) + { + /// Part of the query can be passed in the 'query' parameter and the rest in the request body + /// (http method need not necessarily be POST). In this case the entire query consists of the + /// contents of the 'query' parameter, a line break and the request body. + std::string query_param = params.get(param_name, ""); + return query_param.empty() ? query_param : query_param + "\n"; + } + + /// Support for "external data for query processing". + /// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope. + ExternalTablesHandler handler(context, params); + params.load(request, request.getStream(), handler); + + std::string full_query; + /// Params are of both form params POST and uri (GET params) + for (const auto & it : params) + { + if (it.first == param_name) + { + full_query += it.second; + } + else + { + customizeQueryParam(context, it.first, it.second); + } + } + + return full_query; +} + +PredefinedQueryHandler::PredefinedQueryHandler( + IServer & server_, + const NameSet & receive_params_, + const std::string & predefined_query_, + const CompiledRegexPtr & url_regex_, + const std::unordered_map<String, CompiledRegexPtr> & header_name_with_regex_, + const std::optional<String> & content_type_override_) + : HTTPHandler(server_, "PredefinedQueryHandler", content_type_override_) + , receive_params(receive_params_) + , predefined_query(predefined_query_) + , url_regex(url_regex_) + , header_name_with_capture_regex(header_name_with_regex_) +{ +} + +bool PredefinedQueryHandler::customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) +{ + if (receive_params.contains(key)) + { + context->setQueryParameter(key, value); + return true; + } + + return false; +} + +void PredefinedQueryHandler::customizeContext(HTTPServerRequest & request, ContextMutablePtr context, ReadBuffer & body) +{ + /// If in the configuration file, the handler's header is regex and contains named capture group + /// We will extract regex named capture groups as query parameters + + const auto & set_query_params = [&](const char * begin, const char * end, const CompiledRegexPtr & compiled_regex) + { + int num_captures = compiled_regex->NumberOfCapturingGroups() + 1; + + std::string_view matches[num_captures]; + std::string_view input(begin, end - begin); + if (compiled_regex->Match(input, 0, end - begin, re2::RE2::Anchor::ANCHOR_BOTH, matches, num_captures)) + { + for (const auto & [capturing_name, capturing_index] : compiled_regex->NamedCapturingGroups()) + { + const auto & capturing_value = matches[capturing_index]; + + if (capturing_value.data()) + context->setQueryParameter(capturing_name, String(capturing_value.data(), capturing_value.size())); + } + } + }; + + if (url_regex) + { + const auto & uri = request.getURI(); + set_query_params(uri.data(), find_first_symbols<'?'>(uri.data(), uri.data() + uri.size()), url_regex); + } + + for (const auto & [header_name, regex] : header_name_with_capture_regex) + { + const auto & header_value = request.get(header_name); + set_query_params(header_value.data(), header_value.data() + header_value.size(), regex); + } + + if (unlikely(receive_params.contains("_request_body") && !context->getQueryParameters().contains("_request_body"))) + { + WriteBufferFromOwnString value; + const auto & settings = context->getSettingsRef(); + + copyDataMaxBytes(body, value, settings.http_max_request_param_data_size); + context->setQueryParameter("_request_body", value.str()); + } +} + +std::string PredefinedQueryHandler::getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) +{ + if (unlikely(startsWith(request.getContentType(), "multipart/form-data"))) + { + /// Support for "external data for query processing". + ExternalTablesHandler handler(context, params); + params.load(request, request.getStream(), handler); + } + + return predefined_query; +} + +HTTPRequestHandlerFactoryPtr createDynamicHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix) +{ + auto query_param_name = config.getString(config_prefix + ".handler.query_param_name", "query"); + + std::optional<String> content_type_override; + if (config.has(config_prefix + ".handler.content_type")) + content_type_override = config.getString(config_prefix + ".handler.content_type"); + + auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<DynamicQueryHandler>>( + server, std::move(query_param_name), std::move(content_type_override)); + + factory->addFiltersFromConfig(config, config_prefix); + + return factory; +} + +static inline bool capturingNamedQueryParam(NameSet receive_params, const CompiledRegexPtr & compiled_regex) +{ + const auto & capturing_names = compiled_regex->NamedCapturingGroups(); + return std::count_if(capturing_names.begin(), capturing_names.end(), [&](const auto & iterator) + { + return std::count_if(receive_params.begin(), receive_params.end(), + [&](const auto & param_name) { return param_name == iterator.first; }); + }); +} + +static inline CompiledRegexPtr getCompiledRegex(const std::string & expression) +{ + auto compiled_regex = std::make_shared<const re2::RE2>(expression); + + if (!compiled_regex->ok()) + throw Exception(ErrorCodes::CANNOT_COMPILE_REGEXP, "Cannot compile re2: {} for http handling rule, error: {}. " + "Look at https://github.com/google/re2/wiki/Syntax for reference.", expression, compiled_regex->error()); + + return compiled_regex; +} + +HTTPRequestHandlerFactoryPtr createPredefinedHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix) +{ + if (!config.has(config_prefix + ".handler.query")) + throw Exception(ErrorCodes::NO_ELEMENTS_IN_CONFIG, "There is no path '{}.handler.query' in configuration file.", config_prefix); + + std::string predefined_query = config.getString(config_prefix + ".handler.query"); + NameSet analyze_receive_params = analyzeReceiveQueryParams(predefined_query); + + std::unordered_map<String, CompiledRegexPtr> headers_name_with_regex; + Poco::Util::AbstractConfiguration::Keys headers_name; + config.keys(config_prefix + ".headers", headers_name); + + for (const auto & header_name : headers_name) + { + auto expression = config.getString(config_prefix + ".headers." + header_name); + + if (!startsWith(expression, "regex:")) + continue; + + expression = expression.substr(6); + auto regex = getCompiledRegex(expression); + if (capturingNamedQueryParam(analyze_receive_params, regex)) + headers_name_with_regex.emplace(std::make_pair(header_name, regex)); + } + + std::optional<String> content_type_override; + if (config.has(config_prefix + ".handler.content_type")) + content_type_override = config.getString(config_prefix + ".handler.content_type"); + + std::shared_ptr<HandlingRuleHTTPHandlerFactory<PredefinedQueryHandler>> factory; + + if (config.has(config_prefix + ".url")) + { + auto url_expression = config.getString(config_prefix + ".url"); + + if (startsWith(url_expression, "regex:")) + url_expression = url_expression.substr(6); + + auto regex = getCompiledRegex(url_expression); + if (capturingNamedQueryParam(analyze_receive_params, regex)) + { + factory = std::make_shared<HandlingRuleHTTPHandlerFactory<PredefinedQueryHandler>>( + server, + std::move(analyze_receive_params), + std::move(predefined_query), + std::move(regex), + std::move(headers_name_with_regex), + std::move(content_type_override)); + factory->addFiltersFromConfig(config, config_prefix); + return factory; + } + } + + factory = std::make_shared<HandlingRuleHTTPHandlerFactory<PredefinedQueryHandler>>( + server, + std::move(analyze_receive_params), + std::move(predefined_query), + CompiledRegexPtr{}, + std::move(headers_name_with_regex), + std::move(content_type_override)); + factory->addFiltersFromConfig(config, config_prefix); + + return factory; +} + +} diff --git a/contrib/clickhouse/src/Server/HTTPHandler.h b/contrib/clickhouse/src/Server/HTTPHandler.h new file mode 100644 index 0000000000..5eda592753 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPHandler.h @@ -0,0 +1,173 @@ +#pragma once + +#include <Core/Names.h> +#include <Server/HTTP/HTMLForm.h> +#include <Server/HTTP/HTTPRequestHandler.h> +#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h> +#include <Common/CurrentMetrics.h> +#include <Common/CurrentThread.h> + +#include <re2/re2.h> + +namespace CurrentMetrics +{ + extern const Metric HTTPConnection; +} + +namespace Poco { class Logger; } + +namespace DB +{ + +class Session; +class Credentials; +class IServer; +struct Settings; +class WriteBufferFromHTTPServerResponse; + +using CompiledRegexPtr = std::shared_ptr<const re2::RE2>; + +class HTTPHandler : public HTTPRequestHandler +{ +public: + HTTPHandler(IServer & server_, const std::string & name, const std::optional<String> & content_type_override_); + virtual ~HTTPHandler() override; + + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; + + /// This method is called right before the query execution. + virtual void customizeContext(HTTPServerRequest & /* request */, ContextMutablePtr /* context */, ReadBuffer & /* body */) {} + + virtual bool customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) = 0; + + virtual std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) = 0; + +private: + struct Output + { + /* Raw data + * ↓ + * CascadeWriteBuffer out_maybe_delayed_and_compressed (optional) + * ↓ (forwards data if an overflow is occur or explicitly via pushDelayedResults) + * CompressedWriteBuffer out_maybe_compressed (optional) + * ↓ + * WriteBufferFromHTTPServerResponse out + */ + + std::shared_ptr<WriteBufferFromHTTPServerResponse> out; + /// Points to 'out' or to CompressedWriteBuffer(*out), depending on settings. + std::shared_ptr<WriteBuffer> out_maybe_compressed; + /// Points to 'out' or to CompressedWriteBuffer(*out) or to CascadeWriteBuffer. + std::shared_ptr<WriteBuffer> out_maybe_delayed_and_compressed; + + bool finalized = false; + + inline bool hasDelayed() const + { + return out_maybe_delayed_and_compressed != out_maybe_compressed; + } + + inline void finalize() + { + if (finalized) + return; + finalized = true; + + if (out_maybe_delayed_and_compressed) + out_maybe_delayed_and_compressed->finalize(); + if (out_maybe_compressed) + out_maybe_compressed->finalize(); + if (out) + out->finalize(); + } + + inline bool isFinalized() const + { + return finalized; + } + }; + + IServer & server; + Poco::Logger * log; + + /// It is the name of the server that will be sent in an http-header X-ClickHouse-Server-Display-Name. + String server_display_name; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection}; + + /// Reference to the immutable settings in the global context. + /// Those settings are used only to extract a http request's parameters. + /// See settings http_max_fields, http_max_field_name_size, http_max_field_value_size in HTMLForm. + const Settings & default_settings; + + /// Overrides Content-Type provided by the format of the response. + std::optional<String> content_type_override; + + // session is reset at the end of each request/response. + std::unique_ptr<Session> session; + + // The request_credential instance may outlive a single request/response loop. + // This happens only when the authentication mechanism requires more than a single request/response exchange (e.g., SPNEGO). + std::unique_ptr<Credentials> request_credentials; + + // Returns true when the user successfully authenticated, + // the session instance will be configured accordingly, and the request_credentials instance will be dropped. + // Returns false when the user is not authenticated yet, and the 'Negotiate' response is sent, + // the session and request_credentials instances are preserved. + // Throws an exception if authentication failed. + bool authenticateUser( + HTTPServerRequest & request, + HTMLForm & params, + HTTPServerResponse & response); + + /// Also initializes 'used_output'. + void processQuery( + HTTPServerRequest & request, + HTMLForm & params, + HTTPServerResponse & response, + Output & used_output, + std::optional<CurrentThread::QueryScope> & query_scope); + + void trySendExceptionToClient( + const std::string & s, + int exception_code, + HTTPServerRequest & request, + HTTPServerResponse & response, + Output & used_output); + + static void pushDelayedResults(Output & used_output); +}; + +class DynamicQueryHandler : public HTTPHandler +{ +private: + std::string param_name; +public: + explicit DynamicQueryHandler(IServer & server_, const std::string & param_name_ = "query", const std::optional<String>& content_type_override_ = std::nullopt); + + std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) override; + + bool customizeQueryParam(ContextMutablePtr context, const std::string &key, const std::string &value) override; +}; + +class PredefinedQueryHandler : public HTTPHandler +{ +private: + NameSet receive_params; + std::string predefined_query; + CompiledRegexPtr url_regex; + std::unordered_map<String, CompiledRegexPtr> header_name_with_capture_regex; +public: + PredefinedQueryHandler( + IServer & server_, const NameSet & receive_params_, const std::string & predefined_query_ + , const CompiledRegexPtr & url_regex_, const std::unordered_map<String, CompiledRegexPtr> & header_name_with_regex_ + , const std::optional<std::string> & content_type_override_); + + void customizeContext(HTTPServerRequest & request, ContextMutablePtr context, ReadBuffer & body) override; + + std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) override; + + bool customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) override; +}; + +} diff --git a/contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp b/contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp new file mode 100644 index 0000000000..1c911034da --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPHandlerFactory.cpp @@ -0,0 +1,186 @@ +#include <Server/HTTPHandlerFactory.h> + +#include <Server/HTTP/HTTPRequestHandler.h> +#include <Server/IServer.h> +#include <Access/Credentials.h> + +#include <Poco/Util/AbstractConfiguration.h> + +#include "HTTPHandler.h" +#include "NotFoundHandler.h" +#include "StaticRequestHandler.h" +#include "ReplicasStatusHandler.h" +#include "InterserverIOHTTPHandler.h" +#include "PrometheusRequestHandler.h" +#include "WebUIRequestHandler.h" + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_ELEMENT_IN_CONFIG; + extern const int INVALID_CONFIG_PARAMETER; +} + +static void addCommonDefaultHandlersFactory(HTTPRequestHandlerFactoryMain & factory, IServer & server); +static void addDefaultHandlersFactory( + HTTPRequestHandlerFactoryMain & factory, + IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics); + +static inline auto createHandlersFactoryFromConfig( + IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & name, + const String & prefix, + AsynchronousMetrics & async_metrics) +{ + auto main_handler_factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name); + + Poco::Util::AbstractConfiguration::Keys keys; + config.keys(prefix, keys); + + for (const auto & key : keys) + { + if (key == "defaults") + { + addDefaultHandlersFactory(*main_handler_factory, server, config, async_metrics); + } + else if (startsWith(key, "rule")) + { + const auto & handler_type = config.getString(prefix + "." + key + ".handler.type", ""); + + if (handler_type.empty()) + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "Handler type in config is not specified here: " + "{}.{}.handler.type", prefix, key); + + if (handler_type == "static") + main_handler_factory->addHandler(createStaticHandlerFactory(server, config, prefix + "." + key)); + else if (handler_type == "dynamic_query_handler") + main_handler_factory->addHandler(createDynamicHandlerFactory(server, config, prefix + "." + key)); + else if (handler_type == "predefined_query_handler") + main_handler_factory->addHandler(createPredefinedHandlerFactory(server, config, prefix + "." + key)); + else if (handler_type == "prometheus") + main_handler_factory->addHandler(createPrometheusHandlerFactory(server, config, async_metrics, prefix + "." + key)); + else if (handler_type == "replicas_status") + main_handler_factory->addHandler(createReplicasStatusHandlerFactory(server, config, prefix + "." + key)); + else + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "Unknown handler type '{}' in config here: {}.{}.handler.type", + handler_type, prefix, key); + } + else + throw Exception(ErrorCodes::UNKNOWN_ELEMENT_IN_CONFIG, "Unknown element in config: " + "{}.{}, must be 'rule' or 'defaults'", prefix, key); + } + + return main_handler_factory; +} + +static inline HTTPRequestHandlerFactoryPtr +createHTTPHandlerFactory(IServer & server, const Poco::Util::AbstractConfiguration & config, const std::string & name, AsynchronousMetrics & async_metrics) +{ + if (config.has("http_handlers")) + { + return createHandlersFactoryFromConfig(server, config, name, "http_handlers", async_metrics); + } + else + { + auto factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name); + addDefaultHandlersFactory(*factory, server, config, async_metrics); + return factory; + } +} + +static inline HTTPRequestHandlerFactoryPtr createInterserverHTTPHandlerFactory(IServer & server, const std::string & name) +{ + auto factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name); + addCommonDefaultHandlersFactory(*factory, server); + + auto main_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<InterserverIOHTTPHandler>>(server); + main_handler->allowPostAndGetParamsAndOptionsRequest(); + factory->addHandler(main_handler); + + return factory; +} + +HTTPRequestHandlerFactoryPtr createHandlerFactory(IServer & server, const Poco::Util::AbstractConfiguration & config, AsynchronousMetrics & async_metrics, const std::string & name) +{ + if (name == "HTTPHandler-factory" || name == "HTTPSHandler-factory") + return createHTTPHandlerFactory(server, config, name, async_metrics); + else if (name == "InterserverIOHTTPHandler-factory" || name == "InterserverIOHTTPSHandler-factory") + return createInterserverHTTPHandlerFactory(server, name); + else if (name == "PrometheusHandler-factory") + return createPrometheusMainHandlerFactory(server, config, async_metrics, name); + + throw Exception(ErrorCodes::LOGICAL_ERROR, "LOGICAL ERROR: Unknown HTTP handler factory name."); +} + +static const auto ping_response_expression = "Ok.\n"; +static const auto root_response_expression = "config://http_server_default_response"; + +void addCommonDefaultHandlersFactory(HTTPRequestHandlerFactoryMain & factory, IServer & server) +{ + auto root_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<StaticRequestHandler>>(server, root_response_expression); + root_handler->attachStrictPath("/"); + root_handler->allowGetAndHeadRequest(); + factory.addHandler(root_handler); + + auto ping_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<StaticRequestHandler>>(server, ping_response_expression); + ping_handler->attachStrictPath("/ping"); + ping_handler->allowGetAndHeadRequest(); + factory.addPathToHints("/ping"); + factory.addHandler(ping_handler); + + auto replicas_status_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<ReplicasStatusHandler>>(server); + replicas_status_handler->attachNonStrictPath("/replicas_status"); + replicas_status_handler->allowGetAndHeadRequest(); + factory.addPathToHints("/replicas_status"); + factory.addHandler(replicas_status_handler); + + auto play_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<WebUIRequestHandler>>(server); + play_handler->attachNonStrictPath("/play"); + play_handler->allowGetAndHeadRequest(); + factory.addPathToHints("/play"); + factory.addHandler(play_handler); + + auto dashboard_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<WebUIRequestHandler>>(server); + dashboard_handler->attachNonStrictPath("/dashboard"); + dashboard_handler->allowGetAndHeadRequest(); + factory.addPathToHints("/dashboard"); + factory.addHandler(dashboard_handler); + + auto js_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<WebUIRequestHandler>>(server); + js_handler->attachNonStrictPath("/js/"); + js_handler->allowGetAndHeadRequest(); + factory.addHandler(js_handler); +} + +void addDefaultHandlersFactory( + HTTPRequestHandlerFactoryMain & factory, + IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics) +{ + addCommonDefaultHandlersFactory(factory, server); + + auto query_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<DynamicQueryHandler>>(server, "query"); + query_handler->allowPostAndGetParamsAndOptionsRequest(); + factory.addHandler(query_handler); + + /// We check that prometheus handler will be served on current (default) port. + /// Otherwise it will be created separately, see createHandlerFactory(...). + if (config.has("prometheus") && config.getInt("prometheus.port", 0) == 0) + { + auto prometheus_handler = std::make_shared<HandlingRuleHTTPHandlerFactory<PrometheusRequestHandler>>( + server, PrometheusMetricsWriter(config, "prometheus", async_metrics)); + prometheus_handler->attachStrictPath(config.getString("prometheus.endpoint", "/metrics")); + prometheus_handler->allowGetAndHeadRequest(); + factory.addHandler(prometheus_handler); + } +} + +} diff --git a/contrib/clickhouse/src/Server/HTTPHandlerFactory.h b/contrib/clickhouse/src/Server/HTTPHandlerFactory.h new file mode 100644 index 0000000000..fe11833dc3 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPHandlerFactory.h @@ -0,0 +1,148 @@ +#pragma once + +#include <Common/AsynchronousMetrics.h> +#include <Server/HTTP/HTMLForm.h> +#include <Server/HTTP/HTTPRequestHandlerFactory.h> +#include <Server/HTTPHandlerRequestFilter.h> +#include <Server/HTTPRequestHandlerFactoryMain.h> +#include <Common/StringUtils/StringUtils.h> + +#include <Poco/Util/AbstractConfiguration.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_ELEMENT_IN_CONFIG; +} + +class IServer; + +template <typename TEndpoint> +class HandlingRuleHTTPHandlerFactory : public HTTPRequestHandlerFactory +{ +public: + using Filter = std::function<bool(const HTTPServerRequest &)>; + + template <typename... TArgs> + explicit HandlingRuleHTTPHandlerFactory(TArgs &&... args) + { + creator = [my_args = std::tuple<TArgs...>(std::forward<TArgs>(args) ...)]() + { + return std::apply([&](auto && ... endpoint_args) + { + return std::make_unique<TEndpoint>(std::forward<decltype(endpoint_args)>(endpoint_args)...); + }, std::move(my_args)); + }; + } + + void addFilter(Filter cur_filter) + { + Filter prev_filter = filter; + filter = [prev_filter, cur_filter](const auto & request) + { + return prev_filter ? prev_filter(request) && cur_filter(request) : cur_filter(request); + }; + } + + void addFiltersFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & prefix) + { + Poco::Util::AbstractConfiguration::Keys filters_type; + config.keys(prefix, filters_type); + + for (const auto & filter_type : filters_type) + { + if (filter_type == "handler") + continue; + else if (filter_type == "url") + addFilter(urlFilter(config, prefix + ".url")); + else if (filter_type == "headers") + addFilter(headersFilter(config, prefix + ".headers")); + else if (filter_type == "methods") + addFilter(methodsFilter(config, prefix + ".methods")); + else + throw Exception(ErrorCodes::UNKNOWN_ELEMENT_IN_CONFIG, "Unknown element in config: {}.{}", prefix, filter_type); + } + } + + void attachStrictPath(const String & strict_path) + { + addFilter([strict_path](const auto & request) { return request.getURI() == strict_path; }); + } + + void attachNonStrictPath(const String & non_strict_path) + { + addFilter([non_strict_path](const auto & request) { return startsWith(request.getURI(), non_strict_path); }); + } + + /// Handle GET or HEAD endpoint on specified path + void allowGetAndHeadRequest() + { + addFilter([](const auto & request) + { + return request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET + || request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD; + }); + } + + /// Handle Post request or (Get or Head) with params or OPTIONS requests + void allowPostAndGetParamsAndOptionsRequest() + { + addFilter([](const auto & request) + { + return (request.getURI().find('?') != std::string::npos + && (request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET + || request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD)) + || request.getMethod() == Poco::Net::HTTPRequest::HTTP_OPTIONS + || request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST; + }); + } + + std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override + { + return filter(request) ? creator() : nullptr; + } + +private: + Filter filter; + std::function<std::unique_ptr<HTTPRequestHandler> ()> creator; +}; + +HTTPRequestHandlerFactoryPtr createStaticHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + +HTTPRequestHandlerFactoryPtr createDynamicHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + +HTTPRequestHandlerFactoryPtr createPredefinedHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + +HTTPRequestHandlerFactoryPtr createReplicasStatusHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + +HTTPRequestHandlerFactoryPtr +createPrometheusHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics, + const std::string & config_prefix); + +HTTPRequestHandlerFactoryPtr +createPrometheusMainHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics, + const std::string & name); + +/// @param server - used in handlers to check IServer::isCancelled() +/// @param config - not the same as server.config(), since it can be newer +/// @param async_metrics - used for prometheus (in case of prometheus.asynchronous_metrics=true) +HTTPRequestHandlerFactoryPtr createHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics, + const std::string & name); + +} diff --git a/contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h b/contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h new file mode 100644 index 0000000000..25cbb95087 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPHandlerRequestFilter.h @@ -0,0 +1,102 @@ +#pragma once + +#include <Server/HTTP/HTTPServerRequest.h> +#include <Common/Exception.h> +#include <Common/StringUtils/StringUtils.h> +#include <base/find_symbols.h> + +#include <re2/re2.h> +#include <Poco/StringTokenizer.h> +#include <Poco/Util/LayeredConfiguration.h> + +#include <unordered_map> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_COMPILE_REGEXP; +} + +using CompiledRegexPtr = std::shared_ptr<const re2::RE2>; + +static inline bool checkRegexExpression(std::string_view match_str, const CompiledRegexPtr & compiled_regex) +{ + int num_captures = compiled_regex->NumberOfCapturingGroups() + 1; + + std::string_view matches[num_captures]; + return compiled_regex->Match({match_str.data(), match_str.size()}, 0, match_str.size(), re2::RE2::Anchor::ANCHOR_BOTH, matches, num_captures); +} + +static inline bool checkExpression(std::string_view match_str, const std::pair<String, CompiledRegexPtr> & expression) +{ + if (expression.second) + return checkRegexExpression(match_str, expression.second); + + return match_str == expression.first; +} + +static inline auto methodsFilter(const Poco::Util::AbstractConfiguration & config, const std::string & config_path) /// NOLINT +{ + std::vector<String> methods; + Poco::StringTokenizer tokenizer(config.getString(config_path), ","); + + for (const auto & iterator : tokenizer) + methods.emplace_back(Poco::toUpper(Poco::trim(iterator))); + + return [methods](const HTTPServerRequest & request) { return std::count(methods.begin(), methods.end(), request.getMethod()); }; +} + +static inline auto getExpression(const std::string & expression) +{ + if (!startsWith(expression, "regex:")) + return std::make_pair(expression, CompiledRegexPtr{}); + + auto compiled_regex = std::make_shared<const re2::RE2>(expression.substr(6)); + + if (!compiled_regex->ok()) + throw Exception(ErrorCodes::CANNOT_COMPILE_REGEXP, "cannot compile re2: {} for http handling rule, error: {}. " + "Look at https://github.com/google/re2/wiki/Syntax for reference.", + expression, compiled_regex->error()); + return std::make_pair(expression, compiled_regex); +} + +static inline auto urlFilter(const Poco::Util::AbstractConfiguration & config, const std::string & config_path) /// NOLINT +{ + return [expression = getExpression(config.getString(config_path))](const HTTPServerRequest & request) + { + const auto & uri = request.getURI(); + const auto & end = find_first_symbols<'?'>(uri.data(), uri.data() + uri.size()); + + return checkExpression(std::string_view(uri.data(), end - uri.data()), expression); + }; +} + +static inline auto headersFilter(const Poco::Util::AbstractConfiguration & config, const std::string & prefix) /// NOLINT +{ + std::unordered_map<String, std::pair<String, CompiledRegexPtr>> headers_expression; + Poco::Util::AbstractConfiguration::Keys headers_name; + config.keys(prefix, headers_name); + + for (const auto & header_name : headers_name) + { + const auto & expression = getExpression(config.getString(prefix + "." + header_name)); + checkExpression("", expression); /// Check expression syntax is correct + headers_expression.emplace(std::make_pair(header_name, expression)); + } + + return [headers_expression](const HTTPServerRequest & request) + { + for (const auto & [header_name, header_expression] : headers_expression) + { + const auto & header_value = request.get(header_name, ""); + if (!checkExpression(std::string_view(header_value.data(), header_value.size()), header_expression)) + return false; + } + + return true; + }; +} + +} diff --git a/contrib/clickhouse/src/Server/HTTPPathHints.cpp b/contrib/clickhouse/src/Server/HTTPPathHints.cpp new file mode 100644 index 0000000000..51ef3eabff --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPPathHints.cpp @@ -0,0 +1,16 @@ +#include <Server/HTTPPathHints.h> + +namespace DB +{ + +void HTTPPathHints::add(const String & http_path) +{ + http_paths.push_back(http_path); +} + +std::vector<String> HTTPPathHints::getAllRegisteredNames() const +{ + return http_paths; +} + +} diff --git a/contrib/clickhouse/src/Server/HTTPPathHints.h b/contrib/clickhouse/src/Server/HTTPPathHints.h new file mode 100644 index 0000000000..708816ebf0 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPPathHints.h @@ -0,0 +1,22 @@ +#pragma once + +#include <base/types.h> + +#include <Common/NamePrompter.h> + +namespace DB +{ + +class HTTPPathHints : public IHints<1, HTTPPathHints> +{ +public: + std::vector<String> getAllRegisteredNames() const override; + void add(const String & http_path); + +private: + std::vector<String> http_paths; +}; + +using HTTPPathHintsPtr = std::shared_ptr<HTTPPathHints>; + +} diff --git a/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp new file mode 100644 index 0000000000..5481bcd508 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.cpp @@ -0,0 +1,38 @@ +#include <Server/HTTPRequestHandlerFactoryMain.h> +#include <Server/NotFoundHandler.h> + +#include <Common/logger_useful.h> + +namespace DB +{ + +HTTPRequestHandlerFactoryMain::HTTPRequestHandlerFactoryMain(const std::string & name_) + : log(&Poco::Logger::get(name_)), name(name_) +{ +} + +std::unique_ptr<HTTPRequestHandler> HTTPRequestHandlerFactoryMain::createRequestHandler(const HTTPServerRequest & request) +{ + LOG_TRACE(log, "HTTP Request for {}. Method: {}, Address: {}, User-Agent: {}{}, Content Type: {}, Transfer Encoding: {}, X-Forwarded-For: {}", + name, request.getMethod(), request.clientAddress().toString(), request.get("User-Agent", "(none)"), + (request.hasContentLength() ? (", Length: " + std::to_string(request.getContentLength())) : ("")), + request.getContentType(), request.getTransferEncoding(), request.get("X-Forwarded-For", "(none)")); + + for (auto & handler_factory : child_factories) + { + auto handler = handler_factory->createRequestHandler(request); + if (handler) + return handler; + } + + if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET + || request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD + || request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST) + { + return std::unique_ptr<HTTPRequestHandler>(new NotFoundHandler(hints.getHints(request.getURI()))); + } + + return nullptr; +} + +} diff --git a/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h new file mode 100644 index 0000000000..07b278d831 --- /dev/null +++ b/contrib/clickhouse/src/Server/HTTPRequestHandlerFactoryMain.h @@ -0,0 +1,31 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandlerFactory.h> +#include <Server/HTTPPathHints.h> + +#include <vector> + +namespace DB +{ + +/// Handle request using child handlers +class HTTPRequestHandlerFactoryMain : public HTTPRequestHandlerFactory +{ +public: + explicit HTTPRequestHandlerFactoryMain(const std::string & name_); + + void addHandler(HTTPRequestHandlerFactoryPtr child_factory) { child_factories.emplace_back(child_factory); } + + void addPathToHints(const std::string & http_path) { hints.add(http_path); } + + std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override; + +private: + Poco::Logger * log; + std::string name; + HTTPPathHints hints; + + std::vector<HTTPRequestHandlerFactoryPtr> child_factories; +}; + +} diff --git a/contrib/clickhouse/src/Server/IServer.h b/contrib/clickhouse/src/Server/IServer.h new file mode 100644 index 0000000000..c55b045d2a --- /dev/null +++ b/contrib/clickhouse/src/Server/IServer.h @@ -0,0 +1,39 @@ +#pragma once + +#include <Interpreters/Context_fwd.h> + +namespace Poco +{ + +namespace Util +{ +class LayeredConfiguration; +} + +class Logger; + +} + + +namespace DB +{ + +class IServer +{ +public: + /// Returns the application's configuration. + virtual Poco::Util::LayeredConfiguration & config() const = 0; + + /// Returns the application's logger. + virtual Poco::Logger & logger() const = 0; + + /// Returns global application's context. + virtual ContextMutablePtr context() const = 0; + + /// Returns true if shutdown signaled. + virtual bool isCancelled() const = 0; + + virtual ~IServer() = default; +}; + +} diff --git a/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp new file mode 100644 index 0000000000..9741592868 --- /dev/null +++ b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.cpp @@ -0,0 +1,163 @@ +#include <Server/InterserverIOHTTPHandler.h> + +#include <Server/IServer.h> + +#include <Compression/CompressedWriteBuffer.h> +#include <IO/ReadBufferFromIStream.h> +#include <Interpreters/Context.h> +#include <Interpreters/InterserverIOHandler.h> +#include <Server/HTTP/HTMLForm.h> +#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h> +#include <Common/setThreadName.h> +#include <Common/logger_useful.h> + +#include <Poco/Net/HTTPBasicCredentials.h> +#include <Poco/Util/LayeredConfiguration.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ABORTED; + extern const int TOO_MANY_SIMULTANEOUS_QUERIES; +} + +std::pair<String, bool> InterserverIOHTTPHandler::checkAuthentication(HTTPServerRequest & request) const +{ + auto server_credentials = server.context()->getInterserverCredentials(); + if (server_credentials) + { + if (!request.hasCredentials()) + return server_credentials->isValidUser("", ""); + + String scheme, info; + request.getCredentials(scheme, info); + + if (scheme != "Basic") + return {"Server requires HTTP Basic authentication but client provides another method", false}; + + Poco::Net::HTTPBasicCredentials credentials(info); + return server_credentials->isValidUser(credentials.getUsername(), credentials.getPassword()); + } + else if (request.hasCredentials()) + { + return {"Client requires HTTP Basic authentication, but server doesn't provide it", false}; + } + + return {"", true}; +} + +void InterserverIOHTTPHandler::processQuery(HTTPServerRequest & request, HTTPServerResponse & response, Output & used_output) +{ + HTMLForm params(server.context()->getSettingsRef(), request); + + LOG_TRACE(log, "Request URI: {}", request.getURI()); + + String endpoint_name = params.get("endpoint"); + bool compress = params.get("compress") == "true"; + + auto & body = request.getStream(); + + auto endpoint = server.context()->getInterserverIOHandler().getEndpoint(endpoint_name); + /// Locked for read while query processing + std::shared_lock lock(endpoint->rwlock); + if (endpoint->blocker.isCancelled()) + throw Exception(ErrorCodes::ABORTED, "Transferring part to replica was cancelled"); + + if (compress) + { + CompressedWriteBuffer compressed_out(*used_output.out); + endpoint->processQuery(params, body, compressed_out, response); + } + else + { + endpoint->processQuery(params, body, *used_output.out, response); + } +} + + +void InterserverIOHTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + setThreadName("IntersrvHandler"); + ThreadStatus thread_status; + + /// In order to work keep-alive. + if (request.getVersion() == HTTPServerRequest::HTTP_1_1) + response.setChunkedTransferEncoding(true); + + Output used_output; + const auto & config = server.config(); + unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10); + used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>( + response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout); + + auto write_response = [&](const std::string & message) + { + auto & out = *used_output.out; + if (response.sent()) + { + out.finalize(); + return; + } + + try + { + writeString(message, out); + out.finalize(); + } + catch (...) + { + tryLogCurrentException(log); + out.finalize(); + } + }; + + try + { + if (auto [message, success] = checkAuthentication(request); success) + { + processQuery(request, response, used_output); + used_output.out->finalize(); + LOG_DEBUG(log, "Done processing query"); + } + else + { + response.setStatusAndReason(HTTPServerResponse::HTTP_UNAUTHORIZED); + write_response(message); + LOG_WARNING(log, "Query processing failed request: '{}' authentication failed", request.getURI()); + } + } + catch (Exception & e) + { + if (e.code() == ErrorCodes::TOO_MANY_SIMULTANEOUS_QUERIES) + { + used_output.out->finalize(); + return; + } + + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR); + + /// Sending to remote server was cancelled due to server shutdown or drop table. + bool is_real_error = e.code() != ErrorCodes::ABORTED; + + PreformattedMessage message = getCurrentExceptionMessageAndPattern(is_real_error); + write_response(message.text); + + if (is_real_error) + LOG_ERROR(log, message); + else + LOG_INFO(log, message); + } + catch (...) + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR); + PreformattedMessage message = getCurrentExceptionMessageAndPattern(/* with_stacktrace */ false); + write_response(message.text); + + LOG_ERROR(log, message); + } +} + + +} diff --git a/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h new file mode 100644 index 0000000000..da5b286b9e --- /dev/null +++ b/contrib/clickhouse/src/Server/InterserverIOHTTPHandler.h @@ -0,0 +1,51 @@ +#pragma once + +#include <Interpreters/InterserverCredentials.h> +#include <Server/HTTP/HTTPRequestHandler.h> +#include <Common/CurrentMetrics.h> + +#include <Poco/Logger.h> + +#include <memory> +#include <string> + + +namespace CurrentMetrics +{ + extern const Metric InterserverConnection; +} + +namespace DB +{ + +class IServer; +class WriteBufferFromHTTPServerResponse; + +class InterserverIOHTTPHandler : public HTTPRequestHandler +{ +public: + explicit InterserverIOHTTPHandler(IServer & server_) + : server(server_) + , log(&Poco::Logger::get("InterserverIOHTTPHandler")) + { + } + + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; + +private: + struct Output + { + std::shared_ptr<WriteBufferFromHTTPServerResponse> out; + }; + + IServer & server; + Poco::Logger * log; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::InterserverConnection}; + + void processQuery(HTTPServerRequest & request, HTTPServerResponse & response, Output & used_output); + + std::pair<String, bool> checkAuthentication(HTTPServerRequest & request) const; +}; + +} diff --git a/contrib/clickhouse/src/Server/KeeperTCPHandler.cpp b/contrib/clickhouse/src/Server/KeeperTCPHandler.cpp new file mode 100644 index 0000000000..84ed738850 --- /dev/null +++ b/contrib/clickhouse/src/Server/KeeperTCPHandler.cpp @@ -0,0 +1,695 @@ +#include <Server/KeeperTCPHandler.h> + +#if USE_NURAFT + +#include <Common/ZooKeeper/ZooKeeperIO.h> +#include <Core/Types.h> +#include <IO/WriteBufferFromPocoSocket.h> +#include <IO/ReadBufferFromPocoSocket.h> +#include <Poco/Net/NetException.h> +#include <Common/CurrentThread.h> +#include <Common/Stopwatch.h> +#include <Common/NetException.h> +#include <Common/setThreadName.h> +#include <Common/logger_useful.h> +#include <base/defines.h> +#include <chrono> +#include <Common/PipeFDs.h> +#include <Poco/Util/AbstractConfiguration.h> +#include <IO/ReadBufferFromFileDescriptor.h> +#include <queue> +#include <mutex> +#include <Coordination/FourLetterCommand.h> +#include <base/hex.h> + + +#ifdef POCO_HAVE_FD_EPOLL + #include <sys/epoll.h> +#else + #include <poll.h> +#endif + + +namespace DB +{ + +struct LastOp +{ +public: + String name{"NA"}; + int64_t last_cxid{-1}; + int64_t last_zxid{-1}; + int64_t last_response_time{0}; +}; + +static const LastOp EMPTY_LAST_OP {"NA", -1, -1, 0}; + +namespace ErrorCodes +{ + extern const int SYSTEM_ERROR; + extern const int LOGICAL_ERROR; + extern const int UNEXPECTED_PACKET_FROM_CLIENT; + extern const int TIMEOUT_EXCEEDED; +} + +struct PollResult +{ + size_t responses_count{0}; + bool has_requests{false}; + bool error{false}; +}; + +struct SocketInterruptablePollWrapper +{ + int sockfd; + PipeFDs pipe; + ReadBufferFromFileDescriptor response_in; + +#if defined(POCO_HAVE_FD_EPOLL) + int epollfd; + epoll_event socket_event{}; + epoll_event pipe_event{}; +#endif + + using InterruptCallback = std::function<void()>; + + explicit SocketInterruptablePollWrapper(const Poco::Net::StreamSocket & poco_socket_) + : sockfd(poco_socket_.impl()->sockfd()) + , response_in(pipe.fds_rw[0]) + { + pipe.setNonBlockingReadWrite(); + +#if defined(POCO_HAVE_FD_EPOLL) + epollfd = epoll_create(2); + if (epollfd < 0) + throwFromErrno("Cannot epoll_create", ErrorCodes::SYSTEM_ERROR); + + socket_event.events = EPOLLIN | EPOLLERR | EPOLLPRI; + socket_event.data.fd = sockfd; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &socket_event) < 0) + { + int err = ::close(epollfd); + chassert(!err || errno == EINTR); + + throwFromErrno("Cannot insert socket into epoll queue", ErrorCodes::SYSTEM_ERROR); + } + pipe_event.events = EPOLLIN | EPOLLERR | EPOLLPRI; + pipe_event.data.fd = pipe.fds_rw[0]; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, pipe.fds_rw[0], &pipe_event) < 0) + { + int err = ::close(epollfd); + chassert(!err || errno == EINTR); + + throwFromErrno("Cannot insert socket into epoll queue", ErrorCodes::SYSTEM_ERROR); + } +#endif + } + + int getResponseFD() const + { + return pipe.fds_rw[1]; + } + + PollResult poll(Poco::Timespan remaining_time, const std::shared_ptr<ReadBufferFromPocoSocket> & in) + { + + bool socket_ready = false; + bool fd_ready = false; + + if (in->available() != 0) + socket_ready = true; + + if (response_in.available() != 0) + fd_ready = true; + + int rc = 0; + if (!fd_ready) + { +#if defined(POCO_HAVE_FD_EPOLL) + epoll_event evout[2]; + evout[0].data.fd = evout[1].data.fd = -1; + do + { + Poco::Timestamp start; + /// TODO: use epoll_pwait() for more precise timers + rc = epoll_wait(epollfd, evout, 2, static_cast<int>(remaining_time.totalMilliseconds())); + if (rc < 0 && errno == EINTR) + { + Poco::Timestamp end; + Poco::Timespan waited = end - start; + if (waited < remaining_time) + remaining_time -= waited; + else + remaining_time = 0; + } + } + while (rc < 0 && errno == EINTR); + + for (int i = 0; i < rc; ++i) + { + if (evout[i].data.fd == sockfd) + socket_ready = true; + if (evout[i].data.fd == pipe.fds_rw[0]) + fd_ready = true; + } +#else + pollfd poll_buf[2]; + poll_buf[0].fd = sockfd; + poll_buf[0].events = POLLIN; + poll_buf[1].fd = pipe.fds_rw[0]; + poll_buf[1].events = POLLIN; + + do + { + Poco::Timestamp start; + rc = ::poll(poll_buf, 2, static_cast<int>(remaining_time.totalMilliseconds())); + if (rc < 0 && errno == POCO_EINTR) + { + Poco::Timestamp end; + Poco::Timespan waited = end - start; + if (waited < remaining_time) + remaining_time -= waited; + else + remaining_time = 0; + } + } + while (rc < 0 && errno == POCO_EINTR); + + if (rc >= 1) + { + if (poll_buf[0].revents & POLLIN) + socket_ready = true; + if (poll_buf[1].revents & POLLIN) + fd_ready = true; + } +#endif + } + + PollResult result{}; + result.has_requests = socket_ready; + if (fd_ready) + { + UInt8 dummy; + readIntBinary(dummy, response_in); + result.responses_count = 1; + auto available = response_in.available(); + response_in.ignore(available); + result.responses_count += available; + } + + if (rc < 0) + result.error = true; + + return result; + } + +#if defined(POCO_HAVE_FD_EPOLL) + ~SocketInterruptablePollWrapper() + { + int err = ::close(epollfd); + chassert(!err || errno == EINTR); + } +#endif +}; + +KeeperTCPHandler::KeeperTCPHandler( + const Poco::Util::AbstractConfiguration & config_ref, + std::shared_ptr<KeeperDispatcher> keeper_dispatcher_, + Poco::Timespan receive_timeout_, + Poco::Timespan send_timeout_, + const Poco::Net::StreamSocket & socket_) + : Poco::Net::TCPServerConnection(socket_) + , log(&Poco::Logger::get("KeeperTCPHandler")) + , keeper_dispatcher(keeper_dispatcher_) + , operation_timeout( + 0, + config_ref.getUInt( + "keeper_server.coordination_settings.operation_timeout_ms", Coordination::DEFAULT_OPERATION_TIMEOUT_MS) * 1000) + , min_session_timeout( + 0, + config_ref.getUInt( + "keeper_server.coordination_settings.min_session_timeout_ms", Coordination::DEFAULT_MIN_SESSION_TIMEOUT_MS) * 1000) + , max_session_timeout( + 0, + config_ref.getUInt( + "keeper_server.coordination_settings.session_timeout_ms", Coordination::DEFAULT_MAX_SESSION_TIMEOUT_MS) * 1000) + , poll_wrapper(std::make_unique<SocketInterruptablePollWrapper>(socket_)) + , send_timeout(send_timeout_) + , receive_timeout(receive_timeout_) + , responses(std::make_unique<ThreadSafeResponseQueue>(std::numeric_limits<size_t>::max())) + , last_op(std::make_unique<LastOp>(EMPTY_LAST_OP)) +{ + KeeperTCPHandler::registerConnection(this); +} + +void KeeperTCPHandler::sendHandshake(bool has_leader) +{ + Coordination::write(Coordination::SERVER_HANDSHAKE_LENGTH, *out); + if (has_leader) + { + Coordination::write(Coordination::ZOOKEEPER_PROTOCOL_VERSION, *out); + } + else + { + /// Ignore connections if we are not leader, client will throw exception + /// and reconnect to another replica faster. ClickHouse client provide + /// clear message for such protocol version. + Coordination::write(Coordination::KEEPER_PROTOCOL_VERSION_CONNECTION_REJECT, *out); + } + + Coordination::write(static_cast<int32_t>(session_timeout.totalMilliseconds()), *out); + Coordination::write(session_id, *out); + std::array<char, Coordination::PASSWORD_LENGTH> passwd{}; + Coordination::write(passwd, *out); + out->next(); +} + +void KeeperTCPHandler::run() +{ + runImpl(); +} + +Poco::Timespan KeeperTCPHandler::receiveHandshake(int32_t handshake_length) +{ + int32_t protocol_version; + int64_t last_zxid_seen; + int32_t timeout_ms; + int64_t previous_session_id = 0; /// We don't support session restore. So previous session_id is always zero. + std::array<char, Coordination::PASSWORD_LENGTH> passwd {}; + + if (!isHandShake(handshake_length)) + throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected handshake length received: {}", toString(handshake_length)); + + Coordination::read(protocol_version, *in); + + if (protocol_version != Coordination::ZOOKEEPER_PROTOCOL_VERSION) + throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected protocol version: {}", toString(protocol_version)); + + Coordination::read(last_zxid_seen, *in); + Coordination::read(timeout_ms, *in); + + /// TODO Stop ignoring this value + Coordination::read(previous_session_id, *in); + Coordination::read(passwd, *in); + + int8_t readonly; + if (handshake_length == Coordination::CLIENT_HANDSHAKE_LENGTH_WITH_READONLY) + Coordination::read(readonly, *in); + + return Poco::Timespan(timeout_ms * 1000); +} + + +void KeeperTCPHandler::runImpl() +{ + setThreadName("KeeperHandler"); + ThreadStatus thread_status; + + socket().setReceiveTimeout(receive_timeout); + socket().setSendTimeout(send_timeout); + socket().setNoDelay(true); + + in = std::make_shared<ReadBufferFromPocoSocket>(socket()); + out = std::make_shared<WriteBufferFromPocoSocket>(socket()); + + if (in->eof()) + { + LOG_WARNING(log, "Client has not sent any data."); + return; + } + + int32_t header; + try + { + Coordination::read(header, *in); + } + catch (const Exception & e) + { + LOG_WARNING(log, "Error while read connection header {}", e.displayText()); + return; + } + + /// All four letter word command code is larger than 2^24 or lower than 0. + /// Hand shake package length must be lower than 2^24 and larger than 0. + /// So collision never happens. + int32_t four_letter_cmd = header; + if (!isHandShake(four_letter_cmd)) + { + connected.store(true, std::memory_order_relaxed); + tryExecuteFourLetterWordCmd(four_letter_cmd); + return; + } + + try + { + int32_t handshake_length = header; + auto client_timeout = receiveHandshake(handshake_length); + + if (client_timeout.totalMilliseconds() == 0) + client_timeout = Poco::Timespan(Coordination::DEFAULT_SESSION_TIMEOUT_MS * Poco::Timespan::MILLISECONDS); + session_timeout = std::max(client_timeout, min_session_timeout); + session_timeout = std::min(session_timeout, max_session_timeout); + } + catch (const Exception & e) /// Typical for an incorrect username, password, or address. + { + LOG_WARNING(log, "Cannot receive handshake {}", e.displayText()); + return; + } + + if (keeper_dispatcher->isServerActive()) + { + try + { + LOG_INFO(log, "Requesting session ID for the new client"); + session_id = keeper_dispatcher->getSessionID(session_timeout.totalMilliseconds()); + LOG_INFO(log, "Received session ID {}", session_id); + } + catch (const Exception & e) + { + LOG_WARNING(log, "Cannot receive session id {}", e.displayText()); + sendHandshake(false); + return; + + } + + sendHandshake(true); + } + else + { + LOG_WARNING(log, "Ignoring user request, because the server is not active yet"); + sendHandshake(false); + return; + } + + auto response_fd = poll_wrapper->getResponseFD(); + auto response_callback = [responses = this->responses, response_fd](const Coordination::ZooKeeperResponsePtr & response) + { + if (!responses->push(response)) + throw Exception(ErrorCodes::SYSTEM_ERROR, + "Could not push response with xid {} and zxid {}", + response->xid, + response->zxid); + + UInt8 single_byte = 1; + [[maybe_unused]] ssize_t result = write(response_fd, &single_byte, sizeof(single_byte)); + }; + keeper_dispatcher->registerSession(session_id, response_callback); + + Stopwatch logging_stopwatch; + auto log_long_operation = [&](const String & operation) + { + constexpr UInt64 operation_max_ms = 500; + auto elapsed_ms = logging_stopwatch.elapsedMilliseconds(); + if (operation_max_ms < elapsed_ms) + LOG_TEST(log, "{} for session {} took {} ms", operation, session_id, elapsed_ms); + logging_stopwatch.restart(); + }; + + session_stopwatch.start(); + connected.store(true, std::memory_order_release); + bool close_received = false; + + try + { + while (true) + { + using namespace std::chrono_literals; + + PollResult result = poll_wrapper->poll(session_timeout, in); + log_long_operation("Polling socket"); + if (result.has_requests && !close_received) + { + if (in->eof()) + { + LOG_DEBUG(log, "Client closed connection, session id #{}", session_id); + keeper_dispatcher->finishSession(session_id); + break; + } + + auto [received_op, received_xid] = receiveRequest(); + packageReceived(); + log_long_operation("Receiving request"); + + if (received_op == Coordination::OpNum::Close) + { + LOG_DEBUG(log, "Received close event with xid {} for session id #{}", received_xid, session_id); + close_xid = received_xid; + close_received = true; + } + else if (received_op == Coordination::OpNum::Heartbeat) + { + LOG_TRACE(log, "Received heartbeat for session #{}", session_id); + } + else + operations[received_xid] = Poco::Timestamp(); + + /// Each request restarts session stopwatch + session_stopwatch.restart(); + } + + /// Process exact amount of responses from pipe + /// otherwise state of responses queue and signaling pipe + /// became inconsistent and race condition is possible. + while (result.responses_count != 0) + { + Coordination::ZooKeeperResponsePtr response; + + if (!responses->tryPop(response)) + throw Exception(ErrorCodes::LOGICAL_ERROR, "We must have ready response, but queue is empty. It's a bug."); + log_long_operation("Waiting for response to be ready"); + + if (response->xid == close_xid) + { + LOG_DEBUG(log, "Session #{} successfully closed", session_id); + return; + } + + updateStats(response); + packageSent(); + + response->write(*out); + log_long_operation("Sending response"); + if (response->error == Coordination::Error::ZSESSIONEXPIRED) + { + LOG_DEBUG(log, "Session #{} expired because server shutting down or quorum is not alive", session_id); + keeper_dispatcher->finishSession(session_id); + return; + } + + result.responses_count--; + } + + if (result.error) + throw Exception(ErrorCodes::SYSTEM_ERROR, "Exception happened while reading from socket"); + + if (session_stopwatch.elapsedMicroseconds() > static_cast<UInt64>(session_timeout.totalMicroseconds())) + { + LOG_DEBUG(log, "Session #{} expired", session_id); + keeper_dispatcher->finishSession(session_id); + break; + } + } + } + catch (const Exception & ex) + { + log_long_operation("Unknown operation"); + LOG_TRACE(log, "Has {} responses in the queue", responses->size()); + LOG_INFO(log, "Got exception processing session #{}: {}", session_id, getExceptionMessage(ex, true)); + keeper_dispatcher->finishSession(session_id); + } +} + +bool KeeperTCPHandler::isHandShake(int32_t handshake_length) +{ + return handshake_length == Coordination::CLIENT_HANDSHAKE_LENGTH + || handshake_length == Coordination::CLIENT_HANDSHAKE_LENGTH_WITH_READONLY; +} + +bool KeeperTCPHandler::tryExecuteFourLetterWordCmd(int32_t command) +{ + if (!FourLetterCommandFactory::instance().isKnown(command)) + { + LOG_WARNING(log, "invalid four letter command {}", IFourLetterCommand::toName(command)); + return false; + } + else if (!FourLetterCommandFactory::instance().isEnabled(command)) + { + LOG_WARNING(log, "Not enabled four letter command {}", IFourLetterCommand::toName(command)); + return false; + } + else + { + auto command_ptr = FourLetterCommandFactory::instance().get(command); + LOG_DEBUG(log, "Receive four letter command {}", command_ptr->name()); + + try + { + String res = command_ptr->run(); + out->write(res.data(), res.size()); + out->next(); + } + catch (...) + { + tryLogCurrentException(log, "Error when executing four letter command " + command_ptr->name()); + } + + return true; + } +} + +std::pair<Coordination::OpNum, Coordination::XID> KeeperTCPHandler::receiveRequest() +{ + int32_t length; + Coordination::read(length, *in); + int32_t xid; + Coordination::read(xid, *in); + + Coordination::OpNum opnum; + Coordination::read(opnum, *in); + + Coordination::ZooKeeperRequestPtr request = Coordination::ZooKeeperRequestFactory::instance().get(opnum); + request->xid = xid; + request->readImpl(*in); + + if (!keeper_dispatcher->putRequest(request, session_id)) + throw Exception(ErrorCodes::TIMEOUT_EXCEEDED, "Session {} already disconnected", session_id); + return std::make_pair(opnum, xid); +} + +void KeeperTCPHandler::packageSent() +{ + conn_stats.incrementPacketsSent(); + keeper_dispatcher->incrementPacketsSent(); +} + +void KeeperTCPHandler::packageReceived() +{ + conn_stats.incrementPacketsReceived(); + keeper_dispatcher->incrementPacketsReceived(); +} + +void KeeperTCPHandler::updateStats(Coordination::ZooKeeperResponsePtr & response) +{ + /// update statistics ignoring watch response and heartbeat. + if (response->xid != Coordination::WATCH_XID && response->getOpNum() != Coordination::OpNum::Heartbeat) + { + Int64 elapsed = (Poco::Timestamp() - operations[response->xid]) / 1000; + conn_stats.updateLatency(elapsed); + + operations.erase(response->xid); + keeper_dispatcher->updateKeeperStatLatency(elapsed); + + last_op.set(std::make_unique<LastOp>(LastOp{ + .name = Coordination::toString(response->getOpNum()), + .last_cxid = response->xid, + .last_zxid = response->zxid, + .last_response_time = Poco::Timestamp().epochMicroseconds() / 1000, + })); + } + +} + +KeeperConnectionStats & KeeperTCPHandler::getConnectionStats() +{ + return conn_stats; +} + +void KeeperTCPHandler::dumpStats(WriteBufferFromOwnString & buf, bool brief) +{ + if (!connected.load(std::memory_order_acquire)) + return; + + auto & stats = getConnectionStats(); + + writeText(' ', buf); + writeText(socket().peerAddress().toString(), buf); + writeText("(recved=", buf); + writeIntText(stats.getPacketsReceived(), buf); + writeText(",sent=", buf); + writeIntText(stats.getPacketsSent(), buf); + if (!brief) + { + if (session_id != 0) + { + writeText(",sid=0x", buf); + writeText(getHexUIntLowercase(session_id), buf); + + writeText(",lop=", buf); + LastOpPtr op = last_op.get(); + writeText(op->name, buf); + writeText(",est=", buf); + writeIntText(established.epochMicroseconds() / 1000, buf); + writeText(",to=", buf); + writeIntText(session_timeout.totalMilliseconds(), buf); + int64_t last_cxid = op->last_cxid; + if (last_cxid >= 0) + { + writeText(",lcxid=0x", buf); + writeText(getHexUIntLowercase(last_cxid), buf); + } + writeText(",lzxid=0x", buf); + writeText(getHexUIntLowercase(op->last_zxid), buf); + writeText(",lresp=", buf); + writeIntText(op->last_response_time, buf); + + writeText(",llat=", buf); + writeIntText(stats.getLastLatency(), buf); + writeText(",minlat=", buf); + writeIntText(stats.getMinLatency(), buf); + writeText(",avglat=", buf); + writeIntText(stats.getAvgLatency(), buf); + writeText(",maxlat=", buf); + writeIntText(stats.getMaxLatency(), buf); + } + } + writeText(')', buf); + writeText('\n', buf); +} + +void KeeperTCPHandler::resetStats() +{ + conn_stats.reset(); + last_op.set(std::make_unique<LastOp>(EMPTY_LAST_OP)); +} + +KeeperTCPHandler::~KeeperTCPHandler() +{ + KeeperTCPHandler::unregisterConnection(this); +} + +std::mutex KeeperTCPHandler::conns_mutex; +std::unordered_set<KeeperTCPHandler *> KeeperTCPHandler::connections; + +void KeeperTCPHandler::registerConnection(KeeperTCPHandler * conn) +{ + std::lock_guard lock(conns_mutex); + connections.insert(conn); +} + +void KeeperTCPHandler::unregisterConnection(KeeperTCPHandler * conn) +{ + std::lock_guard lock(conns_mutex); + connections.erase(conn); +} + +void KeeperTCPHandler::dumpConnections(WriteBufferFromOwnString & buf, bool brief) +{ + std::lock_guard lock(conns_mutex); + for (auto * conn : connections) + { + conn->dumpStats(buf, brief); + } +} + +void KeeperTCPHandler::resetConnsStats() +{ + std::lock_guard lock(conns_mutex); + for (auto * conn : connections) + { + conn->resetStats(); + } +} + +} + +#endif diff --git a/contrib/clickhouse/src/Server/KeeperTCPHandler.h b/contrib/clickhouse/src/Server/KeeperTCPHandler.h new file mode 100644 index 0000000000..653a599de6 --- /dev/null +++ b/contrib/clickhouse/src/Server/KeeperTCPHandler.h @@ -0,0 +1,113 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_NURAFT + +#include <Poco/Net/TCPServerConnection.h> +#include <Common/MultiVersion.h> +#include "IServer.h" +#include <Common/Stopwatch.h> +#include <Common/ZooKeeper/ZooKeeperCommon.h> +#include <Common/ZooKeeper/ZooKeeperConstants.h> +#include <Common/ConcurrentBoundedQueue.h> +#include <Coordination/KeeperDispatcher.h> +#include <IO/WriteBufferFromPocoSocket.h> +#include <IO/ReadBufferFromPocoSocket.h> +#include <unordered_map> +#include <Coordination/KeeperConnectionStats.h> +#include <Poco/Timestamp.h> + +namespace DB +{ + +struct SocketInterruptablePollWrapper; +using SocketInterruptablePollWrapperPtr = std::unique_ptr<SocketInterruptablePollWrapper>; + +using ThreadSafeResponseQueue = ConcurrentBoundedQueue<Coordination::ZooKeeperResponsePtr>; +using ThreadSafeResponseQueuePtr = std::shared_ptr<ThreadSafeResponseQueue>; + +struct LastOp; +using LastOpMultiVersion = MultiVersion<LastOp>; +using LastOpPtr = LastOpMultiVersion::Version; + +class KeeperTCPHandler : public Poco::Net::TCPServerConnection +{ +public: + static void registerConnection(KeeperTCPHandler * conn); + static void unregisterConnection(KeeperTCPHandler * conn); + /// dump all connections statistics + static void dumpConnections(WriteBufferFromOwnString & buf, bool brief); + static void resetConnsStats(); + +private: + static std::mutex conns_mutex; + /// all connections + static std::unordered_set<KeeperTCPHandler *> connections; + +public: + KeeperTCPHandler( + const Poco::Util::AbstractConfiguration & config_ref, + std::shared_ptr<KeeperDispatcher> keeper_dispatcher_, + Poco::Timespan receive_timeout_, + Poco::Timespan send_timeout_, + const Poco::Net::StreamSocket & socket_); + void run() override; + + KeeperConnectionStats & getConnectionStats(); + void dumpStats(WriteBufferFromOwnString & buf, bool brief); + void resetStats(); + + ~KeeperTCPHandler() override; + +private: + Poco::Logger * log; + std::shared_ptr<KeeperDispatcher> keeper_dispatcher; + Poco::Timespan operation_timeout; + Poco::Timespan min_session_timeout; + Poco::Timespan max_session_timeout; + Poco::Timespan session_timeout; + int64_t session_id{-1}; + Stopwatch session_stopwatch; + SocketInterruptablePollWrapperPtr poll_wrapper; + Poco::Timespan send_timeout; + Poco::Timespan receive_timeout; + + ThreadSafeResponseQueuePtr responses; + + Coordination::XID close_xid = Coordination::CLOSE_XID; + + /// Streams for reading/writing from/to client connection socket. + std::shared_ptr<ReadBufferFromPocoSocket> in; + std::shared_ptr<WriteBufferFromPocoSocket> out; + + std::atomic<bool> connected{false}; + + void runImpl(); + + void sendHandshake(bool has_leader); + Poco::Timespan receiveHandshake(int32_t handshake_length); + + static bool isHandShake(int32_t handshake_length); + bool tryExecuteFourLetterWordCmd(int32_t command); + + std::pair<Coordination::OpNum, Coordination::XID> receiveRequest(); + + void packageSent(); + void packageReceived(); + + void updateStats(Coordination::ZooKeeperResponsePtr & response); + + Poco::Timestamp established; + + using Operations = std::unordered_map<Coordination::XID, Poco::Timestamp>; + Operations operations; + + LastOpMultiVersion last_op; + + KeeperConnectionStats conn_stats; + +}; + +} +#endif diff --git a/contrib/clickhouse/src/Server/MySQLHandler.cpp b/contrib/clickhouse/src/Server/MySQLHandler.cpp new file mode 100644 index 0000000000..f98b86e6cf --- /dev/null +++ b/contrib/clickhouse/src/Server/MySQLHandler.cpp @@ -0,0 +1,508 @@ +#include "MySQLHandler.h" + +#include <limits> +#include <Common/NetException.h> +#include <Common/OpenSSLHelpers.h> +#include <Core/MySQL/PacketsGeneric.h> +#include <Core/MySQL/PacketsConnection.h> +#include <Core/MySQL/PacketsProtocolText.h> +#include <Core/NamesAndTypes.h> +#include <Interpreters/Session.h> +#include <Interpreters/executeQuery.h> +#include <IO/copyData.h> +#include <IO/LimitReadBuffer.h> +#include <IO/ReadBufferFromPocoSocket.h> +#include <IO/ReadBufferFromString.h> +#include <IO/WriteBufferFromPocoSocket.h> +#include <IO/WriteBufferFromString.h> +#include <IO/ReadHelpers.h> +#include <Server/TCPServer.h> +#include <Storages/IStorage.h> +#include <regex> +#include <Common/setThreadName.h> +#include <Core/MySQL/Authentication.h> +#include <Common/logger_useful.h> +#include <base/scope_guard.h> + +#include "config_version.h" + +#if USE_SSL +# include <Poco/Crypto/RSAKey.h> +# include <Poco/Net/SSLManager.h> +# include <Poco/Net/SecureStreamSocket.h> + +#endif + +namespace DB +{ + +using namespace MySQLProtocol; +using namespace MySQLProtocol::Generic; +using namespace MySQLProtocol::ProtocolText; +using namespace MySQLProtocol::ConnectionPhase; + +#if USE_SSL +using Poco::Net::SecureStreamSocket; +using Poco::Net::SSLManager; +#endif + +namespace ErrorCodes +{ + extern const int CANNOT_READ_ALL_DATA; + extern const int NOT_IMPLEMENTED; + extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES; + extern const int SUPPORT_IS_DISABLED; + extern const int UNSUPPORTED_METHOD; +} + + +static const size_t PACKET_HEADER_SIZE = 4; +static const size_t SSL_REQUEST_PAYLOAD_SIZE = 32; + +static String selectEmptyReplacementQuery(const String & query); +static String showTableStatusReplacementQuery(const String & query); +static String killConnectionIdReplacementQuery(const String & query); +static String selectLimitReplacementQuery(const String & query); + +MySQLHandler::MySQLHandler( + IServer & server_, + TCPServer & tcp_server_, + const Poco::Net::StreamSocket & socket_, + bool ssl_enabled, uint32_t connection_id_) + : Poco::Net::TCPServerConnection(socket_) + , server(server_) + , tcp_server(tcp_server_) + , log(&Poco::Logger::get("MySQLHandler")) + , connection_id(connection_id_) + , auth_plugin(new MySQLProtocol::Authentication::Native41()) +{ + server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; + if (ssl_enabled) + server_capabilities |= CLIENT_SSL; + + replacements.emplace("KILL QUERY", killConnectionIdReplacementQuery); + replacements.emplace("SHOW TABLE STATUS LIKE", showTableStatusReplacementQuery); + replacements.emplace("SHOW VARIABLES", selectEmptyReplacementQuery); + replacements.emplace("SET SQL_SELECT_LIMIT", selectLimitReplacementQuery); +} + +void MySQLHandler::run() +{ + setThreadName("MySQLHandler"); + ThreadStatus thread_status; + + session = std::make_unique<Session>(server.context(), ClientInfo::Interface::MYSQL); + SCOPE_EXIT({ session.reset(); }); + + session->setClientConnectionId(connection_id); + + in = std::make_shared<ReadBufferFromPocoSocket>(socket()); + out = std::make_shared<WriteBufferFromPocoSocket>(socket()); + packet_endpoint = std::make_shared<MySQLProtocol::PacketEndpoint>(*in, *out, sequence_id); + + try + { + Handshake handshake(server_capabilities, connection_id, VERSION_STRING + String("-") + VERSION_NAME, + auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci); + packet_endpoint->sendPacket<Handshake>(handshake, true); + + LOG_TRACE(log, "Sent handshake"); + + HandshakeResponse handshake_response; + finishHandshake(handshake_response); + client_capabilities = handshake_response.capability_flags; + max_packet_size = handshake_response.max_packet_size ? handshake_response.max_packet_size : MAX_PACKET_LENGTH; + + LOG_TRACE(log, + "Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}", + handshake_response.capability_flags, + handshake_response.max_packet_size, + static_cast<int>(handshake_response.character_set), + handshake_response.username, + handshake_response.auth_response.length(), + handshake_response.database, + handshake_response.auth_plugin_name); + + if (!(client_capabilities & CLIENT_PROTOCOL_41)) + throw Exception(ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES, "Required capability: CLIENT_PROTOCOL_41."); + + authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response); + + try + { + session->makeSessionContext(); + session->sessionContext()->setDefaultFormat("MySQLWire"); + if (!handshake_response.database.empty()) + session->sessionContext()->setCurrentDatabase(handshake_response.database); + } + catch (const Exception & exc) + { + log->log(exc); + packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true); + } + + OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0); + packet_endpoint->sendPacket(ok_packet, true); + + while (tcp_server.isOpen()) + { + packet_endpoint->resetSequenceId(); + MySQLPacketPayloadReadBuffer payload = packet_endpoint->getPayload(); + + while (!in->poll(1000000)) + if (!tcp_server.isOpen()) + return; + char command = 0; + payload.readStrict(command); + + // For commands which are executed without MemoryTracker. + LimitReadBuffer limited_payload(payload, 10000, /* trow_exception */ true, /* exact_limit */ {}, "too long MySQL packet."); + + LOG_DEBUG(log, "Received command: {}. Connection id: {}.", + static_cast<int>(static_cast<unsigned char>(command)), connection_id); + + if (!tcp_server.isOpen()) + return; + try + { + switch (command) + { + case COM_QUIT: + return; + case COM_INIT_DB: + comInitDB(limited_payload); + break; + case COM_QUERY: + comQuery(payload); + break; + case COM_FIELD_LIST: + comFieldList(limited_payload); + break; + case COM_PING: + comPing(); + break; + default: + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Command {} is not implemented.", command); + } + } + catch (const NetException & exc) + { + log->log(exc); + throw; + } + catch (...) + { + tryLogCurrentException(log, "MySQLHandler: Cannot read packet: "); + packet_endpoint->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true); + } + } + } + catch (const Poco::Exception & exc) + { + log->log(exc); + } +} + +/** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest. + * Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake. + * If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes. + */ +void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) +{ + size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE; + + /// Buffer for SSLRequest or part of HandshakeResponse. + char buf[packet_size]; + size_t pos = 0; + + /// Reads at least count and at most packet_size bytes. + auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void { + while (pos < count) + { + int ret = socket().receiveBytes(buf + pos, static_cast<uint32_t>(packet_size - pos)); + if (ret == 0) + { + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data. Bytes read: {}. Bytes expected: 3", std::to_string(pos)); + } + pos += ret; + } + }; + read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes. + + size_t payload_size = unalignedLoad<uint32_t>(buf) & 0xFFFFFFu; + LOG_TRACE(log, "payload size: {}", payload_size); + + if (payload_size == SSL_REQUEST_PAYLOAD_SIZE) + { + finishHandshakeSSL(packet_size, buf, pos, read_bytes, packet); + } + else + { + /// Reading rest of HandshakeResponse. + packet_size = PACKET_HEADER_SIZE + payload_size; + WriteBufferFromOwnString buf_for_handshake_response; + buf_for_handshake_response.write(buf, pos); + copyData(*packet_endpoint->in, buf_for_handshake_response, packet_size - pos); + ReadBufferFromString payload(buf_for_handshake_response.str()); + payload.ignore(PACKET_HEADER_SIZE); + packet.readPayloadWithUnpacked(payload); + packet_endpoint->sequence_id++; + } +} + +void MySQLHandler::authenticate(const String & user_name, const String & auth_plugin_name, const String & initial_auth_response) +{ + try + { + // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used. + if (session->getAuthenticationTypeOrLogInFailure(user_name) == DB::AuthenticationType::SHA256_PASSWORD) + { + authPluginSSL(); + } + + std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt; + auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress()); + } + catch (const Exception & exc) + { + LOG_ERROR(log, "Authentication for user {} failed.", user_name); + packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true); + throw; + } + LOG_DEBUG(log, "Authentication for user {} succeeded.", user_name); +} + +void MySQLHandler::comInitDB(ReadBuffer & payload) +{ + String database; + readStringUntilEOF(database, payload); + LOG_DEBUG(log, "Setting current database to {}", database); + session->sessionContext()->setCurrentDatabase(database); + packet_endpoint->sendPacket(OKPacket(0, client_capabilities, 0, 0, 1), true); +} + +void MySQLHandler::comFieldList(ReadBuffer & payload) +{ + ComFieldList packet; + packet.readPayloadWithUnpacked(payload); + const auto session_context = session->sessionContext(); + String database = session_context->getCurrentDatabase(); + StoragePtr table_ptr = DatabaseCatalog::instance().getTable({database, packet.table}, session_context); + auto metadata_snapshot = table_ptr->getInMemoryMetadataPtr(); + for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll()) + { + ColumnDefinition column_definition( + database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0, true + ); + packet_endpoint->sendPacket(column_definition); + } + packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, 0, 0, 0), true); +} + +void MySQLHandler::comPing() +{ + packet_endpoint->sendPacket(OKPacket(0x0, client_capabilities, 0, 0, 0), true); +} + +static bool isFederatedServerSetupSetCommand(const String & query); + +void MySQLHandler::comQuery(ReadBuffer & payload) +{ + String query = String(payload.position(), payload.buffer().end()); + + // This is a workaround in order to support adding ClickHouse to MySQL using federated server. + // As Clickhouse doesn't support these statements, we just send OK packet in response. + if (isFederatedServerSetupSetCommand(query)) + { + packet_endpoint->sendPacket(OKPacket(0x00, client_capabilities, 0, 0, 0), true); + } + else + { + String replacement_query; + bool should_replace = false; + bool with_output = false; + + for (auto const & x : replacements) + { + if (0 == strncasecmp(x.first.c_str(), query.c_str(), x.first.size())) + { + should_replace = true; + replacement_query = x.second(query); + break; + } + } + + ReadBufferFromString replacement(replacement_query); + + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(fmt::format("mysql:{}:{}", connection_id, toString(UUIDHelpers::generateV4()))); + CurrentThread::QueryScope query_scope{query_context}; + + std::atomic<size_t> affected_rows {0}; + auto prev = query_context->getProgressCallback(); + query_context->setProgressCallback([&, my_prev = prev](const Progress & progress) + { + if (my_prev) + my_prev(progress); + + affected_rows += progress.written_rows; + }); + + FormatSettings format_settings; + format_settings.mysql_wire.client_capabilities = client_capabilities; + format_settings.mysql_wire.max_packet_size = max_packet_size; + format_settings.mysql_wire.sequence_id = &sequence_id; + + auto set_result_details = [&with_output](const QueryResultDetails & details) + { + if (details.format) + { + if (*details.format != "MySQLWire") + throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats"); + + with_output = true; + } + }; + + executeQuery(should_replace ? replacement : payload, *out, false, query_context, set_result_details, format_settings); + + if (!with_output) + packet_endpoint->sendPacket(OKPacket(0x00, client_capabilities, affected_rows, 0, 0), true); + } +} + +void MySQLHandler::authPluginSSL() +{ + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, + "ClickHouse was built without SSL support. Try specifying password using double SHA1 in users.xml."); +} + +void MySQLHandler::finishHandshakeSSL( + [[maybe_unused]] size_t packet_size, [[maybe_unused]] char * buf, [[maybe_unused]] size_t pos, + [[maybe_unused]] std::function<void(size_t)> read_bytes, [[maybe_unused]] MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) +{ + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "Client requested SSL, while it is disabled."); +} + +#if USE_SSL +MySQLHandlerSSL::MySQLHandlerSSL( + IServer & server_, + TCPServer & tcp_server_, + const Poco::Net::StreamSocket & socket_, + bool ssl_enabled, + uint32_t connection_id_, + RSA & public_key_, + RSA & private_key_) + : MySQLHandler(server_, tcp_server_, socket_, ssl_enabled, connection_id_) + , public_key(public_key_) + , private_key(private_key_) +{} + +void MySQLHandlerSSL::authPluginSSL() +{ + auth_plugin = std::make_unique<MySQLProtocol::Authentication::Sha256Password>(public_key, private_key, log); +} + +void MySQLHandlerSSL::finishHandshakeSSL( + size_t packet_size, char *buf, size_t pos, std::function<void(size_t)> read_bytes, + MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) +{ + read_bytes(packet_size); /// Reading rest SSLRequest. + SSLRequest ssl_request; + ReadBufferFromMemory payload(buf, pos); + payload.ignore(PACKET_HEADER_SIZE); + ssl_request.readPayloadWithUnpacked(payload); + client_capabilities = ssl_request.capability_flags; + max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH; + secure_connection = true; + ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext())); + in = std::make_shared<ReadBufferFromPocoSocket>(*ss); + out = std::make_shared<WriteBufferFromPocoSocket>(*ss); + sequence_id = 2; + packet_endpoint = std::make_shared<MySQLProtocol::PacketEndpoint>(*in, *out, sequence_id); + packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket. +} + +#endif + +static bool isFederatedServerSetupSetCommand(const String & query) +{ + static const std::regex expr{ + "(^(SET NAMES(.*)))" + "|(^(SET character_set_results(.*)))" + "|(^(SET FOREIGN_KEY_CHECKS(.*)))" + "|(^(SET AUTOCOMMIT(.*)))" + "|(^(SET sql_mode(.*)))" + "|(^(SET @@(.*)))" + "|(^(SET SESSION TRANSACTION ISOLATION LEVEL(.*)))" + , std::regex::icase}; + return 1 == std::regex_match(query, expr); +} + +/// Replace "[query(such as SHOW VARIABLES...)]" into "". +static String selectEmptyReplacementQuery(const String & query) +{ + std::ignore = query; + return "select ''"; +} + +/// Replace "SHOW TABLE STATUS LIKE 'xx'" into "SELECT ... FROM system.tables WHERE name LIKE 'xx'". +static String showTableStatusReplacementQuery(const String & query) +{ + const String prefix = "SHOW TABLE STATUS LIKE "; + if (query.size() > prefix.size()) + { + String suffix = query.data() + prefix.length(); + return ( + "SELECT" + " name AS Name," + " engine AS Engine," + " '10' AS Version," + " 'Dynamic' AS Row_format," + " 0 AS Rows," + " 0 AS Avg_row_length," + " 0 AS Data_length," + " 0 AS Max_data_length," + " 0 AS Index_length," + " 0 AS Data_free," + " 'NULL' AS Auto_increment," + " metadata_modification_time AS Create_time," + " metadata_modification_time AS Update_time," + " metadata_modification_time AS Check_time," + " 'utf8_bin' AS Collation," + " 'NULL' AS Checksum," + " '' AS Create_options," + " '' AS Comment" + " FROM system.tables" + " WHERE name LIKE " + + suffix); + } + return query; +} + +static String selectLimitReplacementQuery(const String & query) +{ + const String prefix = "SET SQL_SELECT_LIMIT"; + if (query.starts_with(prefix)) + return "SET limit" + std::string(query.data() + prefix.length()); + return query; +} + +/// Replace "KILL QUERY [connection_id]" into "KILL QUERY WHERE query_id LIKE 'mysql:[connection_id]:xxx'". +static String killConnectionIdReplacementQuery(const String & query) +{ + const String prefix = "KILL QUERY "; + if (query.size() > prefix.size()) + { + String suffix = query.data() + prefix.length(); + static const std::regex expr{"^[0-9]"}; + if (std::regex_match(suffix, expr)) + { + String replacement = fmt::format("KILL QUERY WHERE query_id LIKE 'mysql:{}:%'", suffix); + return replacement; + } + } + return query; +} + +} diff --git a/contrib/clickhouse/src/Server/MySQLHandler.h b/contrib/clickhouse/src/Server/MySQLHandler.h new file mode 100644 index 0000000000..a4de5c4590 --- /dev/null +++ b/contrib/clickhouse/src/Server/MySQLHandler.h @@ -0,0 +1,111 @@ +#pragma once + +#include <Poco/Net/TCPServerConnection.h> +#include <base/getFQDNOrHostName.h> +#include <Common/CurrentMetrics.h> +#include <Core/MySQL/Authentication.h> +#include <Core/MySQL/PacketsGeneric.h> +#include <Core/MySQL/PacketsConnection.h> +#include <Core/MySQL/PacketsProtocolText.h> +#include "IServer.h" + +#include "clickhouse_config.h" + +#if USE_SSL +# include <Poco/Net/SecureStreamSocket.h> +#endif + +#include <memory> + +namespace CurrentMetrics +{ + extern const Metric MySQLConnection; +} + +namespace DB +{ +class ReadBufferFromPocoSocket; +class TCPServer; + +/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client. +class MySQLHandler : public Poco::Net::TCPServerConnection +{ +public: + MySQLHandler( + IServer & server_, + TCPServer & tcp_server_, + const Poco::Net::StreamSocket & socket_, + bool ssl_enabled, + uint32_t connection_id_); + + void run() final; + +protected: + CurrentMetrics::Increment metric_increment{CurrentMetrics::MySQLConnection}; + + /// Enables SSL, if client requested. + void finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResponse &); + + void comQuery(ReadBuffer & payload); + + void comFieldList(ReadBuffer & payload); + + void comPing(); + + void comInitDB(ReadBuffer & payload); + + void authenticate(const String & user_name, const String & auth_plugin_name, const String & auth_response); + + virtual void authPluginSSL(); + virtual void finishHandshakeSSL(size_t packet_size, char * buf, size_t pos, std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet); + + IServer & server; + TCPServer & tcp_server; + Poco::Logger * log; + uint32_t connection_id = 0; + + uint32_t server_capabilities = 0; + uint32_t client_capabilities = 0; + size_t max_packet_size = 0; + uint8_t sequence_id = 0; + + MySQLProtocol::PacketEndpointPtr packet_endpoint; + std::unique_ptr<Session> session; + + using ReplacementFn = std::function<String(const String & query)>; + using Replacements = std::unordered_map<std::string, ReplacementFn>; + Replacements replacements; + + std::unique_ptr<MySQLProtocol::Authentication::IPlugin> auth_plugin; + std::shared_ptr<ReadBufferFromPocoSocket> in; + std::shared_ptr<WriteBuffer> out; + bool secure_connection = false; +}; + +#if USE_SSL +class MySQLHandlerSSL : public MySQLHandler +{ +public: + MySQLHandlerSSL( + IServer & server_, + TCPServer & tcp_server_, + const Poco::Net::StreamSocket & socket_, + bool ssl_enabled, + uint32_t connection_id_, + RSA & public_key_, + RSA & private_key_); + +private: + void authPluginSSL() override; + + void finishHandshakeSSL( + size_t packet_size, char * buf, size_t pos, + std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) override; + + RSA & public_key; + RSA & private_key; + std::shared_ptr<Poco::Net::SecureStreamSocket> ss; +}; +#endif + +} diff --git a/contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp b/contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp new file mode 100644 index 0000000000..deadb10f9a --- /dev/null +++ b/contrib/clickhouse/src/Server/MySQLHandlerFactory.cpp @@ -0,0 +1,140 @@ +#include "MySQLHandlerFactory.h" +#include <Common/OpenSSLHelpers.h> +#include <Poco/Net/TCPServerConnectionFactory.h> +#include <Poco/Util/Application.h> +#include <Common/logger_useful.h> +#include <base/scope_guard.h> +#include <Server/MySQLHandler.h> + +#if USE_SSL +# include <Poco/Net/SSLManager.h> +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; + extern const int NO_ELEMENTS_IN_CONFIG; + extern const int OPENSSL_ERROR; +} + +MySQLHandlerFactory::MySQLHandlerFactory(IServer & server_) + : server(server_) + , log(&Poco::Logger::get("MySQLHandlerFactory")) +{ +#if USE_SSL + try + { + Poco::Net::SSLManager::instance().defaultServerContext(); + } + catch (...) + { + LOG_TRACE(log, "Failed to create SSL context. SSL will be disabled. Error: {}", getCurrentExceptionMessage(false)); + ssl_enabled = false; + } + + /// Reading rsa keys for SHA256 authentication plugin. + try + { + readRSAKeys(); + } + catch (...) + { + LOG_TRACE(log, "Failed to read RSA key pair from server certificate. Error: {}", getCurrentExceptionMessage(false)); + generateRSAKeys(); + } +#endif +} + +#if USE_SSL +void MySQLHandlerFactory::readRSAKeys() +{ + const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config(); + String certificate_file_property = "openSSL.server.certificateFile"; + String private_key_file_property = "openSSL.server.privateKeyFile"; + + if (!config.has(certificate_file_property)) + throw Exception(ErrorCodes::NO_ELEMENTS_IN_CONFIG, "Certificate file is not set."); + + if (!config.has(private_key_file_property)) + throw Exception(ErrorCodes::NO_ELEMENTS_IN_CONFIG, "Private key file is not set."); + + { + String certificate_file = config.getString(certificate_file_property); + FILE * fp = fopen(certificate_file.data(), "r"); + if (fp == nullptr) + throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open certificate file: {}.", certificate_file); + SCOPE_EXIT( + if (0 != fclose(fp)) + throwFromErrno("Cannot close file with the certificate in MySQLHandlerFactory", ErrorCodes::CANNOT_CLOSE_FILE); + ); + + X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr); + SCOPE_EXIT(X509_free(x509)); + if (x509 == nullptr) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to read PEM certificate from {}. Error: {}", certificate_file, getOpenSSLErrors()); + + EVP_PKEY * p = X509_get_pubkey(x509); + if (p == nullptr) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to get RSA key from X509. Error: {}", getOpenSSLErrors()); + SCOPE_EXIT(EVP_PKEY_free(p)); + + public_key.reset(EVP_PKEY_get1_RSA(p)); + if (public_key.get() == nullptr) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to get RSA key from ENV_PKEY. Error: {}", getOpenSSLErrors()); + } + + { + String private_key_file = config.getString(private_key_file_property); + + FILE * fp = fopen(private_key_file.data(), "r"); + if (fp == nullptr) + throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open private key file {}.", private_key_file); + SCOPE_EXIT( + if (0 != fclose(fp)) + throwFromErrno("Cannot close file with the certificate in MySQLHandlerFactory", ErrorCodes::CANNOT_CLOSE_FILE); + ); + + private_key.reset(PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr)); + if (!private_key) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to read RSA private key from {}. Error: {}", private_key_file, getOpenSSLErrors()); + } +} + +void MySQLHandlerFactory::generateRSAKeys() +{ + LOG_TRACE(log, "Generating new RSA key pair."); + public_key.reset(RSA_new()); + if (!public_key) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to allocate RSA key. Error: {}", getOpenSSLErrors()); + + BIGNUM * e = BN_new(); + if (!e) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to allocate BIGNUM. Error: {}", getOpenSSLErrors()); + SCOPE_EXIT(BN_free(e)); + + if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(public_key.get(), 2048, e, nullptr)) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to generate RSA key. Error: {}", getOpenSSLErrors()); + + private_key.reset(RSAPrivateKey_dup(public_key.get())); + if (!private_key) + throw Exception(ErrorCodes::OPENSSL_ERROR, "Failed to copy RSA key. Error: {}", getOpenSSLErrors()); +} +#endif + +Poco::Net::TCPServerConnection * MySQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) +{ + uint32_t connection_id = last_connection_id++; + LOG_TRACE(log, "MySQL connection. Id: {}. Address: {}", connection_id, socket.peerAddress().toString()); +#if USE_SSL + return new MySQLHandlerSSL(server, tcp_server, socket, ssl_enabled, connection_id, *public_key, *private_key); +#else + return new MySQLHandler(server, tcp_server, socket, ssl_enabled, connection_id); +#endif + +} + +} diff --git a/contrib/clickhouse/src/Server/MySQLHandlerFactory.h b/contrib/clickhouse/src/Server/MySQLHandlerFactory.h new file mode 100644 index 0000000000..edc53e5367 --- /dev/null +++ b/contrib/clickhouse/src/Server/MySQLHandlerFactory.h @@ -0,0 +1,50 @@ +#pragma once + +#include <atomic> +#include <memory> +#include <Server/IServer.h> +#include <Server/TCPServerConnectionFactory.h> + +#include "clickhouse_config.h" + +#if USE_SSL +# include <openssl/rsa.h> +#endif + +namespace DB +{ +class TCPServer; + +class MySQLHandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + Poco::Logger * log; + +#if USE_SSL + struct RSADeleter + { + void operator()(RSA * ptr) { RSA_free(ptr); } + }; + using RSAPtr = std::unique_ptr<RSA, RSADeleter>; + + RSAPtr public_key; + RSAPtr private_key; + + bool ssl_enabled = true; +#else + bool ssl_enabled = false; +#endif + + std::atomic<unsigned> last_connection_id = 0; +public: + explicit MySQLHandlerFactory(IServer & server_); + + void readRSAKeys(); + + void generateRSAKeys(); + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override; +}; + +} diff --git a/contrib/clickhouse/src/Server/NotFoundHandler.cpp b/contrib/clickhouse/src/Server/NotFoundHandler.cpp new file mode 100644 index 0000000000..5b1db50855 --- /dev/null +++ b/contrib/clickhouse/src/Server/NotFoundHandler.cpp @@ -0,0 +1,31 @@ +#include <Server/NotFoundHandler.h> + +#include <IO/HTTPCommon.h> +#include <Common/Exception.h> + +namespace DB +{ +void NotFoundHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + try + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_NOT_FOUND); + *response.send() << "There is no handle " << request.getURI() + << (!hints.empty() ? fmt::format(". Maybe you meant {}.", hints.front()) : "") << "\n\n" + << "Use / or /ping for health checks.\n" + << "Or /replicas_status for more sophisticated health checks.\n\n" + << "Send queries from your program with POST method or GET /?query=...\n\n" + << "Use clickhouse-client:\n\n" + << "For interactive data analysis:\n" + << " clickhouse-client\n\n" + << "For batch query processing:\n" + << " clickhouse-client --query='SELECT 1' > result\n" + << " clickhouse-client < query > result\n"; + } + catch (...) + { + tryLogCurrentException("NotFoundHandler"); + } +} + +} diff --git a/contrib/clickhouse/src/Server/NotFoundHandler.h b/contrib/clickhouse/src/Server/NotFoundHandler.h new file mode 100644 index 0000000000..1cbfcd57f8 --- /dev/null +++ b/contrib/clickhouse/src/Server/NotFoundHandler.h @@ -0,0 +1,18 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandler.h> + +namespace DB +{ + +/// Response with 404 and verbose description. +class NotFoundHandler : public HTTPRequestHandler +{ +public: + NotFoundHandler(std::vector<std::string> hints_) : hints(std::move(hints_)) {} + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; +private: + std::vector<std::string> hints; +}; + +} diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandler.cpp b/contrib/clickhouse/src/Server/PostgreSQLHandler.cpp new file mode 100644 index 0000000000..7b07815425 --- /dev/null +++ b/contrib/clickhouse/src/Server/PostgreSQLHandler.cpp @@ -0,0 +1,329 @@ +#include <IO/ReadBufferFromPocoSocket.h> +#include <IO/ReadHelpers.h> +#include <IO/ReadBufferFromString.h> +#include <IO/WriteBufferFromPocoSocket.h> +#include <Interpreters/Context.h> +#include <Interpreters/executeQuery.h> +#include "PostgreSQLHandler.h" +#include <Parsers/parseQuery.h> +#include <Server/TCPServer.h> +#include <Common/setThreadName.h> +#include <base/scope_guard.h> +#include <random> + +#include "config_version.h" + +#if USE_SSL +# include <Poco/Net/SecureStreamSocket.h> +# include <Poco/Net/SSLManager.h> +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SYNTAX_ERROR; +} + +PostgreSQLHandler::PostgreSQLHandler( + const Poco::Net::StreamSocket & socket_, + IServer & server_, + TCPServer & tcp_server_, + bool ssl_enabled_, + Int32 connection_id_, + std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> & auth_methods_) + : Poco::Net::TCPServerConnection(socket_) + , server(server_) + , tcp_server(tcp_server_) + , ssl_enabled(ssl_enabled_) + , connection_id(connection_id_) + , authentication_manager(auth_methods_) +{ + changeIO(socket()); +} + +void PostgreSQLHandler::changeIO(Poco::Net::StreamSocket & socket) +{ + in = std::make_shared<ReadBufferFromPocoSocket>(socket); + out = std::make_shared<WriteBufferFromPocoSocket>(socket); + message_transport = std::make_shared<PostgreSQLProtocol::Messaging::MessageTransport>(in.get(), out.get()); +} + +void PostgreSQLHandler::run() +{ + setThreadName("PostgresHandler"); + ThreadStatus thread_status; + + session = std::make_unique<Session>(server.context(), ClientInfo::Interface::POSTGRESQL); + SCOPE_EXIT({ session.reset(); }); + + session->setClientConnectionId(connection_id); + + try + { + if (!startup()) + return; + + while (tcp_server.isOpen()) + { + message_transport->send(PostgreSQLProtocol::Messaging::ReadyForQuery(), true); + + constexpr size_t connection_check_timeout = 1; // 1 second + while (!in->poll(1000000 * connection_check_timeout)) + if (!tcp_server.isOpen()) + return; + PostgreSQLProtocol::Messaging::FrontMessageType message_type = message_transport->receiveMessageType(); + + if (!tcp_server.isOpen()) + return; + switch (message_type) + { + case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY: + processQuery(); + break; + case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE: + LOG_DEBUG(log, "Client closed the connection"); + return; + case PostgreSQLProtocol::Messaging::FrontMessageType::PARSE: + case PostgreSQLProtocol::Messaging::FrontMessageType::BIND: + case PostgreSQLProtocol::Messaging::FrontMessageType::DESCRIBE: + case PostgreSQLProtocol::Messaging::FrontMessageType::SYNC: + case PostgreSQLProtocol::Messaging::FrontMessageType::FLUSH: + case PostgreSQLProtocol::Messaging::FrontMessageType::CLOSE: + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "0A000", + "ClickHouse doesn't support extended query mechanism"), + true); + LOG_ERROR(log, "Client tried to access via extended query protocol"); + message_transport->dropMessage(); + break; + default: + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, + "0A000", + "Command is not supported"), + true); + LOG_ERROR(log, "Command is not supported. Command code {:d}", static_cast<Int32>(message_type)); + message_transport->dropMessage(); + } + } + } + catch (const Poco::Exception &exc) + { + log->log(exc); + } + +} + +bool PostgreSQLHandler::startup() +{ + Int32 payload_size; + Int32 info; + establishSecureConnection(payload_size, info); + + if (static_cast<PostgreSQLProtocol::Messaging::FrontMessageType>(info) == PostgreSQLProtocol::Messaging::FrontMessageType::CANCEL_REQUEST) + { + LOG_DEBUG(log, "Client issued request canceling"); + cancelRequest(); + return false; + } + + std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> start_up_msg = receiveStartupMessage(payload_size); + const auto & user_name = start_up_msg->user; + authentication_manager.authenticate(user_name, *session, *message_transport, socket().peerAddress()); + + try + { + session->makeSessionContext(); + session->sessionContext()->setDefaultFormat("PostgreSQLWire"); + if (!start_up_msg->database.empty()) + session->sessionContext()->setCurrentDatabase(start_up_msg->database); + } + catch (const Exception & exc) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "XX000", exc.message()), + true); + throw; + } + + sendParameterStatusData(*start_up_msg); + + message_transport->send( + PostgreSQLProtocol::Messaging::BackendKeyData(connection_id, secret_key), true); + + LOG_DEBUG(log, "Successfully finished Startup stage"); + return true; +} + +void PostgreSQLHandler::establishSecureConnection(Int32 & payload_size, Int32 & info) +{ + bool was_encryption_req = true; + readBinaryBigEndian(payload_size, *in); + readBinaryBigEndian(info, *in); + + switch (static_cast<PostgreSQLProtocol::Messaging::FrontMessageType>(info)) + { + case PostgreSQLProtocol::Messaging::FrontMessageType::SSL_REQUEST: + LOG_DEBUG(log, "Client requested SSL"); + if (ssl_enabled) + makeSecureConnectionSSL(); + else + message_transport->send('N', true); + break; + case PostgreSQLProtocol::Messaging::FrontMessageType::GSSENC_REQUEST: + LOG_DEBUG(log, "Client requested GSSENC"); + message_transport->send('N', true); + break; + default: + was_encryption_req = false; + } + if (was_encryption_req) + { + readBinaryBigEndian(payload_size, *in); + readBinaryBigEndian(info, *in); + } +} + +#if USE_SSL +void PostgreSQLHandler::makeSecureConnectionSSL() +{ + message_transport->send('S'); + ss = std::make_shared<Poco::Net::SecureStreamSocket>( + Poco::Net::SecureStreamSocket::attach(socket(), Poco::Net::SSLManager::instance().defaultServerContext())); + changeIO(*ss); +} +#else +void PostgreSQLHandler::makeSecureConnectionSSL() {} +#endif + +void PostgreSQLHandler::sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message) +{ + std::unordered_map<String, String> & parameters = start_up_message.parameters; + + if (parameters.find("application_name") != parameters.end()) + message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("application_name", parameters["application_name"])); + if (parameters.find("client_encoding") != parameters.end()) + message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("client_encoding", parameters["client_encoding"])); + else + message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("client_encoding", "UTF8")); + + message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("server_version", VERSION_STRING)); + message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("server_encoding", "UTF8")); + message_transport->send(PostgreSQLProtocol::Messaging::ParameterStatus("DateStyle", "ISO")); + message_transport->flush(); +} + +void PostgreSQLHandler::cancelRequest() +{ + std::unique_ptr<PostgreSQLProtocol::Messaging::CancelRequest> msg = + message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::CancelRequest>(8); + + String query = fmt::format("KILL QUERY WHERE query_id = 'postgres:{:d}:{:d}'", msg->process_id, msg->secret_key); + ReadBufferFromString replacement(query); + + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); + executeQuery(replacement, *out, true, query_context, {}); +} + +inline std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> PostgreSQLHandler::receiveStartupMessage(int payload_size) +{ + std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> message; + try + { + message = message_transport->receiveWithPayloadSize<PostgreSQLProtocol::Messaging::StartupMessage>(payload_size - 8); + } + catch (const Exception &) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "08P01", "Can't correctly handle Startup message"), + true); + throw; + } + + LOG_DEBUG(log, "Successfully received Startup message"); + return message; +} + +void PostgreSQLHandler::processQuery() +{ + try + { + std::unique_ptr<PostgreSQLProtocol::Messaging::Query> query = + message_transport->receive<PostgreSQLProtocol::Messaging::Query>(); + + if (isEmptyQuery(query->query)) + { + message_transport->send(PostgreSQLProtocol::Messaging::EmptyQueryResponse()); + return; + } + + bool psycopg2_cond = query->query == "BEGIN" || query->query == "COMMIT"; // psycopg2 starts and ends queries with BEGIN/COMMIT commands + bool jdbc_cond = query->query.find("SET extra_float_digits") != String::npos || query->query.find("SET application_name") != String::npos; // jdbc starts with setting this parameter + if (psycopg2_cond || jdbc_cond) + { + message_transport->send( + PostgreSQLProtocol::Messaging::CommandComplete( + PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(query->query), 0)); + return; + } + + const auto & settings = session->sessionContext()->getSettingsRef(); + std::vector<String> queries; + auto parse_res = splitMultipartQuery(query->query, queries, + settings.max_query_size, + settings.max_parser_depth, + settings.allow_settings_after_format_in_insert); + if (!parse_res.second) + throw Exception(ErrorCodes::SYNTAX_ERROR, "Cannot parse and execute the following part of query: {}", String(parse_res.first)); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<Int32> dis(0, INT32_MAX); + + for (const auto & spl_query : queries) + { + secret_key = dis(gen); + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(fmt::format("postgres:{:d}:{:d}", connection_id, secret_key)); + + CurrentThread::QueryScope query_scope{query_context}; + ReadBufferFromString read_buf(spl_query); + executeQuery(read_buf, *out, false, query_context, {}); + + PostgreSQLProtocol::Messaging::CommandComplete::Command command = + PostgreSQLProtocol::Messaging::CommandComplete::classifyQuery(spl_query); + message_transport->send(PostgreSQLProtocol::Messaging::CommandComplete(command, 0), true); + } + + } + catch (const Exception & e) + { + message_transport->send( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse( + PostgreSQLProtocol::Messaging::ErrorOrNoticeResponse::ERROR, "2F000", "Query execution failed.\n" + e.displayText()), + true); + throw; + } +} + +bool PostgreSQLHandler::isEmptyQuery(const String & query) +{ + if (query.empty()) + return true; + /// golang driver pgx sends ";" + if (query == ";") + return true; + + Poco::RegularExpression regex(R"(\A\s*\z)"); + return regex.match(query); +} + +} diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandler.h b/contrib/clickhouse/src/Server/PostgreSQLHandler.h new file mode 100644 index 0000000000..4de5977cc4 --- /dev/null +++ b/contrib/clickhouse/src/Server/PostgreSQLHandler.h @@ -0,0 +1,81 @@ +#pragma once + +#include <Common/CurrentMetrics.h> +#include "clickhouse_config.h" +#include <Core/PostgreSQLProtocol.h> +#include <Poco/Net/TCPServerConnection.h> +#include "IServer.h" + +#if USE_SSL +# include <Poco/Net/SecureStreamSocket.h> +#endif + +namespace CurrentMetrics +{ + extern const Metric PostgreSQLConnection; +} + +namespace DB +{ +class ReadBufferFromPocoSocket; +class Session; +class TCPServer; + +/** PostgreSQL wire protocol implementation. + * For more info see https://www.postgresql.org/docs/current/protocol.html + */ +class PostgreSQLHandler : public Poco::Net::TCPServerConnection +{ +public: + PostgreSQLHandler( + const Poco::Net::StreamSocket & socket_, + IServer & server_, + TCPServer & tcp_server_, + bool ssl_enabled_, + Int32 connection_id_, + std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> & auth_methods_); + + void run() final; + +private: + Poco::Logger * log = &Poco::Logger::get("PostgreSQLHandler"); + + IServer & server; + TCPServer & tcp_server; + std::unique_ptr<Session> session; + bool ssl_enabled = false; + Int32 connection_id = 0; + Int32 secret_key = 0; + + std::shared_ptr<ReadBufferFromPocoSocket> in; + std::shared_ptr<WriteBuffer> out; + std::shared_ptr<PostgreSQLProtocol::Messaging::MessageTransport> message_transport; + +#if USE_SSL + std::shared_ptr<Poco::Net::SecureStreamSocket> ss; +#endif + + PostgreSQLProtocol::PGAuthentication::AuthenticationManager authentication_manager; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::PostgreSQLConnection}; + + void changeIO(Poco::Net::StreamSocket & socket); + + bool startup(); + + void establishSecureConnection(Int32 & payload_size, Int32 & info); + + void makeSecureConnectionSSL(); + + void sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message); + + void cancelRequest(); + + std::unique_ptr<PostgreSQLProtocol::Messaging::StartupMessage> receiveStartupMessage(int payload_size); + + void processQuery(); + + static bool isEmptyQuery(const String & query); +}; + +} diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp new file mode 100644 index 0000000000..6f2124861e --- /dev/null +++ b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.cpp @@ -0,0 +1,26 @@ +#include "PostgreSQLHandlerFactory.h" +#include <memory> +#include <Server/PostgreSQLHandler.h> + +namespace DB +{ + +PostgreSQLHandlerFactory::PostgreSQLHandlerFactory(IServer & server_) + : server(server_) + , log(&Poco::Logger::get("PostgreSQLHandlerFactory")) +{ + auth_methods = + { + std::make_shared<PostgreSQLProtocol::PGAuthentication::NoPasswordAuth>(), + std::make_shared<PostgreSQLProtocol::PGAuthentication::CleartextPasswordAuth>(), + }; +} + +Poco::Net::TCPServerConnection * PostgreSQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) +{ + Int32 connection_id = last_connection_id++; + LOG_TRACE(log, "PostgreSQL connection. Id: {}. Address: {}", connection_id, socket.peerAddress().toString()); + return new PostgreSQLHandler(socket, server, tcp_server, ssl_enabled, connection_id, auth_methods); +} + +} diff --git a/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h new file mode 100644 index 0000000000..08b7ad1d85 --- /dev/null +++ b/contrib/clickhouse/src/Server/PostgreSQLHandlerFactory.h @@ -0,0 +1,33 @@ +#pragma once + +#include <atomic> +#include <memory> +#include <Server/IServer.h> +#include <Server/TCPServerConnectionFactory.h> +#include <Core/PostgreSQLProtocol.h> +#include "clickhouse_config.h" + +namespace DB +{ + +class PostgreSQLHandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + Poco::Logger * log; + +#if USE_SSL + bool ssl_enabled = true; +#else + bool ssl_enabled = false; +#endif + + std::atomic<Int32> last_connection_id = 0; + std::vector<std::shared_ptr<PostgreSQLProtocol::PGAuthentication::AuthenticationMethod>> auth_methods; + +public: + explicit PostgreSQLHandlerFactory(IServer & server_); + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & server) override; +}; +} diff --git a/contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp new file mode 100644 index 0000000000..2331e45522 --- /dev/null +++ b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.cpp @@ -0,0 +1,161 @@ +#include "PrometheusMetricsWriter.h" + +#include <IO/WriteHelpers.h> +#include <Common/StatusInfo.h> +#include <regex> /// TODO: this library is harmful. +#include <algorithm> + + +namespace +{ + +template <typename T> +void writeOutLine(DB::WriteBuffer & wb, T && val) +{ + DB::writeText(std::forward<T>(val), wb); + DB::writeChar('\n', wb); +} + +template <typename T, typename... TArgs> +void writeOutLine(DB::WriteBuffer & wb, T && val, TArgs &&... args) +{ + DB::writeText(std::forward<T>(val), wb); + DB::writeChar(' ', wb); + writeOutLine(wb, std::forward<TArgs>(args)...); +} + +/// Returns false if name is not valid +bool replaceInvalidChars(std::string & metric_name) +{ + /// dirty solution + metric_name = std::regex_replace(metric_name, std::regex("[^a-zA-Z0-9_:]"), "_"); + metric_name = std::regex_replace(metric_name, std::regex("^[^a-zA-Z]*"), ""); + return !metric_name.empty(); +} + +void convertHelpToSingleLine(std::string & help) +{ + std::replace(help.begin(), help.end(), '\n', ' '); +} + +} + + +namespace DB +{ + +PrometheusMetricsWriter::PrometheusMetricsWriter( + const Poco::Util::AbstractConfiguration & config, const std::string & config_name, + const AsynchronousMetrics & async_metrics_) + : async_metrics(async_metrics_) + , send_events(config.getBool(config_name + ".events", true)) + , send_metrics(config.getBool(config_name + ".metrics", true)) + , send_asynchronous_metrics(config.getBool(config_name + ".asynchronous_metrics", true)) + , send_status_info(config.getBool(config_name + ".status_info", true)) +{ +} + +void PrometheusMetricsWriter::write(WriteBuffer & wb) const +{ + if (send_events) + { + for (ProfileEvents::Event i = ProfileEvents::Event(0), end = ProfileEvents::end(); i < end; ++i) + { + const auto counter = ProfileEvents::global_counters[i].load(std::memory_order_relaxed); + + std::string metric_name{ProfileEvents::getName(static_cast<ProfileEvents::Event>(i))}; + std::string metric_doc{ProfileEvents::getDocumentation(static_cast<ProfileEvents::Event>(i))}; + + convertHelpToSingleLine(metric_doc); + + if (!replaceInvalidChars(metric_name)) + continue; + std::string key{profile_events_prefix + metric_name}; + + writeOutLine(wb, "# HELP", key, metric_doc); + writeOutLine(wb, "# TYPE", key, "counter"); + writeOutLine(wb, key, counter); + } + } + + if (send_metrics) + { + for (size_t i = 0, end = CurrentMetrics::end(); i < end; ++i) + { + const auto value = CurrentMetrics::values[i].load(std::memory_order_relaxed); + + std::string metric_name{CurrentMetrics::getName(static_cast<CurrentMetrics::Metric>(i))}; + std::string metric_doc{CurrentMetrics::getDocumentation(static_cast<CurrentMetrics::Metric>(i))}; + + convertHelpToSingleLine(metric_doc); + + if (!replaceInvalidChars(metric_name)) + continue; + std::string key{current_metrics_prefix + metric_name}; + + writeOutLine(wb, "# HELP", key, metric_doc); + writeOutLine(wb, "# TYPE", key, "gauge"); + writeOutLine(wb, key, value); + } + } + + if (send_asynchronous_metrics) + { + auto async_metrics_values = async_metrics.getValues(); + for (const auto & name_value : async_metrics_values) + { + std::string key{asynchronous_metrics_prefix + name_value.first}; + + if (!replaceInvalidChars(key)) + continue; + + auto value = name_value.second; + + std::string metric_doc{value.documentation}; + convertHelpToSingleLine(metric_doc); + + // TODO: add HELP section? asynchronous_metrics contains only key and value + writeOutLine(wb, "# HELP", key, metric_doc); + writeOutLine(wb, "# TYPE", key, "gauge"); + writeOutLine(wb, key, value.value); + } + } + + if (send_status_info) + { + for (size_t i = 0, end = CurrentStatusInfo::end(); i < end; ++i) + { + std::lock_guard lock(CurrentStatusInfo::locks[static_cast<CurrentStatusInfo::Status>(i)]); + std::string metric_name{CurrentStatusInfo::getName(static_cast<CurrentStatusInfo::Status>(i))}; + std::string metric_doc{CurrentStatusInfo::getDocumentation(static_cast<CurrentStatusInfo::Status>(i))}; + + convertHelpToSingleLine(metric_doc); + + if (!replaceInvalidChars(metric_name)) + continue; + std::string key{current_status_prefix + metric_name}; + + writeOutLine(wb, "# HELP", key, metric_doc); + writeOutLine(wb, "# TYPE", key, "gauge"); + + for (const auto & value: CurrentStatusInfo::values[i]) + { + for (const auto & enum_value: CurrentStatusInfo::getAllPossibleValues(static_cast<CurrentStatusInfo::Status>(i))) + { + DB::writeText(key, wb); + DB::writeChar('{', wb); + DB::writeText(key, wb); + DB::writeChar('=', wb); + writeDoubleQuotedString(enum_value.first, wb); + DB::writeText(",name=", wb); + writeDoubleQuotedString(value.first, wb); + DB::writeText("} ", wb); + DB::writeText(value.second == enum_value.second, wb); + DB::writeChar('\n', wb); + } + } + } + } +} + +} diff --git a/contrib/clickhouse/src/Server/PrometheusMetricsWriter.h b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.h new file mode 100644 index 0000000000..b4f6ab57de --- /dev/null +++ b/contrib/clickhouse/src/Server/PrometheusMetricsWriter.h @@ -0,0 +1,38 @@ +#pragma once + +#include <string> + +#include <Common/AsynchronousMetrics.h> +#include <IO/WriteBuffer.h> + +#include <Poco/Util/AbstractConfiguration.h> + + +namespace DB +{ + +/// Write metrics in Prometheus format +class PrometheusMetricsWriter +{ +public: + PrometheusMetricsWriter( + const Poco::Util::AbstractConfiguration & config, const std::string & config_name, + const AsynchronousMetrics & async_metrics_); + + void write(WriteBuffer & wb) const; + +private: + const AsynchronousMetrics & async_metrics; + + const bool send_events; + const bool send_metrics; + const bool send_asynchronous_metrics; + const bool send_status_info; + + static inline constexpr auto profile_events_prefix = "ClickHouseProfileEvents_"; + static inline constexpr auto current_metrics_prefix = "ClickHouseMetrics_"; + static inline constexpr auto asynchronous_metrics_prefix = "ClickHouseAsyncMetrics_"; + static inline constexpr auto current_status_prefix = "ClickHouseStatusInfo_"; +}; + +} diff --git a/contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp b/contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp new file mode 100644 index 0000000000..7902562420 --- /dev/null +++ b/contrib/clickhouse/src/Server/PrometheusRequestHandler.cpp @@ -0,0 +1,71 @@ +#include <Server/PrometheusRequestHandler.h> + +#include <IO/HTTPCommon.h> +#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h> +#include <Server/HTTPHandlerFactory.h> +#include <Server/IServer.h> +#include <Common/CurrentMetrics.h> +#include <Common/Exception.h> +#include <Common/ProfileEvents.h> + +#include <Poco/Util/LayeredConfiguration.h> + + +namespace DB +{ +void PrometheusRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + try + { + const auto & config = server.config(); + unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10); + + setResponseDefaultHeaders(response, keep_alive_timeout); + + response.setContentType("text/plain; version=0.0.4; charset=UTF-8"); + + WriteBufferFromHTTPServerResponse wb(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout); + try + { + metrics_writer.write(wb); + wb.finalize(); + } + catch (...) + { + wb.finalize(); + } + } + catch (...) + { + tryLogCurrentException("PrometheusRequestHandler"); + } +} + +HTTPRequestHandlerFactoryPtr +createPrometheusHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics, + const std::string & config_prefix) +{ + auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<PrometheusRequestHandler>>( + server, PrometheusMetricsWriter(config, config_prefix + ".handler", async_metrics)); + factory->addFiltersFromConfig(config, config_prefix); + return factory; +} + +HTTPRequestHandlerFactoryPtr +createPrometheusMainHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + AsynchronousMetrics & async_metrics, + const std::string & name) +{ + auto factory = std::make_shared<HTTPRequestHandlerFactoryMain>(name); + auto handler = std::make_shared<HandlingRuleHTTPHandlerFactory<PrometheusRequestHandler>>( + server, PrometheusMetricsWriter(config, "prometheus", async_metrics)); + handler->attachStrictPath(config.getString("prometheus.endpoint", "/metrics")); + handler->allowGetAndHeadRequest(); + factory->addHandler(handler); + return factory; +} + +} diff --git a/contrib/clickhouse/src/Server/PrometheusRequestHandler.h b/contrib/clickhouse/src/Server/PrometheusRequestHandler.h new file mode 100644 index 0000000000..1fb3d9f0f5 --- /dev/null +++ b/contrib/clickhouse/src/Server/PrometheusRequestHandler.h @@ -0,0 +1,28 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandler.h> + +#include "PrometheusMetricsWriter.h" + +namespace DB +{ + +class IServer; + +class PrometheusRequestHandler : public HTTPRequestHandler +{ +private: + IServer & server; + const PrometheusMetricsWriter & metrics_writer; + +public: + explicit PrometheusRequestHandler(IServer & server_, const PrometheusMetricsWriter & metrics_writer_) + : server(server_) + , metrics_writer(metrics_writer_) + { + } + + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; +}; + +} diff --git a/contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp b/contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp new file mode 100644 index 0000000000..8d14a84989 --- /dev/null +++ b/contrib/clickhouse/src/Server/ProtocolServerAdapter.cpp @@ -0,0 +1,75 @@ +#include <Server/ProtocolServerAdapter.h> +#include <Server/TCPServer.h> + +#if USE_GRPC && !defined(CLICKHOUSE_KEEPER_STANDALONE_BUILD) +#include <Server/GRPCServer.h> +#endif + + +namespace DB +{ +class ProtocolServerAdapter::TCPServerAdapterImpl : public Impl +{ +public: + explicit TCPServerAdapterImpl(std::unique_ptr<TCPServer> tcp_server_) : tcp_server(std::move(tcp_server_)) {} + ~TCPServerAdapterImpl() override = default; + + void start() override { tcp_server->start(); } + void stop() override { tcp_server->stop(); } + bool isStopping() const override { return !tcp_server->isOpen(); } + UInt16 portNumber() const override { return tcp_server->portNumber(); } + size_t currentConnections() const override { return tcp_server->currentConnections(); } + size_t currentThreads() const override { return tcp_server->currentThreads(); } + +private: + std::unique_ptr<TCPServer> tcp_server; +}; + +ProtocolServerAdapter::ProtocolServerAdapter( + const std::string & listen_host_, + const char * port_name_, + const std::string & description_, + std::unique_ptr<TCPServer> tcp_server_) + : listen_host(listen_host_) + , port_name(port_name_) + , description(description_) + , impl(std::make_unique<TCPServerAdapterImpl>(std::move(tcp_server_))) +{ +} + +#if USE_GRPC && !defined(CLICKHOUSE_KEEPER_STANDALONE_BUILD) +class ProtocolServerAdapter::GRPCServerAdapterImpl : public Impl +{ +public: + explicit GRPCServerAdapterImpl(std::unique_ptr<GRPCServer> grpc_server_) : grpc_server(std::move(grpc_server_)) {} + ~GRPCServerAdapterImpl() override = default; + + void start() override { grpc_server->start(); } + void stop() override + { + is_stopping = true; + grpc_server->stop(); + } + bool isStopping() const override { return is_stopping; } + UInt16 portNumber() const override { return grpc_server->portNumber(); } + size_t currentConnections() const override { return grpc_server->currentConnections(); } + size_t currentThreads() const override { return grpc_server->currentThreads(); } + +private: + std::unique_ptr<GRPCServer> grpc_server; + bool is_stopping = false; +}; + +ProtocolServerAdapter::ProtocolServerAdapter( + const std::string & listen_host_, + const char * port_name_, + const std::string & description_, + std::unique_ptr<GRPCServer> grpc_server_) + : listen_host(listen_host_) + , port_name(port_name_) + , description(description_) + , impl(std::make_unique<GRPCServerAdapterImpl>(std::move(grpc_server_))) +{ +} +#endif +} diff --git a/contrib/clickhouse/src/Server/ProtocolServerAdapter.h b/contrib/clickhouse/src/Server/ProtocolServerAdapter.h new file mode 100644 index 0000000000..9497ec3e8e --- /dev/null +++ b/contrib/clickhouse/src/Server/ProtocolServerAdapter.h @@ -0,0 +1,74 @@ +#pragma once + +#include "clickhouse_config.h" + +#include <base/types.h> +#include <memory> +#include <string> + + +namespace DB +{ + +class GRPCServer; +class TCPServer; + +/// Provides an unified interface to access a protocol implementing server +/// no matter what type it has (HTTPServer, TCPServer, MySQLServer, GRPCServer, ...). +class ProtocolServerAdapter +{ + friend class ProtocolServers; +public: + ProtocolServerAdapter(ProtocolServerAdapter && src) = default; + ProtocolServerAdapter & operator =(ProtocolServerAdapter && src) = default; + ProtocolServerAdapter(const std::string & listen_host_, const char * port_name_, const std::string & description_, std::unique_ptr<TCPServer> tcp_server_); + +#if USE_GRPC && !defined(CLICKHOUSE_KEEPER_STANDALONE_BUILD) + ProtocolServerAdapter(const std::string & listen_host_, const char * port_name_, const std::string & description_, std::unique_ptr<GRPCServer> grpc_server_); +#endif + + /// Starts the server. A new thread will be created that waits for and accepts incoming connections. + void start() { impl->start(); } + + /// Stops the server. No new connections will be accepted. + void stop() { impl->stop(); } + + bool isStopping() const { return impl->isStopping(); } + + /// Returns the number of currently handled connections. + size_t currentConnections() const { return impl->currentConnections(); } + + /// Returns the number of current threads. + size_t currentThreads() const { return impl->currentThreads(); } + + /// Returns the port this server is listening to. + UInt16 portNumber() const { return impl->portNumber(); } + + const std::string & getListenHost() const { return listen_host; } + + const std::string & getPortName() const { return port_name; } + + const std::string & getDescription() const { return description; } + +private: + class Impl + { + public: + virtual ~Impl() = default; + virtual void start() = 0; + virtual void stop() = 0; + virtual bool isStopping() const = 0; + virtual UInt16 portNumber() const = 0; + virtual size_t currentConnections() const = 0; + virtual size_t currentThreads() const = 0; + }; + class TCPServerAdapterImpl; + class GRPCServerAdapterImpl; + + std::string listen_host; + std::string port_name; + std::string description; + std::unique_ptr<Impl> impl; +}; + +} diff --git a/contrib/clickhouse/src/Server/ProxyV1Handler.cpp b/contrib/clickhouse/src/Server/ProxyV1Handler.cpp new file mode 100644 index 0000000000..56621940a2 --- /dev/null +++ b/contrib/clickhouse/src/Server/ProxyV1Handler.cpp @@ -0,0 +1,127 @@ +#include <Server/ProxyV1Handler.h> +#include <Poco/Net/NetException.h> +#include <Common/NetException.h> +#include <Common/logger_useful.h> +#include <Interpreters/Context.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NETWORK_ERROR; + extern const int SOCKET_TIMEOUT; + extern const int CANNOT_READ_FROM_SOCKET; + extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED; +} + +void ProxyV1Handler::run() +{ + const auto & settings = server.context()->getSettingsRef(); + socket().setReceiveTimeout(settings.receive_timeout); + + std::string word; + bool eol; + + // Read PROXYv1 protocol header + // http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + + // read "PROXY" + if (!readWord(5, word, eol) || word != "PROXY" || eol) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + // read "TCP4" or "TCP6" or "UNKNOWN" + if (!readWord(7, word, eol)) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + if (word != "TCP4" && word != "TCP6" && word != "UNKNOWN") + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + if (word == "UNKNOWN" && eol) + return; + + if (eol) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + // read address + if (!readWord(39, word, eol) || eol) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + stack_data.forwarded_for = std::move(word); + + // read address + if (!readWord(39, word, eol) || eol) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + // read port + if (!readWord(5, word, eol) || eol) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + // read port and "\r\n" + if (!readWord(5, word, eol) || !eol) + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "PROXY protocol violation"); + + if (!stack_data.forwarded_for.empty()) + LOG_TRACE(log, "Forwarded client address from PROXY header: {}", stack_data.forwarded_for); +} + +bool ProxyV1Handler::readWord(int max_len, std::string & word, bool & eol) +{ + word.clear(); + eol = false; + + char ch = 0; + int n = 0; + bool is_cr = false; + try + { + for (++max_len; max_len > 0 || is_cr; --max_len) + { + n = socket().receiveBytes(&ch, 1); + if (n == 0) + { + socket().shutdown(); + return false; + } + if (n < 0) + break; + + if (is_cr) + return ch == 0x0A; + + if (ch == 0x0D) + { + is_cr = true; + eol = true; + continue; + } + + if (ch == ' ') + return true; + + word.push_back(ch); + } + } + catch (const Poco::Net::NetException & e) + { + throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while reading from socket ({})", e.displayText(), socket().peerAddress().toString()); + } + catch (const Poco::TimeoutException &) + { + throw NetException(ErrorCodes::SOCKET_TIMEOUT, "Timeout exceeded while reading from socket ({}, {} ms)", + socket().peerAddress().toString(), + socket().getReceiveTimeout().totalMilliseconds()); + } + catch (const Poco::IOException & e) + { + throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while reading from socket ({})", e.displayText(), socket().peerAddress().toString()); + } + + if (n < 0) + throw NetException(ErrorCodes::CANNOT_READ_FROM_SOCKET, "Cannot read from socket ({})", socket().peerAddress().toString()); + + return false; +} + +} diff --git a/contrib/clickhouse/src/Server/ProxyV1Handler.h b/contrib/clickhouse/src/Server/ProxyV1Handler.h new file mode 100644 index 0000000000..b50c2acbc5 --- /dev/null +++ b/contrib/clickhouse/src/Server/ProxyV1Handler.h @@ -0,0 +1,30 @@ +#pragma once + +#include <Poco/Net/TCPServerConnection.h> +#include <Server/IServer.h> +#include <Server/TCPProtocolStackData.h> + + +namespace DB +{ + +class ProxyV1Handler : public Poco::Net::TCPServerConnection +{ + using StreamSocket = Poco::Net::StreamSocket; +public: + explicit ProxyV1Handler(const StreamSocket & socket, IServer & server_, const std::string & conf_name_, TCPProtocolStackData & stack_data_) + : Poco::Net::TCPServerConnection(socket), log(&Poco::Logger::get("ProxyV1Handler")), server(server_), conf_name(conf_name_), stack_data(stack_data_) {} + + void run() override; + +protected: + bool readWord(int max_len, std::string & word, bool & eol); + +private: + Poco::Logger * log; + IServer & server; + std::string conf_name; + TCPProtocolStackData & stack_data; +}; + +} diff --git a/contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h b/contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h new file mode 100644 index 0000000000..028596d745 --- /dev/null +++ b/contrib/clickhouse/src/Server/ProxyV1HandlerFactory.h @@ -0,0 +1,56 @@ +#pragma once + +#include <Poco/Net/NetException.h> +#include <Poco/Net/TCPServerConnection.h> +#include <Server/ProxyV1Handler.h> +#include <Common/logger_useful.h> +#include <Server/IServer.h> +#include <Server/TCPServer.h> +#include <Server/TCPProtocolStackData.h> + + +namespace DB +{ + +class ProxyV1HandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + Poco::Logger * log; + std::string conf_name; + + class DummyTCPHandler : public Poco::Net::TCPServerConnection + { + public: + using Poco::Net::TCPServerConnection::TCPServerConnection; + void run() override {} + }; + +public: + explicit ProxyV1HandlerFactory(IServer & server_, const std::string & conf_name_) + : server(server_), log(&Poco::Logger::get("ProxyV1HandlerFactory")), conf_name(conf_name_) + { + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override + { + TCPProtocolStackData stack_data; + return createConnection(socket, tcp_server, stack_data); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override + { + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + return new ProxyV1Handler(socket, server, conf_name, stack_data); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } +}; + +} diff --git a/contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp b/contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp new file mode 100644 index 0000000000..8c0ab0c1a3 --- /dev/null +++ b/contrib/clickhouse/src/Server/ReplicasStatusHandler.cpp @@ -0,0 +1,128 @@ +#include <Server/ReplicasStatusHandler.h> + +#include <Databases/IDatabase.h> +#include <IO/HTTPCommon.h> +#include <Interpreters/Context.h> +#include <Server/HTTP/HTMLForm.h> +#include <Server/HTTPHandlerFactory.h> +#include <Server/HTTPHandlerRequestFilter.h> +#include <Server/IServer.h> +#include <Storages/StorageReplicatedMergeTree.h> +#include <Common/typeid_cast.h> + +#include <Poco/Net/HTTPRequestHandlerFactory.h> +#include <Poco/Net/HTTPServerRequest.h> +#include <Poco/Net/HTTPServerResponse.h> + + +namespace DB +{ + +ReplicasStatusHandler::ReplicasStatusHandler(IServer & server) : WithContext(server.context()) +{ +} + +void ReplicasStatusHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + try + { + HTMLForm params(getContext()->getSettingsRef(), request); + + /// Even if lag is small, output detailed information about the lag. + bool verbose = params.get("verbose", "") == "1"; + + const MergeTreeSettings & settings = getContext()->getReplicatedMergeTreeSettings(); + + bool ok = true; + WriteBufferFromOwnString message; + + auto databases = DatabaseCatalog::instance().getDatabases(); + + /// Iterate through all the replicated tables. + for (const auto & db : databases) + { + /// Check if database can contain replicated tables + if (!db.second->canContainMergeTreeTables()) + continue; + + for (auto iterator = db.second->getTablesIterator(getContext()); iterator->isValid(); iterator->next()) + { + const auto & table = iterator->table(); + if (!table) + continue; + + StorageReplicatedMergeTree * table_replicated = dynamic_cast<StorageReplicatedMergeTree *>(table.get()); + + if (!table_replicated) + continue; + + time_t absolute_delay = 0; + time_t relative_delay = 0; + + if (!table_replicated->isTableReadOnly()) + { + table_replicated->getReplicaDelays(absolute_delay, relative_delay); + + if ((settings.min_absolute_delay_to_close && absolute_delay >= static_cast<time_t>(settings.min_absolute_delay_to_close)) + || (settings.min_relative_delay_to_close && relative_delay >= static_cast<time_t>(settings.min_relative_delay_to_close))) + ok = false; + + message << backQuoteIfNeed(db.first) << "." << backQuoteIfNeed(iterator->name()) + << ":\tAbsolute delay: " << absolute_delay << ". Relative delay: " << relative_delay << ".\n"; + } + else + { + message << backQuoteIfNeed(db.first) << "." << backQuoteIfNeed(iterator->name()) + << ":\tis readonly. \n"; + } + } + } + + const auto & config = getContext()->getConfigRef(); + setResponseDefaultHeaders(response, config.getUInt("keep_alive_timeout", 10)); + + if (!ok) + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_SERVICE_UNAVAILABLE); + verbose = true; + } + + if (verbose) + *response.send() << message.str(); + else + { + const char * data = "Ok.\n"; + response.sendBuffer(data, strlen(data)); + } + } + catch (...) + { + tryLogCurrentException("ReplicasStatusHandler"); + + try + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR); + + if (!response.sent()) + { + /// We have not sent anything yet and we don't even know if we need to compress response. + *response.send() << getCurrentExceptionMessage(false) << std::endl; + } + } + catch (...) + { + LOG_ERROR((&Poco::Logger::get("ReplicasStatusHandler")), "Cannot send exception to client"); + } + } +} + +HTTPRequestHandlerFactoryPtr createReplicasStatusHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix) +{ + auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<ReplicasStatusHandler>>(server); + factory->addFiltersFromConfig(config, config_prefix); + return factory; +} + +} diff --git a/contrib/clickhouse/src/Server/ReplicasStatusHandler.h b/contrib/clickhouse/src/Server/ReplicasStatusHandler.h new file mode 100644 index 0000000000..1a5388aa2a --- /dev/null +++ b/contrib/clickhouse/src/Server/ReplicasStatusHandler.h @@ -0,0 +1,21 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandler.h> + +namespace DB +{ + +class Context; +class IServer; + +/// Replies "Ok.\n" if all replicas on this server don't lag too much. Otherwise output lag information. +class ReplicasStatusHandler : public HTTPRequestHandler, WithContext +{ +public: + explicit ReplicasStatusHandler(IServer & server_); + + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; +}; + + +} diff --git a/contrib/clickhouse/src/Server/ServerType.cpp b/contrib/clickhouse/src/Server/ServerType.cpp new file mode 100644 index 0000000000..fb052e7d6e --- /dev/null +++ b/contrib/clickhouse/src/Server/ServerType.cpp @@ -0,0 +1,153 @@ +#include <Server/ServerType.h> + +#include <vector> +#include <algorithm> + +#include <magic_enum.hpp> + + +namespace DB +{ + +namespace +{ + std::vector<std::string> getTypeIndexToTypeName() + { + constexpr std::size_t types_size = magic_enum::enum_count<ServerType::Type>(); + + std::vector<std::string> type_index_to_type_name; + type_index_to_type_name.resize(types_size); + + auto entries = magic_enum::enum_entries<ServerType::Type>(); + for (const auto & [entry, str] : entries) + { + auto str_copy = String(str); + std::replace(str_copy.begin(), str_copy.end(), '_', ' '); + type_index_to_type_name[static_cast<UInt64>(entry)] = std::move(str_copy); + } + + return type_index_to_type_name; + } +} + +const char * ServerType::serverTypeToString(ServerType::Type type) +{ + /** During parsing if SystemQuery is not parsed properly it is added to Expected variants as description check IParser.h. + * Description string must be statically allocated. + */ + static std::vector<std::string> type_index_to_type_name = getTypeIndexToTypeName(); + const auto & type_name = type_index_to_type_name[static_cast<UInt64>(type)]; + return type_name.data(); +} + +bool ServerType::shouldStart(Type server_type, const std::string & server_custom_name) const +{ + auto is_type_default = [](Type current_type) + { + switch (current_type) + { + case Type::TCP: + case Type::TCP_WITH_PROXY: + case Type::TCP_SECURE: + case Type::HTTP: + case Type::HTTPS: + case Type::MYSQL: + case Type::GRPC: + case Type::POSTGRESQL: + case Type::PROMETHEUS: + case Type::INTERSERVER_HTTP: + case Type::INTERSERVER_HTTPS: + return true; + default: + return false; + } + }; + + if (exclude_types.contains(Type::QUERIES_ALL)) + return false; + + if (exclude_types.contains(Type::QUERIES_DEFAULT) && is_type_default(server_type)) + return false; + + if (exclude_types.contains(Type::QUERIES_CUSTOM) && server_type == Type::CUSTOM) + return false; + + if (exclude_types.contains(server_type)) + { + if (server_type != Type::CUSTOM) + return false; + + if (exclude_custom_names.contains(server_custom_name)) + return false; + } + + if (type == Type::QUERIES_ALL) + return true; + + if (type == Type::QUERIES_DEFAULT) + return is_type_default(server_type); + + if (type == Type::QUERIES_CUSTOM) + return server_type == Type::CUSTOM; + + if (type == Type::CUSTOM) + return server_type == type && server_custom_name == custom_name; + + return server_type == type; +} + +bool ServerType::shouldStop(const std::string & port_name) const +{ + Type port_type; + std::string port_custom_name; + + if (port_name == "http_port") + port_type = Type::HTTP; + + else if (port_name == "https_port") + port_type = Type::HTTPS; + + else if (port_name == "tcp_port") + port_type = Type::TCP; + + else if (port_name == "tcp_with_proxy_port") + port_type = Type::TCP_WITH_PROXY; + + else if (port_name == "tcp_port_secure") + port_type = Type::TCP_SECURE; + + else if (port_name == "mysql_port") + port_type = Type::MYSQL; + + else if (port_name == "postgresql_port") + port_type = Type::POSTGRESQL; + + else if (port_name == "grpc_port") + port_type = Type::GRPC; + + else if (port_name == "prometheus.port") + port_type = Type::PROMETHEUS; + + else if (port_name == "interserver_http_port") + port_type = Type::INTERSERVER_HTTP; + + else if (port_name == "interserver_https_port") + port_type = Type::INTERSERVER_HTTPS; + + else if (port_name.starts_with("protocols.") && port_name.ends_with(".port")) + { + port_type = Type::CUSTOM; + + constexpr size_t protocols_size = std::string_view("protocols.").size(); + constexpr size_t ports_size = std::string_view(".ports").size(); + + port_custom_name = port_name.substr(protocols_size, port_name.size() - protocols_size - ports_size + 1); + } + + else + return false; + + return shouldStart(port_type, port_custom_name); +} + +} diff --git a/contrib/clickhouse/src/Server/ServerType.h b/contrib/clickhouse/src/Server/ServerType.h new file mode 100644 index 0000000000..e3544fe6a2 --- /dev/null +++ b/contrib/clickhouse/src/Server/ServerType.h @@ -0,0 +1,60 @@ +#pragma once + +#include <base/types.h> +#include <unordered_set> + +namespace DB +{ + +class ServerType +{ +public: + enum Type + { + TCP, + TCP_WITH_PROXY, + TCP_SECURE, + HTTP, + HTTPS, + MYSQL, + GRPC, + POSTGRESQL, + PROMETHEUS, + CUSTOM, + INTERSERVER_HTTP, + INTERSERVER_HTTPS, + QUERIES_ALL, + QUERIES_DEFAULT, + QUERIES_CUSTOM, + END + }; + + using Types = std::unordered_set<Type>; + using CustomNames = std::unordered_set<String>; + + ServerType() = default; + + explicit ServerType( + Type type_, + const std::string & custom_name_ = "", + const Types & exclude_types_ = {}, + const CustomNames exclude_custom_names_ = {}) + : type(type_), + custom_name(custom_name_), + exclude_types(exclude_types_), + exclude_custom_names(exclude_custom_names_) {} + + static const char * serverTypeToString(Type type); + + /// Checks whether provided in the arguments type should be started or stopped based on current server type. + bool shouldStart(Type server_type, const std::string & server_custom_name = "") const; + bool shouldStop(const std::string & port_name) const; + + Type type; + std::string custom_name; + + Types exclude_types; + CustomNames exclude_custom_names; +}; + +} diff --git a/contrib/clickhouse/src/Server/StaticRequestHandler.cpp b/contrib/clickhouse/src/Server/StaticRequestHandler.cpp new file mode 100644 index 0000000000..13a01ba813 --- /dev/null +++ b/contrib/clickhouse/src/Server/StaticRequestHandler.cpp @@ -0,0 +1,179 @@ +#include "StaticRequestHandler.h" +#include "IServer.h" + +#include "HTTPHandlerFactory.h" +#include "HTTPHandlerRequestFilter.h" + +#include <IO/HTTPCommon.h> +#include <IO/ReadBufferFromFile.h> +#include <IO/WriteBufferFromString.h> +#include <IO/copyData.h> +#include <IO/WriteHelpers.h> +#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h> +#include <Interpreters/Context.h> + +#include <Common/Exception.h> + +#include <Poco/Net/HTTPServerRequest.h> +#include <Poco/Net/HTTPServerResponse.h> +#include <Poco/Net/HTTPRequestHandlerFactory.h> +#include <Poco/Util/LayeredConfiguration.h> +#include <filesystem> + + +namespace fs = std::filesystem; + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INCORRECT_FILE_NAME; + extern const int HTTP_LENGTH_REQUIRED; + extern const int INVALID_CONFIG_PARAMETER; +} + +static inline WriteBufferPtr +responseWriteBuffer(HTTPServerRequest & request, HTTPServerResponse & response, unsigned int keep_alive_timeout) +{ + /// The client can pass a HTTP header indicating supported compression method (gzip or deflate). + String http_response_compression_methods = request.get("Accept-Encoding", ""); + CompressionMethod http_response_compression_method = CompressionMethod::None; + + if (!http_response_compression_methods.empty()) + http_response_compression_method = chooseHTTPCompressionMethod(http_response_compression_methods); + + bool client_supports_http_compression = http_response_compression_method != CompressionMethod::None; + + return std::make_shared<WriteBufferFromHTTPServerResponse>( + response, + request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, + keep_alive_timeout, + client_supports_http_compression, + http_response_compression_method); +} + +static inline void trySendExceptionToClient( + const std::string & s, int exception_code, HTTPServerRequest & request, HTTPServerResponse & response, WriteBuffer & out) +{ + try + { + response.set("X-ClickHouse-Exception-Code", toString<int>(exception_code)); + + /// If HTTP method is POST and Keep-Alive is turned on, we should read the whole request body + /// to avoid reading part of the current request body in the next request. + if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST + && response.getKeepAlive() && !request.getStream().eof() && exception_code != ErrorCodes::HTTP_LENGTH_REQUIRED) + request.getStream().ignore(std::numeric_limits<std::streamsize>::max()); + + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR); + + if (!response.sent()) + *response.send() << s << std::endl; + else + { + if (out.count() != out.offset()) + out.position() = out.buffer().begin(); + + writeString(s, out); + writeChar('\n', out); + + out.next(); + out.finalize(); + } + } + catch (...) + { + tryLogCurrentException("StaticRequestHandler", "Cannot send exception to client"); + } +} + +void StaticRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + auto keep_alive_timeout = server.config().getUInt("keep_alive_timeout", 10); + const auto & out = responseWriteBuffer(request, response, keep_alive_timeout); + + try + { + response.setContentType(content_type); + + if (request.getVersion() == Poco::Net::HTTPServerRequest::HTTP_1_1) + response.setChunkedTransferEncoding(true); + + /// Workaround. Poco does not detect 411 Length Required case. + if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST && !request.getChunkedTransferEncoding() && !request.hasContentLength()) + throw Exception(ErrorCodes::HTTP_LENGTH_REQUIRED, + "The Transfer-Encoding is not chunked and there " + "is no Content-Length header for POST request"); + + setResponseDefaultHeaders(response, keep_alive_timeout); + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTPStatus(status)); + writeResponse(*out); + } + catch (...) + { + tryLogCurrentException("StaticRequestHandler"); + + int exception_code = getCurrentExceptionCode(); + std::string exception_message = getCurrentExceptionMessage(false, true); + trySendExceptionToClient(exception_message, exception_code, request, response, *out); + } + + out->finalize(); +} + +void StaticRequestHandler::writeResponse(WriteBuffer & out) +{ + static const String file_prefix = "file://"; + static const String config_prefix = "config://"; + + if (startsWith(response_expression, file_prefix)) + { + auto file_name = response_expression.substr(file_prefix.size(), response_expression.size() - file_prefix.size()); + if (file_name.starts_with('/')) + file_name = file_name.substr(1); + + fs::path user_files_absolute_path = fs::canonical(fs::path(server.context()->getUserFilesPath())); + String file_path = fs::weakly_canonical(user_files_absolute_path / file_name); + + if (!fs::exists(file_path)) + throw Exception(ErrorCodes::INCORRECT_FILE_NAME, "Invalid file name {} for static HTTPHandler. ", file_path); + + ReadBufferFromFile in(file_path); + copyData(in, out); + } + else if (startsWith(response_expression, config_prefix)) + { + if (response_expression.size() <= config_prefix.size()) + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, + "Static handling rule handler must contain a complete configuration path, for example: " + "config://config_key"); + + const auto & config_path = response_expression.substr(config_prefix.size(), response_expression.size() - config_prefix.size()); + writeString(server.config().getRawString(config_path, "Ok.\n"), out); + } + else + writeString(response_expression, out); +} + +StaticRequestHandler::StaticRequestHandler(IServer & server_, const String & expression, int status_, const String & content_type_) + : server(server_), status(status_), content_type(content_type_), response_expression(expression) +{ +} + +HTTPRequestHandlerFactoryPtr createStaticHandlerFactory(IServer & server, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix) +{ + int status = config.getInt(config_prefix + ".handler.status", 200); + std::string response_content = config.getRawString(config_prefix + ".handler.response_content", "Ok.\n"); + std::string response_content_type = config.getString(config_prefix + ".handler.content_type", "text/plain; charset=UTF-8"); + auto factory = std::make_shared<HandlingRuleHTTPHandlerFactory<StaticRequestHandler>>( + server, std::move(response_content), std::move(status), std::move(response_content_type)); + + factory->addFiltersFromConfig(config, config_prefix); + + return factory; +} + +} diff --git a/contrib/clickhouse/src/Server/StaticRequestHandler.h b/contrib/clickhouse/src/Server/StaticRequestHandler.h new file mode 100644 index 0000000000..df9374d440 --- /dev/null +++ b/contrib/clickhouse/src/Server/StaticRequestHandler.h @@ -0,0 +1,35 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandler.h> +#include <base/types.h> + + +namespace DB +{ + +class IServer; +class WriteBuffer; + +/// Response with custom string. Can be used for browser. +class StaticRequestHandler : public HTTPRequestHandler +{ +private: + IServer & server; + + int status; + String content_type; + String response_expression; + +public: + StaticRequestHandler( + IServer & server, + const String & expression, + int status_ = 200, + const String & content_type_ = "text/html; charset=UTF-8"); + + void writeResponse(WriteBuffer & out); + + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; +}; + +} diff --git a/contrib/clickhouse/src/Server/TCPHandler.cpp b/contrib/clickhouse/src/Server/TCPHandler.cpp new file mode 100644 index 0000000000..136f2dd953 --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPHandler.cpp @@ -0,0 +1,2151 @@ +#include <algorithm> +#include <iterator> +#include <memory> +#include <mutex> +#include <vector> +#include <string_view> +#include <cstring> +#include <base/types.h> +#include <base/scope_guard.h> +#include <Poco/Net/NetException.h> +#include <Poco/Net/SocketAddress.h> +#include <Poco/Util/LayeredConfiguration.h> +#include <Common/CurrentThread.h> +#include <Common/Stopwatch.h> +#include <Common/NetException.h> +#include <Common/setThreadName.h> +#include <Common/OpenSSLHelpers.h> +#include <IO/Progress.h> +#include <Compression/CompressedReadBuffer.h> +#include <Compression/CompressedWriteBuffer.h> +#include <IO/ReadBufferFromPocoSocket.h> +#include <IO/WriteBufferFromPocoSocket.h> +#include <IO/LimitReadBuffer.h> +#include <IO/ReadHelpers.h> +#include <IO/WriteHelpers.h> +#include <Formats/NativeReader.h> +#include <Formats/NativeWriter.h> +#include <Interpreters/executeQuery.h> +#include <Interpreters/TablesStatus.h> +#include <Interpreters/InternalTextLogsQueue.h> +#include <Interpreters/OpenTelemetrySpanLog.h> +#include <Interpreters/Session.h> +#include <Server/TCPServer.h> +#include <Storages/StorageReplicatedMergeTree.h> +#include <Storages/MergeTree/MergeTreeDataPartUUID.h> +#include <Storages/StorageS3Cluster.h> +#include <Core/ExternalTable.h> +#include <Core/ServerSettings.h> +#include <Access/AccessControl.h> +#include <Access/Credentials.h> +#include <DataTypes/DataTypeLowCardinality.h> +#include <Compression/CompressionFactory.h> +#include <Common/logger_useful.h> +#include <Common/CurrentMetrics.h> +#include <Common/thread_local_rng.h> +#include <fmt/format.h> + +#include <Processors/Executors/PullingAsyncPipelineExecutor.h> +#include <Processors/Executors/PushingPipelineExecutor.h> +#include <Processors/Executors/PushingAsyncPipelineExecutor.h> +#include <Processors/Executors/CompletedPipelineExecutor.h> +#include <Processors/Sinks/SinkToStorage.h> + +#if USE_SSL +# include <Poco/Net/SecureStreamSocket.h> +# include <Poco/Net/SecureStreamSocketImpl.h> +#endif + +#include "Core/Protocol.h" +#include "Storages/MergeTree/RequestResponse.h" +#include "TCPHandler.h" + +#include "config_version.h" + +using namespace std::literals; +using namespace DB; + + +namespace CurrentMetrics +{ + extern const Metric QueryThread; + extern const Metric ReadTaskRequestsSent; + extern const Metric MergeTreeReadTaskRequestsSent; + extern const Metric MergeTreeAllRangesAnnouncementsSent; +} + +namespace ProfileEvents +{ + extern const Event ReadTaskRequestsSent; + extern const Event MergeTreeReadTaskRequestsSent; + extern const Event MergeTreeAllRangesAnnouncementsSent; + extern const Event ReadTaskRequestsSentElapsedMicroseconds; + extern const Event MergeTreeReadTaskRequestsSentElapsedMicroseconds; + extern const Event MergeTreeAllRangesAnnouncementsSentElapsedMicroseconds; +} + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int ATTEMPT_TO_READ_AFTER_EOF; + extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT; + extern const int UNKNOWN_EXCEPTION; + extern const int UNKNOWN_PACKET_FROM_CLIENT; + extern const int POCO_EXCEPTION; + extern const int SOCKET_TIMEOUT; + extern const int UNEXPECTED_PACKET_FROM_CLIENT; + extern const int UNKNOWN_PROTOCOL; + extern const int AUTHENTICATION_FAILED; + extern const int QUERY_WAS_CANCELLED; + extern const int CLIENT_INFO_DOES_NOT_MATCH; +} + +namespace +{ +NameToNameMap convertToQueryParameters(const Settings & passed_params) +{ + NameToNameMap query_parameters; + for (const auto & param : passed_params) + { + std::string value; + ReadBufferFromOwnString buf(param.getValueString()); + readQuoted(value, buf); + query_parameters.emplace(param.getName(), value); + } + return query_parameters; +} + +// This function corrects the wrong client_name from the old client. +// Old clients 28.7 and some intermediate versions of 28.7 were sending different ClientInfo.client_name +// "ClickHouse client" was sent with the hello message. +// "ClickHouse" or "ClickHouse " was sent with the query message. +void correctQueryClientInfo(const ClientInfo & session_client_info, ClientInfo & client_info) +{ + if (client_info.getVersionNumber() <= VersionNumber(23, 8, 1) && + session_client_info.client_name == "ClickHouse client" && + (client_info.client_name == "ClickHouse" || client_info.client_name == "ClickHouse ")) + { + client_info.client_name = "ClickHouse client"; + } +} + +void validateClientInfo(const ClientInfo & session_client_info, const ClientInfo & client_info) +{ + // Secondary query may contain different client_info. + // In the case of select from distributed table or 'select * from remote' from non-tcp handler. Server sends the initial client_info data. + // + // Example 1: curl -q -s --max-time 60 -sS "http://127.0.0.1:8123/?" -d "SELECT 1 FROM remote('127.0.0.1', system.one)" + // HTTP handler initiates TCP connection with remote 127.0.0.1 (session on remote 127.0.0.1 use TCP interface) + // HTTP handler sends client_info with HTTP interface and HTTP data by TCP protocol in Protocol::Client::Query message. + // + // Example 2: select * from <distributed_table> --host shard_1 // distributed table has 2 shards: shard_1, shard_2 + // shard_1 receives a message with 'ClickHouse client' client_name + // shard_1 initiates TCP connection with shard_2 with 'ClickHouse server' client_name. + // shard_1 sends 'ClickHouse client' client_name in Protocol::Client::Query message to shard_2. + if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY) + return; + + if (session_client_info.interface != client_info.interface) + { + throw Exception( + DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH, + "Client info's interface does not match: {} not equal to {}", + toString(session_client_info.interface), + toString(client_info.interface)); + } + + if (session_client_info.interface == ClientInfo::Interface::TCP) + { + if (session_client_info.client_name != client_info.client_name) + throw Exception( + DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH, + "Client info's client_name does not match: {} not equal to {}", + session_client_info.client_name, + client_info.client_name); + + // TCP handler got patch version 0 always for backward compatibility. + if (!session_client_info.clientVersionEquals(client_info, false)) + throw Exception( + DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH, + "Client info's version does not match: {} not equal to {}", + session_client_info.getVersionStr(), + client_info.getVersionStr()); + + // os_user, quota_key, client_trace_context can be different. + } +} +} + +namespace DB +{ + +TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_) + : Poco::Net::TCPServerConnection(socket_) + , server(server_) + , tcp_server(tcp_server_) + , parse_proxy_protocol(parse_proxy_protocol_) + , log(&Poco::Logger::get("TCPHandler")) + , server_display_name(std::move(server_display_name_)) +{ +} + +TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_) +: Poco::Net::TCPServerConnection(socket_) + , server(server_) + , tcp_server(tcp_server_) + , log(&Poco::Logger::get("TCPHandler")) + , forwarded_for(stack_data.forwarded_for) + , certificate(stack_data.certificate) + , default_database(stack_data.default_database) + , server_display_name(std::move(server_display_name_)) +{ + if (!forwarded_for.empty()) + LOG_TRACE(log, "Forwarded client address: {}", forwarded_for); +} + +TCPHandler::~TCPHandler() +{ + try + { + state.reset(); + if (out) + out->next(); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +void TCPHandler::runImpl() +{ + setThreadName("TCPHandler"); + ThreadStatus thread_status; + + extractConnectionSettingsFromContext(server.context()); + + socket().setReceiveTimeout(receive_timeout); + socket().setSendTimeout(send_timeout); + socket().setNoDelay(true); + + in = std::make_shared<ReadBufferFromPocoSocket>(socket()); + out = std::make_shared<WriteBufferFromPocoSocket>(socket()); + + /// Support for PROXY protocol + if (parse_proxy_protocol && !receiveProxyHeader()) + return; + + if (in->eof()) + { + LOG_INFO(log, "Client has not sent any data."); + return; + } + + /// User will be authenticated here. It will also set settings from user profile into connection_context. + try + { + receiveHello(); + + /// In interserver mode queries are executed without a session context. + if (!is_interserver_mode) + session->makeSessionContext(); + + sendHello(); + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) + receiveAddendum(); + + if (!is_interserver_mode) + { + /// If session created, then settings in session context has been updated. + /// So it's better to update the connection settings for flexibility. + extractConnectionSettingsFromContext(session->sessionContext()); + + /// When connecting, the default database could be specified. + if (!default_database.empty()) + session->sessionContext()->setCurrentDatabase(default_database); + } + } + catch (const Exception & e) /// Typical for an incorrect username, password, or address. + { + if (e.code() == ErrorCodes::CLIENT_HAS_CONNECTED_TO_WRONG_PORT) + { + LOG_DEBUG(log, "Client has connected to wrong port."); + return; + } + + if (e.code() == ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF) + { + LOG_INFO(log, "Client has gone away."); + return; + } + + try + { + /// We try to send error information to the client. + sendException(e, send_exception_with_stack_trace); + } + catch (...) {} + + throw; + } + + while (tcp_server.isOpen()) + { + /// We are waiting for a packet from the client. Thus, every `poll_interval` seconds check whether we need to shut down. + { + Stopwatch idle_time; + UInt64 timeout_ms = std::min(poll_interval, idle_connection_timeout) * 1000000; + while (tcp_server.isOpen() && !server.isCancelled() && !static_cast<ReadBufferFromPocoSocket &>(*in).poll(timeout_ms)) + { + if (idle_time.elapsedSeconds() > idle_connection_timeout) + { + LOG_TRACE(log, "Closing idle connection"); + return; + } + } + } + + /// If we need to shut down, or client disconnects. + if (!tcp_server.isOpen() || server.isCancelled() || in->eof()) + { + LOG_TEST(log, "Closing connection (open: {}, cancelled: {}, eof: {})", tcp_server.isOpen(), server.isCancelled(), in->eof()); + break; + } + + state.reset(); + + /// Initialized later. + std::optional<CurrentThread::QueryScope> query_scope; + OpenTelemetry::TracingContextHolderPtr thread_trace_context; + + /** An exception during the execution of request (it must be sent over the network to the client). + * The client will be able to accept it, if it did not happen while sending another packet and the client has not disconnected yet. + */ + std::unique_ptr<DB::Exception> exception; + bool network_error = false; + bool query_duration_already_logged = false; + auto log_query_duration = [this, &query_duration_already_logged]() + { + if (query_duration_already_logged) + return; + query_duration_already_logged = true; + auto elapsed_sec = state.watch.elapsedSeconds(); + /// We already logged more detailed info if we read some rows + if (elapsed_sec < 1.0 && state.progress.read_rows) + return; + LOG_DEBUG(log, "Processed in {} sec.", elapsed_sec); + }; + + try + { + /// If a user passed query-local timeouts, reset socket to initial state at the end of the query + SCOPE_EXIT({state.timeout_setter.reset();}); + + /** If Query - process it. If Ping or Cancel - go back to the beginning. + * There may come settings for a separate query that modify `query_context`. + * It's possible to receive part uuids packet before the query, so then receivePacket has to be called twice. + */ + if (!receivePacket()) + continue; + + /** If part_uuids got received in previous packet, trying to read again. + */ + if (state.empty() && state.part_uuids_to_ignore && !receivePacket()) + continue; + + /// Set up tracing context for this query on current thread + thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("TCPHandler", + query_context->getClientInfo().client_trace_context, + query_context->getSettingsRef(), + query_context->getOpenTelemetrySpanLog()); + thread_trace_context->root_span.kind = OpenTelemetry::SERVER; + + query_scope.emplace(query_context, /* fatal_error_callback */ [this] + { + std::lock_guard lock(fatal_error_mutex); + sendLogs(); + }); + + /// If query received, then settings in query_context has been updated. + /// So it's better to update the connection settings for flexibility. + extractConnectionSettingsFromContext(query_context); + + /// Sync timeouts on client and server during current query to avoid dangling queries on server + /// NOTE: We use send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter), + /// because send_timeout is client-side setting which has opposite meaning on the server side. + /// NOTE: these settings are applied only for current connection (not for distributed tables' connections) + state.timeout_setter = std::make_unique<TimeoutSetter>(socket(), receive_timeout, send_timeout); + + /// Should we send internal logs to client? + const auto client_logs_level = query_context->getSettingsRef().send_logs_level; + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_LOGS + && client_logs_level != LogsLevel::none) + { + state.logs_queue = std::make_shared<InternalTextLogsQueue>(); + state.logs_queue->max_priority = Poco::Logger::parseLevel(client_logs_level.toString()); + state.logs_queue->setSourceRegexp(query_context->getSettingsRef().send_logs_source_regexp); + CurrentThread::attachInternalTextLogsQueue(state.logs_queue, client_logs_level); + } + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS) + { + state.profile_queue = std::make_shared<InternalProfileEventsQueue>(std::numeric_limits<int>::max()); + CurrentThread::attachInternalProfileEventsQueue(state.profile_queue); + } + + query_context->setExternalTablesInitializer([this] (ContextPtr context) + { + if (context != query_context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in external tables initializer"); + + /// Get blocks of temporary tables + readData(); + + /// Reset the input stream, as we received an empty block while receiving external table data. + /// So, the stream has been marked as cancelled and we can't read from it anymore. + state.block_in.reset(); + state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker. + }); + + /// Send structure of columns to client for function input() + query_context->setInputInitializer([this] (ContextPtr context, const StoragePtr & input_storage) + { + if (context != query_context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in Input initializer"); + + auto metadata_snapshot = input_storage->getInMemoryMetadataPtr(); + state.need_receive_data_for_input = true; + + /// Send ColumnsDescription for input storage. + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA + && query_context->getSettingsRef().input_format_defaults_for_omitted_fields) + { + sendTableColumns(metadata_snapshot->getColumns()); + } + + /// Send block to the client - input storage structure. + state.input_header = metadata_snapshot->getSampleBlock(); + sendData(state.input_header); + sendTimezone(); + }); + + query_context->setInputBlocksReaderCallback([this] (ContextPtr context) -> Block + { + if (context != query_context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected context in InputBlocksReader"); + + if (!readDataNext()) + { + state.block_in.reset(); + state.maybe_compressed_in.reset(); + return Block(); + } + return state.block_for_input; + }); + + customizeContext(query_context); + + /// This callback is needed for requesting read tasks inside pipeline for distributed processing + query_context->setReadTaskCallback([this]() -> String + { + Stopwatch watch; + CurrentMetrics::Increment callback_metric_increment(CurrentMetrics::ReadTaskRequestsSent); + + std::lock_guard lock(task_callback_mutex); + + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + return {}; + + sendReadTaskRequestAssumeLocked(); + ProfileEvents::increment(ProfileEvents::ReadTaskRequestsSent); + auto res = receiveReadTaskResponseAssumeLocked(); + ProfileEvents::increment(ProfileEvents::ReadTaskRequestsSentElapsedMicroseconds, watch.elapsedMicroseconds()); + return res; + }); + + query_context->setMergeTreeAllRangesCallback([this](InitialAllRangesAnnouncement announcement) + { + Stopwatch watch; + CurrentMetrics::Increment callback_metric_increment(CurrentMetrics::MergeTreeAllRangesAnnouncementsSent); + std::lock_guard lock(task_callback_mutex); + + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + return; + + sendMergeTreeAllRangesAnnounecementAssumeLocked(announcement); + ProfileEvents::increment(ProfileEvents::MergeTreeAllRangesAnnouncementsSent); + ProfileEvents::increment(ProfileEvents::MergeTreeAllRangesAnnouncementsSentElapsedMicroseconds, watch.elapsedMicroseconds()); + }); + + query_context->setMergeTreeReadTaskCallback([this](ParallelReadRequest request) -> std::optional<ParallelReadResponse> + { + Stopwatch watch; + CurrentMetrics::Increment callback_metric_increment(CurrentMetrics::MergeTreeReadTaskRequestsSent); + std::lock_guard lock(task_callback_mutex); + + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + return std::nullopt; + + sendMergeTreeReadTaskRequestAssumeLocked(std::move(request)); + ProfileEvents::increment(ProfileEvents::MergeTreeReadTaskRequestsSent); + auto res = receivePartitionMergeTreeReadTaskResponseAssumeLocked(); + ProfileEvents::increment(ProfileEvents::MergeTreeReadTaskRequestsSentElapsedMicroseconds, watch.elapsedMicroseconds()); + return res; + }); + + /// Processing Query + state.io = executeQuery(state.query, query_context, false, state.stage); + + after_check_cancelled.restart(); + after_send_progress.restart(); + + auto finish_or_cancel = [this]() + { + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + state.io.onCancelOrConnectionLoss(); + else + state.io.onFinish(); + }; + + if (state.io.pipeline.pushing()) + { + /// FIXME: check explicitly that insert query suggests to receive data via native protocol, + state.need_receive_data_for_insert = true; + processInsertQuery(); + finish_or_cancel(); + } + else if (state.io.pipeline.pulling()) + { + processOrdinaryQueryWithProcessors(); + finish_or_cancel(); + } + else if (state.io.pipeline.completed()) + { + { + CompletedPipelineExecutor executor(state.io.pipeline); + + /// Should not check for cancel in case of input. + if (!state.need_receive_data_for_input) + { + auto callback = [this]() + { + std::scoped_lock lock(task_callback_mutex, fatal_error_mutex); + + if (getQueryCancellationStatus() == CancellationStatus::FULLY_CANCELLED) + return true; + + sendProgress(); + sendSelectProfileEvents(); + sendLogs(); + + return false; + }; + + executor.setCancelCallback(callback, interactive_delay / 1000); + } + executor.execute(); + } + + finish_or_cancel(); + + std::lock_guard lock(task_callback_mutex); + + /// Send final progress after calling onFinish(), since it will update the progress. + /// + /// NOTE: we cannot send Progress for regular INSERT (with VALUES) + /// without breaking protocol compatibility, but it can be done + /// by increasing revision. + sendProgress(); + sendSelectProfileEvents(); + } + else + { + finish_or_cancel(); + } + + /// Do it before sending end of stream, to have a chance to show log message in client. + query_scope->logPeakMemoryUsage(); + log_query_duration(); + + if (state.is_connection_closed) + break; + + { + std::lock_guard lock(task_callback_mutex); + sendLogs(); + sendEndOfStream(); + } + + /// QueryState should be cleared before QueryScope, since otherwise + /// the MemoryTracker will be wrong for possible deallocations. + /// (i.e. deallocations from the Aggregator with two-level aggregation) + state.reset(); + last_sent_snapshots = ProfileEvents::ThreadIdToCountersSnapshot{}; + query_scope.reset(); + thread_trace_context.reset(); + } + catch (const Exception & e) + { + state.io.onException(); + exception.reset(e.clone()); + + if (e.code() == ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT) + throw; + + /// If there is UNEXPECTED_PACKET_FROM_CLIENT emulate network_error + /// to break the loop, but do not throw to send the exception to + /// the client. + if (e.code() == ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT) + network_error = true; + + /// If a timeout occurred, try to inform client about it and close the session + if (e.code() == ErrorCodes::SOCKET_TIMEOUT) + network_error = true; + + if (network_error) + LOG_TEST(log, "Going to close connection due to exception: {}", e.message()); + } + catch (const Poco::Net::NetException & e) + { + /** We can get here if there was an error during connection to the client, + * or in connection with a remote server that was used to process the request. + * It is not possible to distinguish between these two cases. + * Although in one of them, we have to send exception to the client, but in the other - we can not. + * We will try to send exception to the client in any case - see below. + */ + state.io.onException(); + exception = std::make_unique<DB::Exception>(Exception::CreateFromPocoTag{}, e); + } + catch (const Poco::Exception & e) + { + state.io.onException(); + exception = std::make_unique<DB::Exception>(Exception::CreateFromPocoTag{}, e); + } +// Server should die on std logic errors in debug, like with assert() +// or ErrorCodes::LOGICAL_ERROR. This helps catch these errors in +// tests. +#ifdef ABORT_ON_LOGICAL_ERROR + catch (const std::logic_error & e) + { + state.io.onException(); + exception = std::make_unique<DB::Exception>(Exception::CreateFromSTDTag{}, e); + sendException(*exception, send_exception_with_stack_trace); + std::abort(); + } +#endif + catch (const std::exception & e) + { + state.io.onException(); + exception = std::make_unique<DB::Exception>(Exception::CreateFromSTDTag{}, e); + } + catch (...) + { + state.io.onException(); + exception = std::make_unique<DB::Exception>(ErrorCodes::UNKNOWN_EXCEPTION, "Unknown exception"); + } + + try + { + if (exception) + { + if (thread_trace_context) + thread_trace_context->root_span.addAttribute(*exception); + + try + { + /// Try to send logs to client, but it could be risky too + /// Assume that we can't break output here + sendLogs(); + } + catch (...) + { + tryLogCurrentException(log, "Can't send logs to client"); + } + + const auto & e = *exception; + LOG_ERROR(log, getExceptionMessageAndPattern(e, send_exception_with_stack_trace)); + sendException(*exception, send_exception_with_stack_trace); + } + } + catch (...) + { + /** Could not send exception information to the client. */ + network_error = true; + LOG_WARNING(log, "Client has gone away."); + } + + try + { + /// A query packet is always followed by one or more data packets. + /// If some of those data packets are left, try to skip them. + if (exception && !state.empty() && !state.read_all_data) + skipData(); + } + catch (...) + { + network_error = true; + LOG_WARNING(log, "Can't skip data packets after query failure."); + } + + log_query_duration(); + + /// QueryState should be cleared before QueryScope, since otherwise + /// the MemoryTracker will be wrong for possible deallocations. + /// (i.e. deallocations from the Aggregator with two-level aggregation) + state.reset(); + query_scope.reset(); + thread_trace_context.reset(); + + /// It is important to destroy query context here. We do not want it to live arbitrarily longer than the query. + query_context.reset(); + + if (is_interserver_mode) + { + /// We don't really have session in interserver mode, new one is created for each query. It's better to reset it now. + session.reset(); + } + + if (network_error) + break; + } +} + + +void TCPHandler::extractConnectionSettingsFromContext(const ContextPtr & context) +{ + const auto & settings = context->getSettingsRef(); + send_exception_with_stack_trace = settings.calculate_text_stack_trace; + send_timeout = settings.send_timeout; + receive_timeout = settings.receive_timeout; + poll_interval = settings.poll_interval; + idle_connection_timeout = settings.idle_connection_timeout; + interactive_delay = settings.interactive_delay; + sleep_in_send_tables_status = settings.sleep_in_send_tables_status_ms; + unknown_packet_in_send_data = settings.unknown_packet_in_send_data; + sleep_after_receiving_query = settings.sleep_after_receiving_query_ms; +} + + +bool TCPHandler::readDataNext() +{ + Stopwatch watch(CLOCK_MONOTONIC_COARSE); + + /// Poll interval should not be greater than receive_timeout + constexpr UInt64 min_timeout_us = 5000; // 5 ms + UInt64 timeout_us = std::max(min_timeout_us, std::min(poll_interval * 1000000, static_cast<UInt64>(receive_timeout.totalMicroseconds()))); + bool read_ok = false; + + /// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down. + while (true) + { + if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(timeout_us)) + { + /// If client disconnected. + if (in->eof()) + { + LOG_INFO(log, "Client has dropped the connection, cancel the query."); + state.is_connection_closed = true; + state.cancellation_status = CancellationStatus::FULLY_CANCELLED; + break; + } + + /// We accept and process data. + read_ok = receivePacket(); + break; + } + + /// Do we need to shut down? + if (server.isCancelled()) + break; + + /** Have we waited for data for too long? + * If we periodically poll, the receive_timeout of the socket itself does not work. + * Therefore, an additional check is added. + */ + Float64 elapsed = watch.elapsedSeconds(); + if (elapsed > static_cast<Float64>(receive_timeout.totalSeconds())) + { + throw Exception(ErrorCodes::SOCKET_TIMEOUT, + "Timeout exceeded while receiving data from client. Waited for {} seconds, timeout is {} seconds.", + static_cast<size_t>(elapsed), receive_timeout.totalSeconds()); + } + } + + if (read_ok) + { + sendLogs(); + sendInsertProfileEvents(); + } + else + state.read_all_data = true; + + return read_ok; +} + + +void TCPHandler::readData() +{ + sendLogs(); + + while (readDataNext()) + ; + + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + throw Exception(ErrorCodes::QUERY_WAS_CANCELLED, "Query was cancelled"); +} + + +void TCPHandler::skipData() +{ + state.skipping_data = true; + SCOPE_EXIT({ state.skipping_data = false; }); + + while (readDataNext()) + ; + + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + throw Exception(ErrorCodes::QUERY_WAS_CANCELLED, "Query was cancelled"); +} + + +void TCPHandler::processInsertQuery() +{ + size_t num_threads = state.io.pipeline.getNumThreads(); + + auto run_executor = [&](auto & executor) + { + /// Made above the rest of the lines, + /// so that in case of `writePrefix` function throws an exception, + /// client receive exception before sending data. + executor.start(); + + /// Send ColumnsDescription for insertion table + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_COLUMN_DEFAULTS_METADATA) + { + const auto & table_id = query_context->getInsertionTable(); + if (query_context->getSettingsRef().input_format_defaults_for_omitted_fields) + { + if (!table_id.empty()) + { + auto storage_ptr = DatabaseCatalog::instance().getTable(table_id, query_context); + sendTableColumns(storage_ptr->getInMemoryMetadataPtr()->getColumns()); + } + } + } + + /// Send block to the client - table structure. + sendData(executor.getHeader()); + sendLogs(); + + while (readDataNext()) + executor.push(std::move(state.block_for_insert)); + + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED) + executor.cancel(); + else + executor.finish(); + }; + + if (num_threads > 1) + { + PushingAsyncPipelineExecutor executor(state.io.pipeline); + run_executor(executor); + } + else + { + PushingPipelineExecutor executor(state.io.pipeline); + run_executor(executor); + } + + sendInsertProfileEvents(); +} + + +void TCPHandler::processOrdinaryQueryWithProcessors() +{ + auto & pipeline = state.io.pipeline; + + if (query_context->getSettingsRef().allow_experimental_query_deduplication) + { + std::lock_guard lock(task_callback_mutex); + sendPartUUIDs(); + } + + /// Send header-block, to allow client to prepare output format for data to send. + { + const auto & header = pipeline.getHeader(); + + if (header) + { + std::lock_guard lock(task_callback_mutex); + sendData(header); + } + } + + /// Defer locking to cover a part of the scope below and everything after it + std::unique_lock progress_lock(task_callback_mutex, std::defer_lock); + + { + PullingAsyncPipelineExecutor executor(pipeline); + CurrentMetrics::Increment query_thread_metric_increment{CurrentMetrics::QueryThread}; + + Block block; + while (executor.pull(block, interactive_delay / 1000)) + { + std::unique_lock lock(task_callback_mutex); + + auto cancellation_status = getQueryCancellationStatus(); + if (cancellation_status == CancellationStatus::FULLY_CANCELLED) + { + /// Several callback like callback for parallel reading could be called from inside the pipeline + /// and we have to unlock the mutex from our side to prevent deadlock. + lock.unlock(); + /// A packet was received requesting to stop execution of the request. + executor.cancel(); + break; + } + else if (cancellation_status == CancellationStatus::READ_CANCELLED) + { + executor.cancelReading(); + } + + if (after_send_progress.elapsed() / 1000 >= interactive_delay) + { + /// Some time passed and there is a progress. + after_send_progress.restart(); + sendProgress(); + sendSelectProfileEvents(); + } + + sendLogs(); + + if (block) + { + if (!state.io.null_format) + sendData(block); + } + } + + /// This lock wasn't acquired before and we make .lock() call here + /// so everything under this line is covered even together + /// with sendProgress() out of the scope + progress_lock.lock(); + + /** If data has run out, we will send the profiling data and total values to + * the last zero block to be able to use + * this information in the suffix output of stream. + * If the request was interrupted, then `sendTotals` and other methods could not be called, + * because we have not read all the data yet, + * and there could be ongoing calculations in other threads at the same time. + */ + if (getQueryCancellationStatus() != CancellationStatus::FULLY_CANCELLED) + { + sendTotals(executor.getTotalsBlock()); + sendExtremes(executor.getExtremesBlock()); + sendProfileInfo(executor.getProfileInfo()); + sendProgress(); + sendLogs(); + sendSelectProfileEvents(); + } + + if (state.is_connection_closed) + return; + + sendData({}); + last_sent_snapshots.clear(); + } + + sendProgress(); +} + + +void TCPHandler::processTablesStatusRequest() +{ + TablesStatusRequest request; + request.read(*in, client_tcp_protocol_version); + + ContextPtr context_to_resolve_table_names; + if (is_interserver_mode) + { + /// In interserver mode session context does not exists, because authentication is done for each query. + /// We also cannot create query context earlier, because it cannot be created before authentication, + /// but query is not received yet. So we have to do this trick. + ContextMutablePtr fake_interserver_context = Context::createCopy(server.context()); + if (!default_database.empty()) + fake_interserver_context->setCurrentDatabase(default_database); + context_to_resolve_table_names = fake_interserver_context; + } + else + { + assert(session); + context_to_resolve_table_names = session->sessionContext(); + } + + TablesStatusResponse response; + for (const QualifiedTableName & table_name: request.tables) + { + auto resolved_id = context_to_resolve_table_names->tryResolveStorageID({table_name.database, table_name.table}); + StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, context_to_resolve_table_names); + if (!table) + continue; + + TableStatus status; + if (auto * replicated_table = dynamic_cast<StorageReplicatedMergeTree *>(table.get())) + { + status.is_replicated = true; + status.absolute_delay = static_cast<UInt32>(replicated_table->getAbsoluteDelay()); + } + else + status.is_replicated = false; + + response.table_states_by_id.emplace(table_name, std::move(status)); + } + + + writeVarUInt(Protocol::Server::TablesStatusResponse, *out); + + /// For testing hedged requests + if (unlikely(sleep_in_send_tables_status.totalMilliseconds())) + { + out->next(); + std::chrono::milliseconds ms(sleep_in_send_tables_status.totalMilliseconds()); + std::this_thread::sleep_for(ms); + } + + response.write(*out, client_tcp_protocol_version); +} + +void TCPHandler::receiveUnexpectedTablesStatusRequest() +{ + TablesStatusRequest skip_request; + skip_request.read(*in, client_tcp_protocol_version); + + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet TablesStatusRequest received from client"); +} + +void TCPHandler::sendPartUUIDs() +{ + auto uuids = query_context->getPartUUIDs()->get(); + if (!uuids.empty()) + { + for (const auto & uuid : uuids) + LOG_TRACE(log, "Sending UUID: {}", toString(uuid)); + + writeVarUInt(Protocol::Server::PartUUIDs, *out); + writeVectorBinary(uuids, *out); + out->next(); + } +} + + +void TCPHandler::sendReadTaskRequestAssumeLocked() +{ + writeVarUInt(Protocol::Server::ReadTaskRequest, *out); + out->next(); +} + + +void TCPHandler::sendMergeTreeAllRangesAnnounecementAssumeLocked(InitialAllRangesAnnouncement announcement) +{ + writeVarUInt(Protocol::Server::MergeTreeAllRangesAnnounecement, *out); + announcement.serialize(*out); + out->next(); +} + + +void TCPHandler::sendMergeTreeReadTaskRequestAssumeLocked(ParallelReadRequest request) +{ + writeVarUInt(Protocol::Server::MergeTreeReadTaskRequest, *out); + request.serialize(*out); + out->next(); +} + + +void TCPHandler::sendProfileInfo(const ProfileInfo & info) +{ + writeVarUInt(Protocol::Server::ProfileInfo, *out); + info.write(*out); + out->next(); +} + + +void TCPHandler::sendTotals(const Block & totals) +{ + if (totals) + { + initBlockOutput(totals); + + writeVarUInt(Protocol::Server::Totals, *out); + writeStringBinary("", *out); + + state.block_out->write(totals); + state.maybe_compressed_out->next(); + out->next(); + } +} + + +void TCPHandler::sendExtremes(const Block & extremes) +{ + if (extremes) + { + initBlockOutput(extremes); + + writeVarUInt(Protocol::Server::Extremes, *out); + writeStringBinary("", *out); + + state.block_out->write(extremes); + state.maybe_compressed_out->next(); + out->next(); + } +} + +void TCPHandler::sendProfileEvents() +{ + Block block; + ProfileEvents::getProfileEvents(server_display_name, state.profile_queue, block, last_sent_snapshots); + if (block.rows() != 0) + { + initProfileEventsBlockOutput(block); + + writeVarUInt(Protocol::Server::ProfileEvents, *out); + writeStringBinary("", *out); + + state.profile_events_block_out->write(block); + out->next(); + } +} + +void TCPHandler::sendSelectProfileEvents() +{ + if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS) + return; + + sendProfileEvents(); +} + +void TCPHandler::sendInsertProfileEvents() +{ + if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_PROFILE_EVENTS_IN_INSERT) + return; + if (query_kind != ClientInfo::QueryKind::INITIAL_QUERY) + return; + + sendProfileEvents(); +} + +void TCPHandler::sendTimezone() +{ + if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_TIMEZONE_UPDATES) + return; + + const String & tz = query_context->getSettingsRef().session_timezone.value; + + LOG_DEBUG(log, "TCPHandler::sendTimezone(): {}", tz); + writeVarUInt(Protocol::Server::TimezoneUpdate, *out); + writeStringBinary(tz, *out); + out->next(); +} + + +bool TCPHandler::receiveProxyHeader() +{ + if (in->eof()) + { + LOG_WARNING(log, "Client has not sent any data."); + return false; + } + + String forwarded_address; + + /// Only PROXYv1 is supported. + /// Validation of protocol is not fully performed. + + LimitReadBuffer limit_in(*in, 107, /* trow_exception */ true, /* exact_limit */ {}); /// Maximum length from the specs. + + assertString("PROXY ", limit_in); + + if (limit_in.eof()) + { + LOG_WARNING(log, "Incomplete PROXY header is received."); + return false; + } + + /// TCP4 / TCP6 / UNKNOWN + if ('T' == *limit_in.position()) + { + assertString("TCP", limit_in); + + if (limit_in.eof()) + { + LOG_WARNING(log, "Incomplete PROXY header is received."); + return false; + } + + if ('4' != *limit_in.position() && '6' != *limit_in.position()) + { + LOG_WARNING(log, "Unexpected protocol in PROXY header is received."); + return false; + } + + ++limit_in.position(); + assertChar(' ', limit_in); + + /// Read the first field and ignore other. + readStringUntilWhitespace(forwarded_address, limit_in); + + /// Skip until \r\n + while (!limit_in.eof() && *limit_in.position() != '\r') + ++limit_in.position(); + assertString("\r\n", limit_in); + } + else if (checkString("UNKNOWN", limit_in)) + { + /// This is just a health check, there is no subsequent data in this connection. + + while (!limit_in.eof() && *limit_in.position() != '\r') + ++limit_in.position(); + assertString("\r\n", limit_in); + return false; + } + else + { + LOG_WARNING(log, "Unexpected protocol in PROXY header is received."); + return false; + } + + LOG_TRACE(log, "Forwarded client address from PROXY header: {}", forwarded_address); + forwarded_for = std::move(forwarded_address); + return true; +} + + +namespace +{ + +std::string formatHTTPErrorResponseWhenUserIsConnectedToWrongPort(const Poco::Util::AbstractConfiguration& config) +{ + std::string result = fmt::format( + "HTTP/1.0 400 Bad Request\r\n\r\n" + "Port {} is for clickhouse-client program\r\n", + config.getString("tcp_port")); + + if (config.has("http_port")) + { + result += fmt::format( + "You must use port {} for HTTP.\r\n", + config.getString("http_port")); + } + + return result; +} + +} + +std::unique_ptr<Session> TCPHandler::makeSession() +{ + auto interface = is_interserver_mode ? ClientInfo::Interface::TCP_INTERSERVER : ClientInfo::Interface::TCP; + + auto res = std::make_unique<Session>(server.context(), interface, socket().secure(), certificate); + + res->setForwardedFor(forwarded_for); + res->setClientName(client_name); + res->setClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); + res->setConnectionClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); + res->setClientInterface(interface); + + return res; +} + +void TCPHandler::receiveHello() +{ + /// Receive `hello` packet. + UInt64 packet_type = 0; + String user; + String password; + String default_db; + + readVarUInt(packet_type, *in); + if (packet_type != Protocol::Client::Hello) + { + /** If you accidentally accessed the HTTP protocol for a port destined for an internal TCP protocol, + * Then instead of the packet type, there will be G (GET) or P (POST), in most cases. + */ + if (packet_type == 'G' || packet_type == 'P') + { + writeString(formatHTTPErrorResponseWhenUserIsConnectedToWrongPort(server.config()), *out); + throw Exception(ErrorCodes::CLIENT_HAS_CONNECTED_TO_WRONG_PORT, "Client has connected to wrong port"); + } + else + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, + "Unexpected packet from client (expected Hello, got {})", packet_type); + } + + readStringBinary(client_name, *in); + readVarUInt(client_version_major, *in); + readVarUInt(client_version_minor, *in); + // NOTE For backward compatibility of the protocol, client cannot send its version_patch. + readVarUInt(client_tcp_protocol_version, *in); + readStringBinary(default_db, *in); + if (!default_db.empty()) + default_database = default_db; + readStringBinary(user, *in); + readStringBinary(password, *in); + + if (user.empty()) + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet from client (no user in Hello package)"); + + LOG_DEBUG(log, "Connected {} version {}.{}.{}, revision: {}{}{}.", + client_name, + client_version_major, client_version_minor, client_version_patch, + client_tcp_protocol_version, + (!default_database.empty() ? ", database: " + default_database : ""), + (!user.empty() ? ", user: " + user : "") + ); + + is_interserver_mode = (user == USER_INTERSERVER_MARKER) && password.empty(); + if (is_interserver_mode) + { + if (client_tcp_protocol_version < DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2) + LOG_WARNING(LogFrequencyLimiter(log, 10), + "Using deprecated interserver protocol because the client is too old. Consider upgrading all nodes in cluster."); + receiveClusterNameAndSalt(); + return; + } + + session = makeSession(); + const auto & client_info = session->getClientInfo(); + +#if USE_SSL + /// Authentication with SSL user certificate + if (dynamic_cast<Poco::Net::SecureStreamSocketImpl*>(socket().impl())) + { + Poco::Net::SecureStreamSocket secure_socket(socket()); + if (secure_socket.havePeerCertificate()) + { + try + { + session->authenticate( + SSLCertificateCredentials{user, secure_socket.peerCertificate().commonName()}, + getClientAddress(client_info)); + return; + } + catch (...) + { + tryLogCurrentException(log, "SSL authentication failed, falling back to password authentication"); + } + } + } +#endif + + session->authenticate(user, password, getClientAddress(client_info)); +} + +void TCPHandler::receiveAddendum() +{ + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY) + readStringBinary(quota_key, *in); + + if (!is_interserver_mode) + session->setQuotaClientKey(quota_key); +} + + +void TCPHandler::receiveUnexpectedHello() +{ + UInt64 skip_uint_64; + String skip_string; + + readStringBinary(skip_string, *in); + readVarUInt(skip_uint_64, *in); + readVarUInt(skip_uint_64, *in); + readVarUInt(skip_uint_64, *in); + readStringBinary(skip_string, *in); + readStringBinary(skip_string, *in); + readStringBinary(skip_string, *in); + + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Hello received from client"); +} + + +void TCPHandler::sendHello() +{ + writeVarUInt(Protocol::Server::Hello, *out); + writeStringBinary(VERSION_NAME, *out); + writeVarUInt(VERSION_MAJOR, *out); + writeVarUInt(VERSION_MINOR, *out); + writeVarUInt(DBMS_TCP_PROTOCOL_VERSION, *out); + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE) + writeStringBinary(DateLUT::instance().getTimeZone(), *out); + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME) + writeStringBinary(server_display_name, *out); + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_VERSION_PATCH) + writeVarUInt(VERSION_PATCH, *out); + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_PASSWORD_COMPLEXITY_RULES) + { + auto rules = server.context()->getAccessControl().getPasswordComplexityRules(); + + writeVarUInt(rules.size(), *out); + for (const auto & [original_pattern, exception_message] : rules) + { + writeStringBinary(original_pattern, *out); + writeStringBinary(exception_message, *out); + } + } + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2) + { + chassert(!nonce.has_value()); + /// Contains lots of stuff (including time), so this should be enough for NONCE. + nonce.emplace(thread_local_rng()); + writeIntBinary(nonce.value(), *out); + } + out->next(); +} + + +bool TCPHandler::receivePacket() +{ + UInt64 packet_type = 0; + readVarUInt(packet_type, *in); + + switch (packet_type) + { + case Protocol::Client::IgnoredPartUUIDs: + /// Part uuids packet if any comes before query. + if (!state.empty() || state.part_uuids_to_ignore) + receiveUnexpectedIgnoredPartUUIDs(); + receiveIgnoredPartUUIDs(); + return true; + + case Protocol::Client::Query: + if (!state.empty()) + receiveUnexpectedQuery(); + receiveQuery(); + return true; + + case Protocol::Client::Data: + case Protocol::Client::Scalar: + if (state.skipping_data) + return receiveUnexpectedData(false); + if (state.empty()) + receiveUnexpectedData(true); + return receiveData(packet_type == Protocol::Client::Scalar); + + case Protocol::Client::Ping: + writeVarUInt(Protocol::Server::Pong, *out); + out->next(); + return false; + + case Protocol::Client::Cancel: + decreaseCancellationStatus("Received 'Cancel' packet from the client, canceling the query."); + return false; + + case Protocol::Client::Hello: + receiveUnexpectedHello(); + + case Protocol::Client::TablesStatusRequest: + if (!state.empty()) + receiveUnexpectedTablesStatusRequest(); + processTablesStatusRequest(); + out->next(); + return false; + + default: + throw Exception(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT, "Unknown packet {} from client", toString(packet_type)); + } +} + + +void TCPHandler::receiveIgnoredPartUUIDs() +{ + readVectorBinary(state.part_uuids_to_ignore.emplace(), *in); +} + + +void TCPHandler::receiveUnexpectedIgnoredPartUUIDs() +{ + std::vector<UUID> skip_part_uuids; + readVectorBinary(skip_part_uuids, *in); + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet IgnoredPartUUIDs received from client"); +} + + +String TCPHandler::receiveReadTaskResponseAssumeLocked() +{ + UInt64 packet_type = 0; + readVarUInt(packet_type, *in); + if (packet_type != Protocol::Client::ReadTaskResponse) + { + if (packet_type == Protocol::Client::Cancel) + { + decreaseCancellationStatus("Received 'Cancel' packet from the client, canceling the read task."); + return {}; + } + else + { + throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Received {} packet after requesting read task", + Protocol::Client::toString(packet_type)); + } + } + UInt64 version; + readVarUInt(version, *in); + if (version != DBMS_CLUSTER_PROCESSING_PROTOCOL_VERSION) + throw Exception(ErrorCodes::UNKNOWN_PROTOCOL, "Protocol version for distributed processing mismatched"); + String response; + readStringBinary(response, *in); + return response; +} + + +std::optional<ParallelReadResponse> TCPHandler::receivePartitionMergeTreeReadTaskResponseAssumeLocked() +{ + UInt64 packet_type = 0; + readVarUInt(packet_type, *in); + if (packet_type != Protocol::Client::MergeTreeReadTaskResponse) + { + if (packet_type == Protocol::Client::Cancel) + { + decreaseCancellationStatus("Received 'Cancel' packet from the client, canceling the MergeTree read task."); + return std::nullopt; + } + else + { + throw Exception(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Received {} packet after requesting read task", + Protocol::Client::toString(packet_type)); + } + } + ParallelReadResponse response; + response.deserialize(*in); + return response; +} + + +void TCPHandler::receiveClusterNameAndSalt() +{ + readStringBinary(cluster, *in); + readStringBinary(salt, *in, 32); +} + +void TCPHandler::receiveQuery() +{ + UInt64 stage = 0; + UInt64 compression = 0; + + state.is_empty = false; + readStringBinary(state.query_id, *in); + + /// In interserver mode, + /// initial_user can be empty in case of Distributed INSERT via Buffer/Kafka, + /// (i.e. when the INSERT is done with the global context without user), + /// so it is better to reset session to avoid using old user. + if (is_interserver_mode) + { + session = makeSession(); + } + + /// Read client info. + ClientInfo client_info = session->getClientInfo(); + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) + { + client_info.read(*in, client_tcp_protocol_version); + + correctQueryClientInfo(session->getClientInfo(), client_info); + const auto & config_ref = Context::getGlobalContextInstance()->getServerSettings(); + if (config_ref.validate_tcp_client_information) + validateClientInfo(session->getClientInfo(), client_info); + } + + /// Per query settings are also passed via TCP. + /// We need to check them before applying due to they can violate the settings constraints. + auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) + ? SettingsWriteFormat::STRINGS_WITH_FLAGS + : SettingsWriteFormat::BINARY; + Settings passed_settings; + passed_settings.read(*in, settings_format); + + /// Interserver secret. + std::string received_hash; + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET) + { + readStringBinary(received_hash, *in, 32); + } + + readVarUInt(stage, *in); + state.stage = QueryProcessingStage::Enum(stage); + + readVarUInt(compression, *in); + state.compression = static_cast<Protocol::Compression>(compression); + last_block_in.compression = state.compression; + + readStringBinary(state.query, *in); + + Settings passed_params; + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) + passed_params.read(*in, settings_format); + + if (is_interserver_mode) + { + client_info.interface = ClientInfo::Interface::TCP_INTERSERVER; +#if USE_SSL + String cluster_secret = server.context()->getCluster(cluster)->getSecret(); + + if (salt.empty() || cluster_secret.empty()) + { + auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed (no salt/cluster secret)"); + session->onAuthenticationFailure(/* user_name= */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT + } + + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2 && !nonce.has_value()) + { + auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed (no nonce)"); + session->onAuthenticationFailure(/* user_name= */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT + } + + std::string data(salt); + // For backward compatibility + if (nonce.has_value()) + data += std::to_string(nonce.value()); + data += cluster_secret; + data += state.query; + data += state.query_id; + data += client_info.initial_user; + + std::string calculated_hash = encodeSHA256(data); + assert(calculated_hash.size() == 32); + + /// TODO maybe also check that peer address actually belongs to the cluster? + if (calculated_hash != received_hash) + { + auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed"); + session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT + } + + /// NOTE Usually we get some fields of client_info (including initial_address and initial_user) from user input, + /// so we should not rely on that. However, in this particular case we got client_info from other clickhouse-server, so it's ok. + if (client_info.initial_user.empty()) + { + LOG_DEBUG(log, "User (no user, interserver mode) (client: {})", getClientAddress(client_info).toString()); + } + else + { + LOG_DEBUG(log, "User (initial, interserver mode): {} (client: {})", client_info.initial_user, getClientAddress(client_info).toString()); + /// In case of inter-server mode authorization is done with the + /// initial address of the client, not the real address from which + /// the query was come, since the real address is the address of + /// the initiator server, while we are interested in client's + /// address. + session->authenticate(AlwaysAllowCredentials{client_info.initial_user}, client_info.initial_address); + } +#else + auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Inter-server secret support is disabled, because ClickHouse was built without SSL library"); + session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT +#endif + } + + query_context = session->makeQueryContext(std::move(client_info)); + + /// Sets the default database if it wasn't set earlier for the session context. + if (is_interserver_mode && !default_database.empty()) + query_context->setCurrentDatabase(default_database); + + if (state.part_uuids_to_ignore) + query_context->getIgnoredPartUUIDs()->add(*state.part_uuids_to_ignore); + + query_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); }); + query_context->setFileProgressCallback([this](const FileProgress & value) { this->updateProgress(Progress(value)); }); + + /// + /// Settings + /// + auto settings_changes = passed_settings.changes(); + query_kind = query_context->getClientInfo().query_kind; + if (query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + { + /// Throw an exception if the passed settings violate the constraints. + query_context->checkSettingsConstraints(settings_changes, SettingSource::QUERY); + } + else + { + /// Quietly clamp to the constraints if it's not an initial query. + query_context->clampToSettingsConstraints(settings_changes, SettingSource::QUERY); + } + query_context->applySettingsChanges(settings_changes); + + /// Use the received query id, or generate a random default. It is convenient + /// to also generate the default OpenTelemetry trace id at the same time, and + /// set the trace parent. + /// Notes: + /// 1) ClientInfo might contain upstream trace id, so we decide whether to use + /// the default ids after we have received the ClientInfo. + /// 2) There is the opentelemetry_start_trace_probability setting that + /// controls when we start a new trace. It can be changed via Native protocol, + /// so we have to apply the changes first. + query_context->setCurrentQueryId(state.query_id); + + query_context->addQueryParameters(convertToQueryParameters(passed_params)); + + /// For testing hedged requests + if (unlikely(sleep_after_receiving_query.totalMilliseconds())) + { + std::chrono::milliseconds ms(sleep_after_receiving_query.totalMilliseconds()); + std::this_thread::sleep_for(ms); + } +} + +void TCPHandler::receiveUnexpectedQuery() +{ + UInt64 skip_uint_64; + String skip_string; + + readStringBinary(skip_string, *in); + + ClientInfo skip_client_info; + if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) + skip_client_info.read(*in, client_tcp_protocol_version); + + Settings skip_settings; + auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) ? SettingsWriteFormat::STRINGS_WITH_FLAGS + : SettingsWriteFormat::BINARY; + skip_settings.read(*in, settings_format); + + std::string skip_hash; + bool interserver_secret = client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET; + if (interserver_secret) + readStringBinary(skip_hash, *in, 32); + + readVarUInt(skip_uint_64, *in); + + readVarUInt(skip_uint_64, *in); + last_block_in.compression = static_cast<Protocol::Compression>(skip_uint_64); + + readStringBinary(skip_string, *in); + + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) + skip_settings.read(*in, settings_format); + + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Query received from client"); +} + +bool TCPHandler::receiveData(bool scalar) +{ + initBlockInput(); + + /// The name of the temporary table for writing data, default to empty string + auto temporary_id = StorageID::createEmpty(); + readStringBinary(temporary_id.table_name, *in); + + /// Read one block from the network and write it down + Block block = state.block_in->read(); + + if (!block) + { + state.read_all_data = true; + return false; + } + + if (scalar) + { + /// Scalar value + query_context->addScalar(temporary_id.table_name, block); + } + else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input) + { + /// Data for external tables + + auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal); + StoragePtr storage; + /// If such a table does not exist, create it. + if (resolved) + { + storage = DatabaseCatalog::instance().getTable(resolved, query_context); + } + else + { + NamesAndTypesList columns = block.getNamesAndTypesList(); + auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {}); + storage = temporary_table.getTable(); + query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table)); + } + auto metadata_snapshot = storage->getInMemoryMetadataPtr(); + /// The data will be written directly to the table. + QueryPipeline temporary_table_out(storage->write(ASTPtr(), metadata_snapshot, query_context, /*async_insert=*/false)); + PushingPipelineExecutor executor(temporary_table_out); + executor.start(); + executor.push(block); + executor.finish(); + } + else if (state.need_receive_data_for_input) + { + /// 'input' table function. + state.block_for_input = block; + } + else + { + /// INSERT query. + state.block_for_insert = block; + } + return true; +} + + +bool TCPHandler::receiveUnexpectedData(bool throw_exception) +{ + String skip_external_table_name; + readStringBinary(skip_external_table_name, *in); + + std::shared_ptr<ReadBuffer> maybe_compressed_in; + if (last_block_in.compression == Protocol::Compression::Enable) + maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true); + else + maybe_compressed_in = in; + + auto skip_block_in = std::make_shared<NativeReader>(*maybe_compressed_in, client_tcp_protocol_version); + bool read_ok = !!skip_block_in->read(); + + if (!read_ok) + state.read_all_data = true; + + if (throw_exception) + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Data received from client"); + + return read_ok; +} + +void TCPHandler::initBlockInput() +{ + if (!state.block_in) + { + /// 'allow_different_codecs' is set to true, because some parts of compressed data can be precompressed in advance + /// with another codec that the rest of the data. Example: data sent by Distributed tables. + + if (state.compression == Protocol::Compression::Enable) + state.maybe_compressed_in = std::make_shared<CompressedReadBuffer>(*in, /* allow_different_codecs */ true); + else + state.maybe_compressed_in = in; + + Block header; + if (state.io.pipeline.pushing()) + header = state.io.pipeline.getHeader(); + else if (state.need_receive_data_for_input) + header = state.input_header; + + state.block_in = std::make_unique<NativeReader>( + *state.maybe_compressed_in, + header, + client_tcp_protocol_version); + } +} + + +void TCPHandler::initBlockOutput(const Block & block) +{ + if (!state.block_out) + { + const Settings & query_settings = query_context->getSettingsRef(); + if (!state.maybe_compressed_out) + { + std::string method = Poco::toUpper(query_settings.network_compression_method.toString()); + std::optional<int> level; + if (method == "ZSTD") + level = query_settings.network_zstd_compression_level; + + if (state.compression == Protocol::Compression::Enable) + { + CompressionCodecFactory::instance().validateCodec(method, level, !query_settings.allow_suspicious_codecs, query_settings.allow_experimental_codecs, query_settings.enable_deflate_qpl_codec); + + state.maybe_compressed_out = std::make_shared<CompressedWriteBuffer>( + *out, CompressionCodecFactory::instance().get(method, level)); + } + else + state.maybe_compressed_out = out; + } + + state.block_out = std::make_unique<NativeWriter>( + *state.maybe_compressed_out, + client_tcp_protocol_version, + block.cloneEmpty(), + !query_settings.low_cardinality_allow_in_native_format); + } +} + +void TCPHandler::initLogsBlockOutput(const Block & block) +{ + if (!state.logs_block_out) + { + /// Use uncompressed stream since log blocks usually contain only one row + const Settings & query_settings = query_context->getSettingsRef(); + state.logs_block_out = std::make_unique<NativeWriter>( + *out, + client_tcp_protocol_version, + block.cloneEmpty(), + !query_settings.low_cardinality_allow_in_native_format); + } +} + + +void TCPHandler::initProfileEventsBlockOutput(const Block & block) +{ + if (!state.profile_events_block_out) + { + const Settings & query_settings = query_context->getSettingsRef(); + state.profile_events_block_out = std::make_unique<NativeWriter>( + *out, + client_tcp_protocol_version, + block.cloneEmpty(), + !query_settings.low_cardinality_allow_in_native_format); + } +} + +void TCPHandler::decreaseCancellationStatus(const std::string & log_message) +{ + auto prev_status = magic_enum::enum_name(state.cancellation_status); + + bool partial_result_on_first_cancel = false; + if (query_context) + { + const auto & settings = query_context->getSettingsRef(); + partial_result_on_first_cancel = settings.partial_result_on_first_cancel; + } + + if (partial_result_on_first_cancel && state.cancellation_status == CancellationStatus::NOT_CANCELLED) + { + state.cancellation_status = CancellationStatus::READ_CANCELLED; + } + else + { + state.cancellation_status = CancellationStatus::FULLY_CANCELLED; + } + + auto current_status = magic_enum::enum_name(state.cancellation_status); + LOG_INFO(log, "Change cancellation status from {} to {}. Log message: {}", prev_status, current_status, log_message); +} + +QueryState::CancellationStatus TCPHandler::getQueryCancellationStatus() +{ + if (state.cancellation_status == CancellationStatus::FULLY_CANCELLED || state.sent_all_data) + return CancellationStatus::FULLY_CANCELLED; + + if (after_check_cancelled.elapsed() / 1000 < interactive_delay) + return state.cancellation_status; + + after_check_cancelled.restart(); + + /// During request execution the only packet that can come from the client is stopping the query. + if (static_cast<ReadBufferFromPocoSocket &>(*in).poll(0)) + { + if (in->eof()) + { + LOG_INFO(log, "Client has dropped the connection, cancel the query."); + state.cancellation_status = CancellationStatus::FULLY_CANCELLED; + state.is_connection_closed = true; + return CancellationStatus::FULLY_CANCELLED; + } + + UInt64 packet_type = 0; + readVarUInt(packet_type, *in); + + switch (packet_type) + { + case Protocol::Client::Cancel: + if (state.empty()) + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT, "Unexpected packet Cancel received from client"); + + decreaseCancellationStatus("Query was cancelled."); + + return state.cancellation_status; + + default: + throw NetException(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT, "Unknown packet from client {}", toString(packet_type)); + } + } + + return state.cancellation_status; +} + + +void TCPHandler::sendData(const Block & block) +{ + initBlockOutput(block); + + size_t prev_bytes_written_out = out->count(); + size_t prev_bytes_written_compressed_out = state.maybe_compressed_out->count(); + + try + { + /// For testing hedged requests + if (unknown_packet_in_send_data) + { + constexpr UInt64 marker = (1ULL<<63) - 1; + --unknown_packet_in_send_data; + if (unknown_packet_in_send_data == 0) + writeVarUInt(marker, *out); + } + + writeVarUInt(Protocol::Server::Data, *out); + /// Send external table name (empty name is the main table) + writeStringBinary("", *out); + + /// For testing hedged requests + if (block.rows() > 0 && query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds()) + { + out->next(); + std::chrono::milliseconds ms(query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds()); + std::this_thread::sleep_for(ms); + } + + state.block_out->write(block); + state.maybe_compressed_out->next(); + out->next(); + } + catch (...) + { + /// In case of unsuccessful write, if the buffer with written data was not flushed, + /// we will rollback write to avoid breaking the protocol. + /// (otherwise the client will not be able to receive exception after unfinished data + /// as it will expect the continuation of the data). + /// It looks like hangs on client side or a message like "Data compressed with different methods". + + if (state.compression == Protocol::Compression::Enable) + { + auto extra_bytes_written_compressed = state.maybe_compressed_out->count() - prev_bytes_written_compressed_out; + if (state.maybe_compressed_out->offset() >= extra_bytes_written_compressed) + state.maybe_compressed_out->position() -= extra_bytes_written_compressed; + } + + auto extra_bytes_written_out = out->count() - prev_bytes_written_out; + if (out->offset() >= extra_bytes_written_out) + out->position() -= extra_bytes_written_out; + + throw; + } +} + + +void TCPHandler::sendLogData(const Block & block) +{ + initLogsBlockOutput(block); + + writeVarUInt(Protocol::Server::Log, *out); + /// Send log tag (empty tag is the default tag) + writeStringBinary("", *out); + + state.logs_block_out->write(block); + out->next(); +} + +void TCPHandler::sendTableColumns(const ColumnsDescription & columns) +{ + writeVarUInt(Protocol::Server::TableColumns, *out); + + /// Send external table name (empty name is the main table) + writeStringBinary("", *out); + writeStringBinary(columns.toString(), *out); + + out->next(); +} + +void TCPHandler::sendException(const Exception & e, bool with_stack_trace) +{ + state.io.setAllDataSent(); + + writeVarUInt(Protocol::Server::Exception, *out); + writeException(e, *out, with_stack_trace); + out->next(); +} + + +void TCPHandler::sendEndOfStream() +{ + state.sent_all_data = true; + state.io.setAllDataSent(); + + writeVarUInt(Protocol::Server::EndOfStream, *out); + out->next(); +} + + +void TCPHandler::updateProgress(const Progress & value) +{ + state.progress.incrementPiecewiseAtomically(value); +} + + +void TCPHandler::sendProgress() +{ + writeVarUInt(Protocol::Server::Progress, *out); + auto increment = state.progress.fetchValuesAndResetPiecewiseAtomically(); + UInt64 current_elapsed_ns = state.watch.elapsedNanoseconds(); + increment.elapsed_ns = current_elapsed_ns - state.prev_elapsed_ns; + state.prev_elapsed_ns = current_elapsed_ns; + increment.write(*out, client_tcp_protocol_version); + out->next(); +} + + +void TCPHandler::sendLogs() +{ + if (!state.logs_queue) + return; + + MutableColumns logs_columns; + MutableColumns curr_logs_columns; + size_t rows = 0; + + for (; state.logs_queue->tryPop(curr_logs_columns); ++rows) + { + if (rows == 0) + { + logs_columns = std::move(curr_logs_columns); + } + else + { + for (size_t j = 0; j < logs_columns.size(); ++j) + logs_columns[j]->insertRangeFrom(*curr_logs_columns[j], 0, curr_logs_columns[j]->size()); + } + } + + if (rows > 0) + { + Block block = InternalTextLogsQueue::getSampleBlock(); + block.setColumns(std::move(logs_columns)); + sendLogData(block); + } +} + + +void TCPHandler::run() +{ + try + { + runImpl(); + + LOG_DEBUG(log, "Done processing connection."); + } + catch (Poco::Exception & e) + { + /// Timeout - not an error. + if (e.what() == "Timeout"sv) + { + LOG_DEBUG(log, "Poco::Exception. Code: {}, e.code() = {}, e.displayText() = {}, e.what() = {}", ErrorCodes::POCO_EXCEPTION, e.code(), e.displayText(), e.what()); + } + else + throw; + } +} + +Poco::Net::SocketAddress TCPHandler::getClientAddress(const ClientInfo & client_info) +{ + /// Extract the last entry from comma separated list of forwarded_for addresses. + /// Only the last proxy can be trusted (if any). + String forwarded_address = client_info.getLastForwardedFor(); + if (!forwarded_address.empty() && server.config().getBool("auth_use_forwarded_address", false)) + return Poco::Net::SocketAddress(forwarded_address, socket().peerAddress().port()); + else + return socket().peerAddress(); +} + +} diff --git a/contrib/clickhouse/src/Server/TCPHandler.h b/contrib/clickhouse/src/Server/TCPHandler.h new file mode 100644 index 0000000000..235f634afe --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPHandler.h @@ -0,0 +1,294 @@ +#pragma once + +#include <optional> +#include <Poco/Net/TCPServerConnection.h> + +#include <base/getFQDNOrHostName.h> +#include <Common/ProfileEvents.h> +#include <Common/CurrentMetrics.h> +#include <Common/Stopwatch.h> +#include <Common/ThreadStatus.h> +#include <Core/Protocol.h> +#include <Core/QueryProcessingStage.h> +#include <IO/Progress.h> +#include <IO/TimeoutSetter.h> +#include <QueryPipeline/BlockIO.h> +#include <Interpreters/InternalTextLogsQueue.h> +#include <Interpreters/Context_fwd.h> +#include <Interpreters/ClientInfo.h> +#include <Interpreters/ProfileEventsExt.h> +#include <Formats/NativeReader.h> +#include <Formats/NativeWriter.h> + +#include "IServer.h" +#include "Server/TCPProtocolStackData.h" +#include "Storages/MergeTree/RequestResponse.h" +#include "base/types.h" + + +namespace CurrentMetrics +{ + extern const Metric TCPConnection; +} + +namespace Poco { class Logger; } + +namespace DB +{ + +class Session; +struct Settings; +class ColumnsDescription; +struct ProfileInfo; +class TCPServer; +class NativeWriter; +class NativeReader; + +/// State of query processing. +struct QueryState +{ + /// Identifier of the query. + String query_id; + + QueryProcessingStage::Enum stage = QueryProcessingStage::Complete; + Protocol::Compression compression = Protocol::Compression::Disable; + + /// A queue with internal logs that will be passed to client. It must be + /// destroyed after input/output blocks, because they may contain other + /// threads that use this queue. + InternalTextLogsQueuePtr logs_queue; + std::unique_ptr<NativeWriter> logs_block_out; + + InternalProfileEventsQueuePtr profile_queue; + std::unique_ptr<NativeWriter> profile_events_block_out; + + /// From where to read data for INSERT. + std::shared_ptr<ReadBuffer> maybe_compressed_in; + std::unique_ptr<NativeReader> block_in; + + /// Where to write result data. + std::shared_ptr<WriteBuffer> maybe_compressed_out; + std::unique_ptr<NativeWriter> block_out; + Block block_for_insert; + + /// Query text. + String query; + /// Streams of blocks, that are processing the query. + BlockIO io; + + enum class CancellationStatus: UInt8 + { + FULLY_CANCELLED, + READ_CANCELLED, + NOT_CANCELLED + }; + + /// Is request cancelled + CancellationStatus cancellation_status = CancellationStatus::NOT_CANCELLED; + bool is_connection_closed = false; + /// empty or not + bool is_empty = true; + /// Data was sent. + bool sent_all_data = false; + /// Request requires data from the client (INSERT, but not INSERT SELECT). + bool need_receive_data_for_insert = false; + /// Data was read. + bool read_all_data = false; + + /// A state got uuids to exclude from a query + std::optional<std::vector<UUID>> part_uuids_to_ignore; + + /// Request requires data from client for function input() + bool need_receive_data_for_input = false; + /// temporary place for incoming data block for input() + Block block_for_input; + /// sample block from StorageInput + Block input_header; + + /// If true, the data packets will be skipped instead of reading. Used to recover after errors. + bool skipping_data = false; + + /// To output progress, the difference after the previous sending of progress. + Progress progress; + Stopwatch watch; + UInt64 prev_elapsed_ns = 0; + + /// Timeouts setter for current query + std::unique_ptr<TimeoutSetter> timeout_setter; + + void reset() + { + *this = QueryState(); + } + + bool empty() const + { + return is_empty; + } +}; + + +struct LastBlockInputParameters +{ + Protocol::Compression compression = Protocol::Compression::Disable; +}; + +class TCPHandler : public Poco::Net::TCPServerConnection +{ +public: + /** parse_proxy_protocol_ - if true, expect and parse the header of PROXY protocol in every connection + * and set the information about forwarded address accordingly. + * See https://github.com/wolfeidau/proxyv2/blob/master/docs/proxy-protocol.txt + * + * Note: immediate IP address is always used for access control (accept-list of IP networks), + * because it allows to check the IP ranges of the trusted proxy. + * Proxy-forwarded (original client) IP address is used for quota accounting if quota is keyed by forwarded IP. + */ + TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_); + TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, TCPProtocolStackData & stack_data, std::string server_display_name_); + ~TCPHandler() override; + + void run() override; + + /// This method is called right before the query execution. + virtual void customizeContext(ContextMutablePtr /*context*/) {} + +private: + IServer & server; + TCPServer & tcp_server; + bool parse_proxy_protocol = false; + Poco::Logger * log; + + String forwarded_for; + String certificate; + + String client_name; + UInt64 client_version_major = 0; + UInt64 client_version_minor = 0; + UInt64 client_version_patch = 0; + UInt32 client_tcp_protocol_version = 0; + String quota_key; + + /// Connection settings, which are extracted from a context. + bool send_exception_with_stack_trace = true; + Poco::Timespan send_timeout = Poco::Timespan(DBMS_DEFAULT_SEND_TIMEOUT_SEC, 0); + Poco::Timespan receive_timeout = Poco::Timespan(DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC, 0); + UInt64 poll_interval = DBMS_DEFAULT_POLL_INTERVAL; + UInt64 idle_connection_timeout = 3600; + UInt64 interactive_delay = 100000; + Poco::Timespan sleep_in_send_tables_status; + UInt64 unknown_packet_in_send_data = 0; + Poco::Timespan sleep_after_receiving_query; + + std::unique_ptr<Session> session; + ContextMutablePtr query_context; + ClientInfo::QueryKind query_kind = ClientInfo::QueryKind::NO_QUERY; + + /// Streams for reading/writing from/to client connection socket. + std::shared_ptr<ReadBuffer> in; + std::shared_ptr<WriteBuffer> out; + + /// Time after the last check to stop the request and send the progress. + Stopwatch after_check_cancelled; + Stopwatch after_send_progress; + + String default_database; + + /// For inter-server secret (remote_server.*.secret) + bool is_interserver_mode = false; + /// For DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET + String salt; + /// For DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET_V2 + std::optional<UInt64> nonce; + String cluster; + + std::mutex task_callback_mutex; + std::mutex fatal_error_mutex; + + /// At the moment, only one ongoing query in the connection is supported at a time. + QueryState state; + + /// Last block input parameters are saved to be able to receive unexpected data packet sent after exception. + LastBlockInputParameters last_block_in; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::TCPConnection}; + + ProfileEvents::ThreadIdToCountersSnapshot last_sent_snapshots; + + /// It is the name of the server that will be sent to the client. + String server_display_name; + + void runImpl(); + + void extractConnectionSettingsFromContext(const ContextPtr & context); + + std::unique_ptr<Session> makeSession(); + + bool receiveProxyHeader(); + void receiveHello(); + void receiveAddendum(); + bool receivePacket(); + void receiveQuery(); + void receiveIgnoredPartUUIDs(); + String receiveReadTaskResponseAssumeLocked(); + std::optional<ParallelReadResponse> receivePartitionMergeTreeReadTaskResponseAssumeLocked(); + bool receiveData(bool scalar); + bool readDataNext(); + void readData(); + void skipData(); + void receiveClusterNameAndSalt(); + + bool receiveUnexpectedData(bool throw_exception = true); + [[noreturn]] void receiveUnexpectedQuery(); + [[noreturn]] void receiveUnexpectedIgnoredPartUUIDs(); + [[noreturn]] void receiveUnexpectedHello(); + [[noreturn]] void receiveUnexpectedTablesStatusRequest(); + + /// Process INSERT query + void processInsertQuery(); + + /// Process a request that does not require the receiving of data blocks from the client + void processOrdinaryQuery(); + + void processOrdinaryQueryWithProcessors(); + + void processTablesStatusRequest(); + + void sendHello(); + void sendData(const Block & block); /// Write a block to the network. + void sendLogData(const Block & block); + void sendTableColumns(const ColumnsDescription & columns); + void sendException(const Exception & e, bool with_stack_trace); + void sendProgress(); + void sendLogs(); + void sendEndOfStream(); + void sendPartUUIDs(); + void sendReadTaskRequestAssumeLocked(); + void sendMergeTreeAllRangesAnnounecementAssumeLocked(InitialAllRangesAnnouncement announcement); + void sendMergeTreeReadTaskRequestAssumeLocked(ParallelReadRequest request); + void sendProfileInfo(const ProfileInfo & info); + void sendTotals(const Block & totals); + void sendExtremes(const Block & extremes); + void sendProfileEvents(); + void sendSelectProfileEvents(); + void sendInsertProfileEvents(); + void sendTimezone(); + + /// Creates state.block_in/block_out for blocks read/write, depending on whether compression is enabled. + void initBlockInput(); + void initBlockOutput(const Block & block); + void initLogsBlockOutput(const Block & block); + void initProfileEventsBlockOutput(const Block & block); + + using CancellationStatus = QueryState::CancellationStatus; + + void decreaseCancellationStatus(const std::string & log_message); + CancellationStatus getQueryCancellationStatus(); + + /// This function is called from different threads. + void updateProgress(const Progress & value); + + Poco::Net::SocketAddress getClientAddress(const ClientInfo & client_info); +}; + +} diff --git a/contrib/clickhouse/src/Server/TCPHandlerFactory.h b/contrib/clickhouse/src/Server/TCPHandlerFactory.h new file mode 100644 index 0000000000..fde04c6e0a --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPHandlerFactory.h @@ -0,0 +1,74 @@ +#pragma once + +#include <Poco/Net/NetException.h> +#include <Poco/Util/LayeredConfiguration.h> +#include <Common/logger_useful.h> +#include "Server/TCPProtocolStackData.h" +#include <Server/IServer.h> +#include <Server/TCPHandler.h> +#include <Server/TCPServerConnectionFactory.h> + +namespace Poco { class Logger; } + +namespace DB +{ + +class TCPHandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + bool parse_proxy_protocol = false; + Poco::Logger * log; + std::string server_display_name; + + class DummyTCPHandler : public Poco::Net::TCPServerConnection + { + public: + using Poco::Net::TCPServerConnection::TCPServerConnection; + void run() override {} + }; + +public: + /** parse_proxy_protocol_ - if true, expect and parse the header of PROXY protocol in every connection + * and set the information about forwarded address accordingly. + * See https://github.com/wolfeidau/proxyv2/blob/master/docs/proxy-protocol.txt + */ + TCPHandlerFactory(IServer & server_, bool secure_, bool parse_proxy_protocol_) + : server(server_), parse_proxy_protocol(parse_proxy_protocol_) + , log(&Poco::Logger::get(std::string("TCP") + (secure_ ? "S" : "") + "HandlerFactory")) + { + server_display_name = server.config().getString("display_name", getFQDNOrHostName()); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override + { + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + + return new TCPHandler(server, tcp_server, socket, parse_proxy_protocol, server_display_name); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData & stack_data) override + { + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + + return new TCPHandler(server, tcp_server, socket, stack_data, server_display_name); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } +}; + +} diff --git a/contrib/clickhouse/src/Server/TCPProtocolStackData.h b/contrib/clickhouse/src/Server/TCPProtocolStackData.h new file mode 100644 index 0000000000..4ad401e723 --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPProtocolStackData.h @@ -0,0 +1,22 @@ +#pragma once + +#include <string> +#include <Poco/Net/StreamSocket.h> + +namespace DB +{ + +// Data to communicate between protocol layers +struct TCPProtocolStackData +{ + // socket implementation can be replaced by some layer - TLS as an example + Poco::Net::StreamSocket socket; + // host from PROXY layer + std::string forwarded_for; + // certificate path from TLS layer to TCP layer + std::string certificate; + // default database from endpoint configuration to TCP layer + std::string default_database; +}; + +} diff --git a/contrib/clickhouse/src/Server/TCPProtocolStackFactory.h b/contrib/clickhouse/src/Server/TCPProtocolStackFactory.h new file mode 100644 index 0000000000..7373e6e1c4 --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPProtocolStackFactory.h @@ -0,0 +1,92 @@ +#pragma once + +#include <Server/TCPServerConnectionFactory.h> +#include <Server/IServer.h> +#include <Server/TCPProtocolStackHandler.h> +#include <Poco/Logger.h> +#include <Poco/Net/NetException.h> +#include <Common/logger_useful.h> +#include <Access/Common/AllowedClientHosts.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_ADDRESS_PATTERN_TYPE; + extern const int IP_ADDRESS_NOT_ALLOWED; +} + + +class TCPProtocolStackFactory : public TCPServerConnectionFactory +{ +private: + IServer & server [[maybe_unused]]; + Poco::Logger * log; + std::string conf_name; + std::vector<TCPServerConnectionFactory::Ptr> stack; + AllowedClientHosts allowed_client_hosts; + + class DummyTCPHandler : public Poco::Net::TCPServerConnection + { + public: + using Poco::Net::TCPServerConnection::TCPServerConnection; + void run() override {} + }; + +public: + template <typename... T> + explicit TCPProtocolStackFactory(IServer & server_, const std::string & conf_name_, T... factory) + : server(server_), log(&Poco::Logger::get("TCPProtocolStackFactory")), conf_name(conf_name_), stack({factory...}) + { + const auto & config = server.config(); + /// Fill list of allowed hosts. + const auto networks_config = conf_name + ".networks"; + if (config.has(networks_config)) + { + Poco::Util::AbstractConfiguration::Keys keys; + config.keys(networks_config, keys); + for (const String & key : keys) + { + String value = config.getString(networks_config + "." + key); + if (key.starts_with("ip")) + allowed_client_hosts.addSubnet(value); + else if (key.starts_with("host_regexp")) + allowed_client_hosts.addNameRegexp(value); + else if (key.starts_with("host")) + allowed_client_hosts.addName(value); + else + throw Exception(ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE, "Unknown address pattern type: {}", key); + } + } + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override + { + if (!allowed_client_hosts.empty() && !allowed_client_hosts.contains(socket.peerAddress().host())) + throw Exception(ErrorCodes::IP_ADDRESS_NOT_ALLOWED, "Connections from {} are not allowed", socket.peerAddress().toString()); + + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + return new TCPProtocolStackHandler(server, tcp_server, socket, stack, conf_name); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } + + void append(TCPServerConnectionFactory::Ptr factory) + { + stack.push_back(std::move(factory)); + } + + size_t size() { return stack.size(); } + bool empty() { return stack.empty(); } +}; + + +} diff --git a/contrib/clickhouse/src/Server/TCPProtocolStackHandler.h b/contrib/clickhouse/src/Server/TCPProtocolStackHandler.h new file mode 100644 index 0000000000..e16a6b6b2c --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPProtocolStackHandler.h @@ -0,0 +1,46 @@ +#pragma once + +#include <Server/TCPServerConnectionFactory.h> +#include <Server/TCPServer.h> +#include <Poco/Util/LayeredConfiguration.h> +#include <Server/IServer.h> +#include <Server/TCPProtocolStackData.h> + + +namespace DB +{ + + +class TCPProtocolStackHandler : public Poco::Net::TCPServerConnection +{ + using StreamSocket = Poco::Net::StreamSocket; + using TCPServerConnection = Poco::Net::TCPServerConnection; +private: + IServer & server; + TCPServer & tcp_server; + std::vector<TCPServerConnectionFactory::Ptr> stack; + std::string conf_name; + +public: + TCPProtocolStackHandler(IServer & server_, TCPServer & tcp_server_, const StreamSocket & socket, const std::vector<TCPServerConnectionFactory::Ptr> & stack_, const std::string & conf_name_) + : TCPServerConnection(socket), server(server_), tcp_server(tcp_server_), stack(stack_), conf_name(conf_name_) + {} + + void run() override + { + const auto & conf = server.config(); + TCPProtocolStackData stack_data; + stack_data.socket = socket(); + stack_data.default_database = conf.getString(conf_name + ".default_database", ""); + for (auto & factory : stack) + { + std::unique_ptr<TCPServerConnection> connection(factory->createConnection(socket(), tcp_server, stack_data)); + connection->run(); + if (stack_data.socket != socket()) + socket() = stack_data.socket; + } + } +}; + + +} diff --git a/contrib/clickhouse/src/Server/TCPServer.cpp b/contrib/clickhouse/src/Server/TCPServer.cpp new file mode 100644 index 0000000000..380c4ef992 --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPServer.cpp @@ -0,0 +1,36 @@ +#include <Poco/Net/TCPServerConnectionFactory.h> +#include <Server/TCPServer.h> + +namespace DB +{ + +class TCPServerConnectionFactoryImpl : public Poco::Net::TCPServerConnectionFactory +{ +public: + TCPServerConnectionFactoryImpl(TCPServer & tcp_server_, DB::TCPServerConnectionFactory::Ptr factory_) + : tcp_server(tcp_server_) + , factory(factory_) + {} + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override + { + return factory->createConnection(socket, tcp_server); + } +private: + TCPServer & tcp_server; + DB::TCPServerConnectionFactory::Ptr factory; +}; + +TCPServer::TCPServer( + TCPServerConnectionFactory::Ptr factory_, + Poco::ThreadPool & thread_pool, + Poco::Net::ServerSocket & socket_, + Poco::Net::TCPServerParams::Ptr params) + : Poco::Net::TCPServer(new TCPServerConnectionFactoryImpl(*this, factory_), thread_pool, socket_, params) + , factory(factory_) + , socket(socket_) + , is_open(true) + , port_number(socket.address().port()) +{} + +} diff --git a/contrib/clickhouse/src/Server/TCPServer.h b/contrib/clickhouse/src/Server/TCPServer.h new file mode 100644 index 0000000000..219fed5342 --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPServer.h @@ -0,0 +1,47 @@ +#pragma once + +#include <Poco/Net/TCPServer.h> + +#include <base/types.h> +#include <Server/TCPServerConnectionFactory.h> + + +namespace DB +{ +class Context; + +class TCPServer : public Poco::Net::TCPServer +{ +public: + explicit TCPServer( + TCPServerConnectionFactory::Ptr factory, + Poco::ThreadPool & thread_pool, + Poco::Net::ServerSocket & socket, + Poco::Net::TCPServerParams::Ptr params = new Poco::Net::TCPServerParams); + + /// Close the socket and ask existing connections to stop serving queries + void stop() + { + Poco::Net::TCPServer::stop(); + // This notifies already established connections that they should stop serving + // queries and close their socket as soon as they can. + is_open = false; + // Poco's stop() stops listening on the socket but leaves it open. + // To be able to hand over control of the listening port to a new server, and + // to get fast connection refusal instead of timeouts, we also need to close + // the listening socket. + socket.close(); + } + + bool isOpen() const { return is_open; } + + UInt16 portNumber() const { return port_number; } + +private: + TCPServerConnectionFactory::Ptr factory; + Poco::Net::ServerSocket socket; + std::atomic<bool> is_open; + UInt16 port_number; +}; + +} diff --git a/contrib/clickhouse/src/Server/TCPServerConnectionFactory.h b/contrib/clickhouse/src/Server/TCPServerConnectionFactory.h new file mode 100644 index 0000000000..18b30557b0 --- /dev/null +++ b/contrib/clickhouse/src/Server/TCPServerConnectionFactory.h @@ -0,0 +1,32 @@ +#pragma once + +#include <Poco/SharedPtr.h> +#include <Server/TCPProtocolStackData.h> + +namespace Poco +{ +namespace Net +{ + class StreamSocket; + class TCPServerConnection; +} +} +namespace DB +{ +class TCPServer; + +class TCPServerConnectionFactory +{ +public: + using Ptr = Poco::SharedPtr<TCPServerConnectionFactory>; + + virtual ~TCPServerConnectionFactory() = default; + + /// Same as Poco::Net::TCPServerConnectionFactory except we can pass the TCPServer + virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) = 0; + virtual Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server, TCPProtocolStackData &/* stack_data */) + { + return createConnection(socket, tcp_server); + } +}; +} diff --git a/contrib/clickhouse/src/Server/TLSHandler.h b/contrib/clickhouse/src/Server/TLSHandler.h new file mode 100644 index 0000000000..d4d7584d12 --- /dev/null +++ b/contrib/clickhouse/src/Server/TLSHandler.h @@ -0,0 +1,58 @@ +#pragma once + +#include <Poco/Net/TCPServerConnection.h> +#include <Poco/SharedPtr.h> +#include <Common/Exception.h> +#include <Server/TCPProtocolStackData.h> + +#if USE_SSL +# include <Poco/Net/Context.h> +# include <Poco/Net/SecureStreamSocket.h> +# include <Poco/Net/SSLManager.h> +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SUPPORT_IS_DISABLED; +} + +class TLSHandler : public Poco::Net::TCPServerConnection +{ +#if USE_SSL + using SecureStreamSocket = Poco::Net::SecureStreamSocket; + using SSLManager = Poco::Net::SSLManager; + using Context = Poco::Net::Context; +#endif + using StreamSocket = Poco::Net::StreamSocket; +public: + explicit TLSHandler(const StreamSocket & socket, const std::string & key_, const std::string & certificate_, TCPProtocolStackData & stack_data_) + : Poco::Net::TCPServerConnection(socket) + , key(key_) + , certificate(certificate_) + , stack_data(stack_data_) + {} + + void run() override + { +#if USE_SSL + auto ctx = SSLManager::instance().defaultServerContext(); + // if (!key.empty() && !certificate.empty()) + // ctx = new Context(Context::Usage::SERVER_USE, key, certificate, ctx->getCAPaths().caLocation); + socket() = SecureStreamSocket::attach(socket(), ctx); + stack_data.socket = socket(); + stack_data.certificate = certificate; +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "SSL support for TCP protocol is disabled because Poco library was built without NetSSL support."); +#endif + } +private: + std::string key [[maybe_unused]]; + std::string certificate [[maybe_unused]]; + TCPProtocolStackData & stack_data [[maybe_unused]]; +}; + + +} diff --git a/contrib/clickhouse/src/Server/TLSHandlerFactory.h b/contrib/clickhouse/src/Server/TLSHandlerFactory.h new file mode 100644 index 0000000000..9e3002d297 --- /dev/null +++ b/contrib/clickhouse/src/Server/TLSHandlerFactory.h @@ -0,0 +1,64 @@ +#pragma once + +#include <Poco/Logger.h> +#include <Poco/Net/TCPServerConnection.h> +#include <Poco/Net/NetException.h> +#include <Poco/Util/LayeredConfiguration.h> +#include <Server/TLSHandler.h> +#include <Server/IServer.h> +#include <Server/TCPServer.h> +#include <Server/TCPProtocolStackData.h> +#include <Common/logger_useful.h> + + +namespace DB +{ + + +class TLSHandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + Poco::Logger * log; + std::string conf_name; + + class DummyTCPHandler : public Poco::Net::TCPServerConnection + { + public: + using Poco::Net::TCPServerConnection::TCPServerConnection; + void run() override {} + }; + +public: + explicit TLSHandlerFactory(IServer & server_, const std::string & conf_name_) + : server(server_), log(&Poco::Logger::get("TLSHandlerFactory")), conf_name(conf_name_) + { + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer & tcp_server) override + { + TCPProtocolStackData stack_data; + return createConnection(socket, tcp_server, stack_data); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &/* tcp_server*/, TCPProtocolStackData & stack_data) override + { + try + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + return new TLSHandler( + socket, + server.config().getString(conf_name + ".privateKeyFile", ""), + server.config().getString(conf_name + ".certificateFile", ""), + stack_data); + } + catch (const Poco::Net::NetException &) + { + LOG_TRACE(log, "TCP Request. Client is not connected (most likely RST packet was sent)."); + return new DummyTCPHandler(socket); + } + } +}; + + +} diff --git a/contrib/clickhouse/src/Server/WebUIRequestHandler.cpp b/contrib/clickhouse/src/Server/WebUIRequestHandler.cpp new file mode 100644 index 0000000000..131badbe83 --- /dev/null +++ b/contrib/clickhouse/src/Server/WebUIRequestHandler.cpp @@ -0,0 +1,75 @@ +#include "WebUIRequestHandler.h" +#include "IServer.h" + +#include <Poco/Net/HTTPServerRequest.h> +#include <Poco/Net/HTTPServerResponse.h> +#include <Poco/Util/LayeredConfiguration.h> + +#include <IO/HTTPCommon.h> + +#include <re2/re2.h> + +#include <incbin.h> + +#include "clickhouse_config.h" + +/// Embedded HTML pages +INCBIN(resource_play_html, SOURCE_DIR "/programs/server/play.html"); +INCBIN(resource_dashboard_html, SOURCE_DIR "/programs/server/dashboard.html"); +INCBIN(resource_uplot_js, SOURCE_DIR "/programs/server/js/uplot.js"); + + +namespace DB +{ + +WebUIRequestHandler::WebUIRequestHandler(IServer & server_) + : server(server_) +{ +} + + +void WebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) +{ + auto keep_alive_timeout = server.config().getUInt("keep_alive_timeout", 10); + + response.setContentType("text/html; charset=UTF-8"); + + if (request.getVersion() == HTTPServerRequest::HTTP_1_1) + response.setChunkedTransferEncoding(true); + + setResponseDefaultHeaders(response, keep_alive_timeout); + + if (request.getURI().starts_with("/play")) + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK); + *response.send() << std::string_view(reinterpret_cast<const char *>(gresource_play_htmlData), gresource_play_htmlSize); + } + else if (request.getURI().starts_with("/dashboard")) + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK); + + std::string html(reinterpret_cast<const char *>(gresource_dashboard_htmlData), gresource_dashboard_htmlSize); + + /// Replace a link to external JavaScript file to embedded file. + /// This allows to open the HTML without running a server and to host it on server. + /// Note: we can embed the JavaScript file inline to the HTML, + /// but we don't do it to keep the "view-source" perfectly readable. + + static re2::RE2 uplot_url = R"(https://[^\s"'`]+u[Pp]lot[^\s"'`]*\.js)"; + RE2::Replace(&html, uplot_url, "/js/uplot.js"); + + *response.send() << html; + } + else if (request.getURI() == "/js/uplot.js") + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK); + *response.send() << std::string_view(reinterpret_cast<const char *>(gresource_uplot_jsData), gresource_uplot_jsSize); + } + else + { + response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_NOT_FOUND); + *response.send() << "Not found.\n"; + } +} + +} diff --git a/contrib/clickhouse/src/Server/WebUIRequestHandler.h b/contrib/clickhouse/src/Server/WebUIRequestHandler.h new file mode 100644 index 0000000000..09fe62d41c --- /dev/null +++ b/contrib/clickhouse/src/Server/WebUIRequestHandler.h @@ -0,0 +1,22 @@ +#pragma once + +#include <Server/HTTP/HTTPRequestHandler.h> + + +namespace DB +{ + +class IServer; + +/// Response with HTML page that allows to send queries and show results in browser. +class WebUIRequestHandler : public HTTPRequestHandler +{ +private: + IServer & server; + +public: + WebUIRequestHandler(IServer & server_); + void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response) override; +}; + +} diff --git a/contrib/clickhouse/src/Server/waitServersToFinish.cpp b/contrib/clickhouse/src/Server/waitServersToFinish.cpp new file mode 100644 index 0000000000..3b07c08206 --- /dev/null +++ b/contrib/clickhouse/src/Server/waitServersToFinish.cpp @@ -0,0 +1,39 @@ +#include <Server/waitServersToFinish.h> +#include <Server/ProtocolServerAdapter.h> +#include <base/sleep.h> + +namespace DB +{ + +size_t waitServersToFinish(std::vector<DB::ProtocolServerAdapter> & servers, std::mutex & mutex, size_t seconds_to_wait) +{ + const size_t sleep_max_ms = 1000 * seconds_to_wait; + const size_t sleep_one_ms = 100; + size_t sleep_current_ms = 0; + size_t current_connections = 0; + for (;;) + { + current_connections = 0; + + { + std::scoped_lock lock{mutex}; + for (auto & server : servers) + { + server.stop(); + current_connections += server.currentConnections(); + } + } + + if (!current_connections) + break; + + sleep_current_ms += sleep_one_ms; + if (sleep_current_ms < sleep_max_ms) + sleepForMilliseconds(sleep_one_ms); + else + break; + } + return current_connections; +} + +} diff --git a/contrib/clickhouse/src/Server/waitServersToFinish.h b/contrib/clickhouse/src/Server/waitServersToFinish.h new file mode 100644 index 0000000000..b6daa02596 --- /dev/null +++ b/contrib/clickhouse/src/Server/waitServersToFinish.h @@ -0,0 +1,10 @@ +#pragma once +#include <Core/Types.h> + +namespace DB +{ +class ProtocolServerAdapter; + +size_t waitServersToFinish(std::vector<ProtocolServerAdapter> & servers, std::mutex & mutex, size_t seconds_to_wait); + +} |