diff options
author | vitalyisaev <vitalyisaev@ydb.tech> | 2023-11-14 09:58:56 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@ydb.tech> | 2023-11-14 10:20:20 +0300 |
commit | c2b2dfd9827a400a8495e172a56343462e3ceb82 (patch) | |
tree | cd4e4f597d01bede4c82dffeb2d780d0a9046bd0 /contrib/clickhouse/src/IO | |
parent | d4ae8f119e67808cb0cf776ba6e0cf95296f2df7 (diff) | |
download | ydb-c2b2dfd9827a400a8495e172a56343462e3ceb82.tar.gz |
YQ Connector: move tests from yql to ydb (OSS)
Перенос папки с тестами на Коннектор из папки yql в папку ydb (синхронизируется с github).
Diffstat (limited to 'contrib/clickhouse/src/IO')
268 files changed, 34177 insertions, 0 deletions
diff --git a/contrib/clickhouse/src/IO/AIO.cpp b/contrib/clickhouse/src/IO/AIO.cpp new file mode 100644 index 0000000000..7088be633e --- /dev/null +++ b/contrib/clickhouse/src/IO/AIO.cpp @@ -0,0 +1,148 @@ +#include <IO/AIO.h> + +#if defined(OS_LINUX) + +# include <Common/Exception.h> + +# include <sys/syscall.h> +# include <unistd.h> +# include <utility> + + +/** Small wrappers for asynchronous I/O. + */ + +namespace DB +{ + namespace ErrorCodes + { + extern const int CANNOT_IOSETUP; + } +} + + +int io_setup(unsigned nr, aio_context_t * ctxp) +{ + return static_cast<int>(syscall(__NR_io_setup, nr, ctxp)); +} + +int io_destroy(aio_context_t ctx) +{ + return static_cast<int>(syscall(__NR_io_destroy, ctx)); +} + +int io_submit(aio_context_t ctx, long nr, struct iocb * iocbpp[]) // NOLINT +{ + return static_cast<int>(syscall(__NR_io_submit, ctx, nr, iocbpp)); +} + +int io_getevents(aio_context_t ctx, long min_nr, long max_nr, io_event * events, struct timespec * timeout) // NOLINT +{ + return static_cast<int>(syscall(__NR_io_getevents, ctx, min_nr, max_nr, events, timeout)); +} + + +AIOContext::AIOContext(unsigned int nr_events) +{ + ctx = 0; + if (io_setup(nr_events, &ctx) < 0) + DB::throwFromErrno("io_setup failed", DB::ErrorCodes::CANNOT_IOSETUP); +} + +AIOContext::~AIOContext() +{ + if (ctx) + io_destroy(ctx); +} + +AIOContext::AIOContext(AIOContext && rhs) noexcept +{ + *this = std::move(rhs); +} + +AIOContext & AIOContext::operator=(AIOContext && rhs) noexcept +{ + std::swap(ctx, rhs.ctx); + return *this; +} + +#elif defined(OS_FREEBSD) + +# include <Common/Exception.h> + + +/** Small wrappers for asynchronous I/O. + */ + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_IOSETUP; +} +} + + +int io_setup(void) +{ + return kqueue(); +} + +int io_destroy(int ctx) +{ + return close(ctx); +} + +int io_submit(int ctx, long nr, struct iocb * iocbpp[]) +{ + for (long i = 0; i < nr; ++i) + { + struct aiocb * iocb = &iocbpp[i]->aio; + + struct sigevent * se = &iocb->aio_sigevent; + se->sigev_notify_kqueue = ctx; + se->sigev_notify_kevent_flags = 0; + se->sigev_notify = SIGEV_KEVENT; + se->sigev_value.sival_ptr = iocbpp[i]; + + switch (iocb->aio_lio_opcode) + { + case LIO_READ: + { + int r = aio_read(iocb); + if (r < 0) + return r; + break; + } + case LIO_WRITE: + { + int r = aio_write(iocb); + if (r < 0) + return r; + break; + } + } + } + + return static_cast<int>(nr); +} + +int io_getevents(int ctx, long, long max_nr, struct kevent * events, struct timespec * timeout) +{ + return kevent(ctx, nullptr, 0, events, static_cast<int>(max_nr), timeout); +} + + +AIOContext::AIOContext(unsigned int) +{ + ctx = io_setup(); + if (ctx < 0) + DB::throwFromErrno("io_setup failed", DB::ErrorCodes::CANNOT_IOSETUP); +} + +AIOContext::~AIOContext() +{ + io_destroy(ctx); +} + +#endif diff --git a/contrib/clickhouse/src/IO/AIO.h b/contrib/clickhouse/src/IO/AIO.h new file mode 100644 index 0000000000..202939638b --- /dev/null +++ b/contrib/clickhouse/src/IO/AIO.h @@ -0,0 +1,79 @@ +#pragma once + +#include <boost/noncopyable.hpp> + +#if defined(OS_LINUX) + +/// https://stackoverflow.com/questions/20759750/resolving-redefinition-of-timespec-in-time-h +# define timespec linux_timespec +# define timeval linux_timeval +# define itimerspec linux_itimerspec +# define sigset_t linux_sigset_t + +# include <linux/aio_abi.h> + +# undef timespec +# undef timeval +# undef itimerspec +# undef sigset_t + + +/** Small wrappers for asynchronous I/O. + */ + +int io_setup(unsigned nr, aio_context_t * ctxp); + +int io_destroy(aio_context_t ctx); + +/// last argument is an array of pointers technically speaking +int io_submit(aio_context_t ctx, long nr, struct iocb * iocbpp[]); /// NOLINT + +int io_getevents(aio_context_t ctx, long min_nr, long max_nr, io_event * events, struct timespec * timeout); /// NOLINT + + +struct AIOContext : private boost::noncopyable +{ + aio_context_t ctx = 0; + + AIOContext() = default; + explicit AIOContext(unsigned int nr_events); + ~AIOContext(); + AIOContext(AIOContext && rhs) noexcept; + AIOContext & operator=(AIOContext && rhs) noexcept; +}; + +#elif defined(OS_FREEBSD) + +# include <aio.h> +# include <sys/event.h> +# include <sys/time.h> +# include <sys/types.h> + +typedef struct kevent io_event; +typedef int aio_context_t; + +struct iocb +{ + struct aiocb aio; + long aio_data; +}; + +int io_setup(void); + +int io_destroy(void); + +/// last argument is an array of pointers technically speaking +int io_submit(int ctx, long nr, struct iocb * iocbpp[]); + +int io_getevents(int ctx, long min_nr, long max_nr, struct kevent * events, struct timespec * timeout); + + +struct AIOContext : private boost::noncopyable +{ + int ctx; + + AIOContext(unsigned int nr_events = 128); + ~AIOContext(); +}; + +#endif diff --git a/contrib/clickhouse/src/IO/Archives/ArchiveUtils.h b/contrib/clickhouse/src/IO/Archives/ArchiveUtils.h new file mode 100644 index 0000000000..00bebcc890 --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/ArchiveUtils.h @@ -0,0 +1,14 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_LIBARCHIVE + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-macro-identifier" + +#error #include <archive.h> +#error #include <archive_entry.h> +#endif +#endif diff --git a/contrib/clickhouse/src/IO/Archives/IArchiveReader.h b/contrib/clickhouse/src/IO/Archives/IArchiveReader.h new file mode 100644 index 0000000000..84a1dc21f5 --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/IArchiveReader.h @@ -0,0 +1,69 @@ +#pragma once + +#include <boost/noncopyable.hpp> +#include <base/types.h> +#include <functional> +#include <memory> + + +namespace DB +{ +class ReadBuffer; +class ReadBufferFromFileBase; +class SeekableReadBuffer; + +/// Interface for reading an archive. +class IArchiveReader : public std::enable_shared_from_this<IArchiveReader>, boost::noncopyable +{ +public: + virtual ~IArchiveReader() = default; + + /// Returns true if there is a specified file in the archive. + virtual bool fileExists(const String & filename) = 0; + + struct FileInfo + { + UInt64 uncompressed_size; + UInt64 compressed_size; + bool is_encrypted; + }; + + /// Returns the information about a file stored in the archive. + virtual FileInfo getFileInfo(const String & filename) = 0; + + class FileEnumerator + { + public: + virtual ~FileEnumerator() = default; + virtual const String & getFileName() const = 0; + virtual const FileInfo & getFileInfo() const = 0; + virtual bool nextFile() = 0; + }; + + virtual const std::string & getPath() const = 0; + + /// Starts enumerating files in the archive. + virtual std::unique_ptr<FileEnumerator> firstFile() = 0; + + using NameFilter = std::function<bool(const std::string &)>; + + /// Starts reading a file from the archive. The function returns a read buffer, + /// you can read that buffer to extract uncompressed data from the archive. + /// Several read buffers can be used at the same time in parallel. + virtual std::unique_ptr<ReadBufferFromFileBase> readFile(const String & filename, bool throw_on_not_found) = 0; + virtual std::unique_ptr<ReadBufferFromFileBase> readFile(NameFilter filter, bool throw_on_not_found) = 0; + + /// It's possible to convert a file enumerator to a read buffer and vice versa. + virtual std::unique_ptr<ReadBufferFromFileBase> readFile(std::unique_ptr<FileEnumerator> enumerator) = 0; + virtual std::unique_ptr<FileEnumerator> nextFile(std::unique_ptr<ReadBuffer> read_buffer) = 0; + + virtual std::vector<std::string> getAllFiles() = 0; + virtual std::vector<std::string> getAllFiles(NameFilter filter) = 0; + + /// Sets password used to decrypt files in the archive. + virtual void setPassword(const String & /* password */) {} + + using ReadArchiveFunction = std::function<std::unique_ptr<SeekableReadBuffer>()>; +}; + +} diff --git a/contrib/clickhouse/src/IO/Archives/IArchiveWriter.h b/contrib/clickhouse/src/IO/Archives/IArchiveWriter.h new file mode 100644 index 0000000000..d7ff038e7b --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/IArchiveWriter.h @@ -0,0 +1,39 @@ +#pragma once + +#include <boost/noncopyable.hpp> +#include <base/types.h> +#include <memory> + + +namespace DB +{ +class WriteBufferFromFileBase; + +/// Interface for writing an archive. +class IArchiveWriter : public std::enable_shared_from_this<IArchiveWriter>, boost::noncopyable +{ +public: + /// Destructors finalizes writing the archive. + virtual ~IArchiveWriter() = default; + + /// Starts writing a file to the archive. The function returns a write buffer, + /// any data written to that buffer will be compressed and then put to the archive. + /// You can keep only one such buffer at a time, a buffer returned by previous call + /// of the function `writeFile()` should be destroyed before next call of `writeFile()`. + virtual std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename) = 0; + + /// Returns true if there is an active instance of WriteBuffer returned by writeFile(). + /// This function should be used mostly for debugging purposes. + virtual bool isWritingFile() const = 0; + + static constexpr const int kDefaultCompressionLevel = -1; + + /// Sets compression method and level. + /// Changing them will affect next file in the archive. + virtual void setCompression(const String & /* compression_method */, int /* compression_level */ = kDefaultCompressionLevel) {} + + /// Sets password. If the password is not empty it will enable encryption in the archive. + virtual void setPassword(const String & /* password */) {} +}; + +} diff --git a/contrib/clickhouse/src/IO/Archives/LibArchiveReader.cpp b/contrib/clickhouse/src/IO/Archives/LibArchiveReader.cpp new file mode 100644 index 0000000000..a411b4bb4b --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/LibArchiveReader.cpp @@ -0,0 +1,354 @@ +#include <IO/Archives/LibArchiveReader.h> +#include <IO/ReadBufferFromFileBase.h> +#include <Common/quoteString.h> +#include <Common/scope_guard_safe.h> + +#include <IO/Archives/ArchiveUtils.h> + +#include <mutex> + +namespace DB +{ + +#if USE_LIBARCHIVE + +namespace ErrorCodes +{ + extern const int CANNOT_UNPACK_ARCHIVE; + extern const int LOGICAL_ERROR; + extern const int CANNOT_READ_ALL_DATA; + extern const int UNSUPPORTED_METHOD; +} + +class LibArchiveReader::Handle +{ +public: + explicit Handle(std::string path_to_archive_, bool lock_on_reading_) + : path_to_archive(path_to_archive_), lock_on_reading(lock_on_reading_) + { + current_archive = open(path_to_archive); + } + + Handle(const Handle &) = delete; + Handle(Handle && other) noexcept + : current_archive(other.current_archive) + , current_entry(other.current_entry) + , lock_on_reading(other.lock_on_reading) + { + other.current_archive = nullptr; + other.current_entry = nullptr; + } + + ~Handle() + { + close(current_archive); + } + + bool locateFile(const std::string & filename) + { + return locateFile([&](const std::string & file) { return file == filename; }); + } + + bool locateFile(NameFilter filter) + { + resetFileInfo(); + int err = ARCHIVE_OK; + while (true) + { + err = readNextHeader(current_archive, ¤t_entry); + + if (err == ARCHIVE_RETRY) + continue; + + if (err != ARCHIVE_OK) + break; + + if (filter(archive_entry_pathname(current_entry))) + return true; + } + + checkError(err); + return false; + } + + bool nextFile() + { + resetFileInfo(); + int err = ARCHIVE_OK; + do + { + err = readNextHeader(current_archive, ¤t_entry); + } while (err == ARCHIVE_RETRY); + + checkError(err); + return err == ARCHIVE_OK; + } + + std::vector<std::string> getAllFiles(NameFilter filter) + { + auto * archive = open(path_to_archive); + SCOPE_EXIT( + close(archive); + ); + + struct archive_entry * entry = nullptr; + + std::vector<std::string> files; + int error = readNextHeader(archive, &entry); + while (error == ARCHIVE_OK || error == ARCHIVE_RETRY) + { + chassert(entry != nullptr); + std::string name = archive_entry_pathname(entry); + if (!filter || filter(name)) + files.push_back(std::move(name)); + + error = readNextHeader(archive, &entry); + } + + checkError(error); + return files; + } + + const String & getFileName() const + { + chassert(current_entry); + if (!file_name) + file_name.emplace(archive_entry_pathname(current_entry)); + + return *file_name; + } + + const FileInfo & getFileInfo() const + { + chassert(current_entry); + if (!file_info) + { + file_info.emplace(); + file_info->uncompressed_size = archive_entry_size(current_entry); + file_info->compressed_size = archive_entry_size(current_entry); + file_info->is_encrypted = false; + } + + return *file_info; + } + + struct archive * current_archive; + struct archive_entry * current_entry = nullptr; +private: + void checkError(int error) const + { + if (error == ARCHIVE_FATAL) + throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Failed to read archive while fetching all files: {}", archive_error_string(current_archive)); + } + + void resetFileInfo() + { + file_name.reset(); + file_info.reset(); + } + + static struct archive * open(const String & path_to_archive) + { + auto * archive = archive_read_new(); + try + { + archive_read_support_filter_all(archive); + archive_read_support_format_all(archive); + if (archive_read_open_filename(archive, path_to_archive.c_str(), 10240) != ARCHIVE_OK) + throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't open archive {}: {}", quoteString(path_to_archive), archive_error_string(archive)); + } + catch (...) + { + close(archive); + throw; + } + + return archive; + } + + static void close(struct archive * archive) + { + if (archive) + { + archive_read_close(archive); + archive_read_free(archive); + } + } + + int readNextHeader(struct archive * archive, struct archive_entry ** entry) const + { + std::unique_lock lock(Handle::read_lock, std::defer_lock); + if (lock_on_reading) + lock.lock(); + + return archive_read_next_header(archive, entry); + } + + const String path_to_archive; + + /// for some archive types when we are reading headers static variables are used + /// which are not thread-safe + const bool lock_on_reading; + static inline std::mutex read_lock; + + mutable std::optional<String> file_name; + mutable std::optional<FileInfo> file_info; +}; + +class LibArchiveReader::FileEnumeratorImpl : public FileEnumerator +{ +public: + explicit FileEnumeratorImpl(Handle handle_) : handle(std::move(handle_)) {} + + const String & getFileName() const override { return handle.getFileName(); } + const FileInfo & getFileInfo() const override { return handle.getFileInfo(); } + bool nextFile() override { return handle.nextFile(); } + + /// Releases owned handle to pass it to a read buffer. + Handle releaseHandle() && { return std::move(handle); } +private: + Handle handle; +}; + +class LibArchiveReader::ReadBufferFromLibArchive : public ReadBufferFromFileBase +{ +public: + explicit ReadBufferFromLibArchive(Handle handle_, std::string path_to_archive_) + : ReadBufferFromFileBase(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0) + , handle(std::move(handle_)) + , path_to_archive(std::move(path_to_archive_)) + {} + + off_t seek(off_t /* off */, int /* whence */) override + { + throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Seek is not supported when reading from archive"); + } + + off_t getPosition() override + { + throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "getPosition not supported when reading from archive"); + } + + String getFileName() const override { return handle.getFileName(); } + + size_t getFileSize() override { return handle.getFileInfo().uncompressed_size; } + + Handle releaseHandle() && + { + return std::move(handle); + } + +private: + bool nextImpl() override + { + auto bytes_read = archive_read_data(handle.current_archive, internal_buffer.begin(), static_cast<int>(internal_buffer.size())); + + if (bytes_read < 0) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Failed to read file {} from {}: {}", handle.getFileName(), path_to_archive, archive_error_string(handle.current_archive)); + + if (!bytes_read) + return false; + + total_bytes_read += bytes; + + working_buffer = internal_buffer; + working_buffer.resize(bytes_read); + return true; + } + + Handle handle; + const String path_to_archive; + size_t total_bytes_read = 0; +}; + +LibArchiveReader::LibArchiveReader(std::string archive_name_, bool lock_on_reading_, std::string path_to_archive_) + : archive_name(std::move(archive_name_)), lock_on_reading(lock_on_reading_), path_to_archive(std::move(path_to_archive_)) +{} + +LibArchiveReader::~LibArchiveReader() = default; + +const std::string & LibArchiveReader::getPath() const +{ + return path_to_archive; +} + +bool LibArchiveReader::fileExists(const String & filename) +{ + Handle handle(path_to_archive, lock_on_reading); + return handle.locateFile(filename); +} + +LibArchiveReader::FileInfo LibArchiveReader::getFileInfo(const String & filename) +{ + Handle handle(path_to_archive, lock_on_reading); + if (!handle.locateFile(filename)) + throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't unpack archive {}: file not found", path_to_archive); + return handle.getFileInfo(); +} + +std::unique_ptr<LibArchiveReader::FileEnumerator> LibArchiveReader::firstFile() +{ + Handle handle(path_to_archive, lock_on_reading); + if (!handle.nextFile()) + return nullptr; + + return std::make_unique<FileEnumeratorImpl>(std::move(handle)); +} + +std::unique_ptr<ReadBufferFromFileBase> LibArchiveReader::readFile(const String & filename, bool throw_on_not_found) +{ + return readFile([&](const std::string & file) { return file == filename; }, throw_on_not_found); +} + +std::unique_ptr<ReadBufferFromFileBase> LibArchiveReader::readFile(NameFilter filter, bool throw_on_not_found) +{ + Handle handle(path_to_archive, lock_on_reading); + if (!handle.locateFile(filter)) + { + if (throw_on_not_found) + throw Exception( + ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't unpack archive {}: no file found satisfying the filter", path_to_archive); + return nullptr; + } + return std::make_unique<ReadBufferFromLibArchive>(std::move(handle), path_to_archive); +} + +std::unique_ptr<ReadBufferFromFileBase> LibArchiveReader::readFile(std::unique_ptr<FileEnumerator> enumerator) +{ + if (!dynamic_cast<FileEnumeratorImpl *>(enumerator.get())) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong enumerator passed to readFile()"); + auto enumerator_impl = std::unique_ptr<FileEnumeratorImpl>(static_cast<FileEnumeratorImpl *>(enumerator.release())); + auto handle = std::move(*enumerator_impl).releaseHandle(); + return std::make_unique<ReadBufferFromLibArchive>(std::move(handle), path_to_archive); +} + +std::unique_ptr<LibArchiveReader::FileEnumerator> LibArchiveReader::nextFile(std::unique_ptr<ReadBuffer> read_buffer) +{ + if (!dynamic_cast<ReadBufferFromLibArchive *>(read_buffer.get())) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong ReadBuffer passed to nextFile()"); + auto read_buffer_from_libarchive = std::unique_ptr<ReadBufferFromLibArchive>(static_cast<ReadBufferFromLibArchive *>(read_buffer.release())); + auto handle = std::move(*read_buffer_from_libarchive).releaseHandle(); + if (!handle.nextFile()) + return nullptr; + return std::make_unique<FileEnumeratorImpl>(std::move(handle)); +} + +std::vector<std::string> LibArchiveReader::getAllFiles() +{ + return getAllFiles({}); +} + +std::vector<std::string> LibArchiveReader::getAllFiles(NameFilter filter) +{ + Handle handle(path_to_archive, lock_on_reading); + return handle.getAllFiles(filter); +} + +void LibArchiveReader::setPassword(const String & /*password_*/) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not set password to {} archive", archive_name); +} + +#endif + +} diff --git a/contrib/clickhouse/src/IO/Archives/LibArchiveReader.h b/contrib/clickhouse/src/IO/Archives/LibArchiveReader.h new file mode 100644 index 0000000000..5d5e9a5a25 --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/LibArchiveReader.h @@ -0,0 +1,78 @@ +#pragma once + +#include "clickhouse_config.h" + +#include <IO/Archives/IArchiveReader.h> + + +namespace DB +{ + +#if USE_LIBARCHIVE + +class ReadBuffer; +class ReadBufferFromFileBase; +class SeekableReadBuffer; + +/// Implementation of IArchiveReader for reading archives using libarchive. +class LibArchiveReader : public IArchiveReader +{ +public: + ~LibArchiveReader() override; + + const std::string & getPath() const override; + + /// Returns true if there is a specified file in the archive. + bool fileExists(const String & filename) override; + + /// Returns the information about a file stored in the archive. + FileInfo getFileInfo(const String & filename) override; + + /// Starts enumerating files in the archive. + std::unique_ptr<FileEnumerator> firstFile() override; + + /// Starts reading a file from the archive. The function returns a read buffer, + /// you can read that buffer to extract uncompressed data from the archive. + /// Several read buffers can be used at the same time in parallel. + std::unique_ptr<ReadBufferFromFileBase> readFile(const String & filename, bool throw_on_not_found) override; + std::unique_ptr<ReadBufferFromFileBase> readFile(NameFilter filter, bool throw_on_not_found) override; + + /// It's possible to convert a file enumerator to a read buffer and vice versa. + std::unique_ptr<ReadBufferFromFileBase> readFile(std::unique_ptr<FileEnumerator> enumerator) override; + std::unique_ptr<FileEnumerator> nextFile(std::unique_ptr<ReadBuffer> read_buffer) override; + + std::vector<std::string> getAllFiles() override; + std::vector<std::string> getAllFiles(NameFilter filter) override; + + /// Sets password used to decrypt the contents of the files in the archive. + void setPassword(const String & password_) override; + +protected: + /// Constructs an archive's reader that will read from a file in the local filesystem. + LibArchiveReader(std::string archive_name_, bool lock_on_reading_, std::string path_to_archive_); + +private: + class ReadBufferFromLibArchive; + class Handle; + class FileEnumeratorImpl; + + const std::string archive_name; + const bool lock_on_reading; + const String path_to_archive; +}; + +class TarArchiveReader : public LibArchiveReader +{ +public: + explicit TarArchiveReader(std::string path_to_archive) : LibArchiveReader("tar", /*lock_on_reading_=*/ true, std::move(path_to_archive)) { } +}; + +class SevenZipArchiveReader : public LibArchiveReader +{ +public: + explicit SevenZipArchiveReader(std::string path_to_archive) : LibArchiveReader("7z", /*lock_on_reading_=*/ false, std::move(path_to_archive)) { } +}; + +#endif + +} diff --git a/contrib/clickhouse/src/IO/Archives/ZipArchiveReader.cpp b/contrib/clickhouse/src/IO/Archives/ZipArchiveReader.cpp new file mode 100644 index 0000000000..970211f06b --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/ZipArchiveReader.cpp @@ -0,0 +1,662 @@ +#include <IO/Archives/ZipArchiveReader.h> + +#if USE_MINIZIP +#include <IO/Archives/ZipArchiveWriter.h> +#include <IO/ReadBufferFromFileBase.h> +#include <Common/quoteString.h> +#include <base/errnoToString.h> +#error #include <unzip.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_UNPACK_ARCHIVE; + extern const int LOGICAL_ERROR; + extern const int SEEK_POSITION_OUT_OF_BOUND; +} + +using RawHandle = unzFile; + + +namespace +{ + void checkCompressionMethodIsEnabled(int compression_method_) + { + ZipArchiveWriter::checkCompressionMethodIsEnabled(compression_method_); + } + + void checkEncryptionIsEnabled() + { + ZipArchiveWriter::checkEncryptionIsEnabled(); + } +} + + +/// Holds a raw handle, calls acquireRawHandle() in the constructor and releaseRawHandle() in the destructor. +class ZipArchiveReader::HandleHolder +{ +public: + HandleHolder() = default; + + explicit HandleHolder(const std::shared_ptr<ZipArchiveReader> & reader_) : reader(reader_), raw_handle(reader->acquireRawHandle()) { } + + ~HandleHolder() + { + if (raw_handle) + { + try + { + closeFile(); + } + catch (...) + { + tryLogCurrentException("ZipArchiveReader"); + } + reader->releaseRawHandle(raw_handle); + } + } + + HandleHolder(HandleHolder && src) noexcept + { + *this = std::move(src); + } + + HandleHolder & operator=(HandleHolder && src) noexcept + { + reader = std::exchange(src.reader, nullptr); + raw_handle = std::exchange(src.raw_handle, nullptr); + file_name = std::exchange(src.file_name, {}); + file_info = std::exchange(src.file_info, {}); + return *this; + } + + RawHandle getRawHandle() const { return raw_handle; } + std::shared_ptr<ZipArchiveReader> getReader() const { return reader; } + + bool locateFile(const String & file_name_) + { + resetFileInfo(); + bool case_sensitive = true; + int err = unzLocateFile(raw_handle, file_name_.c_str(), reinterpret_cast<unzFileNameComparer>(static_cast<size_t>(case_sensitive))); + if (err == UNZ_END_OF_LIST_OF_FILE) + return false; + file_name = file_name_; + return true; + } + + bool locateFile(NameFilter filter) + { + int err = unzGoToFirstFile(raw_handle); + if (err == UNZ_END_OF_LIST_OF_FILE) + return false; + + do + { + checkResult(err); + resetFileInfo(); + retrieveFileInfo(); + if (filter(getFileName())) + return true; + + err = unzGoToNextFile(raw_handle); + } while (err != UNZ_END_OF_LIST_OF_FILE); + + return false; + } + + bool tryLocateFile(const String & file_name_) + { + resetFileInfo(); + bool case_sensitive = true; + int err = unzLocateFile(raw_handle, file_name_.c_str(), reinterpret_cast<unzFileNameComparer>(static_cast<size_t>(case_sensitive))); + if (err == UNZ_END_OF_LIST_OF_FILE) + return false; + checkResult(err); + file_name = file_name_; + return true; + } + + bool firstFile() + { + resetFileInfo(); + int err = unzGoToFirstFile(raw_handle); + if (err == UNZ_END_OF_LIST_OF_FILE) + return false; + checkResult(err); + return true; + } + + bool nextFile() + { + resetFileInfo(); + int err = unzGoToNextFile(raw_handle); + if (err == UNZ_END_OF_LIST_OF_FILE) + return false; + checkResult(err); + return true; + } + + const String & getFileName() const + { + if (!file_name) + retrieveFileInfo(); + return *file_name; + } + + const FileInfoImpl & getFileInfo() const + { + if (!file_info) + retrieveFileInfo(); + return *file_info; + } + + std::vector<std::string> getAllFiles(NameFilter filter) + { + std::vector<std::string> files; + resetFileInfo(); + int err = unzGoToFirstFile(raw_handle); + if (err == UNZ_END_OF_LIST_OF_FILE) + return files; + + do + { + checkResult(err); + resetFileInfo(); + retrieveFileInfo(); + if (!filter || filter(getFileName())) + files.push_back(*file_name); + err = unzGoToNextFile(raw_handle); + } while (err != UNZ_END_OF_LIST_OF_FILE); + + return files; + } + + void closeFile() + { + int err = unzCloseCurrentFile(raw_handle); + /// If err == UNZ_PARAMERROR the file is already closed. + if (err != UNZ_PARAMERROR) + checkResult(err); + } + + void checkResult(int code) const { reader->checkResult(code); } + [[noreturn]] void showError(const String & message) const { reader->showError(message); } + +private: + void retrieveFileInfo() const + { + if (file_name && file_info) + return; + unz_file_info64 finfo; + int err = unzGetCurrentFileInfo64(raw_handle, &finfo, nullptr, 0, nullptr, 0, nullptr, 0); + if (err == UNZ_PARAMERROR) + showError("No current file"); + checkResult(err); + if (!file_info) + { + file_info.emplace(); + file_info->uncompressed_size = finfo.uncompressed_size; + file_info->compressed_size = finfo.compressed_size; + file_info->compression_method = finfo.compression_method; + file_info->is_encrypted = (finfo.flag & MZ_ZIP_FLAG_ENCRYPTED); + } + if (!file_name) + { + file_name.emplace(); + file_name->resize(finfo.size_filename); + checkResult(unzGetCurrentFileInfo64(raw_handle, nullptr, file_name->data(), finfo.size_filename, nullptr, 0, nullptr, 0)); + } + } + + void resetFileInfo() + { + file_info.reset(); + file_name.reset(); + } + + std::shared_ptr<ZipArchiveReader> reader; + RawHandle raw_handle = nullptr; + mutable std::optional<String> file_name; + mutable std::optional<FileInfoImpl> file_info; +}; + + +/// This class represents a ReadBuffer actually returned by readFile(). +class ZipArchiveReader::ReadBufferFromZipArchive : public ReadBufferFromFileBase +{ +public: + explicit ReadBufferFromZipArchive(HandleHolder && handle_) + : ReadBufferFromFileBase(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0) + , handle(std::move(handle_)) + { + const auto & file_info = handle.getFileInfo(); + checkCompressionMethodIsEnabled(file_info.compression_method); + + const char * password_cstr = nullptr; + if (file_info.is_encrypted) + { + const auto & password_str = handle.getReader()->password; + if (password_str.empty()) + showError("Password is required"); + password_cstr = password_str.c_str(); + checkEncryptionIsEnabled(); + } + + RawHandle raw_handle = handle.getRawHandle(); + int err = unzOpenCurrentFilePassword(raw_handle, password_cstr); + if (err == MZ_PASSWORD_ERROR) + showError("Wrong password"); + checkResult(err); + } + + off_t seek(off_t off, int whence) override + { + off_t current_pos = getPosition(); + off_t new_pos; + if (whence == SEEK_SET) + new_pos = off; + else if (whence == SEEK_CUR) + new_pos = off + current_pos; + else + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Only SEEK_SET and SEEK_CUR seek modes allowed."); + + if (new_pos == current_pos) + return current_pos; /// The position is the same. + + if (new_pos < 0) + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bound"); + + off_t working_buffer_start_pos = current_pos - offset(); + off_t working_buffer_end_pos = current_pos + available(); + + if ((working_buffer_start_pos <= new_pos) && (new_pos <= working_buffer_end_pos)) + { + /// The new position is still inside the buffer. + position() += new_pos - current_pos; + return new_pos; + } + + RawHandle raw_handle = handle.getRawHandle(); + + /// Check that the new position is now beyond the end of the file. + const auto & file_info = handle.getFileInfo(); + if (new_pos > static_cast<off_t>(file_info.uncompressed_size)) + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bound"); + + if (file_info.compression_method == MZ_COMPRESS_METHOD_STORE) + { + /// unzSeek64() works only for non-compressed files. + checkResult(unzSeek64(raw_handle, off, whence)); + return unzTell64(raw_handle); + } + + /// As a last try we go slow way, we're going to simply ignore all data before the new position. + if (new_pos < current_pos) + { + checkResult(unzCloseCurrentFile(raw_handle)); + checkResult(unzOpenCurrentFile(raw_handle)); + current_pos = 0; + } + + ignore(new_pos - current_pos); + return new_pos; + } + + off_t getPosition() override + { + RawHandle raw_handle = handle.getRawHandle(); + return unzTell64(raw_handle) - available(); + } + + String getFileName() const override { return handle.getFileName(); } + + size_t getFileSize() override { return handle.getFileInfo().uncompressed_size; } + + /// Releases owned handle to pass it to an enumerator. + HandleHolder releaseHandle() && + { + handle.closeFile(); + return std::move(handle); + } + +private: + bool nextImpl() override + { + RawHandle raw_handle = handle.getRawHandle(); + auto bytes_read = unzReadCurrentFile(raw_handle, internal_buffer.begin(), static_cast<int>(internal_buffer.size())); + + if (bytes_read < 0) + checkResult(bytes_read); + + if (!bytes_read) + return false; + + working_buffer = internal_buffer; + working_buffer.resize(bytes_read); + return true; + } + + void checkResult(int code) const { handle.checkResult(code); } + [[noreturn]] void showError(const String & message) const { handle.showError(message); } + + HandleHolder handle; +}; + + +class ZipArchiveReader::FileEnumeratorImpl : public FileEnumerator +{ +public: + explicit FileEnumeratorImpl(HandleHolder && handle_) : handle(std::move(handle_)) {} + + const String & getFileName() const override { return handle.getFileName(); } + const FileInfo & getFileInfo() const override { return handle.getFileInfo(); } + bool nextFile() override { return handle.nextFile(); } + + /// Releases owned handle to pass it to a read buffer. + HandleHolder releaseHandle() && { return std::move(handle); } + +private: + HandleHolder handle; +}; + + +namespace +{ + /// Provides a set of functions allowing the minizip library to read its input + /// from a SeekableReadBuffer instead of an ordinary file in the local filesystem. + class StreamFromReadBuffer + { + public: + static RawHandle open(std::unique_ptr<SeekableReadBuffer> archive_read_buffer, UInt64 archive_size) + { + StreamFromReadBuffer::Opaque opaque{std::move(archive_read_buffer), archive_size}; + + zlib_filefunc64_def func_def; + func_def.zopen64_file = &StreamFromReadBuffer::openFileFunc; + func_def.zclose_file = &StreamFromReadBuffer::closeFileFunc; + func_def.zread_file = &StreamFromReadBuffer::readFileFunc; + func_def.zwrite_file = &StreamFromReadBuffer::writeFileFunc; + func_def.zseek64_file = &StreamFromReadBuffer::seekFunc; + func_def.ztell64_file = &StreamFromReadBuffer::tellFunc; + func_def.zerror_file = &StreamFromReadBuffer::testErrorFunc; + func_def.opaque = &opaque; + + return unzOpen2_64(/* path= */ nullptr, + &func_def); + } + + private: + std::unique_ptr<SeekableReadBuffer> read_buffer; + UInt64 start_offset = 0; + UInt64 total_size = 0; + bool at_end = false; + + struct Opaque + { + std::unique_ptr<SeekableReadBuffer> read_buffer; + UInt64 total_size = 0; + }; + + static void * openFileFunc(void * opaque, const void *, int) + { + auto & opq = *reinterpret_cast<Opaque *>(opaque); + return new StreamFromReadBuffer(std::move(opq.read_buffer), opq.total_size); + } + + StreamFromReadBuffer(std::unique_ptr<SeekableReadBuffer> read_buffer_, UInt64 total_size_) + : read_buffer(std::move(read_buffer_)), start_offset(read_buffer->getPosition()), total_size(total_size_) {} + + static int closeFileFunc(void *, void * stream) + { + delete reinterpret_cast<StreamFromReadBuffer *>(stream); + return ZIP_OK; + } + + static StreamFromReadBuffer & get(void * ptr) + { + return *reinterpret_cast<StreamFromReadBuffer *>(ptr); + } + + static int testErrorFunc(void *, void *) + { + return ZIP_OK; + } + + static unsigned long readFileFunc(void *, void * stream, void * buf, unsigned long size) // NOLINT(google-runtime-int) + { + auto & strm = get(stream); + if (strm.at_end) + return 0; + auto read_bytes = strm.read_buffer->read(reinterpret_cast<char *>(buf), size); + return read_bytes; + } + + static ZPOS64_T tellFunc(void *, void * stream) + { + auto & strm = get(stream); + if (strm.at_end) + return strm.total_size; + auto pos = strm.read_buffer->getPosition() - strm.start_offset; + return pos; + } + + static long seekFunc(void *, void * stream, ZPOS64_T offset, int origin) // NOLINT(google-runtime-int) + { + auto & strm = get(stream); + if (origin == SEEK_END) + { + /// Our implementations of SeekableReadBuffer don't support SEEK_END, + /// but the minizip library needs it, so we have to simulate it here. + strm.at_end = true; + return ZIP_OK; + } + strm.at_end = false; + if (origin == SEEK_SET) + offset += strm.start_offset; + strm.read_buffer->seek(offset, origin); + return ZIP_OK; + } + + static unsigned long writeFileFunc(void *, void *, const void *, unsigned long) // NOLINT(google-runtime-int) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "StreamFromReadBuffer::writeFile must not be called"); + } + }; +} + + +ZipArchiveReader::ZipArchiveReader(const String & path_to_archive_) + : path_to_archive(path_to_archive_) +{ + init(); + +} + +ZipArchiveReader::ZipArchiveReader( + const String & path_to_archive_, const ReadArchiveFunction & archive_read_function_, UInt64 archive_size_) + : path_to_archive(path_to_archive_), archive_read_function(archive_read_function_), archive_size(archive_size_) +{ + init(); +} + +void ZipArchiveReader::init() +{ + /// Prepare the first handle in `free_handles` and check that the archive can be read. + releaseRawHandle(acquireRawHandle()); +} + +ZipArchiveReader::~ZipArchiveReader() +{ + /// Close all `free_handles`. + for (RawHandle free_handle : free_handles) + { + try + { + checkResult(unzClose(free_handle)); + } + catch (...) + { + tryLogCurrentException("ZipArchiveReader"); + } + } +} + +const std::string & ZipArchiveReader::getPath() const +{ + return path_to_archive; +} + +bool ZipArchiveReader::fileExists(const String & filename) +{ + return acquireHandle().tryLocateFile(filename); +} + +ZipArchiveReader::FileInfo ZipArchiveReader::getFileInfo(const String & filename) +{ + auto handle = acquireHandle(); + if (!handle.locateFile(filename)) + showError(fmt::format("File {} was not found in archive", quoteString(filename))); + + return handle.getFileInfo(); +} + +std::unique_ptr<ZipArchiveReader::FileEnumerator> ZipArchiveReader::firstFile() +{ + auto handle = acquireHandle(); + if (!handle.firstFile()) + return nullptr; + return std::make_unique<FileEnumeratorImpl>(std::move(handle)); +} + +std::unique_ptr<ReadBufferFromFileBase> ZipArchiveReader::readFile(const String & filename, bool throw_on_not_found) +{ + auto handle = acquireHandle(); + if (!handle.locateFile(filename)) + { + if (throw_on_not_found) + showError(fmt::format("File {} was not found in archive", quoteString(filename))); + + return nullptr; + } + + return std::make_unique<ReadBufferFromZipArchive>(std::move(handle)); +} + +std::unique_ptr<ReadBufferFromFileBase> ZipArchiveReader::readFile(NameFilter filter, bool throw_on_not_found) +{ + auto handle = acquireHandle(); + if (!handle.locateFile(filter)) + { + if (throw_on_not_found) + showError(fmt::format("No file satisfying filter in archive")); + + return nullptr; + } + + return std::make_unique<ReadBufferFromZipArchive>(std::move(handle)); +} + +std::unique_ptr<ReadBufferFromFileBase> ZipArchiveReader::readFile(std::unique_ptr<FileEnumerator> enumerator) +{ + if (!dynamic_cast<FileEnumeratorImpl *>(enumerator.get())) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong enumerator passed to readFile()"); + auto enumerator_impl = std::unique_ptr<FileEnumeratorImpl>(static_cast<FileEnumeratorImpl *>(enumerator.release())); + auto handle = std::move(*enumerator_impl).releaseHandle(); + return std::make_unique<ReadBufferFromZipArchive>(std::move(handle)); +} + +std::unique_ptr<ZipArchiveReader::FileEnumerator> ZipArchiveReader::nextFile(std::unique_ptr<ReadBuffer> read_buffer) +{ + if (!dynamic_cast<ReadBufferFromZipArchive *>(read_buffer.get())) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong ReadBuffer passed to nextFile()"); + auto read_buffer_from_zip = std::unique_ptr<ReadBufferFromZipArchive>(static_cast<ReadBufferFromZipArchive *>(read_buffer.release())); + auto handle = std::move(*read_buffer_from_zip).releaseHandle(); + if (!handle.nextFile()) + return nullptr; + return std::make_unique<FileEnumeratorImpl>(std::move(handle)); +} + +std::vector<std::string> ZipArchiveReader::getAllFiles() +{ + return getAllFiles({}); +} + +std::vector<std::string> ZipArchiveReader::getAllFiles(NameFilter filter) +{ + auto handle = acquireHandle(); + return handle.getAllFiles(filter); +} + +void ZipArchiveReader::setPassword(const String & password_) +{ + std::lock_guard lock{mutex}; + password = password_; +} + +ZipArchiveReader::HandleHolder ZipArchiveReader::acquireHandle() +{ + return HandleHolder{std::static_pointer_cast<ZipArchiveReader>(shared_from_this())}; +} + +ZipArchiveReader::RawHandle ZipArchiveReader::acquireRawHandle() +{ + std::lock_guard lock{mutex}; + + if (!free_handles.empty()) + { + RawHandle free_handle = free_handles.back(); + free_handles.pop_back(); + return free_handle; + } + + RawHandle new_handle = nullptr; + if (archive_read_function) + new_handle = StreamFromReadBuffer::open(archive_read_function(), archive_size); + else + new_handle = unzOpen64(path_to_archive.c_str()); + + if (!new_handle) + throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't open zip archive {}", quoteString(path_to_archive)); + + return new_handle; +} + +void ZipArchiveReader::releaseRawHandle(RawHandle handle_) +{ + if (!handle_) + return; + + std::lock_guard lock{mutex}; + free_handles.push_back(handle_); +} + +void ZipArchiveReader::checkResult(int code) const +{ + if (code >= UNZ_OK) + return; + + String message = "Code = "; + switch (code) + { + case UNZ_OK: return; + case UNZ_ERRNO: message += "ERRNO, errno = " + errnoToString(); break; + case UNZ_PARAMERROR: message += "PARAMERROR"; break; + case UNZ_BADZIPFILE: message += "BADZIPFILE"; break; + case UNZ_INTERNALERROR: message += "INTERNALERROR"; break; + case UNZ_CRCERROR: message += "CRCERROR"; break; + case UNZ_BADPASSWORD: message += "BADPASSWORD"; break; + default: message += std::to_string(code); break; + } + showError(message); +} + +void ZipArchiveReader::showError(const String & message) const +{ + throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Couldn't unpack zip archive {}: {}", quoteString(path_to_archive), message); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/Archives/ZipArchiveReader.h b/contrib/clickhouse/src/IO/Archives/ZipArchiveReader.h new file mode 100644 index 0000000000..74fa26b6fe --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/ZipArchiveReader.h @@ -0,0 +1,87 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_MINIZIP +#include <IO/Archives/IArchiveReader.h> +#include <mutex> +#include <vector> + + +namespace DB +{ +class ReadBuffer; +class ReadBufferFromFileBase; +class SeekableReadBuffer; + +/// Implementation of IArchiveReader for reading zip archives. +class ZipArchiveReader : public IArchiveReader +{ +public: + /// Constructs an archive's reader that will read from a file in the local filesystem. + explicit ZipArchiveReader(const String & path_to_archive_); + + /// Constructs an archive's reader that will read by making a read buffer by using + /// a specified function. + ZipArchiveReader(const String & path_to_archive_, const ReadArchiveFunction & archive_read_function_, UInt64 archive_size_); + + ~ZipArchiveReader() override; + + const std::string & getPath() const override; + + /// Returns true if there is a specified file in the archive. + bool fileExists(const String & filename) override; + + /// Returns the information about a file stored in the archive. + FileInfo getFileInfo(const String & filename) override; + + /// Starts enumerating files in the archive. + std::unique_ptr<FileEnumerator> firstFile() override; + + /// Starts reading a file from the archive. The function returns a read buffer, + /// you can read that buffer to extract uncompressed data from the archive. + /// Several read buffers can be used at the same time in parallel. + std::unique_ptr<ReadBufferFromFileBase> readFile(const String & filename, bool throw_on_not_found) override; + std::unique_ptr<ReadBufferFromFileBase> readFile(NameFilter filter, bool throw_on_not_found) override; + + /// It's possible to convert a file enumerator to a read buffer and vice versa. + std::unique_ptr<ReadBufferFromFileBase> readFile(std::unique_ptr<FileEnumerator> enumerator) override; + std::unique_ptr<FileEnumerator> nextFile(std::unique_ptr<ReadBuffer> read_buffer) override; + + std::vector<std::string> getAllFiles() override; + std::vector<std::string> getAllFiles(NameFilter filter) override; + + /// Sets password used to decrypt the contents of the files in the archive. + void setPassword(const String & password_) override; + +private: + class ReadBufferFromZipArchive; + class FileEnumeratorImpl; + class HandleHolder; + using RawHandle = void *; + + void init(); + + struct FileInfoImpl : public FileInfo + { + int compression_method; + }; + + HandleHolder acquireHandle(); + RawHandle acquireRawHandle(); + void releaseRawHandle(RawHandle handle_); + + void checkResult(int code) const; + [[noreturn]] void showError(const String & message) const; + + const String path_to_archive; + const ReadArchiveFunction archive_read_function; + const UInt64 archive_size = 0; + String password; + std::vector<RawHandle> free_handles; + mutable std::mutex mutex; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/Archives/ZipArchiveWriter.cpp b/contrib/clickhouse/src/IO/Archives/ZipArchiveWriter.cpp new file mode 100644 index 0000000000..4f8aa27df6 --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/ZipArchiveWriter.cpp @@ -0,0 +1,407 @@ +#include <IO/Archives/ZipArchiveWriter.h> + +#if USE_MINIZIP +#include <IO/WriteBufferFromFileBase.h> +#include <Common/quoteString.h> +#include <base/errnoToString.h> +#error #include <zip.h> +#include <boost/algorithm/string/predicate.hpp> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_PACK_ARCHIVE; + extern const int SUPPORT_IS_DISABLED; + extern const int LOGICAL_ERROR; +} + +using RawHandle = zipFile; + + +/// Holds a raw handle, calls acquireRawHandle() in the constructor and releaseRawHandle() in the destructor. +class ZipArchiveWriter::HandleHolder +{ +public: + HandleHolder() = default; + + explicit HandleHolder(const std::shared_ptr<ZipArchiveWriter> & writer_) : writer(writer_), raw_handle(writer->acquireRawHandle()) { } + + ~HandleHolder() + { + if (raw_handle) + { + try + { + int err = zipCloseFileInZip(raw_handle); + /// If err == ZIP_PARAMERROR the file is already closed. + if (err != ZIP_PARAMERROR) + checkResult(err); + } + catch (...) + { + tryLogCurrentException("ZipArchiveWriter"); + } + writer->releaseRawHandle(raw_handle); + } + } + + HandleHolder(HandleHolder && src) noexcept + { + *this = std::move(src); + } + + HandleHolder & operator=(HandleHolder && src) noexcept + { + writer = std::exchange(src.writer, nullptr); + raw_handle = std::exchange(src.raw_handle, nullptr); + return *this; + } + + RawHandle getRawHandle() const { return raw_handle; } + std::shared_ptr<ZipArchiveWriter> getWriter() const { return writer; } + + void checkResult(int code) const { writer->checkResult(code); } + +private: + std::shared_ptr<ZipArchiveWriter> writer; + RawHandle raw_handle = nullptr; +}; + + +/// This class represents a WriteBuffer actually returned by writeFile(). +class ZipArchiveWriter::WriteBufferFromZipArchive : public WriteBufferFromFileBase +{ +public: + WriteBufferFromZipArchive(HandleHolder && handle_, const String & filename_) + : WriteBufferFromFileBase(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0) + , handle(std::move(handle_)) + , filename(filename_) + { + auto compress_method = handle.getWriter()->compression_method; + auto compress_level = handle.getWriter()->compression_level; + checkCompressionMethodIsEnabled(compress_method); + + const char * password_cstr = nullptr; + const String & password_str = handle.getWriter()->password; + if (!password_str.empty()) + { + checkEncryptionIsEnabled(); + password_cstr = password_str.c_str(); + } + + RawHandle raw_handle = handle.getRawHandle(); + + checkResult(zipOpenNewFileInZip3_64( + raw_handle, + filename_.c_str(), + /* zipfi= */ nullptr, + /* extrafield_local= */ nullptr, + /* size_extrafield_local= */ 0, + /* extrafield_global= */ nullptr, + /* size_extrafield_global= */ 0, + /* comment= */ nullptr, + compress_method, + compress_level, + /* raw= */ false, + /* windowBits= */ 0, + /* memLevel= */ 0, + /* strategy= */ 0, + password_cstr, + /* crc_for_crypting= */ 0, + /* zip64= */ true)); + } + + ~WriteBufferFromZipArchive() override + { + try + { + finalize(); + } + catch (...) + { + tryLogCurrentException("ZipArchiveWriter"); + } + } + + void sync() override { next(); } + std::string getFileName() const override { return filename; } + +private: + void nextImpl() override + { + if (!offset()) + return; + RawHandle raw_handle = handle.getRawHandle(); + int code = zipWriteInFileInZip(raw_handle, working_buffer.begin(), static_cast<uint32_t>(offset())); + checkResult(code); + } + + void checkResult(int code) const { handle.checkResult(code); } + + HandleHolder handle; + String filename; +}; + + +namespace +{ + /// Provides a set of functions allowing the minizip library to write its output + /// to a WriteBuffer instead of an ordinary file in the local filesystem. + class StreamFromWriteBuffer + { + public: + static RawHandle open(std::unique_ptr<WriteBuffer> archive_write_buffer) + { + Opaque opaque{std::move(archive_write_buffer)}; + + zlib_filefunc64_def func_def; + func_def.zopen64_file = &StreamFromWriteBuffer::openFileFunc; + func_def.zclose_file = &StreamFromWriteBuffer::closeFileFunc; + func_def.zread_file = &StreamFromWriteBuffer::readFileFunc; + func_def.zwrite_file = &StreamFromWriteBuffer::writeFileFunc; + func_def.zseek64_file = &StreamFromWriteBuffer::seekFunc; + func_def.ztell64_file = &StreamFromWriteBuffer::tellFunc; + func_def.zerror_file = &StreamFromWriteBuffer::testErrorFunc; + func_def.opaque = &opaque; + + return zipOpen2_64( + /* path= */ nullptr, + /* append= */ false, + /* globalcomment= */ nullptr, + &func_def); + } + + private: + std::unique_ptr<WriteBuffer> write_buffer; + UInt64 start_offset = 0; + + struct Opaque + { + std::unique_ptr<WriteBuffer> write_buffer; + }; + + static void * openFileFunc(void * opaque, const void *, int) + { + Opaque & opq = *reinterpret_cast<Opaque *>(opaque); + return new StreamFromWriteBuffer(std::move(opq.write_buffer)); + } + + explicit StreamFromWriteBuffer(std::unique_ptr<WriteBuffer> write_buffer_) + : write_buffer(std::move(write_buffer_)), start_offset(write_buffer->count()) {} + + ~StreamFromWriteBuffer() + { + write_buffer->finalize(); + } + + static int closeFileFunc(void *, void * stream) + { + delete reinterpret_cast<StreamFromWriteBuffer *>(stream); + return ZIP_OK; + } + + static StreamFromWriteBuffer & get(void * ptr) + { + return *reinterpret_cast<StreamFromWriteBuffer *>(ptr); + } + + static unsigned long writeFileFunc(void *, void * stream, const void * buf, unsigned long size) // NOLINT(google-runtime-int) + { + auto & strm = get(stream); + strm.write_buffer->write(reinterpret_cast<const char *>(buf), size); + return size; + } + + static int testErrorFunc(void *, void *) + { + return ZIP_OK; + } + + static ZPOS64_T tellFunc(void *, void * stream) + { + auto & strm = get(stream); + auto pos = strm.write_buffer->count() - strm.start_offset; + return pos; + } + + static long seekFunc(void *, void *, ZPOS64_T, int) // NOLINT(google-runtime-int) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "StreamFromWriteBuffer::seek must not be called"); + } + + static unsigned long readFileFunc(void *, void *, void *, unsigned long) // NOLINT(google-runtime-int) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "StreamFromWriteBuffer::readFile must not be called"); + } + }; +} + + +ZipArchiveWriter::ZipArchiveWriter(const String & path_to_archive_) + : ZipArchiveWriter(path_to_archive_, nullptr) +{ +} + +ZipArchiveWriter::ZipArchiveWriter(const String & path_to_archive_, std::unique_ptr<WriteBuffer> archive_write_buffer_) + : path_to_archive(path_to_archive_), compression_method(MZ_COMPRESS_METHOD_DEFLATE) +{ + if (archive_write_buffer_) + handle = StreamFromWriteBuffer::open(std::move(archive_write_buffer_)); + else + handle = zipOpen64(path_to_archive.c_str(), /* append= */ false); + if (!handle) + throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Couldn't create zip archive {}", quoteString(path_to_archive)); + +} + +ZipArchiveWriter::~ZipArchiveWriter() +{ + if (handle) + { + try + { + checkResult(zipClose(handle, /* global_comment= */ nullptr)); + } + catch (...) + { + tryLogCurrentException("ZipArchiveWriter"); + } + } +} + +std::unique_ptr<WriteBufferFromFileBase> ZipArchiveWriter::writeFile(const String & filename) +{ + return std::make_unique<WriteBufferFromZipArchive>(acquireHandle(), filename); +} + +bool ZipArchiveWriter::isWritingFile() const +{ + std::lock_guard lock{mutex}; + return !handle; +} + +void ZipArchiveWriter::setCompression(const String & compression_method_, int compression_level_) +{ + std::lock_guard lock{mutex}; + compression_method = compressionMethodToInt(compression_method_); + compression_level = compression_level_; +} + +void ZipArchiveWriter::setPassword(const String & password_) +{ + std::lock_guard lock{mutex}; + password = password_; +} + +int ZipArchiveWriter::compressionMethodToInt(const String & compression_method_) +{ + if (compression_method_.empty()) + return MZ_COMPRESS_METHOD_DEFLATE; /// By default the compression method is "deflate". + else if (compression_method_ == kStore) + return MZ_COMPRESS_METHOD_STORE; + else if (compression_method_ == kDeflate) + return MZ_COMPRESS_METHOD_DEFLATE; + else if (compression_method_ == kBzip2) + return MZ_COMPRESS_METHOD_BZIP2; + else if (compression_method_ == kLzma) + return MZ_COMPRESS_METHOD_LZMA; + else if (compression_method_ == kZstd) + return MZ_COMPRESS_METHOD_ZSTD; + else if (compression_method_ == kXz) + return MZ_COMPRESS_METHOD_XZ; + else + throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Unknown compression method specified for a zip archive: {}", compression_method_); +} + +String ZipArchiveWriter::intToCompressionMethod(int compression_method_) +{ + switch (compression_method_) + { + case MZ_COMPRESS_METHOD_STORE: return kStore; + case MZ_COMPRESS_METHOD_DEFLATE: return kDeflate; + case MZ_COMPRESS_METHOD_BZIP2: return kBzip2; + case MZ_COMPRESS_METHOD_LZMA: return kLzma; + case MZ_COMPRESS_METHOD_ZSTD: return kZstd; + case MZ_COMPRESS_METHOD_XZ: return kXz; + } + throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Unknown compression method specified for a zip archive: {}", compression_method_); +} + +/// Checks that a passed compression method can be used. +void ZipArchiveWriter::checkCompressionMethodIsEnabled(int compression_method_) +{ + switch (compression_method_) + { + case MZ_COMPRESS_METHOD_STORE: [[fallthrough]]; + case MZ_COMPRESS_METHOD_DEFLATE: + case MZ_COMPRESS_METHOD_LZMA: + case MZ_COMPRESS_METHOD_ZSTD: + case MZ_COMPRESS_METHOD_XZ: + return; + + case MZ_COMPRESS_METHOD_BZIP2: + { +#if USE_BZIP2 + return; +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "bzip2 compression method is disabled"); +#endif + } + } + throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Unknown compression method specified for a zip archive: {}", compression_method_); +} + +/// Checks that encryption is enabled. +void ZipArchiveWriter::checkEncryptionIsEnabled() +{ +#if !USE_SSL + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "Encryption in zip archive is disabled"); +#endif +} + +ZipArchiveWriter::HandleHolder ZipArchiveWriter::acquireHandle() +{ + return HandleHolder{std::static_pointer_cast<ZipArchiveWriter>(shared_from_this())}; +} + +RawHandle ZipArchiveWriter::acquireRawHandle() +{ + std::lock_guard lock{mutex}; + if (!handle) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot have more than one write buffer while writing a zip archive"); + return std::exchange(handle, nullptr); +} + +void ZipArchiveWriter::releaseRawHandle(RawHandle raw_handle_) +{ + std::lock_guard lock{mutex}; + handle = raw_handle_; +} + +void ZipArchiveWriter::checkResult(int code) const +{ + if (code >= ZIP_OK) + return; + + String message = "Code = "; + switch (code) + { + case ZIP_ERRNO: message += "ERRNO, errno = " + errnoToString(); break; + case ZIP_PARAMERROR: message += "PARAMERROR"; break; + case ZIP_BADZIPFILE: message += "BADZIPFILE"; break; + case ZIP_INTERNALERROR: message += "INTERNALERROR"; break; + default: message += std::to_string(code); break; + } + showError(message); +} + +void ZipArchiveWriter::showError(const String & message) const +{ + throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Couldn't pack zip archive {}: {}", quoteString(path_to_archive), message); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/Archives/ZipArchiveWriter.h b/contrib/clickhouse/src/IO/Archives/ZipArchiveWriter.h new file mode 100644 index 0000000000..6650705fca --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/ZipArchiveWriter.h @@ -0,0 +1,92 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_MINIZIP +#include <IO/Archives/IArchiveWriter.h> +#include <mutex> + + +namespace DB +{ +class WriteBuffer; +class WriteBufferFromFileBase; + +/// Implementation of IArchiveWriter for writing zip archives. +class ZipArchiveWriter : public IArchiveWriter +{ +public: + /// Constructs an archive that will be written as a file in the local filesystem. + explicit ZipArchiveWriter(const String & path_to_archive_); + + /// Constructs an archive that will be written by using a specified `archive_write_buffer_`. + ZipArchiveWriter(const String & path_to_archive_, std::unique_ptr<WriteBuffer> archive_write_buffer_); + + /// Destructors finalizes writing the archive. + ~ZipArchiveWriter() override; + + /// Starts writing a file to the archive. The function returns a write buffer, + /// any data written to that buffer will be compressed and then put to the archive. + /// You can keep only one such buffer at a time, a buffer returned by previous call + /// of the function `writeFile()` should be destroyed before next call of `writeFile()`. + std::unique_ptr<WriteBufferFromFileBase> writeFile(const String & filename) override; + + /// Returns true if there is an active instance of WriteBuffer returned by writeFile(). + /// This function should be used mostly for debugging purposes. + bool isWritingFile() const override; + + /// Supported compression methods. + static constexpr const char kStore[] = "store"; + static constexpr const char kDeflate[] = "deflate"; + static constexpr const char kBzip2[] = "bzip2"; + static constexpr const char kLzma[] = "lzma"; + static constexpr const char kZstd[] = "zstd"; + static constexpr const char kXz[] = "xz"; + + /// Some compression levels. + enum class CompressionLevels + { + kDefault = kDefaultCompressionLevel, + kFast = 2, + kNormal = 6, + kBest = 9, + }; + + /// Sets compression method and level. + /// Changing them will affect next file in the archive. + void setCompression(const String & compression_method_, int compression_level_) override; + + /// Sets password. Only contents of the files are encrypted, + /// names of files are not encrypted. + /// Changing the password will affect next file in the archive. + void setPassword(const String & password_) override; + + /// Utility functions. + static int compressionMethodToInt(const String & compression_method_); + static String intToCompressionMethod(int compression_method_); + static void checkCompressionMethodIsEnabled(int compression_method_); + static void checkEncryptionIsEnabled(); + +private: + class WriteBufferFromZipArchive; + class HandleHolder; + using RawHandle = void *; + + HandleHolder acquireHandle(); + RawHandle acquireRawHandle(); + void releaseRawHandle(RawHandle raw_handle_); + + void checkResult(int code) const; + [[noreturn]] void showError(const String & message) const; + + const String path_to_archive; + int compression_method; /// By default the compression method is "deflate". + int compression_level = kDefaultCompressionLevel; + String password; + RawHandle handle = nullptr; + mutable std::mutex mutex; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/Archives/createArchiveReader.cpp b/contrib/clickhouse/src/IO/Archives/createArchiveReader.cpp new file mode 100644 index 0000000000..0c998971de --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/createArchiveReader.cpp @@ -0,0 +1,70 @@ +#include <IO/Archives/createArchiveReader.h> +#include <IO/Archives/ZipArchiveReader.h> +#include <IO/Archives/LibArchiveReader.h> +#include <Common/Exception.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_UNPACK_ARCHIVE; + extern const int SUPPORT_IS_DISABLED; +} + + +std::shared_ptr<IArchiveReader> createArchiveReader(const String & path_to_archive) +{ + return createArchiveReader(path_to_archive, {}, 0); +} + + +std::shared_ptr<IArchiveReader> createArchiveReader( + const String & path_to_archive, + [[maybe_unused]] const std::function<std::unique_ptr<SeekableReadBuffer>()> & archive_read_function, + [[maybe_unused]] size_t archive_size) +{ + using namespace std::literals; + static constexpr std::array tar_extensions + { + ".tar"sv, + ".tar.gz"sv, + ".tgz"sv, + ".tar.zst"sv, + ".tzst"sv, + ".tar.xz"sv, + ".tar.bz2"sv + }; + + if (path_to_archive.ends_with(".zip") || path_to_archive.ends_with(".zipx")) + { +#if USE_MINIZIP + return std::make_shared<ZipArchiveReader>(path_to_archive, archive_read_function, archive_size); +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "minizip library is disabled"); +#endif + } + else if (std::any_of( + tar_extensions.begin(), tar_extensions.end(), [&](const auto extension) { return path_to_archive.ends_with(extension); })) + { +#if USE_LIBARCHIVE + return std::make_shared<TarArchiveReader>(path_to_archive); +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "libarchive library is disabled"); +#endif + } + else if (path_to_archive.ends_with(".7z")) + { +#if USE_LIBARCHIVE + return std::make_shared<SevenZipArchiveReader>(path_to_archive); +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "libarchive library is disabled"); +#endif + } + else + { + throw Exception(ErrorCodes::CANNOT_UNPACK_ARCHIVE, "Cannot determine the type of archive {}", path_to_archive); + } +} + +} diff --git a/contrib/clickhouse/src/IO/Archives/createArchiveReader.h b/contrib/clickhouse/src/IO/Archives/createArchiveReader.h new file mode 100644 index 0000000000..64eb4c8eab --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/createArchiveReader.h @@ -0,0 +1,23 @@ +#pragma once + +#include <base/types.h> +#include <functional> +#include <memory> + + +namespace DB +{ +class IArchiveReader; +class SeekableReadBuffer; + +/// Starts reading a specified archive in the local filesystem. +std::shared_ptr<IArchiveReader> createArchiveReader(const String & path_to_archive); + +/// Starts reading a specified archive, the archive is read by using a specified read buffer, +/// `path_to_archive` is used only to determine the archive's type. +std::shared_ptr<IArchiveReader> createArchiveReader( + const String & path_to_archive, + const std::function<std::unique_ptr<SeekableReadBuffer>()> & archive_read_function, + size_t archive_size); + +} diff --git a/contrib/clickhouse/src/IO/Archives/createArchiveWriter.cpp b/contrib/clickhouse/src/IO/Archives/createArchiveWriter.cpp new file mode 100644 index 0000000000..807fe66e6a --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/createArchiveWriter.cpp @@ -0,0 +1,38 @@ +#include <IO/Archives/createArchiveWriter.h> +#include <IO/Archives/ZipArchiveWriter.h> +#include <IO/WriteBuffer.h> +#include <Common/Exception.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_PACK_ARCHIVE; + extern const int SUPPORT_IS_DISABLED; +} + + +std::shared_ptr<IArchiveWriter> createArchiveWriter(const String & path_to_archive) +{ + return createArchiveWriter(path_to_archive, nullptr); +} + + +std::shared_ptr<IArchiveWriter> createArchiveWriter( + const String & path_to_archive, + [[maybe_unused]] std::unique_ptr<WriteBuffer> archive_write_buffer) +{ + if (path_to_archive.ends_with(".zip") || path_to_archive.ends_with(".zipx")) + { +#if USE_MINIZIP + return std::make_shared<ZipArchiveWriter>(path_to_archive, std::move(archive_write_buffer)); +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "minizip library is disabled"); +#endif + } + else + throw Exception(ErrorCodes::CANNOT_PACK_ARCHIVE, "Cannot determine the type of archive {}", path_to_archive); +} + +} diff --git a/contrib/clickhouse/src/IO/Archives/createArchiveWriter.h b/contrib/clickhouse/src/IO/Archives/createArchiveWriter.h new file mode 100644 index 0000000000..51ffd4d114 --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/createArchiveWriter.h @@ -0,0 +1,19 @@ +#pragma once + +#include <base/types.h> +#include <memory> + + +namespace DB +{ +class IArchiveWriter; +class WriteBuffer; + +/// Starts writing a specified archive in the local filesystem. +std::shared_ptr<IArchiveWriter> createArchiveWriter(const String & path_to_archive); + +/// Starts writing a specified archive, the archive is written by using a specified write buffer, +/// `path_to_archive` is used only to determine the archive's type. +std::shared_ptr<IArchiveWriter> createArchiveWriter(const String & path_to_archive, std::unique_ptr<WriteBuffer> archive_write_buffer); + +} diff --git a/contrib/clickhouse/src/IO/Archives/hasRegisteredArchiveFileExtension.cpp b/contrib/clickhouse/src/IO/Archives/hasRegisteredArchiveFileExtension.cpp new file mode 100644 index 0000000000..6b2ef29d05 --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/hasRegisteredArchiveFileExtension.cpp @@ -0,0 +1,12 @@ +#include <IO/Archives/hasRegisteredArchiveFileExtension.h> + + +namespace DB +{ + +bool hasRegisteredArchiveFileExtension(const String & path) +{ + return path.ends_with(".zip") || path.ends_with(".zipx"); +} + +} diff --git a/contrib/clickhouse/src/IO/Archives/hasRegisteredArchiveFileExtension.h b/contrib/clickhouse/src/IO/Archives/hasRegisteredArchiveFileExtension.h new file mode 100644 index 0000000000..cab938aa0b --- /dev/null +++ b/contrib/clickhouse/src/IO/Archives/hasRegisteredArchiveFileExtension.h @@ -0,0 +1,12 @@ +#pragma once + +#include <base/types.h> + + +namespace DB +{ + +/// Returns true if a specified path has one of the registered file extensions for an archive. +bool hasRegisteredArchiveFileExtension(const String & path); + +} diff --git a/contrib/clickhouse/src/IO/AsyncReadCounters.cpp b/contrib/clickhouse/src/IO/AsyncReadCounters.cpp new file mode 100644 index 0000000000..816da0d331 --- /dev/null +++ b/contrib/clickhouse/src/IO/AsyncReadCounters.cpp @@ -0,0 +1,37 @@ +#include <IO/AsyncReadCounters.h> + +namespace DB +{ + +void AsyncReadCounters::dumpToMapColumn(IColumn * column) const +{ + auto * column_map = column ? &typeid_cast<DB::ColumnMap &>(*column) : nullptr; + if (!column_map) + return; + + auto & offsets = column_map->getNestedColumn().getOffsets(); + auto & tuple_column = column_map->getNestedData(); + auto & key_column = tuple_column.getColumn(0); + auto & value_column = tuple_column.getColumn(1); + + size_t size = 0; + auto load_if_not_empty = [&](const auto & key, const auto & value) + { + if (value) + { + key_column.insert(key); + value_column.insert(value); + ++size; + } + }; + + std::lock_guard lock(mutex); + + load_if_not_empty("max_parallel_read_tasks", max_parallel_read_tasks); + load_if_not_empty("max_parallel_prefetch_tasks", max_parallel_prefetch_tasks); + load_if_not_empty("total_prefetch_tasks", total_prefetch_tasks); + + offsets.push_back(offsets.back() + size); +} + +} diff --git a/contrib/clickhouse/src/IO/AsyncReadCounters.h b/contrib/clickhouse/src/IO/AsyncReadCounters.h new file mode 100644 index 0000000000..1f84b2a214 --- /dev/null +++ b/contrib/clickhouse/src/IO/AsyncReadCounters.h @@ -0,0 +1,32 @@ +#pragma once +#include <Core/Types.h> +#include <Columns/ColumnMap.h> + +namespace DB +{ + +/// Metrics for asynchronous reading feature. +struct AsyncReadCounters +{ + /// Count current and max number of tasks in a asynchronous read pool. + /// The tasks are requests to read the data. + size_t max_parallel_read_tasks = 0; + size_t current_parallel_read_tasks = 0; + + /// Count current and max number of tasks in a reader prefetch read pool. + /// The tasks are calls to IMergeTreeReader::prefetch(), which does not do + /// any reading but creates a request for read. But as we need to wait for + /// marks to be loaded during this prefetch, we do it in a threadpool too. + size_t max_parallel_prefetch_tasks = 0; + size_t current_parallel_prefetch_tasks = 0; + size_t total_prefetch_tasks = 0; + + mutable std::mutex mutex; + + AsyncReadCounters() = default; + + void dumpToMapColumn(IColumn * column) const; +}; +using AsyncReadCountersPtr = std::shared_ptr<AsyncReadCounters>; + +} diff --git a/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFile.cpp b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFile.cpp new file mode 100644 index 0000000000..0e6c8090cb --- /dev/null +++ b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFile.cpp @@ -0,0 +1,109 @@ +#include <fcntl.h> + +#include <IO/AsynchronousReadBufferFromFile.h> +#include <IO/WriteHelpers.h> +#include <Common/ProfileEvents.h> +#include <base/defines.h> +#include <cerrno> + + +namespace ProfileEvents +{ + extern const Event FileOpen; +} + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FILE_DOESNT_EXIST; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; +} + + +AsynchronousReadBufferFromFile::AsynchronousReadBufferFromFile( + IAsynchronousReader & reader_, + Priority priority_, + const std::string & file_name_, + size_t buf_size, + int flags, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_) + : AsynchronousReadBufferFromFileDescriptor(reader_, priority_, -1, buf_size, existing_memory, alignment, file_size_) + , file_name(file_name_) +{ + ProfileEvents::increment(ProfileEvents::FileOpen); + +#ifdef OS_DARWIN + bool o_direct = (flags != -1) && (flags & O_DIRECT); + if (o_direct) + flags = flags & ~O_DIRECT; +#endif + fd = ::open(file_name.c_str(), flags == -1 ? O_RDONLY | O_CLOEXEC : flags | O_CLOEXEC); + + if (-1 == fd) + throwFromErrnoWithPath("Cannot open file " + file_name, file_name, + errno == ENOENT ? ErrorCodes::FILE_DOESNT_EXIST : ErrorCodes::CANNOT_OPEN_FILE); +#ifdef OS_DARWIN + if (o_direct) + { + if (fcntl(fd, F_NOCACHE, 1) == -1) + throwFromErrnoWithPath("Cannot set F_NOCACHE on file " + file_name, file_name, ErrorCodes::CANNOT_OPEN_FILE); + } +#endif +} + + +AsynchronousReadBufferFromFile::AsynchronousReadBufferFromFile( + IAsynchronousReader & reader_, + Priority priority_, + int & fd_, + const std::string & original_file_name, + size_t buf_size, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_) + : AsynchronousReadBufferFromFileDescriptor(reader_, priority_, fd_, buf_size, existing_memory, alignment, file_size_) + , file_name(original_file_name.empty() ? "(fd = " + toString(fd_) + ")" : original_file_name) +{ + fd_ = -1; +} + + +AsynchronousReadBufferFromFile::~AsynchronousReadBufferFromFile() +{ + /// Must wait for events in flight before closing the file. + finalize(); + + if (fd < 0) + return; + + int err = ::close(fd); + chassert(!err || errno == EINTR); +} + + +void AsynchronousReadBufferFromFile::close() +{ + if (fd < 0) + return; + + if (0 != ::close(fd)) + throw Exception(ErrorCodes::CANNOT_CLOSE_FILE, "Cannot close file"); + + fd = -1; +} + + +AsynchronousReadBufferFromFileWithDescriptorsCache::~AsynchronousReadBufferFromFileWithDescriptorsCache() +{ + /// Must wait for events in flight before potentially closing the file by destroying OpenedFilePtr. + finalize(); +} + + +} diff --git a/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFile.h b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFile.h new file mode 100644 index 0000000000..d3b7ffbc7d --- /dev/null +++ b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFile.h @@ -0,0 +1,84 @@ +#pragma once + +#include <Common/Throttler_fwd.h> +#include <IO/AsynchronousReadBufferFromFileDescriptor.h> +#include <IO/OpenedFileCache.h> + + +namespace DB +{ + +/* NOTE: Unused */ +class AsynchronousReadBufferFromFile : public AsynchronousReadBufferFromFileDescriptor +{ +protected: + std::string file_name; + +public: + explicit AsynchronousReadBufferFromFile( + IAsynchronousReader & reader_, + Priority priority_, + const std::string & file_name_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + int flags = -1, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt); + + /// Use pre-opened file descriptor. + explicit AsynchronousReadBufferFromFile( + IAsynchronousReader & reader_, + Priority priority_, + int & fd, /// Will be set to -1 if constructor didn't throw and ownership of file descriptor is passed to the object. + const std::string & original_file_name = {}, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt); + + ~AsynchronousReadBufferFromFile() override; + + /// Close file before destruction of object. + void close(); + + std::string getFileName() const override + { + return file_name; + } +}; + +/** Similar to AsynchronousReadBufferFromFile but also transparently shares open file descriptors. + */ +class AsynchronousReadBufferFromFileWithDescriptorsCache : public AsynchronousReadBufferFromFileDescriptor +{ +private: + std::string file_name; + OpenedFileCache::OpenedFilePtr file; + +public: + AsynchronousReadBufferFromFileWithDescriptorsCache( + IAsynchronousReader & reader_, + Priority priority_, + const std::string & file_name_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + int flags = -1, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler_ = {}) + : AsynchronousReadBufferFromFileDescriptor(reader_, priority_, -1, buf_size, existing_memory, alignment, file_size_, throttler_) + , file_name(file_name_) + { + file = OpenedFileCache::instance().get(file_name, flags); + fd = file->getFD(); + } + + ~AsynchronousReadBufferFromFileWithDescriptorsCache() override; + + std::string getFileName() const override + { + return file_name; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFileDescriptor.cpp b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFileDescriptor.cpp new file mode 100644 index 0000000000..d30773f88f --- /dev/null +++ b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFileDescriptor.cpp @@ -0,0 +1,272 @@ +#include <cerrno> +#include <ctime> +#include <optional> +#include <Common/ProfileEvents.h> +#include <Common/Stopwatch.h> +#include <Common/Exception.h> +#include <Common/CurrentMetrics.h> +#include <Common/Throttler.h> +#include <Common/filesystemHelpers.h> +#include <IO/AsynchronousReadBufferFromFileDescriptor.h> +#include <IO/WriteHelpers.h> + + +namespace ProfileEvents +{ + extern const Event AsynchronousReadWaitMicroseconds; + extern const Event LocalReadThrottlerBytes; + extern const Event LocalReadThrottlerSleepMicroseconds; +} + +namespace CurrentMetrics +{ + extern const Metric AsynchronousReadWait; +} + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int LOGICAL_ERROR; +} + + +std::string AsynchronousReadBufferFromFileDescriptor::getFileName() const +{ + return "(fd = " + toString(fd) + ")"; +} + + +std::future<IAsynchronousReader::Result> AsynchronousReadBufferFromFileDescriptor::asyncReadInto(char * data, size_t size, Priority priority) +{ + IAsynchronousReader::Request request; + request.descriptor = std::make_shared<IAsynchronousReader::LocalFileDescriptor>(fd); + request.buf = data; + request.size = size; + request.offset = file_offset_of_buffer_end; + request.priority = Priority{base_priority.value + priority.value}; + request.ignore = bytes_to_ignore; + bytes_to_ignore = 0; + + /// This is a workaround of a read pass EOF bug in linux kernel with pread() + if (file_size.has_value() && file_offset_of_buffer_end >= *file_size) + { + return std::async(std::launch::deferred, [] { return IAsynchronousReader::Result{.size = 0, .offset = 0}; }); + } + + return reader.submit(request); +} + + +void AsynchronousReadBufferFromFileDescriptor::prefetch(Priority priority) +{ + if (prefetch_future.valid()) + return; + + /// Will request the same amount of data that is read in nextImpl. + prefetch_buffer.resize(internal_buffer.size()); + prefetch_future = asyncReadInto(prefetch_buffer.data(), prefetch_buffer.size(), priority); +} + + +bool AsynchronousReadBufferFromFileDescriptor::nextImpl() +{ + if (prefetch_future.valid()) + { + /// Read request already in flight. Wait for its completion. + + size_t size = 0; + size_t offset = 0; + { + Stopwatch watch; + CurrentMetrics::Increment metric_increment{CurrentMetrics::AsynchronousReadWait}; + auto result = prefetch_future.get(); + ProfileEvents::increment(ProfileEvents::AsynchronousReadWaitMicroseconds, watch.elapsedMicroseconds()); + size = result.size; + offset = result.offset; + assert(offset < size || size == 0); + } + + prefetch_future = {}; + file_offset_of_buffer_end += size; + + assert(offset <= size); + size_t bytes_read = size - offset; + if (throttler) + throttler->add(bytes_read, ProfileEvents::LocalReadThrottlerBytes, ProfileEvents::LocalReadThrottlerSleepMicroseconds); + + if (bytes_read) + { + prefetch_buffer.swap(memory); + /// Adjust the working buffer so that it ignores `offset` bytes. + internal_buffer = Buffer(memory.data(), memory.data() + memory.size()); + working_buffer = Buffer(memory.data() + offset, memory.data() + size); + pos = working_buffer.begin(); + return true; + } + + return false; + } + else + { + /// No pending request. Do synchronous read. + + Stopwatch watch; + auto [size, offset, _] = asyncReadInto(memory.data(), memory.size(), DEFAULT_PREFETCH_PRIORITY).get(); + ProfileEvents::increment(ProfileEvents::AsynchronousReadWaitMicroseconds, watch.elapsedMicroseconds()); + + file_offset_of_buffer_end += size; + + assert(offset <= size); + size_t bytes_read = size - offset; + if (throttler) + throttler->add(bytes_read, ProfileEvents::LocalReadThrottlerBytes, ProfileEvents::LocalReadThrottlerSleepMicroseconds); + + if (bytes_read) + { + /// Adjust the working buffer so that it ignores `offset` bytes. + internal_buffer = Buffer(memory.data(), memory.data() + memory.size()); + working_buffer = Buffer(memory.data() + offset, memory.data() + size); + pos = working_buffer.begin(); + return true; + } + + return false; + } +} + + +void AsynchronousReadBufferFromFileDescriptor::finalize() +{ + if (prefetch_future.valid()) + { + prefetch_future.wait(); + prefetch_future = {}; + } +} + + +AsynchronousReadBufferFromFileDescriptor::AsynchronousReadBufferFromFileDescriptor( + IAsynchronousReader & reader_, + Priority priority_, + int fd_, + size_t buf_size, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_, + ThrottlerPtr throttler_) + : ReadBufferFromFileBase(buf_size, existing_memory, alignment, file_size_) + , reader(reader_) + , base_priority(priority_) + , required_alignment(alignment) + , fd(fd_) + , throttler(throttler_) +{ + if (required_alignment > buf_size) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Too large alignment. Cannot have required_alignment greater than buf_size: {} > {}. It is a bug", + required_alignment, + buf_size); + + prefetch_buffer.alignment = alignment; +} + +AsynchronousReadBufferFromFileDescriptor::~AsynchronousReadBufferFromFileDescriptor() +{ + finalize(); +} + + +/// If 'offset' is small enough to stay in buffer after seek, then true seek in file does not happen. +off_t AsynchronousReadBufferFromFileDescriptor::seek(off_t offset, int whence) +{ + size_t new_pos; + if (whence == SEEK_SET) + { + assert(offset >= 0); + new_pos = offset; + } + else if (whence == SEEK_CUR) + { + new_pos = file_offset_of_buffer_end - (working_buffer.end() - pos) + offset; + } + else + { + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "ReadBufferFromFileDescriptor::seek expects SEEK_SET or SEEK_CUR as whence"); + } + + /// Position is unchanged. + if (new_pos + (working_buffer.end() - pos) == file_offset_of_buffer_end) + return new_pos; + + while (true) + { + if (file_offset_of_buffer_end - working_buffer.size() <= new_pos && new_pos <= file_offset_of_buffer_end) + { + /// Position is still inside the buffer. + /// Probably it is at the end of the buffer - then we will load data on the following 'next' call. + + pos = working_buffer.end() - file_offset_of_buffer_end + new_pos; + assert(pos >= working_buffer.begin()); + assert(pos <= working_buffer.end()); + + return new_pos; + } + else if (prefetch_future.valid()) + { + /// Read from prefetch buffer and recheck if the new position is valid inside. + if (nextImpl()) + continue; + } + + break; + } + + assert(!prefetch_future.valid()); + + /// Position is out of the buffer, we need to do real seek. + off_t seek_pos = required_alignment > 1 + ? new_pos / required_alignment * required_alignment + : new_pos; + + /// First reset the buffer so the next read will fetch new data to the buffer. + resetWorkingBuffer(); + + /// Just update the info about the next position in file. + + file_offset_of_buffer_end = seek_pos; + bytes_to_ignore = new_pos - seek_pos; + + if (bytes_to_ignore >= internal_buffer.size()) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Logical error in AsynchronousReadBufferFromFileDescriptor, bytes_to_ignore ({}" + ") >= internal_buffer.size() ({})", bytes_to_ignore, internal_buffer.size()); + + return seek_pos; +} + + +void AsynchronousReadBufferFromFileDescriptor::rewind() +{ + if (prefetch_future.valid()) + { + prefetch_future.wait(); + prefetch_future = {}; + } + + /// Clearing the buffer with existing data. New data will be read on subsequent call to 'next'. + working_buffer.resize(0); + pos = working_buffer.begin(); + file_offset_of_buffer_end = 0; +} + +size_t AsynchronousReadBufferFromFileDescriptor::getFileSize() +{ + return getSizeFromFileDescriptor(fd, getFileName()); +} + +} diff --git a/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFileDescriptor.h b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFileDescriptor.h new file mode 100644 index 0000000000..4a4130ebab --- /dev/null +++ b/contrib/clickhouse/src/IO/AsynchronousReadBufferFromFileDescriptor.h @@ -0,0 +1,77 @@ +#pragma once + +#include <IO/ReadBufferFromFileBase.h> +#include <IO/AsynchronousReader.h> +#include <Interpreters/Context.h> +#include <Common/Throttler_fwd.h> +#include <Common/Priority.h> + +#include <optional> +#include <unistd.h> + + +namespace DB +{ + +/** Use ready file descriptor. Does not open or close a file. + */ +class AsynchronousReadBufferFromFileDescriptor : public ReadBufferFromFileBase +{ +protected: + IAsynchronousReader & reader; + Priority base_priority; + + Memory<> prefetch_buffer; + std::future<IAsynchronousReader::Result> prefetch_future; + + const size_t required_alignment = 0; /// For O_DIRECT both file offsets and memory addresses have to be aligned. + size_t file_offset_of_buffer_end = 0; /// What offset in file corresponds to working_buffer.end(). + size_t bytes_to_ignore = 0; /// How many bytes should we ignore upon a new read request. + int fd; + ThrottlerPtr throttler; + + bool nextImpl() override; + + /// Name or some description of file. + std::string getFileName() const override; + + void finalize(); + +public: + AsynchronousReadBufferFromFileDescriptor( + IAsynchronousReader & reader_, + Priority priority_, + int fd_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler_ = {}); + + ~AsynchronousReadBufferFromFileDescriptor() override; + + void prefetch(Priority priority) override; + + int getFD() const + { + return fd; + } + + off_t getPosition() override + { + return file_offset_of_buffer_end - (working_buffer.end() - pos); + } + + /// If 'offset' is small enough to stay in buffer after seek, then true seek in file does not happen. + off_t seek(off_t off, int whence) override; + + /// Seek to the beginning, discarding already read data if any. Useful to reread file that changes on every read. + void rewind(); + + size_t getFileSize() override; + +private: + std::future<IAsynchronousReader::Result> asyncReadInto(char * data, size_t size, Priority priority); +}; + +} diff --git a/contrib/clickhouse/src/IO/AsynchronousReader.h b/contrib/clickhouse/src/IO/AsynchronousReader.h new file mode 100644 index 0000000000..467a3f1d6a --- /dev/null +++ b/contrib/clickhouse/src/IO/AsynchronousReader.h @@ -0,0 +1,87 @@ +#pragma once + +#include <Core/Types.h> +#include <optional> +#include <memory> +#include <future> +#include <boost/noncopyable.hpp> +#include <Common/Stopwatch.h> +#include <Common/Priority.h> + + +namespace DB +{ + +/** Interface for asynchronous reads from file descriptors. + * It can abstract Linux AIO, io_uring or normal reads from separate thread pool, + * and also reads from non-local filesystems. + * The implementation not necessarily to be efficient for large number of small requests, + * instead it should be ok for moderate number of sufficiently large requests + * (e.g. read 1 MB of data 50 000 times per seconds; BTW this is normal performance for reading from page cache). + * For example, this interface may not suffice if you want to serve 10 000 000 of 4 KiB requests per second. + * This interface is fairly limited. + */ +class IAsynchronousReader : private boost::noncopyable +{ +public: + /// For local filesystems, the file descriptor is simply integer + /// but it can be arbitrary opaque object for remote filesystems. + struct IFileDescriptor + { + virtual ~IFileDescriptor() = default; + }; + + using FileDescriptorPtr = std::shared_ptr<IFileDescriptor>; + + struct LocalFileDescriptor : public IFileDescriptor + { + explicit LocalFileDescriptor(int fd_) : fd(fd_) {} + int fd; + }; + + /// Read from file descriptor at specified offset up to size bytes into buf. + /// Some implementations may require alignment and it is responsibility of + /// the caller to provide conforming requests. + struct Request + { + FileDescriptorPtr descriptor; + size_t offset = 0; + size_t size = 0; + char * buf = nullptr; + Priority priority; + size_t ignore = 0; + }; + + struct Result + { + /// size + /// Less than requested amount of data can be returned. + /// If size is zero - the file has ended. + /// (for example, EINTR must be handled by implementation automatically) + size_t size = 0; + + /// offset + /// Optional. Useful when implementation needs to do ignore(). + size_t offset = 0; + + std::unique_ptr<Stopwatch> execution_watch = {}; + + operator std::tuple<size_t &, size_t &>() { return {size, offset}; } + }; + + /// Submit request and obtain a handle. This method don't perform any waits. + /// If this method did not throw, the caller must wait for the result with 'wait' method + /// or destroy the whole reader before destroying the buffer for request. + /// The method can be called concurrently from multiple threads. + virtual std::future<Result> submit(Request request) = 0; + + virtual void wait() = 0; + + /// Destructor must wait for all not completed request and ignore the results. + /// It may also cancel the requests. + virtual ~IAsynchronousReader() = default; +}; + +using AsynchronousReaderPtr = std::shared_ptr<IAsynchronousReader>; + +} diff --git a/contrib/clickhouse/src/IO/BitHelpers.h b/contrib/clickhouse/src/IO/BitHelpers.h new file mode 100644 index 0000000000..a384da0a95 --- /dev/null +++ b/contrib/clickhouse/src/IO/BitHelpers.h @@ -0,0 +1,236 @@ +#pragma once + +#include <bit> +#include <base/types.h> +#include <Common/BitHelpers.h> +#include <Common/Exception.h> + +#include <cstring> +#include <cassert> + + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int CANNOT_WRITE_AFTER_END_OF_BUFFER; +extern const int ATTEMPT_TO_READ_AFTER_EOF; +} + +/** Reads data from underlying ReadBuffer bit by bit, max 64 bits at once. + * + * reads MSB bits first, imagine that you have a data: + * 11110000 10101010 00100100 11111110 + * + * Given that r is BitReader created with a ReadBuffer that reads from data above: + * r.readBits(3) => 0b111 + * r.readBit() => 0b1 + * r.readBits(8) => 0b1010 // 4 leading zero-bits are not shown + * r.readBit() => 0b1 + * r.readBit() => 0b0 + * r.readBits(15) => 0b10001001001111111 + * r.readBit() => 0b0 +**/ + +class BitReader +{ + const char * const source_begin; + const char * const source_end; + const char * source_current; + + using BufferType = unsigned __int128; + BufferType bits_buffer = 0; + + UInt8 bits_count = 0; + +public: + BitReader(const char * begin, size_t size) + : source_begin(begin) + , source_end(begin + size) + , source_current(begin) + {} + + ~BitReader() = default; + + // reads bits_to_read high-bits from bits_buffer + ALWAYS_INLINE UInt64 readBits(UInt8 bits_to_read) + { + if (bits_to_read > bits_count) + fillBitBuffer(); + + return getBitsFromBitBuffer<CONSUME>(bits_to_read); + } + + UInt8 peekByte() + { + if (bits_count < 8) + fillBitBuffer(); + + return getBitsFromBitBuffer<PEEK>(8); + } + + ALWAYS_INLINE UInt8 readBit() + { + return static_cast<UInt8>(readBits(1)); + } + + // skip bits from bits_buffer + void skipBufferedBits(UInt8 bits) + { + bits_buffer <<= bits; + bits_count -= bits; + } + + + bool eof() const + { + return bits_count == 0 && source_current >= source_end; + } + + // number of bits that was already read by clients with readBits() + UInt64 count() const + { + return (source_current - source_begin) * 8 - bits_count; + } + + UInt64 remaining() const + { + return (source_end - source_current) * 8 + bits_count; + } + +private: + enum GetBitsMode {CONSUME, PEEK}; + // read data from internal buffer, if it has not enough bits, result is undefined. + template <GetBitsMode mode> + UInt64 getBitsFromBitBuffer(UInt8 bits_to_read) + { + assert(bits_to_read > 0); + + // push down the high-bits + const UInt64 result = static_cast<UInt64>(bits_buffer >> (sizeof(bits_buffer) * 8 - bits_to_read)); + + if constexpr (mode == CONSUME) + { + // 'erase' high-bits that were have read + skipBufferedBits(bits_to_read); + } + + return result; + } + + + // Fills internal bits_buffer with data from source, reads at most 64 bits + ALWAYS_INLINE size_t fillBitBuffer() + { + const size_t available = source_end - source_current; + const auto bytes_to_read = std::min<size_t>(64 / 8, available); + if (available == 0) + { + if (bytes_to_read == 0) + return 0; + + throw Exception(ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF, "Buffer is empty, but requested to read {} more bytes.", + bytes_to_read); + } + + UInt64 tmp_buffer = 0; + memcpy(&tmp_buffer, source_current, bytes_to_read); + source_current += bytes_to_read; + + if constexpr (std::endian::native == std::endian::little) + tmp_buffer = DB::byteswap(tmp_buffer); + + bits_buffer |= BufferType(tmp_buffer) << ((sizeof(BufferType) - sizeof(tmp_buffer)) * 8 - bits_count); + bits_count += static_cast<UInt8>(bytes_to_read) * 8; + + return bytes_to_read; + } +}; + +class BitWriter +{ + char * dest_begin; + char * dest_end; + char * dest_current; + + using BufferType = unsigned __int128; + BufferType bits_buffer = 0; + + UInt8 bits_count = 0; + + static constexpr UInt8 BIT_BUFFER_SIZE = sizeof(bits_buffer) * 8; + +public: + BitWriter(char * begin, size_t size) + : dest_begin(begin) + , dest_end(begin + size) + , dest_current(begin) + {} + + ~BitWriter() + { + flush(); + } + + // write `bits_to_write` low-bits of `value` to the buffer + void writeBits(UInt8 bits_to_write, UInt64 value) + { + assert(bits_to_write > 0); + + UInt32 capacity = BIT_BUFFER_SIZE - bits_count; + if (capacity < bits_to_write) + { + doFlush(); + capacity = BIT_BUFFER_SIZE - bits_count; + } + + // write low bits of value as high bits of bits_buffer + const UInt64 mask = maskLowBits<UInt64>(bits_to_write); + BufferType v = value & mask; + v <<= capacity - bits_to_write; + + bits_buffer |= v; + bits_count += bits_to_write; + } + + // flush contents of bits_buffer to the dest_current, partial bytes are completed with zeroes. + void flush() + { + bits_count = (bits_count + 8 - 1) & ~(8 - 1); // align up to 8-bytes, so doFlush will write all data from bits_buffer + while (bits_count != 0) + doFlush(); + } + + UInt64 count() const + { + return (dest_current - dest_begin) * 8 + bits_count; + } + +private: + void doFlush() + { + // write whole bytes to the dest_current, leaving partial bits in bits_buffer + const size_t available = dest_end - dest_current; + const size_t to_write = std::min<size_t>(sizeof(UInt64), bits_count / 8); // align to 8-bit boundary + + if (available < to_write) + { + throw Exception(ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER, + "Can not write past end of buffer. Space available is {} bytes, required to write {} bytes.", + available, to_write); + } + + UInt64 tmp_buffer = static_cast<UInt64>(bits_buffer >> (sizeof(bits_buffer) - sizeof(UInt64)) * 8); + if constexpr (std::endian::native == std::endian::little) + tmp_buffer = DB::byteswap(tmp_buffer); + + memcpy(dest_current, &tmp_buffer, to_write); + dest_current += to_write; + + bits_buffer <<= to_write * 8; + bits_count -= to_write * 8; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/BoundedReadBuffer.cpp b/contrib/clickhouse/src/IO/BoundedReadBuffer.cpp new file mode 100644 index 0000000000..bda79d82ad --- /dev/null +++ b/contrib/clickhouse/src/IO/BoundedReadBuffer.cpp @@ -0,0 +1,66 @@ +#include "BoundedReadBuffer.h" +#include <IO/SwapHelper.h> + +namespace DB +{ + +BoundedReadBuffer::BoundedReadBuffer(std::unique_ptr<SeekableReadBuffer> impl_) + : ReadBufferFromFileDecorator(std::move(impl_)) +{ +} + +void BoundedReadBuffer::setReadUntilPosition(size_t position) +{ + read_until_position = position; +} + +void BoundedReadBuffer::setReadUntilEnd() +{ + read_until_position.reset(); +} + +off_t BoundedReadBuffer::getPosition() +{ + return file_offset_of_buffer_end - (working_buffer.end() - pos); +} + +bool BoundedReadBuffer::nextImpl() +{ + if (read_until_position && file_offset_of_buffer_end == *read_until_position) + return false; + + bool result; + { + SwapHelper swap(*this, *impl); + result = impl->next(); + } + chassert(file_offset_of_buffer_end + available() == impl->getFileOffsetOfBufferEnd()); + if (result && read_until_position) + { + size_t remaining_size_to_read = *read_until_position - file_offset_of_buffer_end; + if (working_buffer.size() > remaining_size_to_read) + { + /// file: [______________________________] + /// working buffer: [_______________] + /// ^ + /// read_until_position + /// ^ + /// file_offset_of_buffer_end + working_buffer.resize(remaining_size_to_read); + } + } + file_offset_of_buffer_end += available(); + return result; +} + +off_t BoundedReadBuffer::seek(off_t off, int whence) +{ + swap(*impl); + auto result = impl->seek(off, whence); + swap(*impl); + + file_offset_of_buffer_end = impl->getFileOffsetOfBufferEnd(); + return result; +} + +} diff --git a/contrib/clickhouse/src/IO/BoundedReadBuffer.h b/contrib/clickhouse/src/IO/BoundedReadBuffer.h new file mode 100644 index 0000000000..eb65857e83 --- /dev/null +++ b/contrib/clickhouse/src/IO/BoundedReadBuffer.h @@ -0,0 +1,38 @@ +#pragma once +#include <IO/ReadBufferFromFileDecorator.h> + + +namespace DB +{ + +/// A buffer which allows to make an underlying buffer as right bounded, +/// e.g. the buffer cannot return data beyond offset specified in `setReadUntilPosition`. +class BoundedReadBuffer : public ReadBufferFromFileDecorator +{ +public: + explicit BoundedReadBuffer(std::unique_ptr<SeekableReadBuffer> impl_); + + bool supportsRightBoundedReads() const override { return true; } + + void setReadUntilPosition(size_t position) override; + + void setReadUntilEnd() override; + + bool nextImpl() override; + + off_t seek(off_t off, int whence) override; + + size_t getFileOffsetOfBufferEnd() const override { return file_offset_of_buffer_end; } + + /// file_offset_of_buffer_end can differ from impl's file_offset_of_buffer_end + /// because of resizing of the tail. => Need to also override getPosition() as + /// it uses file_offset_of_buffer_end. + off_t getPosition() override; + +private: + std::optional<size_t> read_until_position; + /// atomic because can be used in log or exception messages while being updated. + std::atomic<size_t> file_offset_of_buffer_end = 0; +}; + +} diff --git a/contrib/clickhouse/src/IO/BrotliReadBuffer.cpp b/contrib/clickhouse/src/IO/BrotliReadBuffer.cpp new file mode 100644 index 0000000000..effdfb4b8c --- /dev/null +++ b/contrib/clickhouse/src/IO/BrotliReadBuffer.cpp @@ -0,0 +1,109 @@ +#include "clickhouse_config.h" + +#if USE_BROTLI +# error #include <brotli/decode.h> +# include "BrotliReadBuffer.h" +# include <IO/WithFileName.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BROTLI_READ_FAILED; +} + + +class BrotliReadBuffer::BrotliStateWrapper +{ +public: + BrotliStateWrapper() + : state(BrotliDecoderCreateInstance(nullptr, nullptr, nullptr)) + , result(BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT) + { + } + + ~BrotliStateWrapper() + { + BrotliDecoderDestroyInstance(state); + } + + BrotliDecoderState * state; + BrotliDecoderResult result; +}; + +BrotliReadBuffer::BrotliReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char *existing_memory, size_t alignment) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment) + , brotli(std::make_unique<BrotliStateWrapper>()) + , in_available(0) + , in_data(nullptr) + , out_capacity(0) + , out_data(nullptr) + , eof_flag(false) +{ +} + +BrotliReadBuffer::~BrotliReadBuffer() = default; + +bool BrotliReadBuffer::nextImpl() +{ + if (eof_flag) + return false; + + do + { + if (!in_available) + { + in->nextIfAtEnd(); + in_available = in->buffer().end() - in->position(); + in_data = reinterpret_cast<uint8_t *>(in->position()); + } + + if (brotli->result == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && (!in_available || in->eof())) + { + throw Exception( + ErrorCodes::BROTLI_READ_FAILED, + "brotli decode error{}", + getExceptionEntryWithFileName(*in)); + } + + out_capacity = internal_buffer.size(); + out_data = reinterpret_cast<uint8_t *>(internal_buffer.begin()); + + brotli->result = BrotliDecoderDecompressStream(brotli->state, &in_available, &in_data, &out_capacity, &out_data, nullptr); + + in->position() = in->buffer().end() - in_available; + } + while (brotli->result == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT && out_capacity == internal_buffer.size()); + + working_buffer.resize(internal_buffer.size() - out_capacity); + + if (brotli->result == BROTLI_DECODER_RESULT_SUCCESS) + { + if (in->eof()) + { + eof_flag = true; + return !working_buffer.empty(); + } + else + { + throw Exception( + ErrorCodes::BROTLI_READ_FAILED, + "brotli decode error{}", + getExceptionEntryWithFileName(*in)); + } + } + + if (brotli->result == BROTLI_DECODER_RESULT_ERROR) + { + throw Exception( + ErrorCodes::BROTLI_READ_FAILED, + "brotli decode error{}", + getExceptionEntryWithFileName(*in)); + } + + return true; +} +} + +#endif diff --git a/contrib/clickhouse/src/IO/BrotliReadBuffer.h b/contrib/clickhouse/src/IO/BrotliReadBuffer.h new file mode 100644 index 0000000000..8583d6892e --- /dev/null +++ b/contrib/clickhouse/src/IO/BrotliReadBuffer.h @@ -0,0 +1,37 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/CompressedReadBufferWrapper.h> + + +namespace DB +{ + +class BrotliReadBuffer : public CompressedReadBufferWrapper +{ +public: + explicit BrotliReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~BrotliReadBuffer() override; + +private: + bool nextImpl() override; + + class BrotliStateWrapper; + std::unique_ptr<BrotliStateWrapper> brotli; + + size_t in_available; + const uint8_t * in_data; + + size_t out_capacity; + uint8_t * out_data; + + bool eof_flag; +}; + +} + diff --git a/contrib/clickhouse/src/IO/BrotliWriteBuffer.cpp b/contrib/clickhouse/src/IO/BrotliWriteBuffer.cpp new file mode 100644 index 0000000000..6ec427049c --- /dev/null +++ b/contrib/clickhouse/src/IO/BrotliWriteBuffer.cpp @@ -0,0 +1,126 @@ +#include "clickhouse_config.h" + +#if USE_BROTLI +# include <IO/BrotliWriteBuffer.h> +# error #include <brotli/encode.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BROTLI_WRITE_FAILED; +} + + +class BrotliWriteBuffer::BrotliStateWrapper +{ +public: + BrotliStateWrapper() + : state(BrotliEncoderCreateInstance(nullptr, nullptr, nullptr)) + { + } + + ~BrotliStateWrapper() + { + BrotliEncoderDestroyInstance(state); + } + + BrotliEncoderState * state; +}; + +BrotliWriteBuffer::BrotliWriteBuffer(std::unique_ptr<WriteBuffer> out_, int compression_level, size_t buf_size, char * existing_memory, size_t alignment) + : WriteBufferWithOwnMemoryDecorator(std::move(out_), buf_size, existing_memory, alignment) + , brotli(std::make_unique<BrotliStateWrapper>()) + , in_available(0) + , in_data(nullptr) + , out_capacity(0) + , out_data(nullptr) +{ + BrotliEncoderSetParameter(brotli->state, BROTLI_PARAM_QUALITY, static_cast<uint32_t>(compression_level)); + // Set LZ77 window size. According to brotli sources default value is 24 (c/tools/brotli.c:81) + BrotliEncoderSetParameter(brotli->state, BROTLI_PARAM_LGWIN, 24); +} + +BrotliWriteBuffer::~BrotliWriteBuffer() = default; + +void BrotliWriteBuffer::nextImpl() +{ + if (!offset()) + { + return; + } + + in_data = reinterpret_cast<unsigned char *>(working_buffer.begin()); + in_available = offset(); + + try + { + do + { + out->nextIfAtEnd(); + out_data = reinterpret_cast<unsigned char *>(out->position()); + out_capacity = out->buffer().end() - out->position(); + + int result = BrotliEncoderCompressStream( + brotli->state, + in_available ? BROTLI_OPERATION_PROCESS : BROTLI_OPERATION_FINISH, + &in_available, + &in_data, + &out_capacity, + &out_data, + nullptr); + + out->position() = out->buffer().end() - out_capacity; + + if (result == 0) + { + throw Exception(ErrorCodes::BROTLI_WRITE_FAILED, "brotli compress failed"); + } + } + while (in_available > 0); + } + catch (...) + { + /// Do not try to write next time after exception. + out->position() = out->buffer().begin(); + throw; + } +} + +void BrotliWriteBuffer::finalizeBefore() +{ + next(); + + while (true) + { + out->nextIfAtEnd(); + out_data = reinterpret_cast<unsigned char *>(out->position()); + out_capacity = out->buffer().end() - out->position(); + + int result = BrotliEncoderCompressStream( + brotli->state, + BROTLI_OPERATION_FINISH, + &in_available, + &in_data, + &out_capacity, + &out_data, + nullptr); + + out->position() = out->buffer().end() - out_capacity; + + if (BrotliEncoderIsFinished(brotli->state)) + { + return; + } + + if (result == 0) + { + throw Exception(ErrorCodes::BROTLI_WRITE_FAILED, "brotli compress failed"); + } + } +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/BrotliWriteBuffer.h b/contrib/clickhouse/src/IO/BrotliWriteBuffer.h new file mode 100644 index 0000000000..8cbc78bd9e --- /dev/null +++ b/contrib/clickhouse/src/IO/BrotliWriteBuffer.h @@ -0,0 +1,38 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <IO/WriteBufferDecorator.h> + +namespace DB +{ + +class BrotliWriteBuffer : public WriteBufferWithOwnMemoryDecorator +{ +public: + BrotliWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + int compression_level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~BrotliWriteBuffer() override; + +private: + void nextImpl() override; + + void finalizeBefore() override; + + class BrotliStateWrapper; + std::unique_ptr<BrotliStateWrapper> brotli; + + + size_t in_available; + const uint8_t * in_data; + + size_t out_capacity; + uint8_t * out_data; +}; + +} diff --git a/contrib/clickhouse/src/IO/BufferBase.h b/contrib/clickhouse/src/IO/BufferBase.h new file mode 100644 index 0000000000..7a59687fa5 --- /dev/null +++ b/contrib/clickhouse/src/IO/BufferBase.h @@ -0,0 +1,132 @@ +#pragma once + +#include <Core/Defines.h> +#include <algorithm> + + +namespace DB +{ + + +/** Base class for ReadBuffer and WriteBuffer. + * Contains common types, variables, and functions. + * + * ReadBuffer and WriteBuffer are similar to istream and ostream, respectively. + * They have to be used, because using iostreams it is impossible to effectively implement some operations. + * For example, using istream, you can not quickly read string values from a tab-separated file, + * so that after reading, the position remains immediately after the read value. + * (The only option is to call the std::istream::get() function on each byte, but this slows down due to several virtual calls.) + * + * Read/WriteBuffers provide direct access to the internal buffer, so the necessary operations are implemented more efficiently. + * Only one virtual function nextImpl() is used, which is rarely called: + * - in the case of ReadBuffer - fill in the buffer with new data from the source; + * - in the case of WriteBuffer - write data from the buffer into the receiver. + * + * Read/WriteBuffer can own or not own an own piece of memory. + * In the second case, you can effectively read from an already existing piece of memory / std::string without copying it. + */ +class BufferBase +{ +public: + /** Cursor in the buffer. The position of write or read. */ + using Position = char *; + + /** A reference to the range of memory. */ + struct Buffer + { + Buffer(Position begin_pos_, Position end_pos_) : begin_pos(begin_pos_), end_pos(end_pos_) {} + + inline Position begin() const { return begin_pos; } + inline Position end() const { return end_pos; } + inline size_t size() const { return size_t(end_pos - begin_pos); } + inline void resize(size_t size) { end_pos = begin_pos + size; } + inline bool empty() const { return size() == 0; } + + inline void swap(Buffer & other) + { + std::swap(begin_pos, other.begin_pos); + std::swap(end_pos, other.end_pos); + } + + private: + Position begin_pos; + Position end_pos; /// 1 byte after the end of the buffer + }; + + /** The constructor takes a range of memory to use for the buffer. + * offset - the starting point of the cursor. ReadBuffer must set it to the end of the range, and WriteBuffer - to the beginning. + */ + BufferBase(Position ptr, size_t size, size_t offset) + : pos(ptr + offset), working_buffer(ptr, ptr + size), internal_buffer(ptr, ptr + size) {} + + void set(Position ptr, size_t size, size_t offset) + { + internal_buffer = Buffer(ptr, ptr + size); + working_buffer = Buffer(ptr, ptr + size); + pos = ptr + offset; + } + + /// get buffer + inline Buffer & internalBuffer() { return internal_buffer; } + + /// get the part of the buffer from which you can read / write data + inline Buffer & buffer() { return working_buffer; } + + /// get (for reading and modifying) the position in the buffer + inline Position & position() { return pos; } + + /// offset in bytes of the cursor from the beginning of the buffer + inline size_t offset() const { return size_t(pos - working_buffer.begin()); } + + /// How many bytes are available for read/write + inline size_t available() const { return size_t(working_buffer.end() - pos); } + + inline void swap(BufferBase & other) + { + internal_buffer.swap(other.internal_buffer); + working_buffer.swap(other.working_buffer); + std::swap(pos, other.pos); + } + + /** How many bytes have been read/written, counting those that are still in the buffer. */ + size_t count() const { return bytes + offset(); } + + /** Check that there is more bytes in buffer after cursor. */ + bool ALWAYS_INLINE hasPendingData() const { return available() > 0; } + + bool isPadded() const { return padded; } + +protected: + void resetWorkingBuffer() + { + /// Move position to the end of buffer to trigger call of 'next' on next reading. + /// Discard all data in current working buffer to prevent wrong assumptions on content + /// of buffer, e.g. for optimizations of seeks in seekable buffers. + working_buffer.resize(0); + pos = working_buffer.end(); + } + + /// Read/write position. + Position pos; + + /** How many bytes have been read/written, not counting those that are now in the buffer. + * (counting those that were already used and "removed" from the buffer) + */ + size_t bytes = 0; + + /** A piece of memory that you can use. + * For example, if internal_buffer is 1MB, and from a file for reading it was loaded into the buffer + * only 10 bytes, then working_buffer will be 10 bytes in size + * (working_buffer.end() will point to the position immediately after the 10 bytes that can be read). + */ + Buffer working_buffer; + + /// A reference to a piece of memory for the buffer. + Buffer internal_buffer; + + /// Indicator of 15 bytes pad_right + bool padded{false}; +}; + + +} diff --git a/contrib/clickhouse/src/IO/BufferWithOwnMemory.h b/contrib/clickhouse/src/IO/BufferWithOwnMemory.h new file mode 100644 index 0000000000..39c83e9167 --- /dev/null +++ b/contrib/clickhouse/src/IO/BufferWithOwnMemory.h @@ -0,0 +1,196 @@ +#pragma once + +#include <boost/noncopyable.hpp> + +#include <Common/ProfileEvents.h> +#include <Common/Allocator.h> + +#include <Common/Exception.h> +#include <Core/Defines.h> + +#include <base/arithmeticOverflow.h> + + +namespace ProfileEvents +{ + extern const Event IOBufferAllocs; + extern const Event IOBufferAllocBytes; +} + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; +} + + +/** Replacement for std::vector<char> to use in buffers. + * Differs in that is doesn't do unneeded memset. (And also tries to do as little as possible.) + * Also allows to allocate aligned piece of memory (to use with O_DIRECT, for example). + */ +template <typename Allocator = Allocator<false>> +struct Memory : boost::noncopyable, Allocator +{ + static constexpr size_t pad_right = PADDING_FOR_SIMD - 1; + + size_t m_capacity = 0; /// With padding. + size_t m_size = 0; + char * m_data = nullptr; + size_t alignment = 0; + + Memory() = default; + + /// If alignment != 0, then allocate memory aligned to specified value. + explicit Memory(size_t size_, size_t alignment_ = 0) : alignment(alignment_) + { + alloc(size_); + } + + ~Memory() + { + dealloc(); + } + + void swap(Memory & rhs) noexcept + { + std::swap(m_capacity, rhs.m_capacity); + std::swap(m_size, rhs.m_size); + std::swap(m_data, rhs.m_data); + std::swap(alignment, rhs.alignment); + } + + Memory(Memory && rhs) noexcept + { + swap(rhs); + } + + Memory & operator=(Memory && rhs) noexcept + { + swap(rhs); + return *this; + } + + size_t size() const { return m_size; } + const char & operator[](size_t i) const { return m_data[i]; } + char & operator[](size_t i) { return m_data[i]; } + const char * data() const { return m_data; } + char * data() { return m_data; } + + void resize(size_t new_size) + { + if (!m_data) + { + alloc(new_size); + return; + } + + if (new_size <= m_capacity - pad_right) + { + m_size = new_size; + return; + } + + size_t new_capacity = withPadding(new_size); + + size_t diff = new_capacity - m_capacity; + ProfileEvents::increment(ProfileEvents::IOBufferAllocBytes, diff); + + m_data = static_cast<char *>(Allocator::realloc(m_data, m_capacity, new_capacity, alignment)); + m_capacity = new_capacity; + m_size = new_size; + } + +private: + static size_t withPadding(size_t value) + { + size_t res = 0; + + if (common::addOverflow<size_t>(value, pad_right, res)) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "value is too big to apply padding"); + + return res; + } + + void alloc(size_t new_size) + { + if (!new_size) + { + m_data = nullptr; + return; + } + + size_t new_capacity = withPadding(new_size); + + ProfileEvents::increment(ProfileEvents::IOBufferAllocs); + ProfileEvents::increment(ProfileEvents::IOBufferAllocBytes, new_capacity); + + m_data = static_cast<char *>(Allocator::alloc(new_capacity, alignment)); + m_capacity = new_capacity; + m_size = new_size; + } + + void dealloc() + { + if (!m_data) + return; + + Allocator::free(m_data, m_capacity); + m_data = nullptr; /// To avoid double free if next alloc will throw an exception. + } +}; + + +/** Buffer that could own its working memory. + * Template parameter: ReadBuffer or WriteBuffer + */ +template <typename Base> +class BufferWithOwnMemory : public Base +{ +protected: + Memory<> memory; +public: + /// If non-nullptr 'existing_memory' is passed, then buffer will not create its own memory and will use existing_memory without ownership. + explicit BufferWithOwnMemory(size_t size = DBMS_DEFAULT_BUFFER_SIZE, char * existing_memory = nullptr, size_t alignment = 0) + : Base(nullptr, 0), memory(existing_memory ? 0 : size, alignment) + { + Base::set(existing_memory ? existing_memory : memory.data(), size); + Base::padded = !existing_memory; + } +}; + + +/** Buffer that could write data to external memory which came from outside + * Template parameter: ReadBuffer or WriteBuffer + */ +template <typename Base> +class BufferWithOutsideMemory : public Base +{ +protected: + Memory<> & memory; +public: + + explicit BufferWithOutsideMemory(Memory<> & memory_) + : Base(memory_.data(), memory_.size()), memory(memory_) + { + Base::set(memory.data(), memory.size(), 0); + Base::padded = false; + } + + size_t getActualSize() + { + return Base::count(); + } + +private: + void nextImpl() final + { + const size_t prev_size = Base::position() - memory.data(); + memory.resize(2 * prev_size + 1); + Base::set(memory.data() + prev_size, memory.size() - prev_size, 0); + } +}; + +} diff --git a/contrib/clickhouse/src/IO/Bzip2ReadBuffer.cpp b/contrib/clickhouse/src/IO/Bzip2ReadBuffer.cpp new file mode 100644 index 0000000000..a08367dedc --- /dev/null +++ b/contrib/clickhouse/src/IO/Bzip2ReadBuffer.cpp @@ -0,0 +1,139 @@ +#include "clickhouse_config.h" + +#if USE_BZIP2 +# include <IO/Bzip2ReadBuffer.h> +# error #include <bzlib.h> +# include <IO/WithFileName.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BZIP2_STREAM_DECODER_FAILED; + extern const int UNEXPECTED_END_OF_FILE; +} + + +class Bzip2ReadBuffer::Bzip2StateWrapper +{ +public: + Bzip2StateWrapper() + { + memset(&stream, 0, sizeof(stream)); + + int ret = BZ2_bzDecompressInit(&stream, 0, 0); + + if (ret != BZ_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_DECODER_FAILED, + "bzip2 stream encoder init failed: error code: {}", + ret); + } + + ~Bzip2StateWrapper() + { + BZ2_bzDecompressEnd(&stream); + } + + void reinitialize() + { + auto avail_out = stream.avail_out; + auto * next_out = stream.next_out; + + int ret = BZ2_bzDecompressEnd(&stream); + + if (ret != BZ_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_DECODER_FAILED, + "bzip2 stream encoder reinit decompress end failed: error code: {}", + ret); + + memset(&stream, 0, sizeof(bz->stream)); + + ret = BZ2_bzDecompressInit(&stream, 0, 0); + + if (ret != BZ_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_DECODER_FAILED, + "bzip2 stream encoder reinit failed: error code: {}", + ret); + + stream.avail_out = avail_out; + stream.next_out = next_out; + } + + bz_stream stream; +}; + +Bzip2ReadBuffer::Bzip2ReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char *existing_memory, size_t alignment) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment) + , bz(std::make_unique<Bzip2StateWrapper>()) + , eof_flag(false) +{ +} + +Bzip2ReadBuffer::~Bzip2ReadBuffer() = default; + +bool Bzip2ReadBuffer::nextImpl() +{ + if (eof_flag) + return false; + + int ret; + do + { + if (!bz->stream.avail_in) + { + in->nextIfAtEnd(); + bz->stream.avail_in = static_cast<unsigned>(in->buffer().end() - in->position()); + bz->stream.next_in = in->position(); + } + + bz->stream.avail_out = static_cast<unsigned>(internal_buffer.size()); + bz->stream.next_out = internal_buffer.begin(); + + ret = BZ2_bzDecompress(&bz->stream); + + in->position() = in->buffer().end() - bz->stream.avail_in; + + if (ret == BZ_STREAM_END && !in->eof()) + { + bz->reinitialize(); + bz->stream.avail_in = static_cast<unsigned>(in->buffer().end() - in->position()); + bz->stream.next_in = in->position(); + + ret = BZ_OK; + } + } + while (bz->stream.avail_out == internal_buffer.size() && ret == BZ_OK && !in->eof()); + + working_buffer.resize(internal_buffer.size() - bz->stream.avail_out); + + if (ret == BZ_STREAM_END && in->eof()) + { + eof_flag = true; + return !working_buffer.empty(); + } + + if (ret != BZ_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_DECODER_FAILED, + "bzip2 stream decoder failed: error code: {}{}", + ret, + getExceptionEntryWithFileName(*in)); + + if (in->eof()) + { + eof_flag = true; + throw Exception( + ErrorCodes::UNEXPECTED_END_OF_FILE, + "Unexpected end of bzip2 archive{}", + getExceptionEntryWithFileName(*in)); + } + + return true; +} +} + +#endif diff --git a/contrib/clickhouse/src/IO/Bzip2ReadBuffer.h b/contrib/clickhouse/src/IO/Bzip2ReadBuffer.h new file mode 100644 index 0000000000..9131bf780b --- /dev/null +++ b/contrib/clickhouse/src/IO/Bzip2ReadBuffer.h @@ -0,0 +1,31 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/CompressedReadBufferWrapper.h> + + +namespace DB +{ + +class Bzip2ReadBuffer : public CompressedReadBufferWrapper +{ +public: + explicit Bzip2ReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~Bzip2ReadBuffer() override; + +private: + bool nextImpl() override; + + class Bzip2StateWrapper; + std::unique_ptr<Bzip2StateWrapper> bz; + + bool eof_flag; +}; + +} + diff --git a/contrib/clickhouse/src/IO/Bzip2WriteBuffer.cpp b/contrib/clickhouse/src/IO/Bzip2WriteBuffer.cpp new file mode 100644 index 0000000000..6bcbd872a3 --- /dev/null +++ b/contrib/clickhouse/src/IO/Bzip2WriteBuffer.cpp @@ -0,0 +1,110 @@ +#include "clickhouse_config.h" + +#if USE_BZIP2 +# include <IO/Bzip2WriteBuffer.h> +# error #include <bzlib.h> + +#include <Common/MemoryTracker.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BZIP2_STREAM_ENCODER_FAILED; +} + + +class Bzip2WriteBuffer::Bzip2StateWrapper +{ +public: + explicit Bzip2StateWrapper(int compression_level) + { + memset(&stream, 0, sizeof(stream)); + + int ret = BZ2_bzCompressInit(&stream, compression_level, 0, 0); + + if (ret != BZ_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_ENCODER_FAILED, + "bzip2 stream encoder init failed: error code: {}", + ret); + } + + ~Bzip2StateWrapper() + { + BZ2_bzCompressEnd(&stream); + } + + bz_stream stream; +}; + +Bzip2WriteBuffer::Bzip2WriteBuffer(std::unique_ptr<WriteBuffer> out_, int compression_level, size_t buf_size, char * existing_memory, size_t alignment) + : WriteBufferWithOwnMemoryDecorator(std::move(out_), buf_size, existing_memory, alignment) + , bz(std::make_unique<Bzip2StateWrapper>(compression_level)) +{ +} + +Bzip2WriteBuffer::~Bzip2WriteBuffer() = default; + +void Bzip2WriteBuffer::nextImpl() +{ + if (!offset()) + { + return; + } + + bz->stream.next_in = working_buffer.begin(); + bz->stream.avail_in = static_cast<unsigned>(offset()); + + try + { + do + { + out->nextIfAtEnd(); + bz->stream.next_out = out->position(); + bz->stream.avail_out = static_cast<unsigned>(out->buffer().end() - out->position()); + + int ret = BZ2_bzCompress(&bz->stream, BZ_RUN); + + out->position() = out->buffer().end() - bz->stream.avail_out; + + if (ret != BZ_RUN_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_ENCODER_FAILED, + "bzip2 stream encoder failed: error code: {}", + ret); + + } + while (bz->stream.avail_in > 0); + } + catch (...) + { + /// Do not try to write next time after exception. + out->position() = out->buffer().begin(); + throw; + } +} + +void Bzip2WriteBuffer::finalizeBefore() +{ + next(); + + out->nextIfAtEnd(); + bz->stream.next_out = out->position(); + bz->stream.avail_out = static_cast<unsigned>(out->buffer().end() - out->position()); + + int ret = BZ2_bzCompress(&bz->stream, BZ_FINISH); + + out->position() = out->buffer().end() - bz->stream.avail_out; + + if (ret != BZ_STREAM_END && ret != BZ_FINISH_OK) + throw Exception( + ErrorCodes::BZIP2_STREAM_ENCODER_FAILED, + "bzip2 stream encoder failed: error code: {}", + ret); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/Bzip2WriteBuffer.h b/contrib/clickhouse/src/IO/Bzip2WriteBuffer.h new file mode 100644 index 0000000000..d037190348 --- /dev/null +++ b/contrib/clickhouse/src/IO/Bzip2WriteBuffer.h @@ -0,0 +1,31 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <IO/WriteBufferDecorator.h> + +namespace DB +{ + +class Bzip2WriteBuffer : public WriteBufferWithOwnMemoryDecorator +{ +public: + Bzip2WriteBuffer( + std::unique_ptr<WriteBuffer> out_, + int compression_level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~Bzip2WriteBuffer() override; + +private: + void nextImpl() override; + + void finalizeBefore() override; + + class Bzip2StateWrapper; + std::unique_ptr<Bzip2StateWrapper> bz; +}; + +} diff --git a/contrib/clickhouse/src/IO/CascadeWriteBuffer.cpp b/contrib/clickhouse/src/IO/CascadeWriteBuffer.cpp new file mode 100644 index 0000000000..91a42e77fd --- /dev/null +++ b/contrib/clickhouse/src/IO/CascadeWriteBuffer.cpp @@ -0,0 +1,119 @@ +#include <IO/CascadeWriteBuffer.h> +#include <IO/MemoryReadWriteBuffer.h> +#include <Common/Exception.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CURRENT_WRITE_BUFFER_IS_EXHAUSTED; + extern const int CANNOT_WRITE_AFTER_END_OF_BUFFER; + extern const int CANNOT_CREATE_IO_BUFFER; +} + +CascadeWriteBuffer::CascadeWriteBuffer(WriteBufferPtrs && prepared_sources_, WriteBufferConstructors && lazy_sources_) + : WriteBuffer(nullptr, 0), prepared_sources(std::move(prepared_sources_)), lazy_sources(std::move(lazy_sources_)) +{ + first_lazy_source_num = prepared_sources.size(); + num_sources = first_lazy_source_num + lazy_sources.size(); + + /// fill lazy sources by nullptr + prepared_sources.resize(num_sources); + + curr_buffer_num = 0; + curr_buffer = setNextBuffer(); + set(curr_buffer->buffer().begin(), curr_buffer->buffer().size()); +} + + +void CascadeWriteBuffer::nextImpl() +{ + if (!curr_buffer) + return; + try + { + curr_buffer->position() = position(); + curr_buffer->next(); + } + catch (const MemoryWriteBuffer::CurrentBufferExhausted &) + { + if (curr_buffer_num < num_sources) + { + /// TODO: protocol should require set(position(), 0) before Exception + + /// good situation, fetch next WriteBuffer + ++curr_buffer_num; + curr_buffer = setNextBuffer(); + } + else + throw Exception(ErrorCodes::CURRENT_WRITE_BUFFER_IS_EXHAUSTED, "MemoryWriteBuffer limit is exhausted"); + } + + set(curr_buffer->position(), curr_buffer->buffer().end() - curr_buffer->position()); +} + + +void CascadeWriteBuffer::getResultBuffers(WriteBufferPtrs & res) +{ + finalize(); + + /// Sync position with underlying buffer before invalidating + curr_buffer->position() = position(); + + res = std::move(prepared_sources); + + curr_buffer = nullptr; + curr_buffer_num = num_sources = 0; + prepared_sources.clear(); + lazy_sources.clear(); +} + +void CascadeWriteBuffer::finalizeImpl() +{ + if (curr_buffer) + curr_buffer->position() = position(); + + for (auto & buf : prepared_sources) + { + if (buf) + { + buf->finalize(); + } + } +} + +WriteBuffer * CascadeWriteBuffer::setNextBuffer() +{ + if (first_lazy_source_num <= curr_buffer_num && curr_buffer_num < num_sources) + { + if (!prepared_sources[curr_buffer_num]) + { + WriteBufferPtr prev_buf = (curr_buffer_num > 0) ? prepared_sources[curr_buffer_num - 1] : nullptr; + prepared_sources[curr_buffer_num] = lazy_sources[curr_buffer_num - first_lazy_source_num](prev_buf); + } + } + else if (curr_buffer_num >= num_sources) + throw Exception(ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER, "There are no WriteBuffers to write result"); + + WriteBuffer * res = prepared_sources[curr_buffer_num].get(); + if (!res) + throw Exception(ErrorCodes::CANNOT_CREATE_IO_BUFFER, "Required WriteBuffer is not created"); + + /// Check that returned buffer isn't empty + if (!res->hasPendingData()) + res->next(); + + return res; +} + + +CascadeWriteBuffer::~CascadeWriteBuffer() +{ + /// Sync position with underlying buffer before exit + if (curr_buffer) + curr_buffer->position() = position(); +} + + +} diff --git a/contrib/clickhouse/src/IO/CascadeWriteBuffer.h b/contrib/clickhouse/src/IO/CascadeWriteBuffer.h new file mode 100644 index 0000000000..a003d11bd8 --- /dev/null +++ b/contrib/clickhouse/src/IO/CascadeWriteBuffer.h @@ -0,0 +1,63 @@ +#pragma once +#include <functional> +#include <IO/WriteBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ +} + +/* The buffer is similar to ConcatReadBuffer, but writes data + * + * It has WriteBuffers sequence [prepared_sources, lazy_sources] + * (lazy_sources contains not pointers themself, but their delayed constructors) + * + * Firtly, CascadeWriteBuffer redirects data to first buffer of the sequence + * If current WriteBuffer cannot receive data anymore, it throws special exception MemoryWriteBuffer::CurrentBufferExhausted in nextImpl() body, + * CascadeWriteBuffer prepare next buffer and continuously redirects data to it. + * If there are no buffers anymore CascadeWriteBuffer throws an exception. + * + * NOTE: If you use one of underlying WriteBuffers buffers outside, you need sync its position() with CascadeWriteBuffer's position(). + * The sync is performed into nextImpl(), getResultBuffers() and destructor. + */ +class CascadeWriteBuffer : public WriteBuffer +{ +public: + + using WriteBufferPtrs = std::vector<WriteBufferPtr>; + using WriteBufferConstructor = std::function<WriteBufferPtr (const WriteBufferPtr & prev_buf)>; + using WriteBufferConstructors = std::vector<WriteBufferConstructor>; + + explicit CascadeWriteBuffer(WriteBufferPtrs && prepared_sources_, WriteBufferConstructors && lazy_sources_ = {}); + + void nextImpl() override; + + /// Should be called once + void getResultBuffers(WriteBufferPtrs & res); + + const WriteBuffer * getCurrentBuffer() const + { + return curr_buffer; + } + + ~CascadeWriteBuffer() override; + +private: + + void finalizeImpl() override; + + WriteBuffer * setNextBuffer(); + + WriteBufferPtrs prepared_sources; + WriteBufferConstructors lazy_sources; + size_t first_lazy_source_num; + size_t num_sources; + + WriteBuffer * curr_buffer; + size_t curr_buffer_num; +}; + +} diff --git a/contrib/clickhouse/src/IO/CompressedReadBufferWrapper.h b/contrib/clickhouse/src/IO/CompressedReadBufferWrapper.h new file mode 100644 index 0000000000..bb58a7bfeb --- /dev/null +++ b/contrib/clickhouse/src/IO/CompressedReadBufferWrapper.h @@ -0,0 +1,28 @@ +#pragma once +#include <IO/BufferWithOwnMemory.h> +#include <IO/ReadBuffer.h> + +namespace DB +{ + +class CompressedReadBufferWrapper : public BufferWithOwnMemory<ReadBuffer> +{ +public: + CompressedReadBufferWrapper( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size, + char * existing_memory, + size_t alignment) + : BufferWithOwnMemory<ReadBuffer>(buf_size, existing_memory, alignment) + , in(std::move(in_)) {} + + const ReadBuffer & getWrappedReadBuffer() const { return *in; } + ReadBuffer & getWrappedReadBuffer() { return *in; } + + void prefetch(Priority priority) override { in->prefetch(priority); } + +protected: + std::unique_ptr<ReadBuffer> in; +}; + +} diff --git a/contrib/clickhouse/src/IO/CompressionMethod.cpp b/contrib/clickhouse/src/IO/CompressionMethod.cpp new file mode 100644 index 0000000000..e873f5dc8e --- /dev/null +++ b/contrib/clickhouse/src/IO/CompressionMethod.cpp @@ -0,0 +1,205 @@ +#include <IO/CompressionMethod.h> + +#include <IO/BrotliReadBuffer.h> +#include <IO/BrotliWriteBuffer.h> +#include <IO/LZMADeflatingWriteBuffer.h> +#include <IO/LZMAInflatingReadBuffer.h> +#include <IO/ReadBuffer.h> +#include <IO/WriteBuffer.h> +#include <IO/ZlibDeflatingWriteBuffer.h> +#include <IO/ZlibInflatingReadBuffer.h> +#include <IO/ZstdDeflatingWriteBuffer.h> +#include <IO/ZstdInflatingReadBuffer.h> +#include <IO/Lz4DeflatingWriteBuffer.h> +#include <IO/Lz4InflatingReadBuffer.h> +#include <IO/Bzip2ReadBuffer.h> +#include <IO/Bzip2WriteBuffer.h> +#include <IO/HadoopSnappyReadBuffer.h> + +#include "clickhouse_config.h" + +#include <boost/algorithm/string/case_conv.hpp> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; +} + + +std::string toContentEncodingName(CompressionMethod method) +{ + switch (method) + { + case CompressionMethod::Gzip: + return "gzip"; + case CompressionMethod::Zlib: + return "deflate"; + case CompressionMethod::Brotli: + return "br"; + case CompressionMethod::Xz: + return "xz"; + case CompressionMethod::Zstd: + return "zstd"; + case CompressionMethod::Lz4: + return "lz4"; + case CompressionMethod::Bzip2: + return "bz2"; + case CompressionMethod::Snappy: + return "snappy"; + case CompressionMethod::None: + return ""; + } + UNREACHABLE(); +} + +CompressionMethod chooseHTTPCompressionMethod(const std::string & list) +{ + /// The compression methods are ordered from most to least preferred. + + if (std::string::npos != list.find("zstd")) + return CompressionMethod::Zstd; + else if (std::string::npos != list.find("br")) + return CompressionMethod::Brotli; + else if (std::string::npos != list.find("lz4")) + return CompressionMethod::Lz4; + else if (std::string::npos != list.find("snappy")) + return CompressionMethod::Snappy; + else if (std::string::npos != list.find("gzip")) + return CompressionMethod::Gzip; + else if (std::string::npos != list.find("deflate")) + return CompressionMethod::Zlib; + else if (std::string::npos != list.find("xz")) + return CompressionMethod::Xz; + else if (std::string::npos != list.find("bz2")) + return CompressionMethod::Bzip2; + else + return CompressionMethod::None; +} + +CompressionMethod chooseCompressionMethod(const std::string & path, const std::string & hint) +{ + std::string file_extension; + if (hint.empty() || hint == "auto") + { + auto pos = path.find_last_of('.'); + if (pos != std::string::npos) + file_extension = path.substr(pos + 1, std::string::npos); + } + + std::string method_str; + + if (file_extension.empty()) + method_str = hint; + else + method_str = std::move(file_extension); + + boost::algorithm::to_lower(method_str); + + if (method_str == "gzip" || method_str == "gz") + return CompressionMethod::Gzip; + if (method_str == "deflate") + return CompressionMethod::Zlib; + if (method_str == "brotli" || method_str == "br") + return CompressionMethod::Brotli; + if (method_str == "lzma" || method_str == "xz") + return CompressionMethod::Xz; + if (method_str == "zstd" || method_str == "zst") + return CompressionMethod::Zstd; + if (method_str == "lz4") + return CompressionMethod::Lz4; + if (method_str == "bz2") + return CompressionMethod::Bzip2; + if (method_str == "snappy") + return CompressionMethod::Snappy; + if (hint.empty() || hint == "auto" || hint == "none") + return CompressionMethod::None; + + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unknown compression method '{}'. " + "Only 'auto', 'none', 'gzip', 'deflate', 'br', 'xz', 'zstd', 'lz4', 'bz2', 'snappy' are supported as compression methods", hint); +} + +std::pair<uint64_t, uint64_t> getCompressionLevelRange(const CompressionMethod & method) +{ + switch (method) + { + case CompressionMethod::Zstd: + return {1, 22}; + case CompressionMethod::Lz4: + return {1, 12}; + default: + return {1, 9}; + } +} + +static std::unique_ptr<CompressedReadBufferWrapper> createCompressedWrapper( + std::unique_ptr<ReadBuffer> nested, CompressionMethod method, size_t buf_size, char * existing_memory, size_t alignment, int zstd_window_log_max) +{ + if (method == CompressionMethod::Gzip || method == CompressionMethod::Zlib) + return std::make_unique<ZlibInflatingReadBuffer>(std::move(nested), method, buf_size, existing_memory, alignment); +#if USE_BROTLI + if (method == CompressionMethod::Brotli) + return std::make_unique<BrotliReadBuffer>(std::move(nested), buf_size, existing_memory, alignment); +#endif + if (method == CompressionMethod::Xz) + return std::make_unique<LZMAInflatingReadBuffer>(std::move(nested), buf_size, existing_memory, alignment); + if (method == CompressionMethod::Zstd) + return std::make_unique<ZstdInflatingReadBuffer>(std::move(nested), buf_size, existing_memory, alignment, zstd_window_log_max); + if (method == CompressionMethod::Lz4) + return std::make_unique<Lz4InflatingReadBuffer>(std::move(nested), buf_size, existing_memory, alignment); +#if USE_BZIP2 + if (method == CompressionMethod::Bzip2) + return std::make_unique<Bzip2ReadBuffer>(std::move(nested), buf_size, existing_memory, alignment); +#endif +#if USE_SNAPPY + if (method == CompressionMethod::Snappy) + return std::make_unique<HadoopSnappyReadBuffer>(std::move(nested), buf_size, existing_memory, alignment); +#endif + + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported compression method"); +} + +std::unique_ptr<ReadBuffer> wrapReadBufferWithCompressionMethod( + std::unique_ptr<ReadBuffer> nested, CompressionMethod method, int zstd_window_log_max, size_t buf_size, char * existing_memory, size_t alignment) +{ + if (method == CompressionMethod::None) + return nested; + return createCompressedWrapper(std::move(nested), method, buf_size, existing_memory, alignment, zstd_window_log_max); +} + +std::unique_ptr<WriteBuffer> wrapWriteBufferWithCompressionMethod( + std::unique_ptr<WriteBuffer> nested, CompressionMethod method, int level, size_t buf_size, char * existing_memory, size_t alignment) +{ + if (method == DB::CompressionMethod::Gzip || method == CompressionMethod::Zlib) + return std::make_unique<ZlibDeflatingWriteBuffer>(std::move(nested), method, level, buf_size, existing_memory, alignment); + +#if USE_BROTLI + if (method == DB::CompressionMethod::Brotli) + return std::make_unique<BrotliWriteBuffer>(std::move(nested), level, buf_size, existing_memory, alignment); +#endif + if (method == CompressionMethod::Xz) + return std::make_unique<LZMADeflatingWriteBuffer>(std::move(nested), level, buf_size, existing_memory, alignment); + + if (method == CompressionMethod::Zstd) + return std::make_unique<ZstdDeflatingWriteBuffer>(std::move(nested), level, buf_size, existing_memory, alignment); + + if (method == CompressionMethod::Lz4) + return std::make_unique<Lz4DeflatingWriteBuffer>(std::move(nested), level, buf_size, existing_memory, alignment); + +#if USE_BZIP2 + if (method == CompressionMethod::Bzip2) + return std::make_unique<Bzip2WriteBuffer>(std::move(nested), level, buf_size, existing_memory, alignment); +#endif +#if USE_SNAPPY + if (method == CompressionMethod::Snappy) + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported compression method"); +#endif + if (method == CompressionMethod::None) + return nested; + + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported compression method"); +} + +} diff --git a/contrib/clickhouse/src/IO/CompressionMethod.h b/contrib/clickhouse/src/IO/CompressionMethod.h new file mode 100644 index 0000000000..c142531cd0 --- /dev/null +++ b/contrib/clickhouse/src/IO/CompressionMethod.h @@ -0,0 +1,73 @@ +#pragma once + +#include <memory> +#include <string> + +#include <Core/Defines.h> + +namespace DB +{ +class ReadBuffer; +class WriteBuffer; + +/** These are "generally recognizable" compression methods for data import/export. + * Do not mess with more efficient compression methods used by ClickHouse internally + * (they use non-standard framing, indexes, checksums...) + */ + +enum class CompressionMethod +{ + None, + /// DEFLATE compression with gzip header and CRC32 checksum. + /// This option corresponds to files produced by gzip(1) or HTTP Content-Encoding: gzip. + Gzip, + /// DEFLATE compression with zlib header and Adler32 checksum. + /// This option corresponds to HTTP Content-Encoding: deflate. + Zlib, + /// LZMA2-based content compression + /// This option corresponds to HTTP Content-Encoding: xz + Xz, + /// Zstd compressor + /// This option corresponds to HTTP Content-Encoding: zstd + Zstd, + Brotli, + Lz4, + Bzip2, + Snappy, +}; + +/// How the compression method is named in HTTP. +std::string toContentEncodingName(CompressionMethod method); + +/** Choose compression method from path and hint. + * if hint is "auto" or empty string, then path is analyzed, + * otherwise path parameter is ignored and hint is used as compression method name. + * path is arbitrary string that will be analyzed for file extension (gz, br...) that determines compression. + */ +CompressionMethod chooseCompressionMethod(const std::string & path, const std::string & hint); + +/** Choose a compression method from HTTP header list of supported compression methods. + */ +CompressionMethod chooseHTTPCompressionMethod(const std::string & list); + +/// Get a range of the valid compression levels for the compression method. +std::pair<uint64_t, uint64_t> getCompressionLevelRange(const CompressionMethod & method); + +std::unique_ptr<ReadBuffer> wrapReadBufferWithCompressionMethod( + std::unique_ptr<ReadBuffer> nested, + CompressionMethod method, + int zstd_window_log_max = 0, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + +std::unique_ptr<WriteBuffer> wrapWriteBufferWithCompressionMethod( + std::unique_ptr<WriteBuffer> nested, + CompressionMethod method, + int level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + +} diff --git a/contrib/clickhouse/src/IO/ConcatReadBuffer.h b/contrib/clickhouse/src/IO/ConcatReadBuffer.h new file mode 100644 index 0000000000..3f44181a6e --- /dev/null +++ b/contrib/clickhouse/src/IO/ConcatReadBuffer.h @@ -0,0 +1,86 @@ +#pragma once + +#include <vector> + +#include <IO/ReadBuffer.h> + + +namespace DB +{ + +/// Reads from the concatenation of multiple ReadBuffer's +class ConcatReadBuffer : public ReadBuffer +{ +public: + using Buffers = std::vector<std::unique_ptr<ReadBuffer>>; + + ConcatReadBuffer() : ReadBuffer(nullptr, 0), current(buffers.end()) + { + } + + explicit ConcatReadBuffer(Buffers && buffers_) : ReadBuffer(nullptr, 0), buffers(std::move(buffers_)), current(buffers.begin()) + { + assert(!buffers.empty()); + } + + ConcatReadBuffer(std::unique_ptr<ReadBuffer> buf1, std::unique_ptr<ReadBuffer> buf2) : ConcatReadBuffer() + { + appendBuffer(std::move(buf1)); + appendBuffer(std::move(buf2)); + } + + ConcatReadBuffer(ReadBuffer & buf1, ReadBuffer & buf2) : ConcatReadBuffer() + { + appendBuffer(wrapReadBufferReference(buf1)); + appendBuffer(wrapReadBufferReference(buf2)); + } + + void appendBuffer(std::unique_ptr<ReadBuffer> buffer) + { + assert(!count()); + buffers.push_back(std::move(buffer)); + current = buffers.begin(); + } + +protected: + Buffers buffers; + Buffers::iterator current; + + bool nextImpl() override + { + if (buffers.end() == current) + return false; + + /// First reading + if (working_buffer.empty()) + { + if ((*current)->hasPendingData()) + { + working_buffer = Buffer((*current)->position(), (*current)->buffer().end()); + return true; + } + } + else + (*current)->position() = position(); + + if (!(*current)->next()) + { + ++current; + if (buffers.end() == current) + return false; + + /// We skip the filled up buffers; if the buffer is not filled in, but the cursor is at the end, then read the next piece of data. + while ((*current)->eof()) + { + ++current; + if (buffers.end() == current) + return false; + } + } + + working_buffer = Buffer((*current)->position(), (*current)->buffer().end()); + return true; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/ConcatSeekableReadBuffer.cpp b/contrib/clickhouse/src/IO/ConcatSeekableReadBuffer.cpp new file mode 100644 index 0000000000..ec2793898f --- /dev/null +++ b/contrib/clickhouse/src/IO/ConcatSeekableReadBuffer.cpp @@ -0,0 +1,144 @@ +#include <IO/ConcatSeekableReadBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; +} + +ConcatSeekableReadBuffer::BufferInfo::BufferInfo(BufferInfo && src) noexcept + : in(std::exchange(src.in, nullptr)), own_in(std::exchange(src.own_in, false)), size(std::exchange(src.size, 0)) +{ +} + +ConcatSeekableReadBuffer::BufferInfo::~BufferInfo() +{ + if (own_in) + delete in; +} + +ConcatSeekableReadBuffer::ConcatSeekableReadBuffer(std::unique_ptr<SeekableReadBuffer> buf1, size_t size1, std::unique_ptr<SeekableReadBuffer> buf2, size_t size2) : ConcatSeekableReadBuffer() +{ + appendBuffer(std::move(buf1), size1); + appendBuffer(std::move(buf2), size2); +} + +ConcatSeekableReadBuffer::ConcatSeekableReadBuffer(SeekableReadBuffer & buf1, size_t size1, SeekableReadBuffer & buf2, size_t size2) : ConcatSeekableReadBuffer() +{ + appendBuffer(buf1, size1); + appendBuffer(buf2, size2); +} + +void ConcatSeekableReadBuffer::appendBuffer(std::unique_ptr<SeekableReadBuffer> buffer, size_t size) +{ + appendBuffer(buffer.release(), true, size); +} + +void ConcatSeekableReadBuffer::appendBuffer(SeekableReadBuffer & buffer, size_t size) +{ + appendBuffer(&buffer, false, size); +} + +void ConcatSeekableReadBuffer::appendBuffer(SeekableReadBuffer * buffer, bool own, size_t size) +{ + BufferInfo info; + info.in = buffer; + info.own_in = own; + info.size = size; + + if (!size) + return; + + buffers.emplace_back(std::move(info)); + total_size += size; + + if (current == buffers.size() - 1) + { + working_buffer = buffers[current].in->buffer(); + pos = buffers[current].in->position(); + } +} + +bool ConcatSeekableReadBuffer::nextImpl() +{ + if (current < buffers.size()) + { + buffers[current].in->position() = pos; + while ((current < buffers.size()) && buffers[current].in->eof()) + { + current_start_pos += buffers[current++].size; + if (current < buffers.size()) + buffers[current].in->seek(0, SEEK_SET); + } + } + + if (current >= buffers.size()) + { + current_start_pos = total_size; + set(nullptr, 0); + return false; + } + + working_buffer = buffers[current].in->buffer(); + pos = buffers[current].in->position(); + return true; +} + +off_t ConcatSeekableReadBuffer::getPosition() +{ + size_t current_pos = current_start_pos; + if (current < buffers.size()) + current_pos += buffers[current].in->getPosition() + offset(); + return current_pos; +} + +off_t ConcatSeekableReadBuffer::seek(off_t off, int whence) +{ + off_t new_position; + off_t current_position = getPosition(); + if (whence == SEEK_SET) + new_position = off; + else if (whence == SEEK_CUR) + new_position = current_position + off; + else + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "ConcatSeekableReadBuffer::seek expects SEEK_SET or SEEK_CUR as whence"); + + if (new_position < 0) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "SEEK_SET underflow: off = {}", off); + if (static_cast<UInt64>(new_position) > total_size) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "SEEK_CUR shift out of bounds"); + + if (static_cast<UInt64>(new_position) == total_size) + { + current = buffers.size(); + current_start_pos = total_size; + set(nullptr, 0); + return new_position; + } + + off_t change_position = new_position - current_position; + if ((working_buffer.begin() <= pos + change_position) && (pos + change_position <= working_buffer.end())) + { + /// Position is still inside the same working buffer. + pos += change_position; + assert(pos >= working_buffer.begin()); + assert(pos <= working_buffer.end()); + return new_position; + } + + while (new_position < static_cast<off_t>(current_start_pos)) + current_start_pos -= buffers[--current].size; + + while (new_position >= static_cast<off_t>(current_start_pos + buffers[current].size)) + current_start_pos += buffers[current++].size; + + buffers[current].in->seek(new_position - current_start_pos, SEEK_SET); + working_buffer = buffers[current].in->buffer(); + pos = buffers[current].in->position(); + return new_position; +} + +} diff --git a/contrib/clickhouse/src/IO/ConcatSeekableReadBuffer.h b/contrib/clickhouse/src/IO/ConcatSeekableReadBuffer.h new file mode 100644 index 0000000000..c8c16c5d88 --- /dev/null +++ b/contrib/clickhouse/src/IO/ConcatSeekableReadBuffer.h @@ -0,0 +1,46 @@ +#pragma once + +#include <IO/SeekableReadBuffer.h> +#include <vector> + + +namespace DB +{ + +/// Reads from the concatenation of multiple SeekableReadBuffer's +class ConcatSeekableReadBuffer : public SeekableReadBuffer, public WithFileSize +{ +public: + ConcatSeekableReadBuffer() : SeekableReadBuffer(nullptr, 0) { } + ConcatSeekableReadBuffer(std::unique_ptr<SeekableReadBuffer> buf1, size_t size1, std::unique_ptr<SeekableReadBuffer> buf2, size_t size2); + ConcatSeekableReadBuffer(SeekableReadBuffer & buf1, size_t size1, SeekableReadBuffer & buf2, size_t size2); + + void appendBuffer(std::unique_ptr<SeekableReadBuffer> buffer, size_t size); + void appendBuffer(SeekableReadBuffer & buffer, size_t size); + + off_t seek(off_t off, int whence) override; + off_t getPosition() override; + + size_t getFileSize() override { return total_size; } + +private: + bool nextImpl() override; + void appendBuffer(SeekableReadBuffer * buffer, bool own, size_t size); + + struct BufferInfo + { + BufferInfo() = default; + BufferInfo(BufferInfo && src) noexcept; + ~BufferInfo(); + SeekableReadBuffer * in = nullptr; + bool own_in = false; + size_t size = 0; + }; + + std::vector<BufferInfo> buffers; + size_t total_size = 0; + size_t current = 0; + size_t current_start_pos = 0; /// Position of the current buffer's begin. +}; + +} diff --git a/contrib/clickhouse/src/IO/ConnectionTimeouts.cpp b/contrib/clickhouse/src/IO/ConnectionTimeouts.cpp new file mode 100644 index 0000000000..01fbaa4f81 --- /dev/null +++ b/contrib/clickhouse/src/IO/ConnectionTimeouts.cpp @@ -0,0 +1,136 @@ +#include <IO/ConnectionTimeouts.h> +#include <Poco/Util/AbstractConfiguration.h> +#include <Interpreters/Context.h> + +namespace DB +{ + +ConnectionTimeouts::ConnectionTimeouts( + Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_) + : connection_timeout(connection_timeout_) + , send_timeout(send_timeout_) + , receive_timeout(receive_timeout_) + , tcp_keep_alive_timeout(0) + , http_keep_alive_timeout(0) + , secure_connection_timeout(connection_timeout) + , hedged_connection_timeout(receive_timeout_) + , receive_data_timeout(receive_timeout_) + , handshake_timeout(receive_timeout_) +{ +} + +ConnectionTimeouts::ConnectionTimeouts( + Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + Poco::Timespan tcp_keep_alive_timeout_, + Poco::Timespan handshake_timeout_) + : connection_timeout(connection_timeout_) + , send_timeout(send_timeout_) + , receive_timeout(receive_timeout_) + , tcp_keep_alive_timeout(tcp_keep_alive_timeout_) + , http_keep_alive_timeout(0) + , secure_connection_timeout(connection_timeout) + , hedged_connection_timeout(receive_timeout_) + , receive_data_timeout(receive_timeout_) + , handshake_timeout(handshake_timeout_) +{ +} + +ConnectionTimeouts::ConnectionTimeouts( + Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + Poco::Timespan tcp_keep_alive_timeout_, + Poco::Timespan http_keep_alive_timeout_, + Poco::Timespan handshake_timeout_) + : connection_timeout(connection_timeout_) + , send_timeout(send_timeout_) + , receive_timeout(receive_timeout_) + , tcp_keep_alive_timeout(tcp_keep_alive_timeout_) + , http_keep_alive_timeout(http_keep_alive_timeout_) + , secure_connection_timeout(connection_timeout) + , hedged_connection_timeout(receive_timeout_) + , receive_data_timeout(receive_timeout_) + , handshake_timeout(handshake_timeout_) +{ +} + +ConnectionTimeouts::ConnectionTimeouts( + Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + Poco::Timespan tcp_keep_alive_timeout_, + Poco::Timespan http_keep_alive_timeout_, + Poco::Timespan secure_connection_timeout_, + Poco::Timespan hedged_connection_timeout_, + Poco::Timespan receive_data_timeout_, + Poco::Timespan handshake_timeout_) + : connection_timeout(connection_timeout_) + , send_timeout(send_timeout_) + , receive_timeout(receive_timeout_) + , tcp_keep_alive_timeout(tcp_keep_alive_timeout_) + , http_keep_alive_timeout(http_keep_alive_timeout_) + , secure_connection_timeout(secure_connection_timeout_) + , hedged_connection_timeout(hedged_connection_timeout_) + , receive_data_timeout(receive_data_timeout_) + , handshake_timeout(handshake_timeout_) +{ +} + +Poco::Timespan ConnectionTimeouts::saturate(Poco::Timespan timespan, Poco::Timespan limit) +{ + if (limit.totalMicroseconds() == 0) + return timespan; + else + return (timespan > limit) ? limit : timespan; +} + +ConnectionTimeouts ConnectionTimeouts::getSaturated(Poco::Timespan limit) const +{ + return ConnectionTimeouts(saturate(connection_timeout, limit), + saturate(send_timeout, limit), + saturate(receive_timeout, limit), + saturate(tcp_keep_alive_timeout, limit), + saturate(http_keep_alive_timeout, limit), + saturate(secure_connection_timeout, limit), + saturate(hedged_connection_timeout, limit), + saturate(receive_data_timeout, limit), + saturate(handshake_timeout, limit)); +} + +/// Timeouts for the case when we have just single attempt to connect. +ConnectionTimeouts ConnectionTimeouts::getTCPTimeoutsWithoutFailover(const Settings & settings) +{ + return ConnectionTimeouts(settings.connect_timeout, settings.send_timeout, settings.receive_timeout, settings.tcp_keep_alive_timeout, settings.handshake_timeout_ms); +} + +/// Timeouts for the case when we will try many addresses in a loop. +ConnectionTimeouts ConnectionTimeouts::getTCPTimeoutsWithFailover(const Settings & settings) +{ + return ConnectionTimeouts( + settings.connect_timeout_with_failover_ms, + settings.send_timeout, + settings.receive_timeout, + settings.tcp_keep_alive_timeout, + 0, + settings.connect_timeout_with_failover_secure_ms, + settings.hedged_connection_timeout_ms, + settings.receive_data_timeout_ms, + settings.handshake_timeout_ms); +} + +ConnectionTimeouts ConnectionTimeouts::getHTTPTimeouts(const Settings & settings, Poco::Timespan http_keep_alive_timeout) +{ + return ConnectionTimeouts( + settings.http_connection_timeout, + settings.http_send_timeout, + settings.http_receive_timeout, + settings.tcp_keep_alive_timeout, + http_keep_alive_timeout, + settings.http_receive_timeout); +} + +} diff --git a/contrib/clickhouse/src/IO/ConnectionTimeouts.h b/contrib/clickhouse/src/IO/ConnectionTimeouts.h new file mode 100644 index 0000000000..684af42827 --- /dev/null +++ b/contrib/clickhouse/src/IO/ConnectionTimeouts.h @@ -0,0 +1,72 @@ +#pragma once + +#include <Core/Defines.h> +#include <Interpreters/Context_fwd.h> + +#include <Poco/Timespan.h> + +namespace DB +{ + +struct Settings; + +struct ConnectionTimeouts +{ + Poco::Timespan connection_timeout; + Poco::Timespan send_timeout; + Poco::Timespan receive_timeout; + Poco::Timespan tcp_keep_alive_timeout; + Poco::Timespan http_keep_alive_timeout; + Poco::Timespan secure_connection_timeout; + + /// Timeouts for HedgedConnections + Poco::Timespan hedged_connection_timeout; + Poco::Timespan receive_data_timeout; + + /// Timeout for receiving HELLO packet + Poco::Timespan handshake_timeout; + + /// Timeout for synchronous request-result protocol call (like Ping or TablesStatus) + Poco::Timespan sync_request_timeout = Poco::Timespan(DBMS_DEFAULT_SYNC_REQUEST_TIMEOUT_SEC, 0); + + ConnectionTimeouts() = default; + + ConnectionTimeouts(Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_); + + ConnectionTimeouts(Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + Poco::Timespan tcp_keep_alive_timeout_, + Poco::Timespan handshake_timeout_); + + ConnectionTimeouts(Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + Poco::Timespan tcp_keep_alive_timeout_, + Poco::Timespan http_keep_alive_timeout_, + Poco::Timespan handshake_timeout_); + + ConnectionTimeouts(Poco::Timespan connection_timeout_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + Poco::Timespan tcp_keep_alive_timeout_, + Poco::Timespan http_keep_alive_timeout_, + Poco::Timespan secure_connection_timeout_, + Poco::Timespan hedged_connection_timeout_, + Poco::Timespan receive_data_timeout_, + Poco::Timespan handshake_timeout_); + + static Poco::Timespan saturate(Poco::Timespan timespan, Poco::Timespan limit); + ConnectionTimeouts getSaturated(Poco::Timespan limit) const; + + /// Timeouts for the case when we have just single attempt to connect. + static ConnectionTimeouts getTCPTimeoutsWithoutFailover(const Settings & settings); + + /// Timeouts for the case when we will try many addresses in a loop. + static ConnectionTimeouts getTCPTimeoutsWithFailover(const Settings & settings); + static ConnectionTimeouts getHTTPTimeouts(const Settings & settings, Poco::Timespan http_keep_alive_timeout); +}; + +} diff --git a/contrib/clickhouse/src/IO/DoubleConverter.cpp b/contrib/clickhouse/src/IO/DoubleConverter.cpp new file mode 100644 index 0000000000..911da5eabc --- /dev/null +++ b/contrib/clickhouse/src/IO/DoubleConverter.cpp @@ -0,0 +1,16 @@ +#include <IO/DoubleConverter.h> + +namespace DB +{ +template <bool emit_decimal_point> +const double_conversion::DoubleToStringConverter & DoubleConverter<emit_decimal_point>::instance() +{ + static const double_conversion::DoubleToStringConverter instance{ + DoubleToStringConverterFlags<emit_decimal_point>::flags, "inf", "nan", 'e', -6, 21, 6, 1}; + + return instance; +} + +template class DoubleConverter<true>; +template class DoubleConverter<false>; +} diff --git a/contrib/clickhouse/src/IO/DoubleConverter.h b/contrib/clickhouse/src/IO/DoubleConverter.h new file mode 100644 index 0000000000..18cbe4e3a1 --- /dev/null +++ b/contrib/clickhouse/src/IO/DoubleConverter.h @@ -0,0 +1,46 @@ +#pragma once + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdouble-promotion" +#endif + +#include <base/defines.h> +#include <double-conversion/double-conversion.h> +#include <boost/noncopyable.hpp> + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + +namespace DB +{ + +template <bool emit_decimal_point> struct DoubleToStringConverterFlags +{ + static constexpr auto flags = double_conversion::DoubleToStringConverter::NO_FLAGS; +}; + +template <> struct DoubleToStringConverterFlags<true> +{ + static constexpr auto flags = double_conversion::DoubleToStringConverter::EMIT_TRAILING_DECIMAL_POINT; +}; + +template <bool emit_decimal_point> +class DoubleConverter : private boost::noncopyable +{ + DoubleConverter() = default; + +public: + /// Sign (1 byte) + DigitsBeforePoint + point (1 byte) + DigitsAfterPoint + zero byte. + /// See comment to DoubleToStringConverter::ToFixed method for explanation. + static constexpr auto MAX_REPRESENTATION_LENGTH = + 1 + double_conversion::DoubleToStringConverter::kMaxFixedDigitsBeforePoint + + 1 + double_conversion::DoubleToStringConverter::kMaxFixedDigitsAfterPoint + 1; + using BufferType = char[MAX_REPRESENTATION_LENGTH]; + + static const double_conversion::DoubleToStringConverter & instance(); +}; + +} diff --git a/contrib/clickhouse/src/IO/EmptyReadBuffer.h b/contrib/clickhouse/src/IO/EmptyReadBuffer.h new file mode 100644 index 0000000000..e2189b9943 --- /dev/null +++ b/contrib/clickhouse/src/IO/EmptyReadBuffer.h @@ -0,0 +1,18 @@ +#pragma once + +#include <IO/ReadBuffer.h> + +namespace DB +{ + +/// Just a stub - reads nothing from nowhere. +class EmptyReadBuffer : public ReadBuffer +{ +public: + EmptyReadBuffer() : ReadBuffer(nullptr, 0) {} + +private: + bool nextImpl() override { return false; } +}; + +} diff --git a/contrib/clickhouse/src/IO/FileEncryptionCommon.cpp b/contrib/clickhouse/src/IO/FileEncryptionCommon.cpp new file mode 100644 index 0000000000..c354a1c8df --- /dev/null +++ b/contrib/clickhouse/src/IO/FileEncryptionCommon.cpp @@ -0,0 +1,465 @@ +#include <IO/FileEncryptionCommon.h> + +#if USE_SSL +#include <IO/ReadBuffer.h> +#include <IO/ReadHelpers.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteHelpers.h> +#include <Common/SipHash.h> +#include <Common/safe_cast.h> + +# include <cassert> +# include <boost/algorithm/string/predicate.hpp> + +# include <openssl/err.h> +# include <openssl/rand.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int DATA_ENCRYPTION_ERROR; + extern const int OPENSSL_ERROR; +} + +namespace FileEncryption +{ + +namespace +{ + const EVP_CIPHER * getCipher(Algorithm algorithm) + { + switch (algorithm) + { + case Algorithm::AES_128_CTR: return EVP_aes_128_ctr(); + case Algorithm::AES_192_CTR: return EVP_aes_192_ctr(); + case Algorithm::AES_256_CTR: return EVP_aes_256_ctr(); + case Algorithm::MAX: break; + } + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Encryption algorithm {} is not supported, specify one of the following: aes_128_ctr, aes_192_ctr, aes_256_ctr", + static_cast<int>(algorithm)); + } + + void checkKeySize(const EVP_CIPHER * evp_cipher, size_t key_size) + { + if (!key_size) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Encryption key must not be empty"); + size_t expected_key_size = static_cast<size_t>(EVP_CIPHER_key_length(evp_cipher)); + if (key_size != expected_key_size) + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Got an encryption key with unexpected size {}, the size should be {}", + key_size, expected_key_size); + } + + void checkInitVectorSize(const EVP_CIPHER * evp_cipher) + { + size_t expected_iv_length = static_cast<size_t>(EVP_CIPHER_iv_length(evp_cipher)); + if (InitVector::kSize != expected_iv_length) + throw Exception( + ErrorCodes::DATA_ENCRYPTION_ERROR, + "Got an initialization vector with unexpected size {}, the size should be {}", + InitVector::kSize, + expected_iv_length); + } + + constexpr const size_t kBlockSize = 16; + + size_t blockOffset(size_t pos) { return pos % kBlockSize; } + size_t blocks(size_t pos) { return pos / kBlockSize; } + + size_t partBlockSize(size_t size, size_t off) + { + assert(off < kBlockSize); + /// write the part as usual block + if (off == 0) + return 0; + return off + size <= kBlockSize ? size : (kBlockSize - off) % kBlockSize; + } + + size_t encryptBlocks(EVP_CIPHER_CTX * evp_ctx, const char * data, size_t size, WriteBuffer & out) + { + const uint8_t * in = reinterpret_cast<const uint8_t *>(data); + size_t in_size = 0; + size_t out_size = 0; + + while (in_size < size) + { + out.nextIfAtEnd(); + + size_t part_size = std::min(size - in_size, out.available()); + part_size = std::min<size_t>(part_size, INT_MAX); + + uint8_t * ciphertext = reinterpret_cast<uint8_t *>(out.position()); + int ciphertext_size = 0; + if (!EVP_EncryptUpdate(evp_ctx, ciphertext, &ciphertext_size, &in[in_size], static_cast<int>(part_size))) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to encrypt"); + + in_size += part_size; + if (ciphertext_size) + { + out.position() += ciphertext_size; + out_size += ciphertext_size; + } + } + + return out_size; + } + + size_t encryptBlockWithPadding(EVP_CIPHER_CTX * evp_ctx, const char * data, size_t size, size_t pad_left, WriteBuffer & out) + { + assert((size <= kBlockSize) && (size + pad_left <= kBlockSize)); + uint8_t padded_data[kBlockSize] = {}; + memcpy(&padded_data[pad_left], data, size); + size_t padded_data_size = pad_left + size; + + uint8_t ciphertext[kBlockSize]; + int ciphertext_size = 0; + if (!EVP_EncryptUpdate(evp_ctx, ciphertext, &ciphertext_size, padded_data, safe_cast<int>(padded_data_size))) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to encrypt"); + + if (!ciphertext_size) + return 0; + + if (static_cast<size_t>(ciphertext_size) < pad_left) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Unexpected size of encrypted data: {} < {}", ciphertext_size, pad_left); + + uint8_t * ciphertext_begin = &ciphertext[pad_left]; + ciphertext_size -= pad_left; + out.write(reinterpret_cast<const char *>(ciphertext_begin), ciphertext_size); + return ciphertext_size; + } + + size_t encryptFinal(EVP_CIPHER_CTX * evp_ctx, WriteBuffer & out) + { + uint8_t ciphertext[kBlockSize]; + int ciphertext_size = 0; + if (!EVP_EncryptFinal_ex(evp_ctx, + ciphertext, &ciphertext_size)) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to finalize encrypting"); + if (ciphertext_size) + out.write(reinterpret_cast<const char *>(ciphertext), ciphertext_size); + return ciphertext_size; + } + + size_t decryptBlocks(EVP_CIPHER_CTX * evp_ctx, const char * data, size_t size, char * out) + { + const uint8_t * in = reinterpret_cast<const uint8_t *>(data); + uint8_t * plaintext = reinterpret_cast<uint8_t *>(out); + int plaintext_size = 0; + if (!EVP_DecryptUpdate(evp_ctx, plaintext, &plaintext_size, in, safe_cast<int>(size))) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to decrypt"); + return plaintext_size; + } + + size_t decryptBlockWithPadding(EVP_CIPHER_CTX * evp_ctx, const char * data, size_t size, size_t pad_left, char * out) + { + assert((size <= kBlockSize) && (size + pad_left <= kBlockSize)); + uint8_t padded_data[kBlockSize] = {}; + memcpy(&padded_data[pad_left], data, size); + size_t padded_data_size = pad_left + size; + uint8_t plaintext[kBlockSize]; + int plaintext_size = 0; + if (!EVP_DecryptUpdate(evp_ctx, plaintext, &plaintext_size, padded_data, safe_cast<int>(padded_data_size))) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to decrypt"); + + if (!plaintext_size) + return 0; + + if (static_cast<size_t>(plaintext_size) < pad_left) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Unexpected size of decrypted data: {} < {}", plaintext_size, pad_left); + + const uint8_t * plaintext_begin = &plaintext[pad_left]; + plaintext_size -= pad_left; + memcpy(out, plaintext_begin, plaintext_size); + return plaintext_size; + } + + size_t decryptFinal(EVP_CIPHER_CTX * evp_ctx, char * out) + { + uint8_t plaintext[kBlockSize]; + int plaintext_size = 0; + if (!EVP_DecryptFinal_ex(evp_ctx, plaintext, &plaintext_size)) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to finalize decrypting"); + if (plaintext_size) + memcpy(out, plaintext, plaintext_size); + return plaintext_size; + } + + constexpr const std::string_view kHeaderSignature = "ENC"; + + UInt128 calculateV1KeyFingerprint(UInt8 small_key_hash, UInt64 key_id) + { + /// In the version 1 we stored {key_id, very_small_hash(key)} instead of a fingerprint. + return static_cast<UInt128>(key_id) | (static_cast<UInt128>(small_key_hash) << 64); + } +} + +String toString(Algorithm algorithm) +{ + switch (algorithm) + { + case Algorithm::AES_128_CTR: return "aes_128_ctr"; + case Algorithm::AES_192_CTR: return "aes_192_ctr"; + case Algorithm::AES_256_CTR: return "aes_256_ctr"; + case Algorithm::MAX: break; + } + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Encryption algorithm {} is not supported, specify one of the following: aes_128_ctr, aes_192_ctr, aes_256_ctr", + static_cast<int>(algorithm)); +} + +Algorithm parseAlgorithmFromString(const String & str) +{ + if (boost::iequals(str, "aes_128_ctr")) + return Algorithm::AES_128_CTR; + else if (boost::iequals(str, "aes_192_ctr")) + return Algorithm::AES_192_CTR; + else if (boost::iequals(str, "aes_256_ctr")) + return Algorithm::AES_256_CTR; + else + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Encryption algorithm '{}' is not supported, specify one of the following: aes_128_ctr, aes_192_ctr, aes_256_ctr", + str); +} + +void checkKeySize(size_t key_size, Algorithm algorithm) { checkKeySize(getCipher(algorithm), key_size); } + + +String InitVector::toString() const +{ + static_assert(sizeof(counter) == InitVector::kSize); + WriteBufferFromOwnString out; + writeBinaryBigEndian(counter, out); + return std::move(out.str()); +} + +InitVector InitVector::fromString(const String & str) +{ + if (str.length() != InitVector::kSize) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expected iv with size {}, got iv with size {}", InitVector::kSize, str.length()); + ReadBufferFromMemory in{str.data(), str.length()}; + UInt128 counter; + readBinaryBigEndian(counter, in); + return InitVector{counter}; +} + +void InitVector::read(ReadBuffer & in) +{ + readBinaryBigEndian(counter, in); +} + +void InitVector::write(WriteBuffer & out) const +{ + writeBinaryBigEndian(counter, out); +} + +InitVector InitVector::random() +{ + UInt128 counter; + auto * buf = reinterpret_cast<unsigned char *>(counter.items); + auto ret = RAND_bytes(buf, sizeof(counter.items)); + if (ret != 1) + throw Exception(DB::ErrorCodes::OPENSSL_ERROR, "OpenSSL error code: {}", ERR_get_error()); + return InitVector{counter}; +} + + +Encryptor::Encryptor(Algorithm algorithm_, const String & key_, const InitVector & iv_) + : key(key_) + , init_vector(iv_) + , evp_cipher(getCipher(algorithm_)) +{ + checkKeySize(evp_cipher, key.size()); + checkInitVectorSize(evp_cipher); +} + +void Encryptor::encrypt(const char * data, size_t size, WriteBuffer & out) +{ + if (!size) + return; + + auto current_iv = (init_vector + blocks(offset)).toString(); + + auto evp_ctx_ptr = std::unique_ptr<EVP_CIPHER_CTX, decltype(&::EVP_CIPHER_CTX_free)>(EVP_CIPHER_CTX_new(), &EVP_CIPHER_CTX_free); + auto * evp_ctx = evp_ctx_ptr.get(); + + if (!EVP_EncryptInit_ex(evp_ctx, evp_cipher, nullptr, nullptr, nullptr)) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to initialize encryption context with cipher"); + + if (!EVP_EncryptInit_ex(evp_ctx, nullptr, nullptr, + reinterpret_cast<const uint8_t*>(key.c_str()), reinterpret_cast<const uint8_t*>(current_iv.c_str()))) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to set key and IV for encryption"); + + size_t in_size = 0; + size_t out_size = 0; + + auto off = blockOffset(offset); + if (off) + { + size_t in_part_size = partBlockSize(size, off); + size_t out_part_size = encryptBlockWithPadding(evp_ctx, &data[in_size], in_part_size, off, out); + in_size += in_part_size; + out_size += out_part_size; + } + + if (in_size < size) + { + size_t in_part_size = size - in_size; + size_t out_part_size = encryptBlocks(evp_ctx, &data[in_size], in_part_size, out); + in_size += in_part_size; + out_size += out_part_size; + } + + out_size += encryptFinal(evp_ctx, out); + + if (out_size != in_size) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Only part of the data was encrypted"); + offset += in_size; +} + +void Encryptor::decrypt(const char * data, size_t size, char * out) +{ + if (!size) + return; + + auto current_iv = (init_vector + blocks(offset)).toString(); + + auto evp_ctx_ptr = std::unique_ptr<EVP_CIPHER_CTX, decltype(&::EVP_CIPHER_CTX_free)>(EVP_CIPHER_CTX_new(), &EVP_CIPHER_CTX_free); + auto * evp_ctx = evp_ctx_ptr.get(); + + if (!EVP_DecryptInit_ex(evp_ctx, evp_cipher, nullptr, nullptr, nullptr)) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to initialize decryption context with cipher"); + + if (!EVP_DecryptInit_ex(evp_ctx, nullptr, nullptr, + reinterpret_cast<const uint8_t*>(key.c_str()), reinterpret_cast<const uint8_t*>(current_iv.c_str()))) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Failed to set key and IV for decryption"); + + size_t in_size = 0; + size_t out_size = 0; + + auto off = blockOffset(offset); + if (off) + { + size_t in_part_size = partBlockSize(size, off); + size_t out_part_size = decryptBlockWithPadding(evp_ctx, &data[in_size], in_part_size, off, &out[out_size]); + in_size += in_part_size; + out_size += out_part_size; + } + + if (in_size < size) + { + size_t in_part_size = size - in_size; + size_t out_part_size = decryptBlocks(evp_ctx, &data[in_size], in_part_size, &out[out_size]); + in_size += in_part_size; + out_size += out_part_size; + } + + out_size += decryptFinal(evp_ctx, &out[out_size]); + + if (out_size != in_size) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Only part of the data was decrypted"); + offset += in_size; +} + + +void Header::read(ReadBuffer & in) +{ + char signature[kHeaderSignature.length()]; + in.readStrict(signature, kHeaderSignature.length()); + if (memcmp(signature, kHeaderSignature.data(), kHeaderSignature.length()) != 0) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Wrong signature, this is not an encrypted file"); + + /// The endianness of how the header is written. + /// Starting from version 2 the header is always in little endian. + std::endian endian = std::endian::little; + + readBinaryLittleEndian(version, in); + + if (version == 0x0100ULL) + { + /// Version 1 could write the header of an encrypted file in either little-endian or big-endian. + /// So now if we read the version as little-endian and it's 256 that means two things: the version is actually 1 and the whole header is in big endian. + endian = std::endian::big; + version = 1; + } + + if (version < 1 || version > kCurrentVersion) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Version {} of the header is not supported", version); + + UInt16 algorithm_u16; + readPODBinary(algorithm_u16, in); + if (std::endian::native != endian) + algorithm_u16 = DB::byteswap(algorithm_u16); + if (algorithm_u16 >= static_cast<UInt16>(Algorithm::MAX)) + throw Exception(ErrorCodes::DATA_ENCRYPTION_ERROR, "Algorithm {} is not supported", algorithm_u16); + algorithm = static_cast<Algorithm>(algorithm_u16); + + size_t bytes_to_skip = kSize - kHeaderSignature.length() - sizeof(version) - sizeof(algorithm_u16) - InitVector::kSize; + + if (version < 2) + { + UInt64 key_id; + UInt8 small_key_hash; + readPODBinary(key_id, in); + readPODBinary(small_key_hash, in); + bytes_to_skip -= sizeof(key_id) + sizeof(small_key_hash); + if (std::endian::native != endian) + key_id = DB::byteswap(key_id); + key_fingerprint = calculateV1KeyFingerprint(small_key_hash, key_id); + } + else + { + readBinaryLittleEndian(key_fingerprint, in); + bytes_to_skip -= sizeof(key_fingerprint); + } + + init_vector.read(in); + + chassert(bytes_to_skip < kSize); + in.ignore(bytes_to_skip); +} + +void Header::write(WriteBuffer & out) const +{ + writeString(kHeaderSignature, out); + + writeBinaryLittleEndian(version, out); + + UInt16 algorithm_u16 = static_cast<UInt16>(algorithm); + writeBinaryLittleEndian(algorithm_u16, out); + + writeBinaryLittleEndian(key_fingerprint, out); + + init_vector.write(out); + + constexpr size_t reserved_size = kSize - kHeaderSignature.length() - sizeof(version) - sizeof(algorithm_u16) - sizeof(key_fingerprint) - InitVector::kSize; + static_assert(reserved_size < kSize); + char zero_bytes[reserved_size] = {}; + out.write(zero_bytes, reserved_size); +} + +UInt128 calculateKeyFingerprint(const String & key) +{ + const UInt64 seed0 = 0x4368456E63727970ULL; // ChEncryp + const UInt64 seed1 = 0x7465644469736B46ULL; // tedDiskF + return sipHash128Keyed(seed0, seed1, key.data(), key.size()); +} + +UInt128 calculateV1KeyFingerprint(const String & key, UInt64 key_id) +{ + /// In the version 1 we stored {key_id, very_small_hash(key)} instead of a fingerprint. + UInt8 small_key_hash = sipHash64(key.data(), key.size()) & 0x0F; + return calculateV1KeyFingerprint(small_key_hash, key_id); +} + +} +} + +#endif diff --git a/contrib/clickhouse/src/IO/FileEncryptionCommon.h b/contrib/clickhouse/src/IO/FileEncryptionCommon.h new file mode 100644 index 0000000000..777d171157 --- /dev/null +++ b/contrib/clickhouse/src/IO/FileEncryptionCommon.h @@ -0,0 +1,154 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_SSL +#include <Core/Types.h> +#include <openssl/evp.h> + +namespace DB +{ +class ReadBuffer; +class WriteBuffer; + +namespace FileEncryption +{ + +/// Encryption algorithm. +/// We chose to use CTR cipther algorithms because they have the following features which are important for us: +/// - No right padding, so we can append encrypted files without deciphering; +/// - One byte is always ciphered as one byte, so we get random access to encrypted files easily. +enum class Algorithm +{ + AES_128_CTR, /// Size of key is 16 bytes. + AES_192_CTR, /// Size of key is 24 bytes. + AES_256_CTR, /// Size of key is 32 bytes. + MAX +}; + +String toString(Algorithm algorithm); +Algorithm parseAlgorithmFromString(const String & str); + +/// Throws an exception if a specified key size doesn't correspond a specified encryption algorithm. +void checkKeySize(size_t key_size, Algorithm algorithm); + + +/// Initialization vector. Its size is always 16 bytes. +class InitVector +{ +public: + static constexpr const size_t kSize = 16; + + InitVector() = default; + explicit InitVector(const UInt128 & counter_) { set(counter_); } + + void set(const UInt128 & counter_) { counter = counter_; } + UInt128 get() const { return counter; } + + void read(ReadBuffer & in); + void write(WriteBuffer & out) const; + + /// Write 16 bytes of the counter to a string in big endian order. + /// We need big endian because the used cipher algorithms treat an initialization vector as a counter in big endian. + String toString() const; + + /// Converts a string of 16 bytes length in big endian order to a counter. + static InitVector fromString(const String & str_); + + /// Adds a specified offset to the counter. + InitVector & operator++() { ++counter; return *this; } + InitVector operator++(int) { InitVector res = *this; ++counter; return res; } /// NOLINT + InitVector & operator+=(size_t offset) { counter += offset; return *this; } + InitVector operator+(size_t offset) const { InitVector res = *this; return res += offset; } + + /// Generates a random initialization vector. + static InitVector random(); + +private: + UInt128 counter = 0; +}; + + +/// Encrypts or decrypts data. +class Encryptor +{ +public: + /// The `key` should have size 16 or 24 or 32 bytes depending on which `algorithm` is specified. + Encryptor(Algorithm algorithm_, const String & key_, const InitVector & iv_); + + /// Sets the current position in the data stream from the very beginning of data. + /// It affects how the data will be encrypted or decrypted because + /// the initialization vector is increased by an index of the current block + /// and the index of the current block is calculated from this offset. + void setOffset(size_t offset_) { offset = offset_; } + size_t getOffset() const { return offset; } + + /// Encrypts some data. + /// Also the function moves `offset` by `size` (for successive encryptions). + void encrypt(const char * data, size_t size, WriteBuffer & out); + + /// Decrypts some data. + /// The used cipher algorithms generate the same number of bytes in output as they were in input, + /// so the function always writes `size` bytes of the plaintext to `out`. + /// Also the function moves `offset` by `size` (for successive decryptions). + void decrypt(const char * data, size_t size, char * out); + +private: + const String key; + const InitVector init_vector; + const EVP_CIPHER * const evp_cipher; + + /// The current position in the data stream from the very beginning of data. + size_t offset = 0; +}; + + +/// File header which is stored at the beginning of encrypted files. +/// +/// The format of that header is following: +/// +--------+------+--------------------------------------------------------------------------+ +/// | offset | size | description | +/// +--------+------+--------------------------------------------------------------------------+ +/// | 0 | 3 | 'E', 'N', 'C' (file's signature) | +/// | 3 | 2 | version of this header (1..2) | +/// | 5 | 2 | encryption algorithm (0..2, 0=AES_128_CTR, 1=AES_192_CTR, 2=AES_256_CTR) | +/// | 7 | 16 | fingerprint of encryption key (SipHash) | +/// | 23 | 16 | initialization vector (randomly generated) | +/// | 39 | 25 | reserved for future use | +/// +--------+------+--------------------------------------------------------------------------+ +/// +struct Header +{ + /// Versions: + /// 1 - Initial version + /// 2 - The header of an encrypted file contains the fingerprint of a used encryption key instead of a pair {key_id, very_small_hash(key)}. + /// The header is always stored in little endian. + static constexpr const UInt16 kCurrentVersion = 2; + + UInt16 version = kCurrentVersion; + + /// Encryption algorithm. + Algorithm algorithm = Algorithm::AES_128_CTR; + + /// Fingerprint of a key. + UInt128 key_fingerprint = 0; + + InitVector init_vector; + + /// The size of this header in bytes, including reserved bytes. + static constexpr const size_t kSize = 64; + + void read(ReadBuffer & in); + void write(WriteBuffer & out) const; +}; + +/// Calculates the fingerprint of a passed encryption key. +UInt128 calculateKeyFingerprint(const String & key); + +/// Calculates kind of the fingerprint of a passed encryption key & key ID as it was implemented in version 1. +UInt128 calculateV1KeyFingerprint(const String & key, UInt64 key_id); + +} +} + +#endif diff --git a/contrib/clickhouse/src/IO/ForkWriteBuffer.cpp b/contrib/clickhouse/src/IO/ForkWriteBuffer.cpp new file mode 100644 index 0000000000..8e11b9ff59 --- /dev/null +++ b/contrib/clickhouse/src/IO/ForkWriteBuffer.cpp @@ -0,0 +1,60 @@ +#include <IO/ForkWriteBuffer.h> +#include <Common/Exception.h> + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int CANNOT_CREATE_IO_BUFFER; +} + +ForkWriteBuffer::ForkWriteBuffer(WriteBufferPtrs && sources_) + : WriteBuffer(nullptr, 0), sources(std::move(sources_)) +{ + if (sources.empty()) + { + throw Exception(ErrorCodes::CANNOT_CREATE_IO_BUFFER, "Expected non-zero number of buffers for `ForkWriteBuffer`"); + } + set(sources.front()->buffer().begin(), sources.front()->buffer().size()); +} + + +void ForkWriteBuffer::nextImpl() +{ + sources.front()->position() = position(); + + try + { + auto & source_buffer = sources.front(); + for (auto it = sources.begin() + 1; it != sources.end(); ++it) + { + auto & buffer = *it; + buffer->write(source_buffer->buffer().begin(), source_buffer->offset()); + buffer->next(); + } + source_buffer->next(); + } + catch (Exception & exception) + { + exception.addMessage("While writing to ForkWriteBuffer"); + throw; + } + +} + +void ForkWriteBuffer::finalizeImpl() +{ + for (const WriteBufferPtr & buffer : sources) + { + buffer->finalize(); + } +} + +ForkWriteBuffer::~ForkWriteBuffer() +{ + finalize(); +} + + +} diff --git a/contrib/clickhouse/src/IO/ForkWriteBuffer.h b/contrib/clickhouse/src/IO/ForkWriteBuffer.h new file mode 100644 index 0000000000..17fc82028a --- /dev/null +++ b/contrib/clickhouse/src/IO/ForkWriteBuffer.h @@ -0,0 +1,34 @@ +#pragma once +#include <IO/WriteBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ +} + +/** ForkWriteBuffer takes a vector of WriteBuffer and writes data to all of them + * If the vector of WriteBufferPts is empty, then it throws an error + * It uses the buffer of the first element as its buffer and copies data from + * first buffer to all the other buffers + **/ +class ForkWriteBuffer : public WriteBuffer +{ +public: + + using WriteBufferPtrs = std::vector<WriteBufferPtr>; + + explicit ForkWriteBuffer(WriteBufferPtrs && sources_); + ~ForkWriteBuffer() override; + +protected: + void nextImpl() override; + void finalizeImpl() override; + +private: + WriteBufferPtrs sources; +}; + +} diff --git a/contrib/clickhouse/src/IO/HTTPChunkedReadBuffer.cpp b/contrib/clickhouse/src/IO/HTTPChunkedReadBuffer.cpp new file mode 100644 index 0000000000..29034b35e1 --- /dev/null +++ b/contrib/clickhouse/src/IO/HTTPChunkedReadBuffer.cpp @@ -0,0 +1,91 @@ +#include <IO/HTTPChunkedReadBuffer.h> + +#include <IO/ReadHelpers.h> +#include <Common/StringUtils/StringUtils.h> +#include <base/hex.h> +#include <base/arithmeticOverflow.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int UNEXPECTED_END_OF_FILE; + extern const int CORRUPTED_DATA; +} + +size_t HTTPChunkedReadBuffer::readChunkHeader() +{ + if (in->eof()) + throw Exception(ErrorCodes::UNEXPECTED_END_OF_FILE, "Unexpected end of file while reading chunk header of HTTP chunked data"); + + if (!isHexDigit(*in->position())) + throw Exception(ErrorCodes::CORRUPTED_DATA, "Unexpected data instead of HTTP chunk header"); + + size_t res = 0; + do + { + if (common::mulOverflow(res, 16ul, res) || common::addOverflow<size_t>(res, unhex(*in->position()), res)) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Chunk size is out of bounds"); + ++in->position(); + } while (!in->eof() && isHexDigit(*in->position())); + + if (res > max_chunk_size) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Chunk size exceeded the limit (max size: {})", max_chunk_size); + + /// NOTE: If we want to read any chunk extensions, it should be done here. + + skipToCarriageReturnOrEOF(*in); + + if (in->eof()) + throw Exception(ErrorCodes::UNEXPECTED_END_OF_FILE, "Unexpected end of file while reading chunk header of HTTP chunked data"); + + assertString("\n", *in); + return res; +} + +void HTTPChunkedReadBuffer::readChunkFooter() +{ + assertString("\r\n", *in); +} + +bool HTTPChunkedReadBuffer::nextImpl() +{ + if (!in) + return false; + + /// The footer of previous chunk. + if (count()) + readChunkFooter(); + + size_t chunk_size = readChunkHeader(); + if (0 == chunk_size) + { + readChunkFooter(); + in.reset(); // prevent double-eof situation. + return false; + } + + if (in->available() >= chunk_size) + { + /// Zero-copy read from input. + working_buffer = Buffer(in->position(), in->position() + chunk_size); + in->position() += chunk_size; + } + else + { + /// Chunk is not completely in buffer, copy it to scratch space. + memory.resize(chunk_size); + in->readStrict(memory.data(), chunk_size); + working_buffer = Buffer(memory.data(), memory.data() + chunk_size); + } + + /// NOTE: We postpone reading the footer to the next iteration, because it may not be completely in buffer, + /// but we need to keep the current data in buffer available. + + return true; +} + +} diff --git a/contrib/clickhouse/src/IO/HTTPChunkedReadBuffer.h b/contrib/clickhouse/src/IO/HTTPChunkedReadBuffer.h new file mode 100644 index 0000000000..68d90e470f --- /dev/null +++ b/contrib/clickhouse/src/IO/HTTPChunkedReadBuffer.h @@ -0,0 +1,27 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/ReadBuffer.h> + +namespace DB +{ + +/// Reads data with HTTP Chunked Transfer Encoding. +class HTTPChunkedReadBuffer : public BufferWithOwnMemory<ReadBuffer> +{ +public: + explicit HTTPChunkedReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t max_chunk_size_) + : max_chunk_size(max_chunk_size_), in(std::move(in_)) + {} + +private: + const size_t max_chunk_size; + std::unique_ptr<ReadBuffer> in; + + size_t readChunkHeader(); + void readChunkFooter(); + + bool nextImpl() override; +}; + +} diff --git a/contrib/clickhouse/src/IO/HTTPCommon.cpp b/contrib/clickhouse/src/IO/HTTPCommon.cpp new file mode 100644 index 0000000000..077adfd863 --- /dev/null +++ b/contrib/clickhouse/src/IO/HTTPCommon.cpp @@ -0,0 +1,449 @@ +#include <IO/HTTPCommon.h> + +#include <Server/HTTP/HTTPServerResponse.h> +#include <Poco/Any.h> +#include <Common/Concepts.h> +#include <Common/DNSResolver.h> +#include <Common/Exception.h> +#include <Common/MemoryTrackerSwitcher.h> +#include <Common/PoolBase.h> +#include <Common/ProfileEvents.h> +#include <Common/SipHash.h> + +#include "clickhouse_config.h" + +#if USE_SSL +# include <Poco/Net/AcceptCertificateHandler.h> +# include <Poco/Net/Context.h> +# include <Poco/Net/HTTPSClientSession.h> +# include <Poco/Net/InvalidCertificateHandler.h> +# include <Poco/Net/PrivateKeyPassphraseHandler.h> +# include <Poco/Net/RejectCertificateHandler.h> +# include <Poco/Net/SSLManager.h> +# include <Poco/Net/SecureStreamSocket.h> +#endif + +#include <Poco/Util/Application.h> + +#include <sstream> +#include <tuple> +#include <unordered_map> + + +namespace ProfileEvents +{ + extern const Event CreatedHTTPConnections; +} + +namespace DB +{ +namespace ErrorCodes +{ + extern const int RECEIVED_ERROR_FROM_REMOTE_IO_SERVER; + extern const int RECEIVED_ERROR_TOO_MANY_REQUESTS; + extern const int FEATURE_IS_NOT_ENABLED_AT_BUILD_TIME; + extern const int UNSUPPORTED_URI_SCHEME; + extern const int LOGICAL_ERROR; +} + + +namespace +{ + void setTimeouts(Poco::Net::HTTPClientSession & session, const ConnectionTimeouts & timeouts) + { + session.setTimeout(timeouts.connection_timeout, timeouts.send_timeout, timeouts.receive_timeout); + session.setKeepAliveTimeout(timeouts.http_keep_alive_timeout); + } + + template <typename Session> + requires std::derived_from<Session, Poco::Net::HTTPClientSession> + class HTTPSessionAdapter : public Session + { + static_assert(std::has_virtual_destructor_v<Session>, "The base class must have a virtual destructor"); + + public: + HTTPSessionAdapter(const std::string & host, UInt16 port) : Session(host, port), log{&Poco::Logger::get("HTTPSessionAdapter")} { } + ~HTTPSessionAdapter() override = default; + + protected: +#if 0 + void reconnect() override + { + // First of all will try to establish connection with last used addr. + if (!Session::getResolvedHost().empty()) + { + try + { + Session::reconnect(); + return; + } + catch (...) + { + Session::close(); + LOG_TRACE( + log, + "Last ip ({}) is unreachable for {}:{}. Will try another resolved address.", + Session::getResolvedHost(), + Session::getHost(), + Session::getPort()); + } + } + + const auto endpoinds = DNSResolver::instance().resolveHostAll(Session::getHost()); + + for (auto it = endpoinds.begin();;) + { + try + { + Session::setResolvedHost(it->toString()); + Session::reconnect(); + + LOG_TRACE( + log, + "Created HTTP(S) session with {}:{} ({}:{})", + Session::getHost(), + Session::getPort(), + it->toString(), + Session::getPort()); + + break; + } + catch (...) + { + Session::close(); + if (++it == endpoinds.end()) + { + Session::setResolvedHost(""); + throw; + } + LOG_TRACE( + log, + "Failed to create connection with {}:{}, Will try another resolved address. {}", + Session::getResolvedHost(), + Session::getPort(), + getCurrentExceptionMessage(false)); + } + } + } +#endif + Poco::Logger * log; + }; + + bool isHTTPS(const Poco::URI & uri) + { + if (uri.getScheme() == "https") + return true; + else if (uri.getScheme() == "http") + return false; + else + throw Exception(ErrorCodes::UNSUPPORTED_URI_SCHEME, "Unsupported scheme in URI '{}'", uri.toString()); + } + + HTTPSessionPtr makeHTTPSessionImpl( + const std::string & host, + UInt16 port, + bool https, + bool keep_alive, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config = {}) + { + HTTPSessionPtr session; + + if (https) + { +#if USE_SSL + session = std::make_shared<HTTPSessionAdapter<Poco::Net::HTTPSClientSession>>(host, port); +#else + throw Exception(ErrorCodes::FEATURE_IS_NOT_ENABLED_AT_BUILD_TIME, "ClickHouse was built without HTTPS support"); +#endif + } + else + { + session = std::make_shared<HTTPSessionAdapter<Poco::Net::HTTPClientSession>>(host, port); + } + + ProfileEvents::increment(ProfileEvents::CreatedHTTPConnections); + + /// doesn't work properly without patch + session->setKeepAlive(keep_alive); + + session->setProxyConfig(proxy_config); + + return session; + } + + class SingleEndpointHTTPSessionPool : public PoolBase<Poco::Net::HTTPClientSession> + { + private: + const std::string host; + const UInt16 port; + const bool https; + const String proxy_host; + const UInt16 proxy_port; + const bool proxy_https; + + using Base = PoolBase<Poco::Net::HTTPClientSession>; + + ObjectPtr allocObject() override + { + /// Pool is global, we shouldn't attribute this memory to query/user. + MemoryTrackerSwitcher switcher{&total_memory_tracker}; + + auto session = makeHTTPSessionImpl(host, port, https, true); + if (!proxy_host.empty()) + { + const String proxy_scheme = proxy_https ? "https" : "http"; + session->setProxyHost(proxy_host); + session->setProxyPort(proxy_port); + +#if 0 + session->setProxyProtocol(proxy_scheme); + + /// Turn on tunnel mode if proxy scheme is HTTP while endpoint scheme is HTTPS. + session->setProxyTunnel(!proxy_https && https); +#endif + } + return session; + } + + public: + SingleEndpointHTTPSessionPool( + const std::string & host_, + UInt16 port_, + bool https_, + const std::string & proxy_host_, + UInt16 proxy_port_, + bool proxy_https_, + size_t max_pool_size_, + bool wait_on_pool_size_limit) + : Base( + static_cast<unsigned>(max_pool_size_), + &Poco::Logger::get("HTTPSessionPool"), + wait_on_pool_size_limit ? BehaviourOnLimit::Wait : BehaviourOnLimit::AllocateNewBypassingPool) + , host(host_) + , port(port_) + , https(https_) + , proxy_host(proxy_host_) + , proxy_port(proxy_port_) + , proxy_https(proxy_https_) + { + } + }; + + class HTTPSessionPool : private boost::noncopyable + { + public: + struct Key + { + String target_host; + UInt16 target_port; + bool is_target_https; + String proxy_host; + UInt16 proxy_port; + bool is_proxy_https; + bool wait_on_pool_size_limit; + + bool operator ==(const Key & rhs) const + { + return std::tie(target_host, target_port, is_target_https, proxy_host, proxy_port, is_proxy_https, wait_on_pool_size_limit) + == std::tie(rhs.target_host, rhs.target_port, rhs.is_target_https, rhs.proxy_host, rhs.proxy_port, rhs.is_proxy_https, rhs.wait_on_pool_size_limit); + } + }; + + private: + using PoolPtr = std::shared_ptr<SingleEndpointHTTPSessionPool>; + using Entry = SingleEndpointHTTPSessionPool::Entry; + + struct Hasher + { + size_t operator()(const Key & k) const + { + SipHash s; + s.update(k.target_host); + s.update(k.target_port); + s.update(k.is_target_https); + s.update(k.proxy_host); + s.update(k.proxy_port); + s.update(k.is_proxy_https); + s.update(k.wait_on_pool_size_limit); + return s.get64(); + } + }; + + std::mutex mutex; + std::unordered_map<Key, PoolPtr, Hasher> endpoints_pool; + + protected: + HTTPSessionPool() = default; + + public: + static auto & instance() + { + static HTTPSessionPool instance; + return instance; + } + + Entry getSession( + const Poco::URI & uri, + const Poco::URI & proxy_uri, + const ConnectionTimeouts & timeouts, + size_t max_connections_per_endpoint, + bool wait_on_pool_size_limit) + { + std::unique_lock lock(mutex); + const std::string & host = uri.getHost(); + UInt16 port = uri.getPort(); + bool https = isHTTPS(uri); + + String proxy_host; + UInt16 proxy_port = 0; + bool proxy_https = false; + if (!proxy_uri.empty()) + { + proxy_host = proxy_uri.getHost(); + proxy_port = proxy_uri.getPort(); + proxy_https = isHTTPS(proxy_uri); + } + + HTTPSessionPool::Key key{host, port, https, proxy_host, proxy_port, proxy_https, wait_on_pool_size_limit}; + auto pool_ptr = endpoints_pool.find(key); + if (pool_ptr == endpoints_pool.end()) + std::tie(pool_ptr, std::ignore) = endpoints_pool.emplace( + key, + std::make_shared<SingleEndpointHTTPSessionPool>( + host, + port, + https, + proxy_host, + proxy_port, + proxy_https, + max_connections_per_endpoint, + wait_on_pool_size_limit)); + + /// Some routines held session objects until the end of its lifetime. Also this routines may create another sessions in this time frame. + /// If some other session holds `lock` because it waits on another lock inside `pool_ptr->second->get` it isn't possible to create any + /// new session and thus finish routine, return session to the pool and unlock the thread waiting inside `pool_ptr->second->get`. + /// To avoid such a deadlock we unlock `lock` before entering `pool_ptr->second->get`. + lock.unlock(); + + auto retry_timeout = timeouts.connection_timeout.totalMicroseconds(); + auto session = pool_ptr->second->get(retry_timeout); + + setTimeouts(*session, timeouts); + + return session; + } + }; +} + +void setResponseDefaultHeaders(HTTPServerResponse & response, size_t keep_alive_timeout) +{ + if (!response.getKeepAlive()) + return; + + Poco::Timespan timeout(keep_alive_timeout, 0); + if (timeout.totalSeconds()) + response.set("Keep-Alive", "timeout=" + std::to_string(timeout.totalSeconds())); +} + +HTTPSessionPtr makeHTTPSession( + const Poco::URI & uri, + const ConnectionTimeouts & timeouts, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config +) +{ + const std::string & host = uri.getHost(); + UInt16 port = uri.getPort(); + bool https = isHTTPS(uri); + + auto session = makeHTTPSessionImpl(host, port, https, false, proxy_config); + setTimeouts(*session, timeouts); + return session; +} + + +PooledHTTPSessionPtr makePooledHTTPSession( + const Poco::URI & uri, + const ConnectionTimeouts & timeouts, + size_t per_endpoint_pool_size, + bool wait_on_pool_size_limit) +{ + return makePooledHTTPSession(uri, {}, timeouts, per_endpoint_pool_size, wait_on_pool_size_limit); +} + +PooledHTTPSessionPtr makePooledHTTPSession( + const Poco::URI & uri, + const Poco::URI & proxy_uri, + const ConnectionTimeouts & timeouts, + size_t per_endpoint_pool_size, + bool wait_on_pool_size_limit) +{ + return HTTPSessionPool::instance().getSession(uri, proxy_uri, timeouts, per_endpoint_pool_size, wait_on_pool_size_limit); +} + +bool isRedirect(const Poco::Net::HTTPResponse::HTTPStatus status) { return status == Poco::Net::HTTPResponse::HTTP_MOVED_PERMANENTLY || status == Poco::Net::HTTPResponse::HTTP_FOUND || status == Poco::Net::HTTPResponse::HTTP_SEE_OTHER || status == Poco::Net::HTTPResponse::HTTP_TEMPORARY_REDIRECT; } + +std::istream * receiveResponse( + Poco::Net::HTTPClientSession & session, const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response, const bool allow_redirects) +{ + auto & istr = session.receiveResponse(response); + assertResponseIsOk(request, response, istr, allow_redirects); + return &istr; +} + +void assertResponseIsOk(const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response, std::istream & istr, const bool allow_redirects) +{ + auto status = response.getStatus(); + + if (!(status == Poco::Net::HTTPResponse::HTTP_OK + || status == Poco::Net::HTTPResponse::HTTP_CREATED + || status == Poco::Net::HTTPResponse::HTTP_ACCEPTED + || status == Poco::Net::HTTPResponse::HTTP_PARTIAL_CONTENT /// Reading with Range header was successful. + || (isRedirect(status) && allow_redirects))) + { + int code = status == Poco::Net::HTTPResponse::HTTP_TOO_MANY_REQUESTS + ? ErrorCodes::RECEIVED_ERROR_TOO_MANY_REQUESTS + : ErrorCodes::RECEIVED_ERROR_FROM_REMOTE_IO_SERVER; + + std::stringstream body; // STYLE_CHECK_ALLOW_STD_STRING_STREAM + body.exceptions(std::ios::failbit); + body << istr.rdbuf(); + + throw HTTPException(code, request.getURI(), status, response.getReason(), body.str()); + } +} + +Exception HTTPException::makeExceptionMessage( + int code, + const std::string & uri, + Poco::Net::HTTPResponse::HTTPStatus http_status, + const std::string & reason, + const std::string & body) +{ + return Exception(code, + "Received error from remote server {}. " + "HTTP status code: {} {}, " + "body: {}", + uri, static_cast<int>(http_status), reason, body); +} + +void markSessionForReuse(Poco::Net::HTTPSession & session) +{ + const auto & session_data = session.sessionData(); + if (!session_data.empty() && !Poco::AnyCast<HTTPSessionReuseTag>(&session_data)) + throw Exception( + ErrorCodes::LOGICAL_ERROR, "Data of an unexpected type ({}) is attached to the session", session_data.type().name()); + + session.attachSessionData(HTTPSessionReuseTag{}); +} + +void markSessionForReuse(HTTPSessionPtr session) +{ + markSessionForReuse(*session); +} + +void markSessionForReuse(PooledHTTPSessionPtr session) +{ + markSessionForReuse(static_cast<Poco::Net::HTTPSession &>(*session)); +} + +} diff --git a/contrib/clickhouse/src/IO/HTTPCommon.h b/contrib/clickhouse/src/IO/HTTPCommon.h new file mode 100644 index 0000000000..caf2fa361d --- /dev/null +++ b/contrib/clickhouse/src/IO/HTTPCommon.h @@ -0,0 +1,104 @@ +#pragma once + +#include <memory> +#include <mutex> + +#include <Poco/Net/HTTPClientSession.h> +#include <Poco/Net/HTTPRequest.h> +#include <Poco/Net/HTTPResponse.h> +#include <Poco/URI.h> +#include <Common/PoolBase.h> +#include <Poco/URIStreamFactory.h> + +#include <IO/ConnectionTimeouts.h> + + +namespace DB +{ + +class HTTPServerResponse; + +class HTTPException : public Exception +{ +public: + HTTPException( + int code, + const std::string & uri, + Poco::Net::HTTPResponse::HTTPStatus http_status_, + const std::string & reason, + const std::string & body + ) + : Exception(makeExceptionMessage(code, uri, http_status_, reason, body)) + , http_status(http_status_) + {} + + HTTPException * clone() const override { return new HTTPException(*this); } + void rethrow() const override { throw *this; } + + int getHTTPStatus() const { return http_status; } + +private: + Poco::Net::HTTPResponse::HTTPStatus http_status{}; + + static Exception makeExceptionMessage( + int code, + const std::string & uri, + Poco::Net::HTTPResponse::HTTPStatus http_status, + const std::string & reason, + const std::string & body); + + const char * name() const noexcept override { return "DB::HTTPException"; } + const char * className() const noexcept override { return "DB::HTTPException"; } +}; + +using PooledHTTPSessionPtr = PoolBase<Poco::Net::HTTPClientSession>::Entry; // SingleEndpointHTTPSessionPool::Entry +using HTTPSessionPtr = std::shared_ptr<Poco::Net::HTTPClientSession>; + +/// If a session have this tag attached, it will be reused without calling `reset()` on it. +/// All pooled sessions don't have this tag attached after being taken from a pool. +/// If the request and the response were fully written/read, the client code should add this tag +/// explicitly by calling `markSessionForReuse()`. +struct HTTPSessionReuseTag +{ +}; + +void markSessionForReuse(HTTPSessionPtr session); +void markSessionForReuse(PooledHTTPSessionPtr session); + + +void setResponseDefaultHeaders(HTTPServerResponse & response, size_t keep_alive_timeout); + +/// Create session object to perform requests and set required parameters. +HTTPSessionPtr makeHTTPSession( + const Poco::URI & uri, + const ConnectionTimeouts & timeouts, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config = {} +); + +/// As previous method creates session, but tooks it from pool, without and with proxy uri. +PooledHTTPSessionPtr makePooledHTTPSession( + const Poco::URI & uri, + const ConnectionTimeouts & timeouts, + size_t per_endpoint_pool_size, + bool wait_on_pool_size_limit = true); + +PooledHTTPSessionPtr makePooledHTTPSession( + const Poco::URI & uri, + const Poco::URI & proxy_uri, + const ConnectionTimeouts & timeouts, + size_t per_endpoint_pool_size, + bool wait_on_pool_size_limit = true); + +bool isRedirect(Poco::Net::HTTPResponse::HTTPStatus status); + +/** Used to receive response (response headers and possibly body) + * after sending data (request headers and possibly body). + * Throws exception in case of non HTTP_OK (200) response code. + * Returned istream lives in 'session' object. + */ +std::istream * receiveResponse( + Poco::Net::HTTPClientSession & session, const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response, bool allow_redirects); + +void assertResponseIsOk( + const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response, std::istream & istr, bool allow_redirects = false); +} diff --git a/contrib/clickhouse/src/IO/HTTPHeaderEntries.h b/contrib/clickhouse/src/IO/HTTPHeaderEntries.h new file mode 100644 index 0000000000..5862f1ead1 --- /dev/null +++ b/contrib/clickhouse/src/IO/HTTPHeaderEntries.h @@ -0,0 +1,18 @@ +#pragma once +#include <string> + +namespace DB +{ + +struct HTTPHeaderEntry +{ + std::string name; + std::string value; + + HTTPHeaderEntry(const std::string & name_, const std::string & value_) : name(name_), value(value_) {} + inline bool operator==(const HTTPHeaderEntry & other) const { return name == other.name && value == other.value; } +}; + +using HTTPHeaderEntries = std::vector<HTTPHeaderEntry>; + +} diff --git a/contrib/clickhouse/src/IO/HadoopSnappyReadBuffer.cpp b/contrib/clickhouse/src/IO/HadoopSnappyReadBuffer.cpp new file mode 100644 index 0000000000..37b709bc89 --- /dev/null +++ b/contrib/clickhouse/src/IO/HadoopSnappyReadBuffer.cpp @@ -0,0 +1,240 @@ +#include "clickhouse_config.h" + +#if USE_SNAPPY +#include <fcntl.h> +#include <sys/types.h> +#include <memory> +#include <string> +#include <cstring> + +#include <snappy-c.h> + +#include "HadoopSnappyReadBuffer.h" + +#include <IO/WithFileName.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int SNAPPY_UNCOMPRESS_FAILED; +} + + +inline bool HadoopSnappyDecoder::checkBufferLength(int max) const +{ + return buffer_length >= 0 && buffer_length < max; +} + +inline bool HadoopSnappyDecoder::checkAvailIn(size_t avail_in, int min) +{ + return avail_in >= static_cast<size_t>(min); +} + +inline void HadoopSnappyDecoder::copyToBuffer(size_t * avail_in, const char ** next_in) +{ + assert(*avail_in + buffer_length <= sizeof(buffer)); + + memcpy(buffer + buffer_length, *next_in, *avail_in); + + buffer_length += *avail_in; + *next_in += *avail_in; + *avail_in = 0; +} + + +inline uint32_t HadoopSnappyDecoder::readLength(const char * in) +{ + uint32_t b1 = *(reinterpret_cast<const uint8_t *>(in)); + uint32_t b2 = *(reinterpret_cast<const uint8_t *>(in + 1)); + uint32_t b3 = *(reinterpret_cast<const uint8_t *>(in + 2)); + uint32_t b4 = *(reinterpret_cast<const uint8_t *>(in + 3)); + uint32_t res = ((b1 << 24) + (b2 << 16) + (b3 << 8) + b4); + return res; +} + + +inline HadoopSnappyDecoder::Status HadoopSnappyDecoder::readLength(size_t * avail_in, const char ** next_in, int * length) +{ + char tmp[4] = {0}; + + if (!checkBufferLength(4)) + return Status::INVALID_INPUT; + memcpy(tmp, buffer, buffer_length); + + if (!checkAvailIn(*avail_in, 4 - buffer_length)) + { + copyToBuffer(avail_in, next_in); + return Status::NEEDS_MORE_INPUT; + } + memcpy(tmp + buffer_length, *next_in, 4 - buffer_length); + + *avail_in -= 4 - buffer_length; + *next_in += 4 - buffer_length; + buffer_length = 0; + *length = readLength(tmp); + return Status::OK; +} + +inline HadoopSnappyDecoder::Status HadoopSnappyDecoder::readBlockLength(size_t * avail_in, const char ** next_in) +{ + if (block_length < 0) + { + return readLength(avail_in, next_in, &block_length); + } + return Status::OK; +} + +inline HadoopSnappyDecoder::Status HadoopSnappyDecoder::readCompressedLength(size_t * avail_in, const char ** next_in) +{ + if (compressed_length < 0) + { + auto status = readLength(avail_in, next_in, &compressed_length); + if (unlikely(compressed_length > 0 && static_cast<size_t>(compressed_length) > sizeof(buffer))) + return Status::TOO_LARGE_COMPRESSED_BLOCK; + + return status; + } + return Status::OK; +} + +inline HadoopSnappyDecoder::Status +HadoopSnappyDecoder::readCompressedData(size_t * avail_in, const char ** next_in, size_t * avail_out, char ** next_out) +{ + if (!checkBufferLength(compressed_length)) + return Status::INVALID_INPUT; + + if (!checkAvailIn(*avail_in, compressed_length - buffer_length)) + { + copyToBuffer(avail_in, next_in); + return Status::NEEDS_MORE_INPUT; + } + + const char * compressed = nullptr; + if (buffer_length > 0) + { + compressed = buffer; + memcpy(buffer + buffer_length, *next_in, compressed_length - buffer_length); + } + else + { + compressed = const_cast<char *>(*next_in); + } + size_t uncompressed_length = *avail_out; + auto status = snappy_uncompress(compressed, compressed_length, *next_out, &uncompressed_length); + if (status != SNAPPY_OK) + { + return Status(status); + } + + *avail_in -= compressed_length - buffer_length; + *next_in += compressed_length - buffer_length; + *avail_out -= uncompressed_length; + *next_out += uncompressed_length; + + total_uncompressed_length += uncompressed_length; + compressed_length = -1; + buffer_length = 0; + return Status::OK; +} + +HadoopSnappyDecoder::Status HadoopSnappyDecoder::readBlock(size_t * avail_in, const char ** next_in, size_t * avail_out, char ** next_out) +{ + if (*avail_in == 0) + { + if (buffer_length == 0 && block_length < 0 && compressed_length < 0) + return Status::OK; + return Status::NEEDS_MORE_INPUT; + } + + HadoopSnappyDecoder::Status status = readBlockLength(avail_in, next_in); + if (status != Status::OK) + return status; + + while (total_uncompressed_length < block_length) + { + status = readCompressedLength(avail_in, next_in); + if (status != Status::OK) + return status; + + status = readCompressedData(avail_in, next_in, avail_out, next_out); + if (status != Status::OK) + return status; + } + if (total_uncompressed_length != block_length) + { + return Status::INVALID_INPUT; + } + return Status::OK; +} + +HadoopSnappyReadBuffer::HadoopSnappyReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char * existing_memory, size_t alignment) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment) + , decoder(std::make_unique<HadoopSnappyDecoder>()) + , in_available(0) + , in_data(nullptr) + , out_capacity(0) + , out_data(nullptr) + , eof(false) +{ +} + +HadoopSnappyReadBuffer::~HadoopSnappyReadBuffer() = default; + +bool HadoopSnappyReadBuffer::nextImpl() +{ + if (eof) + return false; + + do + { + if (!in_available) + { + in->nextIfAtEnd(); + in_available = in->buffer().end() - in->position(); + in_data = in->position(); + } + + if (decoder->result == Status::NEEDS_MORE_INPUT && (!in_available || in->eof())) + { + throw Exception( + ErrorCodes::SNAPPY_UNCOMPRESS_FAILED, + "hadoop snappy decode error: {}{}", + statusToString(decoder->result), + getExceptionEntryWithFileName(*in)); + } + + out_capacity = internal_buffer.size(); + out_data = internal_buffer.begin(); + decoder->result = decoder->readBlock(&in_available, &in_data, &out_capacity, &out_data); + + in->position() = in->buffer().end() - in_available; + } + while (decoder->result == Status::NEEDS_MORE_INPUT); + + working_buffer.resize(internal_buffer.size() - out_capacity); + + if (decoder->result == Status::OK) + { + decoder->reset(); + if (in->eof()) + { + eof = true; + return !working_buffer.empty(); + } + return true; + } + else if (decoder->result != Status::NEEDS_MORE_INPUT) + { + throw Exception( + ErrorCodes::SNAPPY_UNCOMPRESS_FAILED, + "hadoop snappy decode error: {}{}", + statusToString(decoder->result), + getExceptionEntryWithFileName(*in)); + } + return true; +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/HadoopSnappyReadBuffer.h b/contrib/clickhouse/src/IO/HadoopSnappyReadBuffer.h new file mode 100644 index 0000000000..bcc438489d --- /dev/null +++ b/contrib/clickhouse/src/IO/HadoopSnappyReadBuffer.h @@ -0,0 +1,117 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_SNAPPY + +#include <memory> +#include <IO/ReadBuffer.h> +#include <IO/CompressedReadBufferWrapper.h> + +namespace DB +{ + + +/* + * Hadoop-snappy format is one of the compression formats base on Snappy used in Hadoop. It uses its own framing format as follows: + * 1. A compressed file consists of one or more blocks. + * 2. A block consists of uncompressed length (big endian 4 byte integer) and one or more subblocks. + * 3. A subblock consists of compressed length (big endian 4 byte integer) and raw compressed data. + * + * HadoopSnappyDecoder implements the decompression of data compressed with hadoop-snappy format. + */ +class HadoopSnappyDecoder +{ +public: + enum class Status : int + { + OK = 0, + INVALID_INPUT = 1, + BUFFER_TOO_SMALL = 2, + NEEDS_MORE_INPUT = 3, + TOO_LARGE_COMPRESSED_BLOCK = 4, + }; + + HadoopSnappyDecoder() = default; + ~HadoopSnappyDecoder() = default; + + Status readBlock(size_t * avail_in, const char ** next_in, size_t * avail_out, char ** next_out); + + inline void reset() + { + buffer_length = 0; + block_length = -1; + compressed_length = -1; + total_uncompressed_length = 0; + } + + Status result = Status::OK; + +private: + inline bool checkBufferLength(int max) const; + inline static bool checkAvailIn(size_t avail_in, int min); + + inline void copyToBuffer(size_t * avail_in, const char ** next_in); + + inline static uint32_t readLength(const char * in); + inline Status readLength(size_t * avail_in, const char ** next_in, int * length); + inline Status readBlockLength(size_t * avail_in, const char ** next_in); + inline Status readCompressedLength(size_t * avail_in, const char ** next_in); + inline Status readCompressedData(size_t * avail_in, const char ** next_in, size_t * avail_out, char ** next_out); + + char buffer[DBMS_DEFAULT_BUFFER_SIZE] = {0}; + int buffer_length = 0; + + int block_length = -1; + int compressed_length = -1; + int total_uncompressed_length = 0; +}; + +/// HadoopSnappyReadBuffer implements read buffer for data compressed with hadoop-snappy format. +class HadoopSnappyReadBuffer : public CompressedReadBufferWrapper +{ +public: + using Status = HadoopSnappyDecoder::Status; + + inline static String statusToString(Status status) + { + switch (status) + { + case Status::OK: + return "OK"; + case Status::INVALID_INPUT: + return "INVALID_INPUT"; + case Status::BUFFER_TOO_SMALL: + return "BUFFER_TOO_SMALL"; + case Status::NEEDS_MORE_INPUT: + return "NEEDS_MORE_INPUT"; + case Status::TOO_LARGE_COMPRESSED_BLOCK: + return "TOO_LARGE_COMPRESSED_BLOCK"; + } + UNREACHABLE(); + } + + explicit HadoopSnappyReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~HadoopSnappyReadBuffer() override; + +private: + bool nextImpl() override; + + std::unique_ptr<HadoopSnappyDecoder> decoder; + + size_t in_available; + const char * in_data; + + size_t out_capacity; + char * out_data; + + bool eof; +}; + +} +#endif diff --git a/contrib/clickhouse/src/IO/HashingReadBuffer.h b/contrib/clickhouse/src/IO/HashingReadBuffer.h new file mode 100644 index 0000000000..a0a029e6f8 --- /dev/null +++ b/contrib/clickhouse/src/IO/HashingReadBuffer.h @@ -0,0 +1,55 @@ +#pragma once + +#include <IO/HashingWriteBuffer.h> +#include <IO/ReadBuffer.h> + +namespace DB +{ + +/* + * Calculates the hash from the read data. When reading, the data is read from the nested ReadBuffer. + * Small pieces are copied into its own memory. + */ +class HashingReadBuffer : public IHashingBuffer<ReadBuffer> +{ +public: + explicit HashingReadBuffer(ReadBuffer & in_, size_t block_size_ = DBMS_DEFAULT_HASHING_BLOCK_SIZE) + : IHashingBuffer<ReadBuffer>(block_size_), in(in_) + { + working_buffer = in.buffer(); + pos = in.position(); + hashing_begin = pos; + } + + uint128 getHash() + { + if (pos > hashing_begin) + { + calculateHash(hashing_begin, pos - hashing_begin); + hashing_begin = pos; + } + return IHashingBuffer<ReadBuffer>::getHash(); + } + +private: + bool nextImpl() override + { + if (pos > hashing_begin) + calculateHash(hashing_begin, pos - hashing_begin); + + in.position() = pos; + bool res = in.next(); + working_buffer = in.buffer(); + + // `pos` may be different from working_buffer.begin() when using sophisticated ReadBuffers. + pos = in.position(); + hashing_begin = pos; + + return res; + } + + ReadBuffer & in; + BufferBase::Position hashing_begin; +}; + +} diff --git a/contrib/clickhouse/src/IO/HashingWriteBuffer.cpp b/contrib/clickhouse/src/IO/HashingWriteBuffer.cpp new file mode 100644 index 0000000000..d2461d4f52 --- /dev/null +++ b/contrib/clickhouse/src/IO/HashingWriteBuffer.cpp @@ -0,0 +1,54 @@ +#include <IO/HashingWriteBuffer.h> +#include <iomanip> + + +namespace DB +{ + +/// computation of the hash depends on the partitioning of blocks +/// so you need to compute a hash of n complete pieces and one incomplete +template <typename Buffer> +void IHashingBuffer<Buffer>::calculateHash(DB::BufferBase::Position data, size_t len) +{ + if (len) + { + /// if the data is less than `block_size`, then put them into buffer and calculate hash later + if (block_pos + len < block_size) + { + memcpy(&BufferWithOwnMemory<Buffer>::memory[block_pos], data, len); + block_pos += len; + } + else + { + /// if something is already written to the buffer, then we'll add it + if (block_pos) + { + size_t n = block_size - block_pos; + memcpy(&BufferWithOwnMemory<Buffer>::memory[block_pos], data, n); + append(&BufferWithOwnMemory<Buffer>::memory[0]); + len -= n; + data += n; + block_pos = 0; + } + + while (len >= block_size) + { + append(data); + len -= block_size; + data += block_size; + } + + /// write the remainder to its buffer + if (len) + { + memcpy(&BufferWithOwnMemory<Buffer>::memory[0], data, len); + block_pos = len; + } + } + } +} + +template class IHashingBuffer<DB::ReadBuffer>; +template class IHashingBuffer<DB::WriteBuffer>; + +} diff --git a/contrib/clickhouse/src/IO/HashingWriteBuffer.h b/contrib/clickhouse/src/IO/HashingWriteBuffer.h new file mode 100644 index 0000000000..8edfa45a6b --- /dev/null +++ b/contrib/clickhouse/src/IO/HashingWriteBuffer.h @@ -0,0 +1,92 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <IO/ReadHelpers.h> +#include <city.h> + +#define DBMS_DEFAULT_HASHING_BLOCK_SIZE 2048ULL + + +namespace DB +{ + +template <typename Buffer> +class IHashingBuffer : public BufferWithOwnMemory<Buffer> +{ +public: + using uint128 = CityHash_v1_0_2::uint128; + + explicit IHashingBuffer(size_t block_size_ = DBMS_DEFAULT_HASHING_BLOCK_SIZE) + : BufferWithOwnMemory<Buffer>(block_size_), block_pos(0), block_size(block_size_), state(0, 0) + { + } + + uint128 getHash() + { + if (block_pos) + return CityHash_v1_0_2::CityHash128WithSeed(BufferWithOwnMemory<Buffer>::memory.data(), block_pos, state); + else + return state; + } + + void append(DB::BufferBase::Position data) + { + state = CityHash_v1_0_2::CityHash128WithSeed(data, block_size, state); + } + + /// computation of the hash depends on the partitioning of blocks + /// so you need to compute a hash of n complete pieces and one incomplete + void calculateHash(DB::BufferBase::Position data, size_t len); + +protected: + size_t block_pos; + size_t block_size; + uint128 state; +}; + +/** Computes the hash from the data to write and passes it to the specified WriteBuffer. + * The buffer of the nested WriteBuffer is used as the main buffer. + */ +class HashingWriteBuffer : public IHashingBuffer<WriteBuffer> +{ +private: + WriteBuffer & out; + + void nextImpl() override + { + size_t len = offset(); + + Position data = working_buffer.begin(); + calculateHash(data, len); + + out.position() = pos; + out.next(); + working_buffer = out.buffer(); + } + +public: + explicit HashingWriteBuffer( + WriteBuffer & out_, + size_t block_size_ = DBMS_DEFAULT_HASHING_BLOCK_SIZE) + : IHashingBuffer<DB::WriteBuffer>(block_size_), out(out_) + { + out.next(); /// If something has already been written to `out` before us, we will not let the remains of this data affect the hash. + working_buffer = out.buffer(); + pos = working_buffer.begin(); + state = uint128(0, 0); + } + + void sync() override + { + out.sync(); + } + + uint128 getHash() + { + next(); + return IHashingBuffer<WriteBuffer>::getHash(); + } +}; + +} diff --git a/contrib/clickhouse/src/IO/IReadableWriteBuffer.h b/contrib/clickhouse/src/IO/IReadableWriteBuffer.h new file mode 100644 index 0000000000..539825e3a8 --- /dev/null +++ b/contrib/clickhouse/src/IO/IReadableWriteBuffer.h @@ -0,0 +1,32 @@ +#pragma once +#include <memory> +#include <IO/ReadBuffer.h> + +namespace DB +{ + +struct IReadableWriteBuffer +{ + /// At the first time returns getReadBufferImpl(). Next calls return nullptr. + inline std::shared_ptr<ReadBuffer> tryGetReadBuffer() + { + if (!can_reread) + return nullptr; + + can_reread = false; + return getReadBufferImpl(); + } + + virtual ~IReadableWriteBuffer() = default; + +protected: + + /// Creates read buffer from current write buffer. + /// Returned buffer points to the first byte of original buffer. + /// Original stream becomes invalid. + virtual std::shared_ptr<ReadBuffer> getReadBufferImpl() = 0; + + bool can_reread = true; +}; + +} diff --git a/contrib/clickhouse/src/IO/IResourceManager.h b/contrib/clickhouse/src/IO/IResourceManager.h new file mode 100644 index 0000000000..f084a903cb --- /dev/null +++ b/contrib/clickhouse/src/IO/IResourceManager.h @@ -0,0 +1,53 @@ +#pragma once + +#include <IO/ResourceLink.h> + +#include <Poco/Util/AbstractConfiguration.h> + +#include <boost/noncopyable.hpp> + +#include <memory> +#include <unordered_map> + +namespace DB +{ + +/* + * Instance of derived class holds everything required for resource consumption, + * including resources currently registered at `SchedulerRoot`. This is required to avoid + * problems during configuration update. Do not hold instances longer than required. + * Should be created on query start and destructed when query is done. + */ +class IClassifier : private boost::noncopyable +{ +public: + virtual ~IClassifier() {} + + /// Returns ResouceLink that should be used to access resource. + /// Returned link is valid until classifier destruction. + virtual ResourceLink get(const String & resource_name) = 0; +}; + +using ClassifierPtr = std::shared_ptr<IClassifier>; + +/* + * Represents control plane of resource scheduling. Derived class is responsible for reading + * configuration, creating all required `ISchedulerNode` objects and + * managing their lifespan. + */ +class IResourceManager : private boost::noncopyable +{ +public: + virtual ~IResourceManager() {} + + /// Initialize or reconfigure manager. + virtual void updateConfiguration(const Poco::Util::AbstractConfiguration & config) = 0; + + /// Obtain a classifier instance required to get access to resources. + /// Note that it holds resource configuration, so should be destructed when query is done. + virtual ClassifierPtr acquire(const String & classifier_name) = 0; +}; + +using ResourceManagerPtr = std::shared_ptr<IResourceManager>; + +} diff --git a/contrib/clickhouse/src/IO/ISchedulerConstraint.h b/contrib/clickhouse/src/IO/ISchedulerConstraint.h new file mode 100644 index 0000000000..47f6905e26 --- /dev/null +++ b/contrib/clickhouse/src/IO/ISchedulerConstraint.h @@ -0,0 +1,55 @@ +#pragma once + +#include <IO/ISchedulerNode.h> + +namespace DB +{ + +/* + * Constraint defined on the set of requests in consumption state. + * It allows to track two events: + * - dequeueRequest(): resource consumption begins + * - finishRequest(): resource consumption finishes + * This allows to keep track of in-flight requests and implement different constraints (e.g. in-flight limit). + * When constraint is violated, node must be deactivated by dequeueRequest() returning `false`. + * When constraint is again satisfied, scheduleActivation() is called from finishRequest(). + * + * Derived class behaviour requirements: + * - dequeueRequest() must fill `request->constraint` iff it is nullptr; + * - finishRequest() must be recursive: call to `parent_constraint->finishRequest()`. + */ +class ISchedulerConstraint : public ISchedulerNode +{ +public: + ISchedulerConstraint(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + : ISchedulerNode(event_queue_, config, config_prefix) + {} + + /// Resource consumption by `request` is finished. + /// Should be called outside of scheduling subsystem, implementation must be thread-safe. + virtual void finishRequest(ResourceRequest * request) = 0; + + void setParent(ISchedulerNode * parent_) override + { + ISchedulerNode::setParent(parent_); + + // Assign `parent_constraint` to the nearest parent derived from ISchedulerConstraint + for (ISchedulerNode * node = parent_; node != nullptr; node = node->parent) + { + if (auto * constraint = dynamic_cast<ISchedulerConstraint *>(node)) + { + parent_constraint = constraint; + break; + } + } + } + +protected: + // Reference to nearest parent that is also derived from ISchedulerConstraint. + // Request can traverse through multiple constraints while being dequeue from hierarchy, + // while finishing request should traverse the same chain in reverse order. + // NOTE: it must be immutable after initialization, because it is accessed in not thread-safe way from finishRequest() + ISchedulerConstraint * parent_constraint = nullptr; +}; + +} diff --git a/contrib/clickhouse/src/IO/ISchedulerNode.h b/contrib/clickhouse/src/IO/ISchedulerNode.h new file mode 100644 index 0000000000..1c33c03374 --- /dev/null +++ b/contrib/clickhouse/src/IO/ISchedulerNode.h @@ -0,0 +1,222 @@ +#pragma once + +#include <Common/ErrorCodes.h> +#include <Common/Exception.h> +#include <Common/Priority.h> + +#include <IO/ResourceRequest.h> +#include <Poco/Util/AbstractConfiguration.h> +#include <Poco/Util/XMLConfiguration.h> + +#include <boost/noncopyable.hpp> + +#include <deque> +#include <functional> +#include <memory> +#include <mutex> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +class ISchedulerNode; + +inline const Poco::Util::AbstractConfiguration & emptyConfig() +{ + static Poco::AutoPtr<Poco::Util::XMLConfiguration> config = new Poco::Util::XMLConfiguration(); + return *config; +} + +/* + * Info read and write for scheduling purposes by parent + */ +struct SchedulerNodeInfo +{ + double weight = 1.0; /// Weight of this node among it's siblings + Priority priority; /// Priority of this node among it's siblings (lower value means higher priority) + + /// Arbitrary data accessed/stored by parent + union { + size_t idx; + void * ptr; + } parent; + + SchedulerNodeInfo() = default; + + explicit SchedulerNodeInfo(const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + { + setWeight(config.getDouble(config_prefix + ".weight", weight)); + setPriority(config.getInt64(config_prefix + ".priority", priority)); + } + + void setWeight(double value) + { + if (value <= 0 || !isfinite(value)) + throw Exception( + ErrorCodes::INVALID_SCHEDULER_NODE, + "Negative and non-finite node weights are not allowed: {}", + value); + weight = value; + } + + void setPriority(Int64 value) + { + priority.value = value; + } +}; + +/* + * Simple waitable thread-safe FIFO task queue. + * Intended to hold postponed events for later handling (usually by scheduler thread). + */ +class EventQueue +{ +public: + using Event = std::function<void()>; + + void enqueue(Event&& event) + { + std::unique_lock lock{mutex}; + bool was_empty = queue.empty(); + queue.emplace_back(event); + if (was_empty) + pending.notify_one(); + } + + /// Process single event if it exists + /// Returns `true` iff event has been processed + bool tryProcess() + { + std::unique_lock lock{mutex}; + if (queue.empty()) + return false; + Event event = std::move(queue.front()); + queue.pop_front(); + lock.unlock(); // do not hold queue mutext while processing events + event(); + return true; + } + + /// Wait for single event (if not available) and process it + void process() + { + std::unique_lock lock{mutex}; + pending.wait(lock, [&] { return !queue.empty(); }); + Event event = std::move(queue.front()); + queue.pop_front(); + lock.unlock(); // do not hold queue mutext while processing events + event(); + } + +private: + std::mutex mutex; + std::condition_variable pending; + std::deque<Event> queue; +}; + +/* + * Node of hierarchy for scheduling requests for resource. Base class for all + * kinds of scheduling elements (queues, policies, constraints and schedulers). + * + * Root node is a scheduler, which has it's thread to dequeue requests, + * execute requests (see ResourceRequest) and process events in a thread-safe manner. + * Immediate children of the scheduler represent independent resources. + * Each resource has it's own hierarchy to achieve required scheduling policies. + * Non-leaf nodes do not hold requests, but keep scheduling state + * (e.g. consumption history, amount of in-flight requests, etc). + * Leafs of hierarchy are queues capable of holding pending requests. + * + * scheduler (SchedulerRoot) + * / \ + * constraint constraint (SemaphoreConstraint) + * | | + * policy policy (PriorityPolicy) + * / \ / \ + * q1 q2 q3 q4 (FifoQueue) + * + * Dequeueing request from an inner node will dequeue request from one of active leaf-queues in its subtree. + * Node is considered to be active iff: + * - it has at least one pending request in one of leaves of it's subtree; + * - and enforced constraints, if any, are satisfied + * (e.g. amount of concurrent requests is not greater than some number). + * + * All methods must be called only from scheduler thread for thread-safety. + */ +class ISchedulerNode : private boost::noncopyable +{ +public: + ISchedulerNode(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + : event_queue(event_queue_) + , info(config, config_prefix) + {} + + virtual ~ISchedulerNode() {} + + // Checks if two nodes configuration is equal + virtual bool equals(ISchedulerNode * other) = 0; + + /// Attach new child + virtual void attachChild(const std::shared_ptr<ISchedulerNode> & child) = 0; + + /// Detach and destroy child + virtual void removeChild(ISchedulerNode * child) = 0; + + /// Get attached child by name + virtual ISchedulerNode * getChild(const String & child_name) = 0; + + /// Activation of child due to the first pending request + /// Should be called on leaf node (i.e. queue) to propagate activation signal through chain to the root + virtual void activateChild(ISchedulerNode * child) = 0; + + /// Returns true iff node is active + virtual bool isActive() = 0; + + /// Returns the first request to be executed as the first component of resuting pair. + /// The second pair component is `true` iff node is still active after dequeueing. + virtual std::pair<ResourceRequest *, bool> dequeueRequest() = 0; + + /// Returns full path string using names of every parent + String getPath() + { + String result; + ISchedulerNode * ptr = this; + while (ptr->parent) + { + result = "/" + ptr->basename + result; + ptr = ptr->parent; + } + return result.empty() ? "/" : result; + } + + /// Attach to a parent (used by attachChild) + virtual void setParent(ISchedulerNode * parent_) + { + parent = parent_; + } + +protected: + /// Notify parents about the first pending request or constraint becoming satisfied. + /// Postponed to be handled in scheduler thread, so it is intended to be called from outside. + void scheduleActivation() + { + if (likely(parent)) + { + event_queue->enqueue([this] { parent->activateChild(this); }); + } + } + +public: + EventQueue * const event_queue; + String basename; + SchedulerNodeInfo info; + ISchedulerNode * parent = nullptr; +}; + +using SchedulerNodePtr = std::shared_ptr<ISchedulerNode>; + +} diff --git a/contrib/clickhouse/src/IO/ISchedulerQueue.h b/contrib/clickhouse/src/IO/ISchedulerQueue.h new file mode 100644 index 0000000000..fc2f3943d2 --- /dev/null +++ b/contrib/clickhouse/src/IO/ISchedulerQueue.h @@ -0,0 +1,60 @@ +#pragma once + +#include <IO/ISchedulerNode.h> +#include <IO/ResourceBudget.h> +#include <IO/ResourceRequest.h> + +#include <memory> + + +namespace DB +{ + +/* + * Queue for pending requests for specific resource, leaf of hierarchy. + * Note that every queue has budget associated with it. + */ +class ISchedulerQueue : public ISchedulerNode +{ +public: + explicit ISchedulerQueue(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + : ISchedulerNode(event_queue_, config, config_prefix) + {} + + // Wrapper for `enqueueRequest()` that should be used to account for available resource budget + void enqueueRequestUsingBudget(ResourceRequest * request) + { + request->cost = budget.ask(request->cost); + enqueueRequest(request); + } + + // Should be called to account for difference between real and estimated costs + void adjustBudget(ResourceCost estimated_cost, ResourceCost real_cost) + { + budget.adjust(estimated_cost, real_cost); + } + + // Adjust budget to account for extra consumption of `cost` resource units + void consumeBudget(ResourceCost cost) + { + adjustBudget(0, cost); + } + + // Adjust budget to account for requested, but not consumed `cost` resource units + void accumulateBudget(ResourceCost cost) + { + adjustBudget(cost, 0); + } + + /// Enqueue new request to be executed using underlying resource. + /// Should be called outside of scheduling subsystem, implementation must be thread-safe. + virtual void enqueueRequest(ResourceRequest * request) = 0; + +private: + // Allows multiple consumers to synchronize with common "debit/credit" balance. + // 1) (positive) to avoid wasting of allocated but not used resource (e.g in case of a failure); + // 2) (negative) to account for overconsumption (e.g. if cost is not know in advance and estimation from below is applied). + ResourceBudget budget; +}; + +} diff --git a/contrib/clickhouse/src/IO/LZMADeflatingWriteBuffer.cpp b/contrib/clickhouse/src/IO/LZMADeflatingWriteBuffer.cpp new file mode 100644 index 0000000000..c70ec1507c --- /dev/null +++ b/contrib/clickhouse/src/IO/LZMADeflatingWriteBuffer.cpp @@ -0,0 +1,123 @@ +#include <IO/LZMADeflatingWriteBuffer.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LZMA_STREAM_ENCODER_FAILED; +} + +LZMADeflatingWriteBuffer::LZMADeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, int compression_level, size_t buf_size, char * existing_memory, size_t alignment) + : WriteBufferWithOwnMemoryDecorator(std::move(out_), buf_size, existing_memory, alignment) +{ + + lstr = LZMA_STREAM_INIT; + lstr.allocator = nullptr; + lstr.next_in = nullptr; + lstr.avail_in = 0; + lstr.next_out = nullptr; + lstr.avail_out = 0; + + // options for further compression + lzma_options_lzma opt_lzma2; + if (lzma_lzma_preset(&opt_lzma2, compression_level)) + throw Exception(ErrorCodes::LZMA_STREAM_ENCODER_FAILED, "lzma preset failed: lzma version: {}", LZMA_VERSION_STRING); + + + // LZMA_FILTER_X86 - + // LZMA2 - codec for *.xz files compression; LZMA is not suitable for this purpose + // VLI - variable length integer (in *.xz most integers encoded as VLI) + // LZMA_VLI_UNKNOWN (UINT64_MAX) - VLI value to denote that the value is unknown + lzma_filter filters[] = { + {.id = LZMA_FILTER_X86, .options = nullptr}, + {.id = LZMA_FILTER_LZMA2, .options = &opt_lzma2}, + {.id = LZMA_VLI_UNKNOWN, .options = nullptr}, + }; + lzma_ret ret = lzma_stream_encoder(&lstr, filters, LZMA_CHECK_CRC64); + + if (ret != LZMA_OK) + throw Exception( + ErrorCodes::LZMA_STREAM_ENCODER_FAILED, + "lzma stream encoder init failed: error code: {} lzma version: {}", + ret, + LZMA_VERSION_STRING); +} + +LZMADeflatingWriteBuffer::~LZMADeflatingWriteBuffer() = default; + +void LZMADeflatingWriteBuffer::nextImpl() +{ + if (!offset()) + return; + + lstr.next_in = reinterpret_cast<unsigned char *>(working_buffer.begin()); + lstr.avail_in = offset(); + + try + { + lzma_action action = LZMA_RUN; + do + { + out->nextIfAtEnd(); + lstr.next_out = reinterpret_cast<unsigned char *>(out->position()); + lstr.avail_out = out->buffer().end() - out->position(); + + lzma_ret ret = lzma_code(&lstr, action); + out->position() = out->buffer().end() - lstr.avail_out; + + if (ret == LZMA_STREAM_END) + return; + + if (ret != LZMA_OK) + throw Exception( + ErrorCodes::LZMA_STREAM_ENCODER_FAILED, + "lzma stream encoding failed: error code: {}; lzma_version: {}", + ret, + LZMA_VERSION_STRING); + + } while (lstr.avail_in > 0 || lstr.avail_out == 0); + } + catch (...) + { + /// Do not try to write next time after exception. + out->position() = out->buffer().begin(); + throw; + } +} + +void LZMADeflatingWriteBuffer::finalizeBefore() +{ + next(); + + do + { + out->nextIfAtEnd(); + lstr.next_out = reinterpret_cast<unsigned char *>(out->position()); + lstr.avail_out = out->buffer().end() - out->position(); + + lzma_ret ret = lzma_code(&lstr, LZMA_FINISH); + out->position() = out->buffer().end() - lstr.avail_out; + + if (ret == LZMA_STREAM_END) + { + return; + } + + if (ret != LZMA_OK) + throw Exception( + ErrorCodes::LZMA_STREAM_ENCODER_FAILED, + "lzma stream encoding failed: error code: {}; lzma version: {}", + ret, + LZMA_VERSION_STRING); + + } while (lstr.avail_out == 0); +} + +void LZMADeflatingWriteBuffer::finalizeAfter() +{ + lzma_end(&lstr); +} + +} + diff --git a/contrib/clickhouse/src/IO/LZMADeflatingWriteBuffer.h b/contrib/clickhouse/src/IO/LZMADeflatingWriteBuffer.h new file mode 100644 index 0000000000..2e135455e0 --- /dev/null +++ b/contrib/clickhouse/src/IO/LZMADeflatingWriteBuffer.h @@ -0,0 +1,35 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteBufferDecorator.h> + +#include <lzma.h> + + +namespace DB +{ + +/// Performs compression using lzma library and writes compressed data to out_ WriteBuffer. +class LZMADeflatingWriteBuffer : public WriteBufferWithOwnMemoryDecorator +{ +public: + LZMADeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + int compression_level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~LZMADeflatingWriteBuffer() override; + +private: + void nextImpl() override; + + void finalizeBefore() override; + void finalizeAfter() override; + + lzma_stream lstr; +}; + +} diff --git a/contrib/clickhouse/src/IO/LZMAInflatingReadBuffer.cpp b/contrib/clickhouse/src/IO/LZMAInflatingReadBuffer.cpp new file mode 100644 index 0000000000..a6f3c74ae7 --- /dev/null +++ b/contrib/clickhouse/src/IO/LZMAInflatingReadBuffer.cpp @@ -0,0 +1,99 @@ +#include <IO/LZMAInflatingReadBuffer.h> +#include <IO/WithFileName.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LZMA_STREAM_DECODER_FAILED; +} + +LZMAInflatingReadBuffer::LZMAInflatingReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char * existing_memory, size_t alignment) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment), eof_flag(false) +{ + lstr = LZMA_STREAM_INIT; + lstr.allocator = nullptr; + lstr.next_in = nullptr; + lstr.avail_in = 0; + lstr.next_out = nullptr; + lstr.avail_out = 0; + + // 500 mb + uint64_t memlimit = 500ULL << 20; + + lzma_ret ret = lzma_stream_decoder(&lstr, memlimit, LZMA_CONCATENATED); + // lzma does not provide api for converting error code to string unlike zlib + if (ret != LZMA_OK) + throw Exception( + ErrorCodes::LZMA_STREAM_DECODER_FAILED, + "lzma_stream_decoder initialization failed: error code: {}; lzma version: {}", + ret, + LZMA_VERSION_STRING); +} + +LZMAInflatingReadBuffer::~LZMAInflatingReadBuffer() +{ + lzma_end(&lstr); +} + +bool LZMAInflatingReadBuffer::nextImpl() +{ + if (eof_flag) + return false; + + lzma_action action = LZMA_RUN; + lzma_ret ret; + + do + { + if (!lstr.avail_in) + { + in->nextIfAtEnd(); + lstr.next_in = reinterpret_cast<unsigned char *>(in->position()); + lstr.avail_in = in->buffer().end() - in->position(); + } + + if (in->eof()) + { + action = LZMA_FINISH; + } + + lstr.next_out = reinterpret_cast<unsigned char *>(internal_buffer.begin()); + lstr.avail_out = internal_buffer.size(); + + ret = lzma_code(&lstr, action); + in->position() = in->buffer().end() - lstr.avail_in; + } + while (ret == LZMA_OK && lstr.avail_out == internal_buffer.size()); + + working_buffer.resize(internal_buffer.size() - lstr.avail_out); + + if (ret == LZMA_STREAM_END) + { + if (in->eof()) + { + eof_flag = true; + return !working_buffer.empty(); + } + else + { + throw Exception( + ErrorCodes::LZMA_STREAM_DECODER_FAILED, + "lzma decoder finished, but input stream has not exceeded: error code: {}; lzma version: {}{}", + ret, + LZMA_VERSION_STRING, + getExceptionEntryWithFileName(*in)); + } + } + + if (ret != LZMA_OK) + throw Exception( + ErrorCodes::LZMA_STREAM_DECODER_FAILED, + "lzma_stream_decoder failed: error code: error code {}; lzma version: {}{}", + ret, + LZMA_VERSION_STRING, + getExceptionEntryWithFileName(*in)); + + return true; +} +} diff --git a/contrib/clickhouse/src/IO/LZMAInflatingReadBuffer.h b/contrib/clickhouse/src/IO/LZMAInflatingReadBuffer.h new file mode 100644 index 0000000000..5fd29d6f7d --- /dev/null +++ b/contrib/clickhouse/src/IO/LZMAInflatingReadBuffer.h @@ -0,0 +1,29 @@ +#pragma once + +#include <IO/CompressedReadBufferWrapper.h> +#include <IO/ReadBuffer.h> + +#include <lzma.h> + +namespace DB +{ + +class LZMAInflatingReadBuffer : public CompressedReadBufferWrapper +{ +public: + explicit LZMAInflatingReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~LZMAInflatingReadBuffer() override; + +private: + bool nextImpl() override; + + lzma_stream lstr; + bool eof_flag; +}; + +} diff --git a/contrib/clickhouse/src/IO/LimitReadBuffer.cpp b/contrib/clickhouse/src/IO/LimitReadBuffer.cpp new file mode 100644 index 0000000000..e14112f8d1 --- /dev/null +++ b/contrib/clickhouse/src/IO/LimitReadBuffer.cpp @@ -0,0 +1,99 @@ +#include <IO/LimitReadBuffer.h> + +#include <Common/Exception.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LIMIT_EXCEEDED; + extern const int CANNOT_READ_ALL_DATA; +} + + +bool LimitReadBuffer::nextImpl() +{ + assert(position() >= in->position()); + + /// Let underlying buffer calculate read bytes in `next()` call. + in->position() = position(); + + if (bytes >= limit) + { + if (exact_limit && bytes == *exact_limit) + return false; + + if (exact_limit && bytes != *exact_limit) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Unexpected data, got {} bytes, expected {}", bytes, *exact_limit); + + if (throw_exception) + throw Exception(ErrorCodes::LIMIT_EXCEEDED, "Limit for LimitReadBuffer exceeded: {}", exception_message); + + return false; + } + + if (!in->next()) + { + if (exact_limit && bytes != *exact_limit) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Unexpected EOF, got {} of {} bytes", bytes, *exact_limit); + /// Clearing the buffer with existing data. + set(in->position(), 0); + return false; + } + + working_buffer = in->buffer(); + + if (limit - bytes < working_buffer.size()) + working_buffer.resize(limit - bytes); + + return true; +} + + +LimitReadBuffer::LimitReadBuffer(ReadBuffer * in_, bool owns, UInt64 limit_, bool throw_exception_, + std::optional<size_t> exact_limit_, std::string exception_message_) + : ReadBuffer(in_ ? in_->position() : nullptr, 0) + , in(in_) + , owns_in(owns) + , limit(limit_) + , throw_exception(throw_exception_) + , exact_limit(exact_limit_) + , exception_message(std::move(exception_message_)) +{ + assert(in); + + size_t remaining_bytes_in_buffer = in->buffer().end() - in->position(); + if (remaining_bytes_in_buffer > limit) + remaining_bytes_in_buffer = limit; + + working_buffer = Buffer(in->position(), in->position() + remaining_bytes_in_buffer); +} + + +LimitReadBuffer::LimitReadBuffer(ReadBuffer & in_, UInt64 limit_, bool throw_exception_, + std::optional<size_t> exact_limit_, std::string exception_message_) + : LimitReadBuffer(&in_, false, limit_, throw_exception_, exact_limit_, exception_message_) +{ +} + + +LimitReadBuffer::LimitReadBuffer(std::unique_ptr<ReadBuffer> in_, UInt64 limit_, bool throw_exception_, + std::optional<size_t> exact_limit_, std::string exception_message_) + : LimitReadBuffer(in_.release(), true, limit_, throw_exception_, exact_limit_, exception_message_) +{ +} + + +LimitReadBuffer::~LimitReadBuffer() +{ + /// Update underlying buffer's position in case when limit wasn't reached. + if (!working_buffer.empty()) + in->position() = position(); + + if (owns_in) + delete in; +} + +} diff --git a/contrib/clickhouse/src/IO/LimitReadBuffer.h b/contrib/clickhouse/src/IO/LimitReadBuffer.h new file mode 100644 index 0000000000..15885c1d85 --- /dev/null +++ b/contrib/clickhouse/src/IO/LimitReadBuffer.h @@ -0,0 +1,36 @@ +#pragma once + +#include <base/types.h> +#include <IO/ReadBuffer.h> + + +namespace DB +{ + +/** Allows to read from another ReadBuffer no more than the specified number of bytes. + * Note that the nested ReadBuffer may read slightly more data internally to fill its buffer. + */ +class LimitReadBuffer : public ReadBuffer +{ +public: + LimitReadBuffer(ReadBuffer & in_, UInt64 limit_, bool throw_exception_, + std::optional<size_t> exact_limit_, std::string exception_message_ = {}); + LimitReadBuffer(std::unique_ptr<ReadBuffer> in_, UInt64 limit_, bool throw_exception_, std::optional<size_t> exact_limit_, + std::string exception_message_ = {}); + ~LimitReadBuffer() override; + +private: + ReadBuffer * in; + bool owns_in; + + UInt64 limit; + bool throw_exception; + std::optional<size_t> exact_limit; + std::string exception_message; + + LimitReadBuffer(ReadBuffer * in_, bool owns, UInt64 limit_, bool throw_exception_, std::optional<size_t> exact_limit_, std::string exception_message_); + + bool nextImpl() override; +}; + +} diff --git a/contrib/clickhouse/src/IO/LimitSeekableReadBuffer.cpp b/contrib/clickhouse/src/IO/LimitSeekableReadBuffer.cpp new file mode 100644 index 0000000000..587138cb2c --- /dev/null +++ b/contrib/clickhouse/src/IO/LimitSeekableReadBuffer.cpp @@ -0,0 +1,125 @@ +#include <IO/LimitSeekableReadBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; +} + +LimitSeekableReadBuffer::LimitSeekableReadBuffer(SeekableReadBuffer & in_, UInt64 start_offset_, UInt64 limit_size_) + : LimitSeekableReadBuffer(wrapSeekableReadBufferReference(in_), start_offset_, limit_size_) +{ +} + +LimitSeekableReadBuffer::LimitSeekableReadBuffer(std::unique_ptr<SeekableReadBuffer> in_, UInt64 start_offset_, UInt64 limit_size_) + : SeekableReadBuffer(in_->position(), 0) + , in(std::move(in_)) + , min_offset(start_offset_) + , max_offset(start_offset_ + limit_size_) + , need_seek(min_offset) /// We always start reading from `min_offset`. +{ +} + +bool LimitSeekableReadBuffer::nextImpl() +{ + /// First let the nested buffer know the current position in the buffer (otherwise `in->eof()` or `in->seek()` below can work incorrectly). + in->position() = position(); + + if (need_seek) + { + /// Do actual seek. + if (in->seek(*need_seek, SEEK_SET) != static_cast<off_t>(*need_seek)) + { + /// Failed to seek, maybe because the new seek position is located after EOF. + set(in->position(), 0); + return false; + } + need_seek.reset(); + } + + off_t seek_pos = in->getPosition(); + off_t offset_after_min = seek_pos - min_offset; + off_t available_before_max = max_offset - seek_pos; + + if (offset_after_min < 0 || available_before_max <= 0) + { + /// Limit reached. + set(in->position(), 0); + return false; + } + + if (in->eof()) /// `in->eof()` can call `in->next()` + { + /// EOF reached. + set(in->position(), 0); + return false; + } + + /// in->eof() shouldn't change the seek position. + chassert(seek_pos == in->getPosition()); + + /// Adjust the beginning and the end of the working buffer. + /// Because we don't want to read before `min_offset` or after `max_offset`. + auto * ptr = in->position(); + auto * begin = in->buffer().begin(); + auto * end = in->buffer().end(); + + if (ptr - begin > offset_after_min) + begin = ptr - offset_after_min; + if (end - ptr > available_before_max) + end = ptr + available_before_max; + + BufferBase::set(begin, end - begin, ptr - begin); + chassert(position() == ptr && available()); + + return true; +} + +off_t LimitSeekableReadBuffer::seek(off_t off, int whence) +{ + off_t new_position; + off_t current_position = getPosition(); + if (whence == SEEK_SET) + new_position = off; + else if (whence == SEEK_CUR) + new_position = current_position + off; + else + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Seek expects SEEK_SET or SEEK_CUR as whence"); + + if (new_position < 0 || new_position + min_offset > max_offset) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Seek shift out of bounds"); + + off_t position_change = new_position - current_position; + if ((buffer().begin() <= pos + position_change) && (pos + position_change <= buffer().end())) + { + /// Position is still inside the buffer. + pos += position_change; + chassert(pos >= working_buffer.begin()); + chassert(pos <= working_buffer.end()); + return new_position; + } + + /// Actual seek in the nested buffer will be performed in nextImpl(). + need_seek = new_position + min_offset; + + /// Set the size of the working buffer to zero so next call next() would call nextImpl(). + set(in->position(), 0); + + return new_position; +} + +off_t LimitSeekableReadBuffer::getPosition() +{ + if (need_seek) + return *need_seek - min_offset; + + /// We have to do that because `in->getPosition()` below most likely needs to know the current position in the buffer. + in->position() = position(); + + return in->getPosition() - min_offset; +} + +} diff --git a/contrib/clickhouse/src/IO/LimitSeekableReadBuffer.h b/contrib/clickhouse/src/IO/LimitSeekableReadBuffer.h new file mode 100644 index 0000000000..61b307c522 --- /dev/null +++ b/contrib/clickhouse/src/IO/LimitSeekableReadBuffer.h @@ -0,0 +1,33 @@ +#pragma once + +#include <base/types.h> +#include <IO/SeekableReadBuffer.h> + + +namespace DB +{ + +/** Allows to read from another SeekableReadBuffer up to `limit_size` bytes starting from `start_offset`. + * Note that the nested buffer may read slightly more data internally to fill its buffer. + */ +class LimitSeekableReadBuffer : public SeekableReadBuffer +{ +public: + LimitSeekableReadBuffer(SeekableReadBuffer & in_, UInt64 start_offset_, UInt64 limit_size_); + LimitSeekableReadBuffer(std::unique_ptr<SeekableReadBuffer> in_, UInt64 start_offset_, UInt64 limit_size_); + + /// Returns adjusted position, i.e. returns `3` if the position in the nested buffer is `start_offset + 3`. + off_t getPosition() override; + + off_t seek(off_t off, int whence) override; + +private: + std::unique_ptr<SeekableReadBuffer> in; + off_t min_offset; + off_t max_offset; + std::optional<off_t> need_seek; + + bool nextImpl() override; +}; + +} diff --git a/contrib/clickhouse/src/IO/Lz4DeflatingWriteBuffer.cpp b/contrib/clickhouse/src/IO/Lz4DeflatingWriteBuffer.cpp new file mode 100644 index 0000000000..27c945f92c --- /dev/null +++ b/contrib/clickhouse/src/IO/Lz4DeflatingWriteBuffer.cpp @@ -0,0 +1,161 @@ +#include <IO/Lz4DeflatingWriteBuffer.h> +#include <Common/Exception.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LZ4_ENCODER_FAILED; +} + +Lz4DeflatingWriteBuffer::Lz4DeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, int compression_level, size_t buf_size, char * existing_memory, size_t alignment) + : WriteBufferWithOwnMemoryDecorator(std::move(out_), buf_size, existing_memory, alignment) + , in_data(nullptr) + , out_data(nullptr) + , in_capacity(0) + , out_capacity(0) +{ + kPrefs = { + {LZ4F_max256KB, + LZ4F_blockLinked, + LZ4F_noContentChecksum, + LZ4F_frame, + 0 /* unknown content size */, + 0 /* no dictID */, + LZ4F_noBlockChecksum}, + compression_level, /* compression level; 0 == default */ + 1, /* autoflush */ + 0, /* favor decompression speed */ + {0, 0, 0}, /* reserved, must be set to 0 */ + }; + + size_t ret = LZ4F_createCompressionContext(&ctx, LZ4F_VERSION); + + if (LZ4F_isError(ret)) + throw Exception( + ErrorCodes::LZ4_ENCODER_FAILED, + "creation of LZ4 compression context failed. LZ4F version: {}", + LZ4F_VERSION); +} + +Lz4DeflatingWriteBuffer::~Lz4DeflatingWriteBuffer() = default; + +void Lz4DeflatingWriteBuffer::nextImpl() +{ + if (!offset()) + return; + + in_data = reinterpret_cast<void *>(working_buffer.begin()); + in_capacity = offset(); + + out_capacity = out->buffer().end() - out->position(); + out_data = reinterpret_cast<void *>(out->position()); + + try + { + if (first_time) + { + if (out_capacity < LZ4F_HEADER_SIZE_MAX) + { + out->next(); + out_capacity = out->buffer().end() - out->position(); + out_data = reinterpret_cast<void *>(out->position()); + } + + /// write frame header and check for errors + size_t header_size = LZ4F_compressBegin(ctx, out_data, out_capacity, &kPrefs); + + if (LZ4F_isError(header_size)) + throw Exception( + ErrorCodes::LZ4_ENCODER_FAILED, + "LZ4 failed to start stream encoding. LZ4F version: {}", + LZ4F_VERSION); + + out_capacity -= header_size; + out->position() = out->buffer().end() - out_capacity; + out_data = reinterpret_cast<void *>(out->position()); + + first_time = false; + } + + do + { + /// Ensure that there is enough space for compressed block of minimal size + size_t min_compressed_block_size = LZ4F_compressBound(1, &kPrefs); + if (out_capacity < min_compressed_block_size) + { + out->next(); + out_capacity = out->buffer().end() - out->position(); + out_data = reinterpret_cast<void *>(out->position()); + } + + /// LZ4F_compressUpdate compresses whole input buffer at once so we need to shink it manually + size_t cur_buffer_size = in_capacity; + if (out_capacity >= min_compressed_block_size) /// We cannot shrink the input buffer if it's already too small. + { + while (out_capacity < LZ4F_compressBound(cur_buffer_size, &kPrefs)) + cur_buffer_size /= 2; + } + + size_t compressed_size = LZ4F_compressUpdate(ctx, out_data, out_capacity, in_data, cur_buffer_size, nullptr); + + if (LZ4F_isError(compressed_size)) + throw Exception( + ErrorCodes::LZ4_ENCODER_FAILED, + "LZ4 failed to encode stream. LZ4F version: {}", + LZ4F_VERSION); + + in_capacity -= cur_buffer_size; + in_data = reinterpret_cast<void *>(working_buffer.end() - in_capacity); + + out_capacity -= compressed_size; + out->position() = out->buffer().end() - out_capacity; + out_data = reinterpret_cast<void *>(out->position()); + } + while (in_capacity > 0); + } + catch (...) + { + out->position() = out->buffer().begin(); + throw; + } + out->next(); + out_capacity = out->buffer().end() - out->position(); +} + +void Lz4DeflatingWriteBuffer::finalizeBefore() +{ + next(); + + out_capacity = out->buffer().end() - out->position(); + out_data = reinterpret_cast<void *>(out->position()); + + if (out_capacity < LZ4F_compressBound(0, &kPrefs)) + { + out->next(); + out_capacity = out->buffer().end() - out->position(); + out_data = reinterpret_cast<void *>(out->position()); + } + + /// compression end + size_t end_size = LZ4F_compressEnd(ctx, out_data, out_capacity, nullptr); + + if (LZ4F_isError(end_size)) + throw Exception( + ErrorCodes::LZ4_ENCODER_FAILED, + "LZ4 failed to end stream encoding. LZ4F version: {}", + LZ4F_VERSION); + + out_capacity -= end_size; + out->position() = out->buffer().end() - out_capacity; + out_data = reinterpret_cast<void *>(out->position()); +} + +void Lz4DeflatingWriteBuffer::finalizeAfter() +{ + LZ4F_freeCompressionContext(ctx); +} + +} diff --git a/contrib/clickhouse/src/IO/Lz4DeflatingWriteBuffer.h b/contrib/clickhouse/src/IO/Lz4DeflatingWriteBuffer.h new file mode 100644 index 0000000000..68873b5f8e --- /dev/null +++ b/contrib/clickhouse/src/IO/Lz4DeflatingWriteBuffer.h @@ -0,0 +1,43 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/CompressionMethod.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteBufferDecorator.h> + +#include <lz4.h> +#include <lz4frame.h> + +namespace DB +{ +/// Performs compression using lz4 library and writes compressed data to out_ WriteBuffer. +class Lz4DeflatingWriteBuffer : public WriteBufferWithOwnMemoryDecorator +{ +public: + Lz4DeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + int compression_level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~Lz4DeflatingWriteBuffer() override; + +private: + void nextImpl() override; + + void finalizeBefore() override; + void finalizeAfter() override; + + LZ4F_preferences_t kPrefs; /// NOLINT + LZ4F_compressionContext_t ctx; + + void * in_data; + void * out_data; + + size_t in_capacity; + size_t out_capacity; + + bool first_time = true; +}; +} diff --git a/contrib/clickhouse/src/IO/Lz4InflatingReadBuffer.cpp b/contrib/clickhouse/src/IO/Lz4InflatingReadBuffer.cpp new file mode 100644 index 0000000000..eaa71048e7 --- /dev/null +++ b/contrib/clickhouse/src/IO/Lz4InflatingReadBuffer.cpp @@ -0,0 +1,89 @@ +#include <IO/Lz4InflatingReadBuffer.h> +#include <IO/WithFileName.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LZ4_DECODER_FAILED; +} + +Lz4InflatingReadBuffer::Lz4InflatingReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char * existing_memory, size_t alignment) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment) + , in_data(nullptr) + , out_data(nullptr) + , in_available(0) + , out_available(0) +{ + size_t ret = LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION); + + if (LZ4F_isError(ret)) + throw Exception( + ErrorCodes::LZ4_DECODER_FAILED, + "LZ4 failed create decompression context LZ4F_dctx. LZ4F version: {}. Error: {}", + LZ4F_VERSION, + LZ4F_getErrorName(ret)); +} + +Lz4InflatingReadBuffer::~Lz4InflatingReadBuffer() +{ + LZ4F_freeDecompressionContext(dctx); +} + +bool Lz4InflatingReadBuffer::nextImpl() +{ + if (eof_flag) + return false; + + bool need_more_input = false; + size_t ret; + + do + { + if (!in_available) + { + in->nextIfAtEnd(); + in_available = in->buffer().end() - in->position(); + } + + in_data = reinterpret_cast<void *>(in->position()); + out_data = reinterpret_cast<void *>(internal_buffer.begin()); + + out_available = internal_buffer.size(); + + size_t bytes_read = in_available; + size_t bytes_written = out_available; + + ret = LZ4F_decompress(dctx, out_data, &bytes_written, in_data, &bytes_read, /* LZ4F_decompressOptions_t */ nullptr); + + in_available -= bytes_read; + out_available -= bytes_written; + + /// It may happen that we didn't get new uncompressed data + /// (for example if we read the end of frame). Load new data + /// in this case. + need_more_input = bytes_written == 0; + + in->position() = in->buffer().end() - in_available; + } + while (need_more_input && !LZ4F_isError(ret) && !in->eof()); + + working_buffer.resize(internal_buffer.size() - out_available); + + if (LZ4F_isError(ret)) + throw Exception( + ErrorCodes::LZ4_DECODER_FAILED, + "LZ4 decompression failed. LZ4F version: {}. Error: {}{}", + LZ4F_VERSION, + LZ4F_getErrorName(ret), + getExceptionEntryWithFileName(*in)); + + if (in->eof()) + { + eof_flag = true; + return !working_buffer.empty(); + } + + return true; +} +} diff --git a/contrib/clickhouse/src/IO/Lz4InflatingReadBuffer.h b/contrib/clickhouse/src/IO/Lz4InflatingReadBuffer.h new file mode 100644 index 0000000000..46bdc40670 --- /dev/null +++ b/contrib/clickhouse/src/IO/Lz4InflatingReadBuffer.h @@ -0,0 +1,39 @@ +#pragma once + +#include <IO/CompressedReadBufferWrapper.h> +#include <IO/CompressionMethod.h> +#include <IO/ReadBuffer.h> + +#include <lz4.h> +#include <lz4frame.h> + + +namespace DB +{ + +class Lz4InflatingReadBuffer : public CompressedReadBufferWrapper +{ +public: + explicit Lz4InflatingReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~Lz4InflatingReadBuffer() override; + +private: + bool nextImpl() override; + + LZ4F_dctx* dctx; + + void * in_data; + void * out_data; + + size_t in_available; + size_t out_available; + + bool eof_flag = false; +}; + +} diff --git a/contrib/clickhouse/src/IO/MMapReadBufferFromFile.cpp b/contrib/clickhouse/src/IO/MMapReadBufferFromFile.cpp new file mode 100644 index 0000000000..86e05d7ae4 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMapReadBufferFromFile.cpp @@ -0,0 +1,79 @@ +#include <unistd.h> +#include <fcntl.h> + +#include <Common/ProfileEvents.h> +#include <Common/formatReadable.h> +#include <IO/MMapReadBufferFromFile.h> + + +namespace ProfileEvents +{ + extern const Event FileOpen; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FILE_DOESNT_EXIST; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; +} + + +void MMapReadBufferFromFile::open() +{ + ProfileEvents::increment(ProfileEvents::FileOpen); + + fd = ::open(file_name.c_str(), O_RDONLY | O_CLOEXEC); + + if (-1 == fd) + throwFromErrnoWithPath("Cannot open file " + file_name, file_name, + errno == ENOENT ? ErrorCodes::FILE_DOESNT_EXIST : ErrorCodes::CANNOT_OPEN_FILE); +} + + +std::string MMapReadBufferFromFile::getFileName() const +{ + return file_name; +} + + +MMapReadBufferFromFile::MMapReadBufferFromFile(const std::string & file_name_, size_t offset, size_t length_) + : file_name(file_name_) +{ + open(); + mapped.set(fd, offset, length_); + init(); +} + + +MMapReadBufferFromFile::MMapReadBufferFromFile(const std::string & file_name_, size_t offset) + : file_name(file_name_) +{ + open(); + mapped.set(fd, offset); + init(); +} + + +MMapReadBufferFromFile::~MMapReadBufferFromFile() +{ + if (fd != -1) + close(); /// Exceptions will lead to std::terminate and that's Ok. +} + + +void MMapReadBufferFromFile::close() +{ + finish(); + + if (0 != ::close(fd)) + throw Exception(ErrorCodes::CANNOT_CLOSE_FILE, "Cannot close file"); + + fd = -1; + metric_increment.destroy(); +} + +} diff --git a/contrib/clickhouse/src/IO/MMapReadBufferFromFile.h b/contrib/clickhouse/src/IO/MMapReadBufferFromFile.h new file mode 100644 index 0000000000..bc566a0489 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMapReadBufferFromFile.h @@ -0,0 +1,40 @@ +#pragma once + +#include <Common/CurrentMetrics.h> +#include <IO/MMapReadBufferFromFileDescriptor.h> + + +namespace CurrentMetrics +{ + extern const Metric OpenFileForRead; +} + + +namespace DB +{ + +class MMapReadBufferFromFile : public MMapReadBufferFromFileDescriptor +{ +public: + MMapReadBufferFromFile(const std::string & file_name_, size_t offset, size_t length_); + + /// Map till end of file. + MMapReadBufferFromFile(const std::string & file_name_, size_t offset); + + ~MMapReadBufferFromFile() override; + + void close(); + + std::string getFileName() const override; + +private: + int fd = -1; + std::string file_name; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::OpenFileForRead}; + + void open(); +}; + +} + diff --git a/contrib/clickhouse/src/IO/MMapReadBufferFromFileDescriptor.cpp b/contrib/clickhouse/src/IO/MMapReadBufferFromFileDescriptor.cpp new file mode 100644 index 0000000000..9b1c132cc0 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMapReadBufferFromFileDescriptor.cpp @@ -0,0 +1,105 @@ +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <Common/ProfileEvents.h> +#include <Common/formatReadable.h> +#include <Common/Exception.h> +#include <Common/filesystemHelpers.h> +#include <base/getPageSize.h> +#include <IO/WriteHelpers.h> +#include <IO/MMapReadBufferFromFileDescriptor.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int CANNOT_SEEK_THROUGH_FILE; +} + + +void MMapReadBufferFromFileDescriptor::init() +{ + size_t length = mapped.getLength(); + BufferBase::set(mapped.getData(), length, 0); + + size_t page_size = static_cast<size_t>(::getPageSize()); + ReadBuffer::padded = (length % page_size) > 0 && (length % page_size) <= (page_size - (PADDING_FOR_SIMD - 1)); +} + + +MMapReadBufferFromFileDescriptor::MMapReadBufferFromFileDescriptor(int fd, size_t offset, size_t length) + : mapped(fd, offset, length) +{ + init(); +} + + +MMapReadBufferFromFileDescriptor::MMapReadBufferFromFileDescriptor(int fd, size_t offset) + : mapped(fd, offset) +{ + init(); +} + + +void MMapReadBufferFromFileDescriptor::finish() +{ + mapped.finish(); +} + + +std::string MMapReadBufferFromFileDescriptor::getFileName() const +{ + return "(fd = " + toString(mapped.getFD()) + ")"; +} + +int MMapReadBufferFromFileDescriptor::getFD() const +{ + return mapped.getFD(); +} + +off_t MMapReadBufferFromFileDescriptor::getPosition() +{ + return count(); +} + +off_t MMapReadBufferFromFileDescriptor::seek(off_t offset, int whence) +{ + off_t new_pos; + if (whence == SEEK_SET) + new_pos = offset; + else if (whence == SEEK_CUR) + new_pos = count() + offset; + else + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "MMapReadBufferFromFileDescriptor::seek expects SEEK_SET or SEEK_CUR as whence"); + + working_buffer = internal_buffer; + if (new_pos < 0 || new_pos > off_t(working_buffer.size())) + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, + "Cannot seek through file {} because seek position ({}) is out of bounds [0, {}]", + getFileName(), new_pos, working_buffer.size()); + + position() = working_buffer.begin() + new_pos; + return new_pos; +} + +size_t MMapReadBufferFromFileDescriptor::getFileSize() +{ + return getSizeFromFileDescriptor(getFD(), getFileName()); +} + +size_t MMapReadBufferFromFileDescriptor::readBigAt(char * to, size_t n, size_t offset, const std::function<bool(size_t)> &) +{ + if (offset >= mapped.getLength()) + return 0; + + n = std::min(n, mapped.getLength() - offset); + memcpy(to, mapped.getData() + offset, n); + return n; +} + +} diff --git a/contrib/clickhouse/src/IO/MMapReadBufferFromFileDescriptor.h b/contrib/clickhouse/src/IO/MMapReadBufferFromFileDescriptor.h new file mode 100644 index 0000000000..2a039e0497 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMapReadBufferFromFileDescriptor.h @@ -0,0 +1,47 @@ +#pragma once + +#include <IO/ReadBufferFromFileBase.h> +#include <IO/MMappedFileDescriptor.h> + + +namespace DB +{ + +/** MMap range in a file and represent it as a ReadBuffer. + * Please note that mmap is not always the optimal way to read file. + * Also you cannot control whether and how long actual IO take place, + * so this method is not manageable and not recommended for anything except benchmarks. + */ +class MMapReadBufferFromFileDescriptor : public ReadBufferFromFileBase +{ +public: + off_t seek(off_t off, int whence) override; + +protected: + MMapReadBufferFromFileDescriptor() = default; + void init(); + + MMappedFileDescriptor mapped; + +public: + MMapReadBufferFromFileDescriptor(int fd_, size_t offset_, size_t length_); + + /// Map till end of file. + MMapReadBufferFromFileDescriptor(int fd_, size_t offset_); + + /// unmap memory before call to destructor + void finish(); + + off_t getPosition() override; + + std::string getFileName() const override; + + int getFD() const; + + size_t getFileSize() override; + + size_t readBigAt(char * to, size_t n, size_t offset, const std::function<bool(size_t)> &) override; + bool supportsReadAt() override { return true; } +}; + +} diff --git a/contrib/clickhouse/src/IO/MMapReadBufferFromFileWithCache.cpp b/contrib/clickhouse/src/IO/MMapReadBufferFromFileWithCache.cpp new file mode 100644 index 0000000000..d13cf5db2f --- /dev/null +++ b/contrib/clickhouse/src/IO/MMapReadBufferFromFileWithCache.cpp @@ -0,0 +1,78 @@ +#include <IO/MMapReadBufferFromFileWithCache.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int CANNOT_SEEK_THROUGH_FILE; +} + + +void MMapReadBufferFromFileWithCache::init() +{ + size_t length = mapped->getLength(); + BufferBase::set(mapped->getData(), length, 0); + + size_t page_size = static_cast<size_t>(::getPageSize()); + ReadBuffer::padded = (length % page_size) > 0 && (length % page_size) <= (page_size - (PADDING_FOR_SIMD - 1)); + ReadBufferFromFileBase::file_size = length; +} + + +MMapReadBufferFromFileWithCache::MMapReadBufferFromFileWithCache( + MMappedFileCache & cache, const std::string & file_name, size_t offset, size_t length) +{ + mapped = cache.getOrSet(cache.hash(file_name, offset, length), [&] + { + return std::make_shared<MMappedFile>(file_name, offset, length); + }); + + init(); +} + +MMapReadBufferFromFileWithCache::MMapReadBufferFromFileWithCache( + MMappedFileCache & cache, const std::string & file_name, size_t offset) +{ + mapped = cache.getOrSet(cache.hash(file_name, offset, -1), [&] + { + return std::make_shared<MMappedFile>(file_name, offset); + }); + + init(); +} + + +std::string MMapReadBufferFromFileWithCache::getFileName() const +{ + return mapped->getFileName(); +} + +off_t MMapReadBufferFromFileWithCache::getPosition() +{ + return count(); +} + +off_t MMapReadBufferFromFileWithCache::seek(off_t offset, int whence) +{ + off_t new_pos; + if (whence == SEEK_SET) + new_pos = offset; + else if (whence == SEEK_CUR) + new_pos = count() + offset; + else + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "MMapReadBufferFromFileWithCache::seek expects SEEK_SET or SEEK_CUR as whence"); + + working_buffer = internal_buffer; + if (new_pos < 0 || new_pos > off_t(working_buffer.size())) + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, + "Cannot seek through file {} because seek position ({}) is out of bounds [0, {}]", + getFileName(), new_pos, working_buffer.size()); + + position() = working_buffer.begin() + new_pos; + return new_pos; +} + +} diff --git a/contrib/clickhouse/src/IO/MMapReadBufferFromFileWithCache.h b/contrib/clickhouse/src/IO/MMapReadBufferFromFileWithCache.h new file mode 100644 index 0000000000..ff84f81610 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMapReadBufferFromFileWithCache.h @@ -0,0 +1,29 @@ +#pragma once + +#include <IO/ReadBufferFromFileBase.h> +#include <IO/MMappedFileCache.h> +#include <IO/MMapReadBufferFromFileDescriptor.h> + + +namespace DB +{ + +class MMapReadBufferFromFileWithCache : public ReadBufferFromFileBase +{ +public: + MMapReadBufferFromFileWithCache(MMappedFileCache & cache, const std::string & file_name, size_t offset, size_t length); + + /// Map till end of file. + MMapReadBufferFromFileWithCache(MMappedFileCache & cache, const std::string & file_name, size_t offset); + + off_t getPosition() override; + std::string getFileName() const override; + off_t seek(off_t offset, int whence) override; + +private: + MMappedFileCache::MappedPtr mapped; + + void init(); +}; + +} diff --git a/contrib/clickhouse/src/IO/MMappedFile.cpp b/contrib/clickhouse/src/IO/MMappedFile.cpp new file mode 100644 index 0000000000..9e45140d5f --- /dev/null +++ b/contrib/clickhouse/src/IO/MMappedFile.cpp @@ -0,0 +1,78 @@ +#include <unistd.h> +#include <fcntl.h> + +#include <Common/ProfileEvents.h> +#include <Common/formatReadable.h> +#include <Common/Exception.h> +#include <IO/MMappedFile.h> + + +namespace ProfileEvents +{ + extern const Event FileOpen; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FILE_DOESNT_EXIST; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; +} + + +void MMappedFile::open() +{ + ProfileEvents::increment(ProfileEvents::FileOpen); + + fd = ::open(file_name.c_str(), O_RDONLY | O_CLOEXEC); + + if (-1 == fd) + throwFromErrnoWithPath("Cannot open file " + file_name, file_name, + errno == ENOENT ? ErrorCodes::FILE_DOESNT_EXIST : ErrorCodes::CANNOT_OPEN_FILE); +} + + +std::string MMappedFile::getFileName() const +{ + return file_name; +} + + +MMappedFile::MMappedFile(const std::string & file_name_, size_t offset_, size_t length_) + : file_name(file_name_) +{ + open(); + set(fd, offset_, length_); +} + + +MMappedFile::MMappedFile(const std::string & file_name_, size_t offset_) + : file_name(file_name_) +{ + open(); + set(fd, offset_); +} + + +MMappedFile::~MMappedFile() +{ + if (fd != -1) + close(); /// Exceptions will lead to std::terminate and that's Ok. +} + + +void MMappedFile::close() +{ + finish(); + + if (0 != ::close(fd)) + throw Exception(ErrorCodes::CANNOT_CLOSE_FILE, "Cannot close file"); + + fd = -1; + metric_increment.destroy(); +} + +} diff --git a/contrib/clickhouse/src/IO/MMappedFile.h b/contrib/clickhouse/src/IO/MMappedFile.h new file mode 100644 index 0000000000..6ecf988fa9 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMappedFile.h @@ -0,0 +1,40 @@ +#pragma once + +#include <Common/CurrentMetrics.h> +#include <IO/MMappedFileDescriptor.h> +#include <cstddef> + + +namespace CurrentMetrics +{ + extern const Metric OpenFileForRead; +} + + +namespace DB +{ + +/// Opens a file and mmaps a region in it (or a whole file) into memory. Unmaps and closes in destructor. +class MMappedFile : public MMappedFileDescriptor +{ +public: + MMappedFile(const std::string & file_name_, size_t offset_, size_t length_); + + /// Map till end of file. + MMappedFile(const std::string & file_name_, size_t offset_); + + ~MMappedFile() override; + + void close(); + + std::string getFileName() const; + +private: + std::string file_name; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::OpenFileForRead}; + + void open(); +}; + +} diff --git a/contrib/clickhouse/src/IO/MMappedFileCache.h b/contrib/clickhouse/src/IO/MMappedFileCache.h new file mode 100644 index 0000000000..bb30829ed6 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMappedFileCache.h @@ -0,0 +1,60 @@ +#pragma once + +#include <Core/Types.h> +#include <Common/HashTable/Hash.h> +#include <Common/CacheBase.h> +#include <Common/SipHash.h> +#include <Common/ProfileEvents.h> +#include <IO/MMappedFile.h> + + +namespace ProfileEvents +{ + extern const Event MMappedFileCacheHits; + extern const Event MMappedFileCacheMisses; +} + +namespace DB +{ + + +/** Cache of opened and mmapped files for reading. + * mmap/munmap is heavy operation and better to keep mapped file to subsequent use than to map/unmap every time. + */ +class MMappedFileCache : public CacheBase<UInt128, MMappedFile, UInt128TrivialHash> +{ +private: + using Base = CacheBase<UInt128, MMappedFile, UInt128TrivialHash>; + +public: + explicit MMappedFileCache(size_t max_size_in_bytes) + : Base(max_size_in_bytes) {} + + /// Calculate key from path to file and offset. + static UInt128 hash(const String & path_to_file, size_t offset, ssize_t length = -1) + { + SipHash hash; + hash.update(path_to_file.data(), path_to_file.size() + 1); + hash.update(offset); + hash.update(length); + + return hash.get128(); + } + + template <typename LoadFunc> + MappedPtr getOrSet(const Key & key, LoadFunc && load) + { + auto result = Base::getOrSet(key, load); + if (result.second) + ProfileEvents::increment(ProfileEvents::MMappedFileCacheMisses); + else + ProfileEvents::increment(ProfileEvents::MMappedFileCacheHits); + + return result.first; + } +}; + +using MMappedFileCachePtr = std::shared_ptr<MMappedFileCache>; + +} + diff --git a/contrib/clickhouse/src/IO/MMappedFileDescriptor.cpp b/contrib/clickhouse/src/IO/MMappedFileDescriptor.cpp new file mode 100644 index 0000000000..9cc1aaf656 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMappedFileDescriptor.cpp @@ -0,0 +1,107 @@ +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <fmt/format.h> + +#include <Common/formatReadable.h> +#include <Common/Exception.h> +#include <base/getPageSize.h> +#include <IO/MMappedFileDescriptor.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_ALLOCATE_MEMORY; + extern const int CANNOT_MUNMAP; + extern const int CANNOT_STAT; + extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; +} + + +static size_t getFileSize(int fd) +{ + struct stat stat_res {}; + if (0 != fstat(fd, &stat_res)) + throwFromErrno("MMappedFileDescriptor: Cannot fstat.", ErrorCodes::CANNOT_STAT); + + off_t file_size = stat_res.st_size; + + if (file_size < 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "MMappedFileDescriptor: fstat returned negative file size"); + + return file_size; +} + + +MMappedFileDescriptor::MMappedFileDescriptor(int fd_, size_t offset_, size_t length_) +{ + set(fd_, offset_, length_); +} + +MMappedFileDescriptor::MMappedFileDescriptor(int fd_, size_t offset_) + : fd(fd_), offset(offset_) +{ + set(fd_, offset_); +} + +void MMappedFileDescriptor::set(int fd_, size_t offset_, size_t length_) +{ + finish(); + + fd = fd_; + offset = offset_; + length = length_; + + if (!length) + return; + + void * buf = mmap(nullptr, length, PROT_READ, MAP_PRIVATE, fd, offset); + if (MAP_FAILED == buf) + throwFromErrno(fmt::format("MMappedFileDescriptor: Cannot mmap {}.", ReadableSize(length)), + ErrorCodes::CANNOT_ALLOCATE_MEMORY); + + data = static_cast<char *>(buf); + + files_metric_increment.changeTo(1); + bytes_metric_increment.changeTo(length); +} + +void MMappedFileDescriptor::set(int fd_, size_t offset_) +{ + size_t file_size = getFileSize(fd_); + + if (offset > static_cast<size_t>(file_size)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "MMappedFileDescriptor: requested offset is greater than file size"); + + set(fd_, offset_, file_size - offset); +} + +void MMappedFileDescriptor::finish() +{ + if (!length) + return; + + if (0 != munmap(data, length)) + throwFromErrno(fmt::format("MMappedFileDescriptor: Cannot munmap {}.", ReadableSize(length)), + ErrorCodes::CANNOT_MUNMAP); + + length = 0; + + files_metric_increment.changeTo(0); + bytes_metric_increment.changeTo(0); +} + +MMappedFileDescriptor::~MMappedFileDescriptor() +{ + finish(); /// Exceptions will lead to std::terminate and that's Ok. +} + +} + + diff --git a/contrib/clickhouse/src/IO/MMappedFileDescriptor.h b/contrib/clickhouse/src/IO/MMappedFileDescriptor.h new file mode 100644 index 0000000000..2611093643 --- /dev/null +++ b/contrib/clickhouse/src/IO/MMappedFileDescriptor.h @@ -0,0 +1,60 @@ +#pragma once + +#include <cstddef> +#include <Common/CurrentMetrics.h> + +namespace CurrentMetrics +{ + extern const Metric MMappedFiles; + extern const Metric MMappedFileBytes; +} + + +namespace DB +{ + +/// MMaps a region in file (or a whole file) into memory. Unmaps in destructor. +/// Does not open or close file. +class MMappedFileDescriptor +{ +public: + MMappedFileDescriptor(int fd_, size_t offset_, size_t length_); + MMappedFileDescriptor(int fd_, size_t offset_); + + /// Makes empty object that can be initialized with `set`. + MMappedFileDescriptor() = default; + + virtual ~MMappedFileDescriptor(); + + char * getData() { return data; } + const char * getData() const { return data; } + + int getFD() const { return fd; } + size_t getOffset() const { return offset; } + size_t getLength() const { return length; } + + /// Unmap memory before call to destructor + void finish(); + + /// Initialize or reset to another fd. + void set(int fd_, size_t offset_, size_t length_); + void set(int fd_, size_t offset_); + + MMappedFileDescriptor(const MMappedFileDescriptor &) = delete; + MMappedFileDescriptor(MMappedFileDescriptor &&) = delete; + +protected: + + void init(); + + int fd = -1; + size_t offset = 0; + size_t length = 0; + char * data = nullptr; + + CurrentMetrics::Increment files_metric_increment{CurrentMetrics::MMappedFiles, 0}; + CurrentMetrics::Increment bytes_metric_increment{CurrentMetrics::MMappedFileBytes, 0}; +}; + +} + diff --git a/contrib/clickhouse/src/IO/MemoryReadWriteBuffer.cpp b/contrib/clickhouse/src/IO/MemoryReadWriteBuffer.cpp new file mode 100644 index 0000000000..415a6c6fad --- /dev/null +++ b/contrib/clickhouse/src/IO/MemoryReadWriteBuffer.cpp @@ -0,0 +1,147 @@ +#include <IO/MemoryReadWriteBuffer.h> +#include <boost/noncopyable.hpp> + + +namespace DB +{ + +class ReadBufferFromMemoryWriteBuffer : public ReadBuffer, boost::noncopyable, private Allocator<false> +{ +public: + explicit ReadBufferFromMemoryWriteBuffer(MemoryWriteBuffer && origin) + : ReadBuffer(nullptr, 0), + chunk_list(std::move(origin.chunk_list)), + end_pos(origin.position()) + { + chunk_head = chunk_list.begin(); + setChunk(); + } + + bool nextImpl() override + { + if (chunk_head == chunk_list.end()) + return false; + + ++chunk_head; + return setChunk(); + } + + ~ReadBufferFromMemoryWriteBuffer() override + { + for (const auto & range : chunk_list) + free(range.begin(), range.size()); + } + +private: + + /// update buffers and position according to chunk_head pointer + bool setChunk() + { + if (chunk_head != chunk_list.end()) + { + internalBuffer() = *chunk_head; + + /// It is last chunk, it should be truncated + if (std::next(chunk_head) != chunk_list.end()) + buffer() = internalBuffer(); + else + buffer() = Buffer(internalBuffer().begin(), end_pos); + + position() = buffer().begin(); + } + else + { + buffer() = internalBuffer() = Buffer(nullptr, nullptr); + position() = nullptr; + } + + return !buffer().empty(); + } + + using Container = std::forward_list<BufferBase::Buffer>; + + Container chunk_list; + Container::iterator chunk_head; + Position end_pos; +}; + + +MemoryWriteBuffer::MemoryWriteBuffer(size_t max_total_size_, size_t initial_chunk_size_, double growth_rate_, size_t max_chunk_size_) + : WriteBuffer(nullptr, 0), + max_total_size(max_total_size_), + initial_chunk_size(initial_chunk_size_), + max_chunk_size(max_chunk_size_), + growth_rate(growth_rate_) +{ + addChunk(); +} + + +void MemoryWriteBuffer::nextImpl() +{ + if (unlikely(hasPendingData())) + { + /// ignore flush + buffer() = Buffer(pos, buffer().end()); + return; + } + + addChunk(); +} + + +void MemoryWriteBuffer::addChunk() +{ + size_t next_chunk_size; + if (chunk_list.empty()) + { + chunk_tail = chunk_list.before_begin(); + next_chunk_size = initial_chunk_size; + } + else + { + next_chunk_size = std::max(1uz, static_cast<size_t>(chunk_tail->size() * growth_rate)); + next_chunk_size = std::min(next_chunk_size, max_chunk_size); + } + + if (max_total_size) + { + if (total_chunks_size + next_chunk_size > max_total_size) + next_chunk_size = max_total_size - total_chunks_size; + + if (0 == next_chunk_size) + { + set(position(), 0); + throw MemoryWriteBuffer::CurrentBufferExhausted(); + } + } + + Position begin = reinterpret_cast<Position>(alloc(next_chunk_size)); + chunk_tail = chunk_list.emplace_after(chunk_tail, begin, begin + next_chunk_size); + total_chunks_size += next_chunk_size; + + set(chunk_tail->begin(), chunk_tail->size()); +} + + +std::shared_ptr<ReadBuffer> MemoryWriteBuffer::getReadBufferImpl() +{ + finalize(); + + auto res = std::make_shared<ReadBufferFromMemoryWriteBuffer>(std::move(*this)); + + /// invalidate members + chunk_list.clear(); + chunk_tail = chunk_list.begin(); + + return res; +} + + +MemoryWriteBuffer::~MemoryWriteBuffer() +{ + for (const auto & range : chunk_list) + free(range.begin(), range.size()); +} + +} diff --git a/contrib/clickhouse/src/IO/MemoryReadWriteBuffer.h b/contrib/clickhouse/src/IO/MemoryReadWriteBuffer.h new file mode 100644 index 0000000000..d6bf231c22 --- /dev/null +++ b/contrib/clickhouse/src/IO/MemoryReadWriteBuffer.h @@ -0,0 +1,60 @@ +#pragma once +#include <forward_list> + +#include <IO/WriteBuffer.h> +#include <IO/IReadableWriteBuffer.h> +#include <Common/Allocator.h> +#include <Core/Defines.h> +#include <boost/noncopyable.hpp> + + +namespace DB +{ + +/// Stores data in memory chunks, size of chunks are exponentially increasing during write +/// Written data could be reread after write +class MemoryWriteBuffer : public WriteBuffer, public IReadableWriteBuffer, boost::noncopyable, private Allocator<false> +{ +public: + /// Special exception to throw when the current WriteBuffer cannot receive data + class CurrentBufferExhausted : public std::exception + { + public: + const char * what() const noexcept override { return "MemoryWriteBuffer limit is exhausted"; } + }; + + /// Use max_total_size_ = 0 for unlimited storage + explicit MemoryWriteBuffer( + size_t max_total_size_ = 0, + size_t initial_chunk_size_ = DBMS_DEFAULT_BUFFER_SIZE, + double growth_rate_ = 2.0, + size_t max_chunk_size_ = 128 * DBMS_DEFAULT_BUFFER_SIZE); + + ~MemoryWriteBuffer() override; + +protected: + + void nextImpl() override; + + void finalizeImpl() override { /* no op */ } + + std::shared_ptr<ReadBuffer> getReadBufferImpl() override; + + const size_t max_total_size; + const size_t initial_chunk_size; + const size_t max_chunk_size; + const double growth_rate; + + using Container = std::forward_list<BufferBase::Buffer>; + + Container chunk_list; + Container::iterator chunk_tail; + size_t total_chunks_size = 0; + + void addChunk(); + + friend class ReadBufferFromMemoryWriteBuffer; +}; + + +} diff --git a/contrib/clickhouse/src/IO/MySQLBinlogEventReadBuffer.cpp b/contrib/clickhouse/src/IO/MySQLBinlogEventReadBuffer.cpp new file mode 100644 index 0000000000..9f05c5b5e0 --- /dev/null +++ b/contrib/clickhouse/src/IO/MySQLBinlogEventReadBuffer.cpp @@ -0,0 +1,80 @@ +#include <IO/MySQLBinlogEventReadBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +MySQLBinlogEventReadBuffer::MySQLBinlogEventReadBuffer(ReadBuffer & in_, size_t checksum_signature_length_) + : ReadBuffer(nullptr, 0, 0), in(in_), checksum_signature_length(checksum_signature_length_) +{ + if (checksum_signature_length > MAX_CHECKSUM_SIGNATURE_LENGTH) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "LOGICAL ERROR: checksum_signature_length must be less than MAX_CHECKSUM_SIGNATURE_LENGTH. " + "It is a bug."); + + nextIfAtEnd(); +} + +bool MySQLBinlogEventReadBuffer::nextImpl() +{ + if (hasPendingData()) + return true; + + if (in.eof()) + return false; + + if (checksum_buff_size == checksum_buff_limit) + { + if (likely(in.available() > checksum_signature_length)) + { + working_buffer = ReadBuffer::Buffer(in.position(), in.buffer().end() - checksum_signature_length); + in.ignore(working_buffer.size()); + return true; + } + + in.readStrict(checksum_buf, checksum_signature_length); + checksum_buff_size = checksum_buff_limit = checksum_signature_length; + } + else + { + for (size_t index = 0; index < checksum_buff_size - checksum_buff_limit; ++index) + checksum_buf[index] = checksum_buf[checksum_buff_limit + index]; + + checksum_buff_size -= checksum_buff_limit; + size_t read_bytes = checksum_signature_length - checksum_buff_size; + in.readStrict(checksum_buf + checksum_buff_size, read_bytes); /// Minimum checksum_signature_length bytes + checksum_buff_size = checksum_buff_limit = checksum_signature_length; + } + + if (in.eof()) + return false; + + if (in.available() < checksum_signature_length) + { + size_t left_move_size = checksum_signature_length - in.available(); + checksum_buff_limit = checksum_buff_size - left_move_size; + } + + working_buffer = ReadBuffer::Buffer(checksum_buf, checksum_buf + checksum_buff_limit); + return true; +} + +MySQLBinlogEventReadBuffer::~MySQLBinlogEventReadBuffer() +{ + try + { + /// ignore last checksum_signature_length bytes + nextIfAtEnd(); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +} diff --git a/contrib/clickhouse/src/IO/MySQLBinlogEventReadBuffer.h b/contrib/clickhouse/src/IO/MySQLBinlogEventReadBuffer.h new file mode 100644 index 0000000000..7212a54884 --- /dev/null +++ b/contrib/clickhouse/src/IO/MySQLBinlogEventReadBuffer.h @@ -0,0 +1,30 @@ +#pragma once + +#include <IO/ReadBuffer.h> + +namespace DB +{ + +class MySQLBinlogEventReadBuffer : public ReadBuffer +{ +protected: + static const size_t MAX_CHECKSUM_SIGNATURE_LENGTH = 4; + + ReadBuffer & in; + size_t checksum_signature_length = 0; + + size_t checksum_buff_size = 0; + size_t checksum_buff_limit = 0; + char checksum_buf[MAX_CHECKSUM_SIGNATURE_LENGTH]{}; + + bool nextImpl() override; + +public: + ~MySQLBinlogEventReadBuffer() override; + + MySQLBinlogEventReadBuffer(ReadBuffer & in_, size_t checksum_signature_length_); + +}; + + +} diff --git a/contrib/clickhouse/src/IO/MySQLPacketPayloadReadBuffer.cpp b/contrib/clickhouse/src/IO/MySQLPacketPayloadReadBuffer.cpp new file mode 100644 index 0000000000..2c5167ed03 --- /dev/null +++ b/contrib/clickhouse/src/IO/MySQLPacketPayloadReadBuffer.cpp @@ -0,0 +1,61 @@ +#include <IO/MySQLPacketPayloadReadBuffer.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_PACKET_FROM_CLIENT; +} + +const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb + +MySQLPacketPayloadReadBuffer::MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_) + : ReadBuffer(in_.position(), 0), in(in_), sequence_id(sequence_id_) // not in.buffer().begin(), because working buffer may include previous packet +{ +} + +bool MySQLPacketPayloadReadBuffer::nextImpl() +{ + if (!has_read_header || (payload_length == MAX_PACKET_LENGTH && offset == payload_length)) + { + has_read_header = true; + working_buffer.resize(0); + offset = 0; + payload_length = 0; + in.readStrict(reinterpret_cast<char *>(&payload_length), 3); + + if (payload_length > MAX_PACKET_LENGTH) + throw Exception(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT, + "Received packet with payload larger than max_packet_size: {}", payload_length); + + size_t packet_sequence_id = 0; + in.readStrict(reinterpret_cast<char &>(packet_sequence_id)); + if (packet_sequence_id != sequence_id) + throw Exception(ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT, + "Received packet with wrong sequence-id: {}. Expected: {}.", packet_sequence_id, static_cast<unsigned int>(sequence_id)); + sequence_id++; + + if (payload_length == 0) + return false; + } + else if (offset == payload_length) + { + return false; + } + + in.nextIfAtEnd(); + /// Don't return a buffer when no bytes available + if (!in.hasPendingData()) + return false; + working_buffer = ReadBuffer::Buffer(in.position(), in.buffer().end()); + size_t count = std::min(in.available(), payload_length - offset); + working_buffer.resize(count); + in.ignore(count); + + offset += count; + + return true; +} + +} diff --git a/contrib/clickhouse/src/IO/MySQLPacketPayloadReadBuffer.h b/contrib/clickhouse/src/IO/MySQLPacketPayloadReadBuffer.h new file mode 100644 index 0000000000..f90a34ba93 --- /dev/null +++ b/contrib/clickhouse/src/IO/MySQLPacketPayloadReadBuffer.h @@ -0,0 +1,33 @@ +#pragma once + +#include <IO/ReadBuffer.h> + +namespace DB +{ + +/** Reading packets. + * Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload. + */ +class MySQLPacketPayloadReadBuffer : public ReadBuffer +{ +private: + ReadBuffer & in; + uint8_t & sequence_id; + + bool has_read_header = false; + + // Size of packet which is being read now. + size_t payload_length = 0; + + // Offset in packet payload. + size_t offset = 0; + +protected: + bool nextImpl() override; + +public: + MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_); +}; + +} + diff --git a/contrib/clickhouse/src/IO/MySQLPacketPayloadWriteBuffer.cpp b/contrib/clickhouse/src/IO/MySQLPacketPayloadWriteBuffer.cpp new file mode 100644 index 0000000000..425e1b8d08 --- /dev/null +++ b/contrib/clickhouse/src/IO/MySQLPacketPayloadWriteBuffer.cpp @@ -0,0 +1,61 @@ +#include <IO/MySQLPacketPayloadWriteBuffer.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_WRITE_AFTER_END_OF_BUFFER; +} + +const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb + +MySQLPacketPayloadWriteBuffer::MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_) + : WriteBuffer(out_.position(), 0), out(out_), sequence_id(sequence_id_), total_left(payload_length_) +{ + startNewPacket(); + setWorkingBuffer(); + pos = out.position(); +} + +void MySQLPacketPayloadWriteBuffer::startNewPacket() +{ + payload_length = std::min(total_left, MAX_PACKET_LENGTH); + bytes_written = 0; + total_left -= payload_length; + + out.write(reinterpret_cast<char *>(&payload_length), 3); + out.write(sequence_id++); + bytes += 4; +} + +void MySQLPacketPayloadWriteBuffer::setWorkingBuffer() +{ + out.nextIfAtEnd(); + working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available())); + + if (payload_length - bytes_written == 0) + { + /// Finished writing packet. Due to an implementation of WriteBuffer, working_buffer cannot be empty. Further write attempts will throw Exception. + eof = true; + working_buffer.resize(1); + } +} + +void MySQLPacketPayloadWriteBuffer::nextImpl() +{ + size_t written = pos - working_buffer.begin(); + if (eof) + throw Exception(ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER, "Cannot write after end of buffer."); + + out.position() += written; + bytes_written += written; + + /// Packets of size greater than MAX_PACKET_LENGTH are split into few packets of size MAX_PACKET_LENGTH and las packet of size < MAX_PACKET_LENGTH. + if (bytes_written == payload_length && (total_left > 0 || payload_length == MAX_PACKET_LENGTH)) + startNewPacket(); + + setWorkingBuffer(); +} + +} diff --git a/contrib/clickhouse/src/IO/MySQLPacketPayloadWriteBuffer.h b/contrib/clickhouse/src/IO/MySQLPacketPayloadWriteBuffer.h new file mode 100644 index 0000000000..d4ce8a8955 --- /dev/null +++ b/contrib/clickhouse/src/IO/MySQLPacketPayloadWriteBuffer.h @@ -0,0 +1,36 @@ +#pragma once + +#include <IO/WriteBuffer.h> + +namespace DB +{ + +/** Writing packets. + * https://dev.mysql.com/doc/internals/en/mysql-packet.html + */ +class MySQLPacketPayloadWriteBuffer : public WriteBuffer +{ +public: + MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_); + + bool remainingPayloadSize() const { return total_left; } + +protected: + void nextImpl() override; + +private: + WriteBuffer & out; + uint8_t & sequence_id; + + size_t total_left = 0; + size_t payload_length = 0; + size_t bytes_written = 0; + bool eof = false; + + void startNewPacket(); + + /// Sets working buffer to the rest of current packet payload. + void setWorkingBuffer(); +}; + +} diff --git a/contrib/clickhouse/src/IO/NullWriteBuffer.cpp b/contrib/clickhouse/src/IO/NullWriteBuffer.cpp new file mode 100644 index 0000000000..295c53ef7c --- /dev/null +++ b/contrib/clickhouse/src/IO/NullWriteBuffer.cpp @@ -0,0 +1,16 @@ +#include <IO/NullWriteBuffer.h> + + +namespace DB +{ + +NullWriteBuffer::NullWriteBuffer() + : WriteBuffer(data, sizeof(data)) +{ +} + +void NullWriteBuffer::nextImpl() +{ +} + +} diff --git a/contrib/clickhouse/src/IO/NullWriteBuffer.h b/contrib/clickhouse/src/IO/NullWriteBuffer.h new file mode 100644 index 0000000000..f14c74ff72 --- /dev/null +++ b/contrib/clickhouse/src/IO/NullWriteBuffer.h @@ -0,0 +1,21 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <boost/noncopyable.hpp> + +namespace DB +{ + +/// Simply do nothing, can be used to measure amount of written bytes. +class NullWriteBuffer : public WriteBuffer, boost::noncopyable +{ +public: + NullWriteBuffer(); + void nextImpl() override; + +private: + char data[128]; +}; + +} diff --git a/contrib/clickhouse/src/IO/OpenedFile.cpp b/contrib/clickhouse/src/IO/OpenedFile.cpp new file mode 100644 index 0000000000..b75e087e5c --- /dev/null +++ b/contrib/clickhouse/src/IO/OpenedFile.cpp @@ -0,0 +1,77 @@ +#include <mutex> +#include <unistd.h> +#include <fcntl.h> + +#include <Common/ProfileEvents.h> +#include <Common/Exception.h> +#include <IO/OpenedFile.h> + + +namespace ProfileEvents +{ + extern const Event FileOpen; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FILE_DOESNT_EXIST; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; +} + + +void OpenedFile::open() const +{ + ProfileEvents::increment(ProfileEvents::FileOpen); + + fd = ::open(file_name.c_str(), (flags == -1 ? 0 : flags) | O_RDONLY | O_CLOEXEC); + + if (-1 == fd) + throwFromErrnoWithPath("Cannot open file " + file_name, file_name, + errno == ENOENT ? ErrorCodes::FILE_DOESNT_EXIST : ErrorCodes::CANNOT_OPEN_FILE); +} + +int OpenedFile::getFD() const +{ + std::lock_guard l(mutex); + if (fd == -1) + open(); + return fd; +} + +std::string OpenedFile::getFileName() const +{ + return file_name; +} + + +OpenedFile::OpenedFile(const std::string & file_name_, int flags_) + : file_name(file_name_), flags(flags_) +{ +} + + +OpenedFile::~OpenedFile() +{ + close(); /// Exceptions will lead to std::terminate and that's Ok. +} + + +void OpenedFile::close() +{ + std::lock_guard l(mutex); + if (fd == -1) + return; + + if (0 != ::close(fd)) + throw Exception(ErrorCodes::CANNOT_CLOSE_FILE, "Cannot close file"); + + fd = -1; + metric_increment.destroy(); +} + +} + diff --git a/contrib/clickhouse/src/IO/OpenedFile.h b/contrib/clickhouse/src/IO/OpenedFile.h new file mode 100644 index 0000000000..10c36d9e1d --- /dev/null +++ b/contrib/clickhouse/src/IO/OpenedFile.h @@ -0,0 +1,43 @@ +#pragma once + +#include <Common/CurrentMetrics.h> +#include <memory> +#include <mutex> + + +namespace CurrentMetrics +{ + extern const Metric OpenFileForRead; +} + + +namespace DB +{ + +/// RAII for readonly opened file descriptor. +class OpenedFile +{ +public: + OpenedFile(const std::string & file_name_, int flags_); + ~OpenedFile(); + + /// Close prematurally. + void close(); + + int getFD() const; + std::string getFileName() const; + +private: + std::string file_name; + int flags = 0; + + mutable int fd = -1; + mutable std::mutex mutex; + + CurrentMetrics::Increment metric_increment{CurrentMetrics::OpenFileForRead}; + + void open() const; +}; + +} + diff --git a/contrib/clickhouse/src/IO/OpenedFileCache.h b/contrib/clickhouse/src/IO/OpenedFileCache.h new file mode 100644 index 0000000000..2cecc675af --- /dev/null +++ b/contrib/clickhouse/src/IO/OpenedFileCache.h @@ -0,0 +1,116 @@ +#pragma once + +#include <map> +#include <mutex> + +#include <Core/Types.h> +#include <IO/OpenedFile.h> +#include <Common/ElapsedTimeProfileEventIncrement.h> +#include <Common/ProfileEvents.h> + +#include <city.h> + + +namespace ProfileEvents +{ + extern const Event OpenedFileCacheHits; + extern const Event OpenedFileCacheMisses; + extern const Event OpenedFileCacheMicroseconds; +} + +namespace DB +{ + + +/** Cache of opened files for reading. + * It allows to share file descriptors when doing reading with 'pread' syscalls on readonly files. + * Note: open/close of files is very cheap on Linux and we should not bother doing it 10 000 times a second. + * (This may not be the case on Windows with WSL. This is also not the case if strace is active. Neither when some eBPF is loaded). + * But sometimes we may end up opening one file multiple times, that increases chance exhausting opened files limit. + */ +class OpenedFileCache +{ + class OpenedFileMap + { + using Key = std::pair<std::string /* path */, int /* flags */>; + + using OpenedFileWeakPtr = std::weak_ptr<OpenedFile>; + using Files = std::map<Key, OpenedFileWeakPtr>; + + Files files; + std::mutex mutex; + + public: + using OpenedFilePtr = std::shared_ptr<OpenedFile>; + + OpenedFilePtr get(const std::string & path, int flags) + { + Key key(path, flags); + + std::lock_guard lock(mutex); + + auto [it, inserted] = files.emplace(key, OpenedFilePtr{}); + if (!inserted) + { + if (auto res = it->second.lock()) + { + ProfileEvents::increment(ProfileEvents::OpenedFileCacheHits); + return res; + } + } + ProfileEvents::increment(ProfileEvents::OpenedFileCacheMisses); + + OpenedFilePtr res + { + new OpenedFile(path, flags), + [key, this](auto ptr) + { + { + std::lock_guard another_lock(mutex); + files.erase(key); + } + delete ptr; + } + }; + + it->second = res; + return res; + } + + void remove(const std::string & path, int flags) + { + Key key(path, flags); + std::lock_guard lock(mutex); + files.erase(key); + } + }; + + static constexpr size_t buckets = 1024; + std::vector<OpenedFileMap> impls{buckets}; + +public: + using OpenedFilePtr = OpenedFileMap::OpenedFilePtr; + + OpenedFilePtr get(const std::string & path, int flags) + { + ProfileEventTimeIncrement<Microseconds> watch(ProfileEvents::OpenedFileCacheMicroseconds); + const auto bucket = CityHash_v1_0_2::CityHash64(path.data(), path.length()) % buckets; + return impls[bucket].get(path, flags); + } + + void remove(const std::string & path, int flags) + { + ProfileEventTimeIncrement<Microseconds> watch(ProfileEvents::OpenedFileCacheMicroseconds); + const auto bucket = CityHash_v1_0_2::CityHash64(path.data(), path.length()) % buckets; + impls[bucket].remove(path, flags); + } + + static OpenedFileCache & instance() + { + static OpenedFileCache res; + return res; + } +}; + +using OpenedFileCachePtr = std::shared_ptr<OpenedFileCache>; +} diff --git a/contrib/clickhouse/src/IO/Operators.h b/contrib/clickhouse/src/IO/Operators.h new file mode 100644 index 0000000000..185745e841 --- /dev/null +++ b/contrib/clickhouse/src/IO/Operators.h @@ -0,0 +1,98 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/WriteBuffer.h> +#include <IO/ReadHelpers.h> +#include <IO/WriteHelpers.h> + +#include <functional> + + +namespace DB +{ + +/** Implements the ability to write and read data in/from WriteBuffer/ReadBuffer + * with the help of << and >> operators and also manipulators, + * providing a way of using, similar to iostreams. + * + * It is neither a subset nor an extension of iostreams. + * + * Example usage: + * + * DB::WriteBufferFromFileDescriptor buf(STDOUT_FILENO); + * buf << DB::double_quote << "Hello, world!" << '\n' << DB::flush; + * + * Outputs `char` type (usually it's Int8) as a symbol, not as a number. + */ + +/// Manipulators. +enum EscapeManip { escape }; /// For strings - escape special characters. In the rest, as usual. +enum QuoteManip { quote }; /// For strings, dates, datetimes - enclose in single quotes with escaping. In the rest, as usual. +enum DoubleQuoteManip { double_quote }; /// For strings, dates, datetimes - enclose in double quotes with escaping. In the rest, as usual. +enum BinaryManip { binary }; /// Output in binary format. +enum XMLManip { xml }; /// Output strings with XML escaping. + +struct EscapeManipWriteBuffer : std::reference_wrapper<WriteBuffer> { using std::reference_wrapper<WriteBuffer>::reference_wrapper; }; +struct QuoteManipWriteBuffer : std::reference_wrapper<WriteBuffer> { using std::reference_wrapper<WriteBuffer>::reference_wrapper; }; +struct DoubleQuoteManipWriteBuffer : std::reference_wrapper<WriteBuffer> { using std::reference_wrapper<WriteBuffer>::reference_wrapper; }; +struct BinaryManipWriteBuffer : std::reference_wrapper<WriteBuffer> { using std::reference_wrapper<WriteBuffer>::reference_wrapper; }; +struct XMLManipWriteBuffer : std::reference_wrapper<WriteBuffer> { using std::reference_wrapper<WriteBuffer>::reference_wrapper; }; + +struct EscapeManipReadBuffer : std::reference_wrapper<ReadBuffer> { using std::reference_wrapper<ReadBuffer>::reference_wrapper; }; +struct QuoteManipReadBuffer : std::reference_wrapper<ReadBuffer> { using std::reference_wrapper<ReadBuffer>::reference_wrapper; }; +struct DoubleQuoteManipReadBuffer : std::reference_wrapper<ReadBuffer> { using std::reference_wrapper<ReadBuffer>::reference_wrapper; }; +struct BinaryManipReadBuffer : std::reference_wrapper<ReadBuffer> { using std::reference_wrapper<ReadBuffer>::reference_wrapper; }; + +inline WriteBuffer & operator<<(WriteBuffer & buf, const auto & x) { writeText(x, buf); return buf; } +inline WriteBuffer & operator<<(WriteBuffer & buf, const pcg32_fast & x) { PcgSerializer::serializePcg32(x, buf); return buf; } + +inline EscapeManipWriteBuffer operator<< (WriteBuffer & buf, EscapeManip) { return buf; } +inline QuoteManipWriteBuffer operator<< (WriteBuffer & buf, QuoteManip) { return buf; } +inline DoubleQuoteManipWriteBuffer operator<< (WriteBuffer & buf, DoubleQuoteManip) { return buf; } +inline BinaryManipWriteBuffer operator<< (WriteBuffer & buf, BinaryManip) { return buf; } +inline XMLManipWriteBuffer operator<< (WriteBuffer & buf, XMLManip) { return buf; } + +template <typename T> WriteBuffer & operator<< (EscapeManipWriteBuffer buf, const T & x) { writeText(x, buf.get()); return buf; } +template <typename T> WriteBuffer & operator<< (QuoteManipWriteBuffer buf, const T & x) { writeQuoted(x, buf.get()); return buf; } +template <typename T> WriteBuffer & operator<< (DoubleQuoteManipWriteBuffer buf, const T & x) { writeDoubleQuoted(x, buf.get()); return buf; } +template <typename T> WriteBuffer & operator<< (BinaryManipWriteBuffer buf, const T & x) { writeBinary(x, buf.get()); return buf; } +template <typename T> WriteBuffer & operator<< (XMLManipWriteBuffer buf, const T & x) { writeText(x, buf.get()); return buf; } + +inline WriteBuffer & operator<< (EscapeManipWriteBuffer buf, const String & x) { writeEscapedString(x, buf); return buf; } +inline WriteBuffer & operator<< (EscapeManipWriteBuffer buf, std::string_view x) { writeEscapedString(x, buf); return buf; } +inline WriteBuffer & operator<< (EscapeManipWriteBuffer buf, StringRef x) { writeEscapedString(x.toView(), buf); return buf; } +inline WriteBuffer & operator<< (EscapeManipWriteBuffer buf, const char * x) { writeEscapedString(x, strlen(x), buf); return buf; } + +inline WriteBuffer & operator<< (QuoteManipWriteBuffer buf, const char * x) { writeAnyQuotedString<'\''>(x, x + strlen(x), buf.get()); return buf; } +inline WriteBuffer & operator<< (DoubleQuoteManipWriteBuffer buf, const char * x) { writeAnyQuotedString<'"'>(x, x + strlen(x), buf.get()); return buf; } +inline WriteBuffer & operator<< (BinaryManipWriteBuffer buf, const char * x) { writeStringBinary(x, buf.get()); return buf; } + +inline WriteBuffer & operator<< (XMLManipWriteBuffer buf, std::string_view x) { writeXMLStringForTextElementOrAttributeValue(x, buf); return buf; } +inline WriteBuffer & operator<< (XMLManipWriteBuffer buf, StringRef x) { writeXMLStringForTextElementOrAttributeValue(x.toView(), buf); return buf; } +inline WriteBuffer & operator<< (XMLManipWriteBuffer buf, const char * x) { writeXMLStringForTextElementOrAttributeValue(std::string_view(x), buf); return buf; } + +/// The manipulator calls the WriteBuffer method `next` - this makes the buffer reset. For nested buffers, the reset is not recursive. +enum FlushManip { flush }; + +inline WriteBuffer & operator<< (WriteBuffer & buf, FlushManip) { buf.next(); return buf; } + + +template <typename T> ReadBuffer & operator>> (ReadBuffer & buf, T & x) { readText(x, buf); return buf; } +template <> inline ReadBuffer & operator>> (ReadBuffer & buf, String & x) { readString(x, buf); return buf; } +template <> inline ReadBuffer & operator>> (ReadBuffer & buf, char & x) { readChar(x, buf); return buf; } +template <> inline ReadBuffer & operator>> (ReadBuffer & buf, pcg32_fast & x) { PcgDeserializer::deserializePcg32(x, buf); return buf; } + +/// If you specify a string literal for reading, this will mean - make sure there is a sequence of bytes and skip it. +inline ReadBuffer & operator>> (ReadBuffer & buf, const char * x) { assertString(x, buf); return buf; } + +inline EscapeManipReadBuffer operator>> (ReadBuffer & buf, EscapeManip) { return buf; } +inline QuoteManipReadBuffer operator>> (ReadBuffer & buf, QuoteManip) { return buf; } +inline DoubleQuoteManipReadBuffer operator>> (ReadBuffer & buf, DoubleQuoteManip) { return buf; } +inline BinaryManipReadBuffer operator>> (ReadBuffer & buf, BinaryManip) { return buf; } + +template <typename T> ReadBuffer & operator>> (EscapeManipReadBuffer buf, T & x) { readText(x, buf.get()); return buf; } +template <typename T> ReadBuffer & operator>> (QuoteManipReadBuffer buf, T & x) { readQuoted(x, buf.get()); return buf; } +template <typename T> ReadBuffer & operator>> (DoubleQuoteManipReadBuffer buf, T & x) { readDoubleQuoted(x, buf.get()); return buf; } +template <typename T> ReadBuffer & operator>> (BinaryManipReadBuffer buf, T & x) { readBinary(x, buf.get()); return buf; } + +} diff --git a/contrib/clickhouse/src/IO/ParallelReadBuffer.cpp b/contrib/clickhouse/src/IO/ParallelReadBuffer.cpp new file mode 100644 index 0000000000..8d73f22174 --- /dev/null +++ b/contrib/clickhouse/src/IO/ParallelReadBuffer.cpp @@ -0,0 +1,307 @@ +#include <IO/ParallelReadBuffer.h> +#include <IO/SharedThreadPools.h> +#include <Poco/Logger.h> +#include <Common/logger_useful.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNEXPECTED_END_OF_FILE; + extern const int CANNOT_SEEK_THROUGH_FILE; + extern const int SEEK_POSITION_OUT_OF_BOUND; + +} + +// A subrange of the input, read by one thread. +struct ParallelReadBuffer::ReadWorker +{ + ReadWorker(SeekableReadBuffer & input_, size_t offset, size_t size) + : input(input_), start_offset(offset), segment(size) + { + chassert(size); + chassert(segment.size() == size); + } + + bool hasBytesToConsume() const { return bytes_produced > bytes_consumed; } + bool hasBytesToProduce() const { return bytes_produced < segment.size(); } + + SeekableReadBuffer & input; + const size_t start_offset; // start of the segment + + Memory<> segment; + /// Reader thread produces data, nextImpl() consumes it. + /// segment[bytes_consumed..bytes_produced-1] is data waiting to be picked up by nextImpl() + /// segment[bytes_produced..] needs to be read from the input ReadBuffer + size_t bytes_produced = 0; + size_t bytes_consumed = 0; + + std::atomic_bool cancel{false}; + std::mutex worker_mutex; +}; + +ParallelReadBuffer::ParallelReadBuffer( + SeekableReadBuffer & input_, ThreadPoolCallbackRunner<void> schedule_, size_t max_working_readers_, size_t range_step_, size_t file_size_) + : SeekableReadBuffer(nullptr, 0) + , max_working_readers(max_working_readers_) + , schedule(std::move(schedule_)) + , input(input_) + , file_size(file_size_) + , range_step(std::max(1ul, range_step_)) +{ + LOG_TRACE(&Poco::Logger::get("ParallelReadBuffer"), "Parallel reading is used"); + + try + { + addReaders(); + } + catch (const Exception &) + { + finishAndWait(); + throw; + } +} + +bool ParallelReadBuffer::addReaderToPool() +{ + if (next_range_start >= file_size) + return false; + size_t range_start = next_range_start; + size_t size = std::min(range_step, file_size - range_start); + next_range_start += size; + + auto worker = read_workers.emplace_back(std::make_shared<ReadWorker>(input, range_start, size)); + + ++active_working_readers; + schedule([this, my_worker = std::move(worker)]() mutable { readerThreadFunction(std::move(my_worker)); }, Priority{}); + + return true; +} + +void ParallelReadBuffer::addReaders() +{ + while (read_workers.size() < max_working_readers && addReaderToPool()) + ; +} + +off_t ParallelReadBuffer::seek(off_t offset, int whence) +{ + if (whence != SEEK_SET) + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, "Only SEEK_SET mode is allowed."); + + if (offset < 0) + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bounds. Offset: {}", offset); + + if (!working_buffer.empty() && static_cast<size_t>(offset) >= current_position - working_buffer.size() && offset < current_position) + { + pos = working_buffer.end() - (current_position - offset); + assert(pos >= working_buffer.begin()); + assert(pos <= working_buffer.end()); + + return offset; + } + + const auto offset_is_in_range + = [&](const auto & worker) { return static_cast<size_t>(offset) >= worker->start_offset && static_cast<size_t>(offset) < worker->start_offset + worker->segment.size(); }; + + while (!read_workers.empty() && !offset_is_in_range(read_workers.front())) + { + read_workers.front()->cancel = true; + read_workers.pop_front(); + } + + if (!read_workers.empty()) + { + auto & w = read_workers.front(); + size_t diff = static_cast<size_t>(offset) - w->start_offset; + while (true) + { + std::unique_lock lock{w->worker_mutex}; + + if (emergency_stop) + handleEmergencyStop(); + + if (w->bytes_produced > diff) + { + working_buffer = internal_buffer = Buffer( + w->segment.data(), w->segment.data() + w->bytes_produced); + pos = working_buffer.begin() + diff; + w->bytes_consumed = w->bytes_produced; + current_position = w->start_offset + w->bytes_consumed; + addReaders(); + return offset; + } + + next_condvar.wait_for(lock, std::chrono::seconds(10)); + } + } + + finishAndWait(); + + read_workers.clear(); + + next_range_start = offset; + current_position = offset; + resetWorkingBuffer(); + + emergency_stop = false; + + addReaders(); + return offset; +} + +size_t ParallelReadBuffer::getFileSize() +{ + return file_size; +} + +off_t ParallelReadBuffer::getPosition() +{ + return current_position - available(); +} + +void ParallelReadBuffer::handleEmergencyStop() +{ + // this can only be called from the main thread when there is an exception + assert(background_exception); + std::rethrow_exception(background_exception); +} + +bool ParallelReadBuffer::nextImpl() +{ + while (true) + { + /// All readers processed, stop + if (read_workers.empty()) + { + chassert(next_range_start >= file_size); + return false; + } + + auto * w = read_workers.front().get(); + + std::unique_lock lock{w->worker_mutex}; + + if (emergency_stop) + handleEmergencyStop(); // throws + + /// Read data from front reader + if (w->bytes_produced > w->bytes_consumed) + { + chassert(w->start_offset + w->bytes_consumed == static_cast<size_t>(current_position)); + + working_buffer = internal_buffer = Buffer( + w->segment.data() + w->bytes_consumed, w->segment.data() + w->bytes_produced); + current_position += working_buffer.size(); + w->bytes_consumed = w->bytes_produced; + + return true; + } + + /// Front reader is done, remove it and add another + if (!w->hasBytesToProduce()) + { + lock.unlock(); + read_workers.pop_front(); + addReaders(); + + continue; + } + + /// Nothing to do right now, wait for something to change. + /// + /// The timeout is a workaround for a race condition. + /// emergency_stop is assigned while holding a *different* mutex from the one we're holding + /// (exception_mutex vs worker_mutex). So it's possible that our emergency_stop check (above) + /// happens before a onBackgroundException() call, but our wait(lock) happens after it. + /// Then the wait may get stuck forever. + /// + /// Note that using wait(lock, [&]{ return emergency_stop || ...; }) wouldn't help because + /// it does effectively the same "check, then wait" sequence. + /// + /// One possible proper fix would be to make onBackgroundException() lock all read_workers + /// mutexes too (not necessarily simultaneously - just locking+unlocking them one by one + /// between the emergency_stop change and the notify_all() would be enough), but then we + /// need another mutex to protect read_workers itself... + next_condvar.wait_for(lock, std::chrono::seconds(10)); + } + chassert(false); + return false; +} + +void ParallelReadBuffer::readerThreadFunction(ReadWorkerPtr read_worker) +{ + SCOPE_EXIT({ + if (active_working_readers.fetch_sub(1) == 1) + active_working_readers.notify_all(); + }); + + try + { + auto on_progress = [&](size_t bytes_read) -> bool + { + if (emergency_stop || read_worker->cancel) + return true; + + std::lock_guard lock(read_worker->worker_mutex); + if (bytes_read <= read_worker->bytes_produced) + return false; + + bool need_notify = read_worker->bytes_produced == read_worker->bytes_consumed; + read_worker->bytes_produced = bytes_read; + if (need_notify) + next_condvar.notify_all(); + + return false; + }; + + size_t r = input.readBigAt(read_worker->segment.data(), read_worker->segment.size(), read_worker->start_offset, on_progress); + + if (!on_progress(r) && r < read_worker->segment.size()) + throw Exception( + ErrorCodes::UNEXPECTED_END_OF_FILE, + "Failed to read all the data from the reader at offset {}, got {}/{} bytes", + read_worker->start_offset, r, read_worker->segment.size()); + } + catch (...) + { + onBackgroundException(); + } +} + +void ParallelReadBuffer::onBackgroundException() +{ + std::lock_guard lock{exception_mutex}; + if (!background_exception) + background_exception = std::current_exception(); + + emergency_stop = true; + next_condvar.notify_all(); +} + +void ParallelReadBuffer::finishAndWait() +{ + emergency_stop = true; + + size_t active_readers = active_working_readers.load(); + while (active_readers != 0) + { + active_working_readers.wait(active_readers); + active_readers = active_working_readers.load(); + } +} + +std::unique_ptr<ParallelReadBuffer> wrapInParallelReadBufferIfSupported( + ReadBuffer & buf, ThreadPoolCallbackRunner<void> schedule, size_t max_working_readers, + size_t range_step, size_t file_size) +{ + auto * seekable = dynamic_cast<SeekableReadBuffer*>(&buf); + if (!seekable || !seekable->supportsReadAt()) + return nullptr; + + return std::make_unique<ParallelReadBuffer>( + *seekable, schedule, max_working_readers, range_step, file_size); +} + +} diff --git a/contrib/clickhouse/src/IO/ParallelReadBuffer.h b/contrib/clickhouse/src/IO/ParallelReadBuffer.h new file mode 100644 index 0000000000..e76b40f77b --- /dev/null +++ b/contrib/clickhouse/src/IO/ParallelReadBuffer.h @@ -0,0 +1,100 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/ReadBuffer.h> +#include <IO/SeekableReadBuffer.h> +#include <Interpreters/threadPoolCallbackRunner.h> +#include <Common/ArenaWithFreeLists.h> + +namespace DB +{ + +/** + * Reads from multiple positions in a ReadBuffer in parallel. + * Then reassembles the data into one stream in the original order. + * + * Each working reader reads its segment of data into a buffer. + * + * ParallelReadBuffer in nextImpl method take first available segment from first reader in deque and reports it it to user. + * When first reader finishes reading, they will be removed from worker deque and data from next reader consumed. + * + * Number of working readers limited by max_working_readers. + */ +class ParallelReadBuffer : public SeekableReadBuffer, public WithFileSize +{ +private: + /// Blocks until data occurred in the first reader or this reader indicate finishing + /// Finished readers removed from queue and data from next readers processed + bool nextImpl() override; + +public: + ParallelReadBuffer(SeekableReadBuffer & input, ThreadPoolCallbackRunner<void> schedule_, size_t max_working_readers, size_t range_step_, size_t file_size); + + ~ParallelReadBuffer() override { finishAndWait(); } + + off_t seek(off_t off, int whence) override; + size_t getFileSize() override; + off_t getPosition() override; + + const SeekableReadBuffer & getReadBuffer() const { return input; } + SeekableReadBuffer & getReadBuffer() { return input; } + +private: + /// Reader in progress with a buffer for the segment + struct ReadWorker; + using ReadWorkerPtr = std::shared_ptr<ReadWorker>; + + /// First worker in deque have new data or processed all available amount + bool currentWorkerReady() const; + /// First worker in deque processed and flushed all data + bool currentWorkerCompleted() const; + + [[noreturn]] void handleEmergencyStop(); + + void addReaders(); + bool addReaderToPool(); + + /// Process read_worker, read data and save into the buffer + void readerThreadFunction(ReadWorkerPtr read_worker); + + void onBackgroundException(); + void finishAndWait(); + + size_t max_working_readers; + std::atomic_size_t active_working_readers{0}; + + ThreadPoolCallbackRunner<void> schedule; + + SeekableReadBuffer & input; + size_t file_size; + size_t range_step; + size_t next_range_start{0}; + + /** + * FIFO queue of readers. + * Each worker contains a buffer for the downloaded segment. + * After all data for the segment is read and delivered to the user, the reader will be removed + * from deque and data from next reader will be delivered. + * After removing from deque, call addReaders(). + */ + std::deque<ReadWorkerPtr> read_workers; + + /// Triggered when new data available + std::condition_variable next_condvar; + + std::mutex exception_mutex; + std::exception_ptr background_exception = nullptr; + std::atomic_bool emergency_stop{false}; + + off_t current_position{0}; // end of working_buffer + + bool all_completed{false}; +}; + +/// If `buf` is a SeekableReadBuffer with supportsReadAt() == true, creates a ParallelReadBuffer +/// from it. Otherwise returns nullptr; +std::unique_ptr<ParallelReadBuffer> wrapInParallelReadBufferIfSupported( + ReadBuffer & buf, ThreadPoolCallbackRunner<void> schedule, size_t max_working_readers, + size_t range_step, size_t file_size); + +} diff --git a/contrib/clickhouse/src/IO/PeekableReadBuffer.cpp b/contrib/clickhouse/src/IO/PeekableReadBuffer.cpp new file mode 100644 index 0000000000..ce9c20e7a5 --- /dev/null +++ b/contrib/clickhouse/src/IO/PeekableReadBuffer.cpp @@ -0,0 +1,378 @@ +#include <IO/PeekableReadBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +PeekableReadBuffer::PeekableReadBuffer(ReadBuffer & sub_buf_, size_t start_size_ /*= 0*/) + : BufferWithOwnMemory(start_size_), sub_buf(&sub_buf_) +{ + padded &= sub_buf->isPadded(); + /// Read from sub-buffer + Buffer & sub_working = sub_buf->buffer(); + BufferBase::set(sub_working.begin(), sub_working.size(), sub_buf->offset()); + + checkStateCorrect(); +} + +void PeekableReadBuffer::reset() +{ + checkStateCorrect(); +} + +void PeekableReadBuffer::setSubBuffer(ReadBuffer & sub_buf_) +{ + sub_buf = &sub_buf_; + resetImpl(); +} + +void PeekableReadBuffer::resetImpl() +{ + peeked_size = 0; + checkpoint = std::nullopt; + checkpoint_in_own_memory = false; + use_stack_memory = true; + + if (!currentlyReadFromOwnMemory()) + sub_buf->position() = pos; + + Buffer & sub_working = sub_buf->buffer(); + BufferBase::set(sub_working.begin(), sub_working.size(), sub_buf->offset()); + + checkStateCorrect(); +} + +bool PeekableReadBuffer::peekNext() +{ + checkStateCorrect(); + + Position copy_from = pos; + size_t bytes_to_copy = sub_buf->available(); + if (useSubbufferOnly()) + { + /// Don't have to copy all data from sub-buffer if there is no data in own memory (checkpoint and pos are in sub-buffer) + if (checkpoint) + copy_from = *checkpoint; + bytes_to_copy = sub_buf->buffer().end() - copy_from; + if (!bytes_to_copy) + { + sub_buf->position() = copy_from; + + /// Both checkpoint and pos are at the end of sub-buffer. Just load next part of data. + bool res = sub_buf->next(); + BufferBase::set(sub_buf->buffer().begin(), sub_buf->buffer().size(), sub_buf->offset()); + if (checkpoint) + checkpoint.emplace(pos); + + checkStateCorrect(); + return res; + } + } + + /// May throw an exception + resizeOwnMemoryIfNecessary(bytes_to_copy); + + if (useSubbufferOnly()) + { + sub_buf->position() = copy_from; + } + + char * memory_data = getMemoryData(); + + /// Save unread data from sub-buffer to own memory + memcpy(memory_data + peeked_size, sub_buf->position(), bytes_to_copy); + + /// If useSubbufferOnly() is false, then checkpoint is in own memory and it was updated in resizeOwnMemoryIfNecessary + /// Otherwise, checkpoint now at the beginning of own memory + if (checkpoint && useSubbufferOnly()) + { + checkpoint.emplace(memory_data); + checkpoint_in_own_memory = true; + } + + if (currentlyReadFromOwnMemory()) + { + /// Update buffer size + BufferBase::set(memory_data, peeked_size + bytes_to_copy, offset()); + } + else + { + /// Switch to reading from own memory + size_t pos_offset = peeked_size + this->offset(); + if (useSubbufferOnly()) + { + if (checkpoint) + pos_offset = bytes_to_copy; + else + pos_offset = 0; + } + BufferBase::set(memory_data, peeked_size + bytes_to_copy, pos_offset); + } + + peeked_size += bytes_to_copy; + sub_buf->position() += bytes_to_copy; + + checkStateCorrect(); + return sub_buf->next(); +} + +void PeekableReadBuffer::rollbackToCheckpoint(bool drop) +{ + checkStateCorrect(); + + assert(checkpoint); + + if (recursive_checkpoints_offsets.empty()) + { + if (checkpointInOwnMemory() == currentlyReadFromOwnMemory()) + { + /// Both checkpoint and position are in the same buffer. + pos = *checkpoint; + } + else + { + /// Checkpoint is in own memory and position is not. + assert(checkpointInOwnMemory()); + + char * memory_data = getMemoryData(); + /// Switch to reading from own memory. + BufferBase::set(memory_data, peeked_size, *checkpoint - memory_data); + } + } + else + { + size_t offset_from_checkpoint = recursive_checkpoints_offsets.top(); + if (checkpointInOwnMemory() == currentlyReadFromOwnMemory()) + { + /// Both checkpoint and position are in the same buffer. + pos = *checkpoint + offset_from_checkpoint; + } + else + { + /// Checkpoint is in own memory and position is not. + assert(checkpointInOwnMemory()); + + size_t offset_from_checkpoint_in_own_memory = offsetFromCheckpointInOwnMemory(); + if (offset_from_checkpoint >= offset_from_checkpoint_in_own_memory) + { + /// Recursive checkpoint is in sub buffer with current position. + /// Just move position to the recursive checkpoint + pos = buffer().begin() + (offset_from_checkpoint - offset_from_checkpoint_in_own_memory); + } + else + { + /// Recursive checkpoint is in own memory and position is not. + /// Switch to reading from own memory. + char * memory_data = getMemoryData(); + BufferBase::set(memory_data, peeked_size, *checkpoint - memory_data + offset_from_checkpoint); + } + } + } + + if (drop) + dropCheckpoint(); + + checkStateCorrect(); +} + +bool PeekableReadBuffer::nextImpl() +{ + /// FIXME: wrong bytes count because it can read the same data again after rollbackToCheckpoint() + /// however, changing bytes count on every call of next() (even after rollback) allows to determine + /// if some pointers were invalidated. + + checkStateCorrect(); + bool res; + bool checkpoint_at_end = checkpoint && *checkpoint == working_buffer.end() && currentlyReadFromOwnMemory(); + + if (checkpoint) + { + if (currentlyReadFromOwnMemory()) + res = sub_buf->hasPendingData() || sub_buf->next(); + else + res = peekNext(); + } + else + { + if (useSubbufferOnly()) + { + /// Load next data to sub_buf + sub_buf->position() = position(); + res = sub_buf->next(); + } + else + { + /// All copied data have been read from own memory, continue reading from sub_buf + peeked_size = 0; + res = sub_buf->hasPendingData() || sub_buf->next(); + } + } + + /// Switch to reading from sub_buf (or just update it if already switched) + Buffer & sub_working = sub_buf->buffer(); + BufferBase::set(sub_working.begin(), sub_working.size(), sub_buf->offset()); + nextimpl_working_buffer_offset = sub_buf->offset(); + + if (checkpoint_at_end) + { + checkpoint.emplace(position()); + peeked_size = 0; + checkpoint_in_own_memory = false; + } + + checkStateCorrect(); + return res; +} + + +void PeekableReadBuffer::checkStateCorrect() const +{ + if (checkpoint) + { + if (checkpointInOwnMemory()) + { + if (!peeked_size) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Checkpoint in empty own buffer"); + if (currentlyReadFromOwnMemory() && pos < *checkpoint) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Current position in own buffer before checkpoint in own buffer"); + if (!currentlyReadFromOwnMemory() && pos < sub_buf->position()) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Current position in subbuffer less than sub_buf->position()"); + } + else + { + if (peeked_size) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Own buffer is not empty"); + if (currentlyReadFromOwnMemory()) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Current position in own buffer before checkpoint in subbuffer"); + if (pos < *checkpoint) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Current position in subbuffer before checkpoint in subbuffer"); + } + } + else + { + if (!currentlyReadFromOwnMemory() && peeked_size) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Own buffer is not empty"); + } + if (currentlyReadFromOwnMemory() && !peeked_size) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Pos in empty own buffer"); +} + +void PeekableReadBuffer::resizeOwnMemoryIfNecessary(size_t bytes_to_append) +{ + checkStateCorrect(); + bool need_update_checkpoint = checkpointInOwnMemory(); + bool need_update_pos = currentlyReadFromOwnMemory(); + size_t offset = 0; + if (need_update_checkpoint) + { + char * memory_data = getMemoryData(); + offset = *checkpoint - memory_data; + } + else if (need_update_pos) + offset = this->offset(); + + size_t new_size = peeked_size + bytes_to_append; + + if (use_stack_memory) + { + /// If stack memory is still enough, do nothing. + if (sizeof(stack_memory) >= new_size) + return; + + /// Stack memory is not enough, allocate larger buffer. + use_stack_memory = false; + memory.resize(std::max(static_cast<size_t>(DBMS_DEFAULT_BUFFER_SIZE), new_size)); + memcpy(memory.data(), stack_memory, sizeof(stack_memory)); + if (need_update_checkpoint) + checkpoint.emplace(memory.data() + offset); + if (need_update_pos) + BufferBase::set(memory.data(), peeked_size, pos - stack_memory); + } + else if (memory.size() < new_size) + { + if (bytes_to_append < offset && 2 * (peeked_size - offset) <= memory.size()) + { + /// Move unread data to the beginning of own memory instead of resize own memory + peeked_size -= offset; + memmove(memory.data(), memory.data() + offset, peeked_size); + + if (need_update_checkpoint) + *checkpoint -= offset; + if (need_update_pos) + pos -= offset; + } + else + { + size_t pos_offset = pos - memory.data(); + + size_t new_size_amortized = memory.size() * 2; + if (new_size_amortized < new_size) + new_size_amortized = new_size; + memory.resize(new_size_amortized); + + if (need_update_checkpoint) + checkpoint.emplace(memory.data() + offset); + if (need_update_pos) + { + BufferBase::set(memory.data(), peeked_size, pos_offset); + } + } + } + checkStateCorrect(); +} + +void PeekableReadBuffer::makeContinuousMemoryFromCheckpointToPos() +{ + if (!checkpoint) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "There is no checkpoint"); + checkStateCorrect(); + + if (!checkpointInOwnMemory() || currentlyReadFromOwnMemory()) + return; /// it's already continuous + + size_t bytes_to_append = pos - sub_buf->position(); + resizeOwnMemoryIfNecessary(bytes_to_append); + char * memory_data = getMemoryData(); + memcpy(memory_data + peeked_size, sub_buf->position(), bytes_to_append); + sub_buf->position() = pos; + peeked_size += bytes_to_append; + BufferBase::set(memory_data, peeked_size, peeked_size); +} + +PeekableReadBuffer::~PeekableReadBuffer() +{ + if (!currentlyReadFromOwnMemory()) + sub_buf->position() = pos; +} + +bool PeekableReadBuffer::hasUnreadData() const +{ + return peeked_size && pos != getMemoryData() + peeked_size; +} + +size_t PeekableReadBuffer::offsetFromCheckpointInOwnMemory() const +{ + return peeked_size - (*checkpoint - getMemoryData()); +} + +size_t PeekableReadBuffer::offsetFromCheckpoint() const +{ + if (!checkpoint) + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "There is no checkpoint"); + + if (checkpointInOwnMemory() == currentlyReadFromOwnMemory()) + { + /// Checkpoint and pos are in the same buffer. + return pos - *checkpoint; + } + + /// Checkpoint is in own memory, position is in sub buffer. + return offset() + offsetFromCheckpointInOwnMemory(); +} + +} diff --git a/contrib/clickhouse/src/IO/PeekableReadBuffer.h b/contrib/clickhouse/src/IO/PeekableReadBuffer.h new file mode 100644 index 0000000000..78cb319327 --- /dev/null +++ b/contrib/clickhouse/src/IO/PeekableReadBuffer.h @@ -0,0 +1,143 @@ +#pragma once +#include <IO/ReadBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <stack> + +namespace DB +{ + +/// Also allows to set checkpoint at some position in stream and come back to this position later. +/// When next() is called, saves data between checkpoint and current position to own memory and loads next data to sub-buffer +/// Sub-buffer should not be accessed directly during the lifetime of peekable buffer (unless +/// you reset() the state of peekable buffer after each change of underlying buffer) +/// If position() of peekable buffer is explicitly set to some position before checkpoint +/// (e.g. by istr.position() = prev_pos), behavior is undefined. +class PeekableReadBuffer : public BufferWithOwnMemory<ReadBuffer> +{ + friend class PeekableReadBufferCheckpoint; +public: + explicit PeekableReadBuffer(ReadBuffer & sub_buf_, size_t start_size_ = 0); + + ~PeekableReadBuffer() override; + + void prefetch(Priority priority) override { sub_buf->prefetch(priority); } + + /// Sets checkpoint at current position + ALWAYS_INLINE inline void setCheckpoint() + { + if (checkpoint) + { + /// Recursive checkpoints. We just remember offset from the + /// first checkpoint to the current position. + recursive_checkpoints_offsets.push(offsetFromCheckpoint()); + return; + } + + checkpoint_in_own_memory = currentlyReadFromOwnMemory(); + if (!checkpoint_in_own_memory) + { + /// Don't need to store unread data anymore + peeked_size = 0; + } + checkpoint.emplace(pos); + } + + /// Forget checkpoint and all data between checkpoint and position + ALWAYS_INLINE inline void dropCheckpoint() + { + assert(checkpoint); + + if (!recursive_checkpoints_offsets.empty()) + { + recursive_checkpoints_offsets.pop(); + return; + } + + if (!currentlyReadFromOwnMemory()) + { + /// Don't need to store unread data anymore + peeked_size = 0; + } + checkpoint = std::nullopt; + checkpoint_in_own_memory = false; + } + + /// Sets position at checkpoint. + /// All pointers (such as this->buffer().end()) may be invalidated + void rollbackToCheckpoint(bool drop = false); + + /// If checkpoint and current position are in different buffers, appends data from sub-buffer to own memory, + /// so data between checkpoint and position will be in continuous memory. + void makeContinuousMemoryFromCheckpointToPos(); + + /// Returns true if there unread data extracted from sub-buffer in own memory. + /// This data will be lost after destruction of peekable buffer. + bool hasUnreadData() const; + + // for streaming reading (like in Kafka) we need to restore initial state of the buffer + // without recreating the buffer. + void reset(); + + void setSubBuffer(ReadBuffer & sub_buf_); + + const ReadBuffer & getSubBuffer() const { return *sub_buf; } + +private: + bool nextImpl() override; + + void resetImpl(); + + bool peekNext(); + + inline bool useSubbufferOnly() const { return !peeked_size; } + inline bool currentlyReadFromOwnMemory() const { return working_buffer.begin() != sub_buf->buffer().begin(); } + inline bool checkpointInOwnMemory() const { return checkpoint_in_own_memory; } + + void checkStateCorrect() const; + + /// Makes possible to append `bytes_to_append` bytes to data in own memory. + /// Updates all invalidated pointers and sizes. + void resizeOwnMemoryIfNecessary(size_t bytes_to_append); + + char * getMemoryData() { return use_stack_memory ? stack_memory : memory.data(); } + const char * getMemoryData() const { return use_stack_memory ? stack_memory : memory.data(); } + + size_t offsetFromCheckpointInOwnMemory() const; + size_t offsetFromCheckpoint() const; + + + ReadBuffer * sub_buf; + size_t peeked_size = 0; + std::optional<Position> checkpoint = std::nullopt; + bool checkpoint_in_own_memory = false; + + /// To prevent expensive and in some cases unnecessary memory allocations on PeekableReadBuffer + /// creation (for example if PeekableReadBuffer is often created or if we need to remember small amount of + /// data after checkpoint), at the beginning we will use small amount of memory on stack and allocate + /// larger buffer only if reserved memory is not enough. + char stack_memory[PADDING_FOR_SIMD]; + bool use_stack_memory = true; + + std::stack<size_t> recursive_checkpoints_offsets; +}; + + +class PeekableReadBufferCheckpoint : boost::noncopyable +{ + PeekableReadBuffer & buf; + bool auto_rollback; +public: + explicit PeekableReadBufferCheckpoint(PeekableReadBuffer & buf_, bool auto_rollback_ = false) + : buf(buf_), auto_rollback(auto_rollback_) { buf.setCheckpoint(); } + ~PeekableReadBufferCheckpoint() + { + if (!buf.checkpoint) + return; + if (auto_rollback) + buf.rollbackToCheckpoint(); + buf.dropCheckpoint(); + } + +}; + +} diff --git a/contrib/clickhouse/src/IO/Progress.cpp b/contrib/clickhouse/src/IO/Progress.cpp new file mode 100644 index 0000000000..145281e140 --- /dev/null +++ b/contrib/clickhouse/src/IO/Progress.cpp @@ -0,0 +1,240 @@ +#include "Progress.h" + +#include <IO/ReadBuffer.h> +#include <IO/WriteBuffer.h> +#include <IO/ReadHelpers.h> +#include <IO/WriteHelpers.h> +#include <Core/ProtocolDefines.h> + + +namespace DB +{ + +namespace +{ + UInt64 getApproxTotalRowsToRead(UInt64 read_rows, UInt64 read_bytes, UInt64 total_bytes_to_read) + { + if (!read_rows || !read_bytes) + return 0; + + auto bytes_per_row = std::ceil(static_cast<double>(read_bytes) / read_rows); + return static_cast<UInt64>(std::ceil(static_cast<double>(total_bytes_to_read) / bytes_per_row)); + } +} + +void ProgressValues::read(ReadBuffer & in, UInt64 server_revision) +{ + readVarUInt(read_rows, in); + readVarUInt(read_bytes, in); + readVarUInt(total_rows_to_read, in); + if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_TOTAL_BYTES_IN_PROGRESS) + { + readVarUInt(total_bytes_to_read, in); + } + if (server_revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO) + { + readVarUInt(written_rows, in); + readVarUInt(written_bytes, in); + } + if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_SERVER_QUERY_TIME_IN_PROGRESS) + { + readVarUInt(elapsed_ns, in); + } +} + + +void ProgressValues::write(WriteBuffer & out, UInt64 client_revision) const +{ + writeVarUInt(read_rows, out); + writeVarUInt(read_bytes, out); + /// In new TCP protocol we can send total_bytes_to_read without total_rows_to_read. + /// If client doesn't support total_bytes_to_read, send approx total_rows_to_read + /// to indicate at least approx progress. + if (client_revision < DBMS_MIN_PROTOCOL_VERSION_WITH_TOTAL_BYTES_IN_PROGRESS && total_bytes_to_read && !total_rows_to_read) + writeVarUInt(getApproxTotalRowsToRead(read_rows, read_bytes, total_bytes_to_read), out); + else + writeVarUInt(total_rows_to_read, out); + if (client_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_TOTAL_BYTES_IN_PROGRESS) + { + writeVarUInt(total_bytes_to_read, out); + } + if (client_revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO) + { + writeVarUInt(written_rows, out); + writeVarUInt(written_bytes, out); + } + if (client_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_SERVER_QUERY_TIME_IN_PROGRESS) + { + writeVarUInt(elapsed_ns, out); + } +} + +void ProgressValues::writeJSON(WriteBuffer & out) const +{ + /// Numbers are written in double quotes (as strings) to avoid loss of precision + /// of 64-bit integers after interpretation by JavaScript. + + writeCString("{", out); + writeCString("\"read_rows\":\"", out); + writeText(read_rows, out); + writeCString("\",\"read_bytes\":\"", out); + writeText(read_bytes, out); + writeCString("\",\"written_rows\":\"", out); + writeText(written_rows, out); + writeCString("\",\"written_bytes\":\"", out); + writeText(written_bytes, out); + writeCString("\",\"total_rows_to_read\":\"", out); + writeText(total_rows_to_read, out); + writeCString("\",\"result_rows\":\"", out); + writeText(result_rows, out); + writeCString("\",\"result_bytes\":\"", out); + writeText(result_bytes, out); + writeCString("\"", out); + writeCString("}", out); +} + +bool Progress::incrementPiecewiseAtomically(const Progress & rhs) +{ + read_rows += rhs.read_rows; + read_bytes += rhs.read_bytes; + + total_rows_to_read += rhs.total_rows_to_read; + total_bytes_to_read += rhs.total_bytes_to_read; + + written_rows += rhs.written_rows; + written_bytes += rhs.written_bytes; + + result_rows += rhs.result_rows; + result_bytes += rhs.result_bytes; + + elapsed_ns += rhs.elapsed_ns; + + return rhs.read_rows || rhs.written_rows; +} + +void Progress::reset() +{ + read_rows = 0; + read_bytes = 0; + + total_rows_to_read = 0; + total_bytes_to_read = 0; + + written_rows = 0; + written_bytes = 0; + + result_rows = 0; + result_bytes = 0; + + elapsed_ns = 0; +} + +ProgressValues Progress::getValues() const +{ + ProgressValues res; + + res.read_rows = read_rows.load(std::memory_order_relaxed); + res.read_bytes = read_bytes.load(std::memory_order_relaxed); + + res.total_rows_to_read = total_rows_to_read.load(std::memory_order_relaxed); + res.total_bytes_to_read = total_bytes_to_read.load(std::memory_order_relaxed); + + res.written_rows = written_rows.load(std::memory_order_relaxed); + res.written_bytes = written_bytes.load(std::memory_order_relaxed); + + res.result_rows = result_rows.load(std::memory_order_relaxed); + res.result_bytes = result_bytes.load(std::memory_order_relaxed); + + res.elapsed_ns = elapsed_ns.load(std::memory_order_relaxed); + + return res; +} + +ProgressValues Progress::fetchValuesAndResetPiecewiseAtomically() +{ + ProgressValues res; + + res.read_rows = read_rows.fetch_and(0); + res.read_bytes = read_bytes.fetch_and(0); + + res.total_rows_to_read = total_rows_to_read.fetch_and(0); + res.total_bytes_to_read = total_bytes_to_read.fetch_and(0); + + res.written_rows = written_rows.fetch_and(0); + res.written_bytes = written_bytes.fetch_and(0); + + res.result_rows = result_rows.fetch_and(0); + res.result_bytes = result_bytes.fetch_and(0); + + res.elapsed_ns = elapsed_ns.fetch_and(0); + + return res; +} + +Progress Progress::fetchAndResetPiecewiseAtomically() +{ + Progress res; + + res.read_rows = read_rows.fetch_and(0); + res.read_bytes = read_bytes.fetch_and(0); + + res.total_rows_to_read = total_rows_to_read.fetch_and(0); + res.total_bytes_to_read = total_bytes_to_read.fetch_and(0); + + res.written_rows = written_rows.fetch_and(0); + res.written_bytes = written_bytes.fetch_and(0); + + res.result_rows = result_rows.fetch_and(0); + res.result_bytes = result_bytes.fetch_and(0); + + res.elapsed_ns = elapsed_ns.fetch_and(0); + + return res; +} + +Progress & Progress::operator=(Progress && other) noexcept +{ + read_rows = other.read_rows.load(std::memory_order_relaxed); + read_bytes = other.read_bytes.load(std::memory_order_relaxed); + + total_rows_to_read = other.total_rows_to_read.load(std::memory_order_relaxed); + total_bytes_to_read = other.total_bytes_to_read.load(std::memory_order_relaxed); + + written_rows = other.written_rows.load(std::memory_order_relaxed); + written_bytes = other.written_bytes.load(std::memory_order_relaxed); + + result_rows = other.result_rows.load(std::memory_order_relaxed); + result_bytes = other.result_bytes.load(std::memory_order_relaxed); + + elapsed_ns = other.elapsed_ns.load(std::memory_order_relaxed); + + return *this; +} + +void Progress::read(ReadBuffer & in, UInt64 server_revision) +{ + ProgressValues values; + values.read(in, server_revision); + + read_rows.store(values.read_rows, std::memory_order_relaxed); + read_bytes.store(values.read_bytes, std::memory_order_relaxed); + total_rows_to_read.store(values.total_rows_to_read, std::memory_order_relaxed); + total_bytes_to_read.store(values.total_bytes_to_read, std::memory_order_relaxed); + + written_rows.store(values.written_rows, std::memory_order_relaxed); + written_bytes.store(values.written_bytes, std::memory_order_relaxed); + + elapsed_ns.store(values.elapsed_ns, std::memory_order_relaxed); +} + +void Progress::write(WriteBuffer & out, UInt64 client_revision) const +{ + getValues().write(out, client_revision); +} + +void Progress::writeJSON(WriteBuffer & out) const +{ + getValues().writeJSON(out); +} + +} diff --git a/contrib/clickhouse/src/IO/Progress.h b/contrib/clickhouse/src/IO/Progress.h new file mode 100644 index 0000000000..a68ff9bc5c --- /dev/null +++ b/contrib/clickhouse/src/IO/Progress.h @@ -0,0 +1,151 @@ +#pragma once + +#include <atomic> +#include <cstddef> +#include <functional> +#include <base/types.h> + +#include <Core/Defines.h> + +namespace DB +{ + +class ReadBuffer; +class WriteBuffer; + +/// See Progress. +struct ProgressValues +{ + UInt64 read_rows = 0; + UInt64 read_bytes = 0; + + UInt64 total_rows_to_read = 0; + UInt64 total_bytes_to_read = 0; + + UInt64 written_rows = 0; + UInt64 written_bytes = 0; + + UInt64 result_rows = 0; + UInt64 result_bytes = 0; + + UInt64 elapsed_ns = 0; + + void read(ReadBuffer & in, UInt64 server_revision); + void write(WriteBuffer & out, UInt64 client_revision) const; + void writeJSON(WriteBuffer & out) const; +}; + +struct ReadProgress +{ + UInt64 read_rows = 0; + UInt64 read_bytes = 0; + UInt64 total_rows_to_read = 0; + UInt64 total_bytes_to_read = 0; + + ReadProgress(UInt64 read_rows_, UInt64 read_bytes_, UInt64 total_rows_to_read_ = 0, UInt64 total_bytes_to_read_ = 0) + : read_rows(read_rows_), read_bytes(read_bytes_), total_rows_to_read(total_rows_to_read_), total_bytes_to_read(total_bytes_to_read_) {} +}; + +struct WriteProgress +{ + UInt64 written_rows = 0; + UInt64 written_bytes = 0; + + WriteProgress(UInt64 written_rows_, UInt64 written_bytes_) + : written_rows(written_rows_), written_bytes(written_bytes_) {} +}; + +struct ResultProgress +{ + UInt64 result_rows = 0; + UInt64 result_bytes = 0; + + ResultProgress(UInt64 result_rows_, UInt64 result_bytes_) + : result_rows(result_rows_), result_bytes(result_bytes_) {} +}; + +struct FileProgress +{ + /// Here read_bytes (raw bytes) - do not equal ReadProgress::read_bytes, which are calculated according to column types. + UInt64 read_bytes = 0; + UInt64 total_bytes_to_read = 0; + + explicit FileProgress(UInt64 read_bytes_, UInt64 total_bytes_to_read_ = 0) : read_bytes(read_bytes_), total_bytes_to_read(total_bytes_to_read_) {} +}; + + +/** Progress of query execution. + * Values, transferred over network are deltas - how much was done after previously sent value. + * The same struct is also used for summarized values. + */ +struct Progress +{ + std::atomic<UInt64> read_rows {0}; /// Rows (source) processed. + std::atomic<UInt64> read_bytes {0}; /// Bytes (uncompressed, source) processed. + + /** How much rows/bytes must be processed, in total, approximately. Non-zero value is sent when there is information about + * some new part of job. Received values must be summed to get estimate of total rows to process. + */ + std::atomic<UInt64> total_rows_to_read {0}; + std::atomic<UInt64> total_bytes_to_read {0}; + + std::atomic<UInt64> written_rows {0}; + std::atomic<UInt64> written_bytes {0}; + + std::atomic<UInt64> result_rows {0}; + std::atomic<UInt64> result_bytes {0}; + + std::atomic<UInt64> elapsed_ns {0}; + + Progress() = default; + + Progress(UInt64 read_rows_, UInt64 read_bytes_, UInt64 total_rows_to_read_ = 0, UInt64 total_bytes_to_read_ = 0) + : read_rows(read_rows_), read_bytes(read_bytes_), total_rows_to_read(total_rows_to_read_), total_bytes_to_read(total_bytes_to_read_) {} + + explicit Progress(ReadProgress read_progress) + : read_rows(read_progress.read_rows), read_bytes(read_progress.read_bytes), total_rows_to_read(read_progress.total_rows_to_read) {} + + explicit Progress(WriteProgress write_progress) + : written_rows(write_progress.written_rows), written_bytes(write_progress.written_bytes) {} + + explicit Progress(ResultProgress result_progress) + : result_rows(result_progress.result_rows), result_bytes(result_progress.result_bytes) {} + + explicit Progress(FileProgress file_progress) + : read_bytes(file_progress.read_bytes), total_bytes_to_read(file_progress.total_bytes_to_read) {} + + void read(ReadBuffer & in, UInt64 server_revision); + + void write(WriteBuffer & out, UInt64 client_revision) const; + + /// Progress in JSON format (single line, without whitespaces) is used in HTTP headers. + void writeJSON(WriteBuffer & out) const; + + /// Each value separately is changed atomically (but not whole object). + bool incrementPiecewiseAtomically(const Progress & rhs); + + void reset(); + + ProgressValues getValues() const; + + ProgressValues fetchValuesAndResetPiecewiseAtomically(); + + Progress fetchAndResetPiecewiseAtomically(); + + Progress & operator=(Progress && other) noexcept; + + Progress(Progress && other) noexcept + { + *this = std::move(other); + } +}; + + +/** Callback to track the progress of the query. + * Used in QueryPipeline and Context. + * The function takes the number of rows in the last block, the number of bytes in the last block. + * Note that the callback can be called from different threads. + */ +using ProgressCallback = std::function<void(const Progress & progress)>; + +} diff --git a/contrib/clickhouse/src/IO/ReadBuffer.cpp b/contrib/clickhouse/src/IO/ReadBuffer.cpp new file mode 100644 index 0000000000..bf054d0842 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBuffer.cpp @@ -0,0 +1,47 @@ +#include <IO/ReadBuffer.h> + + +namespace DB +{ + +namespace +{ + template <typename CustomData> + class ReadBufferWrapper : public ReadBuffer + { + public: + ReadBufferWrapper(ReadBuffer & in_, CustomData && custom_data_) + : ReadBuffer(in_.buffer().begin(), in_.buffer().size(), in_.offset()), in(in_), custom_data(std::move(custom_data_)) + { + } + + private: + ReadBuffer & in; + CustomData custom_data; + + bool nextImpl() override + { + in.position() = position(); + if (!in.next()) + { + set(in.position(), 0); + return false; + } + BufferBase::set(in.buffer().begin(), in.buffer().size(), in.offset()); + return true; + } + }; +} + + +std::unique_ptr<ReadBuffer> wrapReadBufferReference(ReadBuffer & ref) +{ + return std::make_unique<ReadBufferWrapper<nullptr_t>>(ref, nullptr); +} + +std::unique_ptr<ReadBuffer> wrapReadBufferPointer(ReadBufferPtr ptr) +{ + return std::make_unique<ReadBufferWrapper<ReadBufferPtr>>(*ptr, ReadBufferPtr{ptr}); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBuffer.h b/contrib/clickhouse/src/IO/ReadBuffer.h new file mode 100644 index 0000000000..a4ae12f506 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBuffer.h @@ -0,0 +1,277 @@ +#pragma once + +#include <cassert> +#include <cstring> +#include <algorithm> +#include <memory> + +#include <Common/Exception.h> +#include <Common/Priority.h> +#include <IO/BufferBase.h> +#include <IO/AsynchronousReader.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ATTEMPT_TO_READ_AFTER_EOF; + extern const int CANNOT_READ_ALL_DATA; + extern const int NOT_IMPLEMENTED; +} + +static constexpr auto DEFAULT_PREFETCH_PRIORITY = Priority{0}; + +/** A simple abstract class for buffered data reading (char sequences) from somewhere. + * Unlike std::istream, it provides access to the internal buffer, + * and also allows you to manually manage the position inside the buffer. + * + * Note! `char *`, not `const char *` is used + * (so that you can take out the common code into BufferBase, and also so that you can fill the buffer in with new data). + * This causes inconveniences - for example, when using ReadBuffer to read from a chunk of memory const char *, + * you have to use const_cast. + * + * Derived classes must implement the nextImpl() method. + */ +class ReadBuffer : public BufferBase +{ +public: + /** Creates a buffer and sets a piece of available data to read to zero size, + * so that the next() function is called to load the new data portion into the buffer at the first try. + */ + ReadBuffer(Position ptr, size_t size) : BufferBase(ptr, size, 0) { working_buffer.resize(0); } + + /** Used when the buffer is already full of data that can be read. + * (in this case, pass 0 as an offset) + */ + ReadBuffer(Position ptr, size_t size, size_t offset) : BufferBase(ptr, size, offset) {} + + // Copying the read buffers can be dangerous because they can hold a lot of + // memory or open files, so better to disable the copy constructor to prevent + // accidental copying. + ReadBuffer(const ReadBuffer &) = delete; + + // FIXME: behavior differs greately from `BufferBase::set()` and it's very confusing. + void set(Position ptr, size_t size) { BufferBase::set(ptr, size, 0); working_buffer.resize(0); } + + /** read next data and fill a buffer with it; set position to the beginning of the new data + * (but not necessarily to the beginning of working_buffer!); + * return `false` in case of end, `true` otherwise; throw an exception, if something is wrong; + * + * if an exception was thrown, is the ReadBuffer left in a usable state? this varies across implementations; + * can the caller retry next() after an exception, or call other methods? not recommended + */ + bool next() + { + assert(!hasPendingData()); + assert(position() <= working_buffer.end()); + + bytes += offset(); + bool res = nextImpl(); + if (!res) + working_buffer = Buffer(pos, pos); + else + { + pos = working_buffer.begin() + nextimpl_working_buffer_offset; + assert(position() != working_buffer.end()); + } + nextimpl_working_buffer_offset = 0; + + assert(position() <= working_buffer.end()); + + return res; + } + + + inline void nextIfAtEnd() + { + if (!hasPendingData()) + next(); + } + + virtual ~ReadBuffer() = default; + + + /** Unlike std::istream, it returns true if all data was read + * (and not in case there was an attempt to read after the end). + * If at the moment the position is at the end of the buffer, it calls the next() method. + * That is, it has a side effect - if the buffer is over, then it updates it and set the position to the beginning. + * + * Try to read after the end should throw an exception. + */ + bool ALWAYS_INLINE eof() + { + return !hasPendingData() && !next(); + } + + void ignore() + { + if (!eof()) + ++pos; + else + throwReadAfterEOF(); + } + + void ignore(size_t n) + { + while (n != 0 && !eof()) + { + size_t bytes_to_ignore = std::min(static_cast<size_t>(working_buffer.end() - pos), n); + pos += bytes_to_ignore; + n -= bytes_to_ignore; + } + + if (n) + throwReadAfterEOF(); + } + + /// You could call this method `ignore`, and `ignore` call `ignoreStrict`. + size_t tryIgnore(size_t n) + { + size_t bytes_ignored = 0; + + while (bytes_ignored < n && !eof()) + { + size_t bytes_to_ignore = std::min(static_cast<size_t>(working_buffer.end() - pos), n - bytes_ignored); + pos += bytes_to_ignore; + bytes_ignored += bytes_to_ignore; + } + + return bytes_ignored; + } + + void ignoreAll() + { + tryIgnore(std::numeric_limits<size_t>::max()); + } + + /// Peeks a single byte. + bool ALWAYS_INLINE peek(char & c) + { + if (eof()) + return false; + c = *pos; + return true; + } + + /// Reads a single byte. + [[nodiscard]] bool ALWAYS_INLINE read(char & c) + { + if (peek(c)) + { + ++pos; + return true; + } + + return false; + } + + void ALWAYS_INLINE readStrict(char & c) + { + if (read(c)) + return; + throwReadAfterEOF(); + } + + /** Reads as many as there are, no more than n bytes. */ + [[nodiscard]] size_t read(char * to, size_t n) + { + size_t bytes_copied = 0; + + while (bytes_copied < n && !eof()) + { + size_t bytes_to_copy = std::min(static_cast<size_t>(working_buffer.end() - pos), n - bytes_copied); + ::memcpy(to + bytes_copied, pos, bytes_to_copy); + pos += bytes_to_copy; + bytes_copied += bytes_to_copy; + } + + return bytes_copied; + } + + /** Reads n bytes, if there are less - throws an exception. */ + void readStrict(char * to, size_t n) + { + auto read_bytes = read(to, n); + if (n != read_bytes) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, + "Cannot read all data. Bytes read: {}. Bytes expected: {}.", read_bytes, std::to_string(n)); + } + + /** A method that can be more efficiently implemented in derived classes, in the case of reading large enough blocks. + * The implementation can read data directly into `to`, without superfluous copying, if in `to` there is enough space for work. + * For example, a CompressedReadBuffer can decompress the data directly into `to`, if the entire decompressed block fits there. + * By default - the same as read. + * Don't use for small reads. + */ + [[nodiscard]] virtual size_t readBig(char * to, size_t n) { return read(to, n); } + + /** Do something to allow faster subsequent call to 'nextImpl' if possible. + * It's used for asynchronous readers with double-buffering. + * `priority` is the `ThreadPool` priority, with which the prefetch task will be scheduled. + * Lower value means higher priority. + */ + virtual void prefetch(Priority) {} + + /** + * Set upper bound for read range [..., position). + * Useful for reading from remote filesystem, when it matters how much we read. + * Doesn't affect getFileSize(). + * See also: SeekableReadBuffer::supportsRightBoundedReads(). + * + * Behavior in weird cases is currently implementation-defined: + * - setReadUntilPosition() below current position, + * - setReadUntilPosition() above the end of the file, + * - seek() to a position above the until position (even if you setReadUntilPosition() to a + * higher value right after the seek!), + * + * Typical implementations discard any current buffers and connections, even if the position is + * adjusted only a little. + * + * Typical usage is to call it right after creating the ReadBuffer, before it started doing any + * work. + */ + virtual void setReadUntilPosition(size_t /* position */) {} + + virtual void setReadUntilEnd() {} + + /// Read at most `size` bytes into data at specified offset `offset`. First ignore `ignore` bytes if `ignore` > 0. + /// Notice: this function only need to be implemented in synchronous read buffers to be wrapped in asynchronous read. + /// Such as ReadBufferFromRemoteFSGather and AsynchronousReadIndirectBufferFromRemoteFS. + virtual IAsynchronousReader::Result readInto(char * /*data*/, size_t /*size*/, size_t /*offset*/, size_t /*ignore*/) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "readInto not implemented"); + } + +protected: + /// The number of bytes to ignore from the initial position of `working_buffer` + /// buffer. Apparently this is an additional out-parameter for nextImpl(), + /// not a real field. + size_t nextimpl_working_buffer_offset = 0; + +private: + /** Read the next data and fill a buffer with it. + * Return `false` in case of the end, `true` otherwise. + * Throw an exception if something is wrong. + */ + virtual bool nextImpl() { return false; } + + [[noreturn]] static void throwReadAfterEOF() + { + throw Exception(ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF, "Attempt to read after eof"); + } +}; + + +using ReadBufferPtr = std::shared_ptr<ReadBuffer>; + +/// Due to inconsistencies in ReadBuffer-family interfaces: +/// - some require to fully wrap underlying buffer and own it, +/// - some just wrap the reference without ownership, +/// we need to be able to wrap reference-only buffers with movable transparent proxy-buffer. +/// The uniqueness of such wraps is responsibility of the code author. +std::unique_ptr<ReadBuffer> wrapReadBufferReference(ReadBuffer & ref); +std::unique_ptr<ReadBuffer> wrapReadBufferPointer(ReadBufferPtr ptr); + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromEmptyFile.h b/contrib/clickhouse/src/IO/ReadBufferFromEmptyFile.h new file mode 100644 index 0000000000..f21f2f507d --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromEmptyFile.h @@ -0,0 +1,25 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/ReadBufferFromFileBase.h> + +namespace DB +{ + +/// In case of empty file it does not make any sense to read it. +/// +/// Plus regular readers from file has an assert that buffer is not empty, that will fail: +/// - ReadBufferFromFileDescriptor +/// - SynchronousReader +/// - ThreadPoolReader +class ReadBufferFromEmptyFile : public ReadBufferFromFileBase +{ +private: + bool nextImpl() override { return false; } + std::string getFileName() const override { return "<empty>"; } + off_t seek(off_t /*off*/, int /*whence*/) override { return 0; } + off_t getPosition() override { return 0; } + size_t getFileSize() override { return 0; } +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromEncryptedFile.cpp b/contrib/clickhouse/src/IO/ReadBufferFromEncryptedFile.cpp new file mode 100644 index 0000000000..f9cf159715 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromEncryptedFile.cpp @@ -0,0 +1,106 @@ +#include <IO/ReadBufferFromEncryptedFile.h> + +#if USE_SSL + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ARGUMENT_OUT_OF_BOUND; +} + +ReadBufferFromEncryptedFile::ReadBufferFromEncryptedFile( + size_t buffer_size_, + std::unique_ptr<ReadBufferFromFileBase> in_, + const String & key_, + const FileEncryption::Header & header_, + size_t offset_) + : ReadBufferFromFileBase(buffer_size_, nullptr, 0) + , in(std::move(in_)) + , encrypted_buffer(buffer_size_) + , encryptor(header_.algorithm, key_, header_.init_vector) +{ + offset = offset_; + need_seek = true; +} + +off_t ReadBufferFromEncryptedFile::seek(off_t off, int whence) +{ + off_t new_pos; + if (whence == SEEK_SET) + { + if (off < 0) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "SEEK_SET underflow: off = {}", off); + new_pos = off; + } + else if (whence == SEEK_CUR) + { + if (off < 0 && -off > getPosition()) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "SEEK_CUR shift out of bounds"); + new_pos = getPosition() + off; + } + else + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "ReadBufferFromFileEncrypted::seek expects SEEK_SET or SEEK_CUR as whence"); + + if ((offset - static_cast<off_t>(working_buffer.size()) <= new_pos) && (new_pos <= offset) && !need_seek) + { + /// Position is still inside buffer. + pos = working_buffer.end() - offset + new_pos; + assert(pos >= working_buffer.begin()); + assert(pos <= working_buffer.end()); + } + else + { + need_seek = true; + offset = new_pos; + + /// No more reading from the current working buffer until next() is called. + resetWorkingBuffer(); + assert(!hasPendingData()); + } + + return new_pos; +} + +off_t ReadBufferFromEncryptedFile::getPosition() +{ + return offset - available(); +} + +bool ReadBufferFromEncryptedFile::nextImpl() +{ + if (need_seek) + { + off_t raw_offset = offset + FileEncryption::Header::kSize; + if (in->seek(raw_offset, SEEK_SET) != raw_offset) + return false; + need_seek = false; + } + + if (in->eof()) + return false; + + /// Read up to the size of `encrypted_buffer`. + size_t bytes_read = 0; + while (bytes_read < encrypted_buffer.size() && !in->eof()) + { + bytes_read += in->read(encrypted_buffer.data() + bytes_read, encrypted_buffer.size() - bytes_read); + } + + /// The used cipher algorithms generate the same number of bytes in output as it were in input, + /// so after deciphering the numbers of bytes will be still `bytes_read`. + working_buffer.resize(bytes_read); + + /// The decryptor needs to know what the current offset is (because it's used in the decryption algorithm). + encryptor.setOffset(offset); + + encryptor.decrypt(encrypted_buffer.data(), bytes_read, working_buffer.begin()); + + offset += bytes_read; + pos = working_buffer.begin(); + return true; +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/ReadBufferFromEncryptedFile.h b/contrib/clickhouse/src/IO/ReadBufferFromEncryptedFile.h new file mode 100644 index 0000000000..155dc2ccce --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromEncryptedFile.h @@ -0,0 +1,50 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_SSL +#include <IO/ReadBufferFromFileBase.h> +#include <IO/FileEncryptionCommon.h> + + +namespace DB +{ + +/// Reads data from the underlying read buffer and decrypts it. +class ReadBufferFromEncryptedFile : public ReadBufferFromFileBase +{ +public: + ReadBufferFromEncryptedFile( + size_t buffer_size_, + std::unique_ptr<ReadBufferFromFileBase> in_, + const String & key_, + const FileEncryption::Header & header_, + size_t offset_ = 0); + + off_t seek(off_t off, int whence) override; + off_t getPosition() override; + + std::string getFileName() const override { return in->getFileName(); } + + void setReadUntilPosition(size_t position) override { in->setReadUntilPosition(position + FileEncryption::Header::kSize); } + + void setReadUntilEnd() override { in->setReadUntilEnd(); } + + size_t getFileSize() override { return in->getFileSize(); } + +private: + bool nextImpl() override; + + std::unique_ptr<ReadBufferFromFileBase> in; + + off_t offset = 0; + + bool need_seek = false; + + Memory<> encrypted_buffer; + FileEncryption::Encryptor encryptor; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFile.cpp b/contrib/clickhouse/src/IO/ReadBufferFromFile.cpp new file mode 100644 index 0000000000..79ac62c642 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFile.cpp @@ -0,0 +1,97 @@ +#include <fcntl.h> + +#include <IO/ReadBufferFromFile.h> +#include <IO/WriteHelpers.h> +#include <Common/ProfileEvents.h> +#include <base/defines.h> +#include <cerrno> + + +namespace ProfileEvents +{ + extern const Event FileOpen; +} + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FILE_DOESNT_EXIST; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; +} + + +ReadBufferFromFile::ReadBufferFromFile( + const std::string & file_name_, + size_t buf_size, + int flags, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_, + ThrottlerPtr throttler_) + : ReadBufferFromFileDescriptor(-1, buf_size, existing_memory, alignment, file_size_, throttler_) + , file_name(file_name_) +{ + ProfileEvents::increment(ProfileEvents::FileOpen); + +#ifdef OS_DARWIN + bool o_direct = (flags != -1) && (flags & O_DIRECT); + if (o_direct) + flags = flags & ~O_DIRECT; +#endif + fd = ::open(file_name.c_str(), flags == -1 ? O_RDONLY | O_CLOEXEC : flags | O_CLOEXEC); + + if (-1 == fd) + throwFromErrnoWithPath("Cannot open file " + file_name, file_name, + errno == ENOENT ? ErrorCodes::FILE_DOESNT_EXIST : ErrorCodes::CANNOT_OPEN_FILE); +#ifdef OS_DARWIN + if (o_direct) + { + if (fcntl(fd, F_NOCACHE, 1) == -1) + throwFromErrnoWithPath("Cannot set F_NOCACHE on file " + file_name, file_name, ErrorCodes::CANNOT_OPEN_FILE); + } +#endif +} + + +ReadBufferFromFile::ReadBufferFromFile( + int & fd_, + const std::string & original_file_name, + size_t buf_size, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_, + ThrottlerPtr throttler_) + : ReadBufferFromFileDescriptor(fd_, buf_size, existing_memory, alignment, file_size_, throttler_) + , file_name(original_file_name.empty() ? "(fd = " + toString(fd_) + ")" : original_file_name) +{ + fd_ = -1; +} + + +ReadBufferFromFile::~ReadBufferFromFile() +{ + if (fd < 0) + return; + + int err = ::close(fd); + chassert(!err || errno == EINTR); +} + + +void ReadBufferFromFile::close() +{ + if (fd < 0) + return; + + if (0 != ::close(fd)) + throw Exception(ErrorCodes::CANNOT_CLOSE_FILE, "Cannot close file"); + + fd = -1; + metric_increment.destroy(); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFile.h b/contrib/clickhouse/src/IO/ReadBufferFromFile.h new file mode 100644 index 0000000000..462453d974 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFile.h @@ -0,0 +1,108 @@ +#pragma once + +#include <IO/ReadBufferFromFileDescriptor.h> +#include <IO/OpenedFileCache.h> +#include <Common/CurrentMetrics.h> + + +namespace CurrentMetrics +{ + extern const Metric OpenFileForRead; +} + +namespace DB +{ + +/** Accepts path to file and opens it, or pre-opened file descriptor. + * Closes file by himself (thus "owns" a file descriptor). + */ +class ReadBufferFromFile : public ReadBufferFromFileDescriptor +{ +protected: + std::string file_name; + CurrentMetrics::Increment metric_increment{CurrentMetrics::OpenFileForRead}; + +public: + explicit ReadBufferFromFile( + const std::string & file_name_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + int flags = -1, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler = {}); + + /// Use pre-opened file descriptor. + explicit ReadBufferFromFile( + int & fd, /// Will be set to -1 if constructor didn't throw and ownership of file descriptor is passed to the object. + const std::string & original_file_name = {}, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler = {}); + + ~ReadBufferFromFile() override; + + /// Close file before destruction of object. + void close(); + + std::string getFileName() const override + { + return file_name; + } + + size_t getFileOffsetOfBufferEnd() const override { return file_offset_of_buffer_end; } +}; + + +/** Similar to ReadBufferFromFile but it is using 'pread' instead of 'read'. + */ +class ReadBufferFromFilePRead : public ReadBufferFromFile +{ +public: + explicit ReadBufferFromFilePRead( + const std::string & file_name_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + int flags = -1, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt) + : ReadBufferFromFile(file_name_, buf_size, flags, existing_memory, alignment, file_size_) + { + use_pread = true; + } +}; + + +/** Similar to ReadBufferFromFilePRead but also transparently shares open file descriptors. + */ +class ReadBufferFromFilePReadWithDescriptorsCache : public ReadBufferFromFileDescriptorPRead +{ +private: + std::string file_name; + OpenedFileCache::OpenedFilePtr file; + +public: + explicit ReadBufferFromFilePReadWithDescriptorsCache( + const std::string & file_name_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + int flags = -1, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler_ = {}) + : ReadBufferFromFileDescriptorPRead(-1, buf_size, existing_memory, alignment, file_size_, throttler_) + , file_name(file_name_) + { + file = OpenedFileCache::instance().get(file_name, flags); + fd = file->getFD(); + } + + std::string getFileName() const override + { + return file_name; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFileBase.cpp b/contrib/clickhouse/src/IO/ReadBufferFromFileBase.cpp new file mode 100644 index 0000000000..4ac3f984f7 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFileBase.cpp @@ -0,0 +1,49 @@ +#include <IO/ReadBufferFromFileBase.h> +#include <IO/Progress.h> +#include <Interpreters/Context.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_FILE_SIZE; +} + +ReadBufferFromFileBase::ReadBufferFromFileBase() : BufferWithOwnMemory<SeekableReadBuffer>(0) +{ +} + +ReadBufferFromFileBase::ReadBufferFromFileBase( + size_t buf_size, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_) + : BufferWithOwnMemory<SeekableReadBuffer>(buf_size, existing_memory, alignment) + , file_size(file_size_) +{ +} + +ReadBufferFromFileBase::~ReadBufferFromFileBase() = default; + +size_t ReadBufferFromFileBase::getFileSize() +{ + if (file_size) + return *file_size; + throw Exception(ErrorCodes::UNKNOWN_FILE_SIZE, "Cannot find out file size for read buffer"); +} + +void ReadBufferFromFileBase::setProgressCallback(ContextPtr context) +{ + auto file_progress_callback = context->getFileProgressCallback(); + + if (!file_progress_callback) + return; + + setProfileCallback([file_progress_callback](const ProfileInfo & progress) + { + file_progress_callback(FileProgress(progress.bytes_read)); + }); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFileBase.h b/contrib/clickhouse/src/IO/ReadBufferFromFileBase.h new file mode 100644 index 0000000000..b77db29bc2 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFileBase.h @@ -0,0 +1,63 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/SeekableReadBuffer.h> +#include <IO/WithFileName.h> +#include <Interpreters/Context_fwd.h> +#include <base/time.h> + +#include <functional> +#include <utility> +#include <string> + +#include <sys/stat.h> +#include <sys/types.h> +#include <fcntl.h> + +#ifndef O_DIRECT +#define O_DIRECT 00040000 +#endif + + +namespace DB +{ + +class ReadBufferFromFileBase : public BufferWithOwnMemory<SeekableReadBuffer>, public WithFileName, public WithFileSize +{ +public: + ReadBufferFromFileBase(); + ReadBufferFromFileBase( + size_t buf_size, + char * existing_memory, + size_t alignment, + std::optional<size_t> file_size_ = std::nullopt); + ~ReadBufferFromFileBase() override; + + /// It is possible to get information about the time of each reading. + struct ProfileInfo + { + size_t bytes_requested; + size_t bytes_read; + size_t nanoseconds; + }; + + using ProfileCallback = std::function<void(ProfileInfo)>; + + /// CLOCK_MONOTONIC_COARSE is more than enough to track long reads - for example, hanging for a second. + void setProfileCallback(const ProfileCallback & profile_callback_, clockid_t clock_type_ = CLOCK_MONOTONIC_COARSE) + { + profile_callback = profile_callback_; + clock_type = clock_type_; + } + + size_t getFileSize() override; + + void setProgressCallback(ContextPtr context); + +protected: + std::optional<size_t> file_size; + ProfileCallback profile_callback; + clockid_t clock_type{}; +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFileDecorator.cpp b/contrib/clickhouse/src/IO/ReadBufferFromFileDecorator.cpp new file mode 100644 index 0000000000..9ac0fb4e47 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFileDecorator.cpp @@ -0,0 +1,60 @@ +#include <IO/ReadBufferFromFileDecorator.h> + + +namespace DB +{ + +ReadBufferFromFileDecorator::ReadBufferFromFileDecorator(std::unique_ptr<SeekableReadBuffer> impl_) + : ReadBufferFromFileDecorator(std::move(impl_), "") +{ +} + + +ReadBufferFromFileDecorator::ReadBufferFromFileDecorator(std::unique_ptr<SeekableReadBuffer> impl_, const String & file_name_) + : impl(std::move(impl_)), file_name(file_name_) +{ + swap(*impl); +} + + +std::string ReadBufferFromFileDecorator::getFileName() const +{ + if (!file_name.empty()) + return file_name; + + return getFileNameFromReadBuffer(*impl); +} + + +off_t ReadBufferFromFileDecorator::getPosition() +{ + swap(*impl); + auto position = impl->getPosition(); + swap(*impl); + return position; +} + + +off_t ReadBufferFromFileDecorator::seek(off_t off, int whence) +{ + swap(*impl); + auto result = impl->seek(off, whence); + swap(*impl); + return result; +} + + +bool ReadBufferFromFileDecorator::nextImpl() +{ + swap(*impl); + auto result = impl->next(); + swap(*impl); + return result; +} + +size_t ReadBufferFromFileDecorator::getFileSize() +{ + return getFileSizeFromReadBuffer(*impl); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFileDecorator.h b/contrib/clickhouse/src/IO/ReadBufferFromFileDecorator.h new file mode 100644 index 0000000000..6e62c7f741 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFileDecorator.h @@ -0,0 +1,37 @@ +#pragma once + +#include <IO/ReadBufferFromFileBase.h> + + +namespace DB +{ + +/// Delegates all reads to underlying buffer. Doesn't have own memory. +class ReadBufferFromFileDecorator : public ReadBufferFromFileBase +{ +public: + explicit ReadBufferFromFileDecorator(std::unique_ptr<SeekableReadBuffer> impl_); + ReadBufferFromFileDecorator(std::unique_ptr<SeekableReadBuffer> impl_, const String & file_name_); + + std::string getFileName() const override; + + off_t getPosition() override; + + off_t seek(off_t off, int whence) override; + + bool nextImpl() override; + + bool isWithFileSize() const { return dynamic_cast<const WithFileSize *>(impl.get()) != nullptr; } + + const ReadBuffer & getWrappedReadBuffer() const { return *impl; } + + ReadBuffer & getWrappedReadBuffer() { return *impl; } + + size_t getFileSize() override; + +protected: + std::unique_ptr<SeekableReadBuffer> impl; + String file_name; +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFileDescriptor.cpp b/contrib/clickhouse/src/IO/ReadBufferFromFileDescriptor.cpp new file mode 100644 index 0000000000..6c0c1681a4 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFileDescriptor.cpp @@ -0,0 +1,288 @@ +#include <cerrno> +#include <ctime> +#include <optional> +#include <Common/ProfileEvents.h> +#include <Common/Stopwatch.h> +#include <Common/Exception.h> +#include <Common/CurrentMetrics.h> +#include <Common/Throttler.h> +#include <IO/ReadBufferFromFileDescriptor.h> +#include <IO/WriteHelpers.h> +#include <Common/filesystemHelpers.h> +#include <sys/stat.h> +#include <Interpreters/Context.h> + + +#pragma clang diagnostic ignored "-Wreserved-identifier" + +namespace ProfileEvents +{ + extern const Event ReadBufferFromFileDescriptorRead; + extern const Event ReadBufferFromFileDescriptorReadFailed; + extern const Event ReadBufferFromFileDescriptorReadBytes; + extern const Event DiskReadElapsedMicroseconds; + extern const Event Seek; + extern const Event LocalReadThrottlerBytes; + extern const Event LocalReadThrottlerSleepMicroseconds; +} + +namespace CurrentMetrics +{ + extern const Metric Read; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_READ_FROM_FILE_DESCRIPTOR; + extern const int ARGUMENT_OUT_OF_BOUND; + extern const int CANNOT_SEEK_THROUGH_FILE; + extern const int CANNOT_SELECT; + extern const int CANNOT_ADVISE; +} + + +std::string ReadBufferFromFileDescriptor::getFileName() const +{ + return "(fd = " + toString(fd) + ")"; +} + + +size_t ReadBufferFromFileDescriptor::readImpl(char * to, size_t min_bytes, size_t max_bytes, size_t offset) +{ + chassert(min_bytes <= max_bytes); + + /// This is a workaround of a read past EOF bug in linux kernel with pread() + if (file_size.has_value() && offset >= *file_size) + return 0; + + size_t bytes_read = 0; + while (bytes_read < min_bytes) + { + ProfileEvents::increment(ProfileEvents::ReadBufferFromFileDescriptorRead); + + Stopwatch watch(profile_callback ? clock_type : CLOCK_MONOTONIC); + + ssize_t res = 0; + size_t to_read = max_bytes - bytes_read; + { + CurrentMetrics::Increment metric_increment{CurrentMetrics::Read}; + + if (use_pread) + res = ::pread(fd, to + bytes_read, to_read, offset + bytes_read); + else + res = ::read(fd, to + bytes_read, to_read); + } + if (!res) + break; + + if (-1 == res && errno != EINTR) + { + ProfileEvents::increment(ProfileEvents::ReadBufferFromFileDescriptorReadFailed); + throwFromErrnoWithPath("Cannot read from file: " + getFileName(), getFileName(), ErrorCodes::CANNOT_READ_FROM_FILE_DESCRIPTOR); + } + + if (res > 0) + { + bytes_read += res; + if (throttler) + throttler->add(res, ProfileEvents::LocalReadThrottlerBytes, ProfileEvents::LocalReadThrottlerSleepMicroseconds); + } + + + /// It reports real time spent including the time spent while thread was preempted doing nothing. + /// And it is Ok for the purpose of this watch (it is used to lower the number of threads to read from tables). + /// Sometimes it is better to use taskstats::blkio_delay_total, but it is quite expensive to get it + /// (NetlinkMetricsProvider has about 500K RPS). + watch.stop(); + ProfileEvents::increment(ProfileEvents::DiskReadElapsedMicroseconds, watch.elapsedMicroseconds()); + + if (profile_callback) + { + ProfileInfo info; + info.bytes_requested = to_read; + info.bytes_read = res; + info.nanoseconds = watch.elapsed(); + profile_callback(info); + } + } + + if (bytes_read) + ProfileEvents::increment(ProfileEvents::ReadBufferFromFileDescriptorReadBytes, bytes_read); + + return bytes_read; +} + + +bool ReadBufferFromFileDescriptor::nextImpl() +{ + /// If internal_buffer size is empty, then read() cannot be distinguished from EOF + assert(!internal_buffer.empty()); + + size_t bytes_read = readImpl(internal_buffer.begin(), 1, internal_buffer.size(), file_offset_of_buffer_end); + + file_offset_of_buffer_end += bytes_read; + + if (bytes_read) + { + working_buffer = internal_buffer; + working_buffer.resize(bytes_read); + } + else + return false; + + return true; +} + + +void ReadBufferFromFileDescriptor::prefetch(Priority) +{ +#if defined(POSIX_FADV_WILLNEED) + /// For direct IO, loading data into page cache is pointless. + if (required_alignment) + return; + + /// Ask OS to prefetch data into page cache. + if (0 != posix_fadvise(fd, file_offset_of_buffer_end, internal_buffer.size(), POSIX_FADV_WILLNEED)) + throwFromErrno("Cannot posix_fadvise", ErrorCodes::CANNOT_ADVISE); +#endif +} + + +/// If 'offset' is small enough to stay in buffer after seek, then true seek in file does not happen. +off_t ReadBufferFromFileDescriptor::seek(off_t offset, int whence) +{ + size_t new_pos; + if (whence == SEEK_SET) + { + assert(offset >= 0); + new_pos = offset; + } + else if (whence == SEEK_CUR) + { + new_pos = file_offset_of_buffer_end - (working_buffer.end() - pos) + offset; + } + else + { + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "ReadBufferFromFileDescriptor::seek expects SEEK_SET or SEEK_CUR as whence"); + } + + /// Position is unchanged. + if (new_pos + (working_buffer.end() - pos) == file_offset_of_buffer_end) + return new_pos; + + if (file_offset_of_buffer_end - working_buffer.size() <= static_cast<size_t>(new_pos) + && new_pos <= file_offset_of_buffer_end) + { + /// Position is still inside the buffer. + /// Probably it is at the end of the buffer - then we will load data on the following 'next' call. + + pos = working_buffer.end() - file_offset_of_buffer_end + new_pos; + assert(pos >= working_buffer.begin()); + assert(pos <= working_buffer.end()); + + return new_pos; + } + else + { + /// Position is out of the buffer, we need to do real seek. + off_t seek_pos = required_alignment > 1 + ? new_pos / required_alignment * required_alignment + : new_pos; + + off_t offset_after_seek_pos = new_pos - seek_pos; + + /// First reset the buffer so the next read will fetch new data to the buffer. + resetWorkingBuffer(); + + /// In case of using 'pread' we just update the info about the next position in file. + /// In case of using 'read' we call 'lseek'. + + /// We account both cases as seek event as it leads to non-contiguous reads from file. + ProfileEvents::increment(ProfileEvents::Seek); + + if (!use_pread) + { + Stopwatch watch(profile_callback ? clock_type : CLOCK_MONOTONIC); + + off_t res = ::lseek(fd, seek_pos, SEEK_SET); + if (-1 == res) + throwFromErrnoWithPath(fmt::format("Cannot seek through file {} at offset {}", getFileName(), seek_pos), getFileName(), + ErrorCodes::CANNOT_SEEK_THROUGH_FILE); + + /// Also note that seeking past the file size is not allowed. + if (res != seek_pos) + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, + "The 'lseek' syscall returned value ({}) that is not expected ({})", res, seek_pos); + + watch.stop(); + ProfileEvents::increment(ProfileEvents::DiskReadElapsedMicroseconds, watch.elapsedMicroseconds()); + } + + file_offset_of_buffer_end = seek_pos; + + if (offset_after_seek_pos > 0) + ignore(offset_after_seek_pos); + + return seek_pos; + } +} + + +void ReadBufferFromFileDescriptor::rewind() +{ + if (!use_pread) + { + ProfileEvents::increment(ProfileEvents::Seek); + off_t res = ::lseek(fd, 0, SEEK_SET); + if (-1 == res) + throwFromErrnoWithPath("Cannot seek through file " + getFileName(), getFileName(), + ErrorCodes::CANNOT_SEEK_THROUGH_FILE); + } + /// In case of pread, the ProfileEvents::Seek is not accounted, but it's Ok. + + /// Clearing the buffer with existing data. New data will be read on subsequent call to 'next'. + working_buffer.resize(0); + pos = working_buffer.begin(); + file_offset_of_buffer_end = 0; +} + + +/// Assuming file descriptor supports 'select', check that we have data to read or wait until timeout. +bool ReadBufferFromFileDescriptor::poll(size_t timeout_microseconds) const +{ + fd_set fds; + FD_ZERO(&fds); + FD_SET(fd, &fds); + timeval timeout = { time_t(timeout_microseconds / 1000000), suseconds_t(timeout_microseconds % 1000000) }; + + int res = select(1, &fds, nullptr, nullptr, &timeout); + + if (-1 == res) + throwFromErrno("Cannot select", ErrorCodes::CANNOT_SELECT); + + return res > 0; +} + + +size_t ReadBufferFromFileDescriptor::getFileSize() +{ + return getSizeFromFileDescriptor(fd, getFileName()); +} + +bool ReadBufferFromFileDescriptor::checkIfActuallySeekable() +{ + struct stat stat; + auto res = ::fstat(fd, &stat); + return res == 0 && S_ISREG(stat.st_mode); +} + +size_t ReadBufferFromFileDescriptor::readBigAt(char * to, size_t n, size_t offset, const std::function<bool(size_t)> &) +{ + chassert(use_pread); + return readImpl(to, n, n, offset); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromFileDescriptor.h b/contrib/clickhouse/src/IO/ReadBufferFromFileDescriptor.h new file mode 100644 index 0000000000..64340770cf --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromFileDescriptor.h @@ -0,0 +1,103 @@ +#pragma once + +#include <IO/ReadBufferFromFileBase.h> +#include <Interpreters/Context_fwd.h> +#include <Common/Throttler_fwd.h> + +#include <unistd.h> + + +namespace DB +{ + +/** Use ready file descriptor. Does not open or close a file. + */ +class ReadBufferFromFileDescriptor : public ReadBufferFromFileBase +{ +protected: + const size_t required_alignment = 0; /// For O_DIRECT both file offsets and memory addresses have to be aligned. + bool use_pread = false; /// To access one fd from multiple threads, use 'pread' syscall instead of 'read'. + + size_t file_offset_of_buffer_end = 0; /// What offset in file corresponds to working_buffer.end(). + + int fd; + + ThrottlerPtr throttler; + + bool nextImpl() override; + void prefetch(Priority priority) override; + + /// Name or some description of file. + std::string getFileName() const override; + + /// Does the read()/pread(), with all the metric increments, error handling, throttling, etc. + /// Doesn't seek (`offset` must match fd's position if !use_pread). + /// Stops after min_bytes or eof. Returns 0 if eof. + /// Thread safe. + size_t readImpl(char * to, size_t min_bytes, size_t max_bytes, size_t offset); + +public: + explicit ReadBufferFromFileDescriptor( + int fd_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler_ = {}) + : ReadBufferFromFileBase(buf_size, existing_memory, alignment, file_size_) + , required_alignment(alignment) + , fd(fd_) + , throttler(throttler_) + { + } + + int getFD() const + { + return fd; + } + + off_t getPosition() override + { + return file_offset_of_buffer_end - (working_buffer.end() - pos); + } + + size_t getFileOffsetOfBufferEnd() const override { return file_offset_of_buffer_end; } + + /// If 'offset' is small enough to stay in buffer after seek, then true seek in file does not happen. + off_t seek(off_t off, int whence) override; + + /// Seek to the beginning, discarding already read data if any. Useful to reread file that changes on every read. + void rewind(); + + size_t getFileSize() override; + + bool checkIfActuallySeekable() override; + + size_t readBigAt(char * to, size_t n, size_t offset, const std::function<bool(size_t)> &) override; + bool supportsReadAt() override { return use_pread; } + +private: + /// Assuming file descriptor supports 'select', check that we have data to read or wait until timeout. + bool poll(size_t timeout_microseconds) const; +}; + + +/** Similar to ReadBufferFromFileDescriptor but it is using 'pread' allowing multiple concurrent reads from the same fd. + */ +class ReadBufferFromFileDescriptorPRead : public ReadBufferFromFileDescriptor +{ +public: + explicit ReadBufferFromFileDescriptorPRead( + int fd_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0, + std::optional<size_t> file_size_ = std::nullopt, + ThrottlerPtr throttler_ = {}) + : ReadBufferFromFileDescriptor(fd_, buf_size, existing_memory, alignment, file_size_, throttler_) + { + use_pread = true; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromIStream.cpp b/contrib/clickhouse/src/IO/ReadBufferFromIStream.cpp new file mode 100644 index 0000000000..e0c966fb70 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromIStream.cpp @@ -0,0 +1,39 @@ +#include <IO/ReadBufferFromIStream.h> +#include <Common/Exception.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_READ_FROM_ISTREAM; +} + +bool ReadBufferFromIStream::nextImpl() +{ + istr.read(internal_buffer.begin(), internal_buffer.size()); + size_t gcount = istr.gcount(); + + if (!gcount) + { + if (istr.eof()) + return false; + + if (istr.fail()) + throw Exception(ErrorCodes::CANNOT_READ_FROM_ISTREAM, "Cannot read from istream at offset {}", count()); + + throw Exception(ErrorCodes::CANNOT_READ_FROM_ISTREAM, "Unexpected state of istream at offset {}", count()); + } + else + working_buffer.resize(gcount); + + return true; +} + +ReadBufferFromIStream::ReadBufferFromIStream(std::istream & istr_, size_t size) + : BufferWithOwnMemory<ReadBuffer>(size), istr(istr_) +{ +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromIStream.h b/contrib/clickhouse/src/IO/ReadBufferFromIStream.h new file mode 100644 index 0000000000..8c3f62728b --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromIStream.h @@ -0,0 +1,21 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/BufferWithOwnMemory.h> + + +namespace DB +{ + +class ReadBufferFromIStream : public BufferWithOwnMemory<ReadBuffer> +{ +private: + std::istream & istr; + + bool nextImpl() override; + +public: + explicit ReadBufferFromIStream(std::istream & istr_, size_t size = DBMS_DEFAULT_BUFFER_SIZE); +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromMemory.cpp b/contrib/clickhouse/src/IO/ReadBufferFromMemory.cpp new file mode 100644 index 0000000000..ede2c531e4 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromMemory.cpp @@ -0,0 +1,47 @@ +#include "ReadBufferFromMemory.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_SEEK_THROUGH_FILE; + extern const int SEEK_POSITION_OUT_OF_BOUND; +} + +off_t ReadBufferFromMemory::seek(off_t offset, int whence) +{ + if (whence == SEEK_SET) + { + if (offset >= 0 && internal_buffer.begin() + offset <= internal_buffer.end()) + { + pos = internal_buffer.begin() + offset; + working_buffer = internal_buffer; /// We need to restore `working_buffer` in case the position was at EOF before this seek(). + return static_cast<size_t>(pos - internal_buffer.begin()); + } + else + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bounds. Offset: {}, Max: {}", + offset, std::to_string(static_cast<size_t>(internal_buffer.end() - internal_buffer.begin()))); + } + else if (whence == SEEK_CUR) + { + Position new_pos = pos + offset; + if (new_pos >= internal_buffer.begin() && new_pos <= internal_buffer.end()) + { + pos = new_pos; + working_buffer = internal_buffer; /// We need to restore `working_buffer` in case the position was at EOF before this seek(). + return static_cast<size_t>(pos - internal_buffer.begin()); + } + else + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bounds. Offset: {}, Max: {}", + offset, std::to_string(static_cast<size_t>(internal_buffer.end() - internal_buffer.begin()))); + } + else + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, "Only SEEK_SET and SEEK_CUR seek modes allowed."); +} + +off_t ReadBufferFromMemory::getPosition() +{ + return pos - internal_buffer.begin(); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromMemory.h b/contrib/clickhouse/src/IO/ReadBufferFromMemory.h new file mode 100644 index 0000000000..ad96e4bfa2 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromMemory.h @@ -0,0 +1,27 @@ +#pragma once + +#include "SeekableReadBuffer.h" + + +namespace DB +{ +/** Allows to read from memory range. + * In comparison with just ReadBuffer, it only adds convenient constructors, that do const_cast. + * In fact, ReadBuffer will not modify data in buffer, but it requires non-const pointer. + */ +class ReadBufferFromMemory : public SeekableReadBuffer +{ +public: + template <typename CharT> + requires (sizeof(CharT) == 1) + ReadBufferFromMemory(const CharT * buf, size_t size) + : SeekableReadBuffer(const_cast<char *>(reinterpret_cast<const char *>(buf)), size, 0) {} + explicit ReadBufferFromMemory(const std::string_view&& str) + : SeekableReadBuffer(const_cast<char *>(str.data()), str.size(), 0) {} + + off_t seek(off_t off, int whence) override; + + off_t getPosition() override; +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromPocoSocket.cpp b/contrib/clickhouse/src/IO/ReadBufferFromPocoSocket.cpp new file mode 100644 index 0000000000..ff72dc5386 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromPocoSocket.cpp @@ -0,0 +1,129 @@ +#include <Poco/Net/NetException.h> + +#include <base/scope_guard.h> + +#include <IO/ReadBufferFromPocoSocket.h> +#include <Common/Exception.h> +#include <Common/NetException.h> +#include <Common/Stopwatch.h> +#include <Common/ProfileEvents.h> +#include <Common/CurrentMetrics.h> +#include <Common/AsyncTaskExecutor.h> +#include <Common/checkSSLReturnCode.h> + +namespace ProfileEvents +{ + extern const Event NetworkReceiveElapsedMicroseconds; + extern const Event NetworkReceiveBytes; +} + +namespace CurrentMetrics +{ + extern const Metric NetworkReceive; +} + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NETWORK_ERROR; + extern const int SOCKET_TIMEOUT; + extern const int CANNOT_READ_FROM_SOCKET; + extern const int LOGICAL_ERROR; +} + +bool ReadBufferFromPocoSocket::nextImpl() +{ + ssize_t bytes_read = 0; + Stopwatch watch; + + SCOPE_EXIT({ + /// NOTE: it is quite inaccurate on high loads since the thread could be replaced by another one + ProfileEvents::increment(ProfileEvents::NetworkReceiveElapsedMicroseconds, watch.elapsedMicroseconds()); + ProfileEvents::increment(ProfileEvents::NetworkReceiveBytes, bytes_read); + }); + + /// Add more details to exceptions. + try + { + CurrentMetrics::Increment metric_increment(CurrentMetrics::NetworkReceive); + + if (internal_buffer.size() > INT_MAX) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Buffer overflow"); + + /// If async_callback is specified, set socket to non-blocking mode + /// and try to read data from it, if socket is not ready for reading, + /// run async_callback and try again later. + /// It is expected that file descriptor may be polled externally. + /// Note that send timeout is not checked here. External code should check it while polling. + if (async_callback) + { + socket.setBlocking(false); + SCOPE_EXIT(socket.setBlocking(true)); + bool secure = socket.secure(); + bytes_read = socket.impl()->receiveBytes(internal_buffer.begin(), static_cast<int>(internal_buffer.size())); + + /// Check EAGAIN and ERR_SSL_WANT_READ/ERR_SSL_WANT_WRITE for secure socket (reading from secure socket can write too). + while (bytes_read < 0 && (errno == EAGAIN || (secure && (checkSSLWantRead(bytes_read) || checkSSLWantWrite(bytes_read))))) + { + /// In case of ERR_SSL_WANT_WRITE we should wait for socket to be ready for writing, otherwise - for reading. + if (secure && checkSSLWantWrite(bytes_read)) + async_callback(socket.impl()->sockfd(), socket.getSendTimeout(), AsyncEventTimeoutType::SEND, socket_description, AsyncTaskExecutor::Event::WRITE | AsyncTaskExecutor::Event::ERROR); + else + async_callback(socket.impl()->sockfd(), socket.getReceiveTimeout(), AsyncEventTimeoutType::RECEIVE, socket_description, AsyncTaskExecutor::Event::READ | AsyncTaskExecutor::Event::ERROR); + + /// Try to read again. + bytes_read = socket.impl()->receiveBytes(internal_buffer.begin(), static_cast<int>(internal_buffer.size())); + } + } + else + { + bytes_read = socket.impl()->receiveBytes(internal_buffer.begin(), static_cast<int>(internal_buffer.size())); + } + } + catch (const Poco::Net::NetException & e) + { + throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while reading from socket ({})", e.displayText(), peer_address.toString()); + } + catch (const Poco::TimeoutException &) + { + throw NetException(ErrorCodes::SOCKET_TIMEOUT, "Timeout exceeded while reading from socket ({}, {} ms)", + peer_address.toString(), + socket.impl()->getReceiveTimeout().totalMilliseconds()); + } + catch (const Poco::IOException & e) + { + throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while reading from socket ({})", e.displayText(), peer_address.toString()); + } + + if (bytes_read < 0) + throw NetException(ErrorCodes::CANNOT_READ_FROM_SOCKET, "Cannot read from socket ({})", peer_address.toString()); + + if (bytes_read) + working_buffer.resize(bytes_read); + else + return false; + + return true; +} + +ReadBufferFromPocoSocket::ReadBufferFromPocoSocket(Poco::Net::Socket & socket_, size_t buf_size) + : BufferWithOwnMemory<ReadBuffer>(buf_size) + , socket(socket_) + , peer_address(socket.peerAddress()) + , socket_description("socket (" + peer_address.toString() + ")") +{ +} + +bool ReadBufferFromPocoSocket::poll(size_t timeout_microseconds) const +{ + if (available()) + return true; + + Stopwatch watch; + bool res = socket.poll(timeout_microseconds, Poco::Net::Socket::SELECT_READ | Poco::Net::Socket::SELECT_ERROR); + ProfileEvents::increment(ProfileEvents::NetworkReceiveElapsedMicroseconds, watch.elapsedMicroseconds()); + return res; +} + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromPocoSocket.h b/contrib/clickhouse/src/IO/ReadBufferFromPocoSocket.h new file mode 100644 index 0000000000..dab4ac8629 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromPocoSocket.h @@ -0,0 +1,37 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/ReadBuffer.h> +#include <Common/AsyncTaskExecutor.h> +#include <Poco/Net/Socket.h> + +namespace DB +{ + +/// Works with the ready Poco::Net::Socket. Blocking operations. +class ReadBufferFromPocoSocket : public BufferWithOwnMemory<ReadBuffer> +{ +protected: + Poco::Net::Socket & socket; + + /** For error messages. It is necessary to receive this address in advance, because, + * for example, if the connection is broken, the address will not be received anymore + * (getpeername will return an error). + */ + Poco::Net::SocketAddress peer_address; + + bool nextImpl() override; + +public: + explicit ReadBufferFromPocoSocket(Poco::Net::Socket & socket_, size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE); + + bool poll(size_t timeout_microseconds) const; + + void setAsyncCallback(AsyncCallback async_callback_) { async_callback = std::move(async_callback_); } + +private: + AsyncCallback async_callback; + std::string socket_description; +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadBufferFromS3.cpp b/contrib/clickhouse/src/IO/ReadBufferFromS3.cpp new file mode 100644 index 0000000000..1658f03f85 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromS3.cpp @@ -0,0 +1,498 @@ +#include <IO/HTTPCommon.h> +#include <IO/S3Common.h> +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <IO/ReadBufferFromIStream.h> +#include <IO/ReadBufferFromS3.h> +#include <IO/ResourceGuard.h> +#include <IO/S3/getObjectInfo.h> +#include <IO/S3/Requests.h> + +#include <Common/Stopwatch.h> +#include <Common/Throttler.h> +#include <Common/logger_useful.h> +#include <Common/ElapsedTimeProfileEventIncrement.h> +#include <base/sleep.h> + +#include <utility> + + +namespace ProfileEvents +{ + extern const Event ReadBufferFromS3Microseconds; + extern const Event ReadBufferFromS3InitMicroseconds; + extern const Event ReadBufferFromS3Bytes; + extern const Event ReadBufferFromS3RequestsErrors; + extern const Event ReadBufferFromS3ResetSessions; + extern const Event ReadBufferFromS3PreservedSessions; + extern const Event ReadBufferSeekCancelConnection; + extern const Event S3GetObject; + extern const Event DiskS3GetObject; + extern const Event RemoteReadThrottlerBytes; + extern const Event RemoteReadThrottlerSleepMicroseconds; +} + +namespace +{ +DB::PooledHTTPSessionPtr getSession(Aws::S3::Model::GetObjectResult & read_result) +{ + if (auto * session_aware_stream = dynamic_cast<DB::S3::SessionAwareIOStream<DB::PooledHTTPSessionPtr> *>(&read_result.GetBody())) + return static_cast<DB::PooledHTTPSessionPtr &>(session_aware_stream->getSession()); + + if (dynamic_cast<DB::S3::SessionAwareIOStream<DB::HTTPSessionPtr> *>(&read_result.GetBody())) + return {}; + + /// accept result from S# mock in gtest_writebuffer_s3.cpp + if (dynamic_cast<Aws::Utils::Stream::DefaultUnderlyingStream *>(&read_result.GetBody())) + return {}; + + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Session of unexpected type encountered"); +} + +void resetSession(Aws::S3::Model::GetObjectResult & read_result) +{ + if (auto session = getSession(read_result); !session.isNull()) + { + auto & http_session = static_cast<Poco::Net::HTTPClientSession &>(*session); + http_session.reset(); + } +} + +void resetSessionIfNeeded(bool read_all_range_successfully, std::optional<Aws::S3::Model::GetObjectResult> & read_result) +{ + if (!read_result) + return; + + if (!read_all_range_successfully) + { + /// When we abandon a session with an ongoing GetObject request and there is another one trying to delete the same object this delete + /// operation will hang until GetObject's session idle timeouts. So we have to call `reset()` on GetObject's session session immediately. + resetSession(*read_result); + ProfileEvents::increment(ProfileEvents::ReadBufferFromS3ResetSessions); + } + else if (auto session = getSession(*read_result); !session.isNull()) + { + DB::markSessionForReuse(session); + ProfileEvents::increment(ProfileEvents::ReadBufferFromS3PreservedSessions); + } +} +} + +namespace DB +{ +namespace ErrorCodes +{ + extern const int S3_ERROR; + extern const int CANNOT_SEEK_THROUGH_FILE; + extern const int SEEK_POSITION_OUT_OF_BOUND; + extern const int LOGICAL_ERROR; + extern const int CANNOT_ALLOCATE_MEMORY; +} + + +ReadBufferFromS3::ReadBufferFromS3( + std::shared_ptr<const S3::Client> client_ptr_, + const String & bucket_, + const String & key_, + const String & version_id_, + const S3Settings::RequestSettings & request_settings_, + const ReadSettings & settings_, + bool use_external_buffer_, + size_t offset_, + size_t read_until_position_, + bool restricted_seek_, + std::optional<size_t> file_size_) + : ReadBufferFromFileBase(use_external_buffer_ ? 0 : settings_.remote_fs_buffer_size, nullptr, 0, file_size_) + , client_ptr(std::move(client_ptr_)) + , bucket(bucket_) + , key(key_) + , version_id(version_id_) + , request_settings(request_settings_) + , offset(offset_) + , read_until_position(read_until_position_) + , read_settings(settings_) + , use_external_buffer(use_external_buffer_) + , restricted_seek(restricted_seek_) +{ +} + +bool ReadBufferFromS3::nextImpl() +{ + if (read_until_position) + { + if (read_until_position == offset) + return false; + + if (read_until_position < offset) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to read beyond right offset ({} > {})", offset, read_until_position - 1); + } + + bool next_result = false; + + if (impl) + { + if (use_external_buffer) + { + /** + * use_external_buffer -- means we read into the buffer which + * was passed to us from somewhere else. We do not check whether + * previously returned buffer was read or not (no hasPendingData() check is needed), + * because this branch means we are prefetching data, + * each nextImpl() call we can fill a different buffer. + */ + impl->set(internal_buffer.begin(), internal_buffer.size()); + assert(working_buffer.begin() != nullptr); + assert(!internal_buffer.empty()); + } + else + { + /** + * impl was initialized before, pass position() to it to make + * sure there is no pending data which was not read. + */ + impl->position() = position(); + assert(!impl->hasPendingData()); + } + } + + size_t sleep_time_with_backoff_milliseconds = 100; + for (size_t attempt = 0; !next_result; ++attempt) + { + bool last_attempt = attempt + 1 >= request_settings.max_single_read_retries; + + ProfileEventTimeIncrement<Microseconds> watch(ProfileEvents::ReadBufferFromS3Microseconds); + + try + { + if (!impl) + { + impl = initialize(); + + if (use_external_buffer) + { + impl->set(internal_buffer.begin(), internal_buffer.size()); + assert(working_buffer.begin() != nullptr); + assert(!internal_buffer.empty()); + } + else + { + /// use the buffer returned by `impl` + BufferBase::set(impl->buffer().begin(), impl->buffer().size(), impl->offset()); + } + } + + /// Try to read a next portion of data. + next_result = impl->next(); + break; + } + catch (Exception & e) + { + if (!processException(e, getPosition(), attempt) || last_attempt) + throw; + + /// Pause before next attempt. + sleepForMilliseconds(sleep_time_with_backoff_milliseconds); + sleep_time_with_backoff_milliseconds *= 2; + + /// Try to reinitialize `impl`. + resetWorkingBuffer(); + impl.reset(); + } + } + + if (!next_result) + { + read_all_range_successfully = true; + return false; + } + + BufferBase::set(impl->buffer().begin(), impl->buffer().size(), impl->offset()); + + ProfileEvents::increment(ProfileEvents::ReadBufferFromS3Bytes, working_buffer.size()); + offset += working_buffer.size(); + if (read_settings.remote_throttler) + read_settings.remote_throttler->add(working_buffer.size(), ProfileEvents::RemoteReadThrottlerBytes, ProfileEvents::RemoteReadThrottlerSleepMicroseconds); + + return true; +} + + +size_t ReadBufferFromS3::readBigAt(char * to, size_t n, size_t range_begin, const std::function<bool(size_t)> & progress_callback) +{ + if (n == 0) + return 0; + + size_t sleep_time_with_backoff_milliseconds = 100; + for (size_t attempt = 0;; ++attempt) + { + bool last_attempt = attempt + 1 >= request_settings.max_single_read_retries; + + ProfileEventTimeIncrement<Microseconds> watch(ProfileEvents::ReadBufferFromS3Microseconds); + + try + { + auto result = sendRequest(range_begin, range_begin + n - 1); + std::istream & istr = result.GetBody(); + + size_t bytes = copyFromIStreamWithProgressCallback(istr, to, n, progress_callback); + + ProfileEvents::increment(ProfileEvents::ReadBufferFromS3Bytes, bytes); + + if (read_settings.remote_throttler) + read_settings.remote_throttler->add(bytes, ProfileEvents::RemoteReadThrottlerBytes, ProfileEvents::RemoteReadThrottlerSleepMicroseconds); + + return bytes; + } + catch (Poco::Exception & e) + { + if (!processException(e, range_begin, attempt) || last_attempt) + throw; + + sleepForMilliseconds(sleep_time_with_backoff_milliseconds); + sleep_time_with_backoff_milliseconds *= 2; + } + } +} + +bool ReadBufferFromS3::processException(Poco::Exception & e, size_t read_offset, size_t attempt) const +{ + ProfileEvents::increment(ProfileEvents::ReadBufferFromS3RequestsErrors, 1); + + LOG_DEBUG( + log, + "Caught exception while reading S3 object. Bucket: {}, Key: {}, Version: {}, Offset: {}, " + "Attempt: {}, Message: {}", + bucket, key, version_id.empty() ? "Latest" : version_id, read_offset, attempt, e.message()); + + + if (auto * s3_exception = dynamic_cast<S3Exception *>(&e)) + { + /// It doesn't make sense to retry Access Denied or No Such Key + if (!s3_exception->isRetryableError()) + { + s3_exception->addMessage("while reading key: {}, from bucket: {}", key, bucket); + return false; + } + } + + /// It doesn't make sense to retry allocator errors + if (e.code() == ErrorCodes::CANNOT_ALLOCATE_MEMORY) + { + tryLogCurrentException(log); + return false; + } + + return true; +} + + +off_t ReadBufferFromS3::seek(off_t offset_, int whence) +{ + if (offset_ == getPosition() && whence == SEEK_SET) + return offset_; + + read_all_range_successfully = false; + + if (impl && restricted_seek) + { + throw Exception( + ErrorCodes::CANNOT_SEEK_THROUGH_FILE, + "Seek is allowed only before first read attempt from the buffer (current offset: " + "{}, new offset: {}, reading until position: {}, available: {})", + getPosition(), offset_, read_until_position, available()); + } + + if (whence != SEEK_SET) + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, "Only SEEK_SET mode is allowed."); + + if (offset_ < 0) + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bounds. Offset: {}", offset_); + + if (!restricted_seek) + { + if (!working_buffer.empty() + && static_cast<size_t>(offset_) >= offset - working_buffer.size() + && offset_ < offset) + { + pos = working_buffer.end() - (offset - offset_); + assert(pos >= working_buffer.begin()); + assert(pos < working_buffer.end()); + + return getPosition(); + } + + off_t position = getPosition(); + if (impl && offset_ > position) + { + size_t diff = offset_ - position; + if (diff < read_settings.remote_read_min_bytes_for_seek) + { + ignore(diff); + return offset_; + } + } + + resetWorkingBuffer(); + if (impl) + { + if (!atEndOfRequestedRangeGuess()) + ProfileEvents::increment(ProfileEvents::ReadBufferSeekCancelConnection); + impl.reset(); + } + } + + offset = offset_; + return offset; +} + +size_t ReadBufferFromS3::getFileSize() +{ + if (file_size) + return *file_size; + + auto object_size = S3::getObjectSize(*client_ptr, bucket, key, version_id, request_settings, /* for_disk_s3= */ read_settings.for_object_storage); + + file_size = object_size; + return *file_size; +} + +off_t ReadBufferFromS3::getPosition() +{ + return offset - available(); +} + +void ReadBufferFromS3::setReadUntilPosition(size_t position) +{ + if (position != static_cast<size_t>(read_until_position)) + { + read_all_range_successfully = false; + + if (impl) + { + if (!atEndOfRequestedRangeGuess()) + ProfileEvents::increment(ProfileEvents::ReadBufferSeekCancelConnection); + offset = getPosition(); + resetWorkingBuffer(); + impl.reset(); + } + read_until_position = position; + } +} + +void ReadBufferFromS3::setReadUntilEnd() +{ + if (read_until_position) + { + read_all_range_successfully = false; + + read_until_position = 0; + if (impl) + { + if (!atEndOfRequestedRangeGuess()) + ProfileEvents::increment(ProfileEvents::ReadBufferSeekCancelConnection); + offset = getPosition(); + resetWorkingBuffer(); + impl.reset(); + } + } +} + +bool ReadBufferFromS3::atEndOfRequestedRangeGuess() +{ + if (!impl) + return true; + if (read_until_position) + return getPosition() >= read_until_position; + if (file_size) + return getPosition() >= static_cast<off_t>(*file_size); + return false; +} + +ReadBufferFromS3::~ReadBufferFromS3() +{ + try + { + resetSessionIfNeeded(readAllRangeSuccessfully(), read_result); + } + catch (...) + { + tryLogCurrentException(log); + } +} + +std::unique_ptr<ReadBuffer> ReadBufferFromS3::initialize() +{ + resetSessionIfNeeded(readAllRangeSuccessfully(), read_result); + read_all_range_successfully = false; + + /** + * If remote_filesystem_read_method = 'threadpool', then for MergeTree family tables + * exact byte ranges to read are always passed here. + */ + if (read_until_position && offset >= read_until_position) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to read beyond right offset ({} > {})", offset, read_until_position - 1); + + read_result = sendRequest(offset, read_until_position ? std::make_optional(read_until_position - 1) : std::nullopt); + + size_t buffer_size = use_external_buffer ? 0 : read_settings.remote_fs_buffer_size; + return std::make_unique<ReadBufferFromIStream>(read_result->GetBody(), buffer_size); +} + +Aws::S3::Model::GetObjectResult ReadBufferFromS3::sendRequest(size_t range_begin, std::optional<size_t> range_end_incl) const +{ + S3::GetObjectRequest req; + req.SetBucket(bucket); + req.SetKey(key); + if (!version_id.empty()) + req.SetVersionId(version_id); + + if (range_end_incl) + { + req.SetRange(fmt::format("bytes={}-{}", range_begin, *range_end_incl)); + LOG_TEST( + log, "Read S3 object. Bucket: {}, Key: {}, Version: {}, Range: {}-{}", + bucket, key, version_id.empty() ? "Latest" : version_id, range_begin, *range_end_incl); + } + else if (range_begin) + { + req.SetRange(fmt::format("bytes={}-", range_begin)); + LOG_TEST( + log, "Read S3 object. Bucket: {}, Key: {}, Version: {}, Offset: {}", + bucket, key, version_id.empty() ? "Latest" : version_id, range_begin); + } + + ProfileEvents::increment(ProfileEvents::S3GetObject); + if (read_settings.for_object_storage) + ProfileEvents::increment(ProfileEvents::DiskS3GetObject); + + ProfileEventTimeIncrement<Microseconds> watch(ProfileEvents::ReadBufferFromS3InitMicroseconds); + + // We do not know in advance how many bytes we are going to consume, to avoid blocking estimated it from below + constexpr ResourceCost estimated_cost = 1; + ResourceGuard rlock(read_settings.resource_link, estimated_cost); + Aws::S3::Model::GetObjectOutcome outcome = client_ptr->GetObject(req); + rlock.unlock(); + + if (outcome.IsSuccess()) + { + ResourceCost bytes_read = outcome.GetResult().GetContentLength(); + read_settings.resource_link.adjust(estimated_cost, bytes_read); + return outcome.GetResultWithOwnership(); + } + else + { + read_settings.resource_link.accumulate(estimated_cost); + const auto & error = outcome.GetError(); + throw S3Exception(error.GetMessage(), error.GetErrorType()); + } +} + +bool ReadBufferFromS3::readAllRangeSuccessfully() const +{ + return read_until_position ? offset == read_until_position : read_all_range_successfully; +} +} + +#endif diff --git a/contrib/clickhouse/src/IO/ReadBufferFromS3.h b/contrib/clickhouse/src/IO/ReadBufferFromS3.h new file mode 100644 index 0000000000..94697df1a0 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromS3.h @@ -0,0 +1,108 @@ +#pragma once + +#include <Storages/StorageS3Settings.h> +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <memory> + +#include <IO/HTTPCommon.h> +#include <IO/ParallelReadBuffer.h> +#include <IO/ReadBuffer.h> +#include <IO/ReadSettings.h> +#include <IO/ReadBufferFromFileBase.h> +#include <IO/WithFileName.h> + +#include <aws/s3/model/GetObjectResult.h> + +namespace DB +{ +/** + * Perform S3 HTTP GET request and provide response to read. + */ +class ReadBufferFromS3 : public ReadBufferFromFileBase +{ +private: + std::shared_ptr<const S3::Client> client_ptr; + String bucket; + String key; + String version_id; + const S3Settings::RequestSettings request_settings; + + /// These variables are atomic because they can be used for `logging only` + /// (where it is not important to get consistent result) + /// from separate thread other than the one which uses the buffer for s3 reading. + std::atomic<off_t> offset = 0; + std::atomic<off_t> read_until_position = 0; + + std::optional<Aws::S3::Model::GetObjectResult> read_result; + std::unique_ptr<ReadBuffer> impl; + + Poco::Logger * log = &Poco::Logger::get("ReadBufferFromS3"); + +public: + ReadBufferFromS3( + std::shared_ptr<const S3::Client> client_ptr_, + const String & bucket_, + const String & key_, + const String & version_id_, + const S3Settings::RequestSettings & request_settings_, + const ReadSettings & settings_, + bool use_external_buffer = false, + size_t offset_ = 0, + size_t read_until_position_ = 0, + bool restricted_seek_ = false, + std::optional<size_t> file_size = std::nullopt); + + ~ReadBufferFromS3() override; + + bool nextImpl() override; + + off_t seek(off_t off, int whence) override; + + off_t getPosition() override; + + size_t getFileSize() override; + + void setReadUntilPosition(size_t position) override; + void setReadUntilEnd() override; + + size_t getFileOffsetOfBufferEnd() const override { return offset; } + + bool supportsRightBoundedReads() const override { return true; } + + String getFileName() const override { return bucket + "/" + key; } + + size_t readBigAt(char * to, size_t n, size_t range_begin, const std::function<bool(size_t)> & progress_callback) override; + + bool supportsReadAt() override { return true; } + +private: + std::unique_ptr<ReadBuffer> initialize(); + + /// If true, if we destroy impl now, no work was wasted. Just for metrics. + bool atEndOfRequestedRangeGuess(); + + /// Call inside catch() block if GetObject fails. Bumps metrics, logs the error. + /// Returns true if the error looks retriable. + bool processException(Poco::Exception & e, size_t read_offset, size_t attempt) const; + + Aws::S3::Model::GetObjectResult sendRequest(size_t range_begin, std::optional<size_t> range_end_incl) const; + + bool readAllRangeSuccessfully() const; + + ReadSettings read_settings; + + bool use_external_buffer; + + /// There is different seek policy for disk seek and for non-disk seek + /// (non-disk seek is applied for seekable input formats: orc, arrow, parquet). + bool restricted_seek; + + bool read_all_range_successfully = false; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/ReadBufferFromString.h b/contrib/clickhouse/src/IO/ReadBufferFromString.h new file mode 100644 index 0000000000..f20e319b93 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadBufferFromString.h @@ -0,0 +1,28 @@ +#pragma once + +#include <IO/ReadBufferFromMemory.h> + +namespace DB +{ + +/// Allows to read from std::string-like object. +class ReadBufferFromString : public ReadBufferFromMemory +{ +public: + /// std::string or something similar + template <typename S> + explicit ReadBufferFromString(const S & s) : ReadBufferFromMemory(s.data(), s.size()) {} + + explicit ReadBufferFromString(std::string_view s) : ReadBufferFromMemory(s.data(), s.size()) {} +}; + +class ReadBufferFromOwnString : public String, public ReadBufferFromString +{ +public: + template <typename S> + explicit ReadBufferFromOwnString(S && s_) : String(std::forward<S>(s_)), ReadBufferFromString(*this) + { + } +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadHelpers.cpp b/contrib/clickhouse/src/IO/ReadHelpers.cpp new file mode 100644 index 0000000000..bf3215d582 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadHelpers.cpp @@ -0,0 +1,1712 @@ +#include <Core/Defines.h> +#include <base/hex.h> +#include <Common/PODArray.h> +#include <Common/StringUtils/StringUtils.h> +#include <Common/memcpySmall.h> +#include <Formats/FormatSettings.h> +#include <IO/WriteBufferFromString.h> +#include <IO/BufferWithOwnMemory.h> +#include <IO/readFloatText.h> +#include <IO/Operators.h> +#include <base/find_symbols.h> +#include <cstdlib> +#include <bit> + +#include <base/simd.h> + +#ifdef __SSE2__ + #include <emmintrin.h> +#endif + +#if defined(__aarch64__) && defined(__ARM_NEON) +# include <arm_neon.h> +# pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED; + extern const int CANNOT_PARSE_ESCAPE_SEQUENCE; + extern const int CANNOT_PARSE_QUOTED_STRING; + extern const int CANNOT_PARSE_DATETIME; + extern const int CANNOT_PARSE_DATE; + extern const int CANNOT_PARSE_UUID; + extern const int INCORRECT_DATA; + extern const int ATTEMPT_TO_READ_AFTER_EOF; + extern const int LOGICAL_ERROR; + extern const int BAD_ARGUMENTS; +} + +template <size_t num_bytes, typename IteratorSrc, typename IteratorDst> +inline void parseHex(IteratorSrc src, IteratorDst dst) +{ + size_t src_pos = 0; + size_t dst_pos = 0; + for (; dst_pos < num_bytes; ++dst_pos, src_pos += 2) + dst[dst_pos] = unhex2(reinterpret_cast<const char *>(&src[src_pos])); +} + +UUID parseUUID(std::span<const UInt8> src) +{ + UUID uuid; + const auto * src_ptr = src.data(); + const auto size = src.size(); + +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + const std::reverse_iterator dst(reinterpret_cast<UInt8 *>(&uuid) + sizeof(UUID)); +#else + auto * dst = reinterpret_cast<UInt8 *>(&uuid); +#endif + if (size == 36) + { + parseHex<4>(src_ptr, dst + 8); + parseHex<2>(src_ptr + 9, dst + 12); + parseHex<2>(src_ptr + 14, dst + 14); + parseHex<2>(src_ptr + 19, dst); + parseHex<6>(src_ptr + 24, dst + 2); + } + else if (size == 32) + { + parseHex<8>(src_ptr, dst + 8); + parseHex<8>(src_ptr + 16, dst); + } + else + throw Exception(ErrorCodes::CANNOT_PARSE_UUID, "Unexpected length when trying to parse UUID ({})", size); + + return uuid; +} + +void NO_INLINE throwAtAssertionFailed(const char * s, ReadBuffer & buf) +{ + WriteBufferFromOwnString out; + out << quote << s; + + if (buf.eof()) + out << " at end of stream."; + else + out << " before: " << quote << String(buf.position(), std::min(SHOW_CHARS_ON_SYNTAX_ERROR, buf.buffer().end() - buf.position())); + + throw ParsingException(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "Cannot parse input: expected {}", out.str()); +} + + +bool checkString(const char * s, ReadBuffer & buf) +{ + for (; *s; ++s) + { + if (buf.eof() || *buf.position() != *s) + return false; + ++buf.position(); + } + return true; +} + + +bool checkStringCaseInsensitive(const char * s, ReadBuffer & buf) +{ + for (; *s; ++s) + { + if (buf.eof()) + return false; + + char c = *buf.position(); + if (!equalsCaseInsensitive(*s, c)) + return false; + + ++buf.position(); + } + return true; +} + + +void assertString(const char * s, ReadBuffer & buf) +{ + if (!checkString(s, buf)) + throwAtAssertionFailed(s, buf); +} + + +void assertEOF(ReadBuffer & buf) +{ + if (!buf.eof()) + throwAtAssertionFailed("eof", buf); +} + +void assertNotEOF(ReadBuffer & buf) +{ + if (buf.eof()) + throw Exception(ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF, "Attempt to read after EOF"); +} + + +void assertStringCaseInsensitive(const char * s, ReadBuffer & buf) +{ + if (!checkStringCaseInsensitive(s, buf)) + throwAtAssertionFailed(s, buf); +} + + +bool checkStringByFirstCharacterAndAssertTheRest(const char * s, ReadBuffer & buf) +{ + if (buf.eof() || *buf.position() != *s) + return false; + + assertString(s, buf); + return true; +} + +bool checkStringByFirstCharacterAndAssertTheRestCaseInsensitive(const char * s, ReadBuffer & buf) +{ + if (buf.eof()) + return false; + + char c = *buf.position(); + if (!equalsCaseInsensitive(*s, c)) + return false; + + assertStringCaseInsensitive(s, buf); + return true; +} + + +template <typename T> +static void appendToStringOrVector(T & s, ReadBuffer & rb, const char * end) +{ + s.append(rb.position(), end - rb.position()); +} + +template <> +inline void appendToStringOrVector(PaddedPODArray<UInt8> & s, ReadBuffer & rb, const char * end) +{ + if (rb.isPadded()) + s.insertSmallAllowReadWriteOverflow15(rb.position(), end); + else + s.insert(rb.position(), end); +} + +template <> +inline void appendToStringOrVector(PODArray<char> & s, ReadBuffer & rb, const char * end) +{ + s.insert(rb.position(), end); +} + +template <char... chars, typename Vector> +void readStringUntilCharsInto(Vector & s, ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<chars...>(buf.position(), buf.buffer().end()); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (buf.hasPendingData()) + return; + } +} + +template <typename Vector> +void readStringInto(Vector & s, ReadBuffer & buf) +{ + readStringUntilCharsInto<'\t', '\n'>(s, buf); +} + +template <typename Vector> +void readStringUntilWhitespaceInto(Vector & s, ReadBuffer & buf) +{ + readStringUntilCharsInto<' '>(s, buf); +} + +template <typename Vector> +void readStringUntilNewlineInto(Vector & s, ReadBuffer & buf) +{ + readStringUntilCharsInto<'\n'>(s, buf); +} + +template void readStringUntilNewlineInto<PaddedPODArray<UInt8>>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); +template void readStringUntilNewlineInto<String>(String & s, ReadBuffer & buf); + +template <typename Vector> +void readNullTerminated(Vector & s, ReadBuffer & buf) +{ + readStringUntilCharsInto<'\0'>(s, buf); + buf.ignore(); +} + +void readStringUntilWhitespace(String & s, ReadBuffer & buf) +{ + s.clear(); + readStringUntilWhitespaceInto(s, buf); +} + +template void readNullTerminated<PODArray<char>>(PODArray<char> & s, ReadBuffer & buf); +template void readNullTerminated<String>(String & s, ReadBuffer & buf); + +void readString(String & s, ReadBuffer & buf) +{ + s.clear(); + readStringInto(s, buf); +} + +template void readStringInto<PaddedPODArray<UInt8>>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); +template void readStringInto<String>(String & s, ReadBuffer & buf); +template void readStringInto<NullOutput>(NullOutput & s, ReadBuffer & buf); + +template <typename Vector> +void readStringUntilEOFInto(Vector & s, ReadBuffer & buf) +{ + while (!buf.eof()) + { + appendToStringOrVector(s, buf, buf.buffer().end()); + buf.position() = buf.buffer().end(); + } +} + + +void readStringUntilEOF(String & s, ReadBuffer & buf) +{ + s.clear(); + readStringUntilEOFInto(s, buf); +} + +template <typename Vector> +void readEscapedStringUntilEOLInto(Vector & s, ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\n', '\\'>(buf.position(), buf.buffer().end()); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\n') + return; + + if (*buf.position() == '\\') + parseComplexEscapeSequence(s, buf); + } +} + + +void readEscapedStringUntilEOL(String & s, ReadBuffer & buf) +{ + s.clear(); + readEscapedStringUntilEOLInto(s, buf); +} + +template void readStringUntilEOFInto<PaddedPODArray<UInt8>>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); + + +/** Parse the escape sequence, which can be simple (one character after backslash) or more complex (multiple characters). + * It is assumed that the cursor is located on the `\` symbol + */ +template <typename Vector, typename ReturnType = void> +static ReturnType parseComplexEscapeSequence(Vector & s, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + auto error = [](const char * message [[maybe_unused]], int code [[maybe_unused]]) + { + if constexpr (throw_exception) + throw Exception::createDeprecated(message, code); + return ReturnType(false); + }; + + ++buf.position(); + + if (buf.eof()) + { + return error("Cannot parse escape sequence", ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + } + + char char_after_backslash = *buf.position(); + + if (char_after_backslash == 'x') + { + ++buf.position(); + /// escape sequence of the form \xAA + char hex_code[2]; + + auto bytes_read = buf.read(hex_code, sizeof(hex_code)); + + if (bytes_read != sizeof(hex_code)) + { + return error("Cannot parse escape sequence", ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + } + + s.push_back(unhex2(hex_code)); + } + else if (char_after_backslash == 'N') + { + /// Support for NULLs: \N sequence must be parsed as empty string. + ++buf.position(); + } + else + { + /// The usual escape sequence of a single character. + char decoded_char = parseEscapeSequence(char_after_backslash); + + /// For convenience using LIKE and regular expressions, + /// we leave backslash when user write something like 'Hello 100\%': + /// it is parsed like Hello 100\% instead of Hello 100% + if (decoded_char != '\\' + && decoded_char != '\'' + && decoded_char != '"' + && decoded_char != '`' /// MySQL style identifiers + && decoded_char != '/' /// JavaScript in HTML + && decoded_char != '=' /// TSKV format invented somewhere + && !isControlASCII(decoded_char)) + { + s.push_back('\\'); + } + + s.push_back(decoded_char); + ++buf.position(); + } + + return ReturnType(true); +} + +bool parseComplexEscapeSequence(String & s, ReadBuffer & buf) +{ + return parseComplexEscapeSequence<String, bool>(s, buf); +} + +template <typename Vector, typename ReturnType> +static ReturnType parseJSONEscapeSequence(Vector & s, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + auto error = [](const char * message [[maybe_unused]], int code [[maybe_unused]]) + { + if constexpr (throw_exception) + throw Exception::createDeprecated(message, code); + return ReturnType(false); + }; + + ++buf.position(); + + if (buf.eof()) + return error("Cannot parse escape sequence", ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + + assert(buf.hasPendingData()); + + switch (*buf.position()) + { + case '"': + s.push_back('"'); + break; + case '\\': + s.push_back('\\'); + break; + case '/': + s.push_back('/'); + break; + case 'b': + s.push_back('\b'); + break; + case 'f': + s.push_back('\f'); + break; + case 'n': + s.push_back('\n'); + break; + case 'r': + s.push_back('\r'); + break; + case 't': + s.push_back('\t'); + break; + case 'u': + { + ++buf.position(); + + char hex_code[4]; + if (4 != buf.read(hex_code, 4)) + return error("Cannot parse escape sequence: less than four bytes after \\u", ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + + /// \u0000 - special case + if (0 == memcmp(hex_code, "0000", 4)) + { + s.push_back(0); + return ReturnType(true); + } + + UInt16 code_point = unhex4(hex_code); + + if (code_point <= 0x7F) + { + s.push_back(code_point); + } + else if (code_point <= 0x07FF) + { + s.push_back(((code_point >> 6) & 0x1F) | 0xC0); + s.push_back((code_point & 0x3F) | 0x80); + } + else + { + /// Surrogate pair. + if (code_point >= 0xD800 && code_point <= 0xDBFF) + { + if (!checkString("\\u", buf)) + return error("Cannot parse escape sequence: missing second part of surrogate pair", ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + + char second_hex_code[4]; + if (4 != buf.read(second_hex_code, 4)) + return error("Cannot parse escape sequence: less than four bytes after \\u of second part of surrogate pair", + ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + + UInt16 second_code_point = unhex4(second_hex_code); + + if (second_code_point >= 0xDC00 && second_code_point <= 0xDFFF) + { + UInt32 full_code_point = 0x10000 + (code_point - 0xD800) * 1024 + (second_code_point - 0xDC00); + + s.push_back(((full_code_point >> 18) & 0x07) | 0xF0); + s.push_back(((full_code_point >> 12) & 0x3F) | 0x80); + s.push_back(((full_code_point >> 6) & 0x3F) | 0x80); + s.push_back((full_code_point & 0x3F) | 0x80); + } + else + return error("Incorrect surrogate pair of unicode escape sequences in JSON", ErrorCodes::CANNOT_PARSE_ESCAPE_SEQUENCE); + } + else + { + s.push_back(((code_point >> 12) & 0x0F) | 0xE0); + s.push_back(((code_point >> 6) & 0x3F) | 0x80); + s.push_back((code_point & 0x3F) | 0x80); + } + } + + return ReturnType(true); + } + default: + s.push_back(*buf.position()); + break; + } + + ++buf.position(); + return ReturnType(true); +} + + +template <typename Vector, bool parse_complex_escape_sequence> +void readEscapedStringIntoImpl(Vector & s, ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\t', '\n', '\\'>(buf.position(), buf.buffer().end()); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\t' || *buf.position() == '\n') + return; + + if (*buf.position() == '\\') + { + if constexpr (parse_complex_escape_sequence) + { + parseComplexEscapeSequence(s, buf); + } + else + { + s.push_back(*buf.position()); + ++buf.position(); + if (!buf.eof()) + { + s.push_back(*buf.position()); + ++buf.position(); + } + } + } + } +} + +template <typename Vector> +void readEscapedStringInto(Vector & s, ReadBuffer & buf) +{ + readEscapedStringIntoImpl<Vector, true>(s, buf); +} + + +void readEscapedString(String & s, ReadBuffer & buf) +{ + s.clear(); + readEscapedStringInto(s, buf); +} + +template void readEscapedStringInto<PaddedPODArray<UInt8>>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); +template void readEscapedStringInto<NullOutput>(NullOutput & s, ReadBuffer & buf); + + +/** If enable_sql_style_quoting == true, + * strings like 'abc''def' will be parsed as abc'def. + * Please note, that even with SQL style quoting enabled, + * backslash escape sequences are also parsed, + * that could be slightly confusing. + */ +template <char quote, bool enable_sql_style_quoting, typename Vector, typename ReturnType = void> +static ReturnType readAnyQuotedStringInto(Vector & s, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + if (buf.eof() || *buf.position() != quote) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_QUOTED_STRING, + "Cannot parse quoted string: expected opening quote '{}', got '{}'", + std::string{quote}, buf.eof() ? "EOF" : std::string{*buf.position()}); + else + return ReturnType(false); + } + + ++buf.position(); + + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\\', quote>(buf.position(), buf.buffer().end()); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == quote) + { + ++buf.position(); + + if (enable_sql_style_quoting && !buf.eof() && *buf.position() == quote) + { + s.push_back(quote); + ++buf.position(); + continue; + } + + return ReturnType(true); + } + + if (*buf.position() == '\\') + { + if constexpr (throw_exception) + parseComplexEscapeSequence<Vector, ReturnType>(s, buf); + else + { + if (!parseComplexEscapeSequence<Vector, ReturnType>(s, buf)) + return ReturnType(false); + } + } + } + + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_QUOTED_STRING, "Cannot parse quoted string: expected closing quote"); + else + return ReturnType(false); +} + +template <bool enable_sql_style_quoting, typename Vector> +void readQuotedStringInto(Vector & s, ReadBuffer & buf) +{ + readAnyQuotedStringInto<'\'', enable_sql_style_quoting>(s, buf); +} + +template <typename Vector> +bool tryReadQuotedStringInto(Vector & s, ReadBuffer & buf) +{ + return readAnyQuotedStringInto<'\'', false, Vector, bool>(s, buf); +} + +template bool tryReadQuotedStringInto(String & s, ReadBuffer & buf); + +template <bool enable_sql_style_quoting, typename Vector> +void readDoubleQuotedStringInto(Vector & s, ReadBuffer & buf) +{ + readAnyQuotedStringInto<'"', enable_sql_style_quoting>(s, buf); +} + +template <bool enable_sql_style_quoting, typename Vector> +void readBackQuotedStringInto(Vector & s, ReadBuffer & buf) +{ + readAnyQuotedStringInto<'`', enable_sql_style_quoting>(s, buf); +} + + +void readQuotedString(String & s, ReadBuffer & buf) +{ + s.clear(); + readQuotedStringInto<false>(s, buf); +} + +void readQuotedStringWithSQLStyle(String & s, ReadBuffer & buf) +{ + s.clear(); + readQuotedStringInto<true>(s, buf); +} + + +template void readQuotedStringInto<true>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); +template void readQuotedStringInto<true>(String & s, ReadBuffer & buf); +template void readQuotedStringInto<false>(String & s, ReadBuffer & buf); +template void readDoubleQuotedStringInto<false>(NullOutput & s, ReadBuffer & buf); +template void readDoubleQuotedStringInto<false>(String & s, ReadBuffer & buf); +template void readBackQuotedStringInto<false>(String & s, ReadBuffer & buf); + +void readDoubleQuotedString(String & s, ReadBuffer & buf) +{ + s.clear(); + readDoubleQuotedStringInto<false>(s, buf); +} + +void readDoubleQuotedStringWithSQLStyle(String & s, ReadBuffer & buf) +{ + s.clear(); + readDoubleQuotedStringInto<true>(s, buf); +} + +void readBackQuotedString(String & s, ReadBuffer & buf) +{ + s.clear(); + readBackQuotedStringInto<false>(s, buf); +} + +void readBackQuotedStringWithSQLStyle(String & s, ReadBuffer & buf) +{ + s.clear(); + readBackQuotedStringInto<true>(s, buf); +} + +template<typename T> +concept WithResize = requires (T value) +{ + { value.resize(1) }; + { value.size() } -> std::integral<>; +}; + +template <typename Vector, bool include_quotes> +void readCSVStringInto(Vector & s, ReadBuffer & buf, const FormatSettings::CSV & settings) +{ + /// Empty string + if (buf.eof()) + return; + + const char delimiter = settings.delimiter; + const char maybe_quote = *buf.position(); + const String & custom_delimiter = settings.custom_delimiter; + + /// Emptiness and not even in quotation marks. + if (custom_delimiter.empty() && maybe_quote == delimiter) + return; + + if ((settings.allow_single_quotes && maybe_quote == '\'') || (settings.allow_double_quotes && maybe_quote == '"')) + { + if constexpr (include_quotes) + s.push_back(maybe_quote); + + ++buf.position(); + + /// The quoted case. We are looking for the next quotation mark. + while (!buf.eof()) + { + char * next_pos = reinterpret_cast<char *>(memchr(buf.position(), maybe_quote, buf.buffer().end() - buf.position())); + + if (nullptr == next_pos) + next_pos = buf.buffer().end(); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if constexpr (include_quotes) + s.push_back(maybe_quote); + + /// Now there is a quotation mark under the cursor. Is there any following? + ++buf.position(); + + if (buf.eof()) + return; + + if (*buf.position() == maybe_quote) + { + s.push_back(maybe_quote); + ++buf.position(); + continue; + } + + return; + } + } + else + { + /// If custom_delimiter is specified, we should read until first occurrences of + /// custom_delimiter in buffer. + if (!custom_delimiter.empty()) + { + PeekableReadBuffer * peekable_buf = dynamic_cast<PeekableReadBuffer *>(&buf); + if (!peekable_buf) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Reading CSV string with custom delimiter is allowed only when using PeekableReadBuffer"); + + while (true) + { + if (peekable_buf->eof()) + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected EOF while reading CSV string, expected custom delimiter \"{}\"", custom_delimiter); + + char * next_pos = reinterpret_cast<char *>(memchr(peekable_buf->position(), custom_delimiter[0], peekable_buf->available())); + if (!next_pos) + next_pos = peekable_buf->buffer().end(); + + appendToStringOrVector(s, *peekable_buf, next_pos); + peekable_buf->position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + { + PeekableReadBufferCheckpoint checkpoint{*peekable_buf, true}; + if (checkString(custom_delimiter, *peekable_buf)) + return; + } + + s.push_back(*peekable_buf->position()); + ++peekable_buf->position(); + } + + return; + } + + /// Unquoted case. Look for delimiter or \r or \n. + while (!buf.eof()) + { + char * next_pos = buf.position(); + + [&]() + { +#ifdef __SSE2__ + auto rc = _mm_set1_epi8('\r'); + auto nc = _mm_set1_epi8('\n'); + auto dc = _mm_set1_epi8(delimiter); + for (; next_pos + 15 < buf.buffer().end(); next_pos += 16) + { + __m128i bytes = _mm_loadu_si128(reinterpret_cast<const __m128i *>(next_pos)); + auto eq = _mm_or_si128(_mm_or_si128(_mm_cmpeq_epi8(bytes, rc), _mm_cmpeq_epi8(bytes, nc)), _mm_cmpeq_epi8(bytes, dc)); + uint16_t bit_mask = _mm_movemask_epi8(eq); + if (bit_mask) + { + next_pos += std::countr_zero(bit_mask); + return; + } + } +#elif defined(__aarch64__) && defined(__ARM_NEON) + auto rc = vdupq_n_u8('\r'); + auto nc = vdupq_n_u8('\n'); + auto dc = vdupq_n_u8(delimiter); + for (; next_pos + 15 < buf.buffer().end(); next_pos += 16) + { + uint8x16_t bytes = vld1q_u8(reinterpret_cast<const uint8_t *>(next_pos)); + auto eq = vorrq_u8(vorrq_u8(vceqq_u8(bytes, rc), vceqq_u8(bytes, nc)), vceqq_u8(bytes, dc)); + uint64_t bit_mask = getNibbleMask(eq); + if (bit_mask) + { + next_pos += std::countr_zero(bit_mask) >> 2; + return; + } + } +#endif + while (next_pos < buf.buffer().end() + && *next_pos != delimiter && *next_pos != '\r' && *next_pos != '\n') + ++next_pos; + }(); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if constexpr (WithResize<Vector>) + { + if (settings.trim_whitespaces) [[likely]] + { + /** CSV format can contain insignificant spaces and tabs. + * Usually the task of skipping them is for the calling code. + * But in this case, it will be difficult to do this, so remove the trailing whitespace by ourself. + */ + size_t size = s.size(); + while (size > 0 && (s[size - 1] == ' ' || s[size - 1] == '\t')) + --size; + + s.resize(size); + } + } + return; + } + } +} + +void readCSVString(String & s, ReadBuffer & buf, const FormatSettings::CSV & settings) +{ + s.clear(); + readCSVStringInto(s, buf, settings); +} + +void readCSVField(String & s, ReadBuffer & buf, const FormatSettings::CSV & settings) +{ + s.clear(); + readCSVStringInto<String, true>(s, buf, settings); +} + +void readCSVWithTwoPossibleDelimitersImpl(String & s, PeekableReadBuffer & buf, const String & first_delimiter, const String & second_delimiter) +{ + /// Check that delimiters are not empty. + if (first_delimiter.empty() || second_delimiter.empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Cannot read CSV field with two possible delimiters, one " + "of delimiters '{}' and '{}' is empty", first_delimiter, second_delimiter); + + /// Read all data until first_delimiter or second_delimiter + while (true) + { + if (buf.eof()) + throw Exception(ErrorCodes::INCORRECT_DATA, R"(Unexpected EOF while reading CSV string, expected on " + "of delimiters "{}" or "{}")", first_delimiter, second_delimiter); + + char * next_pos = buf.position(); + while (next_pos != buf.buffer().end() && *next_pos != first_delimiter[0] && *next_pos != second_delimiter[0]) + ++next_pos; + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == first_delimiter[0]) + { + PeekableReadBufferCheckpoint checkpoint(buf, true); + if (checkString(first_delimiter, buf)) + return; + } + + if (*buf.position() == second_delimiter[0]) + { + PeekableReadBufferCheckpoint checkpoint(buf, true); + if (checkString(second_delimiter, buf)) + return; + } + + s.push_back(*buf.position()); + ++buf.position(); + } +} + +String readCSVStringWithTwoPossibleDelimiters(PeekableReadBuffer & buf, const FormatSettings::CSV & settings, const String & first_delimiter, const String & second_delimiter) +{ + String res; + + /// If value is quoted, use regular CSV reading since we need to read only data inside quotes. + if (!buf.eof() && ((settings.allow_single_quotes && *buf.position() == '\'') || (settings.allow_double_quotes && *buf.position() == '"'))) + readCSVStringInto(res, buf, settings); + else + readCSVWithTwoPossibleDelimitersImpl(res, buf, first_delimiter, second_delimiter); + + return res; +} + +String readCSVFieldWithTwoPossibleDelimiters(PeekableReadBuffer & buf, const FormatSettings::CSV & settings, const String & first_delimiter, const String & second_delimiter) +{ + String res; + + /// If value is quoted, use regular CSV reading since we need to read only data inside quotes. + if (!buf.eof() && ((settings.allow_single_quotes && *buf.position() == '\'') || (settings.allow_double_quotes && *buf.position() == '"'))) + readCSVField(res, buf, settings); + else + readCSVWithTwoPossibleDelimitersImpl(res, buf, first_delimiter, second_delimiter); + + return res; +} + +template void readCSVStringInto<PaddedPODArray<UInt8>>(PaddedPODArray<UInt8> & s, ReadBuffer & buf, const FormatSettings::CSV & settings); +template void readCSVStringInto<NullOutput>(NullOutput & s, ReadBuffer & buf, const FormatSettings::CSV & settings); + + +template <typename Vector, typename ReturnType> +ReturnType readJSONStringInto(Vector & s, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + auto error = [](FormatStringHelper<> message [[maybe_unused]], int code [[maybe_unused]]) + { + if constexpr (throw_exception) + throw ParsingException(code, std::move(message)); + return ReturnType(false); + }; + + if (buf.eof() || *buf.position() != '"') + return error("Cannot parse JSON string: expected opening quote", ErrorCodes::CANNOT_PARSE_QUOTED_STRING); + ++buf.position(); + + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\\', '"'>(buf.position(), buf.buffer().end()); + + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '"') + { + ++buf.position(); + return ReturnType(true); + } + + if (*buf.position() == '\\') + parseJSONEscapeSequence<Vector, ReturnType>(s, buf); + } + + return error("Cannot parse JSON string: expected closing quote", ErrorCodes::CANNOT_PARSE_QUOTED_STRING); +} + +void readJSONString(String & s, ReadBuffer & buf) +{ + s.clear(); + readJSONStringInto(s, buf); +} + +template void readJSONStringInto<PaddedPODArray<UInt8>, void>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); +template bool readJSONStringInto<PaddedPODArray<UInt8>, bool>(PaddedPODArray<UInt8> & s, ReadBuffer & buf); +template void readJSONStringInto<NullOutput>(NullOutput & s, ReadBuffer & buf); +template void readJSONStringInto<String>(String & s, ReadBuffer & buf); +template bool readJSONStringInto<String, bool>(String & s, ReadBuffer & buf); + +template <typename Vector, typename ReturnType> +ReturnType readJSONObjectPossiblyInvalid(Vector & s, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + auto error = [](FormatStringHelper<> message [[maybe_unused]], int code [[maybe_unused]]) + { + if constexpr (throw_exception) + throw ParsingException(code, std::move(message)); + return ReturnType(false); + }; + + if (buf.eof() || *buf.position() != '{') + return error("JSON should start from opening curly bracket", ErrorCodes::INCORRECT_DATA); + + s.push_back(*buf.position()); + ++buf.position(); + + Int64 balance = 1; + bool quotes = false; + + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\\', '{', '}', '"'>(buf.position(), buf.buffer().end()); + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + s.push_back(*buf.position()); + + if (*buf.position() == '\\') + { + ++buf.position(); + if (!buf.eof()) + { + s.push_back(*buf.position()); + ++buf.position(); + } + + continue; + } + + if (*buf.position() == '"') + quotes = !quotes; + else if (!quotes) // can be only '{' or '}' + balance += *buf.position() == '{' ? 1 : -1; + + ++buf.position(); + + if (balance == 0) + return ReturnType(true); + + if (balance < 0) + break; + } + + return error("JSON should have equal number of opening and closing brackets", ErrorCodes::INCORRECT_DATA); +} + +template void readJSONObjectPossiblyInvalid<String>(String & s, ReadBuffer & buf); + +template <typename ReturnType> +ReturnType readDateTextFallback(LocalDate & date, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + auto error = [] + { + if constexpr (throw_exception) + throw Exception(ErrorCodes::CANNOT_PARSE_DATE, "Cannot parse date: value is too short"); + return ReturnType(false); + }; + + auto append_digit = [&](auto & x) + { + if (!buf.eof() && isNumericASCII(*buf.position())) + { + x = x * 10 + (*buf.position() - '0'); + ++buf.position(); + return true; + } + else + return false; + }; + + UInt16 year = 0; + UInt8 month = 0; + UInt8 day = 0; + + if (!append_digit(year) + || !append_digit(year) // NOLINT + || !append_digit(year) // NOLINT + || !append_digit(year)) // NOLINT + return error(); + + if (buf.eof()) + return error(); + + if (isNumericASCII(*buf.position())) + { + /// YYYYMMDD + if (!append_digit(month) + || !append_digit(month) // NOLINT + || !append_digit(day) + || !append_digit(day)) // NOLINT + return error(); + } + else + { + ++buf.position(); + + if (!append_digit(month)) + return error(); + append_digit(month); + + if (!buf.eof() && !isNumericASCII(*buf.position())) + ++buf.position(); + else + return error(); + + if (!append_digit(day)) + return error(); + append_digit(day); + } + + date = LocalDate(year, month, day); + return ReturnType(true); +} + +template void readDateTextFallback<void>(LocalDate &, ReadBuffer &); +template bool readDateTextFallback<bool>(LocalDate &, ReadBuffer &); + + +template <typename ReturnType> +ReturnType readDateTimeTextFallback(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & date_lut) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + /// YYYY-MM-DD + static constexpr auto date_broken_down_length = 10; + /// hh:mm:ss + static constexpr auto time_broken_down_length = 8; + /// YYYY-MM-DD hh:mm:ss + static constexpr auto date_time_broken_down_length = date_broken_down_length + 1 + time_broken_down_length; + + char s[date_time_broken_down_length]; + char * s_pos = s; + + /** Read characters, that could represent unix timestamp. + * Only unix timestamp of at least 5 characters is supported. + * Then look at 5th character. If it is a number - treat whole as unix timestamp. + * If it is not a number - then parse datetime in YYYY-MM-DD hh:mm:ss or YYYY-MM-DD format. + */ + + /// A piece similar to unix timestamp, maybe scaled to subsecond precision. + while (s_pos < s + date_time_broken_down_length && !buf.eof() && isNumericASCII(*buf.position())) + { + *s_pos = *buf.position(); + ++s_pos; + ++buf.position(); + } + + /// 2015-01-01 01:02:03 or 2015-01-01 + if (s_pos == s + 4 && !buf.eof() && !isNumericASCII(*buf.position())) + { + const auto already_read_length = s_pos - s; + const size_t remaining_date_size = date_broken_down_length - already_read_length; + + size_t size = buf.read(s_pos, remaining_date_size); + if (size != remaining_date_size) + { + s_pos[size] = 0; + + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot parse DateTime {}", s); + else + return false; + } + + UInt16 year = (s[0] - '0') * 1000 + (s[1] - '0') * 100 + (s[2] - '0') * 10 + (s[3] - '0'); + UInt8 month = (s[5] - '0') * 10 + (s[6] - '0'); + UInt8 day = (s[8] - '0') * 10 + (s[9] - '0'); + + UInt8 hour = 0; + UInt8 minute = 0; + UInt8 second = 0; + + if (!buf.eof() && (*buf.position() == ' ' || *buf.position() == 'T')) + { + ++buf.position(); + size = buf.read(s, time_broken_down_length); + + if (size != time_broken_down_length) + { + s_pos[size] = 0; + + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot parse time component of DateTime {}", s); + else + return false; + } + + hour = (s[0] - '0') * 10 + (s[1] - '0'); + minute = (s[3] - '0') * 10 + (s[4] - '0'); + second = (s[6] - '0') * 10 + (s[7] - '0'); + } + + if (unlikely(year == 0)) + datetime = 0; + else + datetime = date_lut.makeDateTime(year, month, day, hour, minute, second); + } + else + { + if (s_pos - s >= 5) + { + /// Not very efficient. + datetime = 0; + for (const char * digit_pos = s; digit_pos < s_pos; ++digit_pos) + datetime = datetime * 10 + *digit_pos - '0'; + } + else + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot parse datetime"); + else + return false; + } + } + + return ReturnType(true); +} + +template void readDateTimeTextFallback<void>(time_t &, ReadBuffer &, const DateLUTImpl &); +template bool readDateTimeTextFallback<bool>(time_t &, ReadBuffer &, const DateLUTImpl &); + + +void skipJSONField(ReadBuffer & buf, StringRef name_of_field) +{ + if (buf.eof()) + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected EOF for key '{}'", name_of_field.toString()); + else if (*buf.position() == '"') /// skip double-quoted string + { + NullOutput sink; + readJSONStringInto(sink, buf); + } + else if (isNumericASCII(*buf.position()) || *buf.position() == '-' || *buf.position() == '+' || *buf.position() == '.') /// skip number + { + if (*buf.position() == '+') + ++buf.position(); + + double v; + if (!tryReadFloatText(v, buf)) + throw Exception(ErrorCodes::INCORRECT_DATA, "Expected a number field for key '{}'", name_of_field.toString()); + } + else if (*buf.position() == 'n') /// skip null + { + assertString("null", buf); + } + else if (*buf.position() == 't') /// skip true + { + assertString("true", buf); + } + else if (*buf.position() == 'f') /// skip false + { + assertString("false", buf); + } + else if (*buf.position() == '[') + { + ++buf.position(); + skipWhitespaceIfAny(buf); + + if (!buf.eof() && *buf.position() == ']') /// skip empty array + { + ++buf.position(); + return; + } + + while (true) + { + skipJSONField(buf, name_of_field); + skipWhitespaceIfAny(buf); + + if (!buf.eof() && *buf.position() == ',') + { + ++buf.position(); + skipWhitespaceIfAny(buf); + } + else if (!buf.eof() && *buf.position() == ']') + { + ++buf.position(); + break; + } + else + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected symbol for key '{}'", name_of_field.toString()); + } + } + else if (*buf.position() == '{') /// skip whole object + { + ++buf.position(); + skipWhitespaceIfAny(buf); + + while (!buf.eof() && *buf.position() != '}') + { + // field name + if (*buf.position() == '"') + { + NullOutput sink; + readJSONStringInto(sink, buf); + } + else + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected symbol for key '{}'", name_of_field.toString()); + + // ':' + skipWhitespaceIfAny(buf); + if (buf.eof() || !(*buf.position() == ':')) + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected symbol for key '{}'", name_of_field.toString()); + ++buf.position(); + skipWhitespaceIfAny(buf); + + skipJSONField(buf, name_of_field); + skipWhitespaceIfAny(buf); + + // optional ',' + if (!buf.eof() && *buf.position() == ',') + { + ++buf.position(); + skipWhitespaceIfAny(buf); + } + } + + if (buf.eof()) + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected EOF for key '{}'", name_of_field.toString()); + ++buf.position(); + } + else + { + throw Exception(ErrorCodes::INCORRECT_DATA, "Unexpected symbol '{}' for key '{}'", + std::string(*buf.position(), 1), name_of_field.toString()); + } +} + + +Exception readException(ReadBuffer & buf, const String & additional_message, bool remote_exception) +{ + int code = 0; + String name; + String message; + String stack_trace; + bool has_nested = false; /// Obsolete + + readBinaryLittleEndian(code, buf); + readBinary(name, buf); + readBinary(message, buf); + readBinary(stack_trace, buf); + readBinary(has_nested, buf); + + WriteBufferFromOwnString out; + + if (!additional_message.empty()) + out << additional_message << ". "; + + if (name != "DB::Exception") + out << name << ". "; + + out << message << "."; + + if (!stack_trace.empty()) + out << " Stack trace:\n\n" << stack_trace; + + return Exception::createDeprecated(out.str(), code, remote_exception); +} + +void readAndThrowException(ReadBuffer & buf, const String & additional_message) +{ + readException(buf, additional_message).rethrow(); +} + + +void skipToCarriageReturnOrEOF(ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\r'>(buf.position(), buf.buffer().end()); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\r') + { + ++buf.position(); + return; + } + } +} + + +void skipToNextLineOrEOF(ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\n'>(buf.position(), buf.buffer().end()); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\n') + { + ++buf.position(); + return; + } + } +} + + +void skipToUnescapedNextLineOrEOF(ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\n', '\\'>(buf.position(), buf.buffer().end()); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\n') + { + ++buf.position(); + return; + } + + if (*buf.position() == '\\') + { + ++buf.position(); + if (buf.eof()) + return; + + /// Skip escaped character. We do not consider escape sequences with more than one character after backslash (\x01). + /// It's ok for the purpose of this function, because we are interested only in \n and \\. + ++buf.position(); + continue; + } + } +} + +void skipNullTerminated(ReadBuffer & buf) +{ + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\0'>(buf.position(), buf.buffer().end()); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\0') + { + ++buf.position(); + return; + } + } +} + + +void saveUpToPosition(ReadBuffer & in, Memory<> & memory, char * current) +{ + assert(current >= in.position()); + assert(current <= in.buffer().end()); + + const size_t old_bytes = memory.size(); + const size_t additional_bytes = current - in.position(); + const size_t new_bytes = old_bytes + additional_bytes; + + /// There are no new bytes to add to memory. + /// No need to do extra stuff. + if (new_bytes == 0) + return; + + assert(in.position() + additional_bytes <= in.buffer().end()); + memory.resize(new_bytes); + memcpy(memory.data() + old_bytes, in.position(), additional_bytes); + in.position() = current; +} + +bool loadAtPosition(ReadBuffer & in, Memory<> & memory, char * & current) +{ + assert(current <= in.buffer().end()); + + if (current < in.buffer().end()) + return true; + + saveUpToPosition(in, memory, current); + + bool loaded_more = !in.eof(); + // A sanity check. Buffer position may be in the beginning of the buffer + // (normal case), or have some offset from it (AIO). + assert(in.position() >= in.buffer().begin()); + assert(in.position() <= in.buffer().end()); + current = in.position(); + + return loaded_more; +} + +/// Searches for delimiter in input stream and sets buffer position after delimiter (if found) or EOF (if not) +static void findAndSkipNextDelimiter(PeekableReadBuffer & buf, const String & delimiter) +{ + if (delimiter.empty()) + return; + + while (!buf.eof()) + { + void * pos = memchr(buf.position(), delimiter[0], buf.available()); + if (!pos) + { + buf.position() += buf.available(); + continue; + } + + buf.position() = static_cast<ReadBuffer::Position>(pos); + + PeekableReadBufferCheckpoint checkpoint{buf}; + if (checkString(delimiter, buf)) + return; + + buf.rollbackToCheckpoint(); + ++buf.position(); + } +} + +void skipToNextRowOrEof(PeekableReadBuffer & buf, const String & row_after_delimiter, const String & row_between_delimiter, bool skip_spaces) +{ + if (row_after_delimiter.empty()) + { + findAndSkipNextDelimiter(buf, row_between_delimiter); + return; + } + + while (true) + { + findAndSkipNextDelimiter(buf, row_after_delimiter); + + if (skip_spaces) + skipWhitespaceIfAny(buf); + + if (checkString(row_between_delimiter, buf)) + break; + } +} + +// Use PeekableReadBuffer to copy field to string after parsing. +template <typename Vector, typename ParseFunc> +static void readParsedValueInto(Vector & s, ReadBuffer & buf, ParseFunc parse_func) +{ + PeekableReadBuffer peekable_buf(buf); + peekable_buf.setCheckpoint(); + parse_func(peekable_buf); + peekable_buf.makeContinuousMemoryFromCheckpointToPos(); + auto * end = peekable_buf.position(); + peekable_buf.rollbackToCheckpoint(); + s.append(peekable_buf.position(), end); + peekable_buf.position() = end; +} + +template <typename Vector> +static void readQuotedStringFieldInto(Vector & s, ReadBuffer & buf) +{ + assertChar('\'', buf); + s.push_back('\''); + while (!buf.eof()) + { + char * next_pos = find_first_symbols<'\\', '\''>(buf.position(), buf.buffer().end()); + + s.append(buf.position(), next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\'') + break; + + s.push_back(*buf.position()); + if (*buf.position() == '\\') + { + ++buf.position(); + if (!buf.eof()) + { + s.push_back(*buf.position()); + ++buf.position(); + } + } + } + + if (buf.eof()) + return; + + ++buf.position(); + s.push_back('\''); +} + +template <char opening_bracket, char closing_bracket, typename Vector> +static void readQuotedFieldInBracketsInto(Vector & s, ReadBuffer & buf) +{ + assertChar(opening_bracket, buf); + s.push_back(opening_bracket); + + size_t balance = 1; + + while (!buf.eof() && balance) + { + char * next_pos = find_first_symbols<'\'', opening_bracket, closing_bracket>(buf.position(), buf.buffer().end()); + appendToStringOrVector(s, buf, next_pos); + buf.position() = next_pos; + + if (!buf.hasPendingData()) + continue; + + if (*buf.position() == '\'') + { + readQuotedStringFieldInto(s, buf); + } + else if (*buf.position() == opening_bracket) + { + s.push_back(opening_bracket); + ++balance; + ++buf.position(); + } + else if (*buf.position() == closing_bracket) + { + s.push_back(closing_bracket); + --balance; + ++buf.position(); + } + } +} + +template <typename Vector> +void readQuotedFieldInto(Vector & s, ReadBuffer & buf) +{ + if (buf.eof()) + return; + + /// Possible values in 'Quoted' field: + /// - Strings: '...' + /// - Arrays: [...] + /// - Tuples: (...) + /// - Maps: {...} + /// - NULL + /// - Bool: true/false + /// - Number: integer, float, decimal. + + if (*buf.position() == '\'') + readQuotedStringFieldInto(s, buf); + else if (*buf.position() == '[') + readQuotedFieldInBracketsInto<'[', ']'>(s, buf); + else if (*buf.position() == '(') + readQuotedFieldInBracketsInto<'(', ')'>(s, buf); + else if (*buf.position() == '{') + readQuotedFieldInBracketsInto<'{', '}'>(s, buf); + else if (checkCharCaseInsensitive('n', buf)) + { + /// NULL or NaN + if (checkCharCaseInsensitive('u', buf)) + { + assertStringCaseInsensitive("ll", buf); + s.append("NULL"); + } + else + { + assertStringCaseInsensitive("an", buf); + s.append("NaN"); + } + } + else if (checkCharCaseInsensitive('t', buf)) + { + assertStringCaseInsensitive("rue", buf); + s.append("true"); + } + else if (checkCharCaseInsensitive('f', buf)) + { + assertStringCaseInsensitive("alse", buf); + s.append("false"); + } + else + { + /// It's an integer, float or decimal. They all can be parsed as float. + auto parse_func = [](ReadBuffer & in) + { + Float64 tmp; + readFloatText(tmp, in); + }; + readParsedValueInto(s, buf, parse_func); + } +} + +template void readQuotedFieldInto<NullOutput>(NullOutput & s, ReadBuffer & buf); + +void readQuotedField(String & s, ReadBuffer & buf) +{ + s.clear(); + readQuotedFieldInto(s, buf); +} + +void readJSONField(String & s, ReadBuffer & buf) +{ + s.clear(); + auto parse_func = [](ReadBuffer & in) { skipJSONField(in, "json_field"); }; + readParsedValueInto(s, buf, parse_func); +} + +void readTSVField(String & s, ReadBuffer & buf) +{ + s.clear(); + readEscapedStringIntoImpl<String, false>(s, buf); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadHelpers.h b/contrib/clickhouse/src/IO/ReadHelpers.h new file mode 100644 index 0000000000..f99c78fdf1 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadHelpers.h @@ -0,0 +1,1619 @@ +#pragma once + +#include <cmath> +#include <cstring> +#include <string> +#include <string_view> +#include <limits> +#include <algorithm> +#include <iterator> +#include <bit> +#include <span> + +#include <type_traits> + +#include <Common/StackTrace.h> +#include <Common/formatIPv6.h> +#include <Common/DateLUT.h> +#include <Common/LocalDate.h> +#include <Common/LocalDateTime.h> +#include <Common/TransformEndianness.hpp> +#include <base/StringRef.h> +#include <base/arithmeticOverflow.h> +#include <base/sort.h> +#include <base/unit.h> + +#include <Core/Types.h> +#include <Core/DecimalFunctions.h> +#include <Core/UUID.h> +#include <base/IPv4andIPv6.h> + +#include <Common/Allocator.h> +#include <Common/Exception.h> +#include <Common/StringUtils/StringUtils.h> +#include <Common/intExp.h> + +#include <Formats/FormatSettings.h> + +#include <IO/CompressionMethod.h> +#include <IO/ReadBuffer.h> +#include <IO/ReadBufferFromMemory.h> +#include <IO/PeekableReadBuffer.h> +#include <IO/VarInt.h> + +#include <double-conversion/double-conversion.h> + +static constexpr auto DEFAULT_MAX_STRING_SIZE = 1_GiB; + +namespace DB +{ + +template <typename Allocator> +struct Memory; + +namespace ErrorCodes +{ + extern const int CANNOT_PARSE_DATE; + extern const int CANNOT_PARSE_BOOL; + extern const int CANNOT_PARSE_DATETIME; + extern const int CANNOT_PARSE_UUID; + extern const int CANNOT_PARSE_IPV4; + extern const int CANNOT_PARSE_IPV6; + extern const int CANNOT_READ_ARRAY_FROM_TEXT; + extern const int CANNOT_PARSE_NUMBER; + extern const int INCORRECT_DATA; + extern const int TOO_LARGE_STRING_SIZE; + extern const int TOO_LARGE_ARRAY_SIZE; + extern const int SIZE_OF_FIXED_STRING_DOESNT_MATCH; +} + +/// Helper functions for formatted input. + +inline char parseEscapeSequence(char c) +{ + switch (c) + { + case 'a': + return '\a'; + case 'b': + return '\b'; + case 'e': + return '\x1B'; /// \e escape sequence is non standard for C and C++ but supported by gcc and clang. + case 'f': + return '\f'; + case 'n': + return '\n'; + case 'r': + return '\r'; + case 't': + return '\t'; + case 'v': + return '\v'; + case '0': + return '\0'; + default: + return c; + } +} + + +/// Function throwReadAfterEOF is located in VarInt.h + + +inline void readChar(char & x, ReadBuffer & buf) +{ + if (buf.eof()) [[unlikely]] + throwReadAfterEOF(); + x = *buf.position(); + ++buf.position(); +} + + +/// Read POD-type in native format +template <typename T> +inline void readPODBinary(T & x, ReadBuffer & buf) +{ + buf.readStrict(reinterpret_cast<char *>(&x), sizeof(x)); /// NOLINT +} + +inline void readUUIDBinary(UUID & x, ReadBuffer & buf) +{ + auto & uuid = x.toUnderType(); + readPODBinary(uuid.items[0], buf); + readPODBinary(uuid.items[1], buf); +} + +template <typename T> +inline void readIntBinary(T & x, ReadBuffer & buf) +{ + readPODBinary(x, buf); +} + +template <typename T> +inline void readFloatBinary(T & x, ReadBuffer & buf) +{ + readPODBinary(x, buf); +} + +inline void readStringBinary(std::string & s, ReadBuffer & buf, size_t max_string_size = DEFAULT_MAX_STRING_SIZE) +{ + size_t size = 0; + readVarUInt(size, buf); + + if (size > max_string_size) + throw Exception(ErrorCodes::TOO_LARGE_STRING_SIZE, "Too large string size."); + + s.resize(size); + buf.readStrict(s.data(), size); +} + +/// For historical reasons we store IPv6 as a String +inline void readIPv6Binary(IPv6 & ip, ReadBuffer & buf) +{ + size_t size = 0; + readVarUInt(size, buf); + + if (size != IPV6_BINARY_LENGTH) + throw Exception(ErrorCodes::SIZE_OF_FIXED_STRING_DOESNT_MATCH, + "Size of the string {} doesn't match size of binary IPv6 {}", size, IPV6_BINARY_LENGTH); + + buf.readStrict(reinterpret_cast<char*>(&ip.toUnderType()), size); +} + +template <typename T> +void readVectorBinary(std::vector<T> & v, ReadBuffer & buf) +{ + size_t size = 0; + readVarUInt(size, buf); + + if (size > DEFAULT_MAX_STRING_SIZE) + throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, + "Too large array size (maximum: {})", DEFAULT_MAX_STRING_SIZE); + + v.resize(size); + for (size_t i = 0; i < size; ++i) + readBinary(v[i], buf); +} + + +void assertString(const char * s, ReadBuffer & buf); +void assertEOF(ReadBuffer & buf); +void assertNotEOF(ReadBuffer & buf); + +[[noreturn]] void throwAtAssertionFailed(const char * s, ReadBuffer & buf); + +inline bool checkChar(char c, ReadBuffer & buf) +{ + char a; + if (!buf.peek(a) || a != c) + return false; + buf.ignore(); + return true; +} + +inline void assertChar(char symbol, ReadBuffer & buf) +{ + if (!checkChar(symbol, buf)) + { + char err[2] = {symbol, '\0'}; + throwAtAssertionFailed(err, buf); + } +} + +inline bool checkCharCaseInsensitive(char c, ReadBuffer & buf) +{ + char a; + if (!buf.peek(a) || !equalsCaseInsensitive(a, c)) + return false; + buf.ignore(); + return true; +} + +inline void assertString(const String & s, ReadBuffer & buf) +{ + assertString(s.c_str(), buf); +} + +bool checkString(const char * s, ReadBuffer & buf); +inline bool checkString(const String & s, ReadBuffer & buf) +{ + return checkString(s.c_str(), buf); +} + +bool checkStringCaseInsensitive(const char * s, ReadBuffer & buf); +inline bool checkStringCaseInsensitive(const String & s, ReadBuffer & buf) +{ + return checkStringCaseInsensitive(s.c_str(), buf); +} + +void assertStringCaseInsensitive(const char * s, ReadBuffer & buf); +inline void assertStringCaseInsensitive(const String & s, ReadBuffer & buf) +{ + return assertStringCaseInsensitive(s.c_str(), buf); +} + +/** Check that next character in buf matches first character of s. + * If true, then check all characters in s and throw exception if it doesn't match. + * If false, then return false, and leave position in buffer unchanged. + */ +bool checkStringByFirstCharacterAndAssertTheRest(const char * s, ReadBuffer & buf); +bool checkStringByFirstCharacterAndAssertTheRestCaseInsensitive(const char * s, ReadBuffer & buf); + +inline bool checkStringByFirstCharacterAndAssertTheRest(const String & s, ReadBuffer & buf) +{ + return checkStringByFirstCharacterAndAssertTheRest(s.c_str(), buf); +} + +inline bool checkStringByFirstCharacterAndAssertTheRestCaseInsensitive(const String & s, ReadBuffer & buf) +{ + return checkStringByFirstCharacterAndAssertTheRestCaseInsensitive(s.c_str(), buf); +} + + +inline void readBoolText(bool & x, ReadBuffer & buf) +{ + char tmp = '0'; + readChar(tmp, buf); + x = tmp != '0'; +} + +inline void readBoolTextWord(bool & x, ReadBuffer & buf, bool support_upper_case = false) +{ + if (buf.eof()) [[unlikely]] + throwReadAfterEOF(); + + switch (*buf.position()) + { + case 't': + assertString("true", buf); + x = true; + break; + case 'f': + assertString("false", buf); + x = false; + break; + case 'T': + { + if (support_upper_case) + { + assertString("TRUE", buf); + x = true; + break; + } + else + [[fallthrough]]; + } + case 'F': + { + if (support_upper_case) + { + assertString("FALSE", buf); + x = false; + break; + } + else + [[fallthrough]]; + } + default: + throw ParsingException(ErrorCodes::CANNOT_PARSE_BOOL, "Unexpected Bool value"); + } +} + +enum class ReadIntTextCheckOverflow +{ + DO_NOT_CHECK_OVERFLOW, + CHECK_OVERFLOW, +}; + +template <typename T, typename ReturnType = void, ReadIntTextCheckOverflow check_overflow = ReadIntTextCheckOverflow::DO_NOT_CHECK_OVERFLOW> +ReturnType readIntTextImpl(T & x, ReadBuffer & buf) +{ + using UnsignedT = make_unsigned_t<T>; + + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + bool negative = false; + UnsignedT res{}; + if (buf.eof()) [[unlikely]] + { + if constexpr (throw_exception) + throwReadAfterEOF(); + else + return ReturnType(false); + } + + const size_t initial_pos = buf.count(); + bool has_sign = false; + bool has_number = false; + while (!buf.eof()) + { + switch (*buf.position()) + { + case '+': + { + /// 123+ or +123+, just stop after 123 or +123. + if (has_number) + goto end; + + /// No digits read yet, but we already read sign, like ++, -+. + if (has_sign) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, + "Cannot parse number with multiple sign (+/-) characters"); + else + return ReturnType(false); + } + + has_sign = true; + break; + } + case '-': + { + if (has_number) + goto end; + + if (has_sign) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, + "Cannot parse number with multiple sign (+/-) characters"); + else + return ReturnType(false); + } + + if constexpr (is_signed_v<T>) + negative = true; + else + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Unsigned type must not contain '-' symbol"); + else + return ReturnType(false); + } + has_sign = true; + break; + } + case '0': [[fallthrough]]; + case '1': [[fallthrough]]; + case '2': [[fallthrough]]; + case '3': [[fallthrough]]; + case '4': [[fallthrough]]; + case '5': [[fallthrough]]; + case '6': [[fallthrough]]; + case '7': [[fallthrough]]; + case '8': [[fallthrough]]; + case '9': + { + has_number = true; + if constexpr (check_overflow == ReadIntTextCheckOverflow::CHECK_OVERFLOW && !is_big_int_v<T>) + { + /// Perform relativelly slow overflow check only when + /// number of decimal digits so far is close to the max for given type. + /// Example: 20 * 10 will overflow Int8. + + if (buf.count() - initial_pos + 1 >= std::numeric_limits<T>::max_digits10) + { + if (negative) + { + T signed_res = -res; + if (common::mulOverflow<T>(signed_res, 10, signed_res) || + common::subOverflow<T>(signed_res, (*buf.position() - '0'), signed_res)) + return ReturnType(false); + + res = -static_cast<UnsignedT>(signed_res); + } + else + { + T signed_res = res; + if (common::mulOverflow<T>(signed_res, 10, signed_res) || + common::addOverflow<T>(signed_res, (*buf.position() - '0'), signed_res)) + return ReturnType(false); + + res = signed_res; + } + break; + } + } + res *= 10; + res += *buf.position() - '0'; + break; + } + default: + goto end; + } + ++buf.position(); + } + +end: + if (has_sign && !has_number) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, + "Cannot parse number with a sign character but without any numeric character"); + else + return ReturnType(false); + } + x = res; + if constexpr (is_signed_v<T>) + { + if (negative) + { + if constexpr (check_overflow == ReadIntTextCheckOverflow::CHECK_OVERFLOW) + { + if (common::mulOverflow<UnsignedT, Int8, T>(res, -1, x)) + return ReturnType(false); + } + else + x = -res; + } + } + + return ReturnType(true); +} + +template <ReadIntTextCheckOverflow check_overflow = ReadIntTextCheckOverflow::DO_NOT_CHECK_OVERFLOW, typename T> +void readIntText(T & x, ReadBuffer & buf) +{ + if constexpr (is_decimal<T>) + { + readIntText<check_overflow>(x.value, buf); + } + else + { + readIntTextImpl<T, void, check_overflow>(x, buf); + } +} + +template <ReadIntTextCheckOverflow check_overflow = ReadIntTextCheckOverflow::CHECK_OVERFLOW, typename T> +bool tryReadIntText(T & x, ReadBuffer & buf) +{ + return readIntTextImpl<T, bool, check_overflow>(x, buf); +} + + +/** More efficient variant (about 1.5 times on real dataset). + * Differs in following: + * - for numbers starting with zero, parsed only zero; + * - symbol '+' before number is not supported; + */ +template <typename T, bool throw_on_error = true> +void readIntTextUnsafe(T & x, ReadBuffer & buf) +{ + bool negative = false; + make_unsigned_t<T> res = 0; + + auto on_error = [] + { + if (throw_on_error) + throwReadAfterEOF(); + }; + + if (buf.eof()) [[unlikely]] + return on_error(); + + if (is_signed_v<T> && *buf.position() == '-') + { + ++buf.position(); + negative = true; + if (buf.eof()) [[unlikely]] + return on_error(); + } + + if (*buf.position() == '0') /// There are many zeros in real datasets. + { + ++buf.position(); + x = 0; + return; + } + + while (!buf.eof()) + { + unsigned char value = *buf.position() - '0'; + + if (value < 10) + { + res *= 10; + res += value; + ++buf.position(); + } + else + break; + } + + /// See note about undefined behaviour above. + x = is_signed_v<T> && negative ? -res : res; +} + +template <typename T> +void tryReadIntTextUnsafe(T & x, ReadBuffer & buf) +{ + return readIntTextUnsafe<T, false>(x, buf); +} + + +/// Look at readFloatText.h +template <typename T> void readFloatText(T & x, ReadBuffer & in); +template <typename T> bool tryReadFloatText(T & x, ReadBuffer & in); + +template <typename T> void readFloatTextPrecise(T & x, ReadBuffer & in); +template <typename T> bool tryReadFloatTextPrecise(T & x, ReadBuffer & in); +template <typename T> void readFloatTextFast(T & x, ReadBuffer & in); +template <typename T> bool tryReadFloatTextFast(T & x, ReadBuffer & in); + + +/// simple: all until '\n' or '\t' +void readString(String & s, ReadBuffer & buf); + +void readEscapedString(String & s, ReadBuffer & buf); + +void readQuotedString(String & s, ReadBuffer & buf); +void readQuotedStringWithSQLStyle(String & s, ReadBuffer & buf); + +void readDoubleQuotedString(String & s, ReadBuffer & buf); +void readDoubleQuotedStringWithSQLStyle(String & s, ReadBuffer & buf); + +void readJSONString(String & s, ReadBuffer & buf); + +void readBackQuotedString(String & s, ReadBuffer & buf); +void readBackQuotedStringWithSQLStyle(String & s, ReadBuffer & buf); + +void readStringUntilEOF(String & s, ReadBuffer & buf); + +// Reads the line until EOL, unescaping backslash escape sequences. +// Buffer pointer is left at EOL, don't forget to advance it. +void readEscapedStringUntilEOL(String & s, ReadBuffer & buf); + +/// Only 0x20 as whitespace character +void readStringUntilWhitespace(String & s, ReadBuffer & buf); + + +/** Read string in CSV format. + * Parsing rules: + * - string could be placed in quotes; quotes could be single: ' if FormatSettings::CSV::allow_single_quotes is true + * or double: " if FormatSettings::CSV::allow_double_quotes is true; + * - or string could be unquoted - this is determined by first character; + * - if string is unquoted, then: + * - If settings.custom_delimiter is not specified, it is read until next settings.delimiter, either until end of line (CR or LF) or until end of stream; + * - If settings.custom_delimiter is specified it reads until first occurrences of settings.custom_delimiter in buffer. + * This works only if provided buffer is PeekableReadBuffer. + * but spaces and tabs at begin and end of unquoted string are consumed but ignored (note that this behaviour differs from RFC). + * - if string is in quotes, then it will be read until closing quote, + * but sequences of two consecutive quotes are parsed as single quote inside string; + */ +void readCSVString(String & s, ReadBuffer & buf, const FormatSettings::CSV & settings); + +/// Differ from readCSVString in that it doesn't remove quotes around field if any. +void readCSVField(String & s, ReadBuffer & buf, const FormatSettings::CSV & settings); + +/// Read string in CSV format until the first occurrence of first_delimiter or second_delimiter. +/// Similar to readCSVString if string is in quotes, we read only data in quotes. +String readCSVStringWithTwoPossibleDelimiters(PeekableReadBuffer & buf, const FormatSettings::CSV & settings, const String & first_delimiter, const String & second_delimiter); + +/// Same as above but includes quotes in the result if any. +String readCSVFieldWithTwoPossibleDelimiters(PeekableReadBuffer & buf, const FormatSettings::CSV & settings, const String & first_delimiter, const String & second_delimiter); + +/// Read and append result to array of characters. +template <typename Vector> +void readStringInto(Vector & s, ReadBuffer & buf); + +template <typename Vector> +void readNullTerminated(Vector & s, ReadBuffer & buf); + +template <typename Vector> +void readEscapedStringInto(Vector & s, ReadBuffer & buf); + +template <bool enable_sql_style_quoting, typename Vector> +void readQuotedStringInto(Vector & s, ReadBuffer & buf); + +template <bool enable_sql_style_quoting, typename Vector> +void readDoubleQuotedStringInto(Vector & s, ReadBuffer & buf); + +template <bool enable_sql_style_quoting, typename Vector> +void readBackQuotedStringInto(Vector & s, ReadBuffer & buf); + +template <typename Vector> +void readStringUntilEOFInto(Vector & s, ReadBuffer & buf); + +template <typename Vector, bool include_quotes = false> +void readCSVStringInto(Vector & s, ReadBuffer & buf, const FormatSettings::CSV & settings); + +/// ReturnType is either bool or void. If bool, the function will return false instead of throwing an exception. +template <typename Vector, typename ReturnType = void> +ReturnType readJSONStringInto(Vector & s, ReadBuffer & buf); + +template <typename Vector> +bool tryReadJSONStringInto(Vector & s, ReadBuffer & buf) +{ + return readJSONStringInto<Vector, bool>(s, buf); +} + +template <typename Vector> +bool tryReadQuotedStringInto(Vector & s, ReadBuffer & buf); + +/// Reads chunk of data between {} in that way, +/// that it has balanced parentheses sequence of {}. +/// So, it may form a JSON object, but it can be incorrenct. +template <typename Vector, typename ReturnType = void> +ReturnType readJSONObjectPossiblyInvalid(Vector & s, ReadBuffer & buf); + +template <typename Vector> +void readStringUntilWhitespaceInto(Vector & s, ReadBuffer & buf); + +template <typename Vector> +void readStringUntilNewlineInto(Vector & s, ReadBuffer & buf); + +/// This could be used as template parameter for functions above, if you want to just skip data. +struct NullOutput +{ + void append(const char *, size_t) {} + void append(const char *) {} + void append(const char *, const char *) {} + void push_back(char) {} /// NOLINT +}; + +template <typename ReturnType> +ReturnType readDateTextFallback(LocalDate & date, ReadBuffer & buf); + +/// In YYYY-MM-DD format. +/// For convenience, Month and Day parts can have single digit instead of two digits. +/// Any separators other than '-' are supported. +template <typename ReturnType = void> +inline ReturnType readDateTextImpl(LocalDate & date, ReadBuffer & buf) +{ + /// Optimistic path, when whole value is in buffer. + if (!buf.eof() && buf.position() + 10 <= buf.buffer().end()) + { + char * pos = buf.position(); + + /// YYYY-MM-DD + /// YYYY-MM-D + /// YYYY-M-DD + /// YYYY-M-D + /// YYYYMMDD + + /// The delimiters can be arbitrary characters, like YYYY/MM!DD, but obviously not digits. + + UInt16 year = (pos[0] - '0') * 1000 + (pos[1] - '0') * 100 + (pos[2] - '0') * 10 + (pos[3] - '0'); + UInt8 month; + UInt8 day; + pos += 5; + + if (isNumericASCII(pos[-1])) + { + /// YYYYMMDD + month = (pos[-1] - '0') * 10 + (pos[0] - '0'); + day = (pos[1] - '0') * 10 + (pos[2] - '0'); + pos += 3; + } + else + { + month = pos[0] - '0'; + if (isNumericASCII(pos[1])) + { + month = month * 10 + pos[1] - '0'; + pos += 3; + } + else + pos += 2; + + if (isNumericASCII(pos[-1])) + return ReturnType(false); + + day = pos[0] - '0'; + if (isNumericASCII(pos[1])) + { + day = day * 10 + pos[1] - '0'; + pos += 2; + } + else + pos += 1; + } + + buf.position() = pos; + date = LocalDate(year, month, day); + return ReturnType(true); + } + else + return readDateTextFallback<ReturnType>(date, buf); +} + +inline void convertToDayNum(DayNum & date, ExtendedDayNum & from) +{ + if (unlikely(from < 0)) + date = 0; + else if (unlikely(from > 0xFFFF)) + date = 0xFFFF; + else + date = from; +} + +template <typename ReturnType = void> +inline ReturnType readDateTextImpl(DayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + LocalDate local_date; + + if constexpr (throw_exception) + readDateTextImpl<ReturnType>(local_date, buf); + else if (!readDateTextImpl<ReturnType>(local_date, buf)) + return false; + + ExtendedDayNum ret = date_lut.makeDayNum(local_date.year(), local_date.month(), local_date.day()); + convertToDayNum(date, ret); + return ReturnType(true); +} + +template <typename ReturnType = void> +inline ReturnType readDateTextImpl(ExtendedDayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + LocalDate local_date; + + if constexpr (throw_exception) + readDateTextImpl<ReturnType>(local_date, buf); + else if (!readDateTextImpl<ReturnType>(local_date, buf)) + return false; + + /// When the parameter is out of rule or out of range, Date32 uses 1925-01-01 as the default value (-DateLUT::instance().getDayNumOffsetEpoch(), -16436) and Date uses 1970-01-01. + date = date_lut.makeDayNum(local_date.year(), local_date.month(), local_date.day(), -static_cast<Int32>(date_lut.getDayNumOffsetEpoch())); + return ReturnType(true); +} + + +inline void readDateText(LocalDate & date, ReadBuffer & buf) +{ + readDateTextImpl<void>(date, buf); +} + +inline void readDateText(DayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::instance()) +{ + readDateTextImpl<void>(date, buf, date_lut); +} + +inline void readDateText(ExtendedDayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::instance()) +{ + readDateTextImpl<void>(date, buf, date_lut); +} + +inline bool tryReadDateText(LocalDate & date, ReadBuffer & buf) +{ + return readDateTextImpl<bool>(date, buf); +} + +inline bool tryReadDateText(DayNum & date, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + return readDateTextImpl<bool>(date, buf, time_zone); +} + +inline bool tryReadDateText(ExtendedDayNum & date, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + return readDateTextImpl<bool>(date, buf, time_zone); +} + +UUID parseUUID(std::span<const UInt8> src); + +template <typename ReturnType = void> +inline ReturnType readUUIDTextImpl(UUID & uuid, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + char s[36]; + size_t size = buf.read(s, 32); + + if (size == 32) + { + if (s[8] == '-') + { + size += buf.read(&s[32], 4); + + if (size != 36) + { + s[size] = 0; + + if constexpr (throw_exception) + { + throw ParsingException(ErrorCodes::CANNOT_PARSE_UUID, "Cannot parse uuid {}", s); + } + else + { + return ReturnType(false); + } + } + } + + uuid = parseUUID({reinterpret_cast<const UInt8 *>(s), size}); + return ReturnType(true); + } + else + { + s[size] = 0; + + if constexpr (throw_exception) + { + throw ParsingException(ErrorCodes::CANNOT_PARSE_UUID, "Cannot parse uuid {}", s); + } + else + { + return ReturnType(false); + } + } +} + +inline void readUUIDText(UUID & uuid, ReadBuffer & buf) +{ + return readUUIDTextImpl<void>(uuid, buf); +} + +inline bool tryReadUUIDText(UUID & uuid, ReadBuffer & buf) +{ + return readUUIDTextImpl<bool>(uuid, buf); +} + +template <typename ReturnType = void> +inline ReturnType readIPv4TextImpl(IPv4 & ip, ReadBuffer & buf) +{ + if (parseIPv4(buf.position(), [&buf](){ return buf.eof(); }, reinterpret_cast<unsigned char *>(&ip.toUnderType()))) + return ReturnType(true); + + if constexpr (std::is_same_v<ReturnType, void>) + throw ParsingException(ErrorCodes::CANNOT_PARSE_IPV4, "Cannot parse IPv4 {}", std::string_view(buf.position(), buf.available())); + else + return ReturnType(false); +} + +inline void readIPv4Text(IPv4 & ip, ReadBuffer & buf) +{ + return readIPv4TextImpl<void>(ip, buf); +} + +inline bool tryReadIPv4Text(IPv4 & ip, ReadBuffer & buf) +{ + return readIPv4TextImpl<bool>(ip, buf); +} + +template <typename ReturnType = void> +inline ReturnType readIPv6TextImpl(IPv6 & ip, ReadBuffer & buf) +{ + if (parseIPv6orIPv4(buf.position(), [&buf](){ return buf.eof(); }, reinterpret_cast<unsigned char *>(ip.toUnderType().items))) + return ReturnType(true); + + if constexpr (std::is_same_v<ReturnType, void>) + throw ParsingException(ErrorCodes::CANNOT_PARSE_IPV6, "Cannot parse IPv6 {}", std::string_view(buf.position(), buf.available())); + else + return ReturnType(false); +} + +inline void readIPv6Text(IPv6 & ip, ReadBuffer & buf) +{ + return readIPv6TextImpl<void>(ip, buf); +} + +inline bool tryReadIPv6Text(IPv6 & ip, ReadBuffer & buf) +{ + return readIPv6TextImpl<bool>(ip, buf); +} + +template <typename T> +inline T parse(const char * data, size_t size); + +template <typename T> +inline T parseFromString(std::string_view str) +{ + return parse<T>(str.data(), str.size()); +} + + +template <typename ReturnType = void> +ReturnType readDateTimeTextFallback(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & date_lut); + +/** In YYYY-MM-DD hh:mm:ss or YYYY-MM-DD format, according to specified time zone. + * As an exception, also supported parsing of unix timestamp in form of decimal number. + */ +template <typename ReturnType = void> +inline ReturnType readDateTimeTextImpl(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & date_lut) +{ + /// Optimistic path, when whole value is in buffer. + const char * s = buf.position(); + + /// YYYY-MM-DD hh:mm:ss + static constexpr auto date_time_broken_down_length = 19; + /// YYYY-MM-DD + static constexpr auto date_broken_down_length = 10; + bool optimistic_path_for_date_time_input = s + date_time_broken_down_length <= buf.buffer().end(); + + if (optimistic_path_for_date_time_input) + { + if (s[4] < '0' || s[4] > '9') + { + UInt16 year = (s[0] - '0') * 1000 + (s[1] - '0') * 100 + (s[2] - '0') * 10 + (s[3] - '0'); + UInt8 month = (s[5] - '0') * 10 + (s[6] - '0'); + UInt8 day = (s[8] - '0') * 10 + (s[9] - '0'); + + UInt8 hour = 0; + UInt8 minute = 0; + UInt8 second = 0; + + /// Simply determine whether it is YYYY-MM-DD hh:mm:ss or YYYY-MM-DD by the content of the tenth character in an optimistic scenario + bool dt_long = (s[10] == ' ' || s[10] == 'T'); + if (dt_long) + { + hour = (s[11] - '0') * 10 + (s[12] - '0'); + minute = (s[14] - '0') * 10 + (s[15] - '0'); + second = (s[17] - '0') * 10 + (s[18] - '0'); + } + + if (unlikely(year == 0)) + datetime = 0; + else + datetime = date_lut.makeDateTime(year, month, day, hour, minute, second); + + if (dt_long) + buf.position() += date_time_broken_down_length; + else + buf.position() += date_broken_down_length; + + return ReturnType(true); + } + else + /// Why not readIntTextUnsafe? Because for needs of AdFox, parsing of unix timestamp with leading zeros is supported: 000...NNNN. + return readIntTextImpl<time_t, ReturnType, ReadIntTextCheckOverflow::CHECK_OVERFLOW>(datetime, buf); + } + else + return readDateTimeTextFallback<ReturnType>(datetime, buf, date_lut); +} + +template <typename ReturnType> +inline ReturnType readDateTimeTextImpl(DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut) +{ + time_t whole; + if (!readDateTimeTextImpl<bool>(whole, buf, date_lut)) + { + return ReturnType(false); + } + + int negative_multiplier = 1; + + DB::DecimalUtils::DecimalComponents<DateTime64> components{static_cast<DateTime64::NativeType>(whole), 0}; + + if (!buf.eof() && *buf.position() == '.') + { + ++buf.position(); + + /// Read digits, up to 'scale' positions. + for (size_t i = 0; i < scale; ++i) + { + if (!buf.eof() && isNumericASCII(*buf.position())) + { + components.fractional *= 10; + components.fractional += *buf.position() - '0'; + ++buf.position(); + } + else + { + /// Adjust to scale. + components.fractional *= 10; + } + } + + /// Ignore digits that are out of precision. + while (!buf.eof() && isNumericASCII(*buf.position())) + ++buf.position(); + + /// Fractional part (subseconds) is treated as positive by users + /// (as DateTime64 itself is a positive, although underlying decimal is negative) + /// setting fractional part to be negative when whole is 0 results in wrong value, + /// so we multiply result by -1. + if (components.whole < 0 && components.fractional != 0) + { + const auto scale_multiplier = DecimalUtils::scaleMultiplier<DateTime64::NativeType>(scale); + ++components.whole; + components.fractional = scale_multiplier - components.fractional; + if (!components.whole) + { + negative_multiplier = -1; + } + } + } + /// 10413792000 is time_t value for 2300-01-01 UTC (a bit over the last year supported by DateTime64) + else if (whole >= 10413792000LL) + { + /// Unix timestamp with subsecond precision, already scaled to integer. + /// For disambiguation we support only time since 2001-09-09 01:46:40 UTC and less than 30 000 years in future. + components.fractional = components.whole % common::exp10_i32(scale); + components.whole = components.whole / common::exp10_i32(scale); + } + + bool is_ok = true; + if constexpr (std::is_same_v<ReturnType, void>) + { + datetime64 = DecimalUtils::decimalFromComponents<DateTime64>(components, scale) * negative_multiplier; + } + else + { + is_ok = DecimalUtils::tryGetDecimalFromComponents<DateTime64>(components, scale, datetime64); + if (is_ok) + datetime64 *= negative_multiplier; + } + + return ReturnType(is_ok); +} + +inline void readDateTimeText(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + readDateTimeTextImpl<void>(datetime, buf, time_zone); +} + +inline void readDateTime64Text(DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::instance()) +{ + readDateTimeTextImpl<void>(datetime64, scale, buf, date_lut); +} + +inline bool tryReadDateTimeText(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + return readDateTimeTextImpl<bool>(datetime, buf, time_zone); +} + +inline bool tryReadDateTime64Text(DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::instance()) +{ + return readDateTimeTextImpl<bool>(datetime64, scale, buf, date_lut); +} + +inline void readDateTimeText(LocalDateTime & datetime, ReadBuffer & buf) +{ + char s[10]; + size_t size = buf.read(s, 10); + if (10 != size) + { + s[size] = 0; + throw ParsingException(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot parse DateTime {}", s); + } + + datetime.year((s[0] - '0') * 1000 + (s[1] - '0') * 100 + (s[2] - '0') * 10 + (s[3] - '0')); + datetime.month((s[5] - '0') * 10 + (s[6] - '0')); + datetime.day((s[8] - '0') * 10 + (s[9] - '0')); + + /// Allow to read Date as DateTime + if (buf.eof() || !(*buf.position() == ' ' || *buf.position() == 'T')) + return; + + ++buf.position(); + size = buf.read(s, 8); + if (8 != size) + { + s[size] = 0; + throw ParsingException(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot parse time component of DateTime {}", s); + } + + datetime.hour((s[0] - '0') * 10 + (s[1] - '0')); + datetime.minute((s[3] - '0') * 10 + (s[4] - '0')); + datetime.second((s[6] - '0') * 10 + (s[7] - '0')); +} + + +/// Generic methods to read value in native binary format. +template <typename T> +requires is_arithmetic_v<T> +inline void readBinary(T & x, ReadBuffer & buf) { readPODBinary(x, buf); } + +inline void readBinary(bool & x, ReadBuffer & buf) +{ + /// When deserializing a bool it might trigger UBSAN if the input is not 0 or 1, so it's better to treat it as an Int8 + static_assert(sizeof(bool) == sizeof(Int8)); + Int8 flag = 0; + readBinary(flag, buf); + x = (flag != 0); +} + +inline void readBinary(String & x, ReadBuffer & buf) { readStringBinary(x, buf); } +inline void readBinary(Decimal32 & x, ReadBuffer & buf) { readPODBinary(x, buf); } +inline void readBinary(Decimal64 & x, ReadBuffer & buf) { readPODBinary(x, buf); } +inline void readBinary(Decimal128 & x, ReadBuffer & buf) { readPODBinary(x, buf); } +inline void readBinary(Decimal256 & x, ReadBuffer & buf) { readPODBinary(x.value, buf); } +inline void readBinary(LocalDate & x, ReadBuffer & buf) { readPODBinary(x, buf); } +inline void readBinary(IPv4 & x, ReadBuffer & buf) { readPODBinary(x, buf); } +inline void readBinary(IPv6 & x, ReadBuffer & buf) { readPODBinary(x, buf); } + +inline void readBinary(UUID & x, ReadBuffer & buf) +{ + readUUIDBinary(x, buf); +} + +inline void readBinary(CityHash_v1_0_2::uint128 & x, ReadBuffer & buf) +{ + readPODBinary(x.low64, buf); + readPODBinary(x.high64, buf); +} + +inline void readBinary(StackTrace::FramePointers & x, ReadBuffer & buf) { readPODBinary(x, buf); } + +template <std::endian endian, typename T> +inline void readBinaryEndian(T & x, ReadBuffer & buf) +{ + readBinary(x, buf); + transformEndianness<endian>(x); +} + +template <typename T> +inline void readBinaryLittleEndian(T & x, ReadBuffer & buf) +{ + readBinaryEndian<std::endian::little>(x, buf); +} + +template <typename T> +inline void readBinaryBigEndian(T & x, ReadBuffer & buf) +{ + readBinaryEndian<std::endian::big>(x, buf); +} + + +/// Generic methods to read value in text tab-separated format. + +inline void readText(is_integer auto & x, ReadBuffer & buf) +{ + if constexpr (std::is_same_v<decltype(x), bool &>) + readBoolText(x, buf); + else + readIntText(x, buf); +} + +inline bool tryReadText(is_integer auto & x, ReadBuffer & buf) +{ + return tryReadIntText(x, buf); +} + +inline bool tryReadText(UUID & x, ReadBuffer & buf) { return tryReadUUIDText(x, buf); } +inline bool tryReadText(IPv4 & x, ReadBuffer & buf) { return tryReadIPv4Text(x, buf); } +inline bool tryReadText(IPv6 & x, ReadBuffer & buf) { return tryReadIPv6Text(x, buf); } + +inline void readText(is_floating_point auto & x, ReadBuffer & buf) { readFloatText(x, buf); } + +inline void readText(String & x, ReadBuffer & buf) { readEscapedString(x, buf); } + +inline void readText(DayNum & x, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) { readDateText(x, buf, time_zone); } + +inline void readText(LocalDate & x, ReadBuffer & buf) { readDateText(x, buf); } +inline void readText(LocalDateTime & x, ReadBuffer & buf) { readDateTimeText(x, buf); } +inline void readText(UUID & x, ReadBuffer & buf) { readUUIDText(x, buf); } +inline void readText(IPv4 & x, ReadBuffer & buf) { readIPv4Text(x, buf); } +inline void readText(IPv6 & x, ReadBuffer & buf) { readIPv6Text(x, buf); } + +/// Generic methods to read value in text format, +/// possibly in single quotes (only for data types that use quotes in VALUES format of INSERT statement in SQL). +template <typename T> +requires is_arithmetic_v<T> +inline void readQuoted(T & x, ReadBuffer & buf) { readText(x, buf); } + +template <typename T> +requires is_arithmetic_v<T> +inline void readQuoted(T & x, ReadBuffer & buf, const DateLUTImpl & time_zone) { readText(x, buf, time_zone); } + +inline void readQuoted(String & x, ReadBuffer & buf) { readQuotedString(x, buf); } + +inline void readQuoted(LocalDate & x, ReadBuffer & buf) +{ + assertChar('\'', buf); + readDateText(x, buf); + assertChar('\'', buf); +} + +inline void readQuoted(LocalDateTime & x, ReadBuffer & buf) +{ + assertChar('\'', buf); + readDateTimeText(x, buf); + assertChar('\'', buf); +} + +inline void readQuoted(UUID & x, ReadBuffer & buf) +{ + assertChar('\'', buf); + readUUIDText(x, buf); + assertChar('\'', buf); +} + +inline void readQuoted(IPv4 & x, ReadBuffer & buf) +{ + assertChar('\'', buf); + readIPv4Text(x, buf); + assertChar('\'', buf); +} + +inline void readQuoted(IPv6 & x, ReadBuffer & buf) +{ + assertChar('\'', buf); + readIPv6Text(x, buf); + assertChar('\'', buf); +} + +/// Same as above, but in double quotes. +template <typename T> +requires is_arithmetic_v<T> +inline void readDoubleQuoted(T & x, ReadBuffer & buf) { readText(x, buf); } + +template <typename T> +requires is_arithmetic_v<T> +inline void readDoubleQuoted(T & x, ReadBuffer & buf, const DateLUTImpl & time_zone) { readText(x, buf, time_zone); } + +inline void readDoubleQuoted(String & x, ReadBuffer & buf) { readDoubleQuotedString(x, buf); } + +inline void readDoubleQuoted(LocalDate & x, ReadBuffer & buf) +{ + assertChar('"', buf); + readDateText(x, buf); + assertChar('"', buf); +} + +inline void readDoubleQuoted(LocalDateTime & x, ReadBuffer & buf) +{ + assertChar('"', buf); + readDateTimeText(x, buf); + assertChar('"', buf); +} + +/// CSV for numbers: quotes are optional, no special escaping rules. +template <typename T> +inline void readCSVSimple(T & x, ReadBuffer & buf) +{ + if (buf.eof()) [[unlikely]] + throwReadAfterEOF(); + + char maybe_quote = *buf.position(); + + if (maybe_quote == '\'' || maybe_quote == '\"') + ++buf.position(); + + readText(x, buf); + + if (maybe_quote == '\'' || maybe_quote == '\"') + assertChar(maybe_quote, buf); +} + +// standalone overload for dates: to avoid instantiating DateLUTs while parsing other types +template <typename T> +inline void readCSVSimple(T & x, ReadBuffer & buf, const DateLUTImpl & time_zone) +{ + if (buf.eof()) [[unlikely]] + throwReadAfterEOF(); + + char maybe_quote = *buf.position(); + + if (maybe_quote == '\'' || maybe_quote == '\"') + ++buf.position(); + + readText(x, buf, time_zone); + + if (maybe_quote == '\'' || maybe_quote == '\"') + assertChar(maybe_quote, buf); +} + +template <typename T> +requires is_arithmetic_v<T> +inline void readCSV(T & x, ReadBuffer & buf) +{ + readCSVSimple(x, buf); +} + +inline void readCSV(String & x, ReadBuffer & buf, const FormatSettings::CSV & settings) { readCSVString(x, buf, settings); } +inline void readCSV(LocalDate & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(DayNum & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(DayNum & x, ReadBuffer & buf, const DateLUTImpl & time_zone) { readCSVSimple(x, buf, time_zone); } +inline void readCSV(LocalDateTime & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(UUID & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(IPv4 & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(IPv6 & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(UInt128 & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(Int128 & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(UInt256 & x, ReadBuffer & buf) { readCSVSimple(x, buf); } +inline void readCSV(Int256 & x, ReadBuffer & buf) { readCSVSimple(x, buf); } + +template <typename T> +void readBinary(std::vector<T> & x, ReadBuffer & buf) +{ + size_t size = 0; + readVarUInt(size, buf); + + if (size > DEFAULT_MAX_STRING_SIZE) + throw Poco::Exception("Too large vector size."); + + x.resize(size); + for (size_t i = 0; i < size; ++i) + readBinary(x[i], buf); +} + +template <typename T> +void readQuoted(std::vector<T> & x, ReadBuffer & buf) +{ + bool first = true; + assertChar('[', buf); + while (!buf.eof() && *buf.position() != ']') + { + if (!first) + { + if (*buf.position() == ',') + ++buf.position(); + else + throw ParsingException(ErrorCodes::CANNOT_READ_ARRAY_FROM_TEXT, "Cannot read array from text"); + } + + first = false; + + x.push_back(T()); + readQuoted(x.back(), buf); + } + assertChar(']', buf); +} + +template <typename T> +void readDoubleQuoted(std::vector<T> & x, ReadBuffer & buf) +{ + bool first = true; + assertChar('[', buf); + while (!buf.eof() && *buf.position() != ']') + { + if (!first) + { + if (*buf.position() == ',') + ++buf.position(); + else + throw ParsingException(ErrorCodes::CANNOT_READ_ARRAY_FROM_TEXT, "Cannot read array from text"); + } + + first = false; + + x.push_back(T()); + readDoubleQuoted(x.back(), buf); + } + assertChar(']', buf); +} + +template <typename T> +void readText(std::vector<T> & x, ReadBuffer & buf) +{ + readQuoted(x, buf); +} + + +/// Skip whitespace characters. +inline void skipWhitespaceIfAny(ReadBuffer & buf, bool one_line = false) +{ + if (!one_line) + while (!buf.eof() && isWhitespaceASCII(*buf.position())) + ++buf.position(); + else + while (!buf.eof() && isWhitespaceASCIIOneLine(*buf.position())) + ++buf.position(); +} + +/// Skips json value. +void skipJSONField(ReadBuffer & buf, StringRef name_of_field); + + +/** Read serialized exception. + * During serialization/deserialization some information is lost + * (type is cut to base class, 'message' replaced by 'displayText', and stack trace is appended to 'message') + * Some additional message could be appended to exception (example: you could add information about from where it was received). + */ +Exception readException(ReadBuffer & buf, const String & additional_message = "", bool remote_exception = false); +void readAndThrowException(ReadBuffer & buf, const String & additional_message = ""); + + +/** Helper function for implementation. + */ +template <ReadIntTextCheckOverflow check_overflow = ReadIntTextCheckOverflow::CHECK_OVERFLOW, typename T> +static inline const char * tryReadIntText(T & x, const char * pos, const char * end) +{ + ReadBufferFromMemory in(pos, end - pos); + tryReadIntText<check_overflow>(x, in); + return pos + in.count(); +} + + +/// Convenient methods for reading something from string in text format. +template <typename T> +inline T parse(const char * data, size_t size) +{ + T res; + ReadBufferFromMemory buf(data, size); + readText(res, buf); + return res; +} + +template <typename T> +inline bool tryParse(T & res, const char * data, size_t size) +{ + ReadBufferFromMemory buf(data, size); + return tryReadText(res, buf); +} + +template <typename T> +inline void readTextWithSizeSuffix(T & x, ReadBuffer & buf) { readText(x, buf); } + +template <is_integer T> +inline void readTextWithSizeSuffix(T & x, ReadBuffer & buf) +{ + readIntText(x, buf); + if (buf.eof()) + return; + + /// Updates x depending on the suffix + auto finish = [&buf, &x] (UInt64 base, int power_of_two) mutable + { + ++buf.position(); + if (buf.eof()) + { + x *= base; /// For decimal suffixes, such as k, M, G etc. + } + else if (*buf.position() == 'i') + { + x = (x << power_of_two); // NOLINT /// For binary suffixes, such as ki, Mi, Gi, etc. + ++buf.position(); + } + return; + }; + + switch (*buf.position()) + { + case 'k': [[fallthrough]]; + case 'K': + finish(1000, 10); + break; + case 'M': + finish(1000000, 20); + break; + case 'G': + finish(1000000000, 30); + break; + case 'T': + finish(1000000000000ULL, 40); + break; + default: + return; + } +} + +/// Read something from text format and trying to parse the suffix. +/// If the suffix is not valid gives an error +/// For example: 723145 -- ok, 213MB -- not ok, but 213Mi -- ok +template <typename T> +inline T parseWithSizeSuffix(const char * data, size_t size) +{ + T res; + ReadBufferFromMemory buf(data, size); + readTextWithSizeSuffix(res, buf); + assertEOF(buf); + return res; +} + +template <typename T> +inline T parseWithSizeSuffix(std::string_view s) +{ + return parseWithSizeSuffix<T>(s.data(), s.size()); +} + +template <typename T> +inline T parseWithSizeSuffix(const char * data) +{ + return parseWithSizeSuffix<T>(data, strlen(data)); +} + +template <typename T> +inline T parse(const char * data) +{ + return parse<T>(data, strlen(data)); +} + +template <typename T> +inline T parse(const String & s) +{ + return parse<T>(s.data(), s.size()); +} + +template <typename T> +inline T parse(std::string_view s) +{ + return parse<T>(s.data(), s.size()); +} + +template <typename T> +inline bool tryParse(T & res, const char * data) +{ + return tryParse(res, data, strlen(data)); +} + +template <typename T> +inline bool tryParse(T & res, const String & s) +{ + return tryParse(res, s.data(), s.size()); +} + +template <typename T> +inline bool tryParse(T & res, std::string_view s) +{ + return tryParse(res, s.data(), s.size()); +} + + +/** Skip UTF-8 BOM if it is under cursor. + * As BOM is usually located at start of stream, and buffer size is usually larger than three bytes, + * the function expects, that all three bytes of BOM is fully in buffer (otherwise it don't skip anything). + */ +inline void skipBOMIfExists(ReadBuffer & buf) +{ + if (!buf.eof() + && buf.position() + 3 < buf.buffer().end() + && buf.position()[0] == '\xEF' + && buf.position()[1] == '\xBB' + && buf.position()[2] == '\xBF') + { + buf.position() += 3; + } +} + + +/// Skip to next character after next \n. If no \n in stream, skip to end. +void skipToNextLineOrEOF(ReadBuffer & buf); + +/// Skip to next character after next \r. If no \r in stream, skip to end. +void skipToCarriageReturnOrEOF(ReadBuffer & buf); + +/// Skip to next character after next unescaped \n. If no \n in stream, skip to end. Does not throw on invalid escape sequences. +void skipToUnescapedNextLineOrEOF(ReadBuffer & buf); + +/// Skip to next character after next \0. If no \0 in stream, skip to end. +void skipNullTerminated(ReadBuffer & buf); + +/** This function just copies the data from buffer's position (in.position()) + * to current position (from arguments) appending into memory. + */ +void saveUpToPosition(ReadBuffer & in, Memory<Allocator<false>> & memory, char * current); + +/** This function is negative to eof(). + * In fact it returns whether the data was loaded to internal ReadBuffers's buffer or not. + * And saves data from buffer's position to current if there is no pending data in buffer. + * Why we have to use this strange function? Consider we have buffer's internal position in the middle + * of our buffer and the current cursor in the end of the buffer. When we call eof() it calls next(). + * And this function can fill the buffer with new data, so we will lose the data from previous buffer state. + */ +bool loadAtPosition(ReadBuffer & in, Memory<Allocator<false>> & memory, char * & current); + +/// Skip data until start of the next row or eof (the end of row is determined by two delimiters: +/// row_after_delimiter and row_between_delimiter). +void skipToNextRowOrEof(PeekableReadBuffer & buf, const String & row_after_delimiter, const String & row_between_delimiter, bool skip_spaces); + +struct PcgDeserializer +{ + static void deserializePcg32(pcg32_fast & rng, ReadBuffer & buf) + { + decltype(rng.state_) multiplier, increment, state; + readText(multiplier, buf); + assertChar(' ', buf); + readText(increment, buf); + assertChar(' ', buf); + readText(state, buf); + + if (multiplier != rng.multiplier()) + throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect multiplier in pcg32: expected {}, got {}", rng.multiplier(), multiplier); + if (increment != rng.increment()) + throw Exception(ErrorCodes::INCORRECT_DATA, "Incorrect increment in pcg32: expected {}, got {}", rng.increment(), increment); + + rng.state_ = state; + } +}; + +template <typename Vector> +void readQuotedFieldInto(Vector & s, ReadBuffer & buf); + +void readQuotedField(String & s, ReadBuffer & buf); + +void readJSONField(String & s, ReadBuffer & buf); + +void readTSVField(String & s, ReadBuffer & buf); + +/** Parse the escape sequence, which can be simple (one character after backslash) or more complex (multiple characters). + * It is assumed that the cursor is located on the `\` symbol + */ +bool parseComplexEscapeSequence(String & s, ReadBuffer & buf); + +} diff --git a/contrib/clickhouse/src/IO/ReadHelpersArena.h b/contrib/clickhouse/src/IO/ReadHelpersArena.h new file mode 100644 index 0000000000..b88d5c037d --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadHelpersArena.h @@ -0,0 +1,33 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/ReadHelpers.h> +#include <IO/VarInt.h> +#include <base/StringRef.h> +#include <Common/Arena.h> + + +namespace DB +{ + + +namespace ErrorCodes +{ + extern const int TOO_LARGE_STRING_SIZE; +} + +inline StringRef readStringBinaryInto(Arena & arena, ReadBuffer & buf) +{ + size_t size = 0; + readVarUInt(size, buf); + + if (unlikely(size > DEFAULT_MAX_STRING_SIZE)) + throw Exception(ErrorCodes::TOO_LARGE_STRING_SIZE, "Too large string size."); + + char * data = arena.alloc(size); + buf.readStrict(data, size); + + return StringRef(data, size); +} + +} diff --git a/contrib/clickhouse/src/IO/ReadSettings.h b/contrib/clickhouse/src/IO/ReadSettings.h new file mode 100644 index 0000000000..87f249823b --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadSettings.h @@ -0,0 +1,136 @@ +#pragma once + +#include <cstddef> +#include <string> +#include <Core/Defines.h> +#include <Interpreters/Cache/FileCache_fwd.h> +#include <Common/Throttler_fwd.h> +#include <Common/Priority.h> +#include <IO/ResourceLink.h> + +namespace DB +{ +enum class LocalFSReadMethod +{ + /** + * Simple synchronous reads with 'read'. + * Can use direct IO after specified size. + * Can use prefetch by asking OS to perform readahead. + */ + read, + + /** + * Simple synchronous reads with 'pread'. + * In contrast to 'read', shares single file descriptor from multiple threads. + * Can use direct IO after specified size. + * Can use prefetch by asking OS to perform readahead. + */ + pread, + + /** + * Use mmap after specified size or simple synchronous reads with 'pread'. + * Can use prefetch by asking OS to perform readahead. + */ + mmap, + + /** + * Use the io_uring Linux subsystem for asynchronous reads. + * Can use direct IO after specified size. + * Can do prefetch with double buffering. + */ + io_uring, + + /** + * Checks if data is in page cache with 'preadv2' on modern Linux kernels. + * If data is in page cache, read from the same thread. + * If not, offload IO to separate threadpool. + * Can do prefetch with double buffering. + * Can use specified priorities and limit the number of concurrent reads. + */ + pread_threadpool, + + /// Use asynchronous reader with fake backend that in fact synchronous. + /// @attention Use only for testing purposes. + pread_fake_async +}; + +enum class RemoteFSReadMethod +{ + read, + threadpool, +}; + +class MMappedFileCache; + +struct ReadSettings +{ + /// Method to use reading from local filesystem. + LocalFSReadMethod local_fs_method = LocalFSReadMethod::pread; + /// Method to use reading from remote filesystem. + RemoteFSReadMethod remote_fs_method = RemoteFSReadMethod::threadpool; + + /// https://eklitzke.org/efficient-file-copying-on-linux + size_t local_fs_buffer_size = 128 * 1024; + + size_t remote_fs_buffer_size = DBMS_DEFAULT_BUFFER_SIZE; + size_t prefetch_buffer_size = DBMS_DEFAULT_BUFFER_SIZE; + + bool local_fs_prefetch = false; + bool remote_fs_prefetch = false; + + /// For 'read', 'pread' and 'pread_threadpool' methods. + size_t direct_io_threshold = 0; + + /// For 'mmap' method. + size_t mmap_threshold = 0; + MMappedFileCache * mmap_cache = nullptr; + + /// For 'pread_threadpool'/'io_uring' method. Lower value is higher priority. + Priority priority; + + bool load_marks_asynchronously = true; + + size_t remote_fs_read_max_backoff_ms = 10000; + size_t remote_fs_read_backoff_max_tries = 4; + + bool enable_filesystem_read_prefetches_log = false; + + bool enable_filesystem_cache = true; + bool read_from_filesystem_cache_if_exists_otherwise_bypass_cache = false; + bool enable_filesystem_cache_log = false; + /// Don't populate cache when the read is not part of query execution (e.g. background thread). + bool avoid_readthrough_cache_outside_query_context = true; + + size_t filesystem_cache_max_download_size = (128UL * 1024 * 1024 * 1024); + bool skip_download_if_exceeds_query_cache = true; + + size_t remote_read_min_bytes_for_seek = DBMS_DEFAULT_BUFFER_SIZE; + + FileCachePtr remote_fs_cache; + + /// Bandwidth throttler to use during reading + ThrottlerPtr remote_throttler; + ThrottlerPtr local_throttler; + + // Resource to be used during reading + ResourceLink resource_link; + + size_t http_max_tries = 1; + size_t http_retry_initial_backoff_ms = 100; + size_t http_retry_max_backoff_ms = 1600; + bool http_skip_not_found_url_for_globs = true; + + /// Monitoring + bool for_object_storage = false; // to choose which profile events should be incremented + + ReadSettings adjustBufferSize(size_t file_size) const + { + ReadSettings res = *this; + res.local_fs_buffer_size = std::min(std::max(1ul, file_size), local_fs_buffer_size); + res.remote_fs_buffer_size = std::min(std::max(1ul, file_size), remote_fs_buffer_size); + res.prefetch_buffer_size = std::min(std::max(1ul, file_size), prefetch_buffer_size); + return res; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/ReadWriteBufferFromHTTP.cpp b/contrib/clickhouse/src/IO/ReadWriteBufferFromHTTP.cpp new file mode 100644 index 0000000000..7e5c0d37c8 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadWriteBufferFromHTTP.cpp @@ -0,0 +1,940 @@ +#include "ReadWriteBufferFromHTTP.h" + +#include <IO/HTTPCommon.h> + +namespace ProfileEvents +{ +extern const Event ReadBufferSeekCancelConnection; +extern const Event ReadWriteBufferFromHTTPPreservedSessions; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int TOO_MANY_REDIRECTS; + extern const int HTTP_RANGE_NOT_SATISFIABLE; + extern const int BAD_ARGUMENTS; + extern const int CANNOT_SEEK_THROUGH_FILE; + extern const int SEEK_POSITION_OUT_OF_BOUND; + extern const int UNKNOWN_FILE_SIZE; +} + +template <typename TSessionFactory> +UpdatableSession<TSessionFactory>::UpdatableSession(const Poco::URI & uri, UInt64 max_redirects_, std::shared_ptr<TSessionFactory> session_factory_) + : max_redirects{max_redirects_} + , initial_uri(uri) + , session_factory(std::move(session_factory_)) +{ + session = session_factory->buildNewSession(uri); +} + +template <typename TSessionFactory> +typename UpdatableSession<TSessionFactory>::SessionPtr UpdatableSession<TSessionFactory>::getSession() { return session; } + +template <typename TSessionFactory> +void UpdatableSession<TSessionFactory>::updateSession(const Poco::URI & uri) +{ + ++redirects; + if (redirects <= max_redirects) + session = session_factory->buildNewSession(uri); + else + throw Exception(ErrorCodes::TOO_MANY_REDIRECTS, + "Too many redirects while trying to access {}." + " You can {} redirects by changing the setting 'max_http_get_redirects'." + " Example: `SET max_http_get_redirects = 10`." + " Redirects are restricted to prevent possible attack when a malicious server redirects to an internal resource, bypassing the authentication or firewall.", + initial_uri.toString(), max_redirects ? "increase the allowed maximum number of" : "allow"); +} + +template <typename TSessionFactory> +typename UpdatableSession<TSessionFactory>::SessionPtr UpdatableSession<TSessionFactory>::createDetachedSession(const Poco::URI & uri) +{ + return session_factory->buildNewSession(uri); +} + +template <typename TSessionFactory> +std::shared_ptr<UpdatableSession<TSessionFactory>> UpdatableSession<TSessionFactory>::clone(const Poco::URI & uri) +{ + return std::make_shared<UpdatableSession<TSessionFactory>>(uri, max_redirects, session_factory); +} + + +namespace detail +{ + +static bool isRetriableError(const Poco::Net::HTTPResponse::HTTPStatus http_status) noexcept +{ + static constexpr std::array non_retriable_errors{ + Poco::Net::HTTPResponse::HTTPStatus::HTTP_BAD_REQUEST, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_UNAUTHORIZED, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_FOUND, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_FORBIDDEN, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_IMPLEMENTED, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_METHOD_NOT_ALLOWED}; + + return std::all_of( + non_retriable_errors.begin(), non_retriable_errors.end(), [&](const auto status) { return http_status != status; }); +} + +static Poco::URI getUriAfterRedirect(const Poco::URI & prev_uri, Poco::Net::HTTPResponse & response) +{ + auto location = response.get("Location"); + auto location_uri = Poco::URI(location); + if (!location_uri.isRelative()) + return location_uri; + /// Location header contains relative path. So we need to concatenate it + /// with path from the original URI and normalize it. + auto path = std::filesystem::weakly_canonical(std::filesystem::path(prev_uri.getPath()) / location); + location_uri = prev_uri; + location_uri.setPath(path); + return location_uri; +} + +template <typename UpdatableSessionPtr> +bool ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::withPartialContent(const HTTPRange & range) const +{ + /** + * Add range header if we have some passed range + * or if we want to retry GET request on purpose. + */ + return range.begin || range.end || retry_with_range_header; +} + +template <typename UpdatableSessionPtr> +size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getOffset() const { return read_range.begin.value_or(0) + offset_from_begin_pos; } + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::prepareRequest(Poco::Net::HTTPRequest & request, Poco::URI uri_, std::optional<HTTPRange> range) const +{ + request.setHost(uri_.getHost()); // use original, not resolved host name in header + + if (out_stream_callback) + request.setChunkedTransferEncoding(true); + else if (method == Poco::Net::HTTPRequest::HTTP_POST) + request.setContentLength(0); /// No callback - no body + + for (const auto & [header, value] : http_header_entries) + request.set(header, value); + + if (range) + { + String range_header_value; + if (range->end) + range_header_value = fmt::format("bytes={}-{}", *range->begin, *range->end); + else + range_header_value = fmt::format("bytes={}-", *range->begin); + LOG_TEST(log, "Adding header: Range: {}", range_header_value); + request.set("Range", range_header_value); + } + + if (!credentials.getUsername().empty()) + credentials.authenticate(request); +} + +template <typename UpdatableSessionPtr> +std::istream * ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::callImpl( + UpdatableSessionPtr & current_session, Poco::URI uri_, Poco::Net::HTTPResponse & response, const std::string & method_, bool for_object_info) +{ + // With empty path poco will send "POST HTTP/1.1" its bug. + if (uri_.getPath().empty()) + uri_.setPath("/"); + + std::optional<HTTPRange> range; + if (!for_object_info) + { + if (withPartialContent(read_range)) + range = HTTPRange{getOffset(), read_range.end}; + } + + Poco::Net::HTTPRequest request(method_, uri_.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); + prepareRequest(request, uri_, range); + + LOG_TRACE(log, "Sending request to {}", uri_.toString()); + + auto sess = current_session->getSession(); + auto & stream_out = sess->sendRequest(request); + + if (out_stream_callback) + out_stream_callback(stream_out); + + auto result_istr = receiveResponse(*sess, request, response, true); + response.getCookies(cookies); + + /// we can fetch object info while the request is being processed + /// and we don't want to override any context used by it + if (!for_object_info) + content_encoding = response.get("Content-Encoding", ""); + + return result_istr; +} + +template <typename UpdatableSessionPtr> +size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getFileSize() +{ + if (!file_info) + file_info = getFileInfo(); + + if (file_info->file_size) + return *file_info->file_size; + + throw Exception(ErrorCodes::UNKNOWN_FILE_SIZE, "Cannot find out file size for: {}", uri.toString()); +} + +template <typename UpdatableSessionPtr> +bool ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::supportsReadAt() +{ + if (!file_info) + file_info = getFileInfo(); + return method == Poco::Net::HTTPRequest::HTTP_GET && file_info->seekable; +} + +template <typename UpdatableSessionPtr> +bool ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::checkIfActuallySeekable() +{ + if (!file_info) + file_info = getFileInfo(); + return file_info->seekable; +} + +template <typename UpdatableSessionPtr> +String ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getFileName() const { return uri.toString(); } + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getHeadResponse(Poco::Net::HTTPResponse & response) +{ + for (size_t i = 0; i < settings.http_max_tries; ++i) + { + try + { + callWithRedirects(response, Poco::Net::HTTPRequest::HTTP_HEAD, true, true); + break; + } + catch (const Poco::Exception & e) + { + if (i == settings.http_max_tries - 1 || !isRetriableError(response.getStatus())) + throw; + + LOG_ERROR(log, "Failed to make HTTP_HEAD request to {}. Error: {}", uri.toString(), e.displayText()); + } + } +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::setupExternalBuffer() +{ + /** + * use_external_buffer -- means we read into the buffer which + * was passed to us from somewhere else. We do not check whether + * previously returned buffer was read or not (no hasPendingData() check is needed), + * because this branch means we are prefetching data, + * each nextImpl() call we can fill a different buffer. + */ + impl->set(internal_buffer.begin(), internal_buffer.size()); + assert(working_buffer.begin() != nullptr); + assert(!internal_buffer.empty()); +} + +template <typename UpdatableSessionPtr> +ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::ReadWriteBufferFromHTTPBase( + UpdatableSessionPtr session_, + Poco::URI uri_, + const Poco::Net::HTTPBasicCredentials & credentials_, + const std::string & method_, + OutStreamCallback out_stream_callback_, + size_t buffer_size_, + const ReadSettings & settings_, + HTTPHeaderEntries http_header_entries_, + const RemoteHostFilter * remote_host_filter_, + bool delay_initialization, + bool use_external_buffer_, + bool http_skip_not_found_url_, + std::optional<HTTPFileInfo> file_info_, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config_) + : SeekableReadBuffer(nullptr, 0) + , uri {uri_} + , method {!method_.empty() ? method_ : out_stream_callback_ ? Poco::Net::HTTPRequest::HTTP_POST : Poco::Net::HTTPRequest::HTTP_GET} + , session {session_} + , out_stream_callback {out_stream_callback_} + , credentials {credentials_} + , http_header_entries {std::move(http_header_entries_)} + , remote_host_filter {remote_host_filter_} + , buffer_size {buffer_size_} + , use_external_buffer {use_external_buffer_} + , file_info(file_info_) + , http_skip_not_found_url(http_skip_not_found_url_) + , settings {settings_} + , log(&Poco::Logger::get("ReadWriteBufferFromHTTP")) + , proxy_config(proxy_config_) +{ + if (settings.http_max_tries <= 0 || settings.http_retry_initial_backoff_ms <= 0 + || settings.http_retry_initial_backoff_ms >= settings.http_retry_max_backoff_ms) + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Invalid setting for http backoff, " + "must be http_max_tries >= 1 (current is {}) and " + "0 < http_retry_initial_backoff_ms < settings.http_retry_max_backoff_ms (now 0 < {} < {})", + settings.http_max_tries, + settings.http_retry_initial_backoff_ms, + settings.http_retry_max_backoff_ms); + + // Configure User-Agent if it not already set. + const std::string user_agent = "User-Agent"; + auto iter = std::find_if( + http_header_entries.begin(), + http_header_entries.end(), + [&user_agent](const HTTPHeaderEntry & entry) { return entry.name == user_agent; }); + + if (iter == http_header_entries.end()) + { + http_header_entries.emplace_back("User-Agent", fmt::format("ClickHouse/{}", VERSION_STRING)); + } + + if (!delay_initialization) + { + initialize(); + if (exception) + std::rethrow_exception(exception); + } +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::callWithRedirects(Poco::Net::HTTPResponse & response, const String & method_, bool throw_on_all_errors, bool for_object_info) +{ + UpdatableSessionPtr current_session = nullptr; + + /// we can fetch object info while the request is being processed + /// and we don't want to override any context used by it + if (for_object_info) + current_session = session->clone(uri); + else + current_session = session; + + call(current_session, response, method_, throw_on_all_errors, for_object_info); + saved_uri_redirect = uri; + + while (isRedirect(response.getStatus())) + { + Poco::URI uri_redirect = getUriAfterRedirect(*saved_uri_redirect, response); + saved_uri_redirect = uri_redirect; + if (remote_host_filter) + remote_host_filter->checkURL(uri_redirect); + + current_session->updateSession(uri_redirect); + + /// we can fetch object info while the request is being processed + /// and we don't want to override any context used by it + auto result_istr = callImpl(current_session, uri_redirect, response, method, for_object_info); + if (!for_object_info) + istr = result_istr; + } +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::call(UpdatableSessionPtr & current_session, Poco::Net::HTTPResponse & response, const String & method_, bool throw_on_all_errors, bool for_object_info) +{ + try + { + /// we can fetch object info while the request is being processed + /// and we don't want to override any context used by it + auto result_istr = callImpl(current_session, saved_uri_redirect ? *saved_uri_redirect : uri, response, method_, for_object_info); + if (!for_object_info) + istr = result_istr; + } + catch (...) + { + /// we can fetch object info while the request is being processed + /// and we don't want to override any context used by it + if (for_object_info) + throw; + + if (throw_on_all_errors) + throw; + + auto http_status = response.getStatus(); + + if (http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_FOUND && http_skip_not_found_url) + { + initialization_error = InitializeError::SKIP_NOT_FOUND_URL; + } + else if (!isRetriableError(http_status)) + { + initialization_error = InitializeError::NON_RETRYABLE_ERROR; + exception = std::current_exception(); + } + else + { + throw; + } + } +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::initialize() +{ + Poco::Net::HTTPResponse response; + + call(session, response, method); + if (initialization_error != InitializeError::NONE) + return; + + while (isRedirect(response.getStatus())) + { + Poco::URI uri_redirect = getUriAfterRedirect(saved_uri_redirect.value_or(uri), response); + if (remote_host_filter) + remote_host_filter->checkURL(uri_redirect); + + session->updateSession(uri_redirect); + + istr = callImpl(session, uri_redirect, response, method); + saved_uri_redirect = uri_redirect; + } + + if (response.hasContentLength()) + LOG_DEBUG(log, "Received response with content length: {}", response.getContentLength()); + + if (withPartialContent(read_range) && response.getStatus() != Poco::Net::HTTPResponse::HTTPStatus::HTTP_PARTIAL_CONTENT) + { + /// Having `200 OK` instead of `206 Partial Content` is acceptable in case we retried with range.begin == 0. + if (getOffset() != 0) + { + if (!exception) + { + exception = std::make_exception_ptr(Exception( + ErrorCodes::HTTP_RANGE_NOT_SATISFIABLE, + "Cannot read with range: [{}, {}] (response status: {}, reason: {})", + *read_range.begin, + read_range.end ? toString(*read_range.end) : "-", + toString(response.getStatus()), response.getReason())); + } + + /// Retry 200OK + if (response.getStatus() == Poco::Net::HTTPResponse::HTTPStatus::HTTP_OK) + initialization_error = InitializeError::RETRYABLE_ERROR; + else + initialization_error = InitializeError::NON_RETRYABLE_ERROR; + + return; + } + else if (read_range.end) + { + /// We could have range.begin == 0 and range.end != 0 in case of DiskWeb and failing to read with partial content + /// will affect only performance, so a warning is enough. + LOG_WARNING(log, "Unable to read with range header: [{}, {}]", read_range.begin.value_or(0), *read_range.end); + } + } + + // Remember file size. It'll be used to report eof in next nextImpl() call. + if (!read_range.end && response.hasContentLength()) + file_info = parseFileInfo(response, withPartialContent(read_range) ? getOffset() : 0); + + impl = std::make_unique<ReadBufferFromIStream>(*istr, buffer_size); + + if (use_external_buffer) + setupExternalBuffer(); +} + +template <typename UpdatableSessionPtr> +bool ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::nextImpl() +{ + if (initialization_error == InitializeError::SKIP_NOT_FOUND_URL) + return false; + assert(initialization_error == InitializeError::NONE); + + if (next_callback) + next_callback(count()); + + if ((read_range.end && getOffset() > read_range.end.value()) || + (file_info && file_info->file_size && getOffset() >= file_info->file_size.value())) + { + /// Response was fully read. + markSessionForReuse(session->getSession()); + ProfileEvents::increment(ProfileEvents::ReadWriteBufferFromHTTPPreservedSessions); + return false; + } + + if (impl) + { + if (use_external_buffer) + { + setupExternalBuffer(); + } + else + { + /** + * impl was initialized before, pass position() to it to make + * sure there is no pending data which was not read. + */ + if (!working_buffer.empty()) + impl->position() = position(); + } + } + + bool result = false; + size_t milliseconds_to_wait = settings.http_retry_initial_backoff_ms; + bool last_attempt = false; + + auto on_retriable_error = [&]() + { + retry_with_range_header = true; + impl.reset(); + auto http_session = session->getSession(); + http_session->reset(); + if (!last_attempt) + { + sleepForMilliseconds(milliseconds_to_wait); + milliseconds_to_wait = std::min(milliseconds_to_wait * 2, settings.http_retry_max_backoff_ms); + } + }; + + for (size_t i = 0;; ++i) + { + if (last_attempt) + break; + last_attempt = i + 1 >= settings.http_max_tries; + + exception = nullptr; + initialization_error = InitializeError::NONE; + + try + { + if (!impl) + { + initialize(); + + if (initialization_error == InitializeError::NON_RETRYABLE_ERROR) + { + assert(exception); + break; + } + else if (initialization_error == InitializeError::SKIP_NOT_FOUND_URL) + { + return false; + } + else if (initialization_error == InitializeError::RETRYABLE_ERROR) + { + LOG_ERROR( + log, + "HTTP request to `{}` failed at try {}/{} with bytes read: {}/{}. " + "(Current backoff wait is {}/{} ms)", + uri.toString(), i + 1, settings.http_max_tries, getOffset(), + read_range.end ? toString(*read_range.end) : "unknown", + milliseconds_to_wait, settings.http_retry_max_backoff_ms); + + assert(exception); + on_retriable_error(); + continue; + } + + assert(!exception); + + if (use_external_buffer) + { + setupExternalBuffer(); + } + } + + result = impl->next(); + exception = nullptr; + break; + } + catch (const Poco::Exception & e) + { + /// Too many open files - non-retryable. + if (e.code() == POCO_EMFILE) + throw; + + /** Retry request unconditionally if nothing has been read yet. + * Otherwise if it is GET method retry with range header. + */ + bool can_retry_request = !offset_from_begin_pos || method == Poco::Net::HTTPRequest::HTTP_GET; + if (!can_retry_request) + throw; + + LOG_ERROR( + log, + "HTTP request to `{}` failed at try {}/{} with bytes read: {}/{}. " + "Error: {}. (Current backoff wait is {}/{} ms)", + uri.toString(), + i + 1, + settings.http_max_tries, + getOffset(), + read_range.end ? toString(*read_range.end) : "unknown", + e.displayText(), + milliseconds_to_wait, + settings.http_retry_max_backoff_ms); + + on_retriable_error(); + exception = std::current_exception(); + } + } + + if (exception) + std::rethrow_exception(exception); + + if (!result) + { + /// Eof is reached, i.e response was fully read. + markSessionForReuse(session->getSession()); + ProfileEvents::increment(ProfileEvents::ReadWriteBufferFromHTTPPreservedSessions); + return false; + } + + internal_buffer = impl->buffer(); + working_buffer = internal_buffer; + offset_from_begin_pos += working_buffer.size(); + return true; +} + +template <typename UpdatableSessionPtr> +size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::readBigAt(char * to, size_t n, size_t offset, const std::function<bool(size_t)> & progress_callback) +{ + /// Caller must have checked supportsReadAt(). + /// This ensures we've sent at least one HTTP request and populated saved_uri_redirect. + chassert(file_info && file_info->seekable); + + if (n == 0) + return 0; + + Poco::URI uri_ = saved_uri_redirect.value_or(uri); + if (uri_.getPath().empty()) + uri_.setPath("/"); + + size_t milliseconds_to_wait = settings.http_retry_initial_backoff_ms; + + for (size_t attempt = 0;; ++attempt) + { + bool last_attempt = attempt + 1 >= settings.http_max_tries; + + Poco::Net::HTTPRequest request(method, uri_.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); + prepareRequest(request, uri_, HTTPRange { .begin = offset, .end = offset + n - 1}); + + LOG_TRACE(log, "Sending request to {} for range [{}, {})", uri_.toString(), offset, offset + n); + + auto sess = session->createDetachedSession(uri_); + + Poco::Net::HTTPResponse response; + std::istream * result_istr; + + try + { + sess->sendRequest(request); + result_istr = receiveResponse(*sess, request, response, /*allow_redirects*/ false); + + if (response.getStatus() != Poco::Net::HTTPResponse::HTTPStatus::HTTP_PARTIAL_CONTENT && + (offset != 0 || offset + n < *file_info->file_size)) + throw Exception( + ErrorCodes::HTTP_RANGE_NOT_SATISFIABLE, + "Expected 206 Partial Content, got {} when reading {} range [{}, {})", + toString(response.getStatus()), uri_.toString(), offset, offset + n); + + bool cancelled; + size_t r = copyFromIStreamWithProgressCallback(*result_istr, to, n, progress_callback, &cancelled); + + if (!cancelled) + { + /// Response was fully read. + markSessionForReuse(sess); + ProfileEvents::increment(ProfileEvents::ReadWriteBufferFromHTTPPreservedSessions); + } + + return r; + } + catch (const Poco::Exception & e) + { + LOG_ERROR( + log, + "HTTP request (positioned) to `{}` with range [{}, {}) failed at try {}/{}: {}", + uri_.toString(), offset, offset + n, attempt + 1, settings.http_max_tries, + e.what()); + + /// Decide whether to retry. + + if (last_attempt) + throw; + + /// Too many open files - non-retryable. + if (e.code() == POCO_EMFILE) + throw; + + if (const auto * h = dynamic_cast<const HTTPException*>(&e); + h && !isRetriableError(static_cast<Poco::Net::HTTPResponse::HTTPStatus>(h->getHTTPStatus()))) + throw; + + sleepForMilliseconds(milliseconds_to_wait); + milliseconds_to_wait = std::min(milliseconds_to_wait * 2, settings.http_retry_max_backoff_ms); + continue; + } + } +} + +template <typename UpdatableSessionPtr> +off_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getPosition() { return getOffset() - available(); } + +template <typename UpdatableSessionPtr> +off_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::seek(off_t offset_, int whence) +{ + if (whence != SEEK_SET) + throw Exception(ErrorCodes::CANNOT_SEEK_THROUGH_FILE, "Only SEEK_SET mode is allowed."); + + if (offset_ < 0) + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek position is out of bounds. Offset: {}", + offset_); + + off_t current_offset = getOffset(); + if (!working_buffer.empty() && size_t(offset_) >= current_offset - working_buffer.size() && offset_ < current_offset) + { + pos = working_buffer.end() - (current_offset - offset_); + assert(pos >= working_buffer.begin()); + assert(pos < working_buffer.end()); + + return getPosition(); + } + + if (impl) + { + auto position = getPosition(); + if (offset_ > position) + { + size_t diff = offset_ - position; + if (diff < settings.remote_read_min_bytes_for_seek) + { + ignore(diff); + return offset_; + } + } + + if (!atEndOfRequestedRangeGuess()) + ProfileEvents::increment(ProfileEvents::ReadBufferSeekCancelConnection); + impl.reset(); + } + + resetWorkingBuffer(); + read_range.begin = offset_; + offset_from_begin_pos = 0; + + return offset_; +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::setReadUntilPosition(size_t until) +{ + until = std::max(until, 1ul); + if (read_range.end && *read_range.end + 1 == until) + return; + read_range.end = until - 1; + read_range.begin = getPosition(); + resetWorkingBuffer(); + if (impl) + { + if (!atEndOfRequestedRangeGuess()) + ProfileEvents::increment(ProfileEvents::ReadBufferSeekCancelConnection); + impl.reset(); + } +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::setReadUntilEnd() +{ + if (!read_range.end) + return; + read_range.end.reset(); + read_range.begin = getPosition(); + resetWorkingBuffer(); + if (impl) + { + if (!atEndOfRequestedRangeGuess()) + ProfileEvents::increment(ProfileEvents::ReadBufferSeekCancelConnection); + impl.reset(); + } +} + +template <typename UpdatableSessionPtr> +bool ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::supportsRightBoundedReads() const { return true; } + +template <typename UpdatableSessionPtr> +bool ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::atEndOfRequestedRangeGuess() +{ + if (!impl) + return true; + if (read_range.end) + return getPosition() > static_cast<off_t>(*read_range.end); + if (file_info && file_info->file_size) + return getPosition() >= static_cast<off_t>(*file_info->file_size); + return false; +} + +template <typename UpdatableSessionPtr> +std::string ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getResponseCookie(const std::string & name, const std::string & def) const +{ + for (const auto & cookie : cookies) + if (cookie.getName() == name) + return cookie.getValue(); + return def; +} + +template <typename UpdatableSessionPtr> +void ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::setNextCallback(NextCallback next_callback_) +{ + next_callback = next_callback_; + /// Some data maybe already read + next_callback(count()); +} + +template <typename UpdatableSessionPtr> +const std::string & ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getCompressionMethod() const { return content_encoding; } + +template <typename UpdatableSessionPtr> +std::optional<time_t> ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::tryGetLastModificationTime() +{ + if (!file_info) + { + try + { + file_info = getFileInfo(); + } + catch (...) + { + return std::nullopt; + } + } + + return file_info->last_modified; +} + +template <typename UpdatableSessionPtr> +HTTPFileInfo ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::getFileInfo() +{ + Poco::Net::HTTPResponse response; + try + { + getHeadResponse(response); + } + catch (HTTPException & e) + { + /// Maybe the web server doesn't support HEAD requests. + /// E.g. webhdfs reports status 400. + /// We should proceed in hopes that the actual GET request will succeed. + /// (Unless the error in transient. Don't want to nondeterministically sometimes + /// fall back to slow whole-file reads when HEAD is actually supported; that sounds + /// like a nightmare to debug.) + if (e.getHTTPStatus() >= 400 && e.getHTTPStatus() <= 499 && + e.getHTTPStatus() != Poco::Net::HTTPResponse::HTTP_TOO_MANY_REQUESTS) + return HTTPFileInfo{}; + + throw; + } + return parseFileInfo(response, 0); +} + +template <typename UpdatableSessionPtr> +HTTPFileInfo ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::parseFileInfo(const Poco::Net::HTTPResponse & response, size_t requested_range_begin) +{ + HTTPFileInfo res; + + if (response.hasContentLength()) + { + res.file_size = response.getContentLength(); + + if (response.getStatus() == Poco::Net::HTTPResponse::HTTPStatus::HTTP_PARTIAL_CONTENT) + { + *res.file_size += requested_range_begin; + res.seekable = true; + } + else + { + res.seekable = response.has("Accept-Ranges") && response.get("Accept-Ranges") == "bytes"; + } + } + + if (response.has("Last-Modified")) + { + String date_str = response.get("Last-Modified"); + struct tm info; + char * end = strptime(date_str.data(), "%a, %d %b %Y %H:%M:%S %Z", &info); + if (end == date_str.data() + date_str.size()) + res.last_modified = timegm(&info); + } + + return res; +} + +} + +SessionFactory::SessionFactory(const ConnectionTimeouts & timeouts_, Poco::Net::HTTPClientSession::ProxyConfig proxy_config_) + : timeouts(timeouts_), proxy_config(proxy_config_) {} + +SessionFactory::SessionType SessionFactory::buildNewSession(const Poco::URI & uri) +{ + return makeHTTPSession(uri, timeouts, proxy_config); +} + +ReadWriteBufferFromHTTP::ReadWriteBufferFromHTTP( + Poco::URI uri_, + const std::string & method_, + OutStreamCallback out_stream_callback_, + const ConnectionTimeouts & timeouts, + const Poco::Net::HTTPBasicCredentials & credentials_, + const UInt64 max_redirects, + size_t buffer_size_, + const ReadSettings & settings_, + const HTTPHeaderEntries & http_header_entries_, + const RemoteHostFilter * remote_host_filter_, + bool delay_initialization_, + bool use_external_buffer_, + bool skip_not_found_url_, + std::optional<HTTPFileInfo> file_info_, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config_) + : Parent( + std::make_shared<SessionType>(uri_, max_redirects, std::make_shared<SessionFactory>(timeouts, proxy_config_)), + uri_, + credentials_, + method_, + out_stream_callback_, + buffer_size_, + settings_, + http_header_entries_, + remote_host_filter_, + delay_initialization_, + use_external_buffer_, + skip_not_found_url_, + file_info_, + proxy_config_) {} + + +PooledSessionFactory::PooledSessionFactory( + const ConnectionTimeouts & timeouts_, size_t per_endpoint_pool_size_) + : timeouts(timeouts_) + , per_endpoint_pool_size(per_endpoint_pool_size_) {} + +PooledSessionFactory::SessionType PooledSessionFactory::buildNewSession(const Poco::URI & uri) +{ + return makePooledHTTPSession(uri, timeouts, per_endpoint_pool_size); +} + + +PooledReadWriteBufferFromHTTP::PooledReadWriteBufferFromHTTP( + Poco::URI uri_, + const std::string & method_, + OutStreamCallback out_stream_callback_, + const ConnectionTimeouts & timeouts_, + const Poco::Net::HTTPBasicCredentials & credentials_, + size_t buffer_size_, + const UInt64 max_redirects, + size_t max_connections_per_endpoint) + : Parent( + std::make_shared<SessionType>(uri_, max_redirects, std::make_shared<PooledSessionFactory>(timeouts_, max_connections_per_endpoint)), + uri_, + credentials_, + method_, + out_stream_callback_, + buffer_size_) {} + + +template class UpdatableSession<SessionFactory>; +template class UpdatableSession<PooledSessionFactory>; +template class detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<UpdatableSession<SessionFactory>>>; +template class detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<UpdatableSession<PooledSessionFactory>>>; + +} diff --git a/contrib/clickhouse/src/IO/ReadWriteBufferFromHTTP.h b/contrib/clickhouse/src/IO/ReadWriteBufferFromHTTP.h new file mode 100644 index 0000000000..ae02292446 --- /dev/null +++ b/contrib/clickhouse/src/IO/ReadWriteBufferFromHTTP.h @@ -0,0 +1,291 @@ +#pragma once + +#include <functional> +#include <IO/ConnectionTimeouts.h> +#include <IO/HTTPCommon.h> +#include <IO/ParallelReadBuffer.h> +#include <IO/ReadBuffer.h> +#include <IO/ReadBufferFromIStream.h> +#include <IO/ReadHelpers.h> +#include <IO/ReadSettings.h> +#include <IO/WithFileName.h> +#include <IO/HTTPHeaderEntries.h> +#include <Common/logger_useful.h> +#include <base/sleep.h> +#include <base/types.h> +#include <Poco/Any.h> +#include <Poco/Net/HTTPBasicCredentials.h> +#include <Poco/Net/HTTPClientSession.h> +#include <Poco/Net/HTTPRequest.h> +#include <Poco/Net/HTTPResponse.h> +#include <Poco/URI.h> +#include <Poco/URIStreamFactory.h> +#include <Common/DNSResolver.h> +#include <Common/RemoteHostFilter.h> +#include "clickhouse_config.h" +#include "config_version.h" + +#include <filesystem> + +namespace DB +{ + +template <typename TSessionFactory> +class UpdatableSession +{ +public: + using SessionPtr = typename TSessionFactory::SessionType; + + explicit UpdatableSession(const Poco::URI & uri, UInt64 max_redirects_, std::shared_ptr<TSessionFactory> session_factory_); + + SessionPtr getSession(); + + void updateSession(const Poco::URI & uri); + + /// Thread safe. + SessionPtr createDetachedSession(const Poco::URI & uri); + + std::shared_ptr<UpdatableSession<TSessionFactory>> clone(const Poco::URI & uri); + +private: + SessionPtr session; + UInt64 redirects{0}; + UInt64 max_redirects; + Poco::URI initial_uri; + std::shared_ptr<TSessionFactory> session_factory; +}; + + +/// Information from HTTP response header. +struct HTTPFileInfo +{ + // nullopt if the server doesn't report it. + std::optional<size_t> file_size; + std::optional<time_t> last_modified; + bool seekable = false; +}; + + +namespace detail +{ + /// Byte range, including right bound [begin, end]. + struct HTTPRange + { + std::optional<size_t> begin; + std::optional<size_t> end; + }; + + template <typename UpdatableSessionPtr> + class ReadWriteBufferFromHTTPBase : public SeekableReadBuffer, public WithFileName, public WithFileSize + { + protected: + Poco::URI uri; + std::string method; + std::string content_encoding; + + UpdatableSessionPtr session; + std::istream * istr; /// owned by session + std::unique_ptr<ReadBuffer> impl; + std::function<void(std::ostream &)> out_stream_callback; + const Poco::Net::HTTPBasicCredentials & credentials; + std::vector<Poco::Net::HTTPCookie> cookies; + HTTPHeaderEntries http_header_entries; + const RemoteHostFilter * remote_host_filter = nullptr; + std::function<void(size_t)> next_callback; + + size_t buffer_size; + bool use_external_buffer; + + size_t offset_from_begin_pos = 0; + HTTPRange read_range; + std::optional<HTTPFileInfo> file_info; + + /// Delayed exception in case retries with partial content are not satisfiable. + std::exception_ptr exception; + bool retry_with_range_header = false; + /// In case of redirects, save result uri to use it if we retry the request. + std::optional<Poco::URI> saved_uri_redirect; + + bool http_skip_not_found_url; + + ReadSettings settings; + Poco::Logger * log; + + Poco::Net::HTTPClientSession::ProxyConfig proxy_config; + + bool withPartialContent(const HTTPRange & range) const; + + size_t getOffset() const; + + void prepareRequest(Poco::Net::HTTPRequest & request, Poco::URI uri_, std::optional<HTTPRange> range) const; + + std::istream * callImpl(UpdatableSessionPtr & current_session, Poco::URI uri_, Poco::Net::HTTPResponse & response, const std::string & method_, bool for_object_info = false); + + size_t getFileSize() override; + + bool supportsReadAt() override; + + bool checkIfActuallySeekable() override; + + String getFileName() const override; + + enum class InitializeError + { + RETRYABLE_ERROR, + /// If error is not retriable, `exception` variable must be set. + NON_RETRYABLE_ERROR, + /// Allows to skip not found urls for globs + SKIP_NOT_FOUND_URL, + NONE, + }; + + InitializeError initialization_error = InitializeError::NONE; + + private: + void getHeadResponse(Poco::Net::HTTPResponse & response); + + void setupExternalBuffer(); + + public: + using NextCallback = std::function<void(size_t)>; + using OutStreamCallback = std::function<void(std::ostream &)>; + + explicit ReadWriteBufferFromHTTPBase( + UpdatableSessionPtr session_, + Poco::URI uri_, + const Poco::Net::HTTPBasicCredentials & credentials_, + const std::string & method_ = {}, + OutStreamCallback out_stream_callback_ = {}, + size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE, + const ReadSettings & settings_ = {}, + HTTPHeaderEntries http_header_entries_ = {}, + const RemoteHostFilter * remote_host_filter_ = nullptr, + bool delay_initialization = false, + bool use_external_buffer_ = false, + bool http_skip_not_found_url_ = false, + std::optional<HTTPFileInfo> file_info_ = std::nullopt, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config_ = {}); + + void callWithRedirects(Poco::Net::HTTPResponse & response, const String & method_, bool throw_on_all_errors = false, bool for_object_info = false); + + void call(UpdatableSessionPtr & current_session, Poco::Net::HTTPResponse & response, const String & method_, bool throw_on_all_errors = false, bool for_object_info = false); + + /** + * Throws if error is retryable, otherwise sets initialization_error = NON_RETRYABLE_ERROR and + * saves exception into `exception` variable. In case url is not found and skip_not_found_url == true, + * sets initialization_error = SKIP_NOT_FOUND_URL, otherwise throws. + */ + void initialize(); + + bool nextImpl() override; + + size_t readBigAt(char * to, size_t n, size_t offset, const std::function<bool(size_t)> & progress_callback) override; + + off_t getPosition() override; + + off_t seek(off_t offset_, int whence) override; + + void setReadUntilPosition(size_t until) override; + + void setReadUntilEnd() override; + + bool supportsRightBoundedReads() const override; + + // If true, if we destroy impl now, no work was wasted. Just for metrics. + bool atEndOfRequestedRangeGuess(); + + std::string getResponseCookie(const std::string & name, const std::string & def) const; + + /// Set function to call on each nextImpl, useful when you need to track + /// progress. + /// NOTE: parameter on each call is not incremental -- it's all bytes count + /// passed through the buffer + void setNextCallback(NextCallback next_callback_); + + const std::string & getCompressionMethod() const; + + std::optional<time_t> tryGetLastModificationTime(); + + HTTPFileInfo getFileInfo(); + + HTTPFileInfo parseFileInfo(const Poco::Net::HTTPResponse & response, size_t requested_range_begin); + }; +} + +class SessionFactory +{ +public: + explicit SessionFactory(const ConnectionTimeouts & timeouts_, Poco::Net::HTTPClientSession::ProxyConfig proxy_config_ = {}); + + using SessionType = HTTPSessionPtr; + + SessionType buildNewSession(const Poco::URI & uri); +private: + ConnectionTimeouts timeouts; + Poco::Net::HTTPClientSession::ProxyConfig proxy_config; +}; + +class ReadWriteBufferFromHTTP : public detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<UpdatableSession<SessionFactory>>> +{ + using SessionType = UpdatableSession<SessionFactory>; + using Parent = detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<SessionType>>; + +public: + ReadWriteBufferFromHTTP( + Poco::URI uri_, + const std::string & method_, + OutStreamCallback out_stream_callback_, + const ConnectionTimeouts & timeouts, + const Poco::Net::HTTPBasicCredentials & credentials_, + const UInt64 max_redirects = 0, + size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE, + const ReadSettings & settings_ = {}, + const HTTPHeaderEntries & http_header_entries_ = {}, + const RemoteHostFilter * remote_host_filter_ = nullptr, + bool delay_initialization_ = true, + bool use_external_buffer_ = false, + bool skip_not_found_url_ = false, + std::optional<HTTPFileInfo> file_info_ = std::nullopt, + Poco::Net::HTTPClientSession::ProxyConfig proxy_config_ = {}); +}; + +class PooledSessionFactory +{ +public: + explicit PooledSessionFactory( + const ConnectionTimeouts & timeouts_, size_t per_endpoint_pool_size_); + + using SessionType = PooledHTTPSessionPtr; + + /// Thread safe. + SessionType buildNewSession(const Poco::URI & uri); + +private: + ConnectionTimeouts timeouts; + size_t per_endpoint_pool_size; +}; + +class PooledReadWriteBufferFromHTTP : public detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<UpdatableSession<PooledSessionFactory>>> +{ + using SessionType = UpdatableSession<PooledSessionFactory>; + using Parent = detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<SessionType>>; + +public: + explicit PooledReadWriteBufferFromHTTP( + Poco::URI uri_, + const std::string & method_ = {}, + OutStreamCallback out_stream_callback_ = {}, + const ConnectionTimeouts & timeouts_ = {}, + const Poco::Net::HTTPBasicCredentials & credentials_ = {}, + size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE, + const UInt64 max_redirects = 0, + size_t max_connections_per_endpoint = DEFAULT_COUNT_OF_HTTP_CONNECTIONS_PER_ENDPOINT); +}; + + +extern template class UpdatableSession<SessionFactory>; +extern template class UpdatableSession<PooledSessionFactory>; +extern template class detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<UpdatableSession<SessionFactory>>>; +extern template class detail::ReadWriteBufferFromHTTPBase<std::shared_ptr<UpdatableSession<PooledSessionFactory>>>; + +} diff --git a/contrib/clickhouse/src/IO/Resource/ClassifiersConfig.cpp b/contrib/clickhouse/src/IO/Resource/ClassifiersConfig.cpp new file mode 100644 index 0000000000..fcd4655e2e --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/ClassifiersConfig.cpp @@ -0,0 +1,40 @@ +#include <IO/Resource/ClassifiersConfig.h> + +#include <Common/Exception.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int RESOURCE_NOT_FOUND; +} + +ClassifierDescription::ClassifierDescription(const Poco::Util::AbstractConfiguration & config, const String & config_prefix) +{ + Poco::Util::AbstractConfiguration::Keys keys; + config.keys(config_prefix, keys); + for (const auto & key : keys) + emplace(key, config.getString(config_prefix + "." + key)); +} + +ClassifiersConfig::ClassifiersConfig(const Poco::Util::AbstractConfiguration & config) +{ + Poco::Util::AbstractConfiguration::Keys keys; + const String config_prefix = "classifiers"; + config.keys(config_prefix, keys); + for (const auto & key : keys) + classifiers.emplace(std::piecewise_construct, + std::forward_as_tuple(key), + std::forward_as_tuple(config, config_prefix + "." + key)); +} + +const ClassifierDescription & ClassifiersConfig::get(const String & classifier_name) +{ + if (auto it = classifiers.find(classifier_name); it != classifiers.end()) + return it->second; + else + throw Exception(ErrorCodes::RESOURCE_NOT_FOUND, "Unknown classifier '{}' to access resources", classifier_name); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/ClassifiersConfig.h b/contrib/clickhouse/src/IO/Resource/ClassifiersConfig.h new file mode 100644 index 0000000000..96e2bd0f0b --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/ClassifiersConfig.h @@ -0,0 +1,39 @@ +#pragma once + +#include <base/types.h> +#include <Poco/Util/AbstractConfiguration.h> +#include <unordered_map> + +namespace DB +{ + +/// Mapping of resource name into path string (e.g. "disk1" -> "/path/to/class") +struct ClassifierDescription : std::unordered_map<String, String> +{ + ClassifierDescription(const Poco::Util::AbstractConfiguration & config, const String & config_prefix); +}; + +/* + * Loads a config with the following format: + * <classifiers> + * <classifier1> + * <resource1>/path/to/queue</resource1> + * <resource2>/path/to/another/queue</resource2> + * </classifier1> + * ... + * <classifierN>...</classifierN> + * </classifiers> + */ +class ClassifiersConfig +{ +public: + ClassifiersConfig() = default; + explicit ClassifiersConfig(const Poco::Util::AbstractConfiguration & config); + + const ClassifierDescription & get(const String & classifier_name); + +private: + std::unordered_map<String, ClassifierDescription> classifiers; // by classifier_name +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/DynamicResourceManager.cpp b/contrib/clickhouse/src/IO/Resource/DynamicResourceManager.cpp new file mode 100644 index 0000000000..df0de6575f --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/DynamicResourceManager.cpp @@ -0,0 +1,234 @@ +#include <IO/Resource/DynamicResourceManager.h> + +#include <IO/SchedulerNodeFactory.h> +#include <IO/ResourceManagerFactory.h> +#include <IO/ISchedulerQueue.h> + +#include <Common/Exception.h> +#include <Common/StringUtils/StringUtils.h> + +#include <map> +#include <tuple> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int RESOURCE_ACCESS_DENIED; + extern const int RESOURCE_NOT_FOUND; + extern const int INVALID_SCHEDULER_NODE; +} + +DynamicResourceManager::State::State(EventQueue * event_queue, const Poco::Util::AbstractConfiguration & config) + : classifiers(config) +{ + Poco::Util::AbstractConfiguration::Keys keys; + const String config_prefix = "resources"; + config.keys(config_prefix, keys); + + // Create resource for every element under <resources> tag + for (const auto & key : keys) + { + resources.emplace(key, std::make_shared<Resource>(key, event_queue, config, config_prefix + "." + key)); + } +} + +DynamicResourceManager::State::Resource::Resource( + const String & name, + EventQueue * event_queue, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix) +{ + Poco::Util::AbstractConfiguration::Keys keys; + config.keys(config_prefix, keys); + + // Sort nodes by path to create parents before children + std::map<String, String> path2key; + for (const auto & key : keys) + { + if (!startsWith(key, "node")) + continue; + String path = config.getString(config_prefix + "." + key + "[@path]", ""); + if (path.empty()) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Attribute 'path' must be specified in all nodes for resource '{}'", name); + if (path[0] != '/') + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Path must start with '/' for resource '{}'", name); + if (auto [_, inserted] = path2key.emplace(path, key); !inserted) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Duplicate path '{}' for resource '{}'", path, name); + } + + // Create nodes + bool has_root = false; + for (auto [path, key] : path2key) + { + // Validate path + size_t slash = path.rfind('/'); + if (slash == String::npos) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Invalid scheduler node path '{}' for resource '{}'", path, name); + + // Create node + String basename = path.substr(slash + 1); // root name is empty string + auto [iter, _] = nodes.emplace(path, Node(basename, event_queue, config, config_prefix + "." + key)); + if (path == "/") + { + has_root = true; + continue; + } + + // Attach created node to parent (if not root) + // NOTE: resource root is attached to the scheduler using event queue for thread-safety + String parent_path = path.substr(0, slash); + if (parent_path.empty()) + parent_path = "/"; + if (auto parent = nodes.find(parent_path); parent != nodes.end()) + parent->second.ptr->attachChild(iter->second.ptr); + else + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Parent node doesn't exist for path '{}' for resource '{}'", path, name); + } + + if (!has_root) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "undefined root node path '/' for resource '{}'", name); +} + +DynamicResourceManager::State::Resource::~Resource() +{ + // NOTE: we should rely on `attached_to` and cannot use `parent`, + // NOTE: because `parent` can be `nullptr` in case attachment is still in event queue + if (attached_to != nullptr) + { + ISchedulerNode * root = nodes.find("/")->second.ptr.get(); + attached_to->event_queue->enqueue([my_scheduler = attached_to, root] + { + my_scheduler->removeChild(root); + }); + } +} + +DynamicResourceManager::State::Node::Node(const String & name, EventQueue * event_queue, const Poco::Util::AbstractConfiguration & config, const std::string & config_prefix) + : type(config.getString(config_prefix + ".type", "fifo")) + , ptr(SchedulerNodeFactory::instance().get(type, event_queue, config, config_prefix)) +{ + ptr->basename = name; +} + +bool DynamicResourceManager::State::Resource::equals(const DynamicResourceManager::State::Resource & o) const +{ + if (nodes.size() != o.nodes.size()) + return false; + + for (const auto & [path, o_node] : o.nodes) + { + auto iter = nodes.find(path); + if (iter == nodes.end()) + return false; + if (!iter->second.equals(o_node)) + return false; + } + + return true; +} + +bool DynamicResourceManager::State::Node::equals(const DynamicResourceManager::State::Node & o) const +{ + if (type != o.type) + return false; + return ptr->equals(o.ptr.get()); +} + +DynamicResourceManager::Classifier::Classifier(const DynamicResourceManager::StatePtr & state_, const String & classifier_name) + : state(state_) +{ + // State is immutable, but nodes are mutable and thread-safe + // So it's safe to obtain node pointers w/o lock + for (auto [resource_name, path] : state->classifiers.get(classifier_name)) + { + if (auto resource_iter = state->resources.find(resource_name); resource_iter != state->resources.end()) + { + const auto & resource = resource_iter->second; + if (auto node_iter = resource->nodes.find(path); node_iter != resource->nodes.end()) + { + if (auto * queue = dynamic_cast<ISchedulerQueue *>(node_iter->second.ptr.get())) + resources.emplace(resource_name, ResourceLink{.queue = queue}); + else + throw Exception(ErrorCodes::RESOURCE_NOT_FOUND, "Unable to access non-queue node at path '{}' for resource '{}'", path, resource_name); + } + else + throw Exception(ErrorCodes::RESOURCE_NOT_FOUND, "Path '{}' for resource '{}' does not exist", path, resource_name); + } + else + resources.emplace(resource_name, ResourceLink{}); // resource not configured yet - use unlimited resource + } +} + +ResourceLink DynamicResourceManager::Classifier::get(const String & resource_name) +{ + if (auto iter = resources.find(resource_name); iter != resources.end()) + return iter->second; + else + throw Exception(ErrorCodes::RESOURCE_ACCESS_DENIED, "Access denied to resource '{}'", resource_name); +} + +DynamicResourceManager::DynamicResourceManager() + : state(new State()) +{ + scheduler.start(); +} + +void DynamicResourceManager::updateConfiguration(const Poco::Util::AbstractConfiguration & config) +{ + StatePtr new_state = std::make_shared<State>(scheduler.event_queue, config); + + std::lock_guard lock{mutex}; + + // Resource update leads to loss of runtime data of nodes and may lead to temporary violation of constraints (e.g. limits) + // Try to minimise this by reusing "equal" resources (initialized with the same configuration). + for (auto & [name, new_resource] : new_state->resources) + { + if (auto iter = state->resources.find(name); iter != state->resources.end()) // Resource update + { + State::ResourcePtr old_resource = iter->second; + if (old_resource->equals(*new_resource)) + new_resource = old_resource; // Rewrite with older version to avoid loss of runtime data + } + } + + // Commit new state + // NOTE: dtor will detach from scheduler old resources that are not in use currently + state = new_state; + + // Attach new and updated resources to the scheduler + for (auto & [name, resource] : new_state->resources) + { + const SchedulerNodePtr & root = resource->nodes.find("/")->second.ptr; + if (root->parent == nullptr) + { + resource->attached_to = &scheduler; + scheduler.event_queue->enqueue([this, root] + { + scheduler.attachChild(root); + }); + } + } + + // NOTE: after mutex unlock `state` became available for Classifier(s) and must be immutable +} + +ClassifierPtr DynamicResourceManager::acquire(const String & classifier_name) +{ + // Acquire a reference to the current state + StatePtr state_; + { + std::lock_guard lock{mutex}; + state_ = state; + } + + return std::make_shared<Classifier>(state_, classifier_name); +} + +void registerDynamicResourceManager(ResourceManagerFactory & factory) +{ + factory.registerMethod<DynamicResourceManager>("dynamic"); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/DynamicResourceManager.h b/contrib/clickhouse/src/IO/Resource/DynamicResourceManager.h new file mode 100644 index 0000000000..aa1147f1fb --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/DynamicResourceManager.h @@ -0,0 +1,93 @@ +#pragma once + +#include <IO/IResourceManager.h> +#include <IO/SchedulerRoot.h> +#include <IO/Resource/ClassifiersConfig.h> + +#include <mutex> + +namespace DB +{ + +/* + * Implementation of `IResourceManager` supporting arbitrary dynamic hierarchy of scheduler nodes. + * All resources are controlled by single root `SchedulerRoot`. + * + * State of manager is set of resources attached to the scheduler. States are referenced by classifiers. + * Classifiers are used (1) to access resources and (2) to keep shared ownership of resources with pending + * resource requests. This allows `ResourceRequest` and `ResourceLink` to hold raw pointers as long as + * `ClassifierPtr` is acquired and held. + * + * Manager can update configuration after initialization. During update, new version of resources are also + * attached to scheduler, so multiple version can coexist for a short perid. This will violate constraints + * (e.g. in-fly-limit), because different version have independent nodes to impose constraints, the same + * violation will apply to fairness. Old version exists as long as there is at least one classifier + * instance referencing it. Classifiers are typically attached to queries and will be destructed with them. + */ +class DynamicResourceManager : public IResourceManager +{ +public: + DynamicResourceManager(); + void updateConfiguration(const Poco::Util::AbstractConfiguration & config) override; + ClassifierPtr acquire(const String & classifier_name) override; + +private: + /// Holds everything required to work with one specific configuration + struct State + { + struct Node + { + String type; + SchedulerNodePtr ptr; + + Node( + const String & name, + EventQueue * event_queue, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + bool equals(const Node & o) const; + }; + + struct Resource + { + std::unordered_map<String, Node> nodes; // by path + SchedulerRoot * attached_to = nullptr; + + Resource( + const String & name, + EventQueue * event_queue, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + ~Resource(); // unregisters resource from scheduler + bool equals(const Resource & o) const; + }; + + using ResourcePtr = std::shared_ptr<Resource>; + + std::unordered_map<String, ResourcePtr> resources; // by name + ClassifiersConfig classifiers; + + State() = default; + explicit State(EventQueue * event_queue, const Poco::Util::AbstractConfiguration & config); + }; + + using StatePtr = std::shared_ptr<State>; + + /// Created per query, holds State used by that query + class Classifier : public IClassifier + { + public: + Classifier(const StatePtr & state_, const String & classifier_name); + ResourceLink get(const String & resource_name) override; + private: + std::unordered_map<String, ResourceLink> resources; // accessible resources by names + StatePtr state; // hold state to avoid ResourceLink invalidation due to resource deregistration from SchedulerRoot + }; + +private: + SchedulerRoot scheduler; + std::mutex mutex; + StatePtr state; +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/FairPolicy.cpp b/contrib/clickhouse/src/IO/Resource/FairPolicy.cpp new file mode 100644 index 0000000000..248ff04cbd --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/FairPolicy.cpp @@ -0,0 +1,13 @@ +#include <IO/Resource/FairPolicy.h> + +#include <IO/SchedulerNodeFactory.h> + +namespace DB +{ + +void registerFairPolicy(SchedulerNodeFactory & factory) +{ + factory.registerMethod<FairPolicy>("fair"); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/FairPolicy.h b/contrib/clickhouse/src/IO/Resource/FairPolicy.h new file mode 100644 index 0000000000..9c0c78f057 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/FairPolicy.h @@ -0,0 +1,232 @@ +#pragma once + +#include <IO/ISchedulerQueue.h> +#include <IO/SchedulerRoot.h> + +#include <Common/Stopwatch.h> + +#include <algorithm> +#include <unordered_map> +#include <vector> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +/* + * Scheduler node that implements weight-based fair scheduling policy. + * Based on Start-time Fair Queueing (SFQ) algorithm. + * + * Algorithm description. + * Virtual runtime (total consumed cost divided by child weight) is tracked for every child. + * Active child with minimum vruntime is selected to be dequeued next. On activation, initial vruntime + * of a child is set to vruntime of "start" of the last request. This guarantees immediate processing + * of at least single request of newly activated children and thus best isolation and scheduling latency. + */ +class FairPolicy : public ISchedulerNode +{ + /// Scheduling state of a child + struct Item + { + ISchedulerNode * child = nullptr; + double vruntime = 0; /// total consumed cost divided by child weight + + /// For min-heap by vruntime + bool operator<(const Item & rhs) const noexcept + { + return vruntime > rhs.vruntime; + } + }; + +public: + explicit FairPolicy(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + : ISchedulerNode(event_queue_, config, config_prefix) + {} + + bool equals(ISchedulerNode * other) override + { + if (auto * o = dynamic_cast<FairPolicy *>(other)) + return true; + return false; + } + + void attachChild(const SchedulerNodePtr & child) override + { + // Take ownership + if (auto [it, inserted] = children.emplace(child->basename, child); !inserted) + throw Exception( + ErrorCodes::INVALID_SCHEDULER_NODE, + "Can't add another child with the same path: {}", + it->second->getPath()); + + // Attach + child->setParent(this); + + // At first attach as inactive child. + // Inactive attached child must have `info.parent.idx` equal it's index inside `items` array. + // This is needed to avoid later scanning through inactive `items` in O(N). Important optimization. + // NOTE: vruntime must be equal to `system_vruntime` for fairness. + child->info.parent.idx = items.size(); + items.emplace_back(Item{child.get(), system_vruntime}); + + // Activate child if it is not empty + if (child->isActive()) + activateChildImpl(items.size() - 1); + } + + void removeChild(ISchedulerNode * child) override + { + if (auto iter = children.find(child->basename); iter != children.end()) + { + SchedulerNodePtr removed = iter->second; + + // Deactivate: detach is not very common operation, so we can afford O(N) here + size_t child_idx = 0; + [[ maybe_unused ]] bool found = false; + for (; child_idx != items.size(); child_idx++) + { + if (items[child_idx].child == removed.get()) + { + found = true; + break; + } + } + assert(found); + if (child_idx < heap_size) // Detach of active child requires deactivation at first + { + heap_size--; + std::swap(items[child_idx], items[heap_size]); + // Element was removed from inside of heap -- heap must be rebuilt + std::make_heap(items.begin(), items.begin() + heap_size); + child_idx = heap_size; + } + + // Now detach inactive child + if (child_idx != items.size() - 1) + { + std::swap(items[child_idx], items.back()); + items[child_idx].child->info.parent.idx = child_idx; + } + items.pop_back(); + + // Detach + removed->setParent(nullptr); + + // Get rid of ownership + children.erase(iter); + } + } + + ISchedulerNode * getChild(const String & child_name) override + { + if (auto iter = children.find(child_name); iter != children.end()) + return iter->second.get(); + else + return nullptr; + } + + std::pair<ResourceRequest *, bool> dequeueRequest() override + { + if (heap_size == 0) + return {nullptr, false}; + + // Recursively pull request from child + auto [request, child_active] = items.front().child->dequeueRequest(); + assert(request != nullptr); + std::pop_heap(items.begin(), items.begin() + heap_size); + Item & current = items[heap_size - 1]; + + // SFQ fairness invariant: system vruntime equals last served request start-time + assert(current.vruntime >= system_vruntime); + system_vruntime = current.vruntime; + + // By definition vruntime is amount of consumed resource (cost) divided by weight + current.vruntime += double(request->cost) / current.child->info.weight; + max_vruntime = std::max(max_vruntime, current.vruntime); + + if (child_active) // Put active child back in heap after vruntime update + { + std::push_heap(items.begin(), items.begin() + heap_size); + } + else // Deactivate child if it is empty, but remember it's vruntime for latter activations + { + heap_size--; + + // Store index of this inactive child in `parent.idx` + // This enables O(1) search of inactive children instead of O(n) + current.child->info.parent.idx = heap_size; + } + + // Reset any difference between children on busy period end + if (heap_size == 0) + { + // Reset vtime to zero to avoid floating-point error accumulation, + // but do not reset too often, because it's O(N) + UInt64 ns = clock_gettime_ns(); + if (last_reset_ns + 1000000000 < ns) + { + last_reset_ns = ns; + for (Item & item : items) + item.vruntime = 0; + max_vruntime = 0; + } + system_vruntime = max_vruntime; + } + + return {request, heap_size > 0}; + } + + bool isActive() override + { + return heap_size > 0; + } + + void activateChild(ISchedulerNode * child) override + { + // Find this child; this is O(1), thanks to inactive index we hold in `parent.idx` + activateChildImpl(child->info.parent.idx); + } + +private: + void activateChildImpl(size_t inactive_idx) + { + bool activate_parent = heap_size == 0; + + if (heap_size != inactive_idx) + { + std::swap(items[heap_size], items[inactive_idx]); + items[inactive_idx].child->info.parent.idx = inactive_idx; + } + + // Newly activated child should have at least `system_vruntime` to keep fairness + items[heap_size].vruntime = std::max(system_vruntime, items[heap_size].vruntime); + heap_size++; + std::push_heap(items.begin(), items.begin() + heap_size); + + // Recursive activation + if (activate_parent && parent) + parent->activateChild(this); + } + +private: + /// Beginning of `items` vector is heap of active children: [0; `heap_size`). + /// Next go inactive children in unsorted order. + /// NOTE: we have to track vruntime of inactive children for max-min fairness. + std::vector<Item> items; + size_t heap_size = 0; + + /// Last request vruntime + double system_vruntime = 0; + double max_vruntime = 0; + UInt64 last_reset_ns = 0; + + /// All children with ownership + std::unordered_map<String, SchedulerNodePtr> children; // basename -> child +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/FifoQueue.cpp b/contrib/clickhouse/src/IO/Resource/FifoQueue.cpp new file mode 100644 index 0000000000..f4b0e9c332 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/FifoQueue.cpp @@ -0,0 +1,13 @@ +#include <IO/Resource/FifoQueue.h> + +#include <IO/SchedulerNodeFactory.h> + +namespace DB +{ + +void registerFifoQueue(SchedulerNodeFactory & factory) +{ + factory.registerMethod<FifoQueue>("fifo"); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/FifoQueue.h b/contrib/clickhouse/src/IO/Resource/FifoQueue.h new file mode 100644 index 0000000000..f3ff15ad46 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/FifoQueue.h @@ -0,0 +1,91 @@ +#pragma once + +#include <Common/Stopwatch.h> + +#include <IO/ISchedulerQueue.h> + +#include <Poco/Util/AbstractConfiguration.h> + +#include <deque> +#include <mutex> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +/* + * FIFO queue to hold pending resource requests + */ +class FifoQueue : public ISchedulerQueue +{ +public: + FifoQueue(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config, const String & config_prefix) + : ISchedulerQueue(event_queue_, config, config_prefix) + {} + + bool equals(ISchedulerNode * other) override + { + if (auto * o = dynamic_cast<FifoQueue *>(other)) + return true; + return false; + } + + void enqueueRequest(ResourceRequest * request) override + { + std::unique_lock lock(mutex); + request->enqueue_ns = clock_gettime_ns(); + bool was_empty = requests.empty(); + requests.push_back(request); + if (was_empty) + scheduleActivation(); + } + + std::pair<ResourceRequest *, bool> dequeueRequest() override + { + std::unique_lock lock(mutex); + if (requests.empty()) + return {nullptr, false}; + ResourceRequest * result = requests.front(); + requests.pop_front(); + return {result, !requests.empty()}; + } + + bool isActive() override + { + std::unique_lock lock(mutex); + return !requests.empty(); + } + + void activateChild(ISchedulerNode *) override + { + assert(false); // queue cannot have children + } + + void attachChild(const SchedulerNodePtr &) override + { + throw Exception( + ErrorCodes::INVALID_SCHEDULER_NODE, + "Cannot add child to leaf scheduler queue: {}", + getPath()); + } + + void removeChild(ISchedulerNode *) override + { + } + + ISchedulerNode * getChild(const String &) override + { + return nullptr; + } + +private: + std::mutex mutex; + std::deque<ResourceRequest *> requests; +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/PriorityPolicy.cpp b/contrib/clickhouse/src/IO/Resource/PriorityPolicy.cpp new file mode 100644 index 0000000000..bee9a6d5dd --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/PriorityPolicy.cpp @@ -0,0 +1,13 @@ +#include <IO/Resource/PriorityPolicy.h> + +#include <IO/SchedulerNodeFactory.h> + +namespace DB +{ + +void registerPriorityPolicy(SchedulerNodeFactory & factory) +{ + factory.registerMethod<PriorityPolicy>("priority"); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/PriorityPolicy.h b/contrib/clickhouse/src/IO/Resource/PriorityPolicy.h new file mode 100644 index 0000000000..3c091dcc85 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/PriorityPolicy.h @@ -0,0 +1,143 @@ +#pragma once + +#include <IO/ISchedulerQueue.h> +#include <IO/SchedulerRoot.h> + +#include <algorithm> +#include <unordered_map> +#include <vector> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +/* + * Scheduler node that implements priority scheduling policy. + * Requests are scheduled in order of priorities. + */ +class PriorityPolicy : public ISchedulerNode +{ + /// Scheduling state of a child + struct Item + { + ISchedulerNode * child = nullptr; + Priority priority; // lower value means higher priority + + /// For max-heap by priority + bool operator<(const Item& rhs) const noexcept + { + return priority > rhs.priority; // Reversed for heap top to yield highest priority (lowest value) child first + } + }; + +public: + PriorityPolicy(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + : ISchedulerNode(event_queue_, config, config_prefix) + {} + + bool equals(ISchedulerNode * other) override + { + if (auto * o = dynamic_cast<PriorityPolicy *>(other)) + return true; + return false; + } + + void attachChild(const SchedulerNodePtr & child) override + { + // Take ownership + chassert(child->parent == nullptr); + if (auto [it, inserted] = children.emplace(child->basename, child); !inserted) + throw Exception( + ErrorCodes::INVALID_SCHEDULER_NODE, + "Can't add another child with the same path: {}", + it->second->getPath()); + + // Attach + child->setParent(this); + + // Activate child if it is not empty + if (child->isActive()) + activateChild(child.get()); + } + + void removeChild(ISchedulerNode * child) override + { + if (auto iter = children.find(child->basename); iter != children.end()) + { + SchedulerNodePtr removed = iter->second; + + // Deactivate: detach is not very common operation, so we can afford O(N) here + for (auto i = items.begin(), e = items.end(); i != e; ++i) + { + if (i->child == removed.get()) + { + items.erase(i); + // Element was removed from inside of heap -- heap must be rebuilt + std::make_heap(items.begin(), items.end()); + break; + } + } + + // Detach + removed->setParent(nullptr); + + // Get rid of ownership + children.erase(iter); + } + } + + ISchedulerNode * getChild(const String & child_name) override + { + if (auto iter = children.find(child_name); iter != children.end()) + return iter->second.get(); + else + return nullptr; + } + + std::pair<ResourceRequest *, bool> dequeueRequest() override + { + if (items.empty()) + return {nullptr, false}; + + // Recursively pull request from child + auto [request, child_active] = items.front().child->dequeueRequest(); + assert(request != nullptr); + + // Deactivate child if it is empty + if (!child_active) + { + std::pop_heap(items.begin(), items.end()); + items.pop_back(); + } + + return {request, !items.empty()}; + } + + bool isActive() override + { + return !items.empty(); + } + + void activateChild(ISchedulerNode * child) override + { + bool activate_parent = items.empty(); + items.emplace_back(Item{child, child->info.priority}); + std::push_heap(items.begin(), items.end()); + if (activate_parent && parent) + parent->activateChild(this); + } + +private: + /// Heap of active children + std::vector<Item> items; + + /// All children with ownership + std::unordered_map<String, SchedulerNodePtr> children; // basename -> child +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/SemaphoreConstraint.cpp b/contrib/clickhouse/src/IO/Resource/SemaphoreConstraint.cpp new file mode 100644 index 0000000000..2135fd65a8 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/SemaphoreConstraint.cpp @@ -0,0 +1,13 @@ +#include <IO/Resource/SemaphoreConstraint.h> + +#include <IO/SchedulerNodeFactory.h> + +namespace DB +{ + +void registerSemaphoreConstraint(SchedulerNodeFactory & factory) +{ + factory.registerMethod<SemaphoreConstraint>("inflight_limit"); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/SemaphoreConstraint.h b/contrib/clickhouse/src/IO/Resource/SemaphoreConstraint.h new file mode 100644 index 0000000000..237e63eadd --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/SemaphoreConstraint.h @@ -0,0 +1,138 @@ +#pragma once + +#include <IO/ISchedulerConstraint.h> +#include <IO/SchedulerRoot.h> + +#include <mutex> +#include <limits> +#include <utility> + +namespace DB +{ + +/* + * Limited concurrency constraint. + * Blocks if either number of concurrent in-flight requests exceeds `max_requests`, or their total cost exceeds `max_cost` + */ +class SemaphoreConstraint : public ISchedulerConstraint +{ + static constexpr Int64 default_max_requests = std::numeric_limits<Int64>::max(); + static constexpr Int64 default_max_cost = std::numeric_limits<Int64>::max(); +public: + SemaphoreConstraint(EventQueue * event_queue_, const Poco::Util::AbstractConfiguration & config = emptyConfig(), const String & config_prefix = {}) + : ISchedulerConstraint(event_queue_, config, config_prefix) + , max_requests(config.getInt64(config_prefix + ".max_requests", default_max_requests)) + , max_cost(config.getInt64(config_prefix + ".max_cost", config.getInt64(config_prefix + ".max_bytes", default_max_cost))) + {} + + bool equals(ISchedulerNode * other) override + { + if (auto * o = dynamic_cast<SemaphoreConstraint *>(other)) + return max_requests == o->max_requests && max_cost == o->max_cost; + return false; + } + + void attachChild(const std::shared_ptr<ISchedulerNode> & child_) override + { + // Take ownership + child = child_; + child->setParent(this); + + // Activate if required + if (child->isActive()) + activateChild(child.get()); + } + + void removeChild(ISchedulerNode * child_) override + { + if (child.get() == child_) + { + child_active = false; // deactivate + child->setParent(nullptr); // detach + child.reset(); + } + } + + ISchedulerNode * getChild(const String & child_name) override + { + if (child->basename == child_name) + return child.get(); + else + return nullptr; + } + + std::pair<ResourceRequest *, bool> dequeueRequest() override + { + // Dequeue request from the child + auto [request, child_now_active] = child->dequeueRequest(); + if (!request) + return {nullptr, false}; + + // Request has reference to the first (closest to leaf) `constraint`, which can have `parent_constraint`. + // The former is initialized here dynamically and the latter is initialized once during hierarchy construction. + if (!request->constraint) + request->constraint = this; + + // Update state on request arrival + std::unique_lock lock(mutex); + requests++; + cost += request->cost; + child_active = child_now_active; + + return {request, active()}; + } + + void finishRequest(ResourceRequest * request) override + { + // Recursive traverse of parent flow controls in reverse order + if (parent_constraint) + parent_constraint->finishRequest(request); + + // Update state on request departure + std::unique_lock lock(mutex); + bool was_active = active(); + requests--; + cost -= request->cost; + + // Schedule activation on transition from inactive state + if (!was_active && active()) + scheduleActivation(); + } + + void activateChild(ISchedulerNode * child_) override + { + std::unique_lock lock(mutex); + if (child_ == child.get()) + if (!std::exchange(child_active, true) && satisfied() && parent) + parent->activateChild(this); + } + + bool isActive() override + { + std::unique_lock lock(mutex); + return active(); + } + +private: + bool satisfied() const + { + return requests < max_requests && cost < max_cost; + } + + bool active() const + { + return satisfied() && child_active; + } + +private: + std::mutex mutex; + Int64 requests = 0; + Int64 cost = 0; + bool child_active = false; + + SchedulerNodePtr child; + Int64 max_requests = default_max_requests; + Int64 max_cost = default_max_cost; +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/StaticResourceManager.cpp b/contrib/clickhouse/src/IO/Resource/StaticResourceManager.cpp new file mode 100644 index 0000000000..a79e8148f9 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/StaticResourceManager.cpp @@ -0,0 +1,138 @@ +#include <IO/Resource/StaticResourceManager.h> + +#include <IO/SchedulerNodeFactory.h> +#include <IO/ResourceManagerFactory.h> +#include <IO/ISchedulerQueue.h> + +#include <Common/Exception.h> +#include <Common/StringUtils/StringUtils.h> + +#include <map> +#include <tuple> +#include <algorithm> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int RESOURCE_ACCESS_DENIED; + extern const int RESOURCE_NOT_FOUND; + extern const int INVALID_SCHEDULER_NODE; +} + +StaticResourceManager::Resource::Resource( + const String & name, + EventQueue * event_queue, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix) +{ + // Initialize scheduler nodes + Poco::Util::AbstractConfiguration::Keys keys; + std::sort(keys.begin(), keys.end()); // for parents to appear before children + config.keys(config_prefix, keys); + for (const auto & key : keys) + { + if (!startsWith(key, "node")) + continue; + + // Validate path + String path = config.getString(config_prefix + "." + key + "[@path]", ""); + if (path.empty()) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Attribute 'path' must be specified in all nodes for resource '{}'", name); + if (path[0] != '/') + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "path must start with '/' for resource '{}'", name); + + // Create node + String type = config.getString(config_prefix + "." + key + ".type", "fifo"); + SchedulerNodePtr node = SchedulerNodeFactory::instance().get(type, event_queue, config, config_prefix + "." + key); + node->basename = path.substr(1); + + // Take ownership + if (auto [_, inserted] = nodes.emplace(path, node); !inserted) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Duplicate path '{}' for resource '{}'", path, name); + + // Attach created node to parent (if not root) + if (path != "/") + { + String parent_path = path.substr(0, path.rfind('/')); + if (parent_path.empty()) + parent_path = "/"; + if (auto parent = nodes.find(parent_path); parent != nodes.end()) + parent->second->attachChild(node); + else + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Parent doesn't exist for path '{}' for resource '{}'", path, name); + } + } + + if (nodes.find("/") == nodes.end()) + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "undefined root node path '/' for resource '{}'", name); +} + +StaticResourceManager::Classifier::Classifier(const StaticResourceManager & manager, const ClassifierDescription & cfg) +{ + for (auto [resource_name, path] : cfg) + { + if (auto resource_iter = manager.resources.find(resource_name); resource_iter != manager.resources.end()) + { + const Resource & resource = resource_iter->second; + if (auto node_iter = resource.nodes.find(path); node_iter != resource.nodes.end()) + { + if (auto * queue = dynamic_cast<ISchedulerQueue *>(node_iter->second.get())) + resources.emplace(resource_name, ResourceLink{.queue = queue}); + else + throw Exception(ErrorCodes::RESOURCE_NOT_FOUND, "Unable to access non-queue node at path '{}' for resource '{}'", path, resource_name); + } + else + throw Exception(ErrorCodes::RESOURCE_NOT_FOUND, "Path '{}' for resource '{}' does not exist", path, resource_name); + } + else + resources.emplace(resource_name, ResourceLink{}); // resource not configured - unlimited + } +} + +ResourceLink StaticResourceManager::Classifier::get(const String & resource_name) +{ + if (auto iter = resources.find(resource_name); iter != resources.end()) + return iter->second; + else + throw Exception(ErrorCodes::RESOURCE_ACCESS_DENIED, "Access denied to resource '{}'", resource_name); +} + +void StaticResourceManager::updateConfiguration(const Poco::Util::AbstractConfiguration & config) +{ + if (!resources.empty()) + return; // already initialized, configuration update is not supported + + Poco::Util::AbstractConfiguration::Keys keys; + const String config_prefix = "resources"; + config.keys(config_prefix, keys); + + // Create resource for every element under <resources> tag + for (const auto & key : keys) + { + auto [iter, _] = resources.emplace(std::piecewise_construct, + std::forward_as_tuple(key), + std::forward_as_tuple(key, scheduler.event_queue, config, config_prefix + "." + key)); + // Attach root of resource to scheduler + scheduler.attachChild(iter->second.nodes.find("/")->second); + } + + // Initialize classifiers + classifiers = std::make_unique<ClassifiersConfig>(config); + + // Run scheduler thread + scheduler.start(); +} + +ClassifierPtr StaticResourceManager::acquire(const String & classifier_name) +{ + return std::make_shared<Classifier>(*this, classifiers->get(classifier_name)); +} + +void registerStaticResourceManager(ResourceManagerFactory & factory) +{ + factory.registerMethod<StaticResourceManager>("static"); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/StaticResourceManager.h b/contrib/clickhouse/src/IO/Resource/StaticResourceManager.h new file mode 100644 index 0000000000..066dbf4ebf --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/StaticResourceManager.h @@ -0,0 +1,49 @@ +#pragma once + +#include <IO/IResourceManager.h> +#include <IO/SchedulerRoot.h> +#include <IO/Resource/ClassifiersConfig.h> + +#include <mutex> + +namespace DB +{ + +/* + * Reads `<resources>` from config at startup and registers them in single `SchedulerRoot`. + * Do not support configuration updates, server restart is required. + */ +class StaticResourceManager : public IResourceManager +{ +public: + // Just initialization, any further updates are ignored for the sake of simplicity + // NOTE: manager must be initialized before any acquire() calls to avoid races + void updateConfiguration(const Poco::Util::AbstractConfiguration & config) override; + + ClassifierPtr acquire(const String & classifier_name) override; + +private: + struct Resource + { + std::unordered_map<String, SchedulerNodePtr> nodes; // by paths + + Resource( + const String & name, + EventQueue * event_queue, + const Poco::Util::AbstractConfiguration & config, + const std::string & config_prefix); + }; + + struct Classifier : public IClassifier + { + Classifier(const StaticResourceManager & manager, const ClassifierDescription & cfg); + ResourceLink get(const String & resource_name) override; + std::unordered_map<String, ResourceLink> resources; // accessible resources by names + }; + + SchedulerRoot scheduler; + std::unordered_map<String, Resource> resources; // by name + std::unique_ptr<ClassifiersConfig> classifiers; +}; + +} diff --git a/contrib/clickhouse/src/IO/Resource/registerResourceManagers.cpp b/contrib/clickhouse/src/IO/Resource/registerResourceManagers.cpp new file mode 100644 index 0000000000..5217bcdfbe --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/registerResourceManagers.cpp @@ -0,0 +1,17 @@ +#include <IO/Resource/registerResourceManagers.h> +#include <IO/ResourceManagerFactory.h> + +namespace DB +{ + +void registerDynamicResourceManager(ResourceManagerFactory &); +void registerStaticResourceManager(ResourceManagerFactory &); + +void registerResourceManagers() +{ + auto & factory = ResourceManagerFactory::instance(); + registerDynamicResourceManager(factory); + registerStaticResourceManager(factory); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/registerResourceManagers.h b/contrib/clickhouse/src/IO/Resource/registerResourceManagers.h new file mode 100644 index 0000000000..243b25a958 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/registerResourceManagers.h @@ -0,0 +1,8 @@ +#pragma once + +namespace DB +{ + +void registerResourceManagers(); + +} diff --git a/contrib/clickhouse/src/IO/Resource/registerSchedulerNodes.cpp b/contrib/clickhouse/src/IO/Resource/registerSchedulerNodes.cpp new file mode 100644 index 0000000000..896f96d7f5 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/registerSchedulerNodes.cpp @@ -0,0 +1,30 @@ +#include <IO/Resource/registerSchedulerNodes.h> + +#include <IO/ISchedulerNode.h> +#include <IO/ISchedulerConstraint.h> +#include <IO/SchedulerNodeFactory.h> + +namespace DB +{ + +void registerPriorityPolicy(SchedulerNodeFactory &); +void registerFairPolicy(SchedulerNodeFactory &); +void registerSemaphoreConstraint(SchedulerNodeFactory &); +void registerFifoQueue(SchedulerNodeFactory &); + +void registerSchedulerNodes() +{ + auto & factory = SchedulerNodeFactory::instance(); + + // ISchedulerNode + registerPriorityPolicy(factory); + registerFairPolicy(factory); + + // ISchedulerConstraint + registerSemaphoreConstraint(factory); + + // ISchedulerQueue + registerFifoQueue(factory); +} + +} diff --git a/contrib/clickhouse/src/IO/Resource/registerSchedulerNodes.h b/contrib/clickhouse/src/IO/Resource/registerSchedulerNodes.h new file mode 100644 index 0000000000..1e2092aaf4 --- /dev/null +++ b/contrib/clickhouse/src/IO/Resource/registerSchedulerNodes.h @@ -0,0 +1,8 @@ +#pragma once + +namespace DB +{ + +void registerSchedulerNodes(); + +} diff --git a/contrib/clickhouse/src/IO/ResourceBudget.h b/contrib/clickhouse/src/IO/ResourceBudget.h new file mode 100644 index 0000000000..7f67f9cfc1 --- /dev/null +++ b/contrib/clickhouse/src/IO/ResourceBudget.h @@ -0,0 +1,55 @@ +#pragma once + +#include <IO/ResourceRequest.h> +#include <atomic> + +namespace DB +{ + +/* + * Helper class to keep track of requested and consumed amount of resource. + * Useful if real amount of consumed resource can differ from requested amount of resource (e.g. in case of failures). + * Can be safely used from multiple threads. + * Usage example: + * ResourceBudget budget; + * while (!stop) { + * ResourceCost est_cost = myEstimateOfCostOrJustUseOne(); + * myAllocateResource(budget.ask(est_cost)); // Ask external system to allocate resource for you + * ResourceCost real_cost = mySynchronousConsumptionOfResource(); // Real consumption can differ from est_cost + * budget.adjust(est_cost, real_cost); // Adjust balance according to the actual cost, may affect the next iteration + * } + */ +class ResourceBudget +{ +public: + // Returns amount of resource to be requested according to current balance and estimated cost of new consumption + ResourceCost ask(ResourceCost estimated_cost) + { + ResourceCost budget = available.load(); + while (true) + { + // Valid resource request must have positive `cost`. Also takes consumption history into account. + ResourceCost cost = std::max<ResourceCost>(1ll, estimated_cost - budget); + + // Assume every request is satisfied (no resource request cancellation is possible now) + // So we requested additional `cost` units and are going to consume `estimated_cost` + ResourceCost new_budget = budget + cost - estimated_cost; + + // Try to commit this transaction + if (new_budget == budget || available.compare_exchange_strong(budget, new_budget)) + return cost; + } + } + + // Should be called to account for difference between real and estimated costs + // Optional. May be skipped if `real_cost` is known in advance (equals `estimated_cost`). + void adjust(ResourceCost estimated_cost, ResourceCost real_cost) + { + available.fetch_add(estimated_cost - real_cost); + } + +private: + std::atomic<ResourceCost> available = 0; // requested - consumed +}; + +} diff --git a/contrib/clickhouse/src/IO/ResourceGuard.h b/contrib/clickhouse/src/IO/ResourceGuard.h new file mode 100644 index 0000000000..92f25b40f6 --- /dev/null +++ b/contrib/clickhouse/src/IO/ResourceGuard.h @@ -0,0 +1,139 @@ +#pragma once + +#include <base/types.h> + +#include <IO/ResourceRequest.h> +#include <IO/ResourceLink.h> +#include <IO/ISchedulerConstraint.h> + +#include <condition_variable> +#include <mutex> + + +namespace DB +{ + +/* + * Scoped resource guard. + * Waits for resource to be available in constructor and releases resource in destructor + * IMPORTANT: multiple resources should not be locked concurrently by a single thread + */ +class ResourceGuard +{ +public: + enum ResourceGuardCtor + { + LockStraightAway, /// Locks inside constructor (default) + + // WARNING: Only for tests. It is not exception-safe because `lock()` must be called after construction. + PostponeLocking /// Don't lock in constructor, but send request + }; + + enum RequestState + { + Finished, // Last request has already finished; no concurrent access is possible + Enqueued, // Enqueued into the scheduler; thread-safe access is required + Dequeued // Dequeued from the scheduler and is in consumption state; no concurrent access is possible + }; + + class Request : public ResourceRequest + { + public: + void enqueue(ResourceCost cost_, ResourceLink link_) + { + // lock(mutex) is not required because `Finished` request cannot be used by the scheduler thread + chassert(state == Finished); + state = Enqueued; + ResourceRequest::reset(cost_); + link_.queue->enqueueRequestUsingBudget(this); + } + + // This function is executed inside scheduler thread and wakes thread issued this `request`. + // That thread will continue execution and do real consumption of requested resource synchronously. + void execute() override + { + { + std::unique_lock lock(mutex); + chassert(state == Enqueued); + state = Dequeued; + } + dequeued_cv.notify_one(); + } + + void wait() + { + std::unique_lock lock(mutex); + dequeued_cv.wait(lock, [this] { return state == Dequeued; }); + } + + void finish() + { + // lock(mutex) is not required because `Dequeued` request cannot be used by the scheduler thread + chassert(state == Dequeued); + state = Finished; + if (constraint) + constraint->finishRequest(this); + } + + static Request & local() + { + // Since single thread cannot use more than one resource request simultaneously, + // we can reuse thread-local request to avoid allocations + static thread_local Request instance; + return instance; + } + + private: + std::mutex mutex; + std::condition_variable dequeued_cv; + RequestState state = Finished; + }; + + /// Creates pending request for resource; blocks while resource is not available (unless `PostponeLocking`) + explicit ResourceGuard(ResourceLink link_, ResourceCost cost = 1, ResourceGuardCtor ctor = LockStraightAway) + : link(link_) + , request(Request::local()) + { + if (cost == 0) + link.queue = nullptr; // Ignore zero-cost requests + else if (link.queue) + { + request.enqueue(cost, link); + if (ctor == LockStraightAway) + request.wait(); + } + } + + ~ResourceGuard() + { + unlock(); + } + + /// Blocks until resource is available + void lock() + { + if (link.queue) + request.wait(); + } + + /// Report resource consumption has finished + void unlock() + { + if (link.queue) + { + request.finish(); + link.queue = nullptr; + } + } + + /// Mark request as unsuccessful; by default request is considered to be successful + void setFailure() + { + request.successful = false; + } + + ResourceLink link; + Request & request; +}; + +} diff --git a/contrib/clickhouse/src/IO/ResourceLink.h b/contrib/clickhouse/src/IO/ResourceLink.h new file mode 100644 index 0000000000..2da5e75fcb --- /dev/null +++ b/contrib/clickhouse/src/IO/ResourceLink.h @@ -0,0 +1,39 @@ +#pragma once + +#include <base/types.h> + +#include <IO/ResourceRequest.h> +#include <IO/ISchedulerQueue.h> + + +namespace DB +{ + +/* + * Everything required for resource consumption. Connection to a specific resource queue. + */ +struct ResourceLink +{ + ISchedulerQueue * queue = nullptr; + bool operator==(const ResourceLink &) const = default; + + void adjust(ResourceCost estimated_cost, ResourceCost real_cost) const + { + if (queue) + queue->adjustBudget(estimated_cost, real_cost); + } + + void consumed(ResourceCost cost) const + { + if (queue) + queue->consumeBudget(cost); + } + + void accumulate(ResourceCost cost) const + { + if (queue) + queue->accumulateBudget(cost); + } +}; + +} diff --git a/contrib/clickhouse/src/IO/ResourceManagerFactory.h b/contrib/clickhouse/src/IO/ResourceManagerFactory.h new file mode 100644 index 0000000000..8e972f0564 --- /dev/null +++ b/contrib/clickhouse/src/IO/ResourceManagerFactory.h @@ -0,0 +1,55 @@ +#pragma once + +#include <Common/ErrorCodes.h> +#include <Common/Exception.h> + +#include <IO/IResourceManager.h> + +#include <boost/noncopyable.hpp> + +#include <memory> +#include <mutex> +#include <unordered_map> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +class ResourceManagerFactory : private boost::noncopyable +{ +public: + static ResourceManagerFactory & instance() + { + static ResourceManagerFactory ret; + return ret; + } + + ResourceManagerPtr get(const String & name) + { + std::lock_guard lock{mutex}; + if (auto iter = methods.find(name); iter != methods.end()) + return iter->second(); + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Unknown scheduler node type: {}", name); + } + + template <class TDerived> + void registerMethod(const String & name) + { + std::lock_guard lock{mutex}; + methods[name] = [] () + { + return std::make_shared<TDerived>(); + }; + } + +private: + std::mutex mutex; + using Method = std::function<ResourceManagerPtr()>; + std::unordered_map<String, Method> methods; +}; + +} diff --git a/contrib/clickhouse/src/IO/ResourceRequest.h b/contrib/clickhouse/src/IO/ResourceRequest.h new file mode 100644 index 0000000000..989349148c --- /dev/null +++ b/contrib/clickhouse/src/IO/ResourceRequest.h @@ -0,0 +1,89 @@ +#pragma once + +#include <base/types.h> +#include <limits> + +namespace DB +{ + +// Forward declarations +class ISchedulerQueue; +class ISchedulerConstraint; + +/// Cost in terms of used resource (e.g. bytes for network IO) +using ResourceCost = Int64; +constexpr ResourceCost ResourceCostMax = std::numeric_limits<int>::max(); + +/// Timestamps (nanoseconds since epoch) +using ResourceNs = UInt64; + +/* + * Request for a resource consumption. The main moving part of the scheduling subsystem. + * Resource requests processing workflow: + * + * ----1=2222222222222=3=4=555555555555555=6-----> time + * ^ ^ ^ ^ ^ ^ + * | | | | | | + * enqueue wait dequeue execute consume finish + * + * 1) Request is enqueued using ISchedulerQueue::enqueueRequest(). + * 2) Request competes with others for access to a resource; effectively just waiting in a queue. + * 3) Scheduler calls ISchedulerNode::dequeueRequest() that returns the request. + * 4) Callback ResourceRequest::execute() is called to provide access to the resource. + * 5) The resource consumption is happening outside of the scheduling subsystem. + * 6) request->constraint->finishRequest() is called when consumption is finished. + * + * Steps (5) and (6) can be omitted if constraint is not used by the resource. + * + * Request can be created on stack or heap. + * Request ownership is done outside of the scheduling subsystem. + * After (6) request can be destructed safely. + * + * Request cancelling is not supported yet. + */ +class ResourceRequest +{ +public: + /// Cost of request execution; should be filled before request enqueueing. + /// NOTE: If cost is not known in advance, credit model can be used: + /// NOTE: for the first request use 1 and + ResourceCost cost; + + /// Request outcome + /// Should be filled during resource consumption + bool successful; + + /// Scheduler node to be notified on consumption finish + /// Auto-filled during request enqueue/dequeue + ISchedulerConstraint * constraint; + + /// Timestamps for introspection + ResourceNs enqueue_ns; + ResourceNs execute_ns; + ResourceNs finish_ns; + + explicit ResourceRequest(ResourceCost cost_ = 1) + { + reset(cost_); + } + + void reset(ResourceCost cost_) + { + cost = cost_; + successful = true; + constraint = nullptr; + enqueue_ns = 0; + execute_ns = 0; + finish_ns = 0; + } + + virtual ~ResourceRequest() = default; + + /// Callback to trigger resource consumption. + /// IMPORTANT: it is called from scheduler thread and must be fast, + /// just triggering start of a consumption, not doing the consumption itself + /// (e.g. setting an std::promise or creating a job in a thread pool) + virtual void execute() = 0; +}; + +} diff --git a/contrib/clickhouse/src/IO/S3/AWSLogger.cpp b/contrib/clickhouse/src/IO/S3/AWSLogger.cpp new file mode 100644 index 0000000000..d6162823ae --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/AWSLogger.cpp @@ -0,0 +1,78 @@ +#include <IO/S3/AWSLogger.h> + +#if USE_AWS_S3 + +#include <Core/SettingsEnums.h> +#include <Common/logger_useful.h> +#include <aws/core/utils/logging/LogLevel.h> +#include <Poco/Logger.h> + +namespace +{ + +const char * S3_LOGGER_TAG_NAMES[][2] = { + {"AWSClient", "AWSClient"}, + {"AWSAuthV4Signer", "AWSClient (AWSAuthV4Signer)"}, +}; + +const std::pair<DB::LogsLevel, Poco::Message::Priority> & convertLogLevel(Aws::Utils::Logging::LogLevel log_level) +{ + /// We map levels to our own logger 1 to 1 except WARN+ levels. In most cases we failover such errors with retries + /// and don't want to see them as Errors in our logs. + static const std::unordered_map<Aws::Utils::Logging::LogLevel, std::pair<DB::LogsLevel, Poco::Message::Priority>> mapping = + { + {Aws::Utils::Logging::LogLevel::Off, {DB::LogsLevel::none, Poco::Message::PRIO_INFORMATION}}, + {Aws::Utils::Logging::LogLevel::Fatal, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}}, + {Aws::Utils::Logging::LogLevel::Error, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}}, + {Aws::Utils::Logging::LogLevel::Warn, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}}, + {Aws::Utils::Logging::LogLevel::Info, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}}, + {Aws::Utils::Logging::LogLevel::Debug, {DB::LogsLevel::debug, Poco::Message::PRIO_TEST}}, + {Aws::Utils::Logging::LogLevel::Trace, {DB::LogsLevel::trace, Poco::Message::PRIO_TEST}}, + }; + return mapping.at(log_level); +} + +} + +namespace DB::S3 +{ + +AWSLogger::AWSLogger(bool enable_s3_requests_logging_) + : enable_s3_requests_logging(enable_s3_requests_logging_) +{ + for (auto [tag, name] : S3_LOGGER_TAG_NAMES) + tag_loggers[tag] = &Poco::Logger::get(name); + + default_logger = tag_loggers[S3_LOGGER_TAG_NAMES[0][0]]; +} + +Aws::Utils::Logging::LogLevel AWSLogger::GetLogLevel() const +{ + if (enable_s3_requests_logging) + return Aws::Utils::Logging::LogLevel::Trace; + else + return Aws::Utils::Logging::LogLevel::Info; +} + +void AWSLogger::Log(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * format_str, ...) // NOLINT +{ + callLogImpl(log_level, tag, format_str); /// FIXME. Variadic arguments? +} + +void AWSLogger::LogStream(Aws::Utils::Logging::LogLevel log_level, const char * tag, const Aws::OStringStream & message_stream) +{ + callLogImpl(log_level, tag, message_stream.str().c_str()); +} + +void AWSLogger::callLogImpl(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * message) +{ + const auto & [level, prio] = convertLogLevel(log_level); + if (tag_loggers.contains(tag)) + LOG_IMPL(tag_loggers[tag], level, prio, fmt::runtime(message)); + else + LOG_IMPL(default_logger, level, prio, "{}: {}", tag, message); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/AWSLogger.h b/contrib/clickhouse/src/IO/S3/AWSLogger.h new file mode 100644 index 0000000000..7c31ea469f --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/AWSLogger.h @@ -0,0 +1,39 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 +#include <aws/core/utils/logging/LogSystemInterface.h> +#include <base/types.h> +#include <unordered_map> + +namespace Poco { class Logger; } + +namespace DB::S3 +{ +class AWSLogger final : public Aws::Utils::Logging::LogSystemInterface +{ +public: + explicit AWSLogger(bool enable_s3_requests_logging_); + + ~AWSLogger() final = default; + + Aws::Utils::Logging::LogLevel GetLogLevel() const final; + + void Log(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * format_str, ...) final; // NOLINT + + void LogStream(Aws::Utils::Logging::LogLevel log_level, const char * tag, const Aws::OStringStream & message_stream) final; + + void callLogImpl(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * message); + + void Flush() final {} + +private: + Poco::Logger * default_logger; + bool enable_s3_requests_logging; + std::unordered_map<String, Poco::Logger *> tag_loggers; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/Client.cpp b/contrib/clickhouse/src/IO/S3/Client.cpp new file mode 100644 index 0000000000..104fc2dd5b --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/Client.cpp @@ -0,0 +1,903 @@ +#include <IO/S3/Client.h> + +#if USE_AWS_S3 + +#include <aws/core/client/CoreErrors.h> +#include <aws/core/client/DefaultRetryStrategy.h> +#include <aws/s3/model/HeadBucketRequest.h> +#include <aws/s3/model/GetObjectRequest.h> +#include <aws/s3/model/HeadObjectRequest.h> +#include <aws/s3/model/ListObjectsV2Request.h> +#include <aws/core/client/AWSErrorMarshaller.h> +// #include <aws/core/endpoint/EndpointParameter.h> +#include <aws/core/utils/HashingUtils.h> +// #include <aws/core/utils/logging/ErrorMacros.h> + +#include <Poco/Net/NetException.h> + +#include <IO/S3Common.h> +#include <IO/S3/Requests.h> +#include <IO/S3/PocoHTTPClientFactory.h> +#include <IO/S3/AWSLogger.h> +#include <IO/S3/Credentials.h> + +#include <Common/assert_cast.h> + +#include <Common/logger_useful.h> +#include <Common/ProxyConfigurationResolverProvider.h> + + +namespace ProfileEvents +{ + extern const Event S3WriteRequestsErrors; + extern const Event S3ReadRequestsErrors; + + extern const Event DiskS3WriteRequestsErrors; + extern const Event DiskS3ReadRequestsErrors; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int TOO_MANY_REDIRECTS; +} + +namespace S3 +{ + +Client::RetryStrategy::RetryStrategy(std::shared_ptr<Aws::Client::RetryStrategy> wrapped_strategy_) + : wrapped_strategy(std::move(wrapped_strategy_)) +{ + // if (!wrapped_strategy) + // wrapped_strategy = Aws::Client::InitRetryStrategy(); +} + +/// NOLINTNEXTLINE(google-runtime-int) +bool Client::RetryStrategy::ShouldRetry(const Aws::Client::AWSError<Aws::Client::CoreErrors>& error, long attemptedRetries) const +{ + if (error.GetResponseCode() == Aws::Http::HttpResponseCode::MOVED_PERMANENTLY) + return false; + + return wrapped_strategy->ShouldRetry(error, attemptedRetries); +} + +/// NOLINTNEXTLINE(google-runtime-int) +long Client::RetryStrategy::CalculateDelayBeforeNextRetry(const Aws::Client::AWSError<Aws::Client::CoreErrors>& error, long attemptedRetries) const +{ + return wrapped_strategy->CalculateDelayBeforeNextRetry(error, attemptedRetries); +} + +/// NOLINTNEXTLINE(google-runtime-int) +long Client::RetryStrategy::GetMaxAttempts() const +{ + return wrapped_strategy->GetMaxAttempts(); +} + +void Client::RetryStrategy::GetSendToken() +{ + return wrapped_strategy->GetSendToken(); +} + +// bool Client::RetryStrategy::HasSendToken() +// { +// return wrapped_strategy->HasSendToken(); +// } + +void Client::RetryStrategy::RequestBookkeeping(const Aws::Client::HttpResponseOutcome& httpResponseOutcome) +{ + return wrapped_strategy->RequestBookkeeping(httpResponseOutcome); +} + +void Client::RetryStrategy::RequestBookkeeping(const Aws::Client::HttpResponseOutcome& httpResponseOutcome, const Aws::Client::AWSError<Aws::Client::CoreErrors>& lastError) +{ + return wrapped_strategy->RequestBookkeeping(httpResponseOutcome, lastError); +} + +namespace +{ + +void verifyClientConfiguration(const Aws::Client::ClientConfiguration & client_config) +{ + if (!client_config.retryStrategy) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The S3 client can only be used with Client::RetryStrategy, define it in the client configuration"); + + assert_cast<const Client::RetryStrategy &>(*client_config.retryStrategy); +} + +} + +std::unique_ptr<Client> Client::create( + size_t max_redirects_, + ServerSideEncryptionKMSConfig sse_kms_config_, + const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider, + const PocoHTTPClientConfiguration & client_configuration, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads, + bool use_virtual_addressing) +{ + verifyClientConfiguration(client_configuration); + return std::unique_ptr<Client>( + new Client(max_redirects_, std::move(sse_kms_config_), credentials_provider, client_configuration, sign_payloads, use_virtual_addressing)); +} + +std::unique_ptr<Client> Client::clone( + std::optional<std::shared_ptr<RetryStrategy>> override_retry_strategy, + std::optional<Int64> override_request_timeout_ms) const +{ + PocoHTTPClientConfiguration new_configuration = client_configuration; + if (override_retry_strategy.has_value()) + new_configuration.retryStrategy = *override_retry_strategy; + if (override_request_timeout_ms.has_value()) + new_configuration.requestTimeoutMs = *override_request_timeout_ms; + return std::unique_ptr<Client>(new Client(*this, new_configuration)); +} + +namespace +{ + +ProviderType deduceProviderType(const std::string & url) +{ + if (url.find(".amazonaws.com") != std::string::npos) + return ProviderType::AWS; + + if (url.find("storage.googleapis.com") != std::string::npos) + return ProviderType::GCS; + + return ProviderType::UNKNOWN; +} + +} + +Client::Client( + size_t max_redirects_, + ServerSideEncryptionKMSConfig sse_kms_config_, + const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider_, + const PocoHTTPClientConfiguration & client_configuration_, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads_, + bool use_virtual_addressing_) + : Aws::S3::S3Client(credentials_provider_, client_configuration_, sign_payloads_, use_virtual_addressing_) + , credentials_provider(credentials_provider_) + , client_configuration(client_configuration_) + , sign_payloads(sign_payloads_) + , use_virtual_addressing(use_virtual_addressing_) + , max_redirects(max_redirects_) + , sse_kms_config(std::move(sse_kms_config_)) + , log(&Poco::Logger::get("S3Client")) +{ +#if 0 + auto * endpoint_provider = dynamic_cast<Aws::S3::Endpoint::S3DefaultEpProviderBase *>(accessEndpointProvider().get()); + endpoint_provider->GetBuiltInParameters().GetParameter("Region").GetString(explicit_region); + endpoint_provider->GetBuiltInParameters().GetParameter("Endpoint").GetString(initial_endpoint); +#endif + + provider_type = deduceProviderType(initial_endpoint); + LOG_TRACE(log, "Provider type: {}", toString(provider_type)); + + if (provider_type == ProviderType::GCS) + { + /// GCS can operate in 2 modes for header and query params names: + /// - with both x-amz and x-goog prefixes allowed (but cannot mix different prefixes in same request) + /// - only with x-goog prefix + /// first mode is allowed only with HMAC (or unsigned requests) so when we + /// find credential keys we can simply behave as the underlying storage is S3 + /// otherwise, we need to be aware we are making requests to GCS + /// and replace all headers with a valid prefix when needed + if (credentials_provider) + { + auto credentials = credentials_provider->GetAWSCredentials(); + if (credentials.IsEmpty()) + api_mode = ApiMode::GCS; + } + } + + LOG_TRACE(log, "API mode of the S3 client: {}", api_mode); + + detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL; + + cache = std::make_shared<ClientCache>(); + ClientCacheRegistry::instance().registerClient(cache); +} + +Client::Client( + const Client & other, const PocoHTTPClientConfiguration & client_configuration_) + : Aws::S3::S3Client(other.credentials_provider, client_configuration_, other.sign_payloads, + other.use_virtual_addressing) + , initial_endpoint(other.initial_endpoint) + , credentials_provider(other.credentials_provider) + , client_configuration(client_configuration_) + , sign_payloads(other.sign_payloads) + , use_virtual_addressing(other.use_virtual_addressing) + , explicit_region(other.explicit_region) + , detect_region(other.detect_region) + , provider_type(other.provider_type) + , max_redirects(other.max_redirects) + , sse_kms_config(other.sse_kms_config) + , log(&Poco::Logger::get("S3Client")) +{ + cache = std::make_shared<ClientCache>(*other.cache); + ClientCacheRegistry::instance().registerClient(cache); +} + +Aws::Auth::AWSCredentials Client::getCredentials() const +{ + return credentials_provider->GetAWSCredentials(); +} + +bool Client::checkIfWrongRegionDefined(const std::string & bucket, const Aws::S3::S3Error & error, std::string & region) const +{ + if (detect_region) + return false; + + if (error.GetResponseCode() == Aws::Http::HttpResponseCode::BAD_REQUEST && error.GetExceptionName() == "AuthorizationHeaderMalformed") + { + region = GetErrorMarshaller()->ExtractRegion(error); + + if (region.empty()) + region = getRegionForBucket(bucket, /*force_detect*/ true); + + assert(!explicit_region.empty()); + if (region == explicit_region) + return false; + + insertRegionOverride(bucket, region); + return true; + } + + return false; +} + +void Client::insertRegionOverride(const std::string & bucket, const std::string & region) const +{ + std::lock_guard lock(cache->region_cache_mutex); + auto [it, inserted] = cache->region_for_bucket_cache.emplace(bucket, region); + if (inserted) + LOG_INFO(log, "Detected different region ('{}') for bucket {} than the one defined ('{}')", region, bucket, explicit_region); +} + +template <typename RequestType> +void Client::setKMSHeaders(RequestType & request) const +{ + // Don't do anything unless a key ID was specified + if (sse_kms_config.key_id) + { + request.SetServerSideEncryption(Model::ServerSideEncryption::aws_kms); + // If the key ID was specified but is empty, treat it as using the AWS managed key and omit the header + if (!sse_kms_config.key_id->empty()) + request.SetSSEKMSKeyId(*sse_kms_config.key_id); + if (sse_kms_config.encryption_context) + request.SetSSEKMSEncryptionContext(*sse_kms_config.encryption_context); + if (sse_kms_config.bucket_key_enabled) + request.SetBucketKeyEnabled(*sse_kms_config.bucket_key_enabled); + } +} + +// Explicitly instantiate this method only for the request types that support KMS headers +template void Client::setKMSHeaders<CreateMultipartUploadRequest>(CreateMultipartUploadRequest & request) const; +template void Client::setKMSHeaders<CopyObjectRequest>(CopyObjectRequest & request) const; +template void Client::setKMSHeaders<PutObjectRequest>(PutObjectRequest & request) const; + +Model::HeadObjectOutcome Client::HeadObject(const HeadObjectRequest & request) const +{ + const auto & bucket = request.GetBucket(); + + request.setApiMode(api_mode); + + if (auto region = getRegionForBucket(bucket); !region.empty()) + { + if (!detect_region) + LOG_INFO(log, "Using region override {} for bucket {}", region, bucket); + request.overrideRegion(std::move(region)); + } + + if (auto uri = getURIForBucket(bucket); uri.has_value()) + request.overrideURI(std::move(*uri)); + + auto result = HeadObject(static_cast<const Model::HeadObjectRequest&>(request)); + if (result.IsSuccess()) + return result; + + const auto & error = result.GetError(); + + std::string new_region; + if (checkIfWrongRegionDefined(bucket, error, new_region)) + { + request.overrideRegion(new_region); + return Aws::S3::S3Client::HeadObject(request); + } + + if (error.GetResponseCode() != Aws::Http::HttpResponseCode::MOVED_PERMANENTLY) + return result; + + // maybe we detect a correct region + if (!detect_region) + { + if (auto region = GetErrorMarshaller()->ExtractRegion(error); !region.empty() && region != explicit_region) + { + request.overrideRegion(region); + insertRegionOverride(bucket, region); + } + } + + auto bucket_uri = getURIForBucket(bucket); + if (!bucket_uri) + { + if (auto maybe_error = updateURIForBucketForHead(bucket); maybe_error.has_value()) + return *maybe_error; + + if (auto region = getRegionForBucket(bucket); !region.empty()) + { + if (!detect_region) + LOG_INFO(log, "Using region override {} for bucket {}", region, bucket); + request.overrideRegion(std::move(region)); + } + + bucket_uri = getURIForBucket(bucket); + if (!bucket_uri) + { + LOG_ERROR(log, "Missing resolved URI for bucket {}, maybe the cache was cleaned", bucket); + return result; + } + } + + const auto & current_uri_override = request.getURIOverride(); + /// we already tried with this URI + if (current_uri_override && current_uri_override->uri == bucket_uri->uri) + { + LOG_INFO(log, "Getting redirected to the same invalid location {}", bucket_uri->uri.toString()); + return result; + } + + request.overrideURI(std::move(*bucket_uri)); + + /// The next call is NOT a recurcive call + /// This is a virtuall call Aws::S3::S3Client::HeadObject(const Model::HeadObjectRequest&) + return HeadObject(static_cast<const Model::HeadObjectRequest&>(request)); +} + +/// For each request, we wrap the request functions from Aws::S3::Client with doRequest +/// doRequest calls virtuall function from Aws::S3::Client while DB::S3::Client has not virtual calls for each request type + +Model::ListObjectsV2Outcome Client::ListObjectsV2(const ListObjectsV2Request & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ true>( + request, [this](const Model::ListObjectsV2Request & req) { return ListObjectsV2(req); }); +} + +Model::ListObjectsOutcome Client::ListObjects(const ListObjectsRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ true>( + request, [this](const Model::ListObjectsRequest & req) { return ListObjects(req); }); +} + +Model::GetObjectOutcome Client::GetObject(const GetObjectRequest & request) const +{ + return doRequest(request, [this](const Model::GetObjectRequest & req) { return GetObject(req); }); +} + +Model::AbortMultipartUploadOutcome Client::AbortMultipartUpload(const AbortMultipartUploadRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::AbortMultipartUploadRequest & req) { return AbortMultipartUpload(req); }); +} + +Model::CreateMultipartUploadOutcome Client::CreateMultipartUpload(const CreateMultipartUploadRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::CreateMultipartUploadRequest & req) { return CreateMultipartUpload(req); }); +} + +Model::CompleteMultipartUploadOutcome Client::CompleteMultipartUpload(const CompleteMultipartUploadRequest & request) const +{ + auto outcome = doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::CompleteMultipartUploadRequest & req) { return CompleteMultipartUpload(req); }); + + if (!outcome.IsSuccess() || provider_type != ProviderType::GCS) + return outcome; + + const auto & key = request.GetKey(); + const auto & bucket = request.GetBucket(); + + /// For GCS we will try to compose object at the end, otherwise we cannot do a native copy + /// for the object (e.g. for backups) + /// We don't care if the compose fails, because the upload was still successful, only the + /// performance for copying the object will be affected + S3::ComposeObjectRequest compose_req; + compose_req.SetBucket(bucket); + compose_req.SetKey(key); + compose_req.SetComponentNames({key}); + compose_req.SetContentType("binary/octet-stream"); +#if 0 + auto compose_outcome = ComposeObject(compose_req); + + if (compose_outcome.IsSuccess()) + LOG_TRACE(log, "Composing object was successful"); + else + LOG_INFO(log, "Failed to compose object. Message: {}, Key: {}, Bucket: {}", compose_outcome.GetError().GetMessage(), key, bucket); +#endif + + return outcome; +} + +Model::CopyObjectOutcome Client::CopyObject(const CopyObjectRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::CopyObjectRequest & req) { return CopyObject(req); }); +} + +Model::PutObjectOutcome Client::PutObject(const PutObjectRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::PutObjectRequest & req) { return PutObject(req); }); +} + +Model::UploadPartOutcome Client::UploadPart(const UploadPartRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::UploadPartRequest & req) { return UploadPart(req); }); +} + +Model::UploadPartCopyOutcome Client::UploadPartCopy(const UploadPartCopyRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::UploadPartCopyRequest & req) { return UploadPartCopy(req); }); +} + +Model::DeleteObjectOutcome Client::DeleteObject(const DeleteObjectRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::DeleteObjectRequest & req) { return DeleteObject(req); }); +} + +Model::DeleteObjectsOutcome Client::DeleteObjects(const DeleteObjectsRequest & request) const +{ + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, [this](const Model::DeleteObjectsRequest & req) { return DeleteObjects(req); }); +} + +#if 0 +Client::ComposeObjectOutcome Client::ComposeObject(const ComposeObjectRequest & request) const +{ + auto request_fn = [this](const ComposeObjectRequest & req) + { + auto & endpoint_provider = const_cast<Client &>(*this).accessEndpointProvider(); + AWS_OPERATION_CHECK_PTR(endpoint_provider, ComposeObject, Aws::Client::CoreErrors, Aws::Client::CoreErrors::ENDPOINT_RESOLUTION_FAILURE); + + if (!req.BucketHasBeenSet()) + { + AWS_LOGSTREAM_ERROR("ComposeObject", "Required field: Bucket, is not set") + return ComposeObjectOutcome(Aws::Client::AWSError<Aws::S3::S3Errors>(Aws::S3::S3Errors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [Bucket]", false)); + } + + if (!req.KeyHasBeenSet()) + { + AWS_LOGSTREAM_ERROR("ComposeObject", "Required field: Key, is not set") + return ComposeObjectOutcome(Aws::Client::AWSError<Aws::S3::S3Errors>(Aws::S3::S3Errors::MISSING_PARAMETER, "MISSING_PARAMETER", "Missing required field [Key]", false)); + } + + auto endpointResolutionOutcome = endpoint_provider->ResolveEndpoint(req.GetEndpointContextParams()); + AWS_OPERATION_CHECK_SUCCESS(endpointResolutionOutcome, ComposeObject, Aws::Client::CoreErrors, Aws::Client::CoreErrors::ENDPOINT_RESOLUTION_FAILURE, endpointResolutionOutcome.GetError().GetMessage()); + endpointResolutionOutcome.GetResult().AddPathSegments(req.GetKey()); + endpointResolutionOutcome.GetResult().SetQueryString("?compose"); + return ComposeObjectOutcome(MakeRequest(req, endpointResolutionOutcome.GetResult(), Aws::Http::HttpMethod::HTTP_PUT)); + }; + + return doRequestWithRetryNetworkErrors</*IsReadMethod*/ false>( + request, request_fn); +} +#endif + +template <typename RequestType, typename RequestFn> +std::invoke_result_t<RequestFn, RequestType> +Client::doRequest(const RequestType & request, RequestFn request_fn) const +{ + const auto & bucket = request.GetBucket(); + request.setApiMode(api_mode); + + if (auto region = getRegionForBucket(bucket); !region.empty()) + { + if (!detect_region) + LOG_INFO(log, "Using region override {} for bucket {}", region, bucket); + + request.overrideRegion(std::move(region)); + } + + if (auto uri = getURIForBucket(bucket); uri.has_value()) + request.overrideURI(std::move(*uri)); + + + bool found_new_endpoint = false; + // if we found correct endpoint after 301 responses, update the cache for future requests + SCOPE_EXIT( + if (found_new_endpoint) + { + auto uri_override = request.getURIOverride(); + assert(uri_override.has_value()); + updateURIForBucket(bucket, std::move(*uri_override)); + } + ); + + for (size_t attempt = 0; attempt <= max_redirects; ++attempt) + { + auto result = request_fn(request); + if (result.IsSuccess()) + return result; + + const auto & error = result.GetError(); + + std::string new_region; + if (checkIfWrongRegionDefined(bucket, error, new_region)) + { + request.overrideRegion(new_region); + continue; + } + + if (error.GetResponseCode() != Aws::Http::HttpResponseCode::MOVED_PERMANENTLY) + return result; + + // maybe we detect a correct region + if (!detect_region) + { + if (auto region = GetErrorMarshaller()->ExtractRegion(error); !region.empty() && region != explicit_region) + { + request.overrideRegion(region); + insertRegionOverride(bucket, region); + } + } + + // we possibly got new location, need to try with that one + auto new_uri = getURIFromError(error); + if (!new_uri) + return result; + + const auto & current_uri_override = request.getURIOverride(); + /// we already tried with this URI + if (current_uri_override && current_uri_override->uri == new_uri->uri) + { + LOG_INFO(log, "Getting redirected to the same invalid location {}", new_uri->uri.toString()); + return result; + } + + found_new_endpoint = true; + request.overrideURI(*new_uri); + } + + throw Exception(ErrorCodes::TOO_MANY_REDIRECTS, "Too many redirects"); +} + +template <bool IsReadMethod, typename RequestType, typename RequestFn> +std::invoke_result_t<RequestFn, RequestType> +Client::doRequestWithRetryNetworkErrors(const RequestType & request, RequestFn request_fn) const +{ + auto with_retries = [this, request_fn_ = std::move(request_fn)] (const RequestType & request_) + { + chassert(client_configuration.retryStrategy); + const Int64 max_attempts = client_configuration.retryStrategy->GetMaxAttempts(); + std::exception_ptr last_exception = nullptr; + for (Int64 attempt_no = 0; attempt_no < max_attempts; ++attempt_no) + { + try + { + /// S3 does retries network errors actually. + /// But it is matter when errors occur. + /// This code retries a specific case when + /// network error happens when XML document is being read from the response body. + /// Hence, the response body is a stream, network errors are possible at reading. + /// S3 doesn't retry them. + + /// Not all requests can be retried in that way. + /// Requests that read out response body to build the result are possible to retry. + /// Requests that expose the response stream as an answer are not retried with that code. E.g. GetObject. + return request_fn_(request_); + } + catch (Poco::Net::ConnectionResetException &) + { + + if constexpr (IsReadMethod) + { + if (client_configuration.for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3ReadRequestsErrors); + else + ProfileEvents::increment(ProfileEvents::S3ReadRequestsErrors); + } + else + { + if (client_configuration.for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3WriteRequestsErrors); + else + ProfileEvents::increment(ProfileEvents::S3WriteRequestsErrors); + } + + tryLogCurrentException(log, "Will retry"); + last_exception = std::current_exception(); + + auto error = Aws::Client::AWSError<Aws::Client::CoreErrors>(Aws::Client::CoreErrors::NETWORK_CONNECTION, /*retry*/ true); + client_configuration.retryStrategy->CalculateDelayBeforeNextRetry(error, attempt_no); + continue; + } + } + + chassert(last_exception); + std::rethrow_exception(last_exception); + }; + + return doRequest(request, with_retries); +} + +bool Client::supportsMultiPartCopy() const +{ + return provider_type != ProviderType::GCS; +} + +void Client::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, + const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const +{ + Aws::S3::S3Client::BuildHttpRequest(request, httpRequest); + + if (api_mode == ApiMode::GCS) + { + /// some GCS requests don't like S3 specific headers that the client sets + httpRequest->DeleteHeader("x-amz-api-version"); + httpRequest->DeleteHeader("amz-sdk-invocation-id"); + httpRequest->DeleteHeader("amz-sdk-request"); + } +} + +std::string Client::getRegionForBucket(const std::string & bucket, bool force_detect) const +{ + std::lock_guard lock(cache->region_cache_mutex); + if (auto it = cache->region_for_bucket_cache.find(bucket); it != cache->region_for_bucket_cache.end()) + return it->second; + + if (!force_detect && !detect_region) + return ""; + + LOG_INFO(log, "Resolving region for bucket {}", bucket); + Aws::S3::Model::HeadBucketRequest req; + req.SetBucket(bucket); + + std::string region; + auto outcome = HeadBucket(req); + if (outcome.IsSuccess()) + { + const auto & result = outcome.GetResult(); + // region = result.GetRegion(); + } + else + { + static const std::string region_header = "x-amz-bucket-region"; + const auto & headers = outcome.GetError().GetResponseHeaders(); + if (auto it = headers.find(region_header); it != headers.end()) + region = it->second; + } + + if (region.empty()) + { + LOG_INFO(log, "Failed resolving region for bucket {}", bucket); + return ""; + } + + LOG_INFO(log, "Found region {} for bucket {}", region, bucket); + + auto [it, _] = cache->region_for_bucket_cache.emplace(bucket, std::move(region)); + + return it->second; +} + +std::optional<S3::URI> Client::getURIFromError(const Aws::S3::S3Error & error) const +{ + return std::nullopt; +#if 0 + auto endpoint = GetErrorMarshaller()->ExtractEndpoint(error); + if (endpoint.empty()) + return std::nullopt; + + auto & s3_client = const_cast<Client &>(*this); + const auto * endpoint_provider = dynamic_cast<Aws::S3::Endpoint::S3DefaultEpProviderBase *>(s3_client.accessEndpointProvider().get()); + auto resolved_endpoint = endpoint_provider->ResolveEndpoint({}); + + if (!resolved_endpoint.IsSuccess()) + return std::nullopt; + + auto uri = resolved_endpoint.GetResult().GetURI(); + uri.SetAuthority(endpoint); + + return S3::URI(uri.GetURIString()); +#endif +} + +// Do a list request because head requests don't have body in response +std::optional<Aws::S3::S3Error> Client::updateURIForBucketForHead(const std::string & bucket) const +{ + ListObjectsV2Request req; + req.SetBucket(bucket); + req.SetMaxKeys(1); + auto result = ListObjectsV2(req); + if (result.IsSuccess()) + return std::nullopt; + return result.GetError(); +} + +std::optional<S3::URI> Client::getURIForBucket(const std::string & bucket) const +{ + std::lock_guard lock(cache->uri_cache_mutex); + if (auto it = cache->uri_for_bucket_cache.find(bucket); it != cache->uri_for_bucket_cache.end()) + return it->second; + + return std::nullopt; +} + +void Client::updateURIForBucket(const std::string & bucket, S3::URI new_uri) const +{ + std::lock_guard lock(cache->uri_cache_mutex); + if (auto it = cache->uri_for_bucket_cache.find(bucket); it != cache->uri_for_bucket_cache.end()) + { + if (it->second.uri == new_uri.uri) + return; + + LOG_INFO(log, "Updating URI for bucket {} to {}", bucket, new_uri.uri.toString()); + it->second = std::move(new_uri); + + return; + } + + LOG_INFO(log, "Updating URI for bucket {} to {}", bucket, new_uri.uri.toString()); + cache->uri_for_bucket_cache.emplace(bucket, std::move(new_uri)); +} + + +void ClientCache::clearCache() +{ + { + std::lock_guard lock(region_cache_mutex); + region_for_bucket_cache.clear(); + } + { + std::lock_guard lock(uri_cache_mutex); + uri_for_bucket_cache.clear(); + } +} + +void ClientCacheRegistry::registerClient(const std::shared_ptr<ClientCache> & client_cache) +{ + std::lock_guard lock(clients_mutex); + auto [it, inserted] = client_caches.emplace(client_cache.get(), client_cache); + if (!inserted) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Same S3 client registered twice"); +} + +void ClientCacheRegistry::unregisterClient(ClientCache * client) +{ + std::lock_guard lock(clients_mutex); + auto erased = client_caches.erase(client); + if (erased == 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Can't unregister S3 client, either it was already unregistered or not registered at all"); +} + +void ClientCacheRegistry::clearCacheForAll() +{ + std::lock_guard lock(clients_mutex); + + for (auto it = client_caches.begin(); it != client_caches.end();) + { + if (auto locked_client = it->second.lock(); locked_client) + { + locked_client->clearCache(); + ++it; + } + else + { + LOG_INFO(&Poco::Logger::get("ClientCacheRegistry"), "Deleting leftover S3 client cache"); + it = client_caches.erase(it); + } + } + +} + +ClientFactory::ClientFactory() +{ + aws_options = Aws::SDKOptions{}; + Aws::InitAPI(aws_options); + // Aws::Utils::Logging::InitializeAWSLogging(std::make_shared<AWSLogger>(false)); + Aws::Http::SetHttpClientFactory(std::make_shared<PocoHTTPClientFactory>()); +} + +ClientFactory::~ClientFactory() +{ + // Aws::Utils::Logging::ShutdownAWSLogging(); + Aws::ShutdownAPI(aws_options); +} + +ClientFactory & ClientFactory::instance() +{ + static ClientFactory ret; + return ret; +} + +std::unique_ptr<S3::Client> ClientFactory::create( // NOLINT + const PocoHTTPClientConfiguration & cfg_, + bool is_virtual_hosted_style, + const String & access_key_id, + const String & secret_access_key, + const String & server_side_encryption_customer_key_base64, + ServerSideEncryptionKMSConfig sse_kms_config, + HTTPHeaderEntries headers, + CredentialsConfiguration credentials_configuration, + const String & session_token) +{ + PocoHTTPClientConfiguration client_configuration = cfg_; + client_configuration.updateSchemeAndRegion(); + + if (!server_side_encryption_customer_key_base64.empty()) + { + /// See Client::GeneratePresignedUrlWithSSEC(). + + headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM, + Aws::S3::Model::ServerSideEncryptionMapper::GetNameForServerSideEncryption(Aws::S3::Model::ServerSideEncryption::AES256)}); + + headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY, + server_side_encryption_customer_key_base64}); + + Aws::Utils::ByteBuffer buffer = Aws::Utils::HashingUtils::Base64Decode(server_side_encryption_customer_key_base64); + String str_buffer(reinterpret_cast<char *>(buffer.GetUnderlyingData()), buffer.GetLength()); + headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5, + Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(str_buffer))}); + } + + // These will be added after request signing + client_configuration.extra_headers = std::move(headers); + + Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key, session_token); + auto credentials_provider = std::make_shared<S3CredentialsProviderChain>( + client_configuration, + std::move(credentials), + credentials_configuration); + + client_configuration.retryStrategy = std::make_shared<Client::RetryStrategy>(std::move(client_configuration.retryStrategy)); + return Client::create( + client_configuration.s3_max_redirects, + std::move(sse_kms_config), + credentials_provider, + client_configuration, // Client configuration. + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + is_virtual_hosted_style || client_configuration.endpointOverride.empty() /// Use virtual addressing if endpoint is not specified. + ); +} + +PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT + const String & force_region, + const RemoteHostFilter & remote_host_filter, + unsigned int s3_max_redirects, + bool enable_s3_requests_logging, + bool for_disk_s3, + const ThrottlerPtr & get_request_throttler, + const ThrottlerPtr & put_request_throttler, + const String & protocol) +{ + auto proxy_configuration_resolver = DB::ProxyConfigurationResolverProvider::get(DB::ProxyConfiguration::protocolFromString(protocol)); + + auto per_request_configuration = [=] () { return proxy_configuration_resolver->resolve(); }; + auto error_report = [=] (const DB::ProxyConfiguration & req) { proxy_configuration_resolver->errorReport(req); }; + + auto config = PocoHTTPClientConfiguration( + per_request_configuration, + force_region, + remote_host_filter, + s3_max_redirects, + enable_s3_requests_logging, + for_disk_s3, + get_request_throttler, + put_request_throttler, + error_report); + + config.scheme = Aws::Http::SchemeMapper::FromString(protocol.c_str()); + + return config; +} + +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/Client.h b/contrib/clickhouse/src/IO/S3/Client.h new file mode 100644 index 0000000000..721b8dd944 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/Client.h @@ -0,0 +1,329 @@ +#pragma once + +#include <optional> +#include <base/types.h> + +#include "clickhouse_config.h" + +namespace DB::S3 +{ + +/// See https://docs.aws.amazon.com/AmazonS3/latest/userguide/specifying-kms-encryption.html +/// Needed by S3Common.h even if USE_AWS_S3 is 0 +struct ServerSideEncryptionKMSConfig +{ + // If key_id is non-null, enable SSE-KMS. If key_id is "", use the AWS managed key + std::optional<String> key_id = std::nullopt; + std::optional<String> encryption_context = std::nullopt; + std::optional<bool> bucket_key_enabled = std::nullopt; + + bool operator==(const ServerSideEncryptionKMSConfig & other) const = default; +}; + +} + +#if USE_AWS_S3 + +#include <Common/assert_cast.h> +#include <base/scope_guard.h> + +#include <IO/S3/URI.h> +#include <IO/S3/Requests.h> +#include <IO/S3/PocoHTTPClient.h> +#include <IO/S3/Credentials.h> +#include <IO/S3/ProviderType.h> + +#include <aws/core/Aws.h> +#include <aws/core/client/DefaultRetryStrategy.h> +#include <aws/s3/S3Client.h> +// #include <aws/s3/S3ServiceClientModel.h> +#include <aws/core/client/AWSErrorMarshaller.h> +#include <aws/core/client/RetryStrategy.h> + +namespace MockS3 +{ + struct Client; +} + +namespace DB::S3 +{ + +namespace Model = Aws::S3::Model; + +struct ClientCache +{ + ClientCache() = default; + + ClientCache(const ClientCache & other) + : region_for_bucket_cache(other.region_for_bucket_cache) + , uri_for_bucket_cache(other.uri_for_bucket_cache) + {} + + ClientCache(ClientCache && other) = delete; + + ClientCache & operator=(const ClientCache &) = delete; + ClientCache & operator=(ClientCache &&) = delete; + + void clearCache(); + + std::mutex region_cache_mutex; + std::unordered_map<std::string, std::string> region_for_bucket_cache; + + std::mutex uri_cache_mutex; + std::unordered_map<std::string, URI> uri_for_bucket_cache; +}; + +class ClientCacheRegistry +{ +public: + static ClientCacheRegistry & instance() + { + static ClientCacheRegistry registry; + return registry; + } + + void registerClient(const std::shared_ptr<ClientCache> & client_cache); + void unregisterClient(ClientCache * client); + void clearCacheForAll(); +private: + ClientCacheRegistry() = default; + + std::mutex clients_mutex; + std::unordered_map<ClientCache *, std::weak_ptr<ClientCache>> client_caches; +}; + +/// Client that improves the client from the AWS SDK +/// - inject region and URI into requests so they are rerouted to the correct destination if needed +/// - automatically detect endpoint and regions for each bucket and cache them +/// +/// For this client to work correctly both Client::RetryStrategy and Requests defined in <IO/S3/Requests.h> should be used. +/// +/// To add support for new type of request +/// - ExtendedRequest should be defined inside IO/S3/Requests.h +/// - new method accepting that request should be defined in this Client (check other requests for reference) +/// - method handling the request from Aws::S3::S3Client should be left to private so we don't use it by accident +class Client : private Aws::S3::S3Client +{ +public: + class RetryStrategy; + + /// we use a factory method to verify arguments before creating a client because + /// there are certain requirements on arguments for it to work correctly + /// e.g. Client::RetryStrategy should be used + static std::unique_ptr<Client> create( + size_t max_redirects_, + ServerSideEncryptionKMSConfig sse_kms_config_, + const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider, + const PocoHTTPClientConfiguration & client_configuration, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads, + bool use_virtual_addressing); + + /// Create a client with adjusted settings: + /// * override_retry_strategy can be used to disable retries to avoid nested retries when we have + /// a retry loop outside of S3 client. Specifically, for read and write buffers. Currently not + /// actually used. + /// * override_request_timeout_ms is used to increase timeout for CompleteMultipartUploadRequest + /// because it often sits idle for 10 seconds: https://github.com/ClickHouse/ClickHouse/pull/42321 + std::unique_ptr<Client> clone( + std::optional<std::shared_ptr<RetryStrategy>> override_retry_strategy = std::nullopt, + std::optional<Int64> override_request_timeout_ms = std::nullopt) const; + + Client & operator=(const Client &) = delete; + + Client(Client && other) = delete; + Client & operator=(Client &&) = delete; + + ~Client() override + { + try + { + ClientCacheRegistry::instance().unregisterClient(cache.get()); + } + catch (...) + { + tryLogCurrentException(log); + throw; + } + } + + /// Returns the initial endpoint. + const String & getInitialEndpoint() const { return initial_endpoint; } + const String & getRegion() const { return explicit_region; } + + Aws::Auth::AWSCredentials getCredentials() const; + + /// Decorator for RetryStrategy needed for this client to work correctly. + /// We want to manually handle permanent moves (status code 301) because: + /// - redirect location is written in XML format inside the response body something that doesn't exist for HEAD + /// requests so we need to manually find the correct location + /// - we want to cache the new location to decrease number of roundtrips for future requests + /// This decorator doesn't retry if 301 is detected and fallbacks to the inner retry strategy otherwise. + class RetryStrategy : public Aws::Client::RetryStrategy + { + public: + explicit RetryStrategy(std::shared_ptr<Aws::Client::RetryStrategy> wrapped_strategy_); + + /// NOLINTNEXTLINE(google-runtime-int) + bool ShouldRetry(const Aws::Client::AWSError<Aws::Client::CoreErrors>& error, long attemptedRetries) const override; + + /// NOLINTNEXTLINE(google-runtime-int) + long CalculateDelayBeforeNextRetry(const Aws::Client::AWSError<Aws::Client::CoreErrors>& error, long attemptedRetries) const override; + + /// NOLINTNEXTLINE(google-runtime-int) + long GetMaxAttempts() const override; + + void GetSendToken() override; + + // bool HasSendToken() override; + + void RequestBookkeeping(const Aws::Client::HttpResponseOutcome& httpResponseOutcome) override; + void RequestBookkeeping(const Aws::Client::HttpResponseOutcome& httpResponseOutcome, const Aws::Client::AWSError<Aws::Client::CoreErrors>& lastError) override; + private: + std::shared_ptr<Aws::Client::RetryStrategy> wrapped_strategy; + }; + + /// SSE-KMS headers MUST be signed, so they need to be added before the SDK signs the message + /// (before sending the request with one of the methods below). + /// Per the docs (https://docs.aws.amazon.com/AmazonS3/latest/userguide/specifying-kms-encryption.html), + /// the headers should only be set for PutObject, CopyObject, POST Object, and CreateMultipartUpload. + template <typename RequestType> + void setKMSHeaders(RequestType & request) const; + + Model::HeadObjectOutcome HeadObject(const HeadObjectRequest & request) const; + Model::ListObjectsV2Outcome ListObjectsV2(const ListObjectsV2Request & request) const; + Model::ListObjectsOutcome ListObjects(const ListObjectsRequest & request) const; + Model::GetObjectOutcome GetObject(const GetObjectRequest & request) const; + + Model::AbortMultipartUploadOutcome AbortMultipartUpload(const AbortMultipartUploadRequest & request) const; + Model::CreateMultipartUploadOutcome CreateMultipartUpload(const CreateMultipartUploadRequest & request) const; + Model::CompleteMultipartUploadOutcome CompleteMultipartUpload(const CompleteMultipartUploadRequest & request) const; + Model::UploadPartOutcome UploadPart(const UploadPartRequest & request) const; + Model::UploadPartCopyOutcome UploadPartCopy(const UploadPartCopyRequest & request) const; + + Model::CopyObjectOutcome CopyObject(const CopyObjectRequest & request) const; + Model::PutObjectOutcome PutObject(const PutObjectRequest & request) const; + Model::DeleteObjectOutcome DeleteObject(const DeleteObjectRequest & request) const; + Model::DeleteObjectsOutcome DeleteObjects(const DeleteObjectsRequest & request) const; + + using ComposeObjectOutcome = Aws::Utils::Outcome<Aws::NoResult, Aws::S3::S3Error>; + ComposeObjectOutcome ComposeObject(const ComposeObjectRequest & request) const; + + using Aws::S3::S3Client::EnableRequestProcessing; + using Aws::S3::S3Client::DisableRequestProcessing; + + void BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, + const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const override; + + bool supportsMultiPartCopy() const; +private: + friend struct ::MockS3::Client; + + Client(size_t max_redirects_, + ServerSideEncryptionKMSConfig sse_kms_config_, + const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider_, + const PocoHTTPClientConfiguration & client_configuration, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads, + bool use_virtual_addressing); + + Client( + const Client & other, const PocoHTTPClientConfiguration & client_configuration); + + /// Leave regular functions private so we don't accidentally use them + /// otherwise region and endpoint redirection won't work + using Aws::S3::S3Client::HeadObject; + using Aws::S3::S3Client::ListObjectsV2; + using Aws::S3::S3Client::ListObjects; + using Aws::S3::S3Client::GetObject; + + using Aws::S3::S3Client::AbortMultipartUpload; + using Aws::S3::S3Client::CreateMultipartUpload; + using Aws::S3::S3Client::CompleteMultipartUpload; + using Aws::S3::S3Client::UploadPart; + using Aws::S3::S3Client::UploadPartCopy; + + using Aws::S3::S3Client::CopyObject; + using Aws::S3::S3Client::PutObject; + using Aws::S3::S3Client::DeleteObject; + using Aws::S3::S3Client::DeleteObjects; + + template <typename RequestType, typename RequestFn> + std::invoke_result_t<RequestFn, RequestType> + doRequest(const RequestType & request, RequestFn request_fn) const; + + template <bool IsReadMethod, typename RequestType, typename RequestFn> + std::invoke_result_t<RequestFn, RequestType> + doRequestWithRetryNetworkErrors(const RequestType & request, RequestFn request_fn) const; + + void updateURIForBucket(const std::string & bucket, S3::URI new_uri) const; + std::optional<S3::URI> getURIFromError(const Aws::S3::S3Error & error) const; + std::optional<Aws::S3::S3Error> updateURIForBucketForHead(const std::string & bucket) const; + + std::string getRegionForBucket(const std::string & bucket, bool force_detect = false) const; + std::optional<S3::URI> getURIForBucket(const std::string & bucket) const; + + bool checkIfWrongRegionDefined(const std::string & bucket, const Aws::S3::S3Error & error, std::string & region) const; + void insertRegionOverride(const std::string & bucket, const std::string & region) const; + + String initial_endpoint; + std::shared_ptr<Aws::Auth::AWSCredentialsProvider> credentials_provider; + PocoHTTPClientConfiguration client_configuration; + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads; + bool use_virtual_addressing; + + std::string explicit_region; + mutable bool detect_region = true; + + /// provider type can determine if some functionality is supported + /// but for same provider, we would need to generate different headers depending on the + /// mode + /// E.g. GCS can work in AWS mode in some cases and accept headers with x-amz prefix + ProviderType provider_type{ProviderType::UNKNOWN}; + ApiMode api_mode{ApiMode::AWS}; + + mutable std::shared_ptr<ClientCache> cache; + + const size_t max_redirects; + + const ServerSideEncryptionKMSConfig sse_kms_config; + + Poco::Logger * log; +}; + +class ClientFactory +{ +public: + ~ClientFactory(); + + static ClientFactory & instance(); + + std::unique_ptr<S3::Client> create( + const PocoHTTPClientConfiguration & cfg, + bool is_virtual_hosted_style, + const String & access_key_id, + const String & secret_access_key, + const String & server_side_encryption_customer_key_base64, + ServerSideEncryptionKMSConfig sse_kms_config, + HTTPHeaderEntries headers, + CredentialsConfiguration credentials_configuration, + const String & session_token = ""); + + PocoHTTPClientConfiguration createClientConfiguration( + const String & force_region, + const RemoteHostFilter & remote_host_filter, + unsigned int s3_max_redirects, + bool enable_s3_requests_logging, + bool for_disk_s3, + const ThrottlerPtr & get_request_throttler, + const ThrottlerPtr & put_request_throttler, + const String & protocol = "https"); + +private: + ClientFactory(); + + Aws::SDKOptions aws_options; + std::atomic<bool> s3_requests_logging_enabled; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/Credentials.cpp b/contrib/clickhouse/src/IO/S3/Credentials.cpp new file mode 100644 index 0000000000..c8820496bf --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/Credentials.cpp @@ -0,0 +1,574 @@ +#include <IO/S3/Credentials.h> + +#if USE_AWS_S3 + +# include <aws/core/Version.h> +# include <aws/core/platform/OSVersionInfo.h> +# include <aws/core/auth/STSCredentialsProvider.h> +# include <aws/core/platform/Environment.h> +# include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h> +# include <aws/core/utils/json/JsonSerializer.h> +# include <aws/core/utils/UUID.h> +# include <aws/core/http/HttpClientFactory.h> + +# include <Common/logger_useful.h> + +# include <IO/S3/PocoHTTPClient.h> +# include <IO/S3/Client.h> + +# include <fstream> +# include <base/EnumReflection.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int AWS_ERROR; +} + +namespace S3 +{ + +namespace +{ + +bool areCredentialsEmptyOrExpired(const Aws::Auth::AWSCredentials & credentials, uint64_t expiration_window_seconds) +{ + if (credentials.IsEmpty()) + return true; + + const Aws::Utils::DateTime now = Aws::Utils::DateTime::Now(); + return now >= credentials.GetExpiration() - std::chrono::seconds(expiration_window_seconds); +} + +} + +AWSEC2MetadataClient::AWSEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration, const char * endpoint_) + : Aws::Internal::AWSHttpResourceClient(client_configuration) + , endpoint(endpoint_) + , logger(&Poco::Logger::get("AWSEC2InstanceProfileConfigLoader")) +{ +} + +Aws::String AWSEC2MetadataClient::GetResource(const char * resource_path) const +{ + return GetResource(endpoint.c_str(), resource_path, nullptr/*authToken*/); +} + +Aws::String AWSEC2MetadataClient::getDefaultCredentials() const +{ + String credentials_string; + { + std::lock_guard locker(token_mutex); + + LOG_TRACE(logger, "Getting default credentials for ec2 instance from {}", endpoint); + auto result = GetResourceWithAWSWebServiceResult(endpoint.c_str(), EC2_SECURITY_CREDENTIALS_RESOURCE, nullptr); + credentials_string = result.GetPayload(); + if (result.GetResponseCode() == Aws::Http::HttpResponseCode::UNAUTHORIZED) + { + return {}; + } + } + + String trimmed_credentials_string = Aws::Utils::StringUtils::Trim(credentials_string.c_str()); + if (trimmed_credentials_string.empty()) + return {}; + + std::vector<String> security_credentials = Aws::Utils::StringUtils::Split(trimmed_credentials_string, '\n'); + + LOG_DEBUG(logger, "Calling EC2MetadataService resource, {} returned credential string {}.", + EC2_SECURITY_CREDENTIALS_RESOURCE, trimmed_credentials_string); + + if (security_credentials.empty()) + { + LOG_WARNING(logger, "Initial call to EC2MetadataService to get credentials failed."); + return {}; + } + + Aws::StringStream ss; + ss << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << security_credentials[0]; + LOG_DEBUG(logger, "Calling EC2MetadataService resource {}.", ss.str()); + return GetResource(ss.str().c_str()); +} + +Aws::String AWSEC2MetadataClient::awsComputeUserAgentString() +{ + Aws::StringStream ss; + ss << "aws-sdk-cpp/" << Aws::Version::GetVersionString() << " " << Aws::OSVersionInfo::ComputeOSVersionString() + << " " << Aws::Version::GetCompilerVersionString(); + return ss.str(); +} + +Aws::String AWSEC2MetadataClient::getDefaultCredentialsSecurely() const +{ + String user_agent_string = awsComputeUserAgentString(); + auto [new_token, response_code] = getEC2MetadataToken(user_agent_string); + if (response_code == Aws::Http::HttpResponseCode::BAD_REQUEST) + return {}; + else if (response_code != Aws::Http::HttpResponseCode::OK || new_token.empty()) + { + LOG_TRACE(logger, "Calling EC2MetadataService to get token failed, " + "falling back to less secure way. HTTP response code: {}", response_code); + return getDefaultCredentials(); + } + + token = std::move(new_token); + String url = endpoint + EC2_SECURITY_CREDENTIALS_RESOURCE; + std::shared_ptr<Aws::Http::HttpRequest> profile_request(Aws::Http::CreateHttpRequest(url, + Aws::Http::HttpMethod::HTTP_GET, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + profile_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, token); + profile_request->SetUserAgent(user_agent_string); + String profile_string = GetResourceWithAWSWebServiceResult(profile_request).GetPayload(); + + String trimmed_profile_string = Aws::Utils::StringUtils::Trim(profile_string.c_str()); + std::vector<String> security_credentials = Aws::Utils::StringUtils::Split(trimmed_profile_string, '\n'); + + LOG_DEBUG(logger, "Calling EC2MetadataService resource, {} with token returned profile string {}.", + EC2_SECURITY_CREDENTIALS_RESOURCE, trimmed_profile_string); + + if (security_credentials.empty()) + { + LOG_WARNING(logger, "Calling EC2Metadataservice to get profiles failed."); + return {}; + } + + Aws::StringStream ss; + ss << endpoint << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << security_credentials[0]; + std::shared_ptr<Aws::Http::HttpRequest> credentials_request(Aws::Http::CreateHttpRequest(ss.str(), + Aws::Http::HttpMethod::HTTP_GET, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + credentials_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, token); + credentials_request->SetUserAgent(user_agent_string); + LOG_DEBUG(logger, "Calling EC2MetadataService resource {} with token.", ss.str()); + return GetResourceWithAWSWebServiceResult(credentials_request).GetPayload(); +} + +Aws::String AWSEC2MetadataClient::getCurrentAvailabilityZone() const +{ + String user_agent_string = awsComputeUserAgentString(); + auto [new_token, response_code] = getEC2MetadataToken(user_agent_string); + if (response_code != Aws::Http::HttpResponseCode::OK || new_token.empty()) + throw DB::Exception(ErrorCodes::AWS_ERROR, + "Failed to make token request. HTTP response code: {}", response_code); + + token = std::move(new_token); + const String url = endpoint + EC2_AVAILABILITY_ZONE_RESOURCE; + std::shared_ptr<Aws::Http::HttpRequest> profile_request( + Aws::Http::CreateHttpRequest(url, Aws::Http::HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + + profile_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, token); + profile_request->SetUserAgent(user_agent_string); + + const auto result = GetResourceWithAWSWebServiceResult(profile_request); + if (result.GetResponseCode() != Aws::Http::HttpResponseCode::OK) + throw DB::Exception(ErrorCodes::AWS_ERROR, + "Failed to get availability zone. HTTP response code: {}", result.GetResponseCode()); + + return Aws::Utils::StringUtils::Trim(result.GetPayload().c_str()); +} + +std::pair<Aws::String, Aws::Http::HttpResponseCode> AWSEC2MetadataClient::getEC2MetadataToken(const std::string & user_agent_string) const +{ + std::lock_guard locker(token_mutex); + + Aws::StringStream ss; + ss << endpoint << EC2_IMDS_TOKEN_RESOURCE; + std::shared_ptr<Aws::Http::HttpRequest> token_request( + Aws::Http::CreateHttpRequest( + ss.str(), Aws::Http::HttpMethod::HTTP_PUT, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + token_request->SetHeaderValue(EC2_IMDS_TOKEN_TTL_HEADER, EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE); + token_request->SetUserAgent(user_agent_string); + + LOG_TRACE(logger, "Calling EC2MetadataService to get token."); + const auto result = GetResourceWithAWSWebServiceResult(token_request); + const auto & token_string = result.GetPayload(); + return { Aws::Utils::StringUtils::Trim(token_string.c_str()), result.GetResponseCode() }; +} + +Aws::String AWSEC2MetadataClient::getCurrentRegion() const +{ + return Aws::Region::AWS_GLOBAL; +} + +std::shared_ptr<AWSEC2MetadataClient> InitEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration) +{ + Aws::String ec2_metadata_service_endpoint = Aws::Environment::GetEnv("AWS_EC2_METADATA_SERVICE_ENDPOINT"); + auto * logger = &Poco::Logger::get("AWSEC2InstanceProfileConfigLoader"); + if (ec2_metadata_service_endpoint.empty()) + { + Aws::String ec2_metadata_service_endpoint_mode = Aws::Environment::GetEnv("AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE"); + if (ec2_metadata_service_endpoint_mode.length() == 0) + { + ec2_metadata_service_endpoint = "http://169.254.169.254"; //default to IPv4 default endpoint + } + else + { + if (ec2_metadata_service_endpoint_mode.length() == 4) + { + if (Aws::Utils::StringUtils::CaselessCompare(ec2_metadata_service_endpoint_mode.c_str(), "ipv4")) + { + ec2_metadata_service_endpoint = "http://169.254.169.254"; //default to IPv4 default endpoint + } + else if (Aws::Utils::StringUtils::CaselessCompare(ec2_metadata_service_endpoint_mode.c_str(), "ipv6")) + { + ec2_metadata_service_endpoint = "http://[fd00:ec2::254]"; + } + else + { + LOG_ERROR(logger, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE can only be set to ipv4 or ipv6, received: {}", ec2_metadata_service_endpoint_mode); + } + } + else + { + LOG_ERROR(logger, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE can only be set to ipv4 or ipv6, received: {}", ec2_metadata_service_endpoint_mode); + } + } + } + LOG_INFO(logger, "Using IMDS endpoint: {}", ec2_metadata_service_endpoint); + return std::make_shared<AWSEC2MetadataClient>(client_configuration, ec2_metadata_service_endpoint.c_str()); +} + +AWSEC2InstanceProfileConfigLoader::AWSEC2InstanceProfileConfigLoader(const std::shared_ptr<AWSEC2MetadataClient> & client_, bool use_secure_pull_) + : client(client_) + , use_secure_pull(use_secure_pull_) + , logger(&Poco::Logger::get("AWSEC2InstanceProfileConfigLoader")) +{ +} + +bool AWSEC2InstanceProfileConfigLoader::LoadInternal() +{ + auto credentials_str = use_secure_pull ? client->getDefaultCredentialsSecurely() : client->getDefaultCredentials(); + + /// See EC2InstanceProfileConfigLoader. + if (credentials_str.empty()) + return false; + + Aws::Utils::Json::JsonValue credentials_doc(credentials_str); + if (!credentials_doc.WasParseSuccessful()) + { + LOG_ERROR(logger, "Failed to parse output from EC2MetadataService."); + return false; + } + String access_key, secret_key, token; + + auto credentials_view = credentials_doc.View(); + access_key = credentials_view.GetString("AccessKeyId"); + LOG_TRACE(logger, "Successfully pulled credentials from EC2MetadataService with access key."); + + secret_key = credentials_view.GetString("SecretAccessKey"); + token = credentials_view.GetString("Token"); + + auto region = client->getCurrentRegion(); + + Aws::Config::Profile profile; + profile.SetCredentials(Aws::Auth::AWSCredentials(access_key, secret_key, token)); + profile.SetRegion(region); + profile.SetName(Aws::Config::INSTANCE_PROFILE_KEY); + + m_profiles[Aws::Config::INSTANCE_PROFILE_KEY] = profile; + + return true; +} + +AWSInstanceProfileCredentialsProvider::AWSInstanceProfileCredentialsProvider(const std::shared_ptr<AWSEC2InstanceProfileConfigLoader> & config_loader) + : ec2_metadata_config_loader(config_loader) + , load_frequency_ms(Aws::Auth::REFRESH_THRESHOLD) + , logger(&Poco::Logger::get("AWSInstanceProfileCredentialsProvider")) +{ + LOG_INFO(logger, "Creating Instance with injected EC2MetadataClient and refresh rate."); +} + +Aws::Auth::AWSCredentials AWSInstanceProfileCredentialsProvider::GetAWSCredentials() +{ + refreshIfExpired(); + Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock); + auto profile_it = ec2_metadata_config_loader->GetProfiles().find(Aws::Config::INSTANCE_PROFILE_KEY); + + if (profile_it != ec2_metadata_config_loader->GetProfiles().end()) + { + return profile_it->second.GetCredentials(); + } + + return Aws::Auth::AWSCredentials(); +} + +void AWSInstanceProfileCredentialsProvider::Reload() +{ + LOG_INFO(logger, "Credentials have expired attempting to repull from EC2 Metadata Service."); + ec2_metadata_config_loader->Load(); + AWSCredentialsProvider::Reload(); +} + +void AWSInstanceProfileCredentialsProvider::refreshIfExpired() +{ + Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock); + if (!IsTimeToRefresh(load_frequency_ms)) + { + return; + } + + guard.UpgradeToWriterLock(); + if (!IsTimeToRefresh(load_frequency_ms)) // double-checked lock to avoid refreshing twice + { + return; + } + Reload(); +} + +AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider( + DB::S3::PocoHTTPClientConfiguration & aws_client_configuration, uint64_t expiration_window_seconds_) + : logger(&Poco::Logger::get("AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider")) + , expiration_window_seconds(expiration_window_seconds_) +{ + // check environment variables + String tmp_region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION"); + role_arn = Aws::Environment::GetEnv("AWS_ROLE_ARN"); + token_file = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE"); + session_name = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME"); + + // check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable + // region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file. + if (role_arn.empty() || token_file.empty() || tmp_region.empty()) + { + auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName()); + if (tmp_region.empty()) + { + tmp_region = profile.GetRegion(); + } + // If either of these two were not found from environment, use whatever found for all three in config file + if (role_arn.empty() || token_file.empty()) + { + role_arn = profile.GetRoleArn(); + token_file = profile.GetValue("web_identity_token_file"); + session_name = profile.GetValue("role_session_name"); + } + } + + if (token_file.empty()) + { + LOG_WARNING(logger, "Token file must be specified to use STS AssumeRole web identity creds provider."); + return; // No need to do further constructing + } + else + { + LOG_DEBUG(logger, "Resolved token_file from profile_config or environment variable to be {}", token_file); + } + + if (role_arn.empty()) + { + LOG_WARNING(logger, "RoleArn must be specified to use STS AssumeRole web identity creds provider."); + return; // No need to do further constructing + } + else + { + LOG_DEBUG(logger, "Resolved role_arn from profile_config or environment variable to be {}", role_arn); + } + + if (tmp_region.empty()) + { + tmp_region = Aws::Region::US_EAST_1; + } + else + { + LOG_DEBUG(logger, "Resolved region from profile_config or environment variable to be {}", tmp_region); + } + + if (session_name.empty()) + { + session_name = Aws::Utils::UUID::RandomUUID(); + } + else + { + LOG_DEBUG(logger, "Resolved session_name from profile_config or environment variable to be {}", session_name); + } + + aws_client_configuration.scheme = Aws::Http::Scheme::HTTPS; + aws_client_configuration.region = tmp_region; + + std::vector<String> retryable_errors; + retryable_errors.push_back("IDPCommunicationError"); + retryable_errors.push_back("InvalidIdentityToken"); + + aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::SpecifiedRetryableErrorsRetryStrategy>( + retryable_errors, /* maxRetries = */3); + + client = std::make_unique<Aws::Internal::STSCredentialsClient>(aws_client_configuration); + initialized = true; + LOG_INFO(logger, "Creating STS AssumeRole with web identity creds provider."); +} + +Aws::Auth::AWSCredentials AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials() +{ + // A valid client means required information like role arn and token file were constructed correctly. + // We can use this provider to load creds, otherwise, we can just return empty creds. + if (!initialized) + { + return Aws::Auth::AWSCredentials(); + } + refreshIfExpired(); + Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock); + return credentials; +} + +void AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::Reload() +{ + LOG_INFO(logger, "Credentials have expired, attempting to renew from STS."); + + std::ifstream token_stream(token_file.data()); + if (token_stream) + { + String token_string((std::istreambuf_iterator<char>(token_stream)), std::istreambuf_iterator<char>()); + token = token_string; + } + else + { + LOG_INFO(logger, "Can't open token file: {}", token_file); + return; + } + Aws::Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{session_name, role_arn, token}; + + auto result = client->GetAssumeRoleWithWebIdentityCredentials(request); + LOG_TRACE(logger, "Successfully retrieved credentials."); + credentials = result.creds; +} + +void AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::refreshIfExpired() +{ + Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock); + if (!areCredentialsEmptyOrExpired(credentials, expiration_window_seconds)) + return; + + guard.UpgradeToWriterLock(); + if (!areCredentialsEmptyOrExpired(credentials, expiration_window_seconds)) // double-checked lock to avoid refreshing twice + return; + + Reload(); +} + +S3CredentialsProviderChain::S3CredentialsProviderChain( + const DB::S3::PocoHTTPClientConfiguration & configuration, + const Aws::Auth::AWSCredentials & credentials, + CredentialsConfiguration credentials_configuration) +{ + auto * logger = &Poco::Logger::get("S3CredentialsProviderChain"); + + /// we don't provide any credentials to avoid signing + if (credentials_configuration.no_sign_request) + return; + + /// add explicit credentials to the front of the chain + /// because it's manually defined by the user + if (!credentials.IsEmpty()) + { + AddProvider(std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials)); + return; + } + + if (credentials_configuration.use_environment_credentials) + { + static const char AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"; + static const char AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI"; + static const char AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN[] = "AWS_CONTAINER_AUTHORIZATION_TOKEN"; + static const char AWS_EC2_METADATA_DISABLED[] = "AWS_EC2_METADATA_DISABLED"; + + /// The only difference from DefaultAWSCredentialsProviderChain::DefaultAWSCredentialsProviderChain() + /// is that this chain uses custom ClientConfiguration. Also we removed process provider because it's useless in our case. + /// + /// AWS API tries credentials providers one by one. Some of providers (like ProfileConfigFileAWSCredentialsProvider) can be + /// quite verbose even if nobody configured them. So we use our provider first and only after it use default providers. + { + DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration( + configuration.region, + configuration.remote_host_filter, + configuration.s3_max_redirects, + configuration.enable_s3_requests_logging, + configuration.for_disk_s3, + configuration.get_request_throttler, + configuration.put_request_throttler); + AddProvider(std::make_shared<AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider>(aws_client_configuration, credentials_configuration.expiration_window_seconds)); + } + + AddProvider(std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>()); + + + /// ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set. + const auto relative_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI); + LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI, + relative_uri); + + const auto absolute_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI); + LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI, + absolute_uri); + + const auto ec2_metadata_disabled = Aws::Environment::GetEnv(AWS_EC2_METADATA_DISABLED); + LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_EC2_METADATA_DISABLED, + ec2_metadata_disabled); + + if (!relative_uri.empty()) + { + AddProvider(std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(relative_uri.c_str())); + LOG_INFO(logger, "Added ECS metadata service credentials provider with relative path: [{}] to the provider chain.", + relative_uri); + } + else if (!absolute_uri.empty()) + { + const auto token = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN); + AddProvider(std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(absolute_uri.c_str(), token.c_str())); + + /// DO NOT log the value of the authorization token for security purposes. + LOG_INFO(logger, "Added ECS credentials provider with URI: [{}] to the provider chain with a{} authorization token.", + absolute_uri, token.empty() ? "n empty" : " non-empty"); + } + else if (Aws::Utils::StringUtils::ToLower(ec2_metadata_disabled.c_str()) != "true") + { + DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration( + configuration.region, + configuration.remote_host_filter, + configuration.s3_max_redirects, + configuration.enable_s3_requests_logging, + configuration.for_disk_s3, + configuration.get_request_throttler, + configuration.put_request_throttler, + Aws::Http::SchemeMapper::ToString(Aws::Http::Scheme::HTTP)); + + /// See MakeDefaultHttpResourceClientConfiguration(). + /// This is part of EC2 metadata client, but unfortunately it can't be accessed from outside + /// of contrib/aws/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp + aws_client_configuration.maxConnections = 2; + + /// Explicitly set the proxy settings to empty/zero to avoid relying on defaults that could potentially change + /// in the future. + aws_client_configuration.proxyHost = ""; + aws_client_configuration.proxyUserName = ""; + aws_client_configuration.proxyPassword = ""; + aws_client_configuration.proxyPort = 0; + + /// EC2MetadataService throttles by delaying the response so the service client should set a large read timeout. + /// EC2MetadataService delay is in order of seconds so it only make sense to retry after a couple of seconds. + aws_client_configuration.connectTimeoutMs = 1000; + aws_client_configuration.requestTimeoutMs = 1000; + + aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::DefaultRetryStrategy>(1, 1000); + + auto ec2_metadata_client = InitEC2MetadataClient(aws_client_configuration); + auto config_loader = std::make_shared<AWSEC2InstanceProfileConfigLoader>(ec2_metadata_client, !credentials_configuration.use_insecure_imds_request); + + AddProvider(std::make_shared<AWSInstanceProfileCredentialsProvider>(config_loader)); + LOG_INFO(logger, "Added EC2 metadata service credentials provider to the provider chain."); + } + } + + /// Quite verbose provider (argues if file with credentials doesn't exist) so iut's the last one + /// in chain. + AddProvider(std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>()); +} + +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/Credentials.h b/contrib/clickhouse/src/IO/S3/Credentials.h new file mode 100644 index 0000000000..429941cd84 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/Credentials.h @@ -0,0 +1,146 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +# include <aws/core/client/ClientConfiguration.h> +# include <aws/core/internal/AWSHttpResourceClient.h> +# include <aws/core/config/AWSProfileConfigLoader.h> +# include <aws/core/auth/AWSCredentialsProviderChain.h> + +# include <IO/S3/PocoHTTPClient.h> + + +namespace DB::S3 +{ + +inline static constexpr uint64_t DEFAULT_EXPIRATION_WINDOW_SECONDS = 120; + +class AWSEC2MetadataClient : public Aws::Internal::AWSHttpResourceClient +{ + static constexpr char EC2_SECURITY_CREDENTIALS_RESOURCE[] = "/latest/meta-data/iam/security-credentials"; + static constexpr char EC2_AVAILABILITY_ZONE_RESOURCE[] = "/latest/meta-data/placement/availability-zone"; + static constexpr char EC2_IMDS_TOKEN_RESOURCE[] = "/latest/api/token"; + static constexpr char EC2_IMDS_TOKEN_HEADER[] = "x-aws-ec2-metadata-token"; + static constexpr char EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE[] = "21600"; + static constexpr char EC2_IMDS_TOKEN_TTL_HEADER[] = "x-aws-ec2-metadata-token-ttl-seconds"; + +public: + /// See EC2MetadataClient. + + explicit AWSEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration, const char * endpoint_); + + AWSEC2MetadataClient& operator =(const AWSEC2MetadataClient & rhs) = delete; + AWSEC2MetadataClient(const AWSEC2MetadataClient & rhs) = delete; + AWSEC2MetadataClient& operator =(const AWSEC2MetadataClient && rhs) = delete; + AWSEC2MetadataClient(const AWSEC2MetadataClient && rhs) = delete; + + ~AWSEC2MetadataClient() override = default; + + using Aws::Internal::AWSHttpResourceClient::GetResource; + + virtual Aws::String GetResource(const char * resource_path) const; + virtual Aws::String getDefaultCredentials() const; + + static Aws::String awsComputeUserAgentString(); + + virtual Aws::String getDefaultCredentialsSecurely() const; + + virtual Aws::String getCurrentRegion() const; + + virtual Aws::String getCurrentAvailabilityZone() const; + +private: + std::pair<Aws::String, Aws::Http::HttpResponseCode> getEC2MetadataToken(const std::string & user_agent_string) const; + + const Aws::String endpoint; + mutable std::recursive_mutex token_mutex; + mutable Aws::String token; + Poco::Logger * logger; +}; + +std::shared_ptr<AWSEC2MetadataClient> InitEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration); + +class AWSEC2InstanceProfileConfigLoader : public Aws::Config::AWSProfileConfigLoader +{ +public: + explicit AWSEC2InstanceProfileConfigLoader(const std::shared_ptr<AWSEC2MetadataClient> & client_, bool use_secure_pull_); + + ~AWSEC2InstanceProfileConfigLoader() override = default; + +protected: + bool LoadInternal() override; + +private: + std::shared_ptr<AWSEC2MetadataClient> client; + bool use_secure_pull; + Poco::Logger * logger; +}; + +class AWSInstanceProfileCredentialsProvider : public Aws::Auth::AWSCredentialsProvider +{ +public: + /// See InstanceProfileCredentialsProvider. + + explicit AWSInstanceProfileCredentialsProvider(const std::shared_ptr<AWSEC2InstanceProfileConfigLoader> & config_loader); + + Aws::Auth::AWSCredentials GetAWSCredentials() override; +protected: + void Reload() override; + +private: + void refreshIfExpired(); + + std::shared_ptr<AWSEC2InstanceProfileConfigLoader> ec2_metadata_config_loader; + Int64 load_frequency_ms; + Poco::Logger * logger; +}; + +class AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider : public Aws::Auth::AWSCredentialsProvider +{ + /// See STSAssumeRoleWebIdentityCredentialsProvider. + +public: + explicit AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider( + DB::S3::PocoHTTPClientConfiguration & aws_client_configuration, uint64_t expiration_window_seconds_); + + Aws::Auth::AWSCredentials GetAWSCredentials() override; + +protected: + void Reload() override; + +private: + void refreshIfExpired(); + + std::unique_ptr<Aws::Internal::STSCredentialsClient> client; + Aws::Auth::AWSCredentials credentials; + Aws::String role_arn; + Aws::String token_file; + Aws::String session_name; + Aws::String token; + bool initialized = false; + Poco::Logger * logger; + uint64_t expiration_window_seconds; +}; + +struct CredentialsConfiguration +{ + bool use_environment_credentials = false; + bool use_insecure_imds_request = false; + uint64_t expiration_window_seconds = DEFAULT_EXPIRATION_WINDOW_SECONDS; + bool no_sign_request = false; +}; + +class S3CredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain +{ +public: + S3CredentialsProviderChain( + const DB::S3::PocoHTTPClientConfiguration & configuration, + const Aws::Auth::AWSCredentials & credentials, + CredentialsConfiguration credentials_configuration); +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/PocoHTTPClient.cpp b/contrib/clickhouse/src/IO/S3/PocoHTTPClient.cpp new file mode 100644 index 0000000000..a61f88c4af --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/PocoHTTPClient.cpp @@ -0,0 +1,558 @@ +#include <Poco/Timespan.h> +#include "Common/DNSResolver.h" +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include "PocoHTTPClient.h" + +#include <utility> +#include <algorithm> +#include <functional> + +#include <Common/logger_useful.h> +#include <Common/Stopwatch.h> +#include <Common/Throttler.h> +#include <IO/HTTPCommon.h> +#include <IO/WriteBufferFromString.h> +#include <IO/Operators.h> +#include <IO/S3/ProviderType.h> + +#include <aws/core/http/HttpRequest.h> +#include <aws/core/http/HttpResponse.h> +#include <aws/core/utils/xml/XmlSerializer.h> +#include <aws/core/monitoring/HttpClientMetrics.h> +#include <aws/core/utils/ratelimiter/RateLimiterInterface.h> +#include "Poco/StreamCopier.h" +#include <Poco/Net/HTTPRequest.h> +#include <Poco/Net/HTTPResponse.h> +#include <re2/re2.h> + +#include <boost/algorithm/string.hpp> + +static const int SUCCESS_RESPONSE_MIN = 200; +static const int SUCCESS_RESPONSE_MAX = 299; + +namespace ProfileEvents +{ + extern const Event S3ReadMicroseconds; + extern const Event S3ReadRequestsCount; + extern const Event S3ReadRequestsErrors; + extern const Event S3ReadRequestsThrottling; + extern const Event S3ReadRequestsRedirects; + + extern const Event S3WriteMicroseconds; + extern const Event S3WriteRequestsCount; + extern const Event S3WriteRequestsErrors; + extern const Event S3WriteRequestsThrottling; + extern const Event S3WriteRequestsRedirects; + + extern const Event DiskS3ReadMicroseconds; + extern const Event DiskS3ReadRequestsCount; + extern const Event DiskS3ReadRequestsErrors; + extern const Event DiskS3ReadRequestsThrottling; + extern const Event DiskS3ReadRequestsRedirects; + + extern const Event DiskS3WriteMicroseconds; + extern const Event DiskS3WriteRequestsCount; + extern const Event DiskS3WriteRequestsErrors; + extern const Event DiskS3WriteRequestsThrottling; + extern const Event DiskS3WriteRequestsRedirects; + + extern const Event S3GetRequestThrottlerCount; + extern const Event S3GetRequestThrottlerSleepMicroseconds; + extern const Event S3PutRequestThrottlerCount; + extern const Event S3PutRequestThrottlerSleepMicroseconds; + + extern const Event DiskS3GetRequestThrottlerCount; + extern const Event DiskS3GetRequestThrottlerSleepMicroseconds; + extern const Event DiskS3PutRequestThrottlerCount; + extern const Event DiskS3PutRequestThrottlerSleepMicroseconds; +} + +namespace CurrentMetrics +{ + extern const Metric S3Requests; +} + +namespace DB::ErrorCodes +{ + extern const int NOT_IMPLEMENTED; + extern const int TOO_MANY_REDIRECTS; +} + +namespace DB::S3 +{ + +PocoHTTPClientConfiguration::PocoHTTPClientConfiguration( + std::function<DB::ProxyConfiguration()> per_request_configuration_, + const String & force_region_, + const RemoteHostFilter & remote_host_filter_, + unsigned int s3_max_redirects_, + bool enable_s3_requests_logging_, + bool for_disk_s3_, + const ThrottlerPtr & get_request_throttler_, + const ThrottlerPtr & put_request_throttler_, + std::function<void(const DB::ProxyConfiguration &)> error_report_) + : per_request_configuration(per_request_configuration_) + , force_region(force_region_) + , remote_host_filter(remote_host_filter_) + , s3_max_redirects(s3_max_redirects_) + , enable_s3_requests_logging(enable_s3_requests_logging_) + , for_disk_s3(for_disk_s3_) + , get_request_throttler(get_request_throttler_) + , put_request_throttler(put_request_throttler_) + , error_report(error_report_) +{ +} + +void PocoHTTPClientConfiguration::updateSchemeAndRegion() +{ + if (!endpointOverride.empty()) + { + static const RE2 region_pattern(R"(^s3[.\-]([a-z0-9\-]+)\.amazonaws\.)"); + Poco::URI uri(endpointOverride); + if (uri.getScheme() == "http") + scheme = Aws::Http::Scheme::HTTP; + + if (force_region.empty()) + { + String matched_region; + if (re2::RE2::PartialMatch(uri.getHost(), region_pattern, &matched_region)) + { + boost::algorithm::to_lower(matched_region); + region = matched_region; + } + else + { + /// In global mode AWS C++ SDK send `us-east-1` but accept switching to another one if being suggested. + region = Aws::Region::AWS_GLOBAL; + } + } + else + { + region = force_region; + } + } +} + + +PocoHTTPClient::PocoHTTPClient(const PocoHTTPClientConfiguration & client_configuration) + : per_request_configuration(client_configuration.per_request_configuration) + , error_report(client_configuration.error_report) + , timeouts(ConnectionTimeouts( + Poco::Timespan(client_configuration.connectTimeoutMs * 1000), /// connection timeout. + Poco::Timespan(client_configuration.requestTimeoutMs * 1000), /// send timeout. + Poco::Timespan(client_configuration.requestTimeoutMs * 1000), /// receive timeout. + Poco::Timespan(client_configuration.enableTcpKeepAlive ? client_configuration.tcpKeepAliveIntervalMs * 1000 : 0), + Poco::Timespan(client_configuration.http_keep_alive_timeout_ms * 1000))) /// flag indicating whether keep-alive is enabled is set to each session upon creation + , remote_host_filter(client_configuration.remote_host_filter) + , s3_max_redirects(client_configuration.s3_max_redirects) + , enable_s3_requests_logging(client_configuration.enable_s3_requests_logging) + , for_disk_s3(client_configuration.for_disk_s3) + , get_request_throttler(client_configuration.get_request_throttler) + , put_request_throttler(client_configuration.put_request_throttler) + , extra_headers(client_configuration.extra_headers) + , http_connection_pool_size(client_configuration.http_connection_pool_size) + , wait_on_pool_size_limit(client_configuration.wait_on_pool_size_limit) +{ +} + +std::shared_ptr<Aws::Http::HttpResponse> PocoHTTPClient::MakeRequest( + const std::shared_ptr<Aws::Http::HttpRequest> & request, + Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const +{ + try + { + auto response = Aws::MakeShared<PocoHTTPResponse>("PocoHTTPClient", request); + makeRequestInternal(*request, response, readLimiter, writeLimiter); + return response; + } + catch (const Exception &) + { + throw; + } + catch (const Poco::Exception & e) + { + throw Exception(Exception::CreateFromPocoTag{}, e); + } + catch (const std::exception & e) + { + throw Exception(Exception::CreateFromSTDTag{}, e); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + throw; + } +} + +namespace +{ + /// No comments: + /// 1) https://aws.amazon.com/premiumsupport/knowledge-center/s3-resolve-200-internalerror/ + /// 2) https://github.com/aws/aws-sdk-cpp/issues/658 + bool checkRequestCanReturn2xxAndErrorInBody(Aws::Http::HttpRequest & request) + { + auto query_params = request.GetQueryStringParameters(); + if (request.HasHeader("x-amz-copy-source") || request.HasHeader("x-goog-copy-source")) + { + /// CopyObject https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html + if (query_params.empty()) + return true; + + /// UploadPartCopy https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPartCopy.html + if (query_params.contains("partNumber") && query_params.contains("uploadId")) + return true; + + } + else + { + /// CompleteMultipartUpload https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html + if (query_params.size() == 1 && query_params.contains("uploadId")) + return true; + } + + return false; + } +} + +PocoHTTPClient::S3MetricKind PocoHTTPClient::getMetricKind(const Aws::Http::HttpRequest & request) +{ + switch (request.GetMethod()) + { + case Aws::Http::HttpMethod::HTTP_GET: + case Aws::Http::HttpMethod::HTTP_HEAD: + return S3MetricKind::Read; + case Aws::Http::HttpMethod::HTTP_POST: + case Aws::Http::HttpMethod::HTTP_DELETE: + case Aws::Http::HttpMethod::HTTP_PUT: + case Aws::Http::HttpMethod::HTTP_PATCH: + return S3MetricKind::Write; + } + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported request method"); +} + +void PocoHTTPClient::addMetric(const Aws::Http::HttpRequest & request, S3MetricType type, ProfileEvents::Count amount) const +{ + const ProfileEvents::Event events_map[static_cast<size_t>(S3MetricType::EnumSize)][static_cast<size_t>(S3MetricKind::EnumSize)] = { + {ProfileEvents::S3ReadMicroseconds, ProfileEvents::S3WriteMicroseconds}, + {ProfileEvents::S3ReadRequestsCount, ProfileEvents::S3WriteRequestsCount}, + {ProfileEvents::S3ReadRequestsErrors, ProfileEvents::S3WriteRequestsErrors}, + {ProfileEvents::S3ReadRequestsThrottling, ProfileEvents::S3WriteRequestsThrottling}, + {ProfileEvents::S3ReadRequestsRedirects, ProfileEvents::S3WriteRequestsRedirects}, + }; + + const ProfileEvents::Event disk_s3_events_map[static_cast<size_t>(S3MetricType::EnumSize)][static_cast<size_t>(S3MetricKind::EnumSize)] = { + {ProfileEvents::DiskS3ReadMicroseconds, ProfileEvents::DiskS3WriteMicroseconds}, + {ProfileEvents::DiskS3ReadRequestsCount, ProfileEvents::DiskS3WriteRequestsCount}, + {ProfileEvents::DiskS3ReadRequestsErrors, ProfileEvents::DiskS3WriteRequestsErrors}, + {ProfileEvents::DiskS3ReadRequestsThrottling, ProfileEvents::DiskS3WriteRequestsThrottling}, + {ProfileEvents::DiskS3ReadRequestsRedirects, ProfileEvents::DiskS3WriteRequestsRedirects}, + }; + + S3MetricKind kind = getMetricKind(request); + + ProfileEvents::increment(events_map[static_cast<unsigned int>(type)][static_cast<unsigned int>(kind)], amount); + if (for_disk_s3) + ProfileEvents::increment(disk_s3_events_map[static_cast<unsigned int>(type)][static_cast<unsigned int>(kind)], amount); +} + +void PocoHTTPClient::makeRequestInternal( + Aws::Http::HttpRequest & request, + std::shared_ptr<PocoHTTPResponse> & response, + Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const +{ + /// Most sessions in pool are already connected and it is not possible to set proxy host/port to a connected session. + const auto request_configuration = per_request_configuration(); + if (http_connection_pool_size && request_configuration.host.empty()) + makeRequestInternalImpl<true>(request, request_configuration, response, readLimiter, writeLimiter); + else + makeRequestInternalImpl<false>(request, request_configuration, response, readLimiter, writeLimiter); +} + +template <bool pooled> +void PocoHTTPClient::makeRequestInternalImpl( + Aws::Http::HttpRequest & request, + const DB::ProxyConfiguration & request_configuration, + std::shared_ptr<PocoHTTPResponse> & response, + Aws::Utils::RateLimits::RateLimiterInterface *, + Aws::Utils::RateLimits::RateLimiterInterface *) const +{ + using SessionPtr = std::conditional_t<pooled, PooledHTTPSessionPtr, HTTPSessionPtr>; + + Poco::Logger * log = &Poco::Logger::get("AWSClient"); + + auto uri = request.GetUri().GetURIString(); + + if (enable_s3_requests_logging) + LOG_TEST(log, "Make request to: {}", uri); + + switch (request.GetMethod()) + { + case Aws::Http::HttpMethod::HTTP_GET: + case Aws::Http::HttpMethod::HTTP_HEAD: + if (get_request_throttler) + { + UInt64 sleep_us = get_request_throttler->add(1, ProfileEvents::S3GetRequestThrottlerCount, ProfileEvents::S3GetRequestThrottlerSleepMicroseconds); + if (for_disk_s3) + { + ProfileEvents::increment(ProfileEvents::DiskS3GetRequestThrottlerCount); + ProfileEvents::increment(ProfileEvents::DiskS3GetRequestThrottlerSleepMicroseconds, sleep_us); + } + } + break; + case Aws::Http::HttpMethod::HTTP_PUT: + case Aws::Http::HttpMethod::HTTP_POST: + case Aws::Http::HttpMethod::HTTP_PATCH: + if (put_request_throttler) + { + UInt64 sleep_us = put_request_throttler->add(1, ProfileEvents::S3PutRequestThrottlerCount, ProfileEvents::S3PutRequestThrottlerSleepMicroseconds); + if (for_disk_s3) + { + ProfileEvents::increment(ProfileEvents::DiskS3PutRequestThrottlerCount); + ProfileEvents::increment(ProfileEvents::DiskS3PutRequestThrottlerSleepMicroseconds, sleep_us); + } + } + break; + case Aws::Http::HttpMethod::HTTP_DELETE: + break; // Not throttled + } + + addMetric(request, S3MetricType::Count); + CurrentMetrics::Increment metric_increment{CurrentMetrics::S3Requests}; + + try + { + for (unsigned int attempt = 0; attempt <= s3_max_redirects; ++attempt) + { + Poco::URI target_uri(uri); + SessionPtr session; + + if (!request_configuration.host.empty()) + { + if (enable_s3_requests_logging) + LOG_TEST(log, "Due to reverse proxy host name ({}) won't be resolved on ClickHouse side", uri); + + /// Reverse proxy can replace host header with resolved ip address instead of host name. + /// This can lead to request signature difference on S3 side. + if constexpr (pooled) + session = makePooledHTTPSession( + target_uri, timeouts, http_connection_pool_size, wait_on_pool_size_limit); + else + session = makeHTTPSession(target_uri, timeouts); + bool use_tunnel = request_configuration.protocol == DB::ProxyConfiguration::Protocol::HTTP && target_uri.getScheme() == "https"; + + // session->setProxy( + // request_configuration.proxy_host, + // request_configuration.proxy_port + // ); + } + else + { + if constexpr (pooled) + session = makePooledHTTPSession( + target_uri, timeouts, http_connection_pool_size, wait_on_pool_size_limit); + else + session = makeHTTPSession(target_uri, timeouts); + } + + /// In case of error this address will be written to logs + // request.SetResolvedRemoteHost(session->getResolvedAddress()); + + Poco::Net::HTTPRequest poco_request(Poco::Net::HTTPRequest::HTTP_1_1); + + /** According to RFC-2616, Request-URI is allowed to be encoded. + * However, there is no clear agreement on which exact symbols must be encoded. + * Effectively, `Poco::URI` chooses smaller subset of characters to encode, + * whereas Amazon S3 and Google Cloud Storage expects another one. + * In order to successfully execute a request, a path must be exact representation + * of decoded path used by `AWSAuthSigner`. + * Therefore we shall encode some symbols "manually" to fit the signatures. + */ + + std::string path_and_query; + const std::string & query = target_uri.getRawQuery(); + const std::string reserved = "?#:;+@&=%"; /// Poco::URI::RESERVED_QUERY_PARAM without '/' plus percent sign. + Poco::URI::encode(target_uri.getPath(), reserved, path_and_query); + + if (!query.empty()) + { + path_and_query += '?'; + path_and_query += query; + } + + /// `target_uri.getPath()` could return an empty string, but a proper HTTP request must + /// always contain a non-empty URI in its first line (e.g. "POST / HTTP/1.1"). + if (path_and_query.empty()) + path_and_query = "/"; + + poco_request.setURI(path_and_query); + + switch (request.GetMethod()) + { + case Aws::Http::HttpMethod::HTTP_GET: + poco_request.setMethod(Poco::Net::HTTPRequest::HTTP_GET); + break; + case Aws::Http::HttpMethod::HTTP_POST: + poco_request.setMethod(Poco::Net::HTTPRequest::HTTP_POST); + break; + case Aws::Http::HttpMethod::HTTP_DELETE: + poco_request.setMethod(Poco::Net::HTTPRequest::HTTP_DELETE); + break; + case Aws::Http::HttpMethod::HTTP_PUT: + poco_request.setMethod(Poco::Net::HTTPRequest::HTTP_PUT); + break; + case Aws::Http::HttpMethod::HTTP_HEAD: + poco_request.setMethod(Poco::Net::HTTPRequest::HTTP_HEAD); + break; + case Aws::Http::HttpMethod::HTTP_PATCH: + poco_request.setMethod(Poco::Net::HTTPRequest::HTTP_PATCH); + break; + } + + /// Headers coming from SDK are lower-cased. + for (const auto & [header_name, header_value] : request.GetHeaders()) + poco_request.set(header_name, header_value); + for (const auto & [header_name, header_value] : extra_headers) + poco_request.set(boost::algorithm::to_lower_copy(header_name), header_value); + + Poco::Net::HTTPResponse poco_response; + + Stopwatch watch; + + auto & request_body_stream = session->sendRequest(poco_request); + + if (request.GetContentBody()) + { + if (enable_s3_requests_logging) + LOG_TEST(log, "Writing request body."); + + /// Rewind content body buffer. + /// NOTE: we should do that always (even if `attempt == 0`) because the same request can be retried also by AWS, + /// see retryStrategy in Aws::Client::ClientConfiguration. + request.GetContentBody()->clear(); + request.GetContentBody()->seekg(0); + + auto size = Poco::StreamCopier::copyStream(*request.GetContentBody(), request_body_stream); + if (enable_s3_requests_logging) + LOG_TEST(log, "Written {} bytes to request body", size); + } + + if (enable_s3_requests_logging) + LOG_TEST(log, "Receiving response..."); + auto & response_body_stream = session->receiveResponse(poco_response); + + watch.stop(); + addMetric(request, S3MetricType::Microseconds, watch.elapsedMicroseconds()); + + int status_code = static_cast<int>(poco_response.getStatus()); + + if (status_code >= SUCCESS_RESPONSE_MIN && status_code <= SUCCESS_RESPONSE_MAX) + { + if (enable_s3_requests_logging) + LOG_TEST(log, "Response status: {}, {}", status_code, poco_response.getReason()); + } + else + { + /// Error statuses are more important so we show them even if `enable_s3_requests_logging == false`. + LOG_INFO(log, "Response status: {}, {}", status_code, poco_response.getReason()); + } + + if (poco_response.getStatus() == Poco::Net::HTTPResponse::HTTP_TEMPORARY_REDIRECT) + { + auto location = poco_response.get("location"); + remote_host_filter.checkURL(Poco::URI(location)); + uri = location; + if (enable_s3_requests_logging) + LOG_TEST(log, "Redirecting request to new location: {}", location); + + addMetric(request, S3MetricType::Redirects); + + continue; + } + + response->SetResponseCode(static_cast<Aws::Http::HttpResponseCode>(status_code)); + response->SetContentType(poco_response.getContentType()); + + if (enable_s3_requests_logging) + { + WriteBufferFromOwnString headers_ss; + for (const auto & [header_name, header_value] : poco_response) + { + response->AddHeader(header_name, header_value); + headers_ss << header_name << ": " << header_value << "; "; + } + LOG_TEST(log, "Received headers: {}", headers_ss.str()); + } + else + { + for (const auto & [header_name, header_value] : poco_response) + response->AddHeader(header_name, header_value); + } + + /// Request is successful but for some special requests we can have actual error message in body + if (status_code >= SUCCESS_RESPONSE_MIN && status_code <= SUCCESS_RESPONSE_MAX && checkRequestCanReturn2xxAndErrorInBody(request)) + { + std::string response_string((std::istreambuf_iterator<char>(response_body_stream)), + std::istreambuf_iterator<char>()); + + /// Just trim string so it will not be so long + LOG_TRACE(log, "Got dangerous response with successful code {}, checking its body: '{}'", status_code, response_string.substr(0, 300)); + const static std::string_view needle = "<Error>"; + if (auto it = std::search(response_string.begin(), response_string.end(), std::default_searcher(needle.begin(), needle.end())); it != response_string.end()) + { + LOG_WARNING(log, "Response for request contain <Error> tag in body, settings internal server error (500 code)"); + response->SetResponseCode(Aws::Http::HttpResponseCode::INTERNAL_SERVER_ERROR); + + addMetric(request, S3MetricType::Errors); + if (error_report) + error_report(request_configuration); + + } + + /// Set response from string + response->SetResponseBody(response_string); + } + else + { + + if (status_code == 429 || status_code == 503) + { // API throttling + addMetric(request, S3MetricType::Throttling); + } + else if (status_code >= 300) + { + addMetric(request, S3MetricType::Errors); + if (status_code >= 500 && error_report) + error_report(request_configuration); + } + response->SetResponseBody(response_body_stream, session); + } + + return; + } + throw Exception(ErrorCodes::TOO_MANY_REDIRECTS, "Too many redirects while trying to access {}", request.GetUri().GetURIString()); + } + catch (...) + { + auto error_message = getCurrentExceptionMessageAndPattern(/* with_stacktrace */ true); + error_message.text = fmt::format("Failed to make request to: {}: {}", uri, error_message.text); + LOG_INFO(log, error_message); + + response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(getCurrentExceptionMessage(false)); + + addMetric(request, S3MetricType::Errors); + + /// Probably this is socket timeout or something more or less related to DNS + /// Let's just remove this host from DNS cache to be more safe + DNSResolver::instance().removeHostFromCache(Poco::URI(uri).getHost()); + } +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/PocoHTTPClient.h b/contrib/clickhouse/src/IO/S3/PocoHTTPClient.h new file mode 100644 index 0000000000..92680072b2 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/PocoHTTPClient.h @@ -0,0 +1,198 @@ +#pragma once + +#include "clickhouse_config.h" + +#include <string> +#include <vector> + +#if USE_AWS_S3 + +#include <Common/RemoteHostFilter.h> +#include <Common/Throttler_fwd.h> +#include <Common/ProxyConfiguration.h> +#include <IO/ConnectionTimeouts.h> +#include <IO/HTTPCommon.h> +#include <IO/HTTPHeaderEntries.h> +#include <IO/S3/SessionAwareIOStream.h> + +#include <aws/core/client/ClientConfiguration.h> +#include <aws/core/http/HttpClient.h> +#include <aws/core/http/HttpRequest.h> +#include <aws/core/http/standard/StandardHttpResponse.h> + +namespace Aws::Http::Standard +{ +class StandardHttpResponse; +} + +namespace DB +{ + +class Context; +} + +namespace DB::S3 +{ +class ClientFactory; + +struct PocoHTTPClientConfiguration : public Aws::Client::ClientConfiguration +{ + std::function<DB::ProxyConfiguration()> per_request_configuration; + String force_region; + const RemoteHostFilter & remote_host_filter; + unsigned int s3_max_redirects; + bool enable_s3_requests_logging; + bool for_disk_s3; + ThrottlerPtr get_request_throttler; + ThrottlerPtr put_request_throttler; + HTTPHeaderEntries extra_headers; + + /// Not a client parameter in terms of HTTP and we won't send it to the server. Used internally to determine when connection have to be re-established. + uint32_t http_keep_alive_timeout_ms = 0; + /// Zero means pooling will not be used. + size_t http_connection_pool_size = 0; + /// See PoolBase::BehaviourOnLimit + bool wait_on_pool_size_limit = true; + + void updateSchemeAndRegion(); + + std::function<void(const DB::ProxyConfiguration &)> error_report; + +private: + PocoHTTPClientConfiguration( + std::function<DB::ProxyConfiguration()> per_request_configuration_, + const String & force_region_, + const RemoteHostFilter & remote_host_filter_, + unsigned int s3_max_redirects_, + bool enable_s3_requests_logging_, + bool for_disk_s3_, + const ThrottlerPtr & get_request_throttler_, + const ThrottlerPtr & put_request_throttler_, + std::function<void(const DB::ProxyConfiguration &)> error_report_ + ); + + /// Constructor of Aws::Client::ClientConfiguration must be called after AWS SDK initialization. + friend ClientFactory; +}; + +class PocoHTTPResponse : public Aws::Http::Standard::StandardHttpResponse +{ +public: + using SessionPtr = HTTPSessionPtr; + + explicit PocoHTTPResponse(const std::shared_ptr<const Aws::Http::HttpRequest> request) + : Aws::Http::Standard::StandardHttpResponse(request) + , body_stream(request->GetResponseStreamFactory()) + { + } + + void SetResponseBody(Aws::IStream & incoming_stream, SessionPtr & session_) /// NOLINT + { + body_stream = Aws::Utils::Stream::ResponseStream( + Aws::New<SessionAwareIOStream<SessionPtr>>("http result streambuf", session_, incoming_stream.rdbuf()) + ); + } + + void SetResponseBody(Aws::IStream & incoming_stream, PooledHTTPSessionPtr & session_) /// NOLINT + { + body_stream = Aws::Utils::Stream::ResponseStream( + Aws::New<SessionAwareIOStream<PooledHTTPSessionPtr>>("http result streambuf", session_, incoming_stream.rdbuf())); + } + + void SetResponseBody(std::string & response_body) /// NOLINT + { + auto stream = Aws::New<std::stringstream>("http result buf", response_body); // STYLE_CHECK_ALLOW_STD_STRING_STREAM + stream->exceptions(std::ios::failbit); + body_stream = Aws::Utils::Stream::ResponseStream(std::move(stream)); + } + + Aws::IOStream & GetResponseBody() const override + { + return body_stream.GetUnderlyingStream(); + } + + Aws::Utils::Stream::ResponseStream && SwapResponseStreamOwnership() override + { + return std::move(body_stream); + } + +private: + Aws::Utils::Stream::ResponseStream body_stream; +}; + +class PocoHTTPClient : public Aws::Http::HttpClient +{ +public: + explicit PocoHTTPClient(const PocoHTTPClientConfiguration & client_configuration); + ~PocoHTTPClient() override = default; + + std::shared_ptr<Aws::Http::HttpResponse> MakeRequest( + const std::shared_ptr<Aws::Http::HttpRequest> & request, + Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const override; + +private: + + void makeRequestInternal( + Aws::Http::HttpRequest & request, + std::shared_ptr<PocoHTTPResponse> & response, + Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const; + + enum class S3MetricType + { + Microseconds, + Count, + Errors, + Throttling, + Redirects, + + EnumSize, + }; + + enum class S3MetricKind + { + Read, + Write, + + EnumSize, + }; + + template <bool pooled> + void makeRequestInternalImpl( + Aws::Http::HttpRequest & request, + const DB::ProxyConfiguration & per_request_configuration, + std::shared_ptr<PocoHTTPResponse> & response, + Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const; + +protected: + static S3MetricKind getMetricKind(const Aws::Http::HttpRequest & request); + void addMetric(const Aws::Http::HttpRequest & request, S3MetricType type, ProfileEvents::Count amount = 1) const; + + std::function<DB::ProxyConfiguration()> per_request_configuration; + std::function<void(const DB::ProxyConfiguration &)> error_report; + ConnectionTimeouts timeouts; + const RemoteHostFilter & remote_host_filter; + unsigned int s3_max_redirects; + bool enable_s3_requests_logging; + bool for_disk_s3; + + /// Limits get request per second rate for GET, SELECT and all other requests, excluding throttled by put throttler + /// (i.e. throttles GetObject, HeadObject) + ThrottlerPtr get_request_throttler; + + /// Limits put request per second rate for PUT, COPY, POST, LIST requests + /// (i.e. throttles PutObject, CopyObject, ListObjects, CreateMultipartUpload, UploadPartCopy, UploadPart, CompleteMultipartUpload) + /// NOTE: DELETE and CANCEL requests are not throttled by either put or get throttler + ThrottlerPtr put_request_throttler; + + const HTTPHeaderEntries extra_headers; + + size_t http_connection_pool_size = 0; + bool wait_on_pool_size_limit = true; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/PocoHTTPClientFactory.cpp b/contrib/clickhouse/src/IO/S3/PocoHTTPClientFactory.cpp new file mode 100644 index 0000000000..87854b8f6e --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/PocoHTTPClientFactory.cpp @@ -0,0 +1,40 @@ +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include "PocoHTTPClientFactory.h" + +#include <IO/S3/PocoHTTPClient.h> +#include <aws/core/client/ClientConfiguration.h> +#include <aws/core/http/HttpRequest.h> +#include <aws/core/http/HttpResponse.h> +#include <aws/core/http/standard/StandardHttpRequest.h> + +namespace DB::S3 +{ +std::shared_ptr<Aws::Http::HttpClient> +PocoHTTPClientFactory::CreateHttpClient(const Aws::Client::ClientConfiguration & clientConfiguration) const +{ + return std::make_shared<PocoHTTPClient>(static_cast<const PocoHTTPClientConfiguration &>(clientConfiguration)); +} + +std::shared_ptr<Aws::Http::HttpRequest> PocoHTTPClientFactory::CreateHttpRequest( + const Aws::String & uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory & streamFactory) const +{ + return CreateHttpRequest(Aws::Http::URI(uri), method, streamFactory); +} + +std::shared_ptr<Aws::Http::HttpRequest> PocoHTTPClientFactory::CreateHttpRequest( + const Aws::Http::URI & uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory &) const +{ + auto request = Aws::MakeShared<Aws::Http::Standard::StandardHttpRequest>("PocoHTTPClientFactory", uri, method); + + /// Don't create default response stream. Actual response stream will be set later in PocoHTTPClient. + request->SetResponseStreamFactory(null_factory); + + return request; +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/PocoHTTPClientFactory.h b/contrib/clickhouse/src/IO/S3/PocoHTTPClientFactory.h new file mode 100644 index 0000000000..4e555f0550 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/PocoHTTPClientFactory.h @@ -0,0 +1,28 @@ +#pragma once + +#include <aws/core/http/HttpClientFactory.h> + +namespace Aws::Http +{ +class HttpClient; +class HttpRequest; +} + +namespace DB::S3 +{ +class PocoHTTPClientFactory : public Aws::Http::HttpClientFactory +{ +public: + ~PocoHTTPClientFactory() override = default; + [[nodiscard]] std::shared_ptr<Aws::Http::HttpClient> + CreateHttpClient(const Aws::Client::ClientConfiguration & clientConfiguration) const override; + [[nodiscard]] std::shared_ptr<Aws::Http::HttpRequest> + CreateHttpRequest(const Aws::String & uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory & streamFactory) const override; + [[nodiscard]] std::shared_ptr<Aws::Http::HttpRequest> + CreateHttpRequest(const Aws::Http::URI & uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory & streamFactory) const override; + +private: + const Aws::IOStreamFactory null_factory = []() { return nullptr; }; +}; + +} diff --git a/contrib/clickhouse/src/IO/S3/ProviderType.cpp b/contrib/clickhouse/src/IO/S3/ProviderType.cpp new file mode 100644 index 0000000000..5987701db6 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/ProviderType.cpp @@ -0,0 +1,40 @@ +#include <IO/S3/ProviderType.h> + +#if USE_AWS_S3 + +#include <string> + +namespace DB::S3 +{ + +std::string_view toString(ProviderType provider_type) +{ + using enum ProviderType; + + switch (provider_type) + { + case AWS: + return "AWS"; + case GCS: + return "GCS"; + case UNKNOWN: + return "Unknown"; + } +} + +std::string_view toString(ApiMode api_mode) +{ + using enum ApiMode; + + switch (api_mode) + { + case AWS: + return "AWS"; + case GCS: + return "GCS"; + } +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/ProviderType.h b/contrib/clickhouse/src/IO/S3/ProviderType.h new file mode 100644 index 0000000000..3e0ff3f36d --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/ProviderType.h @@ -0,0 +1,44 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <string_view> +#include <cstdint> + +namespace DB::S3 +{ + +/// Provider type defines the platform containing the object +/// we are trying to access +/// This information is useful for determining general support for +/// some feature like multipart copy which is currently supported by AWS +/// but not by GCS +enum class ProviderType : uint8_t +{ + AWS, + GCS, + UNKNOWN +}; + +std::string_view toString(ProviderType provider_type); + +/// Mode in which we can use the XML API +/// This value can be same as the provider type but there can be a difference +/// For example, GCS can work in both +/// AWS compatible mode (accept headers starting with x-amz) +/// and GCS mode (accept only headers starting with x-goog) +/// Because GCS mode is enforced when some features are used we +/// need to have support for both. +enum class ApiMode : uint8_t +{ + AWS, + GCS +}; + +std::string_view toString(ApiMode api_mode); + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/Requests.cpp b/contrib/clickhouse/src/IO/S3/Requests.cpp new file mode 100644 index 0000000000..2f2f8637ef --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/Requests.cpp @@ -0,0 +1,156 @@ +#include <IO/S3/Requests.h> + +#if USE_AWS_S3 + +#include <Common/logger_useful.h> +// #include <aws/core/endpoint/EndpointParameter.h> +#include <aws/core/utils/xml/XmlSerializer.h> + +namespace DB::S3 +{ + +Aws::Http::HeaderValueCollection CopyObjectRequest::GetRequestSpecificHeaders() const +{ + auto headers = Model::CopyObjectRequest::GetRequestSpecificHeaders(); + if (api_mode != ApiMode::GCS) + return headers; + + /// GCS supports same headers as S3 but with a prefix x-goog instead of x-amz + /// we have to replace all the prefixes client set internally + const auto replace_with_gcs_header = [&](const std::string & amz_header, const std::string & gcs_header) + { + if (const auto it = headers.find(amz_header); it != headers.end()) + { + auto header_value = std::move(it->second); + headers.erase(it); + headers.emplace(gcs_header, std::move(header_value)); + } + }; + + replace_with_gcs_header("x-amz-copy-source", "x-goog-copy-source"); + replace_with_gcs_header("x-amz-metadata-directive", "x-goog-metadata-directive"); + replace_with_gcs_header("x-amz-storage-class", "x-goog-storage-class"); + + /// replace all x-amz-meta- headers + std::vector<std::pair<std::string, std::string>> new_meta_headers; + for (auto it = headers.begin(); it != headers.end();) + { + if (it->first.starts_with("x-amz-meta-")) + { + auto value = std::move(it->second); + auto header = "x-goog" + it->first.substr(/* x-amz */ 5); + new_meta_headers.emplace_back(std::pair{std::move(header), std::move(value)}); + it = headers.erase(it); + } + else + ++it; + } + + for (auto & [header, value] : new_meta_headers) + headers.emplace(std::move(header), std::move(value)); + + return headers; +} + +Aws::String ComposeObjectRequest::SerializePayload() const +{ + if (component_names.empty()) + return {}; + + Aws::Utils::Xml::XmlDocument payload_doc = Aws::Utils::Xml::XmlDocument::CreateWithRootNode("ComposeRequest"); + auto root_node = payload_doc.GetRootElement(); + + for (const auto & name : component_names) + { + auto component_node = root_node.CreateChildElement("Component"); + auto name_node = component_node.CreateChildElement("Name"); + name_node.SetText(name); + } + + return payload_doc.ConvertToString(); +} + +void ComposeObjectRequest::AddQueryStringParameters(Aws::Http::URI & /*uri*/) const +{ +} + +Aws::Http::HeaderValueCollection ComposeObjectRequest::GetRequestSpecificHeaders() const +{ + if (content_type.empty()) + return {}; + + return {Aws::Http::HeaderValuePair(Aws::Http::CONTENT_TYPE_HEADER, content_type)}; +} + +// Aws::Endpoint::EndpointParameters ComposeObjectRequest::GetEndpointContextParams() const +// { +// EndpointParameters parameters; +// if (BucketHasBeenSet()) +// parameters.emplace_back("Bucket", GetBucket(), Aws::Endpoint::EndpointParameter::ParameterOrigin::OPERATION_CONTEXT); + +// return parameters; +// } + +const Aws::String & ComposeObjectRequest::GetBucket() const +{ + return bucket; +} + +bool ComposeObjectRequest::BucketHasBeenSet() const +{ + return !bucket.empty(); +} + +void ComposeObjectRequest::SetBucket(const Aws::String & value) +{ + bucket = value; +} + +void ComposeObjectRequest::SetBucket(Aws::String && value) +{ + bucket = std::move(value); +} + +void ComposeObjectRequest::SetBucket(const char * value) +{ + bucket.assign(value); +} + +const Aws::String & ComposeObjectRequest::GetKey() const +{ + return key; +} + +bool ComposeObjectRequest::KeyHasBeenSet() const +{ + return !key.empty(); +} + +void ComposeObjectRequest::SetKey(const Aws::String & value) +{ + key = value; +} + +void ComposeObjectRequest::SetKey(Aws::String && value) +{ + key = std::move(value); +} + +void ComposeObjectRequest::SetKey(const char * value) +{ + key.assign(value); +} + +void ComposeObjectRequest::SetComponentNames(std::vector<Aws::String> component_names_) +{ + component_names = std::move(component_names_); +} + +void ComposeObjectRequest::SetContentType(Aws::String value) +{ + content_type = std::move(value); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/Requests.h b/contrib/clickhouse/src/IO/S3/Requests.h new file mode 100644 index 0000000000..5d0b930e01 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/Requests.h @@ -0,0 +1,135 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <IO/S3/URI.h> +#include <IO/S3/ProviderType.h> + +// #include <aws/core/endpoint/EndpointParameter.h> +#include <aws/s3/model/HeadObjectRequest.h> +#include <aws/s3/model/ListObjectsV2Request.h> +#include <aws/s3/model/ListObjectsRequest.h> +#include <aws/s3/model/GetObjectRequest.h> +#include <aws/s3/model/AbortMultipartUploadRequest.h> +#include <aws/s3/model/CreateMultipartUploadRequest.h> +#include <aws/s3/model/CompleteMultipartUploadRequest.h> +#include <aws/s3/model/CopyObjectRequest.h> +#include <aws/s3/model/PutObjectRequest.h> +#include <aws/s3/model/UploadPartRequest.h> +#include <aws/s3/model/UploadPartCopyRequest.h> +#include <aws/s3/model/DeleteObjectRequest.h> +#include <aws/s3/model/DeleteObjectsRequest.h> + +namespace DB::S3 +{ + +namespace Model = Aws::S3::Model; + +template <typename BaseRequest> +class ExtendedRequest : public BaseRequest +{ +public: + // Aws::Endpoint::EndpointParameters GetEndpointContextParams() const override + // { + // auto params = BaseRequest::GetEndpointContextParams(); + // if (!region_override.empty()) + // params.emplace_back("Region", region_override); + + // if (uri_override.has_value()) + // { + // static const Aws::String AWS_S3_FORCE_PATH_STYLE = "ForcePathStyle"; + // params.emplace_back(AWS_S3_FORCE_PATH_STYLE, !uri_override->is_virtual_hosted_style); + // params.emplace_back("Endpoint", uri_override->endpoint); + // } + + // return params; + // } + + void overrideRegion(std::string region) const + { + region_override = std::move(region); + } + + void overrideURI(S3::URI uri) const + { + uri_override = std::move(uri); + } + + const auto & getURIOverride() const + { + return uri_override; + } + + void setApiMode(ApiMode api_mode_) const + { + api_mode = api_mode_; + } + +protected: + mutable std::string region_override; + mutable std::optional<S3::URI> uri_override; + mutable ApiMode api_mode{ApiMode::AWS}; +}; + +class CopyObjectRequest : public ExtendedRequest<Model::CopyObjectRequest> +{ +public: + Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override; +}; + +using HeadObjectRequest = ExtendedRequest<Model::HeadObjectRequest>; +using ListObjectsV2Request = ExtendedRequest<Model::ListObjectsV2Request>; +using ListObjectsRequest = ExtendedRequest<Model::ListObjectsRequest>; +using GetObjectRequest = ExtendedRequest<Model::GetObjectRequest>; + +using CreateMultipartUploadRequest = ExtendedRequest<Model::CreateMultipartUploadRequest>; +using CompleteMultipartUploadRequest = ExtendedRequest<Model::CompleteMultipartUploadRequest>; +using AbortMultipartUploadRequest = ExtendedRequest<Model::AbortMultipartUploadRequest>; +using UploadPartRequest = ExtendedRequest<Model::UploadPartRequest>; +using UploadPartCopyRequest = ExtendedRequest<Model::UploadPartCopyRequest>; + +using PutObjectRequest = ExtendedRequest<Model::PutObjectRequest>; +using DeleteObjectRequest = ExtendedRequest<Model::DeleteObjectRequest>; +using DeleteObjectsRequest = ExtendedRequest<Model::DeleteObjectsRequest>; + + +class ComposeObjectRequest : public ExtendedRequest<Aws::S3::S3Request> +{ +public: + inline const char * GetServiceRequestName() const override { return "ComposeObject"; } + + AWS_S3_API Aws::String SerializePayload() const override; + + AWS_S3_API void AddQueryStringParameters(Aws::Http::URI & uri) const override; + + AWS_S3_API Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override; + + // AWS_S3_API EndpointParameters GetEndpointContextParams() const override; + + const Aws::String & GetBucket() const; + bool BucketHasBeenSet() const; + void SetBucket(const Aws::String & value); + void SetBucket(Aws::String && value); + void SetBucket(const char* value); + + const Aws::String & GetKey() const; + bool KeyHasBeenSet() const; + void SetKey(const Aws::String & value); + void SetKey(Aws::String && value); + void SetKey(const char * value); + + void SetComponentNames(std::vector<Aws::String> component_names_); + + void SetContentType(Aws::String value); +private: + Aws::String bucket; + Aws::String key; + std::vector<Aws::String> component_names; + Aws::String content_type; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/SessionAwareIOStream.h b/contrib/clickhouse/src/IO/S3/SessionAwareIOStream.h new file mode 100644 index 0000000000..babe52545d --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/SessionAwareIOStream.h @@ -0,0 +1,30 @@ +#pragma once + +#include <iosfwd> + + +namespace DB::S3 +{ +/** + * Wrapper of IOStream to store response stream and corresponding HTTP session. + */ +template <typename Session> +class SessionAwareIOStream : public std::iostream +{ +public: + SessionAwareIOStream(Session session_, std::streambuf * sb) + : std::iostream(sb) + , session(std::move(session_)) + { + } + + Session & getSession() { return session; } + + const Session & getSession() const { return session; } + +private: + /// Poco HTTP session is holder of response stream. + Session session; +}; + +} diff --git a/contrib/clickhouse/src/IO/S3/URI.cpp b/contrib/clickhouse/src/IO/S3/URI.cpp new file mode 100644 index 0000000000..34590df539 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/URI.cpp @@ -0,0 +1,119 @@ +#include <IO/S3/URI.h> + +#if USE_AWS_S3 +#include <Common/Exception.h> +#include <Common/quoteString.h> + +#include <boost/algorithm/string/case_conv.hpp> +#include <re2/re2.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + +namespace S3 +{ + +URI::URI(const std::string & uri_) +{ + /// Case when bucket name represented in domain name of S3 URL. + /// E.g. (https://bucket-name.s3.Region.amazonaws.com/key) + /// https://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html#virtual-hosted-style-access + static const RE2 virtual_hosted_style_pattern(R"((.+)\.(s3|cos|obs|oss)([.\-][a-z0-9\-.:]+))"); + + /// Case when bucket name and key represented in path of S3 URL. + /// E.g. (https://s3.Region.amazonaws.com/bucket-name/key) + /// https://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html#path-style-access + static const RE2 path_style_pattern("^/([^/]*)/(.*)"); + + static constexpr auto S3 = "S3"; + static constexpr auto COSN = "COSN"; + static constexpr auto COS = "COS"; + static constexpr auto OBS = "OBS"; + static constexpr auto OSS = "OSS"; + + uri = Poco::URI(uri_); + + storage_name = S3; + + if (uri.getHost().empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Host is empty in S3 URI."); + + /// Extract object version ID from query string. + bool has_version_id = false; + for (const auto & [query_key, query_value] : uri.getQueryParameters()) + if (query_key == "versionId") + { + version_id = query_value; + has_version_id = true; + } + + /// Poco::URI will ignore '?' when parsing the path, but if there is a versionId in the http parameter, + /// '?' can not be used as a wildcard, otherwise it will be ambiguous. + /// If no "versionId" in the http parameter, '?' can be used as a wildcard. + /// It is necessary to encode '?' to avoid deletion during parsing path. + if (!has_version_id && uri_.find('?') != String::npos) + { + String uri_with_question_mark_encode; + Poco::URI::encode(uri_, "?", uri_with_question_mark_encode); + uri = Poco::URI(uri_with_question_mark_encode); + } + + String name; + String endpoint_authority_from_uri; + + if (re2::RE2::FullMatch(uri.getAuthority(), virtual_hosted_style_pattern, &bucket, &name, &endpoint_authority_from_uri)) + { + is_virtual_hosted_style = true; + endpoint = uri.getScheme() + "://" + name + endpoint_authority_from_uri; + validateBucket(bucket, uri); + + if (!uri.getPath().empty()) + { + /// Remove leading '/' from path to extract key. + key = uri.getPath().substr(1); + } + + boost::to_upper(name); + if (name != S3 && name != COS && name != OBS && name != OSS) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Object storage system name is unrecognized in virtual hosted style S3 URI: {}", + quoteString(name)); + + if (name == S3) + storage_name = name; + else if (name == OBS) + storage_name = OBS; + else if (name == OSS) + storage_name = OSS; + else + storage_name = COSN; + } + else if (re2::RE2::PartialMatch(uri.getPath(), path_style_pattern, &bucket, &key)) + { + is_virtual_hosted_style = false; + endpoint = uri.getScheme() + "://" + uri.getAuthority(); + validateBucket(bucket, uri); + } + else + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bucket or key name are invalid in S3 URI."); +} + +void URI::validateBucket(const String & bucket, const Poco::URI & uri) +{ + /// S3 specification requires at least 3 and at most 63 characters in bucket name. + /// https://docs.aws.amazon.com/awscloudtrail/latest/userguide/cloudtrail-s3-bucket-naming-requirements.html + if (bucket.length() < 3 || bucket.length() > 63) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bucket name length is out of bounds in virtual hosted style S3 URI: {}{}", + quoteString(bucket), !uri.empty() ? " (" + uri.toString() + ")" : ""); +} + +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/URI.h b/contrib/clickhouse/src/IO/S3/URI.h new file mode 100644 index 0000000000..9df075c732 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/URI.h @@ -0,0 +1,41 @@ +#pragma once + +#include <string> + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <Poco/URI.h> + +namespace DB::S3 +{ + +/** + * Represents S3 URI. + * + * The following patterns are allowed: + * s3://bucket/key + * http(s)://endpoint/bucket/key + */ +struct URI +{ + Poco::URI uri; + // Custom endpoint if URI scheme is not S3. + std::string endpoint; + std::string bucket; + std::string key; + std::string version_id; + std::string storage_name; + + bool is_virtual_hosted_style; + + URI() = default; + explicit URI(const std::string & uri_); + + static void validateBucket(const std::string & bucket, const Poco::URI & uri); +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/copyS3File.cpp b/contrib/clickhouse/src/IO/S3/copyS3File.cpp new file mode 100644 index 0000000000..002b8dde56 --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/copyS3File.cpp @@ -0,0 +1,850 @@ +#include <IO/S3/copyS3File.h> + +#if USE_AWS_S3 + +#include <Common/ProfileEvents.h> +#include <Common/typeid_cast.h> +#include <Interpreters/Context.h> +#include <IO/LimitSeekableReadBuffer.h> +#include <IO/S3/getObjectInfo.h> +#include <IO/SeekableReadBuffer.h> +#include <IO/StdStreamFromReadBuffer.h> +#include <IO/ReadBufferFromS3.h> + +#include <IO/S3/Requests.h> + +namespace ProfileEvents +{ + extern const Event WriteBufferFromS3Bytes; + extern const Event WriteBufferFromS3Microseconds; + extern const Event WriteBufferFromS3RequestsErrors; + + extern const Event S3CreateMultipartUpload; + extern const Event S3CompleteMultipartUpload; + extern const Event S3PutObject; + extern const Event S3CopyObject; + extern const Event S3UploadPart; + extern const Event S3UploadPartCopy; + + extern const Event DiskS3CreateMultipartUpload; + extern const Event DiskS3CompleteMultipartUpload; + extern const Event DiskS3PutObject; + extern const Event DiskS3CopyObject; + extern const Event DiskS3UploadPart; + extern const Event DiskS3UploadPartCopy; +} + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int S3_ERROR; + extern const int INVALID_CONFIG_PARAMETER; + extern const int LOGICAL_ERROR; +} + + +namespace +{ + class UploadHelper + { + public: + UploadHelper( + const std::shared_ptr<const S3::Client> & client_ptr_, + const std::shared_ptr<const S3::Client> & client_with_long_timeout_ptr_, + const String & dest_bucket_, + const String & dest_key_, + const S3Settings::RequestSettings & request_settings_, + const std::optional<std::map<String, String>> & object_metadata_, + ThreadPoolCallbackRunner<void> schedule_, + bool for_disk_s3_, + const Poco::Logger * log_) + : client_ptr(client_ptr_) + , client_with_long_timeout_ptr(client_with_long_timeout_ptr_) + , dest_bucket(dest_bucket_) + , dest_key(dest_key_) + , request_settings(request_settings_) + , upload_settings(request_settings.getUploadSettings()) + , object_metadata(object_metadata_) + , schedule(schedule_) + , for_disk_s3(for_disk_s3_) + , log(log_) + { + } + + virtual ~UploadHelper() = default; + + protected: + std::shared_ptr<const S3::Client> client_ptr; + std::shared_ptr<const S3::Client> client_with_long_timeout_ptr; + const String & dest_bucket; + const String & dest_key; + const S3Settings::RequestSettings & request_settings; + const S3Settings::RequestSettings::PartUploadSettings & upload_settings; + const std::optional<std::map<String, String>> & object_metadata; + ThreadPoolCallbackRunner<void> schedule; + bool for_disk_s3; + const Poco::Logger * log; + + struct UploadPartTask + { + std::unique_ptr<Aws::AmazonWebServiceRequest> req; + bool is_finished = false; + String tag; + std::exception_ptr exception; + }; + + size_t normal_part_size; + String multipart_upload_id; + std::atomic<bool> multipart_upload_aborted = false; + Strings part_tags; + + std::list<UploadPartTask> TSA_GUARDED_BY(bg_tasks_mutex) bg_tasks; + int num_added_bg_tasks TSA_GUARDED_BY(bg_tasks_mutex) = 0; + int num_finished_bg_tasks TSA_GUARDED_BY(bg_tasks_mutex) = 0; + std::mutex bg_tasks_mutex; + std::condition_variable bg_tasks_condvar; + + void fillCreateMultipartRequest(S3::CreateMultipartUploadRequest & request) + { + request.SetBucket(dest_bucket); + request.SetKey(dest_key); + + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + request.SetContentType("binary/octet-stream"); + + if (object_metadata.has_value()) + request.SetMetadata(object_metadata.value()); + + const auto & storage_class_name = upload_settings.storage_class_name; + if (!storage_class_name.empty()) + request.SetStorageClass(Aws::S3::Model::StorageClassMapper::GetStorageClassForName(storage_class_name)); + + client_ptr->setKMSHeaders(request); + } + + void createMultipartUpload() + { + S3::CreateMultipartUploadRequest request; + fillCreateMultipartRequest(request); + + ProfileEvents::increment(ProfileEvents::S3CreateMultipartUpload); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3CreateMultipartUpload); + + auto outcome = client_ptr->CreateMultipartUpload(request); + + if (outcome.IsSuccess()) + { + multipart_upload_id = outcome.GetResult().GetUploadId(); + LOG_TRACE(log, "Multipart upload has created. Bucket: {}, Key: {}, Upload id: {}", dest_bucket, dest_key, multipart_upload_id); + } + else + { + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + throw S3Exception(outcome.GetError().GetMessage(), outcome.GetError().GetErrorType()); + } + } + + void completeMultipartUpload() + { + if (multipart_upload_aborted) + return; + + LOG_TRACE(log, "Completing multipart upload. Bucket: {}, Key: {}, Upload_id: {}, Parts: {}", dest_bucket, dest_key, multipart_upload_id, part_tags.size()); + + if (part_tags.empty()) + throw Exception(ErrorCodes::S3_ERROR, "Failed to complete multipart upload. No parts have uploaded"); + + S3::CompleteMultipartUploadRequest request; + request.SetBucket(dest_bucket); + request.SetKey(dest_key); + request.SetUploadId(multipart_upload_id); + + Aws::S3::Model::CompletedMultipartUpload multipart_upload; + for (size_t i = 0; i < part_tags.size(); ++i) + { + Aws::S3::Model::CompletedPart part; + multipart_upload.AddParts(part.WithETag(part_tags[i]).WithPartNumber(static_cast<int>(i + 1))); + } + + request.SetMultipartUpload(multipart_upload); + + size_t max_retries = std::max(request_settings.max_unexpected_write_error_retries, 1UL); + for (size_t retries = 1;; ++retries) + { + ProfileEvents::increment(ProfileEvents::S3CompleteMultipartUpload); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3CompleteMultipartUpload); + + auto outcome = client_with_long_timeout_ptr->CompleteMultipartUpload(request); + + if (outcome.IsSuccess()) + { + LOG_TRACE(log, "Multipart upload has completed. Bucket: {}, Key: {}, Upload_id: {}, Parts: {}", dest_bucket, dest_key, multipart_upload_id, part_tags.size()); + break; + } + + if ((outcome.GetError().GetErrorType() == Aws::S3::S3Errors::NO_SUCH_KEY) && (retries < max_retries)) + { + /// For unknown reason, at least MinIO can respond with NO_SUCH_KEY for put requests + /// BTW, NO_SUCH_UPLOAD is expected error and we shouldn't retry it + LOG_INFO(log, "Multipart upload failed with NO_SUCH_KEY error for Bucket: {}, Key: {}, Upload_id: {}, Parts: {}, will retry", dest_bucket, dest_key, multipart_upload_id, part_tags.size()); + continue; /// will retry + } + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + throw S3Exception( + outcome.GetError().GetErrorType(), + "Message: {}, Key: {}, Bucket: {}, Tags: {}", + outcome.GetError().GetMessage(), dest_key, dest_bucket, fmt::join(part_tags.begin(), part_tags.end(), " ")); + } + } + + void abortMultipartUpload() + { + LOG_TRACE(log, "Aborting multipart upload. Bucket: {}, Key: {}, Upload_id: {}", dest_bucket, dest_key, multipart_upload_id); + S3::AbortMultipartUploadRequest abort_request; + abort_request.SetBucket(dest_bucket); + abort_request.SetKey(dest_key); + abort_request.SetUploadId(multipart_upload_id); + client_ptr->AbortMultipartUpload(abort_request); + multipart_upload_aborted = true; + } + + void checkObjectAfterUpload() + { + LOG_TRACE(log, "Checking object {} exists after upload", dest_key); + S3::checkObjectExists(*client_ptr, dest_bucket, dest_key, {}, request_settings, {}, "Immediately after upload"); + LOG_TRACE(log, "Object {} exists after upload", dest_key); + } + + void performMultipartUpload(size_t start_offset, size_t size) + { + calculatePartSize(size); + createMultipartUpload(); + + size_t position = start_offset; + size_t end_position = start_offset + size; + + try + { + for (size_t part_number = 1; position < end_position; ++part_number) + { + if (multipart_upload_aborted) + break; /// No more part uploads. + + size_t next_position = std::min(position + normal_part_size, end_position); + size_t part_size = next_position - position; /// `part_size` is either `normal_part_size` or smaller if it's the final part. + + Stopwatch watch; + uploadPart(part_number, position, part_size); + watch.stop(); + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Bytes, part_size); + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + + position = next_position; + } + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + // Multipart upload failed because it wasn't possible to schedule all the tasks. + // To avoid execution of already scheduled tasks we abort MultipartUpload. + abortMultipartUpload(); + waitForAllBackgroundTasks(); + throw; + } + + waitForAllBackgroundTasks(); + completeMultipartUpload(); + } + + void calculatePartSize(size_t total_size) + { + if (!total_size) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Chosen multipart upload for an empty file. This must not happen"); + + auto max_part_number = upload_settings.max_part_number; + auto min_upload_part_size = upload_settings.min_upload_part_size; + auto max_upload_part_size = upload_settings.max_upload_part_size; + + if (!max_part_number) + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "max_part_number must not be 0"); + else if (!min_upload_part_size) + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "min_upload_part_size must not be 0"); + else if (max_upload_part_size < min_upload_part_size) + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "max_upload_part_size must not be less than min_upload_part_size"); + + size_t part_size = min_upload_part_size; + size_t num_parts = (total_size + part_size - 1) / part_size; + + if (num_parts > max_part_number) + { + part_size = (total_size + max_part_number - 1) / max_part_number; + num_parts = (total_size + part_size - 1) / part_size; + } + + if (part_size > max_upload_part_size) + { + part_size = max_upload_part_size; + num_parts = (total_size + part_size - 1) / part_size; + } + + if (num_parts < 1 || num_parts > max_part_number || part_size < min_upload_part_size || part_size > max_upload_part_size) + { + String msg; + if (num_parts < 1) + msg = "Number of parts is zero"; + else if (num_parts > max_part_number) + msg = fmt::format("Number of parts exceeds {}", num_parts, max_part_number); + else if (part_size < min_upload_part_size) + msg = fmt::format("Size of a part is less than {}", part_size, min_upload_part_size); + else + msg = fmt::format("Size of a part exceeds {}", part_size, max_upload_part_size); + + throw Exception( + ErrorCodes::INVALID_CONFIG_PARAMETER, + "{} while writing {} bytes to S3. Check max_part_number = {}, " + "min_upload_part_size = {}, max_upload_part_size = {}", + msg, total_size, max_part_number, min_upload_part_size, max_upload_part_size); + } + + /// We've calculated the size of a normal part (the final part can be smaller). + normal_part_size = part_size; + } + + void uploadPart(size_t part_number, size_t part_offset, size_t part_size) + { + LOG_TRACE(log, "Writing part. Bucket: {}, Key: {}, Upload_id: {}, Size: {}", dest_bucket, dest_key, multipart_upload_id, part_size); + + if (!part_size) + { + LOG_TRACE(log, "Skipping writing an empty part."); + return; + } + + if (schedule) + { + UploadPartTask * task = nullptr; + + { + std::lock_guard lock(bg_tasks_mutex); + task = &bg_tasks.emplace_back(); + ++num_added_bg_tasks; + } + + /// Notify waiting thread when task finished + auto task_finish_notify = [this, task]() + { + std::lock_guard lock(bg_tasks_mutex); + task->is_finished = true; + ++num_finished_bg_tasks; + + /// Notification under mutex is important here. + /// Otherwise, WriteBuffer could be destroyed in between + /// Releasing lock and condvar notification. + bg_tasks_condvar.notify_one(); + }; + + try + { + task->req = fillUploadPartRequest(part_number, part_offset, part_size); + + schedule([this, task, task_finish_notify]() + { + try + { + processUploadTask(*task); + } + catch (...) + { + task->exception = std::current_exception(); + } + task_finish_notify(); + }, Priority{}); + } + catch (...) + { + task_finish_notify(); + throw; + } + } + else + { + UploadPartTask task; + task.req = fillUploadPartRequest(part_number, part_offset, part_size); + processUploadTask(task); + part_tags.push_back(task.tag); + } + } + + void processUploadTask(UploadPartTask & task) + { + if (multipart_upload_aborted) + return; /// Already aborted. + + auto tag = processUploadPartRequest(*task.req); + + std::lock_guard lock(bg_tasks_mutex); /// Protect bg_tasks from race + task.tag = tag; + LOG_TRACE(log, "Writing part finished. Bucket: {}, Key: {}, Upload_id: {}, Etag: {}, Parts: {}", dest_bucket, dest_key, multipart_upload_id, task.tag, bg_tasks.size()); + } + + virtual std::unique_ptr<Aws::AmazonWebServiceRequest> fillUploadPartRequest(size_t part_number, size_t part_offset, size_t part_size) = 0; + virtual String processUploadPartRequest(Aws::AmazonWebServiceRequest & request) = 0; + + void waitForAllBackgroundTasks() + { + if (!schedule) + return; + + std::unique_lock lock(bg_tasks_mutex); + /// Suppress warnings because bg_tasks_mutex is actually hold, but tsa annotations do not understand std::unique_lock + bg_tasks_condvar.wait(lock, [this]() {return TSA_SUPPRESS_WARNING_FOR_READ(num_added_bg_tasks) == TSA_SUPPRESS_WARNING_FOR_READ(num_finished_bg_tasks); }); + + auto & tasks = TSA_SUPPRESS_WARNING_FOR_WRITE(bg_tasks); + for (auto & task : tasks) + { + if (task.exception) + { + /// abortMultipartUpload() might be called already, see processUploadPartRequest(). + /// However if there were concurrent uploads at that time, those part uploads might or might not succeed. + /// As a result, it might be necessary to abort a given multipart upload multiple times in order to completely free + /// all storage consumed by all parts. + abortMultipartUpload(); + + std::rethrow_exception(task.exception); + } + + part_tags.push_back(task.tag); + } + } + }; + + /// Helper class to help implementing copyDataToS3File(). + class CopyDataToFileHelper : public UploadHelper + { + public: + CopyDataToFileHelper( + const CreateReadBuffer & create_read_buffer_, + size_t offset_, + size_t size_, + const std::shared_ptr<const S3::Client> & client_ptr_, + const std::shared_ptr<const S3::Client> & client_with_long_timeout_ptr_, + const String & dest_bucket_, + const String & dest_key_, + const S3Settings::RequestSettings & request_settings_, + const std::optional<std::map<String, String>> & object_metadata_, + ThreadPoolCallbackRunner<void> schedule_, + bool for_disk_s3_) + : UploadHelper(client_ptr_, client_with_long_timeout_ptr_, dest_bucket_, dest_key_, request_settings_, object_metadata_, schedule_, for_disk_s3_, &Poco::Logger::get("copyDataToS3File")) + , create_read_buffer(create_read_buffer_) + , offset(offset_) + , size(size_) + { + } + + void performCopy() + { + if (size <= upload_settings.max_single_part_upload_size) + performSinglepartUpload(); + else + performMultipartUpload(); + + if (request_settings.check_objects_after_upload) + checkObjectAfterUpload(); + } + + private: + std::function<std::unique_ptr<SeekableReadBuffer>()> create_read_buffer; + size_t offset; + size_t size; + + void performSinglepartUpload() + { + S3::PutObjectRequest request; + fillPutRequest(request); + processPutRequest(request); + } + + void fillPutRequest(S3::PutObjectRequest & request) + { + auto read_buffer = std::make_unique<LimitSeekableReadBuffer>(create_read_buffer(), offset, size); + + request.SetBucket(dest_bucket); + request.SetKey(dest_key); + request.SetContentLength(size); + request.SetBody(std::make_unique<StdStreamFromReadBuffer>(std::move(read_buffer), size)); + + if (object_metadata.has_value()) + request.SetMetadata(object_metadata.value()); + + const auto & storage_class_name = upload_settings.storage_class_name; + if (!storage_class_name.empty()) + request.SetStorageClass(Aws::S3::Model::StorageClassMapper::GetStorageClassForName(storage_class_name)); + + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + request.SetContentType("binary/octet-stream"); + + client_ptr->setKMSHeaders(request); + } + + void processPutRequest(const S3::PutObjectRequest & request) + { + size_t max_retries = std::max(request_settings.max_unexpected_write_error_retries, 1UL); + for (size_t retries = 1;; ++retries) + { + ProfileEvents::increment(ProfileEvents::S3PutObject); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3PutObject); + + Stopwatch watch; + auto outcome = client_ptr->PutObject(request); + watch.stop(); + + if (outcome.IsSuccess()) + { + Int64 object_size = request.GetContentLength(); + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Bytes, object_size); + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + LOG_TRACE( + log, + "Single part upload has completed. Bucket: {}, Key: {}, Object size: {}", + dest_bucket, + dest_key, + object_size); + break; + } + + if (outcome.GetError().GetExceptionName() == "EntityTooLarge" || outcome.GetError().GetExceptionName() == "InvalidRequest") + { + // Can't come here with MinIO, MinIO allows single part upload for large objects. + LOG_INFO( + log, + "Single part upload failed with error {} for Bucket: {}, Key: {}, Object size: {}, will retry with multipart upload", + outcome.GetError().GetExceptionName(), + dest_bucket, + dest_key, + size); + performMultipartUpload(); + break; + } + + if ((outcome.GetError().GetErrorType() == Aws::S3::S3Errors::NO_SUCH_KEY) && (retries < max_retries)) + { + /// For unknown reason, at least MinIO can respond with NO_SUCH_KEY for put requests + LOG_INFO( + log, + "Single part upload failed with NO_SUCH_KEY error for Bucket: {}, Key: {}, Object size: {}, will retry", + dest_bucket, + dest_key, + request.GetContentLength()); + continue; /// will retry + } + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + throw S3Exception( + outcome.GetError().GetErrorType(), + "Message: {}, Key: {}, Bucket: {}, Object size: {}", + outcome.GetError().GetMessage(), + dest_key, + dest_bucket, + request.GetContentLength()); + } + } + + void performMultipartUpload() { UploadHelper::performMultipartUpload(offset, size); } + + std::unique_ptr<Aws::AmazonWebServiceRequest> fillUploadPartRequest(size_t part_number, size_t part_offset, size_t part_size) override + { + auto read_buffer = std::make_unique<LimitSeekableReadBuffer>(create_read_buffer(), part_offset, part_size); + + /// Setup request. + auto request = std::make_unique<S3::UploadPartRequest>(); + request->SetBucket(dest_bucket); + request->SetKey(dest_key); + request->SetPartNumber(static_cast<int>(part_number)); + request->SetUploadId(multipart_upload_id); + request->SetContentLength(part_size); + request->SetBody(std::make_unique<StdStreamFromReadBuffer>(std::move(read_buffer), part_size)); + + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + request->SetContentType("binary/octet-stream"); + + return request; + } + + String processUploadPartRequest(Aws::AmazonWebServiceRequest & request) override + { + auto & req = typeid_cast<S3::UploadPartRequest &>(request); + + ProfileEvents::increment(ProfileEvents::S3UploadPart); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3UploadPart); + + auto outcome = client_ptr->UploadPart(req); + if (!outcome.IsSuccess()) + { + abortMultipartUpload(); + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + throw S3Exception(outcome.GetError().GetMessage(), outcome.GetError().GetErrorType()); + } + + return outcome.GetResult().GetETag(); + } + }; + + /// Helper class to help implementing copyS3File(). + class CopyFileHelper : public UploadHelper + { + public: + CopyFileHelper( + const std::shared_ptr<const S3::Client> & client_ptr_, + const std::shared_ptr<const S3::Client> & client_with_long_timeout_ptr_, + const String & src_bucket_, + const String & src_key_, + size_t src_offset_, + size_t src_size_, + const String & dest_bucket_, + const String & dest_key_, + const S3Settings::RequestSettings & request_settings_, + const std::optional<std::map<String, String>> & object_metadata_, + ThreadPoolCallbackRunner<void> schedule_, + bool for_disk_s3_) + : UploadHelper(client_ptr_, client_with_long_timeout_ptr_, dest_bucket_, dest_key_, request_settings_, object_metadata_, schedule_, for_disk_s3_, &Poco::Logger::get("copyS3File")) + , src_bucket(src_bucket_) + , src_key(src_key_) + , offset(src_offset_) + , size(src_size_) + , supports_multipart_copy(client_ptr_->supportsMultiPartCopy()) + { + } + + void performCopy() + { + if (!supports_multipart_copy || size <= upload_settings.max_single_operation_copy_size) + performSingleOperationCopy(); + else + performMultipartUploadCopy(); + + if (request_settings.check_objects_after_upload) + checkObjectAfterUpload(); + } + + private: + const String & src_bucket; + const String & src_key; + size_t offset; + size_t size; + bool supports_multipart_copy; + + CreateReadBuffer getSourceObjectReadBuffer() + { + return [&] + { + return std::make_unique<ReadBufferFromS3>(client_ptr, src_bucket, src_key, "", request_settings, Context::getGlobalContextInstance()->getReadSettings()); + }; + } + + void performSingleOperationCopy() + { + S3::CopyObjectRequest request; + fillCopyRequest(request); + processCopyRequest(request); + } + + void fillCopyRequest(S3::CopyObjectRequest & request) + { + request.SetCopySource(src_bucket + "/" + src_key); + request.SetBucket(dest_bucket); + request.SetKey(dest_key); + + if (object_metadata.has_value()) + { + request.SetMetadata(object_metadata.value()); + request.SetMetadataDirective(Aws::S3::Model::MetadataDirective::REPLACE); + } + + const auto & storage_class_name = upload_settings.storage_class_name; + if (!storage_class_name.empty()) + request.SetStorageClass(Aws::S3::Model::StorageClassMapper::GetStorageClassForName(storage_class_name)); + + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + request.SetContentType("binary/octet-stream"); + + client_with_long_timeout_ptr->setKMSHeaders(request); + } + + void processCopyRequest(const S3::CopyObjectRequest & request) + { + size_t max_retries = std::max(request_settings.max_unexpected_write_error_retries, 1UL); + for (size_t retries = 1;; ++retries) + { + ProfileEvents::increment(ProfileEvents::S3CopyObject); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3CopyObject); + + auto outcome = client_with_long_timeout_ptr->CopyObject(request); + if (outcome.IsSuccess()) + { + LOG_TRACE( + log, + "Single operation copy has completed. Bucket: {}, Key: {}, Object size: {}", + dest_bucket, + dest_key, + size); + break; + } + + if (outcome.GetError().GetExceptionName() == "EntityTooLarge" || outcome.GetError().GetExceptionName() == "InvalidRequest" || outcome.GetError().GetExceptionName() == "InvalidArgument") + { + if (!supports_multipart_copy) + { + LOG_INFO(log, "Multipart upload using copy is not supported, will try regular upload for Bucket: {}, Key: {}, Object size: {}", + dest_bucket, + dest_key, + size); + copyDataToS3File( + getSourceObjectReadBuffer(), + offset, + size, + client_ptr, + client_with_long_timeout_ptr, + dest_bucket, + dest_key, + request_settings, + object_metadata, + schedule, + for_disk_s3); + break; + } + else + { + // Can't come here with MinIO, MinIO allows single part upload for large objects. + LOG_INFO( + log, + "Single operation copy failed with error {} for Bucket: {}, Key: {}, Object size: {}, will retry with multipart " + "upload copy", + outcome.GetError().GetExceptionName(), + dest_bucket, + dest_key, + size); + + performMultipartUploadCopy(); + break; + } + } + + if ((outcome.GetError().GetErrorType() == Aws::S3::S3Errors::NO_SUCH_KEY) && (retries < max_retries)) + { + /// TODO: Is it true for copy requests? + /// For unknown reason, at least MinIO can respond with NO_SUCH_KEY for put requests + LOG_INFO( + log, + "Single operation copy failed with NO_SUCH_KEY error for Bucket: {}, Key: {}, Object size: {}, will retry", + dest_bucket, + dest_key, + size); + continue; /// will retry + } + + throw S3Exception( + outcome.GetError().GetErrorType(), + "Message: {}, Key: {}, Bucket: {}, Object size: {}", + outcome.GetError().GetMessage(), + dest_key, + dest_bucket, + size); + } + } + + void performMultipartUploadCopy() { UploadHelper::performMultipartUpload(offset, size); } + + std::unique_ptr<Aws::AmazonWebServiceRequest> fillUploadPartRequest(size_t part_number, size_t part_offset, size_t part_size) override + { + auto request = std::make_unique<S3::UploadPartCopyRequest>(); + + /// Make a copy request to copy a part. + request->SetCopySource(src_bucket + "/" + src_key); + request->SetBucket(dest_bucket); + request->SetKey(dest_key); + request->SetUploadId(multipart_upload_id); + request->SetPartNumber(static_cast<int>(part_number)); + request->SetCopySourceRange(fmt::format("bytes={}-{}", part_offset, part_offset + part_size - 1)); + + return request; + } + + String processUploadPartRequest(Aws::AmazonWebServiceRequest & request) override + { + auto & req = typeid_cast<S3::UploadPartCopyRequest &>(request); + + ProfileEvents::increment(ProfileEvents::S3UploadPartCopy); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3UploadPartCopy); + + auto outcome = client_with_long_timeout_ptr->UploadPartCopy(req); + if (!outcome.IsSuccess()) + { + abortMultipartUpload(); + throw S3Exception(outcome.GetError().GetMessage(), outcome.GetError().GetErrorType()); + } + + return outcome.GetResult().GetCopyPartResult().GetETag(); + } + }; +} + + +void copyDataToS3File( + const std::function<std::unique_ptr<SeekableReadBuffer>()> & create_read_buffer, + size_t offset, + size_t size, + const std::shared_ptr<const S3::Client> & dest_s3_client, + const std::shared_ptr<const S3::Client> & dest_s3_client_with_long_timeout, + const String & dest_bucket, + const String & dest_key, + const S3Settings::RequestSettings & settings, + const std::optional<std::map<String, String>> & object_metadata, + ThreadPoolCallbackRunner<void> schedule, + bool for_disk_s3) +{ + CopyDataToFileHelper helper{create_read_buffer, offset, size, dest_s3_client, dest_s3_client_with_long_timeout, dest_bucket, dest_key, settings, object_metadata, schedule, for_disk_s3}; + helper.performCopy(); +} + + +void copyS3File( + const std::shared_ptr<const S3::Client> & s3_client, + const std::shared_ptr<const S3::Client> & s3_client_with_long_timeout, + const String & src_bucket, + const String & src_key, + size_t src_offset, + size_t src_size, + const String & dest_bucket, + const String & dest_key, + const S3Settings::RequestSettings & settings, + const std::optional<std::map<String, String>> & object_metadata, + ThreadPoolCallbackRunner<void> schedule, + bool for_disk_s3) +{ + if (settings.allow_native_copy) + { + CopyFileHelper helper{s3_client, s3_client_with_long_timeout, src_bucket, src_key, src_offset, src_size, dest_bucket, dest_key, settings, object_metadata, schedule, for_disk_s3}; + helper.performCopy(); + } + else + { + auto create_read_buffer = [&] + { + return std::make_unique<ReadBufferFromS3>(s3_client, src_bucket, src_key, "", settings, Context::getGlobalContextInstance()->getReadSettings()); + }; + copyDataToS3File(create_read_buffer, src_offset, src_size, s3_client, s3_client_with_long_timeout, dest_bucket, dest_key, settings, object_metadata, schedule, for_disk_s3); + } +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/copyS3File.h b/contrib/clickhouse/src/IO/S3/copyS3File.h new file mode 100644 index 0000000000..b39b7469eb --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/copyS3File.h @@ -0,0 +1,68 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <Storages/StorageS3Settings.h> +#include <Interpreters/threadPoolCallbackRunner.h> +#include <base/types.h> +#include <functional> +#include <memory> + + +namespace DB +{ +class SeekableReadBuffer; + +using CreateReadBuffer = std::function<std::unique_ptr<SeekableReadBuffer>()>; + +/// Copies a file from S3 to S3. +/// The same functionality can be done by using the function copyData() and the classes ReadBufferFromS3 and WriteBufferFromS3 +/// however copyS3File() is faster and spends less network traffic and memory. +/// The parameters `src_offset` and `src_size` specify a part in the source to copy. +/// +/// Note, that it tries to copy file using native copy (CopyObject), but if it +/// has been disabled (with settings.allow_native_copy) or request failed +/// because it is a known issue, it is fallbacks to read-write copy +/// (copyDataToS3File()). +/// +/// s3_client_with_long_timeout (may be equal to s3_client) is used for native copy and +/// CompleteMultipartUpload requests. These requests need longer timeout because S3 servers often +/// block on them for multiple seconds without sending or receiving data from us (maybe the servers +/// are copying data internally, or maybe throttling, idk). +void copyS3File( + const std::shared_ptr<const S3::Client> & s3_client, + const std::shared_ptr<const S3::Client> & s3_client_with_long_timeout, + const String & src_bucket, + const String & src_key, + size_t src_offset, + size_t src_size, + const String & dest_bucket, + const String & dest_key, + const S3Settings::RequestSettings & settings, + const std::optional<std::map<String, String>> & object_metadata = std::nullopt, + ThreadPoolCallbackRunner<void> schedule_ = {}, + bool for_disk_s3 = false); + +/// Copies data from any seekable source to S3. +/// The same functionality can be done by using the function copyData() and the class WriteBufferFromS3 +/// however copyDataToS3File() is faster and spends less memory. +/// The callback `create_read_buffer` can be called from multiple threads in parallel, so that should be thread-safe. +/// The parameters `offset` and `size` specify a part in the source to copy. +void copyDataToS3File( + const CreateReadBuffer & create_read_buffer, + size_t offset, + size_t size, + const std::shared_ptr<const S3::Client> & dest_s3_client, + const std::shared_ptr<const S3::Client> & dest_s3_client_with_long_timeout, + const String & dest_bucket, + const String & dest_key, + const S3Settings::RequestSettings & settings, + const std::optional<std::map<String, String>> & object_metadata = std::nullopt, + ThreadPoolCallbackRunner<void> schedule_ = {}, + bool for_disk_s3 = false); + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/getObjectInfo.cpp b/contrib/clickhouse/src/IO/S3/getObjectInfo.cpp new file mode 100644 index 0000000000..88f79f8d8d --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/getObjectInfo.cpp @@ -0,0 +1,144 @@ +#include <IO/S3/getObjectInfo.h> + +#if USE_AWS_S3 + +namespace ErrorCodes +{ + extern const int S3_ERROR; +} + + +namespace ProfileEvents +{ + extern const Event S3GetObject; + extern const Event S3GetObjectAttributes; + extern const Event S3HeadObject; + extern const Event DiskS3GetObject; + extern const Event DiskS3GetObjectAttributes; + extern const Event DiskS3HeadObject; +} + + +namespace DB::S3 +{ + +namespace +{ + Aws::S3::Model::HeadObjectOutcome headObject( + const S3::Client & client, const String & bucket, const String & key, const String & version_id, bool for_disk_s3) + { + ProfileEvents::increment(ProfileEvents::S3HeadObject); + if (for_disk_s3) + ProfileEvents::increment(ProfileEvents::DiskS3HeadObject); + + S3::HeadObjectRequest req; + req.SetBucket(bucket); + req.SetKey(key); + + if (!version_id.empty()) + req.SetVersionId(version_id); + + return client.HeadObject(req); + } + + /// Performs a request to get the size and last modification time of an object. + std::pair<std::optional<ObjectInfo>, Aws::S3::S3Error> tryGetObjectInfo( + const S3::Client & client, const String & bucket, const String & key, const String & version_id, + const S3Settings::RequestSettings & /*request_settings*/, bool with_metadata, bool for_disk_s3) + { + auto outcome = headObject(client, bucket, key, version_id, for_disk_s3); + if (!outcome.IsSuccess()) + return {std::nullopt, outcome.GetError()}; + + const auto & result = outcome.GetResult(); + ObjectInfo object_info; + object_info.size = static_cast<size_t>(result.GetContentLength()); + object_info.last_modification_time = result.GetLastModified().Millis() / 1000; + + if (with_metadata) + object_info.metadata = result.GetMetadata(); + + return {object_info, {}}; + } +} + + +bool isNotFoundError(Aws::S3::S3Errors error) +{ + return error == Aws::S3::S3Errors::RESOURCE_NOT_FOUND || error == Aws::S3::S3Errors::NO_SUCH_KEY; +} + +ObjectInfo getObjectInfo( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id, + const S3Settings::RequestSettings & request_settings, + bool with_metadata, + bool for_disk_s3, + bool throw_on_error) +{ + auto [object_info, error] = tryGetObjectInfo(client, bucket, key, version_id, request_settings, with_metadata, for_disk_s3); + if (object_info) + { + return *object_info; + } + else if (throw_on_error) + { + throw S3Exception(error.GetErrorType(), + "Failed to get object info: {}. HTTP response code: {}", + error.GetMessage(), static_cast<size_t>(error.GetResponseCode())); + } + return {}; +} + +size_t getObjectSize( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id, + const S3Settings::RequestSettings & request_settings, + bool for_disk_s3, + bool throw_on_error) +{ + return getObjectInfo(client, bucket, key, version_id, request_settings, {}, for_disk_s3, throw_on_error).size; +} + +bool objectExists( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id, + const S3Settings::RequestSettings & request_settings, + bool for_disk_s3) +{ + auto [object_info, error] = tryGetObjectInfo(client, bucket, key, version_id, request_settings, {}, for_disk_s3); + if (object_info) + return true; + + if (isNotFoundError(error.GetErrorType())) + return false; + + throw S3Exception(error.GetErrorType(), + "Failed to check existence of key {} in bucket {}: {}", + key, bucket, error.GetMessage()); +} + +void checkObjectExists( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id, + const S3Settings::RequestSettings & request_settings, + bool for_disk_s3, + std::string_view description) +{ + auto [object_info, error] = tryGetObjectInfo(client, bucket, key, version_id, request_settings, {}, for_disk_s3); + if (object_info) + return; + throw S3Exception(error.GetErrorType(), "{}Object {} in bucket {} suddenly disappeared: {}", + (description.empty() ? "" : (String(description) + ": ")), key, bucket, error.GetMessage()); +} +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3/getObjectInfo.h b/contrib/clickhouse/src/IO/S3/getObjectInfo.h new file mode 100644 index 0000000000..8804a9494e --- /dev/null +++ b/contrib/clickhouse/src/IO/S3/getObjectInfo.h @@ -0,0 +1,63 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 +#include <Storages/StorageS3Settings.h> +#include <base/types.h> +#include <IO/S3/Client.h> + + +namespace DB::S3 +{ + +struct ObjectInfo +{ + size_t size = 0; + time_t last_modification_time = 0; + + std::map<String, String> metadata = {}; /// Set only if getObjectInfo() is called with `with_metadata = true`. +}; + +ObjectInfo getObjectInfo( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id = {}, + const S3Settings::RequestSettings & request_settings = {}, + bool with_metadata = false, + bool for_disk_s3 = false, + bool throw_on_error = true); + +size_t getObjectSize( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id = {}, + const S3Settings::RequestSettings & request_settings = {}, + bool for_disk_s3 = false, + bool throw_on_error = true); + +bool objectExists( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id = {}, + const S3Settings::RequestSettings & request_settings = {}, + bool for_disk_s3 = false); + +/// Throws an exception if a specified object doesn't exist. `description` is used as a part of the error message. +void checkObjectExists( + const S3::Client & client, + const String & bucket, + const String & key, + const String & version_id = {}, + const S3Settings::RequestSettings & request_settings = {}, + bool for_disk_s3 = false, + std::string_view description = {}); + +bool isNotFoundError(Aws::S3::S3Errors error); + +} + +#endif diff --git a/contrib/clickhouse/src/IO/S3Common.cpp b/contrib/clickhouse/src/IO/S3Common.cpp new file mode 100644 index 0000000000..115877530f --- /dev/null +++ b/contrib/clickhouse/src/IO/S3Common.cpp @@ -0,0 +1,178 @@ +#include <IO/S3Common.h> + +#include <Common/Exception.h> +#include <Poco/Util/AbstractConfiguration.h> +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +# include <Common/quoteString.h> + +# include <IO/WriteBufferFromString.h> +# include <IO/HTTPHeaderEntries.h> +# include <Storages/StorageS3Settings.h> + +# include <IO/S3/PocoHTTPClientFactory.h> +# include <IO/S3/PocoHTTPClient.h> +# include <IO/S3/Client.h> +# include <IO/S3/URI.h> +# include <IO/S3/Requests.h> +# include <IO/S3/Credentials.h> +# include <Common/logger_useful.h> + +# include <fstream> + +namespace ProfileEvents +{ + extern const Event S3GetObjectAttributes; + extern const Event S3GetObjectMetadata; + extern const Event S3HeadObject; + extern const Event DiskS3GetObjectAttributes; + extern const Event DiskS3GetObjectMetadata; + extern const Event DiskS3HeadObject; +} + +namespace DB +{ + +bool S3Exception::isRetryableError() const +{ + /// Looks like these list is quite conservative, add more codes if you wish + static const std::unordered_set<Aws::S3::S3Errors> unretryable_errors = { + Aws::S3::S3Errors::NO_SUCH_KEY, + Aws::S3::S3Errors::ACCESS_DENIED, + Aws::S3::S3Errors::INVALID_ACCESS_KEY_ID, + Aws::S3::S3Errors::INVALID_SIGNATURE, + Aws::S3::S3Errors::NO_SUCH_UPLOAD, + Aws::S3::S3Errors::NO_SUCH_BUCKET, + }; + + return !unretryable_errors.contains(code); +} + +} + +namespace DB::ErrorCodes +{ + extern const int S3_ERROR; +} + +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_CONFIG_PARAMETER; +} + +namespace S3 +{ + +HTTPHeaderEntries getHTTPHeaders(const std::string & config_elem, const Poco::Util::AbstractConfiguration & config) +{ + HTTPHeaderEntries headers; + Poco::Util::AbstractConfiguration::Keys subconfig_keys; + config.keys(config_elem, subconfig_keys); + for (const std::string & subkey : subconfig_keys) + { + if (subkey.starts_with("header")) + { + auto header_str = config.getString(config_elem + "." + subkey); + auto delimiter = header_str.find(':'); + if (delimiter == std::string::npos) + throw Exception(ErrorCodes::INVALID_CONFIG_PARAMETER, "Malformed s3 header value"); + headers.emplace_back(header_str.substr(0, delimiter), header_str.substr(delimiter + 1, String::npos)); + } + } + return headers; +} + +ServerSideEncryptionKMSConfig getSSEKMSConfig(const std::string & config_elem, const Poco::Util::AbstractConfiguration & config) +{ + ServerSideEncryptionKMSConfig sse_kms_config; + + if (config.has(config_elem + ".server_side_encryption_kms_key_id")) + sse_kms_config.key_id = config.getString(config_elem + ".server_side_encryption_kms_key_id"); + + if (config.has(config_elem + ".server_side_encryption_kms_encryption_context")) + sse_kms_config.encryption_context = config.getString(config_elem + ".server_side_encryption_kms_encryption_context"); + + if (config.has(config_elem + ".server_side_encryption_kms_bucket_key_enabled")) + sse_kms_config.bucket_key_enabled = config.getBool(config_elem + ".server_side_encryption_kms_bucket_key_enabled"); + + return sse_kms_config; +} + +AuthSettings AuthSettings::loadFromConfig(const std::string & config_elem, const Poco::Util::AbstractConfiguration & config) +{ + auto access_key_id = config.getString(config_elem + ".access_key_id", ""); + auto secret_access_key = config.getString(config_elem + ".secret_access_key", ""); + auto region = config.getString(config_elem + ".region", ""); + auto server_side_encryption_customer_key_base64 = config.getString(config_elem + ".server_side_encryption_customer_key_base64", ""); + + std::optional<bool> use_environment_credentials; + if (config.has(config_elem + ".use_environment_credentials")) + use_environment_credentials = config.getBool(config_elem + ".use_environment_credentials"); + + std::optional<bool> use_insecure_imds_request; + if (config.has(config_elem + ".use_insecure_imds_request")) + use_insecure_imds_request = config.getBool(config_elem + ".use_insecure_imds_request"); + + std::optional<uint64_t> expiration_window_seconds; + if (config.has(config_elem + ".expiration_window_seconds")) + expiration_window_seconds = config.getUInt64(config_elem + ".expiration_window_seconds"); + + std::optional<bool> no_sign_request; + if (config.has(config_elem + ".no_sign_request")) + no_sign_request = config.getBool(config_elem + ".no_sign_request"); + + HTTPHeaderEntries headers = getHTTPHeaders(config_elem, config); + ServerSideEncryptionKMSConfig sse_kms_config = getSSEKMSConfig(config_elem, config); + + return AuthSettings + { + std::move(access_key_id), std::move(secret_access_key), + std::move(region), + std::move(server_side_encryption_customer_key_base64), + std::move(sse_kms_config), + std::move(headers), + use_environment_credentials, + use_insecure_imds_request, + expiration_window_seconds, + no_sign_request + }; +} + + +void AuthSettings::updateFrom(const AuthSettings & from) +{ + /// Update with check for emptyness only parameters which + /// can be passed not only from config, but via ast. + + if (!from.access_key_id.empty()) + access_key_id = from.access_key_id; + if (!from.secret_access_key.empty()) + secret_access_key = from.secret_access_key; + + headers = from.headers; + region = from.region; + server_side_encryption_customer_key_base64 = from.server_side_encryption_customer_key_base64; + server_side_encryption_kms_config = from.server_side_encryption_kms_config; + + if (from.use_environment_credentials.has_value()) + use_environment_credentials = from.use_environment_credentials; + + if (from.use_insecure_imds_request.has_value()) + use_insecure_imds_request = from.use_insecure_imds_request; + + if (from.expiration_window_seconds.has_value()) + expiration_window_seconds = from.expiration_window_seconds; + + if (from.no_sign_request.has_value()) + no_sign_request = *from.no_sign_request; +} + +} +} diff --git a/contrib/clickhouse/src/IO/S3Common.h b/contrib/clickhouse/src/IO/S3Common.h new file mode 100644 index 0000000000..881edfcc9b --- /dev/null +++ b/contrib/clickhouse/src/IO/S3Common.h @@ -0,0 +1,98 @@ +#pragma once + +#include <IO/S3/Client.h> +#include <IO/S3/PocoHTTPClient.h> +#include <IO/HTTPHeaderEntries.h> + +#include <string> +#include <optional> + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <base/types.h> +#include <Common/Exception.h> +#include <Common/Throttler_fwd.h> + +#include <IO/S3/URI.h> + +#include <aws/core/Aws.h> +#include <aws/s3/S3Errors.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int S3_ERROR; +} + +class RemoteHostFilter; + +class S3Exception : public Exception +{ +public: + + // Format message with fmt::format, like the logging functions. + template <typename... Args> + S3Exception(Aws::S3::S3Errors code_, fmt::format_string<Args...> fmt, Args &&... args) + : Exception(fmt::format(fmt, std::forward<Args>(args)...), ErrorCodes::S3_ERROR) + , code(code_) + { + } + + S3Exception(const std::string & msg, Aws::S3::S3Errors code_) + : Exception(msg, ErrorCodes::S3_ERROR) + , code(code_) + {} + + Aws::S3::S3Errors getS3ErrorCode() const + { + return code; + } + + bool isRetryableError() const; + +private: + Aws::S3::S3Errors code; +}; +} + +#endif + +namespace Poco::Util +{ + class AbstractConfiguration; +}; + +namespace DB::S3 +{ + +HTTPHeaderEntries getHTTPHeaders(const std::string & config_elem, const Poco::Util::AbstractConfiguration & config); + +ServerSideEncryptionKMSConfig getSSEKMSConfig(const std::string & config_elem, const Poco::Util::AbstractConfiguration & config); + +struct AuthSettings +{ + static AuthSettings loadFromConfig(const std::string & config_elem, const Poco::Util::AbstractConfiguration & config); + + std::string access_key_id; + std::string secret_access_key; + std::string region; + std::string server_side_encryption_customer_key_base64; + ServerSideEncryptionKMSConfig server_side_encryption_kms_config; + + HTTPHeaderEntries headers; + + std::optional<bool> use_environment_credentials; + std::optional<bool> use_insecure_imds_request; + std::optional<uint64_t> expiration_window_seconds; + std::optional<bool> no_sign_request; + + bool operator==(const AuthSettings & other) const = default; + + void updateFrom(const AuthSettings & from); +}; + +} diff --git a/contrib/clickhouse/src/IO/SchedulerNodeFactory.h b/contrib/clickhouse/src/IO/SchedulerNodeFactory.h new file mode 100644 index 0000000000..5c31534a9b --- /dev/null +++ b/contrib/clickhouse/src/IO/SchedulerNodeFactory.h @@ -0,0 +1,57 @@ +#pragma once + +#include <Common/ErrorCodes.h> +#include <Common/Exception.h> + +#include <IO/ISchedulerNode.h> + +#include <Poco/Util/AbstractConfiguration.h> + +#include <boost/noncopyable.hpp> + +#include <memory> +#include <mutex> +#include <unordered_map> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +class SchedulerNodeFactory : private boost::noncopyable +{ +public: + static SchedulerNodeFactory & instance() + { + static SchedulerNodeFactory ret; + return ret; + } + + SchedulerNodePtr get(const String & name, EventQueue * event_queue, const Poco::Util::AbstractConfiguration & config, const String & config_prefix) + { + std::lock_guard lock{mutex}; + if (auto iter = methods.find(name); iter != methods.end()) + return iter->second(event_queue, config, config_prefix); + throw Exception(ErrorCodes::INVALID_SCHEDULER_NODE, "Unknown scheduler node type: {}", name); + } + + template <class TDerived> + void registerMethod(const String & name) + { + std::lock_guard lock{mutex}; + methods[name] = [] (EventQueue * event_queue, const Poco::Util::AbstractConfiguration & config, const String & config_prefix) + { + return std::make_shared<TDerived>(event_queue, config, config_prefix); + }; + } + +private: + std::mutex mutex; + using Method = std::function<SchedulerNodePtr(EventQueue * event_queue, const Poco::Util::AbstractConfiguration & config, const String & config_prefix)>; + std::unordered_map<String, Method> methods; +}; + +} diff --git a/contrib/clickhouse/src/IO/SchedulerRoot.h b/contrib/clickhouse/src/IO/SchedulerRoot.h new file mode 100644 index 0000000000..f9af2099b8 --- /dev/null +++ b/contrib/clickhouse/src/IO/SchedulerRoot.h @@ -0,0 +1,250 @@ +#pragma once + +#include <base/defines.h> + +#include <Common/Stopwatch.h> +#include <Common/ThreadPool.h> + +#include <IO/ISchedulerNode.h> +#include <IO/ISchedulerConstraint.h> + +#include <Poco/Util/XMLConfiguration.h> + +#include <unordered_map> +#include <map> +#include <memory> +#include <atomic> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_SCHEDULER_NODE; +} + +/* + * Resource scheduler root node with a dedicated thread. + * Immediate children correspond to different resources. + */ +class SchedulerRoot : public ISchedulerNode +{ +private: + struct TResource + { + SchedulerNodePtr root; + + // Intrusive cyclic list of active resources + TResource * next = nullptr; + TResource * prev = nullptr; + + explicit TResource(const SchedulerNodePtr & root_) + : root(root_) + { + root->info.parent.ptr = this; + } + + // Get pointer stored by ctor in info + static TResource * get(SchedulerNodeInfo & info) + { + return reinterpret_cast<TResource *>(info.parent.ptr); + } + }; + +public: + SchedulerRoot() + : ISchedulerNode(&events) + {} + + ~SchedulerRoot() override + { + stop(); + } + + /// Runs separate scheduler thread + void start() + { + if (!scheduler.joinable()) + scheduler = ThreadFromGlobalPool([this] { schedulerThread(); }); + } + + /// Joins scheduler threads and execute every pending request iff graceful + void stop(bool graceful = true) + { + if (scheduler.joinable()) + { + stop_flag.store(true); + events.enqueue([]{}); // just to wake up thread + scheduler.join(); + if (graceful) + { + // Do the same cycle as schedulerThread() but never block, just exit instead + bool has_work = true; + while (has_work) + { + auto [request, _] = dequeueRequest(); + if (request) + execute(request); + else + has_work = false; + while (events.tryProcess()) + has_work = true; + } + } + } + } + + bool equals(ISchedulerNode * other) override + { + if (auto * o = dynamic_cast<SchedulerRoot *>(other)) + return true; + return false; + } + + void attachChild(const SchedulerNodePtr & child) override + { + // Take ownership + assert(child->parent == nullptr); + if (auto [it, inserted] = children.emplace(child.get(), child); !inserted) + throw Exception( + ErrorCodes::INVALID_SCHEDULER_NODE, + "Can't add the same scheduler node twice"); + + // Attach + child->setParent(this); + + // Activate child if required + if (child->isActive()) + activateChild(child.get()); + } + + void removeChild(ISchedulerNode * child) override + { + if (auto iter = children.find(child); iter != children.end()) + { + SchedulerNodePtr removed = iter->second.root; + + // Deactivate if required + deactivate(&iter->second); + + // Detach + removed->setParent(nullptr); + + // Remove ownership + children.erase(iter); + } + } + + ISchedulerNode * getChild(const String &) override + { + abort(); // scheduler is allowed to have multiple children with the same name + } + + std::pair<ResourceRequest *, bool> dequeueRequest() override + { + if (current == nullptr) // No active resources + return {nullptr, false}; + + // Dequeue request from current resource + auto [request, resource_active] = current->root->dequeueRequest(); + assert(request != nullptr); + + // Deactivate resource if required + if (!resource_active) + deactivate(current); + else + current = current->next; // Just move round-robin pointer + + return {request, current != nullptr}; + } + + bool isActive() override + { + return current != nullptr; + } + + void activateChild(ISchedulerNode * child) override + { + activate(TResource::get(child->info)); + } + + void setParent(ISchedulerNode *) override + { + abort(); // scheduler must be the root and this function should not be called + } + +private: + void activate(TResource * value) + { + assert(value->next == nullptr && value->prev == nullptr); + if (current == nullptr) // No active children + { + current = value; + value->prev = value; + value->next = value; + } + else + { + current->prev->next = value; + value->prev = current->prev; + current->prev = value; + value->next = current; + } + } + + void deactivate(TResource * value) + { + if (value->next == nullptr) + return; // Already deactivated + assert(current != nullptr); + if (current == value) + { + if (current->next == current) // We are going to remove the last active child + { + value->next = nullptr; + value->prev = nullptr; + current = nullptr; + return; + } + else // Just move current to next to avoid invalidation + current = current->next; + } + value->prev->next = value->next; + value->next->prev = value->prev; + value->prev = nullptr; + value->next = nullptr; + } + +private: + void schedulerThread() + { + while (!stop_flag.load()) + { + // Dequeue and execute single request + auto [request, _] = dequeueRequest(); + if (request) + execute(request); + else // No more requests -- block until any event happens + events.process(); + + // Process all events before dequeuing to ensure fair competition + while (events.tryProcess()) {} + } + } + + void execute(ResourceRequest * request) + { + request->execute_ns = clock_gettime_ns(); + request->execute(); + } + +private: + TResource * current = nullptr; // round-robin pointer + std::unordered_map<ISchedulerNode *, TResource> children; // resources by pointer + std::atomic<bool> stop_flag = false; + EventQueue events; + ThreadFromGlobalPool scheduler; +}; + +} diff --git a/contrib/clickhouse/src/IO/SeekableReadBuffer.cpp b/contrib/clickhouse/src/IO/SeekableReadBuffer.cpp new file mode 100644 index 0000000000..b83e382db0 --- /dev/null +++ b/contrib/clickhouse/src/IO/SeekableReadBuffer.cpp @@ -0,0 +1,109 @@ +#include <IO/SeekableReadBuffer.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int CANNOT_READ_FROM_ISTREAM; +} + +namespace +{ + template <typename CustomData> + class SeekableReadBufferWrapper : public SeekableReadBuffer + { + public: + SeekableReadBufferWrapper(SeekableReadBuffer & in_, CustomData && custom_data_) + : SeekableReadBuffer(in_.buffer().begin(), in_.buffer().size(), in_.offset()) + , in(in_) + , custom_data(std::move(custom_data_)) + { + } + + private: + SeekableReadBuffer & in; + CustomData custom_data; + + bool nextImpl() override + { + in.position() = position(); + if (!in.next()) + { + set(in.position(), 0); + return false; + } + BufferBase::set(in.buffer().begin(), in.buffer().size(), in.offset()); + return true; + } + + off_t seek(off_t off, int whence) override + { + in.position() = position(); + off_t new_pos = in.seek(off, whence); + BufferBase::set(in.buffer().begin(), in.buffer().size(), in.offset()); + return new_pos; + } + + off_t getPosition() override + { + in.position() = position(); + return in.getPosition(); + } + }; +} + + +std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferReference(SeekableReadBuffer & ref) +{ + return std::make_unique<SeekableReadBufferWrapper<nullptr_t>>(ref, nullptr); +} + +std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferPointer(SeekableReadBufferPtr ptr) +{ + return std::make_unique<SeekableReadBufferWrapper<SeekableReadBufferPtr>>(*ptr, SeekableReadBufferPtr{ptr}); +} + +size_t copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_t n, const std::function<bool(size_t)> & progress_callback, bool * out_cancelled) +{ + const size_t chunk = DBMS_DEFAULT_BUFFER_SIZE; + if (out_cancelled) + *out_cancelled = false; + + size_t copied = 0; + while (copied < n) + { + size_t to_copy = std::min(chunk, n - copied); + istr.read(to + copied, to_copy); + size_t gcount = istr.gcount(); + + copied += gcount; + + bool cancelled = false; + if (gcount && progress_callback) + cancelled = progress_callback(copied); + + if (gcount != to_copy) + { + if (!istr.eof()) + throw Exception( + ErrorCodes::CANNOT_READ_FROM_ISTREAM, + "{} at offset {}", + istr.fail() ? "Cannot read from istream" : "Unexpected state of istream", + copied); + + break; + } + + if (cancelled) + { + if (out_cancelled != nullptr) + *out_cancelled = true; + break; + } + } + + return copied; +} + +} diff --git a/contrib/clickhouse/src/IO/SeekableReadBuffer.h b/contrib/clickhouse/src/IO/SeekableReadBuffer.h new file mode 100644 index 0000000000..5770948be2 --- /dev/null +++ b/contrib/clickhouse/src/IO/SeekableReadBuffer.h @@ -0,0 +1,103 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/WithFileSize.h> +#include <optional> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; +} + + +class SeekableReadBuffer : public ReadBuffer +{ +public: + SeekableReadBuffer(Position ptr, size_t size) + : ReadBuffer(ptr, size) {} + SeekableReadBuffer(Position ptr, size_t size, size_t offset) + : ReadBuffer(ptr, size, offset) {} + + /** + * Shifts buffer current position to given offset. + * @param off Offset. + * @param whence Seek mode (@see SEEK_SET, @see SEEK_CUR). + * @return New position from the beginning of underlying buffer / file. + * + * What happens if you seek above the end of the file? Implementation-defined. + */ + virtual off_t seek(off_t off, int whence) = 0; + + /** + * Keep in mind that seekable buffer may encounter eof() once and the working buffer + * may get into inconsistent state. Don't forget to reset it on the first nextImpl() + * after seek(). + */ + + /** + * @return Offset from the begin of the underlying buffer / file corresponds to the buffer current position. + */ + virtual off_t getPosition() = 0; + + virtual String getInfoForLog() { return ""; } + + virtual size_t getFileOffsetOfBufferEnd() const { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method getFileOffsetOfBufferEnd() not implemented"); } + + /// If true, setReadUntilPosition() guarantees that eof will be reported at the given position. + virtual bool supportsRightBoundedReads() const { return false; } + + /// Returns true if seek() actually works, false if seek() will always throw (or make subsequent + /// nextImpl() calls throw). + /// + /// This is needed because: + /// * Sometimes there's no cheap way to know in advance whether the buffer is really seekable. + /// Specifically, HTTP read buffer needs to send a request to check whether the server + /// supports byte ranges. + /// * Sometimes when we create such buffer we don't know in advance whether we'll need it to be + /// seekable or not. So we don't want to pay the price for this check in advance. + virtual bool checkIfActuallySeekable() { return true; } + + /// Unbuffered positional read. + /// Doesn't affect the buffer state (position, working_buffer, etc). + /// + /// `progress_callback` may be called periodically during the read, reporting that to[0..m-1] + /// has been filled. If it returns true, reading is stopped, and readBigAt() returns bytes read + /// so far. Called only from inside readBigAt(), from the same thread, with increasing m. + /// + /// Stops either after n bytes, or at end of file, or on exception. Returns number of bytes read. + /// If offset is past the end of file, may return 0 or throw exception. + /// + /// Caller needs to be careful: + /// * supportsReadAt() must be checked (called and return true) before calling readBigAt(). + /// Otherwise readBigAt() may crash. + /// * Thread safety: multiple readBigAt() calls may be performed in parallel. + /// But readBigAt() may not be called in parallel with any other methods + /// (e.g. next() or supportsReadAt()). + /// * Performance: there's no buffering. Each readBigAt() call typically translates into actual + /// IO operation (e.g. HTTP request). Don't use it for small adjacent reads. + virtual size_t readBigAt(char * /*to*/, size_t /*n*/, size_t /*offset*/, const std::function<bool(size_t m)> & /*progress_callback*/ = nullptr) + { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method readBigAt() not implemented"); } + + /// Checks if readBigAt() is allowed. May be slow, may throw (e.g. it may do an HTTP request or an fstat). + virtual bool supportsReadAt() { return false; } + + /// We do some tricks to avoid seek cost. E.g we read more data and than ignore it (see remote_read_min_bytes_for_seek). + /// Sometimes however seek is basically free because underlying read buffer wasn't yet initialised (or re-initialised after reset). + virtual bool seekIsCheap() { return false; } +}; + + +using SeekableReadBufferPtr = std::shared_ptr<SeekableReadBuffer>; + +/// Wraps a reference to a SeekableReadBuffer into an unique pointer to SeekableReadBuffer. +/// This function is like wrapReadBufferReference() but for SeekableReadBuffer. +std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferReference(SeekableReadBuffer & ref); +std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferPointer(SeekableReadBufferPtr ptr); + +/// Helper for implementing readBigAt(). +size_t copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_t n, const std::function<bool(size_t)> & progress_callback, bool * out_cancelled = nullptr); + +} diff --git a/contrib/clickhouse/src/IO/SharedThreadPools.cpp b/contrib/clickhouse/src/IO/SharedThreadPools.cpp new file mode 100644 index 0000000000..6a0e953f0e --- /dev/null +++ b/contrib/clickhouse/src/IO/SharedThreadPools.cpp @@ -0,0 +1,141 @@ +#include <IO/SharedThreadPools.h> +#include <Common/CurrentMetrics.h> +#include <Common/ThreadPool.h> +#include <Core/Field.h> + +namespace CurrentMetrics +{ + extern const Metric IOThreads; + extern const Metric IOThreadsActive; + extern const Metric BackupsIOThreads; + extern const Metric BackupsIOThreadsActive; + extern const Metric MergeTreePartsLoaderThreads; + extern const Metric MergeTreePartsLoaderThreadsActive; + extern const Metric MergeTreePartsCleanerThreads; + extern const Metric MergeTreePartsCleanerThreadsActive; + extern const Metric MergeTreeOutdatedPartsLoaderThreads; + extern const Metric MergeTreeOutdatedPartsLoaderThreadsActive; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + + +StaticThreadPool::StaticThreadPool( + const String & name_, + CurrentMetrics::Metric threads_metric_, + CurrentMetrics::Metric threads_active_metric_) + : name(name_) + , threads_metric(threads_metric_) + , threads_active_metric(threads_active_metric_) +{ +} + +void StaticThreadPool::initialize(size_t max_threads, size_t max_free_threads, size_t queue_size) +{ + if (instance) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The {} is initialized twice", name); + + /// By default enabling "turbo mode" won't affect the number of threads anyhow + max_threads_turbo = max_threads; + max_threads_normal = max_threads; + instance = std::make_unique<ThreadPool>( + threads_metric, + threads_active_metric, + max_threads, + max_free_threads, + queue_size, + /* shutdown_on_exception= */ false); +} + +void StaticThreadPool::reloadConfiguration(size_t max_threads, size_t max_free_threads, size_t queue_size) +{ + if (!instance) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The {} is not initialized", name); + + instance->setMaxThreads(turbo_mode_enabled > 0 ? max_threads_turbo : max_threads); + instance->setMaxFreeThreads(max_free_threads); + instance->setQueueSize(queue_size); +} + + +ThreadPool & StaticThreadPool::get() +{ + if (!instance) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The {} is not initialized", name); + + return *instance; +} + +void StaticThreadPool::enableTurboMode() +{ + if (!instance) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The {} is not initialized", name); + + std::lock_guard lock(mutex); + + ++turbo_mode_enabled; + if (turbo_mode_enabled == 1) + instance->setMaxThreads(max_threads_turbo); +} + +void StaticThreadPool::disableTurboMode() +{ + if (!instance) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The {} is not initialized", name); + + std::lock_guard lock(mutex); + + --turbo_mode_enabled; + if (turbo_mode_enabled == 0) + instance->setMaxThreads(max_threads_normal); +} + +void StaticThreadPool::setMaxTurboThreads(size_t max_threads_turbo_) +{ + if (!instance) + throw Exception(ErrorCodes::LOGICAL_ERROR, "The {} is not initialized", name); + + std::lock_guard lock(mutex); + + max_threads_turbo = max_threads_turbo_; + if (turbo_mode_enabled > 0) + instance->setMaxThreads(max_threads_turbo); +} + +StaticThreadPool & getIOThreadPool() +{ + static StaticThreadPool instance("IOThreadPool", CurrentMetrics::IOThreads, CurrentMetrics::IOThreadsActive); + return instance; +} + +StaticThreadPool & getBackupsIOThreadPool() +{ + static StaticThreadPool instance("BackupsIOThreadPool", CurrentMetrics::BackupsIOThreads, CurrentMetrics::BackupsIOThreadsActive); + return instance; +} + +StaticThreadPool & getActivePartsLoadingThreadPool() +{ + static StaticThreadPool instance("MergeTreePartsLoaderThreadPool", CurrentMetrics::MergeTreePartsLoaderThreads, CurrentMetrics::MergeTreePartsLoaderThreadsActive); + return instance; +} + +StaticThreadPool & getPartsCleaningThreadPool() +{ + static StaticThreadPool instance("MergeTreePartsCleanerThreadPool", CurrentMetrics::MergeTreePartsCleanerThreads, CurrentMetrics::MergeTreePartsCleanerThreadsActive); + return instance; +} + +StaticThreadPool & getOutdatedPartsLoadingThreadPool() +{ + static StaticThreadPool instance("MergeTreeOutdatedPartsLoaderThreadPool", CurrentMetrics::MergeTreeOutdatedPartsLoaderThreads, CurrentMetrics::MergeTreeOutdatedPartsLoaderThreadsActive); + return instance; +} + +} diff --git a/contrib/clickhouse/src/IO/SharedThreadPools.h b/contrib/clickhouse/src/IO/SharedThreadPools.h new file mode 100644 index 0000000000..188a2a4f00 --- /dev/null +++ b/contrib/clickhouse/src/IO/SharedThreadPools.h @@ -0,0 +1,64 @@ +#pragma once + +#include <base/types.h> +#include <Common/ThreadPool_fwd.h> +#include <Common/CurrentMetrics.h> + +#include <cstdlib> +#include <memory> +#include <mutex> + +namespace DB +{ + +class StaticThreadPool +{ +public: + StaticThreadPool( + const String & name_, + CurrentMetrics::Metric threads_metric_, + CurrentMetrics::Metric threads_active_metric_); + + ThreadPool & get(); + + void initialize(size_t max_threads, size_t max_free_threads, size_t queue_size); + void reloadConfiguration(size_t max_threads, size_t max_free_threads, size_t queue_size); + + /// At runtime we can increase the number of threads up the specified limit + /// This is needed to utilize as much a possible resources to accomplish some task. + void setMaxTurboThreads(size_t max_threads_turbo_); + void enableTurboMode(); + void disableTurboMode(); + +private: + const String name; + const CurrentMetrics::Metric threads_metric; + const CurrentMetrics::Metric threads_active_metric; + + std::unique_ptr<ThreadPool> instance; + std::mutex mutex; + size_t max_threads_turbo = 0; + size_t max_threads_normal = 0; + /// If this counter is > 0 - this specific mode is enabled + size_t turbo_mode_enabled = 0; +}; + +/// ThreadPool used for the IO. +StaticThreadPool & getIOThreadPool(); + +/// ThreadPool used for the Backup IO. +StaticThreadPool & getBackupsIOThreadPool(); + +/// ThreadPool used for the loading of Outdated data parts for MergeTree tables. +StaticThreadPool & getActivePartsLoadingThreadPool(); + +/// ThreadPool used for deleting data parts for MergeTree tables. +StaticThreadPool & getPartsCleaningThreadPool(); + +/// This ThreadPool is used for the loading of Outdated data parts for MergeTree tables. +/// Normally we will just load Outdated data parts concurrently in background, but in +/// case when we need to synchronously wait for the loading to be finished, we can increase +/// the number of threads by calling enableTurboMode() :-) +StaticThreadPool & getOutdatedPartsLoadingThreadPool(); + +} diff --git a/contrib/clickhouse/src/IO/SnappyReadBuffer.cpp b/contrib/clickhouse/src/IO/SnappyReadBuffer.cpp new file mode 100644 index 0000000000..74a1784d6d --- /dev/null +++ b/contrib/clickhouse/src/IO/SnappyReadBuffer.cpp @@ -0,0 +1,75 @@ +#include "clickhouse_config.h" + +#if USE_SNAPPY +#include <memory> +#include <fcntl.h> +#include <sys/types.h> + +#include <snappy.h> + +#include <IO/copyData.h> +#include <IO/WriteBufferFromString.h> +#include <IO/WriteHelpers.h> + +#include "SnappyReadBuffer.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int SNAPPY_UNCOMPRESS_FAILED; + extern const int SEEK_POSITION_OUT_OF_BOUND; +} + + +SnappyReadBuffer::SnappyReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char * existing_memory, size_t alignment) + : BufferWithOwnMemory<SeekableReadBuffer>(buf_size, existing_memory, alignment), in(std::move(in_)) +{ +} + +bool SnappyReadBuffer::nextImpl() +{ + if (compress_buffer.empty() && uncompress_buffer.empty()) + { + WriteBufferFromString wb(compress_buffer); + copyData(*in, wb); + + bool success = snappy::Uncompress(compress_buffer.data(), wb.count(), &uncompress_buffer); + if (!success) + { + throw Exception(ErrorCodes::SNAPPY_UNCOMPRESS_FAILED, "snappy uncomress failed: "); + } + BufferBase::set(const_cast<char *>(uncompress_buffer.data()), uncompress_buffer.size(), 0); + return true; + } + return false; +} + +SnappyReadBuffer::~SnappyReadBuffer() = default; + +off_t SnappyReadBuffer::seek(off_t off, int whence) +{ + off_t new_pos; + if (whence == SEEK_SET) + new_pos = off; + else if (whence == SEEK_CUR) + new_pos = count() + off; + else + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Only SEEK_SET and SEEK_CUR seek modes allowed."); + + working_buffer = internal_buffer; + if (new_pos < 0 || new_pos > off_t(working_buffer.size())) + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, + "Cannot seek through buffer because seek position ({}) is out of bounds [0, {}]", + new_pos, working_buffer.size()); + position() = working_buffer.begin() + new_pos; + return new_pos; +} + +off_t SnappyReadBuffer::getPosition() +{ + return count(); +} + +} +#endif diff --git a/contrib/clickhouse/src/IO/SnappyReadBuffer.h b/contrib/clickhouse/src/IO/SnappyReadBuffer.h new file mode 100644 index 0000000000..532fcb14e8 --- /dev/null +++ b/contrib/clickhouse/src/IO/SnappyReadBuffer.h @@ -0,0 +1,35 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_SNAPPY + +#include <IO/ReadBuffer.h> +#include <IO/SeekableReadBuffer.h> +#include <IO/BufferWithOwnMemory.h> + +namespace DB +{ +class SnappyReadBuffer : public BufferWithOwnMemory<SeekableReadBuffer> +{ +public: + explicit SnappyReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~SnappyReadBuffer() override; + + bool nextImpl() override; + off_t seek(off_t off, int whence) override; + off_t getPosition() override; + +private: + std::unique_ptr<ReadBuffer> in; + String compress_buffer; + String uncompress_buffer; +}; + +} +#endif diff --git a/contrib/clickhouse/src/IO/SnappyWriteBuffer.cpp b/contrib/clickhouse/src/IO/SnappyWriteBuffer.cpp new file mode 100644 index 0000000000..4a27615f24 --- /dev/null +++ b/contrib/clickhouse/src/IO/SnappyWriteBuffer.cpp @@ -0,0 +1,92 @@ +#include "clickhouse_config.h" + +#if USE_SNAPPY +#include <cstring> + +#include <snappy.h> + +#include <Common/ErrorCodes.h> +#include "SnappyWriteBuffer.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int SNAPPY_COMPRESS_FAILED; +} + +SnappyWriteBuffer::SnappyWriteBuffer(std::unique_ptr<WriteBuffer> out_, size_t buf_size, char * existing_memory, size_t alignment) + : BufferWithOwnMemory<WriteBuffer>(buf_size, existing_memory, alignment), out(std::move(out_)) +{ +} + +SnappyWriteBuffer::~SnappyWriteBuffer() +{ + finish(); +} + +void SnappyWriteBuffer::nextImpl() +{ + if (!offset()) + { + return; + } + + const char * in_data = reinterpret_cast<const char *>(working_buffer.begin()); + size_t in_available = offset(); + uncompress_buffer.append(in_data, in_available); +} + +void SnappyWriteBuffer::finish() +{ + if (finished) + return; + + try + { + finishImpl(); + out->finalize(); + finished = true; + } + catch (...) + { + /// Do not try to flush next time after exception. + out->position() = out->buffer().begin(); + finished = true; + throw; + } +} + +void SnappyWriteBuffer::finishImpl() +{ + next(); + + bool success = snappy::Compress(uncompress_buffer.data(), uncompress_buffer.size(), &compress_buffer); + if (!success) + { + throw Exception(ErrorCodes::SNAPPY_COMPRESS_FAILED, "snappy compress failed: "); + } + + char * in_data = compress_buffer.data(); + size_t in_available = compress_buffer.size(); + char * out_data = nullptr; + size_t out_capacity = 0; + size_t len = 0; + while (in_available > 0) + { + out->nextIfAtEnd(); + out_data = out->position(); + out_capacity = out->buffer().end() - out->position(); + len = in_available > out_capacity ? out_capacity : in_available; + + memcpy(out_data, in_data, len); + in_data += len; + in_available -= len; + out->position() += len; + } +} + +} + +#endif + diff --git a/contrib/clickhouse/src/IO/SnappyWriteBuffer.h b/contrib/clickhouse/src/IO/SnappyWriteBuffer.h new file mode 100644 index 0000000000..73652f33a5 --- /dev/null +++ b/contrib/clickhouse/src/IO/SnappyWriteBuffer.h @@ -0,0 +1,41 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_SNAPPY +#include <IO/BufferWithOwnMemory.h> +#include <IO/WriteBuffer.h> + +namespace DB +{ +/// Performs compression using snappy library and write compressed data to the underlying buffer. +class SnappyWriteBuffer : public BufferWithOwnMemory<WriteBuffer> +{ +public: + explicit SnappyWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~SnappyWriteBuffer() override; + + void finalizeImpl() override { finish(); } + +private: + void nextImpl() override; + + void finishImpl(); + void finish(); + + std::unique_ptr<WriteBuffer> out; + bool finished = false; + + String uncompress_buffer; + String compress_buffer; +}; + +} + +#endif + diff --git a/contrib/clickhouse/src/IO/StdIStreamFromMemory.cpp b/contrib/clickhouse/src/IO/StdIStreamFromMemory.cpp new file mode 100644 index 0000000000..3242a7e638 --- /dev/null +++ b/contrib/clickhouse/src/IO/StdIStreamFromMemory.cpp @@ -0,0 +1,62 @@ +#include <IO/StdIStreamFromMemory.h> + +namespace DB +{ + +StdIStreamFromMemory::MemoryBuf::MemoryBuf(char * begin_, size_t size_) + : begin(begin_) + , size(size_) +{ + this->setg(begin, begin, begin + size); +} + +StdIStreamFromMemory::MemoryBuf::int_type StdIStreamFromMemory::MemoryBuf::underflow() +{ + if (gptr() < egptr()) + return traits_type::to_int_type(*gptr()); + return traits_type::eof(); +} + +StdIStreamFromMemory::MemoryBuf::pos_type +StdIStreamFromMemory::MemoryBuf::seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode mode) +{ + bool out_mode = (std::ios_base::out & mode) != 0; + if (out_mode) + return off_type(-1); + + off_type ret(-1); + + if (way == std::ios_base::beg) + ret = 0; + else if (way == std::ios_base::cur) + ret = gptr() - begin; + else if (way == std::ios_base::end) + ret = size; + + if (ret == off_type(-1)) + return ret; + + ret += off; + if (!(ret >= 0 && size_t(ret) <= size)) + return off_type(-1); + + this->setg(begin, begin + ret, begin + size); + + return pos_type(ret); +} + +StdIStreamFromMemory::MemoryBuf::pos_type StdIStreamFromMemory::MemoryBuf::seekpos(pos_type sp, + std::ios_base::openmode mode) +{ + return seekoff(off_type(sp), std::ios_base::beg, mode); +} + +StdIStreamFromMemory::StdIStreamFromMemory(char * begin_, size_t size_) + : std::iostream(nullptr) + , mem_buf(begin_, size_) +{ + init(&mem_buf); +} + +} diff --git a/contrib/clickhouse/src/IO/StdIStreamFromMemory.h b/contrib/clickhouse/src/IO/StdIStreamFromMemory.h new file mode 100644 index 0000000000..64b147fd29 --- /dev/null +++ b/contrib/clickhouse/src/IO/StdIStreamFromMemory.h @@ -0,0 +1,36 @@ +#pragma once + +#include <iostream> + +namespace DB +{ + +/// StdIStreamFromMemory is used in WriteBufferFromS3 as a stream which is passed to the S3::Client +/// It provides istream interface (only reading) over the memory. +/// However S3::Client requires iostream interface it only reads from the stream + +class StdIStreamFromMemory : public std::iostream +{ + struct MemoryBuf: std::streambuf + { + MemoryBuf(char * begin_, size_t size_); + + int_type underflow() override; + + pos_type seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode mode) override; + + pos_type seekpos(pos_type sp, + std::ios_base::openmode mode) override; + + char * begin = nullptr; + size_t size = 0; + }; + + MemoryBuf mem_buf; + +public: + StdIStreamFromMemory(char * begin_, size_t size_); +}; + +} diff --git a/contrib/clickhouse/src/IO/StdStreamBufFromReadBuffer.cpp b/contrib/clickhouse/src/IO/StdStreamBufFromReadBuffer.cpp new file mode 100644 index 0000000000..a814dff040 --- /dev/null +++ b/contrib/clickhouse/src/IO/StdStreamBufFromReadBuffer.cpp @@ -0,0 +1,111 @@ +#include <IO/StdStreamBufFromReadBuffer.h> +#include <IO/SeekableReadBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int SEEK_POSITION_OUT_OF_BOUND; +} + + +StdStreamBufFromReadBuffer::StdStreamBufFromReadBuffer(std::unique_ptr<ReadBuffer> read_buffer_, size_t size_) + : read_buffer(std::move(read_buffer_)), seekable_read_buffer(dynamic_cast<SeekableReadBuffer *>(read_buffer.get())), size(size_) +{ +} + +StdStreamBufFromReadBuffer::StdStreamBufFromReadBuffer(ReadBuffer & read_buffer_, size_t size_) : size(size_) +{ + if (dynamic_cast<SeekableReadBuffer *>(&read_buffer_)) + { + read_buffer = wrapSeekableReadBufferReference(static_cast<SeekableReadBuffer &>(read_buffer_)); + seekable_read_buffer = static_cast<SeekableReadBuffer *>(read_buffer.get()); + } + else + { + read_buffer = wrapReadBufferReference(read_buffer_); + } +} + +StdStreamBufFromReadBuffer::~StdStreamBufFromReadBuffer() = default; + +int StdStreamBufFromReadBuffer::underflow() +{ + char c; + if (!read_buffer->peek(c)) + return std::char_traits<char>::eof(); + return c; +} + +std::streamsize StdStreamBufFromReadBuffer::showmanyc() +{ + return read_buffer->available(); +} + +std::streamsize StdStreamBufFromReadBuffer::xsgetn(char_type* s, std::streamsize count) +{ + return read_buffer->read(s, count); +} + +std::streampos StdStreamBufFromReadBuffer::seekoff(std::streamoff off, std::ios_base::seekdir dir, std::ios_base::openmode which) +{ + if (dir == std::ios_base::beg) + return seekpos(off, which); + else if (dir == std::ios_base::cur) + return seekpos(getCurrentPosition() + off, which); + else if (dir == std::ios_base::end) + return seekpos(size + off, which); + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong seek's base {}", static_cast<int>(dir)); +} + +std::streampos StdStreamBufFromReadBuffer::seekpos(std::streampos pos, std::ios_base::openmode which) +{ + if (!(which & std::ios_base::in)) + throw Exception( + ErrorCodes::LOGICAL_ERROR, "Wrong seek mode {}", static_cast<int>(which)); + + std::streamoff offset = pos - getCurrentPosition(); + if (!offset) + return pos; + + if ((read_buffer->buffer().begin() <= read_buffer->position() + offset) && (read_buffer->position() + offset <= read_buffer->buffer().end())) + { + read_buffer->position() += offset; + return pos; + } + + if (seekable_read_buffer) + return seekable_read_buffer->seek(pos, SEEK_SET); + + if (offset > 0) + { + read_buffer->ignore(offset); + return pos; + } + + throw Exception(ErrorCodes::SEEK_POSITION_OUT_OF_BOUND, "Seek's offset {} is out of bound", pos); +} + +std::streampos StdStreamBufFromReadBuffer::getCurrentPosition() const +{ + if (seekable_read_buffer) + return seekable_read_buffer->getPosition(); + else + return read_buffer->count(); +} + +std::streamsize StdStreamBufFromReadBuffer::xsputn(const char*, std::streamsize) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "StdStreamBufFromReadBuffer cannot be used for output"); +} + +int StdStreamBufFromReadBuffer::overflow(int) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "StdStreamBufFromReadBuffer cannot be used for output"); +} + +} diff --git a/contrib/clickhouse/src/IO/StdStreamBufFromReadBuffer.h b/contrib/clickhouse/src/IO/StdStreamBufFromReadBuffer.h new file mode 100644 index 0000000000..ff16b91e98 --- /dev/null +++ b/contrib/clickhouse/src/IO/StdStreamBufFromReadBuffer.h @@ -0,0 +1,39 @@ +#pragma once + +#include <memory> +#include <streambuf> + + +namespace DB +{ +class ReadBuffer; +class SeekableReadBuffer; + +/// `std::streambuf`-compatible wrapper around a ReadBuffer. +class StdStreamBufFromReadBuffer : public std::streambuf +{ +public: + using Base = std::streambuf; + + explicit StdStreamBufFromReadBuffer(std::unique_ptr<ReadBuffer> read_buffer_, size_t size_); + explicit StdStreamBufFromReadBuffer(ReadBuffer & read_buffer_, size_t size_); + ~StdStreamBufFromReadBuffer() override; + +private: + int underflow() override; + std::streamsize showmanyc() override; + std::streamsize xsgetn(char* s, std::streamsize count) override; + std::streampos seekoff(std::streamoff off, std::ios_base::seekdir dir, std::ios_base::openmode which) override; + std::streampos seekpos(std::streampos pos, std::ios_base::openmode which) override; + + std::streamsize xsputn(const char* s, std::streamsize n) override; + int overflow(int c) override; + + std::streampos getCurrentPosition() const; + + std::unique_ptr<ReadBuffer> read_buffer; + SeekableReadBuffer * seekable_read_buffer = nullptr; + size_t size; +}; + +} diff --git a/contrib/clickhouse/src/IO/StdStreamFromReadBuffer.h b/contrib/clickhouse/src/IO/StdStreamFromReadBuffer.h new file mode 100644 index 0000000000..ff327dc342 --- /dev/null +++ b/contrib/clickhouse/src/IO/StdStreamFromReadBuffer.h @@ -0,0 +1,38 @@ +#pragma once + +#include <IO/StdStreamBufFromReadBuffer.h> +#include <memory> + + +namespace DB +{ +class ReadBuffer; + +/// `std::istream`-compatible wrapper around a ReadBuffer. +class StdIStreamFromReadBuffer : public std::istream +{ +public: + using Base = std::istream; + StdIStreamFromReadBuffer(std::unique_ptr<ReadBuffer> buf, size_t size) : Base(&stream_buf), stream_buf(std::move(buf), size) { } + StdIStreamFromReadBuffer(ReadBuffer & buf, size_t size) : Base(&stream_buf), stream_buf(buf, size) { } + StdStreamBufFromReadBuffer * rdbuf() const { return const_cast<StdStreamBufFromReadBuffer *>(&stream_buf); } + +private: + StdStreamBufFromReadBuffer stream_buf; +}; + + +/// `std::iostream`-compatible wrapper around a ReadBuffer. +class StdStreamFromReadBuffer : public std::iostream +{ +public: + using Base = std::iostream; + StdStreamFromReadBuffer(std::unique_ptr<ReadBuffer> buf, size_t size) : Base(&stream_buf), stream_buf(std::move(buf), size) { } + StdStreamFromReadBuffer(ReadBuffer & buf, size_t size) : Base(&stream_buf), stream_buf(buf, size) { } + StdStreamBufFromReadBuffer * rdbuf() const { return const_cast<StdStreamBufFromReadBuffer *>(&stream_buf); } + +private: + StdStreamBufFromReadBuffer stream_buf; +}; + +} diff --git a/contrib/clickhouse/src/IO/SwapHelper.cpp b/contrib/clickhouse/src/IO/SwapHelper.cpp new file mode 100644 index 0000000000..4a1cc8acf4 --- /dev/null +++ b/contrib/clickhouse/src/IO/SwapHelper.cpp @@ -0,0 +1,17 @@ +#include <IO/SwapHelper.h> + +namespace DB +{ + +SwapHelper::SwapHelper(BufferBase & b1_, BufferBase & b2_) + : b1(b1_), b2(b2_) +{ + b1.swap(b2); +} + +SwapHelper::~SwapHelper() +{ + b1.swap(b2); +} + +} diff --git a/contrib/clickhouse/src/IO/SwapHelper.h b/contrib/clickhouse/src/IO/SwapHelper.h new file mode 100644 index 0000000000..fcf32927f2 --- /dev/null +++ b/contrib/clickhouse/src/IO/SwapHelper.h @@ -0,0 +1,19 @@ +#pragma once + +#include <IO/BufferBase.h> + +namespace DB +{ + +class SwapHelper +{ +public: + SwapHelper(BufferBase & b1_, BufferBase & b2_); + ~SwapHelper(); + +private: + BufferBase & b1; + BufferBase & b2; +}; + +} diff --git a/contrib/clickhouse/src/IO/SynchronousReader.cpp b/contrib/clickhouse/src/IO/SynchronousReader.cpp new file mode 100644 index 0000000000..e1c654e48a --- /dev/null +++ b/contrib/clickhouse/src/IO/SynchronousReader.cpp @@ -0,0 +1,89 @@ +#include <IO/SynchronousReader.h> +#include <Common/assert_cast.h> +#include <Common/Exception.h> +#include <Common/CurrentMetrics.h> +#include <Common/ProfileEvents.h> +#include <Common/Stopwatch.h> +#include <base/errnoToString.h> +#include <unordered_map> +#include <mutex> +#include <unistd.h> +#include <fcntl.h> + + +namespace ProfileEvents +{ + extern const Event ReadBufferFromFileDescriptorRead; + extern const Event ReadBufferFromFileDescriptorReadFailed; + extern const Event ReadBufferFromFileDescriptorReadBytes; + extern const Event DiskReadElapsedMicroseconds; +} + +namespace CurrentMetrics +{ + extern const Metric Read; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_READ_FROM_FILE_DESCRIPTOR; + extern const int CANNOT_ADVISE; +} + + +std::future<IAsynchronousReader::Result> SynchronousReader::submit(Request request) +{ + /// If size is zero, then read() cannot be distinguished from EOF + assert(request.size); + + int fd = assert_cast<const LocalFileDescriptor &>(*request.descriptor).fd; + +#if defined(POSIX_FADV_WILLNEED) + if (0 != posix_fadvise(fd, request.offset, request.size, POSIX_FADV_WILLNEED)) + throwFromErrno("Cannot posix_fadvise", ErrorCodes::CANNOT_ADVISE); +#endif + + return std::async(std::launch::deferred, [fd, request] + { + ProfileEvents::increment(ProfileEvents::ReadBufferFromFileDescriptorRead); + Stopwatch watch(CLOCK_MONOTONIC); + + size_t bytes_read = 0; + while (!bytes_read) + { + ssize_t res = 0; + + { + CurrentMetrics::Increment metric_increment{CurrentMetrics::Read}; + res = ::pread(fd, request.buf, request.size, request.offset); + } + if (!res) + break; + + if (-1 == res && errno != EINTR) + { + ProfileEvents::increment(ProfileEvents::ReadBufferFromFileDescriptorReadFailed); + throwFromErrno(fmt::format("Cannot read from file {}", fd), ErrorCodes::CANNOT_READ_FROM_FILE_DESCRIPTOR); + } + + if (res > 0) + bytes_read += res; + } + + ProfileEvents::increment(ProfileEvents::ReadBufferFromFileDescriptorReadBytes, bytes_read); + + /// It reports real time spent including the time spent while thread was preempted doing nothing. + /// And it is Ok for the purpose of this watch (it is used to lower the number of threads to read from tables). + /// Sometimes it is better to use taskstats::blkio_delay_total, but it is quite expensive to get it + /// (NetlinkMetricsProvider has about 500K RPS). + watch.stop(); + ProfileEvents::increment(ProfileEvents::DiskReadElapsedMicroseconds, watch.elapsedMicroseconds()); + + return Result{ .size = bytes_read, .offset = request.ignore }; + }); +} + +} diff --git a/contrib/clickhouse/src/IO/SynchronousReader.h b/contrib/clickhouse/src/IO/SynchronousReader.h new file mode 100644 index 0000000000..238d6e9371 --- /dev/null +++ b/contrib/clickhouse/src/IO/SynchronousReader.h @@ -0,0 +1,20 @@ +#pragma once + +#include <IO/AsynchronousReader.h> + + +namespace DB +{ + +/** Implementation of IAsynchronousReader that in fact synchronous. + * The only addition is posix_fadvise. + */ +class SynchronousReader final : public IAsynchronousReader +{ +public: + std::future<Result> submit(Request request) override; + + void wait() override {} +}; + +} diff --git a/contrib/clickhouse/src/IO/TimeoutSetter.cpp b/contrib/clickhouse/src/IO/TimeoutSetter.cpp new file mode 100644 index 0000000000..b8b7a81470 --- /dev/null +++ b/contrib/clickhouse/src/IO/TimeoutSetter.cpp @@ -0,0 +1,56 @@ +#include <IO/TimeoutSetter.h> + +#include <Common/logger_useful.h> + + +namespace DB +{ + +TimeoutSetter::TimeoutSetter(Poco::Net::StreamSocket & socket_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + bool limit_max_timeout) + : socket(socket_), send_timeout(send_timeout_), receive_timeout(receive_timeout_) +{ + old_send_timeout = socket.getSendTimeout(); + old_receive_timeout = socket.getReceiveTimeout(); + + if (!limit_max_timeout || old_send_timeout > send_timeout) + socket.setSendTimeout(send_timeout); + + if (!limit_max_timeout || old_receive_timeout > receive_timeout) + socket.setReceiveTimeout(receive_timeout); +} + +TimeoutSetter::TimeoutSetter(Poco::Net::StreamSocket & socket_, Poco::Timespan timeout_, bool limit_max_timeout) + : TimeoutSetter(socket_, timeout_, timeout_, limit_max_timeout) +{ +} + +TimeoutSetter::~TimeoutSetter() +{ + if (was_reset) + return; + + try + { + reset(); + } + catch (...) + { + tryLogCurrentException("Client", "TimeoutSetter: Can't reset timeouts"); + } +} + +void TimeoutSetter::reset() +{ + bool connected = socket.impl()->initialized(); + if (!connected) + return; + + socket.setSendTimeout(old_send_timeout); + socket.setReceiveTimeout(old_receive_timeout); + was_reset = true; +} + +} diff --git a/contrib/clickhouse/src/IO/TimeoutSetter.h b/contrib/clickhouse/src/IO/TimeoutSetter.h new file mode 100644 index 0000000000..3479986d7f --- /dev/null +++ b/contrib/clickhouse/src/IO/TimeoutSetter.h @@ -0,0 +1,34 @@ +#pragma once + +#include <Poco/Net/StreamSocket.h> +#include <Poco/Timespan.h> + + +namespace DB +{ +/// Temporarily overrides socket send/receive timeouts and reset them back into destructor (or manually by calling reset method) +/// If "limit_max_timeout" is true, timeouts could be only decreased (maxed by previous value). +struct TimeoutSetter +{ + TimeoutSetter(Poco::Net::StreamSocket & socket_, + Poco::Timespan send_timeout_, + Poco::Timespan receive_timeout_, + bool limit_max_timeout = false); + + TimeoutSetter(Poco::Net::StreamSocket & socket_, Poco::Timespan timeout_, bool limit_max_timeout = false); + + ~TimeoutSetter(); + + /// Reset timeouts back. + void reset(); + + Poco::Net::StreamSocket & socket; + + Poco::Timespan send_timeout; + Poco::Timespan receive_timeout; + + Poco::Timespan old_send_timeout; + Poco::Timespan old_receive_timeout; + bool was_reset = false; +}; +} diff --git a/contrib/clickhouse/src/IO/UncompressedCache.h b/contrib/clickhouse/src/IO/UncompressedCache.h new file mode 100644 index 0000000000..702804cdda --- /dev/null +++ b/contrib/clickhouse/src/IO/UncompressedCache.h @@ -0,0 +1,80 @@ +#pragma once + +#include <Common/SipHash.h> +#include <Common/ProfileEvents.h> +#include <Common/HashTable/Hash.h> +#include <IO/BufferWithOwnMemory.h> +#include <Common/CacheBase.h> + + +namespace ProfileEvents +{ + extern const Event UncompressedCacheHits; + extern const Event UncompressedCacheMisses; + extern const Event UncompressedCacheWeightLost; +} + +namespace DB +{ + + +struct UncompressedCacheCell +{ + Memory<> data; + size_t compressed_size; + UInt32 additional_bytes; +}; + +struct UncompressedSizeWeightFunction +{ + size_t operator()(const UncompressedCacheCell & x) const + { + return x.data.size(); + } +}; + + +/** Cache of decompressed blocks for implementation of CachedCompressedReadBuffer. thread-safe. + */ +class UncompressedCache : public CacheBase<UInt128, UncompressedCacheCell, UInt128TrivialHash, UncompressedSizeWeightFunction> +{ +private: + using Base = CacheBase<UInt128, UncompressedCacheCell, UInt128TrivialHash, UncompressedSizeWeightFunction>; + +public: + UncompressedCache(const String & cache_policy, size_t max_size_in_bytes, double size_ratio) + : Base(cache_policy, max_size_in_bytes, 0, size_ratio) {} + + /// Calculate key from path to file and offset. + static UInt128 hash(const String & path_to_file, size_t offset) + { + SipHash hash; + hash.update(path_to_file.data(), path_to_file.size() + 1); + hash.update(offset); + + return hash.get128(); + } + + template <typename LoadFunc> + MappedPtr getOrSet(const Key & key, LoadFunc && load) + { + auto result = Base::getOrSet(key, std::forward<LoadFunc>(load)); + + if (result.second) + ProfileEvents::increment(ProfileEvents::UncompressedCacheMisses); + else + ProfileEvents::increment(ProfileEvents::UncompressedCacheHits); + + return result.first; + } + +private: + void onRemoveOverflowWeightLoss(size_t weight_loss) override + { + ProfileEvents::increment(ProfileEvents::UncompressedCacheWeightLost, weight_loss); + } +}; + +using UncompressedCachePtr = std::shared_ptr<UncompressedCache>; + +} diff --git a/contrib/clickhouse/src/IO/UseSSL.cpp b/contrib/clickhouse/src/IO/UseSSL.cpp new file mode 100644 index 0000000000..7a2ff928e0 --- /dev/null +++ b/contrib/clickhouse/src/IO/UseSSL.cpp @@ -0,0 +1,24 @@ +#include "UseSSL.h" + +#include "clickhouse_config.h" + +#if USE_SSL +# include <Poco/Net/SSLManager.h> +#endif + +namespace DB +{ +UseSSL::UseSSL() +{ +#if USE_SSL + Poco::Net::initializeSSL(); +#endif +} + +UseSSL::~UseSSL() +{ +#if USE_SSL + Poco::Net::uninitializeSSL(); +#endif +} +} diff --git a/contrib/clickhouse/src/IO/UseSSL.h b/contrib/clickhouse/src/IO/UseSSL.h new file mode 100644 index 0000000000..324f318edb --- /dev/null +++ b/contrib/clickhouse/src/IO/UseSSL.h @@ -0,0 +1,13 @@ +#pragma once + +#include <boost/noncopyable.hpp> + +namespace DB +{ +// http://stackoverflow.com/questions/18315472/https-request-in-c-using-poco +struct UseSSL : private boost::noncopyable +{ + UseSSL(); + ~UseSSL(); +}; +} diff --git a/contrib/clickhouse/src/IO/VarInt.cpp b/contrib/clickhouse/src/IO/VarInt.cpp new file mode 100644 index 0000000000..a4b249b01d --- /dev/null +++ b/contrib/clickhouse/src/IO/VarInt.cpp @@ -0,0 +1,16 @@ +#include <IO/VarInt.h> +#include <Common/Exception.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ATTEMPT_TO_READ_AFTER_EOF; +} + +void throwReadAfterEOF() +{ + throw Exception(ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF, "Attempt to read after eof"); +} + +} diff --git a/contrib/clickhouse/src/IO/VarInt.h b/contrib/clickhouse/src/IO/VarInt.h new file mode 100644 index 0000000000..8d10055a3d --- /dev/null +++ b/contrib/clickhouse/src/IO/VarInt.h @@ -0,0 +1,218 @@ +#pragma once + +#include <base/types.h> +#include <base/defines.h> +#include <IO/ReadBuffer.h> +#include <IO/WriteBuffer.h> + + +namespace DB +{ + +/// Variable-Length Quantity (VLQ) Base-128 compression, also known as Variable Byte (VB) or Varint encoding. + +[[noreturn]] void throwReadAfterEOF(); + + +inline void writeVarUInt(UInt64 x, WriteBuffer & ostr) +{ + while (x > 0x7F) + { + uint8_t byte = 0x80 | (x & 0x7F); + + ostr.nextIfAtEnd(); + *ostr.position() = byte; + ++ostr.position(); + + x >>= 7; + } + + uint8_t final_byte = static_cast<uint8_t>(x); + + ostr.nextIfAtEnd(); + *ostr.position() = final_byte; + ++ostr.position(); +} + +inline void writeVarUInt(UInt64 x, std::ostream & ostr) +{ + while (x > 0x7F) + { + uint8_t byte = 0x80 | (x & 0x7F); + ostr.put(byte); + + x >>= 7; + } + + uint8_t final_byte = static_cast<uint8_t>(x); + ostr.put(final_byte); +} + +inline char * writeVarUInt(UInt64 x, char * ostr) +{ + while (x > 0x7F) + { + uint8_t byte = 0x80 | (x & 0x7F); + + *ostr = byte; + ++ostr; + + x >>= 7; + } + + uint8_t final_byte = static_cast<uint8_t>(x); + + *ostr = final_byte; + ++ostr; + + return ostr; +} + +template <typename Out> +inline void writeVarInt(Int64 x, Out & ostr) +{ + writeVarUInt(static_cast<UInt64>((x << 1) ^ (x >> 63)), ostr); +} + +inline char * writeVarInt(Int64 x, char * ostr) +{ + return writeVarUInt(static_cast<UInt64>((x << 1) ^ (x >> 63)), ostr); +} + +namespace impl +{ + +template <bool check_eof> +inline void readVarUInt(UInt64 & x, ReadBuffer & istr) +{ + x = 0; + for (size_t i = 0; i < 10; ++i) + { + if constexpr (check_eof) + if (istr.eof()) [[unlikely]] + throwReadAfterEOF(); + + UInt64 byte = *istr.position(); + ++istr.position(); + x |= (byte & 0x7F) << (7 * i); + + if (!(byte & 0x80)) + return; + } +} + +} + +inline void readVarUInt(UInt64 & x, ReadBuffer & istr) +{ + if (istr.buffer().end() - istr.position() >= 10) + return impl::readVarUInt<false>(x, istr); + return impl::readVarUInt<true>(x, istr); +} + +inline void readVarUInt(UInt64 & x, std::istream & istr) +{ + x = 0; + for (size_t i = 0; i < 10; ++i) + { + UInt64 byte = istr.get(); + x |= (byte & 0x7F) << (7 * i); + + if (!(byte & 0x80)) + return; + } +} + +inline const char * readVarUInt(UInt64 & x, const char * istr, size_t size) +{ + const char * end = istr + size; + + x = 0; + for (size_t i = 0; i < 10; ++i) + { + if (istr == end) [[unlikely]] + throwReadAfterEOF(); + + UInt64 byte = *istr; + ++istr; + x |= (byte & 0x7F) << (7 * i); + + if (!(byte & 0x80)) + return istr; + } + + return istr; +} + +template <typename In> +inline void readVarInt(Int64 & x, In & istr) +{ + readVarUInt(*reinterpret_cast<UInt64*>(&x), istr); + x = (static_cast<UInt64>(x) >> 1) ^ -(x & 1); +} + +inline const char * readVarInt(Int64 & x, const char * istr, size_t size) +{ + const char * res = readVarUInt(*reinterpret_cast<UInt64*>(&x), istr, size); + x = (static_cast<UInt64>(x) >> 1) ^ -(x & 1); + return res; +} + +inline void readVarUInt(UInt32 & x, ReadBuffer & istr) +{ + UInt64 tmp; + readVarUInt(tmp, istr); + x = static_cast<UInt32>(tmp); +} + +inline void readVarInt(Int32 & x, ReadBuffer & istr) +{ + Int64 tmp; + readVarInt(tmp, istr); + x = static_cast<Int32>(tmp); +} + +inline void readVarUInt(UInt16 & x, ReadBuffer & istr) +{ + UInt64 tmp; + readVarUInt(tmp, istr); + x = tmp; +} + +inline void readVarInt(Int16 & x, ReadBuffer & istr) +{ + Int64 tmp; + readVarInt(tmp, istr); + x = tmp; +} + +template <typename T> +requires (!std::is_same_v<T, UInt64>) +inline void readVarUInt(T & x, ReadBuffer & istr) +{ + UInt64 tmp; + readVarUInt(tmp, istr); + x = tmp; +} + +inline size_t getLengthOfVarUInt(UInt64 x) +{ + return x < (1ULL << 7) ? 1 + : (x < (1ULL << 14) ? 2 + : (x < (1ULL << 21) ? 3 + : (x < (1ULL << 28) ? 4 + : (x < (1ULL << 35) ? 5 + : (x < (1ULL << 42) ? 6 + : (x < (1ULL << 49) ? 7 + : (x < (1ULL << 56) ? 8 + : (x < (1ULL << 63) ? 9 + : 10)))))))); +} + + +inline size_t getLengthOfVarInt(Int64 x) +{ + return getLengthOfVarUInt(static_cast<UInt64>((x << 1) ^ (x >> 63))); +} + +} diff --git a/contrib/clickhouse/src/IO/WithFileName.cpp b/contrib/clickhouse/src/IO/WithFileName.cpp new file mode 100644 index 0000000000..2383182f7e --- /dev/null +++ b/contrib/clickhouse/src/IO/WithFileName.cpp @@ -0,0 +1,39 @@ +#include <IO/WithFileName.h> +#include <IO/CompressedReadBufferWrapper.h> +#include <IO/ParallelReadBuffer.h> +#include <IO/PeekableReadBuffer.h> + +namespace DB +{ + +template <typename T> +static String getFileName(const T & entry) +{ + if (const auto * with_file_name = dynamic_cast<const WithFileName *>(&entry)) + return with_file_name->getFileName(); + return ""; +} + +String getFileNameFromReadBuffer(const ReadBuffer & in) +{ + if (const auto * compressed = dynamic_cast<const CompressedReadBufferWrapper *>(&in)) + return getFileName(compressed->getWrappedReadBuffer()); + else if (const auto * parallel = dynamic_cast<const ParallelReadBuffer *>(&in)) + return getFileName(parallel->getReadBuffer()); + else if (const auto * peekable = dynamic_cast<const PeekableReadBuffer *>(&in)) + return getFileNameFromReadBuffer(peekable->getSubBuffer()); + else + return getFileName(in); +} + +String getExceptionEntryWithFileName(const ReadBuffer & in) +{ + auto filename = getFileNameFromReadBuffer(in); + + if (filename.empty()) + return ""; + + return fmt::format(": While reading from: {}", filename); +} + +} diff --git a/contrib/clickhouse/src/IO/WithFileName.h b/contrib/clickhouse/src/IO/WithFileName.h new file mode 100644 index 0000000000..595f1a768c --- /dev/null +++ b/contrib/clickhouse/src/IO/WithFileName.h @@ -0,0 +1,19 @@ +#pragma once +#include <base/types.h> + +namespace DB +{ + +class ReadBuffer; + +class WithFileName +{ +public: + virtual String getFileName() const = 0; + virtual ~WithFileName() = default; +}; + +String getFileNameFromReadBuffer(const ReadBuffer & in); +String getExceptionEntryWithFileName(const ReadBuffer & in); + +} diff --git a/contrib/clickhouse/src/IO/WithFileSize.cpp b/contrib/clickhouse/src/IO/WithFileSize.cpp new file mode 100644 index 0000000000..3660d962c0 --- /dev/null +++ b/contrib/clickhouse/src/IO/WithFileSize.cpp @@ -0,0 +1,86 @@ +#include "WithFileSize.h" +#include <IO/ReadBufferFromFile.h> +#include <IO/CompressedReadBufferWrapper.h> +#include <IO/ParallelReadBuffer.h> +#include <IO/ReadBufferFromFileDecorator.h> +#include <IO/PeekableReadBuffer.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_FILE_SIZE; +} + +template <typename T> +static size_t getFileSize(T & in) +{ + if (auto * with_file_size = dynamic_cast<WithFileSize *>(&in)) + { + return with_file_size->getFileSize(); + } + + throw Exception(ErrorCodes::UNKNOWN_FILE_SIZE, "Cannot find out file size"); +} + +size_t getFileSizeFromReadBuffer(ReadBuffer & in) +{ + if (auto * delegate = dynamic_cast<ReadBufferFromFileDecorator *>(&in)) + { + return getFileSize(delegate->getWrappedReadBuffer()); + } + else if (auto * compressed = dynamic_cast<CompressedReadBufferWrapper *>(&in)) + { + return getFileSize(compressed->getWrappedReadBuffer()); + } + + return getFileSize(in); +} + +std::optional<size_t> tryGetFileSizeFromReadBuffer(ReadBuffer & in) +{ + try + { + return getFileSizeFromReadBuffer(in); + } + catch (...) + { + return std::nullopt; + } +} + +bool isBufferWithFileSize(const ReadBuffer & in) +{ + if (const auto * delegate = dynamic_cast<const ReadBufferFromFileDecorator *>(&in)) + { + return delegate->isWithFileSize(); + } + else if (const auto * compressed = dynamic_cast<const CompressedReadBufferWrapper *>(&in)) + { + return isBufferWithFileSize(compressed->getWrappedReadBuffer()); + } + + return dynamic_cast<const WithFileSize *>(&in) != nullptr; +} + +size_t getDataOffsetMaybeCompressed(const ReadBuffer & in) +{ + if (const auto * delegate = dynamic_cast<const ReadBufferFromFileDecorator *>(&in)) + { + return getDataOffsetMaybeCompressed(delegate->getWrappedReadBuffer()); + } + else if (const auto * compressed = dynamic_cast<const CompressedReadBufferWrapper *>(&in)) + { + return getDataOffsetMaybeCompressed(compressed->getWrappedReadBuffer()); + } + else if (const auto * peekable = dynamic_cast<const PeekableReadBuffer *>(&in)) + { + return getDataOffsetMaybeCompressed(peekable->getSubBuffer()); + } + + return in.count(); +} + + +} diff --git a/contrib/clickhouse/src/IO/WithFileSize.h b/contrib/clickhouse/src/IO/WithFileSize.h new file mode 100644 index 0000000000..0ae3af98ea --- /dev/null +++ b/contrib/clickhouse/src/IO/WithFileSize.h @@ -0,0 +1,26 @@ +#pragma once +#include <base/types.h> +#include <optional> + +namespace DB +{ + +class ReadBuffer; + +class WithFileSize +{ +public: + virtual size_t getFileSize() = 0; + virtual ~WithFileSize() = default; +}; + +bool isBufferWithFileSize(const ReadBuffer & in); + +size_t getFileSizeFromReadBuffer(ReadBuffer & in); + +/// Return nullopt if couldn't find out file size; +std::optional<size_t> tryGetFileSizeFromReadBuffer(ReadBuffer & in); + +size_t getDataOffsetMaybeCompressed(const ReadBuffer & in); + +} diff --git a/contrib/clickhouse/src/IO/WriteBuffer.cpp b/contrib/clickhouse/src/IO/WriteBuffer.cpp new file mode 100644 index 0000000000..61fdd31e16 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBuffer.cpp @@ -0,0 +1,33 @@ +#include "WriteBuffer.h" + +#include <Common/logger_useful.h> + +namespace DB +{ + +/// Calling finalize() in the destructor of derived classes is a bad practice. +/// This causes objects to be left on the remote FS when a write operation is rolled back. +/// Do call finalize() explicitly, before this call you have no guarantee that the file has been written +WriteBuffer::~WriteBuffer() +{ + // That destructor could be call with finalized=false in case of exceptions + if (count() > 0 && !finalized) + { + /// It is totally OK to destroy instance without finalization when an exception occurs + /// However it is suspicious to destroy instance without finalization at the green path + if (!std::uncaught_exceptions() && std::current_exception() == nullptr) + { + Poco::Logger * log = &Poco::Logger::get("WriteBuffer"); + LOG_ERROR( + log, + "WriteBuffer is not finalized when destructor is called. " + "No exceptions in flight are detected. " + "The file might not be written at all or might be truncated. " + "Stack trace: {}", + StackTrace().toString()); + chassert(false && "WriteBuffer is not finalized in destructor."); + } + } +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBuffer.h b/contrib/clickhouse/src/IO/WriteBuffer.h new file mode 100644 index 0000000000..d29ca6d5c6 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBuffer.h @@ -0,0 +1,181 @@ +#pragma once + +#include <algorithm> +#include <memory> +#include <cassert> +#include <cstring> + +#include <Common/Exception.h> +#include <Common/LockMemoryExceptionInThread.h> +#include <IO/BufferBase.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_WRITE_AFTER_END_OF_BUFFER; + extern const int LOGICAL_ERROR; +} + + +/** A simple abstract class for buffered data writing (char sequences) somewhere. + * Unlike std::ostream, it provides access to the internal buffer, + * and also allows you to manually manage the position inside the buffer. + * + * Derived classes must implement the nextImpl() method. + */ +class WriteBuffer : public BufferBase +{ +public: + using BufferBase::set; + using BufferBase::position; + void set(Position ptr, size_t size) { BufferBase::set(ptr, size, 0); } + + /** write the data in the buffer (from the beginning of the buffer to the current position); + * set the position to the beginning; throw an exception, if something is wrong + */ + inline void next() + { + if (!offset()) + return; + + auto bytes_in_buffer = offset(); + + try + { + nextImpl(); + } + catch (...) + { + /** If the nextImpl() call was unsuccessful, move the cursor to the beginning, + * so that later (for example, when the stack was expanded) there was no second attempt to write data. + */ + pos = working_buffer.begin(); + bytes += bytes_in_buffer; + throw; + } + + bytes += bytes_in_buffer; + pos = working_buffer.begin(); + } + + /// Calling finalize() in the destructor of derived classes is a bad practice. + virtual ~WriteBuffer(); + + inline void nextIfAtEnd() + { + if (!hasPendingData()) + next(); + } + + + void write(const char * from, size_t n) + { + if (finalized) + throw Exception{ErrorCodes::LOGICAL_ERROR, "Cannot write to finalized buffer"}; + + size_t bytes_copied = 0; + + /// Produces endless loop + assert(!working_buffer.empty()); + + while (bytes_copied < n) + { + nextIfAtEnd(); + size_t bytes_to_copy = std::min(static_cast<size_t>(working_buffer.end() - pos), n - bytes_copied); + memcpy(pos, from + bytes_copied, bytes_to_copy); + pos += bytes_to_copy; + bytes_copied += bytes_to_copy; + } + } + + inline void write(char x) + { + if (finalized) + throw Exception{ErrorCodes::LOGICAL_ERROR, "Cannot write to finalized buffer"}; + + nextIfAtEnd(); + *pos = x; + ++pos; + } + + /// This method may be called before finalize() to tell there would not be any more data written. + /// Used does not have to call it, implementation should check it itself if needed. + /// + /// The idea is similar to prefetch. In case if all data is written, we can flush the buffer + /// and start sending data asynchronously. It may improve writing performance in case you have + /// multiple files to finalize. Mainly, for blob storage, finalization has high latency, + /// and calling preFinalize in a loop may parallelize it. + virtual void preFinalize() { next(); } + + /// Write the last data. + void finalize() + { + if (finalized) + return; + + LockMemoryExceptionInThread lock(VariableContext::Global); + try + { + finalizeImpl(); + finalized = true; + } + catch (...) + { + pos = working_buffer.begin(); + finalized = true; + throw; + } + } + + /// Wait for data to be reliably written. Mainly, call fsync for fd. + /// May be called after finalize() if needed. + virtual void sync() + { + next(); + } + +protected: + WriteBuffer(Position ptr, size_t size) : BufferBase(ptr, size, 0) {} + + virtual void finalizeImpl() + { + next(); + } + + bool finalized = false; + +private: + /** Write the data in the buffer (from the beginning of the buffer to the current position). + * Throw an exception if something is wrong. + */ + virtual void nextImpl() + { + throw Exception(ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER, "Cannot write after end of buffer."); + } +}; + + +using WriteBufferPtr = std::shared_ptr<WriteBuffer>; + + +class WriteBufferFromPointer : public WriteBuffer +{ +public: + WriteBufferFromPointer(Position ptr, size_t size) : WriteBuffer(ptr, size) {} + +private: + virtual void finalizeImpl() override + { + /// no op + } + + virtual void sync() override + { + /// no on + } +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferDecorator.h b/contrib/clickhouse/src/IO/WriteBufferDecorator.h new file mode 100644 index 0000000000..7c984eeea8 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferDecorator.h @@ -0,0 +1,55 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <utility> +#include <memory> + +namespace DB +{ + +class WriteBuffer; + +/// WriteBuffer that decorates data and delegates it to underlying buffer. +/// It's used for writing compressed and encrypted data +template <class Base> +class WriteBufferDecorator : public Base +{ +public: + template <class ... BaseArgs> + explicit WriteBufferDecorator(std::unique_ptr<WriteBuffer> out_, BaseArgs && ... args) + : Base(std::forward<BaseArgs>(args)...), out(std::move(out_)) + { + } + + void finalizeImpl() override + { + try + { + finalizeBefore(); + out->finalize(); + finalizeAfter(); + } + catch (...) + { + /// Do not try to flush next time after exception. + out->position() = out->buffer().begin(); + throw; + } + } + + WriteBuffer * getNestedBuffer() { return out.get(); } + +protected: + /// Do some finalization before finalization of underlying buffer. + virtual void finalizeBefore() {} + + /// Do some finalization after finalization of underlying buffer. + virtual void finalizeAfter() {} + + std::unique_ptr<WriteBuffer> out; +}; + +using WriteBufferWithOwnMemoryDecorator = WriteBufferDecorator<BufferWithOwnMemory<WriteBuffer>>; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromArena.h b/contrib/clickhouse/src/IO/WriteBufferFromArena.h new file mode 100644 index 0000000000..8e9276496b --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromArena.h @@ -0,0 +1,73 @@ +#pragma once + +#include <Common/Arena.h> +#include <base/StringRef.h> +#include <IO/WriteBuffer.h> + + +namespace DB +{ + +/** Writes data contiguously into Arena. + * As it will be located in contiguous memory segment, it can be read back with ReadBufferFromMemory. + * + * While using this object, no other allocations in arena are possible. + */ +class WriteBufferFromArena final : public WriteBuffer +{ +public: + /// begin_ - start of previously used contiguous memory segment or nullptr (see Arena::allocContinue method). + WriteBufferFromArena(Arena & arena_, const char *& begin_) + : WriteBuffer(nullptr, 0), arena(arena_), begin(begin_) + { + nextImpl(); + pos = working_buffer.begin(); + } + + StringRef complete() + { + /// Return over-allocated memory back into arena. + arena.rollback(buffer().end() - position()); + /// Reference to written data. + return { position() - count(), count() }; + } + +private: + Arena & arena; + const char *& begin; + + void nextImpl() override + { + /// Allocate more memory. At least same size as used before (this gives 2x growth ratio), + /// and at most grab all remaining size in current chunk of arena. + /// + /// FIXME this class just doesn't make sense -- WriteBuffer is not + /// a unified interface for everything, it doesn't work well with + /// Arena::allocContinue -- we lose the size of data and then use a + /// heuristic to guess it back? and make a virtual call while we're at it? + /// I don't even.. + /// Being so ill-defined as it is, no wonder that the following line had + /// a bug leading to a very rare infinite loop. Just hack around it in + /// the most stupid way possible, because the real fix for this is to + /// tear down the entire WriteBuffer thing and implement it again, + /// properly. + size_t continuation_size = std::max(size_t(1), + std::max(count(), arena.remainingSpaceInCurrentMemoryChunk())); + + /// allocContinue method will possibly move memory region to new place and modify "begin" pointer. + + char * continuation = arena.allocContinue(continuation_size, begin); + char * end = continuation + continuation_size; + + /// internal buffer points to whole memory segment and working buffer - to free space for writing. + internalBuffer() = Buffer(const_cast<char *>(begin), end); + buffer() = Buffer(continuation, end); + } + + /// it is super strange, + /// but addition next call changes the data in serializeValueIntoArena result + virtual void finalizeImpl() override { /* no op */ } +}; + +} + diff --git a/contrib/clickhouse/src/IO/WriteBufferFromEncryptedFile.cpp b/contrib/clickhouse/src/IO/WriteBufferFromEncryptedFile.cpp new file mode 100644 index 0000000000..5bca0dc68d --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromEncryptedFile.cpp @@ -0,0 +1,61 @@ +#include <IO/WriteBufferFromEncryptedFile.h> + +#if USE_SSL + +namespace DB +{ + +WriteBufferFromEncryptedFile::WriteBufferFromEncryptedFile( + size_t buffer_size_, + std::unique_ptr<WriteBufferFromFileBase> out_, + const String & key_, + const FileEncryption::Header & header_, + size_t old_file_size) + : WriteBufferDecorator<WriteBufferFromFileBase>(std::move(out_), buffer_size_, nullptr, 0) + , header(header_) + , flush_header(!old_file_size) + , encryptor(header.algorithm, key_, header.init_vector) +{ + encryptor.setOffset(old_file_size); +} + +WriteBufferFromEncryptedFile::~WriteBufferFromEncryptedFile() +{ + finalize(); +} + +void WriteBufferFromEncryptedFile::finalizeBefore() +{ + /// If buffer has pending data - write it. + next(); + + /// Note that if there is no data to write an empty file will be written, even without the initialization vector + /// (see nextImpl(): it writes the initialization vector only if there is some data ready to write). + /// That's fine because DiskEncrypted allows files without initialization vectors when they're empty. +} + +void WriteBufferFromEncryptedFile::sync() +{ + /// If buffer has pending data - write it. + next(); + + out->sync(); +} + +void WriteBufferFromEncryptedFile::nextImpl() +{ + if (!offset()) + return; + + if (flush_header) + { + header.write(*out); + flush_header = false; + } + + encryptor.encrypt(working_buffer.begin(), offset(), *out); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromEncryptedFile.h b/contrib/clickhouse/src/IO/WriteBufferFromEncryptedFile.h new file mode 100644 index 0000000000..12c1ba5f6f --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromEncryptedFile.h @@ -0,0 +1,46 @@ +#pragma once + +#include "clickhouse_config.h" +#include <Common/assert_cast.h> + +#if USE_SSL +#include <IO/WriteBufferFromFileBase.h> +#include <IO/FileEncryptionCommon.h> +#include <IO/WriteBufferDecorator.h> + + +namespace DB +{ + +/// Encrypts data and writes the encrypted data to the underlying write buffer. +class WriteBufferFromEncryptedFile : public WriteBufferDecorator<WriteBufferFromFileBase> +{ +public: + /// `old_file_size` should be set to non-zero if we're going to append an existing file. + WriteBufferFromEncryptedFile( + size_t buffer_size_, + std::unique_ptr<WriteBufferFromFileBase> out_, + const String & key_, + const FileEncryption::Header & header_, + size_t old_file_size = 0); + + ~WriteBufferFromEncryptedFile() override; + + void sync() override; + + std::string getFileName() const override { return assert_cast<WriteBufferFromFileBase *>(out.get())->getFileName(); } + +private: + void nextImpl() override; + + void finalizeBefore() override; + + FileEncryption::Header header; + bool flush_header = false; + + FileEncryption::Encryptor encryptor; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFile.cpp b/contrib/clickhouse/src/IO/WriteBufferFromFile.cpp new file mode 100644 index 0000000000..97059ff8f4 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFile.cpp @@ -0,0 +1,115 @@ +#include <sys/stat.h> +#include <fcntl.h> +#include <cerrno> + +#include <Common/ProfileEvents.h> +#include <base/defines.h> + +#include <IO/WriteBufferFromFile.h> +#include <IO/WriteHelpers.h> + + +namespace ProfileEvents +{ + extern const Event FileOpen; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FILE_DOESNT_EXIST; + extern const int CANNOT_OPEN_FILE; + extern const int CANNOT_CLOSE_FILE; +} + + +WriteBufferFromFile::WriteBufferFromFile( + const std::string & file_name_, + size_t buf_size, + int flags, + ThrottlerPtr throttler_, + mode_t mode, + char * existing_memory, + size_t alignment) + : WriteBufferFromFileDescriptor(-1, buf_size, existing_memory, throttler_, alignment, file_name_) +{ + ProfileEvents::increment(ProfileEvents::FileOpen); + +#ifdef OS_DARWIN + bool o_direct = (flags != -1) && (flags & O_DIRECT); + if (o_direct) + flags = flags & ~O_DIRECT; +#endif + + fd = ::open(file_name.c_str(), flags == -1 ? O_WRONLY | O_TRUNC | O_CREAT | O_CLOEXEC : flags | O_CLOEXEC, mode); + + if (-1 == fd) + throwFromErrnoWithPath("Cannot open file " + file_name, file_name, + errno == ENOENT ? ErrorCodes::FILE_DOESNT_EXIST : ErrorCodes::CANNOT_OPEN_FILE); + +#ifdef OS_DARWIN + if (o_direct) + { + if (fcntl(fd, F_NOCACHE, 1) == -1) + throwFromErrnoWithPath("Cannot set F_NOCACHE on file " + file_name, file_name, ErrorCodes::CANNOT_OPEN_FILE); + } +#endif +} + + +/// Use pre-opened file descriptor. +WriteBufferFromFile::WriteBufferFromFile( + int & fd_, + const std::string & original_file_name, + size_t buf_size, + ThrottlerPtr throttler_, + char * existing_memory, + size_t alignment) + : WriteBufferFromFileDescriptor(fd_, buf_size, existing_memory, throttler_, alignment, original_file_name) +{ + fd_ = -1; +} + +WriteBufferFromFile::~WriteBufferFromFile() +{ + if (fd < 0) + return; + + finalize(); + int err = ::close(fd); + /// Everything except for EBADF should be ignored in dtor, since all of + /// others (EINTR/EIO/ENOSPC/EDQUOT) could be possible during writing to + /// fd, and then write already failed and the error had been reported to + /// the user/caller. + /// + /// Note, that for close() on Linux, EINTR should *not* be retried. + chassert(!(err && errno == EBADF)); +} + +void WriteBufferFromFile::finalizeImpl() +{ + if (fd < 0) + return; + + next(); +} + + +/// Close file before destruction of object. +void WriteBufferFromFile::close() +{ + if (fd < 0) + return; + + finalize(); + + if (0 != ::close(fd)) + throw Exception(ErrorCodes::CANNOT_CLOSE_FILE, "Cannot close file"); + + fd = -1; + metric_increment.destroy(); +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFile.h b/contrib/clickhouse/src/IO/WriteBufferFromFile.h new file mode 100644 index 0000000000..57847d893a --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFile.h @@ -0,0 +1,64 @@ +#pragma once + +#include <sys/types.h> + +#include <Common/CurrentMetrics.h> +#include <Common/Throttler_fwd.h> +#include <IO/WriteBufferFromFileDescriptor.h> + + +namespace CurrentMetrics +{ + extern const Metric OpenFileForWrite; +} + + +#ifndef O_DIRECT +#define O_DIRECT 00040000 +#endif + +namespace DB +{ + +/** Accepts path to file and opens it, or pre-opened file descriptor. + * Closes file by himself (thus "owns" a file descriptor). + */ +class WriteBufferFromFile : public WriteBufferFromFileDescriptor +{ +protected: + CurrentMetrics::Increment metric_increment{CurrentMetrics::OpenFileForWrite}; + +public: + explicit WriteBufferFromFile( + const std::string & file_name_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + int flags = -1, + ThrottlerPtr throttler_ = {}, + mode_t mode = 0666, + char * existing_memory = nullptr, + size_t alignment = 0); + + /// Use pre-opened file descriptor. + explicit WriteBufferFromFile( + int & fd, /// Will be set to -1 if constructor didn't throw and ownership of file descriptor is passed to the object. + const std::string & original_file_name = {}, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + ThrottlerPtr throttler_ = {}, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~WriteBufferFromFile() override; + + /// Close file before destruction of object. + void close(); + + std::string getFileName() const override + { + return file_name; + } + +private: + void finalizeImpl() override; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileBase.cpp b/contrib/clickhouse/src/IO/WriteBufferFromFileBase.cpp new file mode 100644 index 0000000000..2b9cbb88cd --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileBase.cpp @@ -0,0 +1,11 @@ +#include <IO/WriteBufferFromFileBase.h> + +namespace DB +{ + +WriteBufferFromFileBase::WriteBufferFromFileBase(size_t buf_size, char * existing_memory, size_t alignment) + : BufferWithOwnMemory<WriteBuffer>(buf_size, existing_memory, alignment) +{ +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileBase.h b/contrib/clickhouse/src/IO/WriteBufferFromFileBase.h new file mode 100644 index 0000000000..d6e2144bcc --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileBase.h @@ -0,0 +1,21 @@ +#pragma once + +#include <string> +#include <fcntl.h> + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> + +namespace DB +{ + +class WriteBufferFromFileBase : public BufferWithOwnMemory<WriteBuffer> +{ +public: + WriteBufferFromFileBase(size_t buf_size, char * existing_memory, size_t alignment); + + void sync() override = 0; + virtual std::string getFileName() const = 0; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileDecorator.cpp b/contrib/clickhouse/src/IO/WriteBufferFromFileDecorator.cpp new file mode 100644 index 0000000000..0e4e5e13a8 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileDecorator.cpp @@ -0,0 +1,74 @@ +#include "WriteBufferFromFileDecorator.h" + +#include <IO/WriteBuffer.h> +#include <IO/SwapHelper.h> + +namespace DB +{ + +WriteBufferFromFileDecorator::WriteBufferFromFileDecorator(std::unique_ptr<WriteBuffer> impl_) + : WriteBufferFromFileBase(0, nullptr, 0), impl(std::move(impl_)) +{ + swap(*impl); +} + +void WriteBufferFromFileDecorator::finalizeImpl() +{ + + /// In case of exception in preFinalize as a part of finalize call + /// WriteBufferFromFileDecorator.finalized is set as true + /// but impl->finalized is remain false + /// That leads to situation when the destructor of impl is called with impl->finalized equal false. + if (!is_prefinalized) + WriteBufferFromFileDecorator::preFinalize(); + + { + SwapHelper swap(*this, *impl); + impl->finalize(); + } +} + +WriteBufferFromFileDecorator::~WriteBufferFromFileDecorator() +{ + /// It is not a mistake that swap is called here + /// Swap has been called at constructor, it should be called at destructor + /// In oreder to provide valid buffer for impl's d-tor call + swap(*impl); +} + +void WriteBufferFromFileDecorator::sync() +{ + next(); + + { + SwapHelper swap(*this, *impl); + impl->sync(); + } +} + +std::string WriteBufferFromFileDecorator::getFileName() const +{ + if (WriteBufferFromFileBase * buffer = dynamic_cast<WriteBufferFromFileBase*>(impl.get())) + return buffer->getFileName(); + return std::string(); +} + +void WriteBufferFromFileDecorator::preFinalize() +{ + next(); + + { + SwapHelper swap(*this, *impl); + impl->preFinalize(); + } + + is_prefinalized = true; +} + +void WriteBufferFromFileDecorator::nextImpl() +{ + SwapHelper swap(*this, *impl); + impl->next(); +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileDecorator.h b/contrib/clickhouse/src/IO/WriteBufferFromFileDecorator.h new file mode 100644 index 0000000000..5344bb1425 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileDecorator.h @@ -0,0 +1,35 @@ +#pragma once + +#include <IO/WriteBufferFromFileBase.h> + +namespace DB +{ + +/// Delegates all writes to underlying buffer. Doesn't have own memory. +class WriteBufferFromFileDecorator : public WriteBufferFromFileBase +{ +public: + explicit WriteBufferFromFileDecorator(std::unique_ptr<WriteBuffer> impl_); + + ~WriteBufferFromFileDecorator() override; + + void sync() override; + + std::string getFileName() const override; + + void preFinalize() override; + + const WriteBuffer & getImpl() const { return *impl; } + +protected: + void finalizeImpl() override; + + std::unique_ptr<WriteBuffer> impl; + +private: + void nextImpl() override; + + bool is_prefinalized = false; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptor.cpp b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptor.cpp new file mode 100644 index 0000000000..135ff60896 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptor.cpp @@ -0,0 +1,179 @@ +#include <unistd.h> +#include <cerrno> +#include <cassert> +#include <sys/stat.h> + +#include <Common/Throttler.h> +#include <Common/Exception.h> +#include <Common/ProfileEvents.h> +#include <Common/CurrentMetrics.h> +#include <Common/Stopwatch.h> + +#include <IO/WriteBufferFromFileDescriptor.h> +#include <IO/WriteHelpers.h> + + +namespace ProfileEvents +{ + extern const Event WriteBufferFromFileDescriptorWrite; + extern const Event WriteBufferFromFileDescriptorWriteFailed; + extern const Event WriteBufferFromFileDescriptorWriteBytes; + extern const Event DiskWriteElapsedMicroseconds; + extern const Event FileSync; + extern const Event FileSyncElapsedMicroseconds; + extern const Event LocalWriteThrottlerBytes; + extern const Event LocalWriteThrottlerSleepMicroseconds; +} + +namespace CurrentMetrics +{ + extern const Metric Write; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_WRITE_TO_FILE_DESCRIPTOR; + extern const int CANNOT_FSYNC; + extern const int CANNOT_SEEK_THROUGH_FILE; + extern const int CANNOT_TRUNCATE_FILE; + extern const int CANNOT_FSTAT; +} + + +void WriteBufferFromFileDescriptor::nextImpl() +{ + if (!offset()) + return; + + Stopwatch watch; + + size_t bytes_written = 0; + while (bytes_written != offset()) + { + ProfileEvents::increment(ProfileEvents::WriteBufferFromFileDescriptorWrite); + + ssize_t res = 0; + { + CurrentMetrics::Increment metric_increment{CurrentMetrics::Write}; + res = ::write(fd, working_buffer.begin() + bytes_written, offset() - bytes_written); + } + + if ((-1 == res || 0 == res) && errno != EINTR) + { + ProfileEvents::increment(ProfileEvents::WriteBufferFromFileDescriptorWriteFailed); + + /// Don't use getFileName() here because this method can be called from destructor + String error_file_name = file_name; + if (error_file_name.empty()) + error_file_name = "(fd = " + toString(fd) + ")"; + throwFromErrnoWithPath("Cannot write to file " + error_file_name, error_file_name, + ErrorCodes::CANNOT_WRITE_TO_FILE_DESCRIPTOR); + } + + if (res > 0) + { + bytes_written += res; + if (throttler) + throttler->add(res, ProfileEvents::LocalWriteThrottlerBytes, ProfileEvents::LocalWriteThrottlerSleepMicroseconds); + } + } + + ProfileEvents::increment(ProfileEvents::DiskWriteElapsedMicroseconds, watch.elapsedMicroseconds()); + ProfileEvents::increment(ProfileEvents::WriteBufferFromFileDescriptorWriteBytes, bytes_written); +} + +/// NOTE: This class can be used as a very low-level building block, for example +/// in trace collector. In such places allocations of memory can be dangerous, +/// so don't allocate anything in this constructor. +WriteBufferFromFileDescriptor::WriteBufferFromFileDescriptor( + int fd_, + size_t buf_size, + char * existing_memory, + ThrottlerPtr throttler_, + size_t alignment, + std::string file_name_) + : WriteBufferFromFileBase(buf_size, existing_memory, alignment) + , fd(fd_) + , throttler(throttler_) + , file_name(std::move(file_name_)) +{ +} + + +WriteBufferFromFileDescriptor::~WriteBufferFromFileDescriptor() +{ + finalize(); +} + +void WriteBufferFromFileDescriptor::finalizeImpl() +{ + if (fd < 0) + { + assert(!offset() && "attempt to write after close"); + return; + } + + next(); +} + +void WriteBufferFromFileDescriptor::sync() +{ + /// If buffer has pending data - write it. + next(); + + ProfileEvents::increment(ProfileEvents::FileSync); + + Stopwatch watch; + + /// Request OS to sync data with storage medium. +#if defined(OS_DARWIN) + int res = ::fsync(fd); +#else + int res = ::fdatasync(fd); +#endif + ProfileEvents::increment(ProfileEvents::FileSyncElapsedMicroseconds, watch.elapsedMicroseconds()); + + if (-1 == res) + throwFromErrnoWithPath("Cannot fsync " + getFileName(), getFileName(), ErrorCodes::CANNOT_FSYNC); +} + + +off_t WriteBufferFromFileDescriptor::seek(off_t offset, int whence) // NOLINT +{ + off_t res = lseek(fd, offset, whence); + if (-1 == res) + throwFromErrnoWithPath("Cannot seek through file " + getFileName(), getFileName(), + ErrorCodes::CANNOT_SEEK_THROUGH_FILE); + return res; +} + +void WriteBufferFromFileDescriptor::truncate(off_t length) // NOLINT +{ + int res = ftruncate(fd, length); + if (-1 == res) + throwFromErrnoWithPath("Cannot truncate file " + getFileName(), getFileName(), ErrorCodes::CANNOT_TRUNCATE_FILE); +} + + +off_t WriteBufferFromFileDescriptor::size() const +{ + struct stat buf; + int res = fstat(fd, &buf); + if (-1 == res) + throwFromErrnoWithPath("Cannot execute fstat " + getFileName(), getFileName(), ErrorCodes::CANNOT_FSTAT); + return buf.st_size; +} + +std::string WriteBufferFromFileDescriptor::getFileName() const +{ + if (file_name.empty()) + return "(fd = " + toString(fd) + ")"; + + return file_name; +} + + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptor.h b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptor.h new file mode 100644 index 0000000000..cb73b1e1d0 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptor.h @@ -0,0 +1,62 @@ +#pragma once + +#include <IO/WriteBufferFromFileBase.h> +#include <Common/Throttler_fwd.h> + + +namespace DB +{ + +/** Use ready file descriptor. Does not open or close a file. + */ +class WriteBufferFromFileDescriptor : public WriteBufferFromFileBase +{ +public: + explicit WriteBufferFromFileDescriptor( + int fd_ = -1, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + ThrottlerPtr throttler_ = {}, + size_t alignment = 0, + std::string file_name_ = ""); + + /** Could be used before initialization if needed 'fd' was not passed to constructor. + * It's not possible to change 'fd' during work. + */ + void setFD(int fd_) + { + fd = fd_; + } + + ~WriteBufferFromFileDescriptor() override; + + int getFD() const + { + return fd; + } + + void sync() override; + + /// clang-tidy wants these methods to be const, but + /// they are not const semantically + off_t seek(off_t offset, int whence); // NOLINT + void truncate(off_t length); // NOLINT + + /// Name or some description of file. + std::string getFileName() const override; + + off_t size() const; + +protected: + void nextImpl() override; + + int fd; + ThrottlerPtr throttler; + + /// If file has name contains filename, otherwise contains string "(fd=...)" + std::string file_name; + + void finalizeImpl() override; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptorDiscardOnFailure.cpp b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptorDiscardOnFailure.cpp new file mode 100644 index 0000000000..69be24f0fa --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptorDiscardOnFailure.cpp @@ -0,0 +1,32 @@ +#include <IO/WriteBufferFromFileDescriptorDiscardOnFailure.h> + +namespace ProfileEvents +{ + extern const Event CannotWriteToWriteBufferDiscard; +} + +namespace DB +{ + +void WriteBufferFromFileDescriptorDiscardOnFailure::nextImpl() +{ + size_t bytes_written = 0; + while (bytes_written != offset()) + { + ssize_t res = ::write(fd, working_buffer.begin() + bytes_written, offset() - bytes_written); + + if ((-1 == res || 0 == res) && errno != EINTR) + { + /// Never send this profile event to trace log because it may cause another + /// write into the same fd and likely will trigger the same error + /// and will lead to infinite recursion. + ProfileEvents::incrementNoTrace(ProfileEvents::CannotWriteToWriteBufferDiscard); + break; /// Discard + } + + if (res > 0) + bytes_written += res; + } +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptorDiscardOnFailure.h b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptorDiscardOnFailure.h new file mode 100644 index 0000000000..2803dd4e8b --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromFileDescriptorDiscardOnFailure.h @@ -0,0 +1,23 @@ +#pragma once + +#include <IO/WriteBufferFromFileDescriptor.h> + + +namespace DB +{ + +/** Write to file descriptor but drop the data if write would block or fail. + * To use within signal handler. Motivating example: a signal handler invoked during execution of malloc + * should not block because some mutex (or even worse - a spinlock) may be held. + */ +class WriteBufferFromFileDescriptorDiscardOnFailure : public WriteBufferFromFileDescriptor +{ +protected: + void nextImpl() override; + +public: + using WriteBufferFromFileDescriptor::WriteBufferFromFileDescriptor; + ~WriteBufferFromFileDescriptorDiscardOnFailure() override = default; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromHTTP.cpp b/contrib/clickhouse/src/IO/WriteBufferFromHTTP.cpp new file mode 100644 index 0000000000..056b965266 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromHTTP.cpp @@ -0,0 +1,50 @@ +#include <IO/WriteBufferFromHTTP.h> + +#include <Common/logger_useful.h> + + +namespace DB +{ + +WriteBufferFromHTTP::WriteBufferFromHTTP( + const Poco::URI & uri, + const std::string & method, + const std::string & content_type, + const std::string & content_encoding, + const HTTPHeaderEntries & additional_headers, + const ConnectionTimeouts & timeouts, + size_t buffer_size_, + Poco::Net::HTTPClientSession::ProxyConfig proxy_configuration) + : WriteBufferFromOStream(buffer_size_) + , session{makeHTTPSession(uri, timeouts, proxy_configuration)} + , request{method, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1} +{ + request.setHost(uri.getHost()); + request.setChunkedTransferEncoding(true); + + if (!content_type.empty()) + { + request.set("Content-Type", content_type); + } + + if (!content_encoding.empty()) + request.set("Content-Encoding", content_encoding); + + for (const auto & header: additional_headers) + request.add(header.name, header.value); + + LOG_TRACE((&Poco::Logger::get("WriteBufferToHTTP")), "Sending request to {}", uri.toString()); + + ostr = &session->sendRequest(request); +} + +void WriteBufferFromHTTP::finalizeImpl() +{ + // Make sure the content in the buffer has been flushed + this->next(); + + receiveResponse(*session, request, response, false); + /// TODO: Response body is ignored. +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromHTTP.h b/contrib/clickhouse/src/IO/WriteBufferFromHTTP.h new file mode 100644 index 0000000000..65dc10213d --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromHTTP.h @@ -0,0 +1,40 @@ +#pragma once + +#include <IO/ConnectionTimeouts.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteBufferFromOStream.h> +#include <IO/HTTPCommon.h> +#include <IO/HTTPHeaderEntries.h> +#include <Poco/Net/HTTPClientSession.h> +#include <Poco/Net/HTTPRequest.h> +#include <Poco/Net/HTTPResponse.h> +#include <Poco/URI.h> + + +namespace DB +{ + +/* Perform HTTP POST/PUT request. + */ +class WriteBufferFromHTTP : public WriteBufferFromOStream +{ +public: + explicit WriteBufferFromHTTP(const Poco::URI & uri, + const std::string & method = Poco::Net::HTTPRequest::HTTP_POST, // POST or PUT only + const std::string & content_type = "", + const std::string & content_encoding = "", + const HTTPHeaderEntries & additional_headers = {}, + const ConnectionTimeouts & timeouts = {}, + size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE, + Poco::Net::HTTPClientSession::ProxyConfig proxy_configuration = {}); + +private: + /// Receives response from the server after sending all data. + void finalizeImpl() override; + + HTTPSessionPtr session; + Poco::Net::HTTPRequest request; + Poco::Net::HTTPResponse response; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromOStream.cpp b/contrib/clickhouse/src/IO/WriteBufferFromOStream.cpp new file mode 100644 index 0000000000..ffc3e62e9a --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromOStream.cpp @@ -0,0 +1,42 @@ +#include <IO/WriteBufferFromOStream.h> +#include <Common/logger_useful.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_WRITE_TO_OSTREAM; +} + +void WriteBufferFromOStream::nextImpl() +{ + if (!offset()) + return; + + ostr->write(working_buffer.begin(), offset()); + ostr->flush(); + + if (!ostr->good()) + throw Exception(ErrorCodes::CANNOT_WRITE_TO_OSTREAM, "Cannot write to ostream at offset {}", count()); +} + +WriteBufferFromOStream::WriteBufferFromOStream( + size_t size, + char * existing_memory, + size_t alignment) + : BufferWithOwnMemory<WriteBuffer>(size, existing_memory, alignment) +{ +} + +WriteBufferFromOStream::WriteBufferFromOStream( + std::ostream & ostr_, + size_t size, + char * existing_memory, + size_t alignment) + : BufferWithOwnMemory<WriteBuffer>(size, existing_memory, alignment), ostr(&ostr_) +{ +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromOStream.h b/contrib/clickhouse/src/IO/WriteBufferFromOStream.h new file mode 100644 index 0000000000..3f9d3ee3d9 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromOStream.h @@ -0,0 +1,29 @@ +#pragma once + +#include <iosfwd> + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> + + +namespace DB +{ + +class WriteBufferFromOStream : public BufferWithOwnMemory<WriteBuffer> +{ +public: + explicit WriteBufferFromOStream( + std::ostream & ostr_, + size_t size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + +protected: + explicit WriteBufferFromOStream(size_t size = DBMS_DEFAULT_BUFFER_SIZE, char * existing_memory = nullptr, size_t alignment = 0); + + void nextImpl() override; + + std::ostream * ostr{}; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromPocoSocket.cpp b/contrib/clickhouse/src/IO/WriteBufferFromPocoSocket.cpp new file mode 100644 index 0000000000..171e7f1ce6 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromPocoSocket.cpp @@ -0,0 +1,140 @@ +#include <Poco/Net/NetException.h> + +#include <base/scope_guard.h> + +#include <IO/WriteBufferFromPocoSocket.h> + +#include <Common/Exception.h> +#include <Common/NetException.h> +#include <Common/Stopwatch.h> +#include <Common/ProfileEvents.h> +#include <Common/CurrentMetrics.h> +#include <Common/AsyncTaskExecutor.h> +#include <Common/checkSSLReturnCode.h> + +namespace ProfileEvents +{ + extern const Event NetworkSendElapsedMicroseconds; + extern const Event NetworkSendBytes; +} + +namespace CurrentMetrics +{ + extern const Metric NetworkSend; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NETWORK_ERROR; + extern const int SOCKET_TIMEOUT; + extern const int CANNOT_WRITE_TO_SOCKET; + extern const int LOGICAL_ERROR; +} + +void WriteBufferFromPocoSocket::nextImpl() +{ + if (!offset()) + return; + + Stopwatch watch; + size_t bytes_written = 0; + + SCOPE_EXIT({ + ProfileEvents::increment(ProfileEvents::NetworkSendElapsedMicroseconds, watch.elapsedMicroseconds()); + ProfileEvents::increment(ProfileEvents::NetworkSendBytes, bytes_written); + }); + + while (bytes_written < offset()) + { + ssize_t res = 0; + + /// Add more details to exceptions. + try + { + CurrentMetrics::Increment metric_increment(CurrentMetrics::NetworkSend); + char * pos = working_buffer.begin() + bytes_written; + size_t size = offset() - bytes_written; + if (size > INT_MAX) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Buffer overflow"); + + /// If async_callback is specified, set socket to non-blocking mode + /// and try to write data to it, if socket is not ready for writing, + /// run async_callback and try again later. + /// It is expected that file descriptor may be polled externally. + /// Note that send timeout is not checked here. External code should check it while polling. + if (async_callback) + { + socket.setBlocking(false); + /// Set socket to blocking mode at the end. + SCOPE_EXIT(socket.setBlocking(true)); + bool secure = socket.secure(); + res = socket.impl()->sendBytes(pos, static_cast<int>(size)); + + /// Check EAGAIN and ERR_SSL_WANT_WRITE/ERR_SSL_WANT_READ for secure socket (writing to secure socket can read too). + while (res < 0 && (errno == EAGAIN || (secure && (checkSSLWantRead(res) || checkSSLWantWrite(res))))) + { + /// In case of ERR_SSL_WANT_READ we should wait for socket to be ready for reading, otherwise - for writing. + if (secure && checkSSLWantRead(res)) + async_callback(socket.impl()->sockfd(), socket.getReceiveTimeout(), AsyncEventTimeoutType::RECEIVE, socket_description, AsyncTaskExecutor::Event::READ | AsyncTaskExecutor::Event::ERROR); + else + async_callback(socket.impl()->sockfd(), socket.getSendTimeout(), AsyncEventTimeoutType::SEND, socket_description, AsyncTaskExecutor::Event::WRITE | AsyncTaskExecutor::Event::ERROR); + + /// Try to write again. + res = socket.impl()->sendBytes(pos, static_cast<int>(size)); + } + } + else + { + res = socket.impl()->sendBytes(pos, static_cast<int>(size)); + } + } + catch (const Poco::Net::NetException & e) + { + throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while writing to socket ({} -> {})", e.displayText(), + our_address.toString(), peer_address.toString()); + } + catch (const Poco::TimeoutException &) + { + throw NetException(ErrorCodes::SOCKET_TIMEOUT, "Timeout exceeded while writing to socket ({}, {} ms)", + peer_address.toString(), + socket.impl()->getSendTimeout().totalMilliseconds()); + } + catch (const Poco::IOException & e) + { + throw NetException(ErrorCodes::NETWORK_ERROR, "{}, while writing to socket ({} -> {})", e.displayText(), + our_address.toString(), peer_address.toString()); + } + + if (res < 0) + throw NetException(ErrorCodes::CANNOT_WRITE_TO_SOCKET, "Cannot write to socket ({} -> {})", + our_address.toString(), peer_address.toString()); + + bytes_written += res; + } +} + +WriteBufferFromPocoSocket::WriteBufferFromPocoSocket(Poco::Net::Socket & socket_, size_t buf_size) + : BufferWithOwnMemory<WriteBuffer>(buf_size) + , socket(socket_) + , peer_address(socket.peerAddress()) + , our_address(socket.address()) + , socket_description("socket (" + peer_address.toString() + ")") +{ +} + +WriteBufferFromPocoSocket::~WriteBufferFromPocoSocket() +{ + try + { + finalize(); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromPocoSocket.h b/contrib/clickhouse/src/IO/WriteBufferFromPocoSocket.h new file mode 100644 index 0000000000..ecb6102035 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromPocoSocket.h @@ -0,0 +1,42 @@ +#pragma once + +#include <Poco/Net/Socket.h> + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <Common/AsyncTaskExecutor.h> + +namespace DB +{ + +using AsyncCallback = std::function<void(int, Poco::Timespan, AsyncEventTimeoutType, const std::string &, uint32_t)>; + +/** Works with the ready Poco::Net::Socket. Blocking operations. + */ +class WriteBufferFromPocoSocket : public BufferWithOwnMemory<WriteBuffer> +{ +public: + explicit WriteBufferFromPocoSocket(Poco::Net::Socket & socket_, size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE); + + ~WriteBufferFromPocoSocket() override; + + void setAsyncCallback(AsyncCallback async_callback_) { async_callback = std::move(async_callback_); } + +protected: + void nextImpl() override; + + Poco::Net::Socket & socket; + + /** For error messages. It is necessary to receive this address in advance, because, + * for example, if the connection is broken, the address will not be received anymore + * (getpeername will return an error). + */ + Poco::Net::SocketAddress peer_address; + Poco::Net::SocketAddress our_address; + +private: + AsyncCallback async_callback; + std::string socket_description; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromS3.cpp b/contrib/clickhouse/src/IO/WriteBufferFromS3.cpp new file mode 100644 index 0000000000..824d0ae00a --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromS3.cpp @@ -0,0 +1,699 @@ +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include "StdIStreamFromMemory.h" +#include "WriteBufferFromS3.h" +#include "WriteBufferFromS3TaskTracker.h" + +#include <Common/logger_useful.h> +#include <Common/ProfileEvents.h> +#include <Common/Throttler.h> +#include <Interpreters/Cache/FileCache.h> + +#include <IO/ResourceGuard.h> +#include <IO/WriteHelpers.h> +#include <IO/S3Common.h> +#include <IO/S3/Requests.h> +#include <IO/S3/getObjectInfo.h> +#include <Interpreters/Context.h> + +#include <aws/s3/model/StorageClass.h> + +#include <utility> + + +namespace ProfileEvents +{ + extern const Event WriteBufferFromS3Bytes; + extern const Event WriteBufferFromS3Microseconds; + extern const Event WriteBufferFromS3RequestsErrors; + extern const Event S3WriteBytes; + + extern const Event S3CreateMultipartUpload; + extern const Event S3CompleteMultipartUpload; + extern const Event S3AbortMultipartUpload; + extern const Event S3UploadPart; + extern const Event S3PutObject; + + extern const Event DiskS3CreateMultipartUpload; + extern const Event DiskS3CompleteMultipartUpload; + extern const Event DiskS3AbortMultipartUpload; + extern const Event DiskS3UploadPart; + extern const Event DiskS3PutObject; + + extern const Event RemoteWriteThrottlerBytes; + extern const Event RemoteWriteThrottlerSleepMicroseconds; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int S3_ERROR; + extern const int INVALID_CONFIG_PARAMETER; + extern const int LOGICAL_ERROR; +} + +// struct WriteBufferFromS3::PartData +// { +// Memory<> memory; +// size_t data_size = 0; + +// std::shared_ptr<std::iostream> createAwsBuffer() +// { +// auto buffer = std::make_shared<StdIStreamFromMemory>(memory.data(), data_size); +// buffer->exceptions(std::ios::badbit); +// return buffer; +// } + +// bool isEmpty() const +// { +// return data_size == 0; +// } +// }; + +std::shared_ptr<std::iostream> WriteBufferFromS3::PartData::createAwsBuffer() +{ + auto buffer = std::make_shared<StdIStreamFromMemory>(memory.data(), data_size); + buffer->exceptions(std::ios::badbit); + return buffer; +} + +WriteBufferFromS3::WriteBufferFromS3( + std::shared_ptr<const S3::Client> client_ptr_, + std::shared_ptr<const S3::Client> client_with_long_timeout_ptr_, + const String & bucket_, + const String & key_, + size_t buf_size_, + const S3Settings::RequestSettings & request_settings_, + std::optional<std::map<String, String>> object_metadata_, + ThreadPoolCallbackRunner<void> schedule_, + const WriteSettings & write_settings_) + : WriteBufferFromFileBase(buf_size_, nullptr, 0) + , bucket(bucket_) + , key(key_) + , request_settings(request_settings_) + , upload_settings(request_settings.getUploadSettings()) + , write_settings(write_settings_) + , client_ptr(std::move(client_ptr_)) + , client_with_long_timeout_ptr(std::move(client_with_long_timeout_ptr_)) + , object_metadata(std::move(object_metadata_)) + , buffer_allocation_policy(ChooseBufferPolicy(upload_settings)) + , task_tracker( + std::make_unique<WriteBufferFromS3::TaskTracker>( + std::move(schedule_), + upload_settings.max_inflight_parts_for_one_file, + limitedLog)) +{ + LOG_TRACE(limitedLog, "Create WriteBufferFromS3, {}", getShortLogDetails()); + + allocateBuffer(); +} + +void WriteBufferFromS3::nextImpl() +{ + if (is_prefinalized) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Cannot write to prefinalized buffer for S3, the file could have been created with PutObjectRequest"); + + /// Make sense to call waitIfAny before adding new async task to check if there is an exception + /// The faster the exception is propagated the lesser time is spent for cancellation + /// Despite the fact that `task_tracker->add()` collects tasks statuses and propagates their exceptions + /// that call is necessary for the case when the is no in-flight limitation and therefore `task_tracker->add()` doesn't wait anything + task_tracker->waitIfAny(); + + hidePartialData(); + + reallocateFirstBuffer(); + + if (available() > 0) + return; + + detachBuffer(); + + if (!multipart_upload_id.empty() || detached_part_data.size() > 1) + writeMultipartUpload(); + + allocateBuffer(); +} + +void WriteBufferFromS3::preFinalize() +{ + if (is_prefinalized) + return; + + LOG_TEST(limitedLog, "preFinalize WriteBufferFromS3. {}", getShortLogDetails()); + + /// This function should not be run again if an exception has occurred + is_prefinalized = true; + + hidePartialData(); + + if (hidden_size > 0) + detachBuffer(); + setFakeBufferWhenPreFinalized(); + + bool do_single_part_upload = false; + + if (multipart_upload_id.empty() && detached_part_data.size() <= 1) + { + if (detached_part_data.empty() || detached_part_data.front().data_size <= upload_settings.max_single_part_upload_size) + do_single_part_upload = true; + } + + if (do_single_part_upload) + { + if (detached_part_data.empty()) + { + makeSinglepartUpload({}); + } + else + { + makeSinglepartUpload(std::move(detached_part_data.front())); + detached_part_data.pop_front(); + } + } + else + { + writeMultipartUpload(); + } +} + +void WriteBufferFromS3::finalizeImpl() +{ + LOG_TRACE(limitedLog, "finalizeImpl WriteBufferFromS3. {}.", getShortLogDetails()); + + if (!is_prefinalized) + preFinalize(); + + chassert(offset() == 0); + chassert(hidden_size == 0); + + task_tracker->waitAll(); + + if (!multipart_upload_id.empty()) + { + completeMultipartUpload(); + multipart_upload_finished = true; + } + + if (request_settings.check_objects_after_upload) + { + S3::checkObjectExists(*client_ptr, bucket, key, {}, request_settings, /* for_disk_s3= */ write_settings.for_object_storage, "Immediately after upload"); + + size_t actual_size = S3::getObjectSize(*client_ptr, bucket, key, {}, request_settings, /* for_disk_s3= */ write_settings.for_object_storage); + if (actual_size != total_size) + throw Exception( + ErrorCodes::S3_ERROR, + "Object {} from bucket {} has unexpected size {} after upload, expected size {}, it's a bug in S3 or S3 API.", + key, bucket, actual_size, total_size); + } +} + +String WriteBufferFromS3::getVerboseLogDetails() const +{ + String multipart_upload_details; + if (!multipart_upload_id.empty()) + multipart_upload_details = fmt::format(", upload id {}, upload has finished {}" + , multipart_upload_id, multipart_upload_finished); + + return fmt::format("Details: bucket {}, key {}, total size {}, count {}, hidden_size {}, offset {}, with pool: {}, prefinalized {}, finalized {}{}", + bucket, key, total_size, count(), hidden_size, offset(), task_tracker->isAsync(), is_prefinalized, finalized, multipart_upload_details); +} + +String WriteBufferFromS3::getShortLogDetails() const +{ + String multipart_upload_details; + if (!multipart_upload_id.empty()) + multipart_upload_details = fmt::format(", upload id {}" + , multipart_upload_id); + + return fmt::format("Details: bucket {}, key {}{}", + bucket, key, multipart_upload_details); +} + +void WriteBufferFromS3::tryToAbortMultipartUpload() +{ + try + { + task_tracker->safeWaitAll(); + abortMultipartUpload(); + } + catch (...) + { + LOG_ERROR(log, "Multipart upload hasn't aborted. {}", getVerboseLogDetails()); + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +WriteBufferFromS3::~WriteBufferFromS3() +{ + LOG_TRACE(limitedLog, "Close WriteBufferFromS3. {}.", getShortLogDetails()); + + /// That destructor could be call with finalized=false in case of exceptions + if (!finalized) + { + LOG_INFO( + log, + "WriteBufferFromS3 is not finalized in destructor. " + "The file might not be written to S3. " + "{}.", + getVerboseLogDetails()); + } + + task_tracker->safeWaitAll(); + + if (!multipart_upload_id.empty() && !multipart_upload_finished) + { + LOG_WARNING(log, "WriteBufferFromS3 was neither finished nor aborted, try to abort upload in destructor. {}.", getVerboseLogDetails()); + tryToAbortMultipartUpload(); + } +} + +void WriteBufferFromS3::hidePartialData() +{ + if (write_settings.remote_throttler) + write_settings.remote_throttler->add(offset(), ProfileEvents::RemoteWriteThrottlerBytes, ProfileEvents::RemoteWriteThrottlerSleepMicroseconds); + + chassert(memory.size() >= hidden_size + offset()); + + hidden_size += offset(); + chassert(memory.data() + hidden_size == working_buffer.begin() + offset()); + chassert(memory.data() + hidden_size == position()); + + WriteBuffer::set(memory.data() + hidden_size, memory.size() - hidden_size); + chassert(offset() == 0); +} + +void WriteBufferFromS3::reallocateFirstBuffer() +{ + chassert(offset() == 0); + + if (buffer_allocation_policy->getBufferNumber() > 1 || available() > 0) + return; + + const size_t max_first_buffer = buffer_allocation_policy->getBufferSize(); + if (memory.size() == max_first_buffer) + return; + + size_t size = std::min(memory.size() * 2, max_first_buffer); + memory.resize(size); + + WriteBuffer::set(memory.data() + hidden_size, memory.size() - hidden_size); + + chassert(offset() == 0); +} + +void WriteBufferFromS3::detachBuffer() +{ + size_t data_size = size_t(position() - memory.data()); + chassert(data_size == hidden_size); + + auto buf = std::move(memory); + + WriteBuffer::set(nullptr, 0); + total_size += hidden_size; + hidden_size = 0; + + detached_part_data.push_back({std::move(buf), data_size}); +} + +void WriteBufferFromS3::allocateFirstBuffer() +{ + const auto max_first_buffer = buffer_allocation_policy->getBufferSize(); + const auto size = std::min(size_t(DBMS_DEFAULT_BUFFER_SIZE), max_first_buffer); + memory = Memory(size); + WriteBuffer::set(memory.data(), memory.size()); +} + +void WriteBufferFromS3::allocateBuffer() +{ + buffer_allocation_policy->nextBuffer(); + chassert(0 == hidden_size); + + if (buffer_allocation_policy->getBufferNumber() == 1) + return allocateFirstBuffer(); + + memory = Memory(buffer_allocation_policy->getBufferSize()); + WriteBuffer::set(memory.data(), memory.size()); +} + +void WriteBufferFromS3::setFakeBufferWhenPreFinalized() +{ + WriteBuffer::set(fake_buffer_when_prefinalized, sizeof(fake_buffer_when_prefinalized)); +} + +void WriteBufferFromS3::writeMultipartUpload() +{ + if (multipart_upload_id.empty()) + { + createMultipartUpload(); + } + + while (!detached_part_data.empty()) + { + writePart(std::move(detached_part_data.front())); + detached_part_data.pop_front(); + } +} + +void WriteBufferFromS3::createMultipartUpload() +{ + LOG_TEST(limitedLog, "Create multipart upload. {}", getShortLogDetails()); + + S3::CreateMultipartUploadRequest req; + + req.SetBucket(bucket); + req.SetKey(key); + + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + req.SetContentType("binary/octet-stream"); + + if (object_metadata.has_value()) + req.SetMetadata(object_metadata.value()); + + client_ptr->setKMSHeaders(req); + + ProfileEvents::increment(ProfileEvents::S3CreateMultipartUpload); + if (write_settings.for_object_storage) + ProfileEvents::increment(ProfileEvents::DiskS3CreateMultipartUpload); + + Stopwatch watch; + auto outcome = client_ptr->CreateMultipartUpload(req); + watch.stop(); + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + + if (!outcome.IsSuccess()) + { + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + throw S3Exception(outcome.GetError().GetMessage(), outcome.GetError().GetErrorType()); + } + + multipart_upload_id = outcome.GetResult().GetUploadId(); + LOG_TRACE(limitedLog, "Multipart upload has created. {}", getShortLogDetails()); +} + +void WriteBufferFromS3::abortMultipartUpload() +{ + if (multipart_upload_id.empty()) + { + LOG_WARNING(log, "Nothing to abort. {}", getVerboseLogDetails()); + return; + } + + LOG_WARNING(log, "Abort multipart upload. {}", getVerboseLogDetails()); + + S3::AbortMultipartUploadRequest req; + req.SetBucket(bucket); + req.SetKey(key); + req.SetUploadId(multipart_upload_id); + + ProfileEvents::increment(ProfileEvents::S3AbortMultipartUpload); + if (write_settings.for_object_storage) + ProfileEvents::increment(ProfileEvents::DiskS3AbortMultipartUpload); + + Stopwatch watch; + auto outcome = client_ptr->AbortMultipartUpload(req); + watch.stop(); + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + + if (!outcome.IsSuccess()) + { + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + throw S3Exception(outcome.GetError().GetMessage(), outcome.GetError().GetErrorType()); + } + + LOG_WARNING(log, "Multipart upload has aborted successfully. {}", getVerboseLogDetails()); +} + +S3::UploadPartRequest WriteBufferFromS3::getUploadRequest(size_t part_number, PartData & data) +{ + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Bytes, data.data_size); + + S3::UploadPartRequest req; + + /// Setup request. + req.SetBucket(bucket); + req.SetKey(key); + req.SetPartNumber(static_cast<int>(part_number)); + req.SetUploadId(multipart_upload_id); + req.SetContentLength(data.data_size); + req.SetBody(data.createAwsBuffer()); + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + req.SetContentType("binary/octet-stream"); + + return req; +} + +void WriteBufferFromS3::writePart(WriteBufferFromS3::PartData && data) +{ + if (data.data_size == 0) + { + LOG_TEST(limitedLog, "Skipping writing part as empty {}", getShortLogDetails()); + return; + } + + multipart_tags.push_back({}); + size_t part_number = multipart_tags.size(); + LOG_TEST(limitedLog, "writePart {}, part size {}, part number {}", getShortLogDetails(), data.data_size, part_number); + + if (multipart_upload_id.empty()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Unable to write a part without multipart_upload_id, details: WriteBufferFromS3 created for bucket {}, key {}", + bucket, key); + + if (part_number > upload_settings.max_part_number) + { + throw Exception( + ErrorCodes::INVALID_CONFIG_PARAMETER, + "Part number exceeded {} while writing {} bytes to S3. Check min_upload_part_size = {}, max_upload_part_size = {}, " + "upload_part_size_multiply_factor = {}, upload_part_size_multiply_parts_count_threshold = {}, max_single_part_upload_size = {}", + upload_settings.max_part_number, count(), upload_settings.min_upload_part_size, upload_settings.max_upload_part_size, + upload_settings.upload_part_size_multiply_factor, upload_settings.upload_part_size_multiply_parts_count_threshold, + upload_settings.max_single_part_upload_size); + } + + if (data.data_size > upload_settings.max_upload_part_size) + { + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Part size exceeded max_upload_part_size. {}, part number {}, part size {}, max_upload_part_size {}", + getShortLogDetails(), + part_number, + data.data_size, + upload_settings.max_upload_part_size + ); + } + + auto req = getUploadRequest(part_number, data); + auto worker_data = std::make_shared<std::tuple<S3::UploadPartRequest, WriteBufferFromS3::PartData>>(std::move(req), std::move(data)); + + auto upload_worker = [&, worker_data, part_number] () + { + auto & data_size = std::get<1>(*worker_data).data_size; + + LOG_TEST(limitedLog, "Write part started {}, part size {}, part number {}", + getShortLogDetails(), data_size, part_number); + + ProfileEvents::increment(ProfileEvents::S3UploadPart); + if (write_settings.for_object_storage) + ProfileEvents::increment(ProfileEvents::DiskS3UploadPart); + + auto & request = std::get<0>(*worker_data); + + ResourceCost cost = request.GetContentLength(); + ResourceGuard rlock(write_settings.resource_link, cost); + Stopwatch watch; + auto outcome = client_ptr->UploadPart(request); + watch.stop(); + rlock.unlock(); // Avoid acquiring other locks under resource lock + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + + if (!outcome.IsSuccess()) + { + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + write_settings.resource_link.accumulate(cost); // We assume no resource was used in case of failure + throw S3Exception(outcome.GetError().GetMessage(), outcome.GetError().GetErrorType()); + } + + multipart_tags[part_number-1] = outcome.GetResult().GetETag(); + + LOG_TEST(limitedLog, "Write part succeeded {}, part size {}, part number {}, etag {}", + getShortLogDetails(), data_size, part_number, multipart_tags[part_number-1]); + }; + + task_tracker->add(std::move(upload_worker)); +} + +void WriteBufferFromS3::completeMultipartUpload() +{ + LOG_TEST(limitedLog, "Completing multipart upload. {}, Parts: {}", getShortLogDetails(), multipart_tags.size()); + + if (multipart_tags.empty()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Failed to complete multipart upload. No parts have uploaded"); + + for (size_t i = 0; i < multipart_tags.size(); ++i) + { + const auto tag = multipart_tags.at(i); + if (tag.empty()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Failed to complete multipart upload. Part {} haven't been uploaded.", i); + } + + S3::CompleteMultipartUploadRequest req; + req.SetBucket(bucket); + req.SetKey(key); + req.SetUploadId(multipart_upload_id); + + Aws::S3::Model::CompletedMultipartUpload multipart_upload; + for (size_t i = 0; i < multipart_tags.size(); ++i) + { + Aws::S3::Model::CompletedPart part; + multipart_upload.AddParts(part.WithETag(multipart_tags[i]).WithPartNumber(static_cast<int>(i + 1))); + } + + req.SetMultipartUpload(multipart_upload); + + size_t max_retry = std::max(request_settings.max_unexpected_write_error_retries, 1UL); + for (size_t i = 0; i < max_retry; ++i) + { + ProfileEvents::increment(ProfileEvents::S3CompleteMultipartUpload); + if (write_settings.for_object_storage) + ProfileEvents::increment(ProfileEvents::DiskS3CompleteMultipartUpload); + + Stopwatch watch; + auto outcome = client_with_long_timeout_ptr->CompleteMultipartUpload(req); + watch.stop(); + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + + if (outcome.IsSuccess()) + { + LOG_TRACE(limitedLog, "Multipart upload has completed. {}, Parts: {}", getShortLogDetails(), multipart_tags.size()); + return; + } + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + + if (outcome.GetError().GetErrorType() == Aws::S3::S3Errors::NO_SUCH_KEY) + { + /// For unknown reason, at least MinIO can respond with NO_SUCH_KEY for put requests + /// BTW, NO_SUCH_UPLOAD is expected error and we shouldn't retry it + LOG_INFO(log, "Multipart upload failed with NO_SUCH_KEY error, will retry. {}, Parts: {}", getVerboseLogDetails(), multipart_tags.size()); + } + else + { + throw S3Exception( + outcome.GetError().GetErrorType(), + "Message: {}, Key: {}, Bucket: {}, Tags: {}", + outcome.GetError().GetMessage(), key, bucket, fmt::join(multipart_tags.begin(), multipart_tags.end(), " ")); + } + } + + throw S3Exception( + Aws::S3::S3Errors::NO_SUCH_KEY, + "Message: Multipart upload failed with NO_SUCH_KEY error, retries {}, Key: {}, Bucket: {}", + max_retry, key, bucket); +} + +S3::PutObjectRequest WriteBufferFromS3::getPutRequest(PartData & data) +{ + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Bytes, data.data_size); + + S3::PutObjectRequest req; + + req.SetBucket(bucket); + req.SetKey(key); + req.SetContentLength(data.data_size); + req.SetBody(data.createAwsBuffer()); + if (object_metadata.has_value()) + req.SetMetadata(object_metadata.value()); + if (!upload_settings.storage_class_name.empty()) + req.SetStorageClass(Aws::S3::Model::StorageClassMapper::GetStorageClassForName(upload_settings.storage_class_name)); + + /// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840 + req.SetContentType("binary/octet-stream"); + + client_ptr->setKMSHeaders(req); + + return req; +} + +void WriteBufferFromS3::makeSinglepartUpload(WriteBufferFromS3::PartData && data) +{ + LOG_TEST(limitedLog, "Making single part upload. {}, size {}", getShortLogDetails(), data.data_size); + + auto req = getPutRequest(data); + auto worker_data = std::make_shared<std::tuple<S3::PutObjectRequest, WriteBufferFromS3::PartData>>(std::move(req), std::move(data)); + + auto upload_worker = [&, worker_data] () + { + LOG_TEST(limitedLog, "writing single part upload started. {}", getShortLogDetails()); + + auto & request = std::get<0>(*worker_data); + size_t content_length = request.GetContentLength(); + + size_t max_retry = std::max(request_settings.max_unexpected_write_error_retries, 1UL); + for (size_t i = 0; i < max_retry; ++i) + { + ProfileEvents::increment(ProfileEvents::S3PutObject); + if (write_settings.for_object_storage) + ProfileEvents::increment(ProfileEvents::DiskS3PutObject); + + ResourceCost cost = request.GetContentLength(); + ResourceGuard rlock(write_settings.resource_link, cost); + Stopwatch watch; + auto outcome = client_ptr->PutObject(request); + watch.stop(); + rlock.unlock(); + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3Microseconds, watch.elapsedMicroseconds()); + + if (outcome.IsSuccess()) + { + LOG_TRACE(limitedLog, "Single part upload has completed. {}, size {}", getShortLogDetails(), content_length); + return; + } + + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3RequestsErrors, 1); + write_settings.resource_link.accumulate(cost); // We assume no resource was used in case of failure + + if (outcome.GetError().GetErrorType() == Aws::S3::S3Errors::NO_SUCH_KEY) + { + + /// For unknown reason, at least MinIO can respond with NO_SUCH_KEY for put requests + LOG_INFO(log, "Single part upload failed with NO_SUCH_KEY error. {}, size {}, will retry", getShortLogDetails(), content_length); + } + else + { + LOG_ERROR(log, "S3Exception name {}, Message: {}, bucket {}, key {}, object size {}", + outcome.GetError().GetExceptionName(), outcome.GetError().GetMessage(), bucket, key, content_length); + throw S3Exception( + outcome.GetError().GetErrorType(), + "Message: {}, bucket {}, key {}, object size {}", + outcome.GetError().GetMessage(), bucket, key, content_length); + } + } + + throw S3Exception( + Aws::S3::S3Errors::NO_SUCH_KEY, + "Message: Single part upload failed with NO_SUCH_KEY error, retries {}, Key: {}, Bucket: {}", + max_retry, key, bucket); + }; + + task_tracker->add(std::move(upload_worker)); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromS3.h b/contrib/clickhouse/src/IO/WriteBufferFromS3.h new file mode 100644 index 0000000000..0fdf771e1f --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromS3.h @@ -0,0 +1,140 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <base/types.h> +#include <Common/logger_useful.h> +#include <IO/WriteBufferFromFileBase.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteSettings.h> +#include <Storages/StorageS3Settings.h> +#include <Interpreters/threadPoolCallbackRunner.h> + +#include <memory> +#include <vector> +#include <list> + +namespace DB +{ +/** + * Buffer to write a data to a S3 object with specified bucket and key. + * If data size written to the buffer is less than 'max_single_part_upload_size' write is performed using singlepart upload. + * In another case multipart upload is used: + * Data is divided on chunks with size greater than 'minimum_upload_part_size'. Last chunk can be less than this threshold. + * Each chunk is written as a part to S3. + */ +class WriteBufferFromS3 final : public WriteBufferFromFileBase +{ +public: + WriteBufferFromS3( + std::shared_ptr<const S3::Client> client_ptr_, + /// for CompleteMultipartUploadRequest, because it blocks on recv() for a few seconds on big uploads + std::shared_ptr<const S3::Client> client_with_long_timeout_ptr_, + const String & bucket_, + const String & key_, + size_t buf_size_, + const S3Settings::RequestSettings & request_settings_, + std::optional<std::map<String, String>> object_metadata_ = std::nullopt, + ThreadPoolCallbackRunner<void> schedule_ = {}, + const WriteSettings & write_settings_ = {}); + + ~WriteBufferFromS3() override; + void nextImpl() override; + void preFinalize() override; + std::string getFileName() const override { return key; } + void sync() override { next(); } + + class IBufferAllocationPolicy + { + public: + virtual size_t getBufferNumber() const = 0; + virtual size_t getBufferSize() const = 0; + virtual void nextBuffer() = 0; + virtual ~IBufferAllocationPolicy() = 0; + }; + using IBufferAllocationPolicyPtr = std::unique_ptr<IBufferAllocationPolicy>; + + static IBufferAllocationPolicyPtr ChooseBufferPolicy(const S3Settings::RequestSettings::PartUploadSettings & settings_); + +private: + /// Receives response from the server after sending all data. + void finalizeImpl() override; + + String getVerboseLogDetails() const; + String getShortLogDetails() const; + + struct PartData + { + Memory<> memory; + size_t data_size = 0; + + std::shared_ptr<std::iostream> createAwsBuffer(); + + bool isEmpty() const + { + return data_size == 0; + } + }; + + void hidePartialData(); + void allocateFirstBuffer(); + void reallocateFirstBuffer(); + void detachBuffer(); + void allocateBuffer(); + void setFakeBufferWhenPreFinalized(); + + S3::UploadPartRequest getUploadRequest(size_t part_number, PartData & data); + void writePart(PartData && data); + void writeMultipartUpload(); + void createMultipartUpload(); + void completeMultipartUpload(); + void abortMultipartUpload(); + void tryToAbortMultipartUpload(); + + S3::PutObjectRequest getPutRequest(PartData & data); + void makeSinglepartUpload(PartData && data); + + const String bucket; + const String key; + const S3Settings::RequestSettings request_settings; + const S3Settings::RequestSettings::PartUploadSettings & upload_settings; + const WriteSettings write_settings; + const std::shared_ptr<const S3::Client> client_ptr; + const std::shared_ptr<const S3::Client> client_with_long_timeout_ptr; + const std::optional<std::map<String, String>> object_metadata; + Poco::Logger * log = &Poco::Logger::get("WriteBufferFromS3"); + LogSeriesLimiterPtr limitedLog = std::make_shared<LogSeriesLimiter>(log, 1, 5); + + IBufferAllocationPolicyPtr buffer_allocation_policy; + + /// Upload in S3 is made in parts. + /// We initiate upload, then upload each part and get ETag as a response, and then finalizeImpl() upload with listing all our parts. + String multipart_upload_id; + std::deque<String> multipart_tags; + bool multipart_upload_finished = false; + + /// Track that prefinalize() is called only once + bool is_prefinalized = false; + + /// First fully filled buffer has to be delayed + /// There are two ways after: + /// First is to call prefinalize/finalize, which leads to single part upload + /// Second is to write more data, which leads to multi part upload + std::deque<PartData> detached_part_data; + char fake_buffer_when_prefinalized[1] = {}; + + /// offset() and count() are unstable inside nextImpl + /// For example nextImpl changes position hence offset() and count() is changed + /// This vars are dedicated to store information about sizes when offset() and count() are unstable + size_t total_size = 0; + size_t hidden_size = 0; + + class TaskTracker; + std::unique_ptr<TaskTracker> task_tracker; +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromS3BufferAllocationPolicy.cpp b/contrib/clickhouse/src/IO/WriteBufferFromS3BufferAllocationPolicy.cpp new file mode 100644 index 0000000000..e64ea82c48 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromS3BufferAllocationPolicy.cpp @@ -0,0 +1,112 @@ +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <IO/WriteBufferFromS3.h> + +#include <memory> + +namespace +{ + +class FixedSizeBufferAllocationPolicy : public DB::WriteBufferFromS3::IBufferAllocationPolicy +{ + const size_t buffer_size = 0; + size_t buffer_number = 0; + +public: + explicit FixedSizeBufferAllocationPolicy(const DB::S3Settings::RequestSettings::PartUploadSettings & settings_) + : buffer_size(settings_.strict_upload_part_size) + { + chassert(buffer_size > 0); + } + + size_t getBufferNumber() const override { return buffer_number; } + + size_t getBufferSize() const override + { + chassert(buffer_number > 0); + return buffer_size; + } + + void nextBuffer() override + { + ++buffer_number; + } +}; + + +class ExpBufferAllocationPolicy : public DB::WriteBufferFromS3::IBufferAllocationPolicy +{ + const size_t first_size = 0; + const size_t second_size = 0; + + const size_t multiply_factor = 0; + const size_t multiply_threshold = 0; + const size_t max_size = 0; + + size_t current_size = 0; + size_t buffer_number = 0; + +public: + explicit ExpBufferAllocationPolicy(const DB::S3Settings::RequestSettings::PartUploadSettings & settings_) + : first_size(std::max(settings_.max_single_part_upload_size, settings_.min_upload_part_size)) + , second_size(settings_.min_upload_part_size) + , multiply_factor(settings_.upload_part_size_multiply_factor) + , multiply_threshold(settings_.upload_part_size_multiply_parts_count_threshold) + , max_size(settings_.max_upload_part_size) + { + chassert(first_size > 0); + chassert(second_size > 0); + chassert(multiply_factor >= 1); + chassert(multiply_threshold > 0); + chassert(max_size > 0); + } + + size_t getBufferNumber() const override { return buffer_number; } + + size_t getBufferSize() const override + { + chassert(buffer_number > 0); + return current_size; + } + + void nextBuffer() override + { + ++buffer_number; + + if (1 == buffer_number) + { + current_size = first_size; + return; + } + + if (2 == buffer_number) + current_size = second_size; + + if (0 == ((buffer_number - 1) % multiply_threshold)) + { + current_size *= multiply_factor; + current_size = std::min(current_size, max_size); + } + } +}; + +} + +namespace DB +{ + +WriteBufferFromS3::IBufferAllocationPolicy::~IBufferAllocationPolicy() = default; + +WriteBufferFromS3::IBufferAllocationPolicyPtr WriteBufferFromS3::ChooseBufferPolicy(const S3Settings::RequestSettings::PartUploadSettings & settings_) +{ + if (settings_.strict_upload_part_size > 0) + return std::make_unique<FixedSizeBufferAllocationPolicy>(settings_); + else + return std::make_unique<ExpBufferAllocationPolicy>(settings_); +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromS3TaskTracker.cpp b/contrib/clickhouse/src/IO/WriteBufferFromS3TaskTracker.cpp new file mode 100644 index 0000000000..ed63d0c530 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromS3TaskTracker.cpp @@ -0,0 +1,176 @@ +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include <IO/WriteBufferFromS3TaskTracker.h> + +namespace ProfileEvents +{ + extern const Event WriteBufferFromS3WaitInflightLimitMicroseconds; +} + +namespace DB +{ + +WriteBufferFromS3::TaskTracker::TaskTracker(ThreadPoolCallbackRunner<void> scheduler_, size_t max_tasks_inflight_, LogSeriesLimiterPtr limitedLog_) + : is_async(bool(scheduler_)) + , scheduler(scheduler_ ? std::move(scheduler_) : syncRunner()) + , max_tasks_inflight(max_tasks_inflight_) + , limitedLog(limitedLog_) +{} + +WriteBufferFromS3::TaskTracker::~TaskTracker() +{ + safeWaitAll(); +} + +ThreadPoolCallbackRunner<void> WriteBufferFromS3::TaskTracker::syncRunner() +{ + return [](Callback && callback, int64_t) mutable -> std::future<void> + { + auto package = std::packaged_task<void()>(std::move(callback)); + /// No exceptions are propagated, exceptions are packed to future + package(); + return package.get_future(); + }; +} + +void WriteBufferFromS3::TaskTracker::waitAll() +{ + /// Exceptions are propagated + for (auto & future : futures) + { + future.get(); + } + futures.clear(); + + std::lock_guard lock(mutex); + finished_futures.clear(); +} + +void WriteBufferFromS3::TaskTracker::safeWaitAll() +{ + for (auto & future : futures) + { + if (future.valid()) + { + try + { + /// Exceptions are not propagated + future.get(); + } catch (...) + { + /// But at least they are printed + tryLogCurrentException(__PRETTY_FUNCTION__); + } + } + } + futures.clear(); + + std::lock_guard lock(mutex); + finished_futures.clear(); +} + +void WriteBufferFromS3::TaskTracker::waitIfAny() +{ + if (futures.empty()) + return; + + Stopwatch watch; + + { + std::lock_guard lock(mutex); + for (auto & it : finished_futures) + { + /// actually that call might lock this thread until the future is set finally + /// however that won't lock us for long, the task is about to finish when the pointer appears in the `finished_futures` + it->get(); + + /// in case of exception in `it->get()` + /// it it not necessary to remove `it` from list `futures` + /// `TaskTracker` has to be destroyed after any exception occurs, for this `safeWaitAll` is called. + /// `safeWaitAll` handles invalid futures in the list `futures` + futures.erase(it); + } + finished_futures.clear(); + } + + watch.stop(); + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3WaitInflightLimitMicroseconds, watch.elapsedMicroseconds()); +} + +void WriteBufferFromS3::TaskTracker::add(Callback && func) +{ + /// All this fuzz is about 2 things. This is the most critical place of TaskTracker. + /// The first is not to fail insertion in the list `futures`. + /// In order to face it, the element is allocated at the end of the list `futures` in advance. + /// The second is not to fail the notification of the task. + /// In order to face it, the list element, which would be inserted to the list `finished_futures`, + /// is allocated in advance as an other list `pre_allocated_finished` with one element inside. + + /// preallocation for the first issue + futures.emplace_back(); + auto future_placeholder = std::prev(futures.end()); + + /// preallocation for the second issue + FinishedList pre_allocated_finished {future_placeholder}; + + Callback func_with_notification = [&, my_func = std::move(func), my_pre_allocated_finished = std::move(pre_allocated_finished)]() mutable + { + SCOPE_EXIT({ + DENY_ALLOCATIONS_IN_SCOPE; + + std::lock_guard lock(mutex); + finished_futures.splice(finished_futures.end(), my_pre_allocated_finished); + has_finished.notify_one(); + }); + + my_func(); + }; + + /// this move is nothrow + *future_placeholder = scheduler(std::move(func_with_notification), Priority{}); + + waitTilInflightShrink(); +} + +void WriteBufferFromS3::TaskTracker::waitTilInflightShrink() +{ + if (!max_tasks_inflight) + return; + + if (futures.size() >= max_tasks_inflight) + LOG_TEST(limitedLog, "have to wait some tasks finish, in queue {}, limit {}", futures.size(), max_tasks_inflight); + + Stopwatch watch; + + /// Alternative approach is to wait until at least futures.size() - max_tasks_inflight element are finished + /// However the faster finished task is collected the faster CH checks if there is an exception + /// The faster an exception is propagated the lesser time is spent for cancellation + while (futures.size() >= max_tasks_inflight) + { + std::unique_lock lock(mutex); + + has_finished.wait(lock, [this] () TSA_REQUIRES(mutex) { return !finished_futures.empty(); }); + + for (auto & it : finished_futures) + { + it->get(); + futures.erase(it); + } + + finished_futures.clear(); + } + + watch.stop(); + ProfileEvents::increment(ProfileEvents::WriteBufferFromS3WaitInflightLimitMicroseconds, watch.elapsedMicroseconds()); +} + +bool WriteBufferFromS3::TaskTracker::isAsync() const +{ + return is_async; +} + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromS3TaskTracker.h b/contrib/clickhouse/src/IO/WriteBufferFromS3TaskTracker.h new file mode 100644 index 0000000000..21daea22c0 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromS3TaskTracker.h @@ -0,0 +1,72 @@ +#pragma once + +#include "clickhouse_config.h" + +#if USE_AWS_S3 + +#include "WriteBufferFromS3.h" + +#include <Common/logger_useful.h> + +#include <list> + +namespace DB +{ + +/// That class is used only in WriteBufferFromS3 for now. +/// Therefore it declared as a part of WriteBufferFromS3. +/// TaskTracker takes a Callback which is run by scheduler in some external shared ThreadPool. +/// TaskTracker brings the methods waitIfAny, waitAll/safeWaitAll +/// to help with coordination of the running tasks. + +/// Basic exception safety is provided. If exception occurred the object has to be destroyed. +/// No thread safety is provided. Use this object with no concurrency. + +class WriteBufferFromS3::TaskTracker +{ +public: + using Callback = std::function<void()>; + + TaskTracker(ThreadPoolCallbackRunner<void> scheduler_, size_t max_tasks_inflight_, LogSeriesLimiterPtr limitedLog_); + ~TaskTracker(); + + static ThreadPoolCallbackRunner<void> syncRunner(); + + bool isAsync() const; + + /// waitIfAny collects statuses from already finished tasks + /// There could be no finished tasks yet, so waitIfAny do nothing useful in that case + /// the first exception is thrown if any task has failed + void waitIfAny(); + + /// Well, waitAll waits all the tasks until they finish and collects their statuses + void waitAll(); + + /// safeWaitAll does the same as waitAll but mutes the exceptions + void safeWaitAll(); + + void add(Callback && func); + +private: + /// waitTilInflightShrink waits til the number of in-flight tasks beyond the limit `max_tasks_inflight`. + void waitTilInflightShrink() TSA_NO_THREAD_SAFETY_ANALYSIS; + + void collectFinishedFutures(bool propagate_exceptions) TSA_REQUIRES(mutex); + + const bool is_async; + ThreadPoolCallbackRunner<void> scheduler; + const size_t max_tasks_inflight; + + using FutureList = std::list<std::future<void>>; + FutureList futures; + LogSeriesLimiterPtr limitedLog; + + std::mutex mutex; + std::condition_variable has_finished TSA_GUARDED_BY(mutex); + using FinishedList = std::list<FutureList::iterator>; + FinishedList finished_futures TSA_GUARDED_BY(mutex); +}; + +} + +#endif diff --git a/contrib/clickhouse/src/IO/WriteBufferFromString.h b/contrib/clickhouse/src/IO/WriteBufferFromString.h new file mode 100644 index 0000000000..1f813b1070 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromString.h @@ -0,0 +1,42 @@ +#pragma once + +#include <string> +#include <IO/WriteBufferFromVector.h> +#include <base/StringRef.h> + + +namespace DB +{ + +/** Writes the data to a string. + * Note: before using the resulting string, destroy this object. + */ +using WriteBufferFromString = WriteBufferFromVector<std::string>; + + +namespace detail +{ + /// For correct order of initialization. + class StringHolder + { + protected: + std::string value; + }; +} + +/// Creates the string by itself and allows to get it. +class WriteBufferFromOwnString : public detail::StringHolder, public WriteBufferFromString +{ +public: + WriteBufferFromOwnString() : WriteBufferFromString(value) {} + + std::string_view stringView() const { return isFinished() ? std::string_view(value) : std::string_view(value.data(), pos - value.data()); } + + std::string & str() + { + finalize(); + return value; + } +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferFromVector.h b/contrib/clickhouse/src/IO/WriteBufferFromVector.h new file mode 100644 index 0000000000..a2ecc34f1a --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferFromVector.h @@ -0,0 +1,103 @@ +#pragma once + +#include <vector> + +#include <IO/WriteBuffer.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_WRITE_AFTER_END_OF_BUFFER; +} + +struct AppendModeTag {}; + +/** Writes data to existing std::vector or similar type. When not enough space, it doubles vector size. + * + * In destructor, vector is cut to the size of written data. + * You can call 'finalize' to resize earlier. + * + * The vector should live until this object is destroyed or until the 'finalizeImpl()' method is called. + */ +template <typename VectorType> +class WriteBufferFromVector : public WriteBuffer +{ +public: + using ValueType = typename VectorType::value_type; + explicit WriteBufferFromVector(VectorType & vector_) + : WriteBuffer(reinterpret_cast<Position>(vector_.data()), vector_.size()), vector(vector_) + { + if (vector.empty()) + { + vector.resize(initial_size); + set(reinterpret_cast<Position>(vector.data()), vector.size()); + } + } + + /// Append to vector instead of rewrite. + WriteBufferFromVector(VectorType & vector_, AppendModeTag) + : WriteBuffer(nullptr, 0), vector(vector_) + { + size_t old_size = vector.size(); + size_t size = (old_size < initial_size) ? initial_size + : ((old_size < vector.capacity()) ? vector.capacity() + : vector.capacity() * size_multiplier); + vector.resize(size); + set(reinterpret_cast<Position>(vector.data() + old_size), (size - old_size) * sizeof(typename VectorType::value_type)); + } + + bool isFinished() const { return finalized; } + + void restart(std::optional<size_t> max_capacity = std::nullopt) + { + if (max_capacity && vector.capacity() > max_capacity) + VectorType(initial_size, ValueType()).swap(vector); + else if (vector.empty()) + vector.resize(initial_size); + set(reinterpret_cast<Position>(vector.data()), vector.size()); + finalized = false; + } + + ~WriteBufferFromVector() override + { + finalize(); + } + +private: + void finalizeImpl() override + { + vector.resize( + ((position() - reinterpret_cast<Position>(vector.data())) /// NOLINT + + sizeof(ValueType) - 1) /// Align up. + / sizeof(ValueType)); + + /// Prevent further writes. + set(nullptr, 0); + } + + void nextImpl() override + { + if (finalized) + throw Exception(ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER, "WriteBufferFromVector is finalized"); + + size_t old_size = vector.size(); + /// pos may not be equal to vector.data() + old_size, because WriteBuffer::next() can be used to flush data + size_t pos_offset = pos - reinterpret_cast<Position>(vector.data()); + if (pos_offset == old_size) + { + vector.resize(old_size * size_multiplier); + } + internal_buffer = Buffer(reinterpret_cast<Position>(vector.data() + pos_offset), reinterpret_cast<Position>(vector.data() + vector.size())); + working_buffer = internal_buffer; + } + + VectorType & vector; + + static constexpr size_t initial_size = 32; + static constexpr size_t size_multiplier = 2; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferValidUTF8.cpp b/contrib/clickhouse/src/IO/WriteBufferValidUTF8.cpp new file mode 100644 index 0000000000..d611befac3 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferValidUTF8.cpp @@ -0,0 +1,157 @@ +#include <Poco/UTF8Encoding.h> +#include <IO/WriteBufferValidUTF8.h> +#include <base/types.h> +#include <base/simd.h> + +#ifdef __SSE2__ + #include <emmintrin.h> +#endif + +#if defined(__aarch64__) && defined(__ARM_NEON) +# include <arm_neon.h> +# pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +namespace DB +{ + +const size_t WriteBufferValidUTF8::DEFAULT_SIZE = 4096; + +/** Index into the table below with the first byte of a UTF-8 sequence to + * get the number of trailing bytes that are supposed to follow it. + * Note that *legal* UTF-8 values can't have 4 or 5-bytes. The table is + * left as-is for anyone who may want to do such conversion, which was + * allowed in earlier algorithms. + */ +extern const UInt8 length_of_utf8_sequence[256] = +{ + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, + 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4,5,5,5,5,6,6,6,6 +}; + + +WriteBufferValidUTF8::WriteBufferValidUTF8( + WriteBuffer & output_buffer_, bool group_replacements_, const char * replacement_, size_t size) + : BufferWithOwnMemory<WriteBuffer>(std::max(static_cast<size_t>(32), size)), output_buffer(output_buffer_), + group_replacements(group_replacements_), replacement(replacement_) +{ +} + + +inline void WriteBufferValidUTF8::putReplacement() +{ + if (replacement.empty() || (group_replacements && just_put_replacement)) + return; + + just_put_replacement = true; + output_buffer.write(replacement.data(), replacement.size()); +} + + +inline void WriteBufferValidUTF8::putValid(char *data, size_t len) +{ + if (len == 0) + return; + + just_put_replacement = false; + output_buffer.write(data, len); +} + + +void WriteBufferValidUTF8::nextImpl() +{ + char * p = memory.data(); + char * valid_start = p; + + while (p < pos) + { +#ifdef __SSE2__ + /// Fast skip of ASCII for x86. + static constexpr size_t SIMD_BYTES = 16; + const char * simd_end = p + (pos - p) / SIMD_BYTES * SIMD_BYTES; + + while (p < simd_end && !_mm_movemask_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i*>(p)))) + p += SIMD_BYTES; + + if (!(p < pos)) + break; +#elif defined(__aarch64__) && defined(__ARM_NEON) + /// Fast skip of ASCII for aarch64. + static constexpr size_t SIMD_BYTES = 16; + const char * simd_end = p + (pos - p) / SIMD_BYTES * SIMD_BYTES; + /// Other options include + /// vmaxvq_u8(input) < 0b10000000; + /// Used by SIMDJSON, has latency 3 for M1, 6 for everything else + /// SIMDJSON uses it for 64 byte masks, so it's a little different. + /// vmaxvq_u32(vandq_u32(input, vdupq_n_u32(0x80808080))) // u32 version has latency 3 + /// shrn version has universally <=3 cycles, on servers 2 cycles. + while (p < simd_end && getNibbleMask(vcgeq_u8(vld1q_u8(reinterpret_cast<const uint8_t *>(p)), vdupq_n_u8(0x80))) == 0) + p += SIMD_BYTES; + + if (!(p < pos)) + break; +#endif + + UInt8 len = length_of_utf8_sequence[static_cast<unsigned char>(*p)]; + + if (len > 4) + { // NOLINT + /// Invalid start of sequence. Skip one byte. + putValid(valid_start, p - valid_start); + putReplacement(); + ++p; + valid_start = p; + } + else if (p + len > pos) + { + /// Sequence was not fully written to this buffer. + break; + } + else if (Poco::UTF8Encoding::isLegal(reinterpret_cast<unsigned char *>(p), len)) + { + /// Valid sequence. + p += len; + } + else + { + /// Invalid sequence. Skip just first byte. + putValid(valid_start, p - valid_start); + putReplacement(); + ++p; + valid_start = p; + } + } + + putValid(valid_start, p - valid_start); + + size_t cnt = pos - p; + + /// Shift unfinished sequence to start of buffer. + for (size_t i = 0; i < cnt; ++i) + memory[i] = p[i]; + + working_buffer = Buffer(&memory[cnt], memory.data() + memory.size()); +} + +WriteBufferValidUTF8::~WriteBufferValidUTF8() +{ + finalize(); +} + +void WriteBufferValidUTF8::finalizeImpl() +{ + /// Write all complete sequences from buffer. + nextImpl(); + + /// If unfinished sequence at end, then write replacement. + if (working_buffer.begin() != memory.data()) + putReplacement(); +} + +} diff --git a/contrib/clickhouse/src/IO/WriteBufferValidUTF8.h b/contrib/clickhouse/src/IO/WriteBufferValidUTF8.h new file mode 100644 index 0000000000..daaf0427f8 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteBufferValidUTF8.h @@ -0,0 +1,41 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> + + +namespace DB +{ + +/** Writes the data to another buffer, replacing the invalid UTF-8 sequences with the specified sequence. + * If the valid UTF-8 is already written, it works faster. + * Note: before using the resulting string, destroy this object. + */ +class WriteBufferValidUTF8 final : public BufferWithOwnMemory<WriteBuffer> +{ +public: + static const size_t DEFAULT_SIZE; + + explicit WriteBufferValidUTF8( + WriteBuffer & output_buffer_, + bool group_replacements_ = true, + const char * replacement_ = "\xEF\xBF\xBD", + size_t size = DEFAULT_SIZE); + + ~WriteBufferValidUTF8() override; + +private: + void putReplacement(); + void putValid(char * data, size_t len); + + void nextImpl() override; + void finalizeImpl() override; + + WriteBuffer & output_buffer; + bool group_replacements; + /// The last recorded character was `replacement`. + bool just_put_replacement = false; + std::string replacement; +}; + +} diff --git a/contrib/clickhouse/src/IO/WriteHelpers.cpp b/contrib/clickhouse/src/IO/WriteHelpers.cpp new file mode 100644 index 0000000000..34eabe55d7 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteHelpers.cpp @@ -0,0 +1,125 @@ +#include <IO/WriteHelpers.h> +#include <cinttypes> +#include <utility> +#include <Common/formatIPv6.h> +#include <base/hex.h> + + +namespace DB +{ + +template <typename IteratorSrc, typename IteratorDst> +void formatHex(IteratorSrc src, IteratorDst dst, size_t num_bytes) +{ + size_t src_pos = 0; + size_t dst_pos = 0; + for (; src_pos < num_bytes; ++src_pos) + { + writeHexByteLowercase(src[src_pos], &dst[dst_pos]); + dst_pos += 2; + } +} + +std::array<char, 36> formatUUID(const UUID & uuid) +{ + std::array<char, 36> dst; + auto * dst_ptr = dst.data(); + +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + const auto * src_ptr = reinterpret_cast<const UInt8 *>(&uuid); + const std::reverse_iterator src(src_ptr + 16); +#else + const auto * src = reinterpret_cast<const UInt8 *>(&uuid); +#endif + formatHex(src + 8, dst_ptr, 4); + dst[8] = '-'; + formatHex(src + 12, dst_ptr + 9, 2); + dst[13] = '-'; + formatHex(src + 14, dst_ptr + 14, 2); + dst[18] = '-'; + formatHex(src, dst_ptr + 19, 2); + dst[23] = '-'; + formatHex(src + 2, dst_ptr + 24, 6); + + return dst; +} + +void writeIPv4Text(const IPv4 & ip, WriteBuffer & buf) +{ + size_t idx = (ip >> 24); + buf.write(one_byte_to_string_lookup_table[idx].first, one_byte_to_string_lookup_table[idx].second); + buf.write('.'); + idx = (ip >> 16) & 0xFF; + buf.write(one_byte_to_string_lookup_table[idx].first, one_byte_to_string_lookup_table[idx].second); + buf.write('.'); + idx = (ip >> 8) & 0xFF; + buf.write(one_byte_to_string_lookup_table[idx].first, one_byte_to_string_lookup_table[idx].second); + buf.write('.'); + idx = ip & 0xFF; + buf.write(one_byte_to_string_lookup_table[idx].first, one_byte_to_string_lookup_table[idx].second); +} + +void writeIPv6Text(const IPv6 & ip, WriteBuffer & buf) +{ + char addr[IPV6_MAX_TEXT_LENGTH + 1] {}; + char * paddr = addr; + + formatIPv6(reinterpret_cast<const unsigned char *>(&ip), paddr); + buf.write(addr, paddr - addr - 1); +} + +void writeException(const Exception & e, WriteBuffer & buf, bool with_stack_trace) +{ + writeBinaryLittleEndian(e.code(), buf); + writeBinary(String(e.name()), buf); + writeBinary(e.displayText() + getExtraExceptionInfo(e), buf); + + if (with_stack_trace) + writeBinary(e.getStackTraceString(), buf); + else + writeBinary(String(), buf); + + bool has_nested = false; + writeBinary(has_nested, buf); +} + + +/// The same, but quotes apply only if there are characters that do not match the identifier without quotes +template <typename F> +static inline void writeProbablyQuotedStringImpl(StringRef s, WriteBuffer & buf, F && write_quoted_string) +{ + if (isValidIdentifier(s.toView()) + /// This are valid identifiers but are problematic if present unquoted in SQL query. + && !(s.size == strlen("distinct") && 0 == strncasecmp(s.data, "distinct", strlen("distinct"))) + && !(s.size == strlen("all") && 0 == strncasecmp(s.data, "all", strlen("all")))) + { + writeString(s, buf); + } + else + write_quoted_string(s, buf); +} + +void writeProbablyBackQuotedString(StringRef s, WriteBuffer & buf) +{ + writeProbablyQuotedStringImpl(s, buf, [](StringRef s_, WriteBuffer & buf_) { return writeBackQuotedString(s_, buf_); }); +} + +void writeProbablyDoubleQuotedString(StringRef s, WriteBuffer & buf) +{ + writeProbablyQuotedStringImpl(s, buf, [](StringRef s_, WriteBuffer & buf_) { return writeDoubleQuotedString(s_, buf_); }); +} + +void writeProbablyBackQuotedStringMySQL(StringRef s, WriteBuffer & buf) +{ + writeProbablyQuotedStringImpl(s, buf, [](StringRef s_, WriteBuffer & buf_) { return writeBackQuotedStringMySQL(s_, buf_); }); +} + +void writePointerHex(const void * ptr, WriteBuffer & buf) +{ + writeString("0x", buf); + char hex_str[2 * sizeof(ptr)]; + writeHexUIntLowercase(reinterpret_cast<uintptr_t>(ptr), hex_str); + buf.write(hex_str, 2 * sizeof(ptr)); +} + +} diff --git a/contrib/clickhouse/src/IO/WriteHelpers.h b/contrib/clickhouse/src/IO/WriteHelpers.h new file mode 100644 index 0000000000..57337e7bb9 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteHelpers.h @@ -0,0 +1,1274 @@ +#pragma once + +#include <cstring> +#include <cstdio> +#include <limits> +#include <algorithm> +#include <iterator> +#include <concepts> +#include <bit> + +#include <pcg-random/pcg_random.hpp> + +#include <Common/StackTrace.h> +#include <Common/formatIPv6.h> +#include <Common/DateLUT.h> +#include <Common/LocalDate.h> +#include <Common/LocalDateTime.h> +#include <Common/TransformEndianness.hpp> +#include <base/find_symbols.h> +#include <base/StringRef.h> +#include <base/DecomposedFloat.h> +#include <base/EnumReflection.h> + +#include <Core/DecimalFunctions.h> +#include <Core/Types.h> +#include <Core/UUID.h> +#include <base/IPv4andIPv6.h> + +#include <Common/Exception.h> +#include <Common/StringUtils/StringUtils.h> +#include <Common/NaNUtils.h> + +#include <IO/CompressionMethod.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteIntText.h> +#include <IO/VarInt.h> +#include <IO/DoubleConverter.h> +#include <IO/WriteBufferFromString.h> + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-parameter" +#pragma clang diagnostic ignored "-Wsign-compare" +#endif +#include <dragonbox/dragonbox_to_chars.h> +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#include <Formats/FormatSettings.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_PRINT_FLOAT_OR_DOUBLE_NUMBER; +} + + +/// Helper functions for formatted and binary output. + +inline void writeChar(char x, WriteBuffer & buf) +{ + buf.nextIfAtEnd(); + *buf.position() = x; + ++buf.position(); +} + +/// Write the same character n times. +inline void writeChar(char c, size_t n, WriteBuffer & buf) +{ + while (n) + { + buf.nextIfAtEnd(); + size_t count = std::min(n, buf.available()); + memset(buf.position(), c, count); + n -= count; + buf.position() += count; + } +} + +/// Write POD-type in native format. It's recommended to use only with packed (dense) data types. +template <typename T> +inline void writePODBinary(const T & x, WriteBuffer & buf) +{ + buf.write(reinterpret_cast<const char *>(&x), sizeof(x)); /// NOLINT +} + +inline void writeUUIDBinary(const UUID & x, WriteBuffer & buf) +{ + const auto & uuid = x.toUnderType(); + writePODBinary(uuid.items[0], buf); + writePODBinary(uuid.items[1], buf); +} + +template <typename T> +inline void writeIntBinary(const T & x, WriteBuffer & buf) +{ + writePODBinary(x, buf); +} + +template <typename T> +inline void writeFloatBinary(const T & x, WriteBuffer & buf) +{ + writePODBinary(x, buf); +} + + +inline void writeStringBinary(const std::string & s, WriteBuffer & buf) +{ + writeVarUInt(s.size(), buf); + buf.write(s.data(), s.size()); +} + +/// For historical reasons we store IPv6 as a String +inline void writeIPv6Binary(const IPv6 & ip, WriteBuffer & buf) +{ + writeVarUInt(IPV6_BINARY_LENGTH, buf); + buf.write(reinterpret_cast<const char *>(&ip.toUnderType()), IPV6_BINARY_LENGTH); +} + +inline void writeStringBinary(StringRef s, WriteBuffer & buf) +{ + writeVarUInt(s.size, buf); + buf.write(s.data, s.size); +} + +inline void writeStringBinary(const char * s, WriteBuffer & buf) +{ + writeStringBinary(StringRef{s}, buf); +} + +inline void writeStringBinary(std::string_view s, WriteBuffer & buf) +{ + writeStringBinary(StringRef{s}, buf); +} + + +template <typename T> +void writeVectorBinary(const std::vector<T> & v, WriteBuffer & buf) +{ + writeVarUInt(v.size(), buf); + + for (typename std::vector<T>::const_iterator it = v.begin(); it != v.end(); ++it) + writeBinary(*it, buf); +} + + +inline void writeBoolText(bool x, WriteBuffer & buf) +{ + writeChar(x ? '1' : '0', buf); +} + + +template <typename T> +inline size_t writeFloatTextFastPath(T x, char * buffer) +{ + Int64 result = 0; + + if constexpr (std::is_same_v<T, double>) + { + /// The library Ryu has low performance on integers. + /// This workaround improves performance 6..10 times. + + if (DecomposedFloat64(x).isIntegerInRepresentableRange()) + result = itoa(Int64(x), buffer) - buffer; + else + result = jkj::dragonbox::to_chars_n(x, buffer) - buffer; + } + else + { + if (DecomposedFloat32(x).isIntegerInRepresentableRange()) + result = itoa(Int32(x), buffer) - buffer; + else + result = jkj::dragonbox::to_chars_n(x, buffer) - buffer; + } + + if (result <= 0) + throw Exception(ErrorCodes::CANNOT_PRINT_FLOAT_OR_DOUBLE_NUMBER, "Cannot print floating point number"); + return result; +} + +template <typename T> +inline void writeFloatText(T x, WriteBuffer & buf) +{ + static_assert(std::is_same_v<T, double> || std::is_same_v<T, float>, "Argument for writeFloatText must be float or double"); + + using Converter = DoubleConverter<false>; + if (likely(buf.available() >= Converter::MAX_REPRESENTATION_LENGTH)) + { + buf.position() += writeFloatTextFastPath(x, buf.position()); + return; + } + + Converter::BufferType buffer; + size_t result = writeFloatTextFastPath(x, buffer); + buf.write(buffer, result); +} + + +inline void writeString(const char * data, size_t size, WriteBuffer & buf) +{ + buf.write(data, size); +} + +// Otherwise StringRef and string_view overloads are ambiguous when passing string literal. Prefer std::string_view +void writeString(std::same_as<StringRef> auto ref, WriteBuffer & buf) +{ + writeString(ref.data, ref.size, buf); +} + +inline void writeString(std::string_view ref, WriteBuffer & buf) +{ + writeString(ref.data(), ref.size(), buf); +} + +/** Writes a C-string without creating a temporary object. If the string is a literal, then `strlen` is executed at the compilation stage. + * Use when the string is a literal. + */ +#define writeCString(s, buf) \ + (buf).write((s), strlen(s)) + +/** Writes a string for use in the JSON format: + * - the string is written in double quotes + * - slash character '/' is escaped for compatibility with JavaScript + * - bytes from the range 0x00-0x1F except `\b', '\f', '\n', '\r', '\t' are escaped as \u00XX + * - code points U+2028 and U+2029 (byte sequences in UTF-8: e2 80 a8, e2 80 a9) are escaped as \u2028 and \u2029 + * - it is assumed that string is in UTF-8, the invalid UTF-8 is not processed + * - all other non-ASCII characters remain as is + */ +inline void writeJSONString(const char * begin, const char * end, WriteBuffer & buf, const FormatSettings & settings) +{ + writeChar('"', buf); + for (const char * it = begin; it != end; ++it) + { + switch (*it) + { + case '\b': + writeChar('\\', buf); + writeChar('b', buf); + break; + case '\f': + writeChar('\\', buf); + writeChar('f', buf); + break; + case '\n': + writeChar('\\', buf); + writeChar('n', buf); + break; + case '\r': + writeChar('\\', buf); + writeChar('r', buf); + break; + case '\t': + writeChar('\\', buf); + writeChar('t', buf); + break; + case '\\': + writeChar('\\', buf); + writeChar('\\', buf); + break; + case '/': + if (settings.json.escape_forward_slashes) + writeChar('\\', buf); + writeChar('/', buf); + break; + case '"': + writeChar('\\', buf); + writeChar('"', buf); + break; + default: + UInt8 c = *it; + if (c <= 0x1F) + { + /// Escaping of ASCII control characters. + + UInt8 higher_half = c >> 4; + UInt8 lower_half = c & 0xF; + + writeCString("\\u00", buf); + writeChar('0' + higher_half, buf); + + if (lower_half <= 9) + writeChar('0' + lower_half, buf); + else + writeChar('A' + lower_half - 10, buf); + } + else if (end - it >= 3 && it[0] == '\xE2' && it[1] == '\x80' && (it[2] == '\xA8' || it[2] == '\xA9')) + { + /// This is for compatibility with JavaScript, because unescaped line separators are prohibited in string literals, + /// and these code points are alternative line separators. + + if (it[2] == '\xA8') + writeCString("\\u2028", buf); + if (it[2] == '\xA9') + writeCString("\\u2029", buf); + + /// Byte sequence is 3 bytes long. We have additional two bytes to skip. + it += 2; + } + else + writeChar(*it, buf); + } + } + writeChar('"', buf); +} + + +/** Will escape quote_character and a list of special characters('\b', '\f', '\n', '\r', '\t', '\0', '\\'). + * - when escape_quote_with_quote is true, use backslash to escape list of special characters, + * and use quote_character to escape quote_character. such as: 'hello''world' + * otherwise use backslash to escape list of special characters and quote_character + * - when escape_backslash_with_backslash is true, backslash is escaped with another backslash + */ +template <char quote_character, bool escape_quote_with_quote = false, bool escape_backslash_with_backslash = true> +void writeAnyEscapedString(const char * begin, const char * end, WriteBuffer & buf) +{ + const char * pos = begin; + while (true) + { + /// On purpose we will escape more characters than minimally necessary. + const char * next_pos = find_first_symbols<'\b', '\f', '\n', '\r', '\t', '\0', '\\', quote_character>(pos, end); + + if (next_pos == end) + { + buf.write(pos, next_pos - pos); + break; + } + else + { + buf.write(pos, next_pos - pos); + pos = next_pos; + switch (*pos) + { + case quote_character: + { + if constexpr (escape_quote_with_quote) + writeChar(quote_character, buf); + else + writeChar('\\', buf); + writeChar(quote_character, buf); + break; + } + case '\b': + writeChar('\\', buf); + writeChar('b', buf); + break; + case '\f': + writeChar('\\', buf); + writeChar('f', buf); + break; + case '\n': + writeChar('\\', buf); + writeChar('n', buf); + break; + case '\r': + writeChar('\\', buf); + writeChar('r', buf); + break; + case '\t': + writeChar('\\', buf); + writeChar('t', buf); + break; + case '\0': + writeChar('\\', buf); + writeChar('0', buf); + break; + case '\\': + if constexpr (escape_backslash_with_backslash) + writeChar('\\', buf); + writeChar('\\', buf); + break; + default: + writeChar(*pos, buf); + } + ++pos; + } + } +} + + +inline void writeJSONString(std::string_view s, WriteBuffer & buf, const FormatSettings & settings) +{ + writeJSONString(s.data(), s.data() + s.size(), buf, settings); +} + +template <typename T> +void writeJSONNumber(T x, WriteBuffer & ostr, const FormatSettings & settings) +{ + bool is_finite = isFinite(x); + + const bool need_quote = (is_integer<T> && (sizeof(T) >= 8) && settings.json.quote_64bit_integers) + || (settings.json.quote_denormals && !is_finite) || (is_floating_point<T> && (sizeof(T) >= 8) && settings.json.quote_64bit_floats); + + if (need_quote) + writeChar('"', ostr); + + if (is_finite) + writeText(x, ostr); + else if (!settings.json.quote_denormals) + writeCString("null", ostr); + else + { + if constexpr (std::is_floating_point_v<T>) + { + if (std::signbit(x)) + { + if (isNaN(x)) + writeCString("-nan", ostr); + else + writeCString("-inf", ostr); + } + else + { + if (isNaN(x)) + writeCString("nan", ostr); + else + writeCString("inf", ostr); + } + } + } + + if (need_quote) + writeChar('"', ostr); +} + + +template <char c> +void writeAnyEscapedString(std::string_view s, WriteBuffer & buf) +{ + writeAnyEscapedString<c>(s.data(), s.data() + s.size(), buf); +} + + +inline void writeEscapedString(const char * str, size_t size, WriteBuffer & buf) +{ + writeAnyEscapedString<'\''>(str, str + size, buf); +} + +inline void writeEscapedString(std::string_view ref, WriteBuffer & buf) +{ + writeEscapedString(ref.data(), ref.size(), buf); +} + +template <char quote_character> +void writeAnyQuotedString(const char * begin, const char * end, WriteBuffer & buf) +{ + writeChar(quote_character, buf); + writeAnyEscapedString<quote_character>(begin, end, buf); + writeChar(quote_character, buf); +} + + +template <char quote_character> +void writeAnyQuotedString(std::string_view ref, WriteBuffer & buf) +{ + writeAnyQuotedString<quote_character>(ref.data(), ref.data() + ref.size(), buf); +} + + +inline void writeQuotedString(const String & s, WriteBuffer & buf) +{ + writeAnyQuotedString<'\''>(s, buf); +} + +inline void writeQuotedString(StringRef ref, WriteBuffer & buf) +{ + writeAnyQuotedString<'\''>(ref.toView(), buf); +} + +inline void writeQuotedString(std::string_view ref, WriteBuffer & buf) +{ + writeAnyQuotedString<'\''>(ref.data(), ref.data() + ref.size(), buf); +} + +inline void writeQuotedStringPostgreSQL(std::string_view ref, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeAnyEscapedString<'\'', true, false>(ref.data(), ref.data() + ref.size(), buf); + writeChar('\'', buf); +} + +inline void writeDoubleQuotedString(const String & s, WriteBuffer & buf) +{ + writeAnyQuotedString<'"'>(s, buf); +} + +inline void writeDoubleQuotedString(StringRef s, WriteBuffer & buf) +{ + writeAnyQuotedString<'"'>(s.toView(), buf); +} + +inline void writeDoubleQuotedString(std::string_view s, WriteBuffer & buf) +{ + writeAnyQuotedString<'"'>(s.data(), s.data() + s.size(), buf); +} + +/// Outputs a string in backquotes. +inline void writeBackQuotedString(StringRef s, WriteBuffer & buf) +{ + writeAnyQuotedString<'`'>(s.toView(), buf); +} + +/// Outputs a string in backquotes for MySQL. +inline void writeBackQuotedStringMySQL(StringRef s, WriteBuffer & buf) +{ + writeChar('`', buf); + writeAnyEscapedString<'`', true>(s.data, s.data + s.size, buf); + writeChar('`', buf); +} + + +/// Write quoted if the string doesn't look like and identifier. +void writeProbablyBackQuotedString(StringRef s, WriteBuffer & buf); +void writeProbablyDoubleQuotedString(StringRef s, WriteBuffer & buf); +void writeProbablyBackQuotedStringMySQL(StringRef s, WriteBuffer & buf); + + +/** Outputs the string in for the CSV format. + * Rules: + * - the string is outputted in quotation marks; + * - the quotation mark inside the string is outputted as two quotation marks in sequence. + */ +template <char quote = '"'> +void writeCSVString(const char * begin, const char * end, WriteBuffer & buf) +{ + writeChar(quote, buf); + + const char * pos = begin; + while (true) + { + const char * next_pos = find_first_symbols<quote>(pos, end); + + if (next_pos == end) + { + buf.write(pos, end - pos); + break; + } + else /// Quotation. + { + ++next_pos; + buf.write(pos, next_pos - pos); + writeChar(quote, buf); + } + + pos = next_pos; + } + + writeChar(quote, buf); +} + +template <char quote = '"'> +void writeCSVString(const String & s, WriteBuffer & buf) +{ + writeCSVString<quote>(s.data(), s.data() + s.size(), buf); +} + +template <char quote = '"'> +void writeCSVString(StringRef s, WriteBuffer & buf) +{ + writeCSVString<quote>(s.data, s.data + s.size, buf); +} + +inline void writeXMLStringForTextElementOrAttributeValue(const char * begin, const char * end, WriteBuffer & buf) +{ + const char * pos = begin; + while (true) + { + const char * next_pos = find_first_symbols<'<', '&', '>', '"', '\''>(pos, end); + + if (next_pos == end) + { + buf.write(pos, end - pos); + break; + } + else if (*next_pos == '<') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString("<", buf); + } + else if (*next_pos == '&') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString("&", buf); + } + else if (*next_pos == '>') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString(">", buf); + } + else if (*next_pos == '"') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString(""", buf); + } + else if (*next_pos == '\'') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString("'", buf); + } + + pos = next_pos; + } +} + +inline void writeXMLStringForTextElementOrAttributeValue(std::string_view s, WriteBuffer & buf) +{ + writeXMLStringForTextElementOrAttributeValue(s.data(), s.data() + s.size(), buf); +} + +/// Writing a string to a text node in XML (not into an attribute - otherwise you need more escaping). +inline void writeXMLStringForTextElement(const char * begin, const char * end, WriteBuffer & buf) +{ + const char * pos = begin; + while (true) + { + /// NOTE Perhaps for some XML parsers, you need to escape the zero byte and some control characters. + const char * next_pos = find_first_symbols<'<', '&'>(pos, end); + + if (next_pos == end) + { + buf.write(pos, end - pos); + break; + } + else if (*next_pos == '<') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString("<", buf); + } + else if (*next_pos == '&') + { + buf.write(pos, next_pos - pos); + ++next_pos; + writeCString("&", buf); + } + + pos = next_pos; + } +} + +inline void writeXMLStringForTextElement(std::string_view s, WriteBuffer & buf) +{ + writeXMLStringForTextElement(s.data(), s.data() + s.size(), buf); +} + +/// @brief Serialize `uuid` into an array of characters in big-endian byte order. +/// @param uuid UUID to serialize. +/// @return Array of characters in big-endian byte order. +std::array<char, 36> formatUUID(const UUID & uuid); + +inline void writeUUIDText(const UUID & uuid, WriteBuffer & buf) +{ + const auto serialized_uuid = formatUUID(uuid); + buf.write(serialized_uuid.data(), serialized_uuid.size()); +} + +void writeIPv4Text(const IPv4 & ip, WriteBuffer & buf); +void writeIPv6Text(const IPv6 & ip, WriteBuffer & buf); + +template <typename DecimalType> +inline void writeDateTime64FractionalText(typename DecimalType::NativeType fractional, UInt32 scale, WriteBuffer & buf) +{ + static constexpr UInt32 MaxScale = DecimalUtils::max_precision<DecimalType>; + + char data[20] = {'0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0'}; + static_assert(sizeof(data) >= MaxScale); + + for (Int32 pos = scale - 1; pos >= 0 && fractional; --pos, fractional /= DateTime64(10)) + data[pos] += fractional % DateTime64(10); + + writeString(&data[0], static_cast<size_t>(scale), buf); +} + +static const char digits100[201] = + "00010203040506070809" + "10111213141516171819" + "20212223242526272829" + "30313233343536373839" + "40414243444546474849" + "50515253545556575859" + "60616263646566676869" + "70717273747576777879" + "80818283848586878889" + "90919293949596979899"; + +/// in YYYY-MM-DD format +template <char delimiter = '-'> +inline void writeDateText(const LocalDate & date, WriteBuffer & buf) +{ + if (reinterpret_cast<intptr_t>(buf.position()) + 10 <= reinterpret_cast<intptr_t>(buf.buffer().end())) + { + memcpy(buf.position(), &digits100[date.year() / 100 * 2], 2); + buf.position() += 2; + memcpy(buf.position(), &digits100[date.year() % 100 * 2], 2); + buf.position() += 2; + *buf.position() = delimiter; + ++buf.position(); + memcpy(buf.position(), &digits100[date.month() * 2], 2); + buf.position() += 2; + *buf.position() = delimiter; + ++buf.position(); + memcpy(buf.position(), &digits100[date.day() * 2], 2); + buf.position() += 2; + } + else + { + buf.write(&digits100[date.year() / 100 * 2], 2); + buf.write(&digits100[date.year() % 100 * 2], 2); + buf.write(delimiter); + buf.write(&digits100[date.month() * 2], 2); + buf.write(delimiter); + buf.write(&digits100[date.day() * 2], 2); + } +} + +template <char delimiter = '-'> +inline void writeDateText(DayNum date, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + writeDateText<delimiter>(LocalDate(date, time_zone), buf); +} + +template <char delimiter = '-'> +inline void writeDateText(ExtendedDayNum date, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + writeDateText<delimiter>(LocalDate(date, time_zone), buf); +} + +/// In the format YYYY-MM-DD HH:MM:SS +template <char date_delimeter = '-', char time_delimeter = ':', char between_date_time_delimiter = ' '> +inline void writeDateTimeText(const LocalDateTime & datetime, WriteBuffer & buf) +{ + if (reinterpret_cast<intptr_t>(buf.position()) + 19 <= reinterpret_cast<intptr_t>(buf.buffer().end())) + { + memcpy(buf.position(), &digits100[datetime.year() / 100 * 2], 2); + buf.position() += 2; + memcpy(buf.position(), &digits100[datetime.year() % 100 * 2], 2); + buf.position() += 2; + *buf.position() = date_delimeter; + ++buf.position(); + memcpy(buf.position(), &digits100[datetime.month() * 2], 2); + buf.position() += 2; + *buf.position() = date_delimeter; + ++buf.position(); + memcpy(buf.position(), &digits100[datetime.day() * 2], 2); + buf.position() += 2; + *buf.position() = between_date_time_delimiter; + ++buf.position(); + memcpy(buf.position(), &digits100[datetime.hour() * 2], 2); + buf.position() += 2; + *buf.position() = time_delimeter; + ++buf.position(); + memcpy(buf.position(), &digits100[datetime.minute() * 2], 2); + buf.position() += 2; + *buf.position() = time_delimeter; + ++buf.position(); + memcpy(buf.position(), &digits100[datetime.second() * 2], 2); + buf.position() += 2; + } + else + { + buf.write(&digits100[datetime.year() / 100 * 2], 2); + buf.write(&digits100[datetime.year() % 100 * 2], 2); + buf.write(date_delimeter); + buf.write(&digits100[datetime.month() * 2], 2); + buf.write(date_delimeter); + buf.write(&digits100[datetime.day() * 2], 2); + buf.write(between_date_time_delimiter); + buf.write(&digits100[datetime.hour() * 2], 2); + buf.write(time_delimeter); + buf.write(&digits100[datetime.minute() * 2], 2); + buf.write(time_delimeter); + buf.write(&digits100[datetime.second() * 2], 2); + } +} + +/// In the format YYYY-MM-DD HH:MM:SS, according to the specified time zone. +template <char date_delimeter = '-', char time_delimeter = ':', char between_date_time_delimiter = ' '> +inline void writeDateTimeText(time_t datetime, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + writeDateTimeText<date_delimeter, time_delimeter, between_date_time_delimiter>(LocalDateTime(datetime, time_zone), buf); +} + +/// In the format YYYY-MM-DD HH:MM:SS.NNNNNNNNN, according to the specified time zone. +template <char date_delimeter = '-', char time_delimeter = ':', char between_date_time_delimiter = ' ', char fractional_time_delimiter = '.'> +inline void writeDateTimeText(DateTime64 datetime64, UInt32 scale, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + static constexpr UInt32 MaxScale = DecimalUtils::max_precision<DateTime64>; + scale = scale > MaxScale ? MaxScale : scale; + + auto components = DecimalUtils::split(datetime64, scale); + /// Case1: + /// -127914467.877 + /// => whole = -127914467, fraction = 877(After DecimalUtils::split) + /// => new whole = -127914468(1965-12-12 12:12:12), new fraction = 1000 - 877 = 123(.123) + /// => 1965-12-12 12:12:12.123 + /// + /// Case2: + /// -0.877 + /// => whole = 0, fractional = -877(After DecimalUtils::split) + /// => whole = -1(1969-12-31 23:59:59), fractional = 1000 + (-877) = 123(.123) + using T = typename DateTime64::NativeType; + if (datetime64.value < 0 && components.fractional) + { + components.fractional = DecimalUtils::scaleMultiplier<T>(scale) + (components.whole ? T(-1) : T(1)) * components.fractional; + --components.whole; + } + + writeDateTimeText<date_delimeter, time_delimeter, between_date_time_delimiter>(LocalDateTime(components.whole, time_zone), buf); + + if (scale > 0) + { + buf.write(fractional_time_delimiter); + writeDateTime64FractionalText<DateTime64>(components.fractional, scale, buf); + } +} + +/// In the RFC 1123 format: "Tue, 03 Dec 2019 00:11:50 GMT". You must provide GMT DateLUT. +/// This is needed for HTTP requests. +inline void writeDateTimeTextRFC1123(time_t datetime, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +{ + const auto & values = time_zone.getValues(datetime); + + static const char week_days[3 * 8 + 1] = "XXX" "Mon" "Tue" "Wed" "Thu" "Fri" "Sat" "Sun"; + static const char months[3 * 13 + 1] = "XXX" "Jan" "Feb" "Mar" "Apr" "May" "Jun" "Jul" "Aug" "Sep" "Oct" "Nov" "Dec"; + + buf.write(&week_days[values.day_of_week * 3], 3); + buf.write(", ", 2); + buf.write(&digits100[values.day_of_month * 2], 2); + buf.write(' '); + buf.write(&months[values.month * 3], 3); + buf.write(' '); + buf.write(&digits100[values.year / 100 * 2], 2); + buf.write(&digits100[values.year % 100 * 2], 2); + buf.write(' '); + buf.write(&digits100[time_zone.toHour(datetime) * 2], 2); + buf.write(':'); + buf.write(&digits100[time_zone.toMinute(datetime) * 2], 2); + buf.write(':'); + buf.write(&digits100[time_zone.toSecond(datetime) * 2], 2); + buf.write(" GMT", 4); +} + +inline void writeDateTimeTextISO(time_t datetime, WriteBuffer & buf, const DateLUTImpl & utc_time_zone) +{ + writeDateTimeText<'-', ':', 'T'>(datetime, buf, utc_time_zone); + buf.write('Z'); +} + +inline void writeDateTimeTextISO(DateTime64 datetime64, UInt32 scale, WriteBuffer & buf, const DateLUTImpl & utc_time_zone) +{ + writeDateTimeText<'-', ':', 'T'>(datetime64, scale, buf, utc_time_zone); + buf.write('Z'); +} + +inline void writeDateTimeUnixTimestamp(DateTime64 datetime64, UInt32 scale, WriteBuffer & buf) +{ + static constexpr UInt32 MaxScale = DecimalUtils::max_precision<DateTime64>; + scale = scale > MaxScale ? MaxScale : scale; + + auto components = DecimalUtils::split(datetime64, scale); + writeIntText(components.whole, buf); + + if (scale > 0) + { + buf.write('.'); + writeDateTime64FractionalText<DateTime64>(components.fractional, scale, buf); + } +} + +/// Methods for output in binary format. +template <typename T> +requires is_arithmetic_v<T> +inline void writeBinary(const T & x, WriteBuffer & buf) { writePODBinary(x, buf); } + +inline void writeBinary(const String & x, WriteBuffer & buf) { writeStringBinary(x, buf); } +inline void writeBinary(StringRef x, WriteBuffer & buf) { writeStringBinary(x, buf); } +inline void writeBinary(std::string_view x, WriteBuffer & buf) { writeStringBinary(x, buf); } +inline void writeBinary(const Decimal32 & x, WriteBuffer & buf) { writePODBinary(x, buf); } +inline void writeBinary(const Decimal64 & x, WriteBuffer & buf) { writePODBinary(x, buf); } +inline void writeBinary(const Decimal128 & x, WriteBuffer & buf) { writePODBinary(x, buf); } +inline void writeBinary(const Decimal256 & x, WriteBuffer & buf) { writePODBinary(x.value, buf); } +inline void writeBinary(const LocalDate & x, WriteBuffer & buf) { writePODBinary(x, buf); } +inline void writeBinary(const LocalDateTime & x, WriteBuffer & buf) { writePODBinary(x, buf); } +inline void writeBinary(const IPv4 & x, WriteBuffer & buf) { writePODBinary(x, buf); } +inline void writeBinary(const IPv6 & x, WriteBuffer & buf) { writePODBinary(x, buf); } + +inline void writeBinary(const UUID & x, WriteBuffer & buf) +{ + writeUUIDBinary(x, buf); +} + +inline void writeBinary(const CityHash_v1_0_2::uint128 & x, WriteBuffer & buf) +{ + writePODBinary(x.low64, buf); + writePODBinary(x.high64, buf); +} + +inline void writeBinary(const StackTrace::FramePointers & x, WriteBuffer & buf) { writePODBinary(x, buf); } + +/// Methods for outputting the value in text form for a tab-separated format. + +inline void writeText(is_integer auto x, WriteBuffer & buf) +{ + if constexpr (std::is_same_v<decltype(x), bool>) + writeBoolText(x, buf); + else if constexpr (std::is_same_v<decltype(x), char>) + writeChar(x, buf); + else + writeIntText(x, buf); +} + +inline void writeText(is_floating_point auto x, WriteBuffer & buf) { writeFloatText(x, buf); } + +inline void writeText(is_enum auto x, WriteBuffer & buf) { writeText(magic_enum::enum_name(x), buf); } + +inline void writeText(std::string_view x, WriteBuffer & buf) { writeString(x.data(), x.size(), buf); } + +inline void writeText(const DayNum & x, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) { writeDateText(LocalDate(x, time_zone), buf); } +inline void writeText(const LocalDate & x, WriteBuffer & buf) { writeDateText(x, buf); } +inline void writeText(const LocalDateTime & x, WriteBuffer & buf) { writeDateTimeText(x, buf); } +inline void writeText(const UUID & x, WriteBuffer & buf) { writeUUIDText(x, buf); } +inline void writeText(const IPv4 & x, WriteBuffer & buf) { writeIPv4Text(x, buf); } +inline void writeText(const IPv6 & x, WriteBuffer & buf) { writeIPv6Text(x, buf); } + +template <typename T> +void writeDecimalFractional(const T & x, UInt32 scale, WriteBuffer & ostr, bool trailing_zeros, + bool fixed_fractional_length, UInt32 fractional_length) +{ + /// If it's big integer, but the number of digits is small, + /// use the implementation for smaller integers for more efficient arithmetic. + if constexpr (std::is_same_v<T, Int256>) + { + if (x <= std::numeric_limits<UInt32>::max()) + { + writeDecimalFractional(static_cast<UInt32>(x), scale, ostr, trailing_zeros, fixed_fractional_length, fractional_length); + return; + } + else if (x <= std::numeric_limits<UInt64>::max()) + { + writeDecimalFractional(static_cast<UInt64>(x), scale, ostr, trailing_zeros, fixed_fractional_length, fractional_length); + return; + } + else if (x <= std::numeric_limits<UInt128>::max()) + { + writeDecimalFractional(static_cast<UInt128>(x), scale, ostr, trailing_zeros, fixed_fractional_length, fractional_length); + return; + } + } + else if constexpr (std::is_same_v<T, Int128>) + { + if (x <= std::numeric_limits<UInt32>::max()) + { + writeDecimalFractional(static_cast<UInt32>(x), scale, ostr, trailing_zeros, fixed_fractional_length, fractional_length); + return; + } + else if (x <= std::numeric_limits<UInt64>::max()) + { + writeDecimalFractional(static_cast<UInt64>(x), scale, ostr, trailing_zeros, fixed_fractional_length, fractional_length); + return; + } + } + + constexpr size_t max_digits = std::numeric_limits<UInt256>::digits10; + assert(scale <= max_digits); + assert(fractional_length <= max_digits); + + char buf[max_digits]; + memset(buf, '0', std::max(scale, fractional_length)); + + T value = x; + Int32 last_nonzero_pos = 0; + + if (fixed_fractional_length && fractional_length < scale) + { + T new_value = value / DecimalUtils::scaleMultiplier<Int256>(scale - fractional_length - 1); + auto round_carry = new_value % 10; + value = new_value / 10; + if (round_carry >= 5) + value += 1; + } + + for (Int32 pos = fixed_fractional_length ? std::min(scale - 1, fractional_length - 1) : scale - 1; pos >= 0; --pos) + { + auto remainder = value % 10; + value /= 10; + + if (remainder != 0 && last_nonzero_pos == 0) + last_nonzero_pos = pos; + + buf[pos] += static_cast<char>(remainder); + } + + writeChar('.', ostr); + ostr.write(buf, fixed_fractional_length ? fractional_length : (trailing_zeros ? scale : last_nonzero_pos + 1)); +} + +template <typename T> +void writeText(Decimal<T> x, UInt32 scale, WriteBuffer & ostr, bool trailing_zeros, + bool fixed_fractional_length = false, UInt32 fractional_length = 0) +{ + T part = DecimalUtils::getWholePart(x, scale); + + if (x.value < 0 && part == 0) + { + writeChar('-', ostr); /// avoid crop leading minus when whole part is zero + } + + writeIntText(part, ostr); + + if (scale || (fixed_fractional_length && fractional_length > 0)) + { + part = DecimalUtils::getFractionalPart(x, scale); + if (part || trailing_zeros) + { + if (part < 0) + part *= T(-1); + + writeDecimalFractional(part, scale, ostr, trailing_zeros, fixed_fractional_length, fractional_length); + } + } +} + +/// String, date, datetime are in single quotes with C-style escaping. Numbers - without. +template <typename T> +requires is_arithmetic_v<T> +inline void writeQuoted(const T & x, WriteBuffer & buf) { writeText(x, buf); } + +inline void writeQuoted(const String & x, WriteBuffer & buf) { writeQuotedString(x, buf); } + +inline void writeQuoted(std::string_view x, WriteBuffer & buf) { writeQuotedString(x, buf); } + +inline void writeQuoted(StringRef x, WriteBuffer & buf) { writeQuotedString(x, buf); } + +inline void writeQuoted(const LocalDate & x, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeDateText(x, buf); + writeChar('\'', buf); +} + +inline void writeQuoted(const LocalDateTime & x, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeDateTimeText(x, buf); + writeChar('\'', buf); +} + +inline void writeQuoted(const UUID & x, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeText(x, buf); + writeChar('\'', buf); +} + +inline void writeQuoted(const IPv4 & x, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeText(x, buf); + writeChar('\'', buf); +} + +inline void writeQuoted(const IPv6 & x, WriteBuffer & buf) +{ + writeChar('\'', buf); + writeText(x, buf); + writeChar('\'', buf); +} + +/// String, date, datetime are in double quotes with C-style escaping. Numbers - without. +template <typename T> +requires is_arithmetic_v<T> +inline void writeDoubleQuoted(const T & x, WriteBuffer & buf) { writeText(x, buf); } + +inline void writeDoubleQuoted(const String & x, WriteBuffer & buf) { writeDoubleQuotedString(x, buf); } + +inline void writeDoubleQuoted(std::string_view x, WriteBuffer & buf) { writeDoubleQuotedString(x, buf); } + +inline void writeDoubleQuoted(StringRef x, WriteBuffer & buf) { writeDoubleQuotedString(x, buf); } + +inline void writeDoubleQuoted(const LocalDate & x, WriteBuffer & buf) +{ + writeChar('"', buf); + writeDateText(x, buf); + writeChar('"', buf); +} + +inline void writeDoubleQuoted(const LocalDateTime & x, WriteBuffer & buf) +{ + writeChar('"', buf); + writeDateTimeText(x, buf); + writeChar('"', buf); +} + +inline void writeDoubleQuoted(const UUID & x, WriteBuffer & buf) +{ + writeChar('"', buf); + writeText(x, buf); + writeChar('"', buf); +} + +inline void writeDoubleQuoted(const IPv4 & x, WriteBuffer & buf) +{ + writeChar('"', buf); + writeText(x, buf); + writeChar('"', buf); +} + +inline void writeDoubleQuoted(const IPv6 & x, WriteBuffer & buf) +{ + writeChar('"', buf); + writeText(x, buf); + writeChar('"', buf); +} + +/// String - in double quotes and with CSV-escaping; date, datetime - in double quotes. Numbers - without. +template <typename T> +requires is_arithmetic_v<T> +inline void writeCSV(const T & x, WriteBuffer & buf) { writeText(x, buf); } + +inline void writeCSV(const String & x, WriteBuffer & buf) { writeCSVString<>(x, buf); } +inline void writeCSV(const LocalDate & x, WriteBuffer & buf) { writeDoubleQuoted(x, buf); } +inline void writeCSV(const LocalDateTime & x, WriteBuffer & buf) { writeDoubleQuoted(x, buf); } +inline void writeCSV(const UUID & x, WriteBuffer & buf) { writeDoubleQuoted(x, buf); } +inline void writeCSV(const IPv4 & x, WriteBuffer & buf) { writeDoubleQuoted(x, buf); } +inline void writeCSV(const IPv6 & x, WriteBuffer & buf) { writeDoubleQuoted(x, buf); } + +template <typename T> +void writeBinary(const std::vector<T> & x, WriteBuffer & buf) +{ + size_t size = x.size(); + writeVarUInt(size, buf); + for (size_t i = 0; i < size; ++i) + writeBinary(x[i], buf); +} + +template <typename T> +void writeQuoted(const std::vector<T> & x, WriteBuffer & buf) +{ + writeChar('[', buf); + for (size_t i = 0, size = x.size(); i < size; ++i) + { + if (i != 0) + writeChar(',', buf); + writeQuoted(x[i], buf); + } + writeChar(']', buf); +} + +template <typename T> +void writeDoubleQuoted(const std::vector<T> & x, WriteBuffer & buf) +{ + writeChar('[', buf); + for (size_t i = 0, size = x.size(); i < size; ++i) + { + if (i != 0) + writeChar(',', buf); + writeDoubleQuoted(x[i], buf); + } + writeChar(']', buf); +} + +template <typename T> +void writeText(const std::vector<T> & x, WriteBuffer & buf) +{ + writeQuoted(x, buf); +} + + +/// Serialize exception (so that it can be transferred over the network) +void writeException(const Exception & e, WriteBuffer & buf, bool with_stack_trace); + + +/// An easy-to-use method for converting something to a string in text form. +template <typename T> +inline String toString(const T & x) +{ + WriteBufferFromOwnString buf; + writeText(x, buf); + return buf.str(); +} + +inline String toString(const CityHash_v1_0_2::uint128 & hash) +{ + WriteBufferFromOwnString buf; + writeText(hash.low64, buf); + writeChar('_', buf); + writeText(hash.high64, buf); + return buf.str(); +} + +template <typename T> +inline String toStringWithFinalSeparator(const std::vector<T> & x, const String & final_sep) +{ + WriteBufferFromOwnString buf; + for (auto it = x.begin(); it != x.end(); ++it) + { + if (it != x.begin()) + { + if (std::next(it) == x.end()) + writeString(final_sep, buf); + else + writeString(", ", buf); + } + writeQuoted(*it, buf); + } + + return buf.str(); +} + +inline void writeNullTerminatedString(const String & s, WriteBuffer & buffer) +{ + /// c_str is guaranteed to return zero-terminated string + buffer.write(s.c_str(), s.size() + 1); +} + +template <std::endian endian, typename T> +inline void writeBinaryEndian(T x, WriteBuffer & buf) +{ + transformEndianness<endian>(x); + writeBinary(x, buf); +} + +template <typename T> +inline void writeBinaryLittleEndian(T x, WriteBuffer & buf) +{ + writeBinaryEndian<std::endian::little>(x, buf); +} + +template <typename T> +inline void writeBinaryBigEndian(T x, WriteBuffer & buf) +{ + writeBinaryEndian<std::endian::big>(x, buf); +} + + +struct PcgSerializer +{ + static void serializePcg32(const pcg32_fast & rng, WriteBuffer & buf) + { + writeText(rng.multiplier(), buf); + writeChar(' ', buf); + writeText(rng.increment(), buf); + writeChar(' ', buf); + writeText(rng.state_, buf); + } +}; + +void writePointerHex(const void * ptr, WriteBuffer & buf); + +} + +template<> +struct fmt::formatter<DB::UUID> +{ + template<typename ParseContext> + constexpr auto parse(ParseContext & context) + { + return context.begin(); + } + + template<typename FormatContext> + auto format(const DB::UUID & uuid, FormatContext & context) + { + return fmt::format_to(context.out(), "{}", toString(uuid)); + } +}; diff --git a/contrib/clickhouse/src/IO/WriteIntText.h b/contrib/clickhouse/src/IO/WriteIntText.h new file mode 100644 index 0000000000..c9a4cb0241 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteIntText.h @@ -0,0 +1,46 @@ +#pragma once + +#include <Core/Defines.h> +#include <IO/WriteBuffer.h> +#include <base/itoa.h> + + +template <typename T> constexpr size_t max_int_width = 20; +template <> inline constexpr size_t max_int_width<UInt8> = 3; /// 255 +template <> inline constexpr size_t max_int_width<Int8> = 4; /// -128 +template <> inline constexpr size_t max_int_width<UInt16> = 5; /// 65535 +template <> inline constexpr size_t max_int_width<Int16> = 6; /// -32768 +template <> inline constexpr size_t max_int_width<UInt32> = 10; /// 4294967295 +template <> inline constexpr size_t max_int_width<Int32> = 11; /// -2147483648 +template <> inline constexpr size_t max_int_width<UInt64> = 20; /// 18446744073709551615 +template <> inline constexpr size_t max_int_width<Int64> = 20; /// -9223372036854775808 +template <> inline constexpr size_t max_int_width<UInt128> = 39; /// 340282366920938463463374607431768211455 +template <> inline constexpr size_t max_int_width<Int128> = 40; /// -170141183460469231731687303715884105728 +template <> inline constexpr size_t max_int_width<UInt256> = 78; /// 115792089237316195423570985008687907853269984665640564039457584007913129639935 +template <> inline constexpr size_t max_int_width<Int256> = 78; /// -57896044618658097711785492504343953926634992332820282019728792003956564819968 + + +namespace DB +{ + +namespace detail +{ + template <typename T> + void NO_INLINE writeUIntTextFallback(T x, WriteBuffer & buf) + { + char tmp[max_int_width<T>]; + char * end = itoa(x, tmp); + buf.write(tmp, end - tmp); + } +} + +template <typename T> +void writeIntText(T x, WriteBuffer & buf) +{ + if (likely(reinterpret_cast<uintptr_t>(buf.position()) + max_int_width<T> < reinterpret_cast<uintptr_t>(buf.buffer().end()))) + buf.position() = itoa(x, buf.position()); + else + detail::writeUIntTextFallback(x, buf); +} + +} diff --git a/contrib/clickhouse/src/IO/WriteSettings.h b/contrib/clickhouse/src/IO/WriteSettings.h new file mode 100644 index 0000000000..8f22e44145 --- /dev/null +++ b/contrib/clickhouse/src/IO/WriteSettings.h @@ -0,0 +1,32 @@ +#pragma once + +#include <Common/Throttler_fwd.h> +#include <IO/ResourceLink.h> + +namespace DB +{ + +/// Settings to be passed to IDisk::writeFile() +struct WriteSettings +{ + /// Bandwidth throttler to use during writing + ThrottlerPtr remote_throttler; + ThrottlerPtr local_throttler; + + // Resource to be used during reading + ResourceLink resource_link; + + /// Filesystem cache settings + bool enable_filesystem_cache_on_write_operations = false; + bool enable_filesystem_cache_log = false; + bool throw_on_error_from_cache = false; + + bool s3_allow_parallel_part_upload = true; + + /// Monitoring + bool for_object_storage = false; // to choose which profile events should be incremented + + bool operator==(const WriteSettings & other) const = default; +}; + +} diff --git a/contrib/clickhouse/src/IO/ZlibDeflatingWriteBuffer.cpp b/contrib/clickhouse/src/IO/ZlibDeflatingWriteBuffer.cpp new file mode 100644 index 0000000000..5455adcb7c --- /dev/null +++ b/contrib/clickhouse/src/IO/ZlibDeflatingWriteBuffer.cpp @@ -0,0 +1,131 @@ +#include <IO/ZlibDeflatingWriteBuffer.h> +#include <Common/Exception.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ZLIB_DEFLATE_FAILED; +} + + +ZlibDeflatingWriteBuffer::ZlibDeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + CompressionMethod compression_method, + int compression_level, + size_t buf_size, + char * existing_memory, + size_t alignment) + : WriteBufferWithOwnMemoryDecorator(std::move(out_), buf_size, existing_memory, alignment) +{ + zstr.zalloc = nullptr; + zstr.zfree = nullptr; + zstr.opaque = nullptr; + zstr.next_in = nullptr; + zstr.avail_in = 0; + zstr.next_out = nullptr; + zstr.avail_out = 0; + + int window_bits = 15; + if (compression_method == CompressionMethod::Gzip) + { + window_bits += 16; + } + + int rc = deflateInit2(&zstr, compression_level, Z_DEFLATED, window_bits, 8, Z_DEFAULT_STRATEGY); + + if (rc != Z_OK) + throw Exception(ErrorCodes::ZLIB_DEFLATE_FAILED, "deflateInit2 failed: {}; zlib version: {}", zError(rc), ZLIB_VERSION); +} + +void ZlibDeflatingWriteBuffer::nextImpl() +{ + if (!offset()) + return; + + zstr.next_in = reinterpret_cast<unsigned char *>(working_buffer.begin()); + zstr.avail_in = static_cast<unsigned>(offset()); + + try + { + do + { + out->nextIfAtEnd(); + zstr.next_out = reinterpret_cast<unsigned char *>(out->position()); + zstr.avail_out = static_cast<unsigned>(out->buffer().end() - out->position()); + + int rc = deflate(&zstr, Z_NO_FLUSH); + out->position() = out->buffer().end() - zstr.avail_out; + + if (rc != Z_OK) + throw Exception(ErrorCodes::ZLIB_DEFLATE_FAILED, "deflate failed: {}", zError(rc)); + } + while (zstr.avail_in > 0 || zstr.avail_out == 0); + } + catch (...) + { + /// Do not try to write next time after exception. + out->position() = out->buffer().begin(); + throw; + } +} + +ZlibDeflatingWriteBuffer::~ZlibDeflatingWriteBuffer() = default; + +void ZlibDeflatingWriteBuffer::finalizeBefore() +{ + next(); + + /// https://github.com/zlib-ng/zlib-ng/issues/494 + do + { + out->nextIfAtEnd(); + zstr.next_out = reinterpret_cast<unsigned char *>(out->position()); + zstr.avail_out = static_cast<unsigned>(out->buffer().end() - out->position()); + + int rc = deflate(&zstr, Z_FULL_FLUSH); + out->position() = out->buffer().end() - zstr.avail_out; + + if (rc != Z_OK) + throw Exception(ErrorCodes::ZLIB_DEFLATE_FAILED, "deflate failed: {}", zError(rc)); + } + while (zstr.avail_out == 0); + + while (true) + { + out->nextIfAtEnd(); + zstr.next_out = reinterpret_cast<unsigned char *>(out->position()); + zstr.avail_out = static_cast<unsigned>(out->buffer().end() - out->position()); + + int rc = deflate(&zstr, Z_FINISH); + out->position() = out->buffer().end() - zstr.avail_out; + + if (rc == Z_STREAM_END) + { + return; + } + + if (rc != Z_OK) + throw Exception(ErrorCodes::ZLIB_DEFLATE_FAILED, "deflate finalizeImpl() failed: {}", zError(rc)); + } +} + +void ZlibDeflatingWriteBuffer::finalizeAfter() +{ + try + { + int rc = deflateEnd(&zstr); + if (rc != Z_OK) + throw Exception(ErrorCodes::ZLIB_DEFLATE_FAILED, "deflateEnd failed: {}", zError(rc)); + } + catch (...) + { + /// It is OK not to terminate under an error from deflateEnd() + /// since all data already written to the stream. + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +} diff --git a/contrib/clickhouse/src/IO/ZlibDeflatingWriteBuffer.h b/contrib/clickhouse/src/IO/ZlibDeflatingWriteBuffer.h new file mode 100644 index 0000000000..58e709b54e --- /dev/null +++ b/contrib/clickhouse/src/IO/ZlibDeflatingWriteBuffer.h @@ -0,0 +1,41 @@ +#pragma once + +#include <IO/WriteBuffer.h> +#include <IO/BufferWithOwnMemory.h> +#include <IO/CompressionMethod.h> +#include <IO/WriteBufferDecorator.h> + + +#include <zlib.h> + + +namespace DB +{ + +/// Performs compression using zlib library and writes compressed data to out_ WriteBuffer. +class ZlibDeflatingWriteBuffer : public WriteBufferWithOwnMemoryDecorator +{ +public: + ZlibDeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + CompressionMethod compression_method, + int compression_level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~ZlibDeflatingWriteBuffer() override; + +private: + void nextImpl() override; + + /// Flush all pending data and write zlib footer to the underlying buffer. + /// After the first call to this function, subsequent calls will have no effect and + /// an attempt to write to this buffer will result in exception. + virtual void finalizeBefore() override; + virtual void finalizeAfter() override; + + z_stream zstr; +}; + +} diff --git a/contrib/clickhouse/src/IO/ZlibInflatingReadBuffer.cpp b/contrib/clickhouse/src/IO/ZlibInflatingReadBuffer.cpp new file mode 100644 index 0000000000..b43dda1bfc --- /dev/null +++ b/contrib/clickhouse/src/IO/ZlibInflatingReadBuffer.cpp @@ -0,0 +1,125 @@ +#include <IO/ZlibInflatingReadBuffer.h> +#include <IO/WithFileName.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ZLIB_INFLATE_FAILED; + extern const int ARGUMENT_OUT_OF_BOUND; +} + +ZlibInflatingReadBuffer::ZlibInflatingReadBuffer( + std::unique_ptr<ReadBuffer> in_, + CompressionMethod compression_method, + size_t buf_size, + char * existing_memory, + size_t alignment) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment) + , eof_flag(false) +{ + if (buf_size > max_buffer_size) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, + "Zlib does not support decompression with buffer size greater than {}, got buffer size: {}", + max_buffer_size, buf_size); + + zstr.zalloc = nullptr; + zstr.zfree = nullptr; + zstr.opaque = nullptr; + zstr.next_in = nullptr; + zstr.avail_in = 0; + zstr.next_out = nullptr; + zstr.avail_out = 0; + + int window_bits = 15; + if (compression_method == CompressionMethod::Gzip) + { + window_bits += 16; + } + + int rc = inflateInit2(&zstr, window_bits); + + if (rc != Z_OK) + throw Exception(ErrorCodes::ZLIB_INFLATE_FAILED, "inflateInit2 failed: {}; zlib version: {}.", zError(rc), ZLIB_VERSION); +} + +ZlibInflatingReadBuffer::~ZlibInflatingReadBuffer() +{ + inflateEnd(&zstr); +} + +bool ZlibInflatingReadBuffer::nextImpl() +{ + /// Need do-while loop to prevent situation, when + /// eof was not reached, but working buffer became empty (when nothing was decompressed in current iteration) + /// (this happens with compression algorithms, same idea is implemented in ZstdInflatingReadBuffer) + do + { + /// if we already found eof, we shouldn't do anything + if (eof_flag) + return false; + + /// if there is no available bytes in zstr, move ptr to next available data + if (!zstr.avail_in) + { + in->nextIfAtEnd(); + zstr.next_in = reinterpret_cast<unsigned char *>(in->position()); + zstr.avail_in = static_cast<BufferSizeType>(std::min( + static_cast<UInt64>(in->buffer().end() - in->position()), + static_cast<UInt64>(max_buffer_size))); + } + + /// init output bytes (place, where decompressed data will be) + zstr.next_out = reinterpret_cast<unsigned char *>(internal_buffer.begin()); + zstr.avail_out = static_cast<BufferSizeType>(internal_buffer.size()); + + size_t old_total_in = zstr.total_in; + int rc = inflate(&zstr, Z_NO_FLUSH); + + /// move in stream on place, where reading stopped + size_t bytes_read = zstr.total_in - old_total_in; + in->position() += bytes_read; + + /// change size of working buffer (it's size equal to internal_buffer size without unused uncompressed values) + working_buffer.resize(internal_buffer.size() - zstr.avail_out); + + /// If end was reached, it can be end of file or end of part (for example, chunk) + if (rc == Z_STREAM_END) + { + /// if it is end of file, remember this and return + /// * true if we can work with working buffer (we still have something to read, so next must return true) + /// * false if there is no data in working buffer + if (in->eof()) + { + eof_flag = true; + return !working_buffer.empty(); + } + /// If it is not end of file, we need to reset zstr and return true, because we still have some data to read + else + { + rc = inflateReset(&zstr); + if (rc != Z_OK) + throw Exception( + ErrorCodes::ZLIB_INFLATE_FAILED, + "inflateReset failed: {}{}", + zError(rc), + getExceptionEntryWithFileName(*in)); + return true; + } + } + + /// If it is not end and not OK, something went wrong, throw exception + if (rc != Z_OK) + throw Exception( + ErrorCodes::ZLIB_INFLATE_FAILED, + "inflate failed: {}{}", + zError(rc), + getExceptionEntryWithFileName(*in)); + } + while (working_buffer.empty()); + + /// if code reach this section, working buffer is not empty, so we have some data to process + return true; +} + +} diff --git a/contrib/clickhouse/src/IO/ZlibInflatingReadBuffer.h b/contrib/clickhouse/src/IO/ZlibInflatingReadBuffer.h new file mode 100644 index 0000000000..d9ca4c6126 --- /dev/null +++ b/contrib/clickhouse/src/IO/ZlibInflatingReadBuffer.h @@ -0,0 +1,44 @@ +#pragma once + +#include <IO/ReadBuffer.h> +#include <IO/CompressedReadBufferWrapper.h> +#include <IO/CompressionMethod.h> + +#include <limits> +#include <zlib.h> + + +namespace DB +{ + +namespace ErrorCodes +{ +} + +/// Reads compressed data from ReadBuffer in_ and performs decompression using zlib library. +/// This buffer is able to seamlessly decompress multiple concatenated zlib streams. +class ZlibInflatingReadBuffer : public CompressedReadBufferWrapper +{ +public: + ZlibInflatingReadBuffer( + std::unique_ptr<ReadBuffer> in_, + CompressionMethod compression_method, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~ZlibInflatingReadBuffer() override; + +private: + bool nextImpl() override; + + z_stream zstr; + bool eof_flag; + + /// Limit size of buffer because zlib uses + /// UInt32 for sizes of internal buffers. + using BufferSizeType = decltype(zstr.avail_in); + static constexpr auto max_buffer_size = std::numeric_limits<BufferSizeType>::max(); +}; + +} diff --git a/contrib/clickhouse/src/IO/ZstdDeflatingAppendableWriteBuffer.cpp b/contrib/clickhouse/src/IO/ZstdDeflatingAppendableWriteBuffer.cpp new file mode 100644 index 0000000000..81be8d8ce4 --- /dev/null +++ b/contrib/clickhouse/src/IO/ZstdDeflatingAppendableWriteBuffer.cpp @@ -0,0 +1,223 @@ +#include <IO/ZstdDeflatingAppendableWriteBuffer.h> +#include <Common/Exception.h> +#include "IO/ReadBufferFromFileBase.h" +#include <IO/ReadBufferFromFile.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ZSTD_ENCODER_FAILED; +} + +ZstdDeflatingAppendableWriteBuffer::ZstdDeflatingAppendableWriteBuffer( + std::unique_ptr<WriteBufferFromFileBase> out_, + int compression_level, + bool append_to_existing_file_, + std::function<std::unique_ptr<ReadBufferFromFileBase>()> read_buffer_creator_, + size_t buf_size, + char * existing_memory, + size_t alignment) + : BufferWithOwnMemory(buf_size, existing_memory, alignment) + , out(std::move(out_)) + , read_buffer_creator(std::move(read_buffer_creator_)) + , append_to_existing_file(append_to_existing_file_) +{ + cctx = ZSTD_createCCtx(); + if (cctx == nullptr) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, "ZSTD stream encoder init failed: ZSTD version: {}", ZSTD_VERSION_STRING); + size_t ret = ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, compression_level); + if (ZSTD_isError(ret)) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, + "ZSTD stream encoder option setting failed: error code: {}; zstd version: {}", + ret, ZSTD_VERSION_STRING); + + input = {nullptr, 0, 0}; + output = {nullptr, 0, 0}; +} + +void ZstdDeflatingAppendableWriteBuffer::nextImpl() +{ + if (!offset()) + return; + + input.src = reinterpret_cast<unsigned char *>(working_buffer.begin()); + input.size = offset(); + input.pos = 0; + + if (first_write && append_to_existing_file && isNeedToAddEmptyBlock()) + { + addEmptyBlock(); + first_write = false; + } + + try + { + bool ended = false; + do + { + out->nextIfAtEnd(); + + output.dst = reinterpret_cast<unsigned char *>(out->buffer().begin()); + output.size = out->buffer().size(); + output.pos = out->offset(); + + size_t compression_result = ZSTD_compressStream2(cctx, &output, &input, ZSTD_e_flush); + if (ZSTD_isError(compression_result)) + throw Exception( + ErrorCodes::ZSTD_ENCODER_FAILED, + "ZSTD stream decoding failed: error code: {}; ZSTD version: {}", + ZSTD_getErrorName(compression_result), ZSTD_VERSION_STRING); + + first_write = false; + out->position() = out->buffer().begin() + output.pos; + + bool everything_was_compressed = (input.pos == input.size); + bool everything_was_flushed = compression_result == 0; + + ended = everything_was_compressed && everything_was_flushed; + } while (!ended); + } + catch (...) + { + /// Do not try to write next time after exception. + out->position() = out->buffer().begin(); + throw; + } + +} + +ZstdDeflatingAppendableWriteBuffer::~ZstdDeflatingAppendableWriteBuffer() +{ + finalize(); +} + +void ZstdDeflatingAppendableWriteBuffer::finalizeImpl() +{ + if (first_write) + { + /// To free cctx + finalizeZstd(); + /// Nothing was written + } + else + { + try + { + finalizeBefore(); + out->finalize(); + finalizeAfter(); + } + catch (...) + { + /// Do not try to flush next time after exception. + out->position() = out->buffer().begin(); + throw; + } + } +} + +void ZstdDeflatingAppendableWriteBuffer::finalizeBefore() +{ + next(); + + out->nextIfAtEnd(); + + input.src = reinterpret_cast<unsigned char *>(working_buffer.begin()); + input.size = offset(); + input.pos = 0; + + output.dst = reinterpret_cast<unsigned char *>(out->buffer().begin()); + output.size = out->buffer().size(); + output.pos = out->offset(); + + /// Actually we can use ZSTD_e_flush here and add empty termination + /// block on each new buffer creation for non-empty file unconditionally (without isNeedToAddEmptyBlock). + /// However ZSTD_decompressStream is able to read non-terminated frame (we use it in reader buffer), + /// but console zstd utility cannot. + size_t remaining = ZSTD_compressStream2(cctx, &output, &input, ZSTD_e_end); + while (remaining != 0) + { + if (ZSTD_isError(remaining)) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, + "ZSTD stream encoder end failed: error: '{}' ZSTD version: {}", + ZSTD_getErrorName(remaining), ZSTD_VERSION_STRING); + + remaining = ZSTD_compressStream2(cctx, &output, &input, ZSTD_e_end); + + out->position() = out->buffer().begin() + output.pos; + + if (!out->hasPendingData()) + { + out->next(); + output.dst = reinterpret_cast<unsigned char *>(out->buffer().begin()); + output.size = out->buffer().size(); + output.pos = out->offset(); + } + } +} + +void ZstdDeflatingAppendableWriteBuffer::finalizeAfter() +{ + finalizeZstd(); +} + +void ZstdDeflatingAppendableWriteBuffer::finalizeZstd() +{ + try + { + size_t err = ZSTD_freeCCtx(cctx); + /// This is just in case, since it is impossible to get an error by using this wrapper. + if (unlikely(err)) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, "ZSTD_freeCCtx failed: error: '{}'; zstd version: {}", + ZSTD_getErrorName(err), ZSTD_VERSION_STRING); + } + catch (...) + { + /// It is OK not to terminate under an error from ZSTD_freeCCtx() + /// since all data already written to the stream. + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +void ZstdDeflatingAppendableWriteBuffer::addEmptyBlock() +{ + /// HACK: https://github.com/facebook/zstd/issues/2090#issuecomment-620158967 + + if (out->buffer().size() - out->offset() < ZSTD_CORRECT_TERMINATION_LAST_BLOCK.size()) + out->next(); + + std::memcpy(out->buffer().begin() + out->offset(), + ZSTD_CORRECT_TERMINATION_LAST_BLOCK.data(), ZSTD_CORRECT_TERMINATION_LAST_BLOCK.size()); + + out->position() = out->buffer().begin() + out->offset() + ZSTD_CORRECT_TERMINATION_LAST_BLOCK.size(); +} + + +bool ZstdDeflatingAppendableWriteBuffer::isNeedToAddEmptyBlock() +{ + auto reader = read_buffer_creator(); + auto fsize = reader->getFileSize(); + if (fsize > 3) + { + std::array<char, 3> result; + reader->seek(fsize - 3, SEEK_SET); + reader->readStrict(result.data(), 3); + + /// If we don't have correct block in the end, then we need to add it manually. + /// NOTE: maybe we can have the same bytes in case of data corruption/unfinished write. + /// But in this case file still corrupted and we have to remove it. + return result != ZSTD_CORRECT_TERMINATION_LAST_BLOCK; + } + else if (fsize > 0) + { + throw Exception( + ErrorCodes::ZSTD_ENCODER_FAILED, + "Trying to write to non-empty file '{}' with tiny size {}. It can lead to data corruption", + out->getFileName(), fsize); + } + return false; +} + +} diff --git a/contrib/clickhouse/src/IO/ZstdDeflatingAppendableWriteBuffer.h b/contrib/clickhouse/src/IO/ZstdDeflatingAppendableWriteBuffer.h new file mode 100644 index 0000000000..d9c4f32d6d --- /dev/null +++ b/contrib/clickhouse/src/IO/ZstdDeflatingAppendableWriteBuffer.h @@ -0,0 +1,84 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/CompressionMethod.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteBufferDecorator.h> +#include <IO/WriteBufferFromFile.h> +#include <IO/ReadBufferFromFileBase.h> + +#include <zstd.h> + + +namespace DB +{ + +/// Performs stream compression using zstd library and writes compressed data to out_ WriteBuffer. +/// Main differences from ZstdDeflatingWriteBuffer: +/// 1) Allows to continue to write to the same output even if finalize() (or destructor) was not called, for example +/// when server was killed with 9 signal. Natively zstd doesn't support such feature because +/// ZSTD_decompressStream expect to see empty block (3 bytes 0x01, 0x00, 0x00) at the end of each frame. There is not API function for it +/// so we just use HACK and add empty block manually on the first write (see addEmptyBlock). Maintainers of zstd +/// said that there is no risks of compatibility issues https://github.com/facebook/zstd/issues/2090#issuecomment-620158967. +/// 2) Doesn't support internal ZSTD check-summing, because ZSTD checksums written at the end of frame (frame epilogue). +/// +class ZstdDeflatingAppendableWriteBuffer : public BufferWithOwnMemory<WriteBuffer> +{ +public: + using ZSTDLastBlock = const std::array<char, 3>; + /// Frame end block. If we read non-empty file and see no such flag we should add it. + static inline constexpr ZSTDLastBlock ZSTD_CORRECT_TERMINATION_LAST_BLOCK = {0x01, 0x00, 0x00}; + + ZstdDeflatingAppendableWriteBuffer( + std::unique_ptr<WriteBufferFromFileBase> out_, + int compression_level, + bool append_to_existing_file_, + std::function<std::unique_ptr<ReadBufferFromFileBase>()> read_buffer_creator_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~ZstdDeflatingAppendableWriteBuffer() override; + + void sync() override + { + next(); + out->sync(); + } + + WriteBuffer * getNestedBuffer() { return out.get(); } + +private: + /// NOTE: will fill compressed data to the out.working_buffer, but will not call out.next method until the buffer is full + void nextImpl() override; + + /// Write terminating ZSTD_e_end: empty block + frame epilogue. BTW it + /// should be almost noop, because frame epilogue contains only checksums, + /// and they are disabled for this buffer. + /// Flush all pending data and write zstd footer to the underlying buffer. + /// After the first call to this function, subsequent calls will have no effect and + /// an attempt to write to this buffer will result in exception. + void finalizeImpl() override; + void finalizeBefore(); + void finalizeAfter(); + void finalizeZstd(); + + /// Read three last bytes from non-empty compressed file and compares them with + /// ZSTD_CORRECT_TERMINATION_LAST_BLOCK. + bool isNeedToAddEmptyBlock(); + + /// Adding zstd empty block (ZSTD_CORRECT_TERMINATION_LAST_BLOCK) to out.working_buffer + void addEmptyBlock(); + + std::unique_ptr<WriteBufferFromFileBase> out; + std::function<std::unique_ptr<ReadBufferFromFileBase>()> read_buffer_creator; + + bool append_to_existing_file = false; + ZSTD_CCtx * cctx; + ZSTD_inBuffer input; + ZSTD_outBuffer output; + /// Flipped on the first nextImpl call + bool first_write = true; +}; + +} diff --git a/contrib/clickhouse/src/IO/ZstdDeflatingWriteBuffer.cpp b/contrib/clickhouse/src/IO/ZstdDeflatingWriteBuffer.cpp new file mode 100644 index 0000000000..83d8487e3e --- /dev/null +++ b/contrib/clickhouse/src/IO/ZstdDeflatingWriteBuffer.cpp @@ -0,0 +1,104 @@ +#include <IO/ZstdDeflatingWriteBuffer.h> +#include <Common/Exception.h> + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ZSTD_ENCODER_FAILED; +} + +ZstdDeflatingWriteBuffer::ZstdDeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, int compression_level, size_t buf_size, char * existing_memory, size_t alignment) + : WriteBufferWithOwnMemoryDecorator(std::move(out_), buf_size, existing_memory, alignment) +{ + cctx = ZSTD_createCCtx(); + if (cctx == nullptr) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, "zstd stream encoder init failed: zstd version: {}", ZSTD_VERSION_STRING); + size_t ret = ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, compression_level); + if (ZSTD_isError(ret)) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, + "zstd stream encoder option setting failed: error code: {}; zstd version: {}", + ret, ZSTD_VERSION_STRING); + ret = ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, 1); + if (ZSTD_isError(ret)) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, + "zstd stream encoder option setting failed: error code: {}; zstd version: {}", + ret, ZSTD_VERSION_STRING); + + input = {nullptr, 0, 0}; + output = {nullptr, 0, 0}; +} + +ZstdDeflatingWriteBuffer::~ZstdDeflatingWriteBuffer() = default; + +void ZstdDeflatingWriteBuffer::flush(ZSTD_EndDirective mode) +{ + input.src = reinterpret_cast<unsigned char *>(working_buffer.begin()); + input.size = offset(); + input.pos = 0; + + try + { + bool ended = false; + do + { + out->nextIfAtEnd(); + + output.dst = reinterpret_cast<unsigned char *>(out->buffer().begin()); + output.size = out->buffer().size(); + output.pos = out->offset(); + + size_t compression_result = ZSTD_compressStream2(cctx, &output, &input, mode); + if (ZSTD_isError(compression_result)) + throw Exception( + ErrorCodes::ZSTD_ENCODER_FAILED, + "ZSTD stream encoding failed: error: '{}'; zstd version: {}", + ZSTD_getErrorName(compression_result), ZSTD_VERSION_STRING); + + out->position() = out->buffer().begin() + output.pos; + + bool everything_was_compressed = (input.pos == input.size); + bool everything_was_flushed = compression_result == 0; + + ended = everything_was_compressed && everything_was_flushed; + } while (!ended); + } + catch (...) + { + /// Do not try to write next time after exception. + out->position() = out->buffer().begin(); + throw; + } +} + +void ZstdDeflatingWriteBuffer::nextImpl() +{ + if (offset()) + flush(ZSTD_e_flush); +} + +void ZstdDeflatingWriteBuffer::finalizeBefore() +{ + flush(ZSTD_e_end); +} + +void ZstdDeflatingWriteBuffer::finalizeAfter() +{ + try + { + size_t err = ZSTD_freeCCtx(cctx); + /// This is just in case, since it is impossible to get an error by using this wrapper. + if (unlikely(err)) + throw Exception(ErrorCodes::ZSTD_ENCODER_FAILED, "ZSTD_freeCCtx failed: error: '{}'; zstd version: {}", + ZSTD_getErrorName(err), ZSTD_VERSION_STRING); + } + catch (...) + { + /// It is OK not to terminate under an error from ZSTD_freeCCtx() + /// since all data already written to the stream. + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +} diff --git a/contrib/clickhouse/src/IO/ZstdDeflatingWriteBuffer.h b/contrib/clickhouse/src/IO/ZstdDeflatingWriteBuffer.h new file mode 100644 index 0000000000..a66d6085a7 --- /dev/null +++ b/contrib/clickhouse/src/IO/ZstdDeflatingWriteBuffer.h @@ -0,0 +1,47 @@ +#pragma once + +#include <IO/BufferWithOwnMemory.h> +#include <IO/CompressionMethod.h> +#include <IO/WriteBuffer.h> +#include <IO/WriteBufferDecorator.h> + +#include <zstd.h> + +namespace DB +{ + +/// Performs compression using zstd library and writes compressed data to out_ WriteBuffer. +class ZstdDeflatingWriteBuffer : public WriteBufferWithOwnMemoryDecorator +{ +public: + ZstdDeflatingWriteBuffer( + std::unique_ptr<WriteBuffer> out_, + int compression_level, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0); + + ~ZstdDeflatingWriteBuffer() override; + + void sync() override + { + out->sync(); + } + +private: + void nextImpl() override; + + /// Flush all pending data and write zstd footer to the underlying buffer. + /// After the first call to this function, subsequent calls will have no effect and + /// an attempt to write to this buffer will result in exception. + void finalizeBefore() override; + void finalizeAfter() override; + + void flush(ZSTD_EndDirective mode); + + ZSTD_CCtx * cctx; + ZSTD_inBuffer input; + ZSTD_outBuffer output; +}; + +} diff --git a/contrib/clickhouse/src/IO/ZstdInflatingReadBuffer.cpp b/contrib/clickhouse/src/IO/ZstdInflatingReadBuffer.cpp new file mode 100644 index 0000000000..2b663ec714 --- /dev/null +++ b/contrib/clickhouse/src/IO/ZstdInflatingReadBuffer.cpp @@ -0,0 +1,95 @@ +#include <IO/ZstdInflatingReadBuffer.h> +#include <IO/WithFileName.h> +#include <zstd_errors.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ZSTD_DECODER_FAILED; +} + +ZstdInflatingReadBuffer::ZstdInflatingReadBuffer(std::unique_ptr<ReadBuffer> in_, size_t buf_size, char * existing_memory, size_t alignment, int zstd_window_log_max) + : CompressedReadBufferWrapper(std::move(in_), buf_size, existing_memory, alignment) +{ + dctx = ZSTD_createDCtx(); + input = {nullptr, 0, 0}; + output = {nullptr, 0, 0}; + + if (dctx == nullptr) + { + throw Exception(ErrorCodes::ZSTD_DECODER_FAILED, "zstd_stream_decoder init failed: zstd version: {}", ZSTD_VERSION_STRING); + } + + size_t ret = ZSTD_DCtx_setParameter(dctx, ZSTD_d_windowLogMax, zstd_window_log_max); + if (ZSTD_isError(ret)) + { + throw Exception(ErrorCodes::ZSTD_DECODER_FAILED, "zstd_stream_decoder init failed: {}", ZSTD_getErrorName(ret)); + } +} + +ZstdInflatingReadBuffer::~ZstdInflatingReadBuffer() +{ + ZSTD_freeDCtx(dctx); +} + +bool ZstdInflatingReadBuffer::nextImpl() +{ + do + { + // If it is known that end of file was reached, return false + if (eof_flag) + return false; + + /// If end was reached, get next part + if (input.pos >= input.size) + { + in->nextIfAtEnd(); + input.src = reinterpret_cast<unsigned char *>(in->position()); + input.pos = 0; + input.size = in->buffer().end() - in->position(); + } + + /// fill output + output.dst = reinterpret_cast<unsigned char *>(internal_buffer.begin()); + output.size = internal_buffer.size(); + output.pos = 0; + + /// Decompress data and check errors. + size_t ret = ZSTD_decompressStream(dctx, &output, &input); + if (ZSTD_getErrorCode(ret)) + { + throw Exception( + ErrorCodes::ZSTD_DECODER_FAILED, + "ZSTD stream decoding failed: error '{}'{}; ZSTD version: {}{}", + ZSTD_getErrorName(ret), + ZSTD_error_frameParameter_windowTooLarge == ret + ? ". You can increase the maximum window size with the 'zstd_window_log_max' setting in ClickHouse. Example: 'SET zstd_window_log_max = 31'" + : "", + ZSTD_VERSION_STRING, + getExceptionEntryWithFileName(*in)); + } + + /// Check that something has changed after decompress (input or output position) + assert(in->eof() || output.pos > 0 || in->position() < in->buffer().begin() + input.pos); + + /// move position to the end of read data + in->position() = in->buffer().begin() + input.pos; + working_buffer.resize(output.pos); + + /// If end of file is reached, fill eof variable and return true if there is some data in buffer, otherwise return false + if (in->eof()) + { + eof_flag = true; + return !working_buffer.empty(); + } + /// It is possible, that input buffer is not at eof yet, but nothing was decompressed in current iteration. + /// But there are cases, when such behaviour is not allowed - i.e. if input buffer is not eof, then + /// it has to be guaranteed that working_buffer is not empty. So if it is empty, continue. + } while (output.pos == 0); + + return true; +} + +} diff --git a/contrib/clickhouse/src/IO/ZstdInflatingReadBuffer.h b/contrib/clickhouse/src/IO/ZstdInflatingReadBuffer.h new file mode 100644 index 0000000000..faa6231d4e --- /dev/null +++ b/contrib/clickhouse/src/IO/ZstdInflatingReadBuffer.h @@ -0,0 +1,37 @@ +#pragma once + +#include <IO/CompressedReadBufferWrapper.h> +#include <IO/CompressionMethod.h> +#include <IO/ReadBuffer.h> + +#include <zstd.h> + + +namespace DB +{ +namespace ErrorCodes +{ +} + +class ZstdInflatingReadBuffer : public CompressedReadBufferWrapper +{ +public: + explicit ZstdInflatingReadBuffer( + std::unique_ptr<ReadBuffer> in_, + size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, + char * existing_memory = nullptr, + size_t alignment = 0, + int zstd_window_log_max = 0); + + ~ZstdInflatingReadBuffer() override; + +private: + bool nextImpl() override; + + ZSTD_DCtx * dctx; + ZSTD_inBuffer input; + ZSTD_outBuffer output; + bool eof_flag = false; +}; + +} diff --git a/contrib/clickhouse/src/IO/copyData.cpp b/contrib/clickhouse/src/IO/copyData.cpp new file mode 100644 index 0000000000..07222a930b --- /dev/null +++ b/contrib/clickhouse/src/IO/copyData.cpp @@ -0,0 +1,112 @@ +#include <Common/Exception.h> +#include <Common/Throttler.h> +#include <IO/ReadBuffer.h> +#include <IO/WriteBuffer.h> +#include <IO/copyData.h> + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ATTEMPT_TO_READ_AFTER_EOF; + extern const int CANNOT_READ_ALL_DATA; +} + +namespace +{ + +void copyDataImpl(ReadBuffer & from, WriteBuffer & to, bool check_bytes, size_t bytes, const std::atomic<int> * is_cancelled, ThrottlerPtr throttler) +{ + /// If read to the end of the buffer, eof() either fills the buffer with new data and moves the cursor to the beginning, or returns false. + while (bytes > 0 && !from.eof()) + { + if (is_cancelled && *is_cancelled) + return; + + /// buffer() - a piece of data available for reading; position() - the cursor of the place to which you have already read. + size_t count = std::min(bytes, static_cast<size_t>(from.buffer().end() - from.position())); + to.write(from.position(), count); + from.position() += count; + bytes -= count; + + if (throttler) + throttler->add(count); + } + + if (check_bytes && bytes > 0) + throw Exception(ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF, "Attempt to read after EOF."); +} + +void copyDataImpl(ReadBuffer & from, WriteBuffer & to, bool check_bytes, size_t bytes, std::function<void()> cancellation_hook, ThrottlerPtr throttler) +{ + /// If read to the end of the buffer, eof() either fills the buffer with new data and moves the cursor to the beginning, or returns false. + while (bytes > 0 && !from.eof()) + { + if (cancellation_hook) + cancellation_hook(); + + /// buffer() - a piece of data available for reading; position() - the cursor of the place to which you have already read. + size_t count = std::min(bytes, static_cast<size_t>(from.buffer().end() - from.position())); + to.write(from.position(), count); + from.position() += count; + bytes -= count; + + if (throttler) + throttler->add(count); + } + + if (check_bytes && bytes > 0) + throw Exception(ErrorCodes::ATTEMPT_TO_READ_AFTER_EOF, "Attempt to read after EOF."); +} + +} + +void copyData(ReadBuffer & from, WriteBuffer & to) +{ + copyDataImpl(from, to, false, std::numeric_limits<size_t>::max(), nullptr, nullptr); +} + +void copyData(ReadBuffer & from, WriteBuffer & to, const std::atomic<int> & is_cancelled) +{ + copyDataImpl(from, to, false, std::numeric_limits<size_t>::max(), &is_cancelled, nullptr); +} + +void copyData(ReadBuffer & from, WriteBuffer & to, std::function<void()> cancellation_hook) +{ + copyDataImpl(from, to, false, std::numeric_limits<size_t>::max(), cancellation_hook, nullptr); +} + +void copyData(ReadBuffer & from, WriteBuffer & to, size_t bytes) +{ + copyDataImpl(from, to, true, bytes, nullptr, nullptr); +} + +void copyData(ReadBuffer & from, WriteBuffer & to, size_t bytes, const std::atomic<int> & is_cancelled) +{ + copyDataImpl(from, to, true, bytes, &is_cancelled, nullptr); +} + +void copyData(ReadBuffer & from, WriteBuffer & to, size_t bytes, std::function<void()> cancellation_hook) +{ + copyDataImpl(from, to, true, bytes, cancellation_hook, nullptr); +} + +void copyDataMaxBytes(ReadBuffer & from, WriteBuffer & to, size_t max_bytes) +{ + copyDataImpl(from, to, false, max_bytes, nullptr, nullptr); + if (!from.eof()) + throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data, max readable size reached."); +} + +void copyDataWithThrottler(ReadBuffer & from, WriteBuffer & to, const std::atomic<int> & is_cancelled, ThrottlerPtr throttler) +{ + copyDataImpl(from, to, false, std::numeric_limits<size_t>::max(), &is_cancelled, throttler); +} + +void copyDataWithThrottler(ReadBuffer & from, WriteBuffer & to, size_t bytes, const std::atomic<int> & is_cancelled, ThrottlerPtr throttler) +{ + copyDataImpl(from, to, true, bytes, &is_cancelled, throttler); +} + +} diff --git a/contrib/clickhouse/src/IO/copyData.h b/contrib/clickhouse/src/IO/copyData.h new file mode 100644 index 0000000000..b67088d8e4 --- /dev/null +++ b/contrib/clickhouse/src/IO/copyData.h @@ -0,0 +1,37 @@ +#pragma once + +#include <atomic> +#include <functional> + + +namespace DB +{ + +class ReadBuffer; +class WriteBuffer; +class Throttler; + +using ThrottlerPtr = std::shared_ptr<Throttler>; + + +/// Copies data from ReadBuffer to WriteBuffer, all that is. +void copyData(ReadBuffer & from, WriteBuffer & to); + +/// Copies `bytes` bytes from ReadBuffer to WriteBuffer. If there are no `bytes` bytes, then throws an exception. +void copyData(ReadBuffer & from, WriteBuffer & to, size_t bytes); + +/// The same, with the condition to cancel. +void copyData(ReadBuffer & from, WriteBuffer & to, const std::atomic<int> & is_cancelled); +void copyData(ReadBuffer & from, WriteBuffer & to, size_t bytes, const std::atomic<int> & is_cancelled); + +void copyData(ReadBuffer & from, WriteBuffer & to, std::function<void()> cancellation_hook); +void copyData(ReadBuffer & from, WriteBuffer & to, size_t bytes, std::function<void()> cancellation_hook); + +/// Copies at most `max_bytes` bytes from ReadBuffer to WriteBuffer. If there are more bytes, then throws an exception. +void copyDataMaxBytes(ReadBuffer & from, WriteBuffer & to, size_t max_bytes); + +/// Same as above but also use throttler to limit maximum speed +void copyDataWithThrottler(ReadBuffer & from, WriteBuffer & to, const std::atomic<int> & is_cancelled, ThrottlerPtr throttler); +void copyDataWithThrottler(ReadBuffer & from, WriteBuffer & to, size_t bytes, const std::atomic<int> & is_cancelled, ThrottlerPtr throttler); + +} diff --git a/contrib/clickhouse/src/IO/parseDateTimeBestEffort.cpp b/contrib/clickhouse/src/IO/parseDateTimeBestEffort.cpp new file mode 100644 index 0000000000..6bdba251c3 --- /dev/null +++ b/contrib/clickhouse/src/IO/parseDateTimeBestEffort.cpp @@ -0,0 +1,723 @@ +#include <Common/DateLUTImpl.h> +#include <Common/StringUtils/StringUtils.h> + +#include <IO/ReadBuffer.h> +#include <IO/ReadHelpers.h> +#include <IO/WriteHelpers.h> +#include <IO/parseDateTimeBestEffort.h> + +#include <limits> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int CANNOT_PARSE_DATETIME; +} + + +namespace +{ + +inline size_t readDigits(char * res, size_t max_chars, ReadBuffer & in) +{ + size_t num_chars = 0; + while (!in.eof() && isNumericASCII(*in.position()) && num_chars < max_chars) + { + res[num_chars] = *in.position() - '0'; + ++num_chars; + ++in.position(); + } + return num_chars; +} + +inline size_t readAlpha(char * res, size_t max_chars, ReadBuffer & in) +{ + size_t num_chars = 0; + while (!in.eof() && isAlphaASCII(*in.position()) && num_chars < max_chars) + { + res[num_chars] = *in.position(); + ++num_chars; + ++in.position(); + } + return num_chars; +} + +template <size_t digit, size_t power_of_ten, typename T> +inline void readDecimalNumberImpl(T & res, const char * src) +{ + res += src[digit] * power_of_ten; + if constexpr (digit > 0) + readDecimalNumberImpl<digit - 1, power_of_ten * 10>(res, src); +} + +template <size_t num_digits, typename T> +inline void readDecimalNumber(T & res, const char * src) +{ + readDecimalNumberImpl<num_digits - 1, 1>(res, src); +} + +template <typename T> +inline void readDecimalNumber(T & res, size_t num_digits, const char * src) +{ +#define READ_DECIMAL_NUMBER(N) do { res *= common::exp10_i32(N); readDecimalNumber<N>(res, src); src += (N); num_digits -= (N); } while (false) + while (num_digits) + { + switch (num_digits) + { + case 3: READ_DECIMAL_NUMBER(3); break; + case 2: READ_DECIMAL_NUMBER(2); break; + case 1: READ_DECIMAL_NUMBER(1); break; + default: READ_DECIMAL_NUMBER(4); break; + } + } +#undef READ_DECIMAL_NUMBER +} + +struct DateTimeSubsecondPart +{ + Int64 value; + UInt8 digits; +}; + +template <typename ReturnType, bool is_us_style> +ReturnType parseDateTimeBestEffortImpl( + time_t & res, + ReadBuffer & in, + const DateLUTImpl & local_time_zone, + const DateLUTImpl & utc_time_zone, + DateTimeSubsecondPart * fractional) +{ + auto on_error = [&]<typename... FmtArgs>(int error_code [[maybe_unused]], + FormatStringHelper<FmtArgs...> fmt_string [[maybe_unused]], + FmtArgs && ...fmt_args [[maybe_unused]]) + { + if constexpr (std::is_same_v<ReturnType, void>) + throw ParsingException(error_code, std::move(fmt_string), std::forward<FmtArgs>(fmt_args)...); + else + return false; + }; + + res = 0; + UInt16 year = 0; + UInt8 month = 0; + UInt8 day_of_month = 0; + UInt8 hour = 0; + UInt8 minute = 0; + UInt8 second = 0; + + bool has_time = false; + + bool has_time_zone_offset = false; + bool time_zone_offset_negative = false; + UInt8 time_zone_offset_hour = 0; + UInt8 time_zone_offset_minute = 0; + + bool is_am = false; + bool is_pm = false; + + bool has_comma_between_date_and_time = false; + + auto read_alpha_month = [&month] (const auto & alpha) + { + if (0 == strncasecmp(alpha, "Jan", 3)) month = 1; + else if (0 == strncasecmp(alpha, "Feb", 3)) month = 2; + else if (0 == strncasecmp(alpha, "Mar", 3)) month = 3; + else if (0 == strncasecmp(alpha, "Apr", 3)) month = 4; + else if (0 == strncasecmp(alpha, "May", 3)) month = 5; + else if (0 == strncasecmp(alpha, "Jun", 3)) month = 6; + else if (0 == strncasecmp(alpha, "Jul", 3)) month = 7; + else if (0 == strncasecmp(alpha, "Aug", 3)) month = 8; + else if (0 == strncasecmp(alpha, "Sep", 3)) month = 9; + else if (0 == strncasecmp(alpha, "Oct", 3)) month = 10; + else if (0 == strncasecmp(alpha, "Nov", 3)) month = 11; + else if (0 == strncasecmp(alpha, "Dec", 3)) month = 12; + else + return false; + return true; + }; + + while (!in.eof()) + { + if ((year && !has_time) || (!year && has_time)) + { + if (*in.position() == ',') + { + has_comma_between_date_and_time = true; + ++in.position(); + } + } + + char digits[std::numeric_limits<UInt64>::digits10]; + + size_t num_digits = 0; + + if (!year || !has_time) + { + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 13 && !year && !has_time) + { + /// This is unix timestamp with millisecond. + readDecimalNumber<10>(res, digits); + if (fractional) + { + fractional->digits = 3; + readDecimalNumber<3>(fractional->value, digits + 10); + } + return ReturnType(true); + } + else if (num_digits == 10 && !year && !has_time) + { + /// This is unix timestamp. + readDecimalNumber<10>(res, digits); + return ReturnType(true); + } + else if (num_digits == 9 && !year && !has_time) + { + /// This is unix timestamp. + readDecimalNumber<9>(res, digits); + return ReturnType(true); + } + else if (num_digits == 14 && !year && !has_time) + { + /// This is YYYYMMDDhhmmss + readDecimalNumber<4>(year, digits); + readDecimalNumber<2>(month, digits + 4); + readDecimalNumber<2>(day_of_month, digits + 6); + readDecimalNumber<2>(hour, digits + 8); + readDecimalNumber<2>(minute, digits + 10); + readDecimalNumber<2>(second, digits + 12); + has_time = true; + } + else if (num_digits == 8 && !year) + { + /// This is YYYYMMDD + readDecimalNumber<4>(year, digits); + readDecimalNumber<2>(month, digits + 4); + readDecimalNumber<2>(day_of_month, digits + 6); + } + else if (num_digits == 6) + { + /// This is YYYYMM or hhmmss + if (!year && !month) + { + readDecimalNumber<4>(year, digits); + readDecimalNumber<2>(month, digits + 4); + } + else if (!has_time) + { + readDecimalNumber<2>(hour, digits); + readDecimalNumber<2>(minute, digits + 2); + readDecimalNumber<2>(second, digits + 4); + has_time = true; + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: ambiguous 6 digits, it can be YYYYMM or hhmmss"); + } + else if (num_digits == 4 && !year) + { + /// YYYY + /// YYYY*MM + /// YYYY*MM*DD + /// YYYY*M + /// YYYY*M*DD + /// YYYY*M*D + + readDecimalNumber<4>(year, digits); + + if (!in.eof()) + { + char delimiter_after_year = *in.position(); + + if (delimiter_after_year < 0x20 + || delimiter_after_year == ',' + || delimiter_after_year == ';' + || delimiter_after_year == '\'' + || delimiter_after_year == '"') + break; + + if (month) + continue; + + ++in.position(); + + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 2) + readDecimalNumber<2>(month, digits); + else if (num_digits == 1) + readDecimalNumber<1>(month, digits); + else if (delimiter_after_year == ' ') + continue; + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after year: {}", num_digits); + + /// Only the same delimiter. + if (!day_of_month && checkChar(delimiter_after_year, in)) + { + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 2) + readDecimalNumber<2>(day_of_month, digits); + else if (num_digits == 1) + readDecimalNumber<1>(day_of_month, digits); + else if (delimiter_after_year == ' ') + continue; + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after year and month: {}", num_digits); + } + } + } + else if (num_digits == 2 || num_digits == 1) + { + /// hh:mm:ss + /// hh:mm + /// hh - only if already have day of month + /// DD/MM/YYYY + /// DD/MM/YY + /// DD.MM.YYYY + /// DD.MM.YY + /// DD-MM-YYYY + /// DD-MM-YY + /// DD + + UInt8 hour_or_day_of_month_or_month = 0; + if (num_digits == 2) + readDecimalNumber<2>(hour_or_day_of_month_or_month, digits); + else if (num_digits == 1) + readDecimalNumber<1>(hour_or_day_of_month_or_month, digits); + else + return on_error(ErrorCodes::LOGICAL_ERROR, "Cannot read DateTime: logical error, unexpected branch in code"); + + if (checkChar(':', in)) + { + if (has_time) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: time component is duplicated"); + + hour = hour_or_day_of_month_or_month; + has_time = true; + + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 2) + readDecimalNumber<2>(minute, digits); + else if (num_digits == 1) + readDecimalNumber<1>(minute, digits); + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after hour: {}", num_digits); + + if (checkChar(':', in)) + { + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 2) + readDecimalNumber<2>(second, digits); + else if (num_digits == 1) + readDecimalNumber<1>(second, digits); + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after hour and minute: {}", num_digits); + } + } + else if (checkChar(',', in)) + { + if (month && !day_of_month) + day_of_month = hour_or_day_of_month_or_month; + } + else if (checkChar('/', in) || checkChar('.', in) || checkChar('-', in)) + { + if (day_of_month) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: day of month is duplicated"); + + if (month) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: month is duplicated"); + + if constexpr (is_us_style) + { + month = hour_or_day_of_month_or_month; + num_digits = readDigits(digits, sizeof(digits), in); + if (num_digits == 2) + readDecimalNumber<2>(day_of_month, digits); + else if (num_digits == 1) + readDecimalNumber<1>(day_of_month, digits); + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after month: {}", num_digits); + } + else + { + day_of_month = hour_or_day_of_month_or_month; + + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 2) + readDecimalNumber<2>(month, digits); + else if (num_digits == 1) + readDecimalNumber<1>(month, digits); + else if (num_digits == 0) + { + /// Month in alphabetical form + + char alpha[9]; /// The longest month name: September + size_t num_alpha = readAlpha(alpha, sizeof(alpha), in); + + if (num_alpha < 3) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of alphabetical characters after day of month: {}", num_alpha); + + if (!read_alpha_month(alpha)) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: alphabetical characters after day of month don't look like month: {}", std::string(alpha, 3)); + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after day of month: {}", num_digits); + } + + if (month > 12) + std::swap(month, day_of_month); + + if (checkChar('/', in) || checkChar('.', in) || checkChar('-', in)) + { + if (year) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: year component is duplicated"); + + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 4) + readDecimalNumber<4>(year, digits); + else if (num_digits == 2) + { + readDecimalNumber<2>(year, digits); + + if (year >= 70) + year += 1900; + else + year += 2000; + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits after day of month and month: {}", num_digits); + } + } + else + { + if (day_of_month) + hour = hour_or_day_of_month_or_month; + else + day_of_month = hour_or_day_of_month_or_month; + } + } + else if (num_digits != 0) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits: {}", num_digits); + } + + if (num_digits == 0) + { + char c = *in.position(); + + /// 'T' is a separator between date and time according to ISO 8601. + /// But don't skip it if we didn't read the date part yet, because 'T' is also a prefix for 'Tue' and 'Thu'. + + if (c == ' ' || (c == 'T' && year && !has_time)) + { + ++in.position(); + } + else if (c == 'Z') + { + ++in.position(); + has_time_zone_offset = true; + } + else if (c == '.') /// We don't support comma (ISO 8601:2004) for fractional part of second to not mess up with CSV separator. + { + if (!has_time) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected point symbol"); + + ++in.position(); + num_digits = readDigits(digits, sizeof(digits), in); + if (fractional) + { + using FractionalType = typename std::decay<decltype(fractional->value)>::type; + // Reading more decimal digits than fits into FractionalType would case an + // overflow, so it is better to skip all digits from the right side that do not + // fit into result type. To provide less precise value rather than bogus one. + num_digits = std::min(static_cast<size_t>(std::numeric_limits<FractionalType>::digits10), num_digits); + + fractional->digits = num_digits; + readDecimalNumber(fractional->value, num_digits, digits); + } + } + else if (c == '+' || c == '-') + { + ++in.position(); + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 6 && !has_time && year && month && day_of_month) + { + /// It looks like hhmmss + readDecimalNumber<2>(hour, digits); + readDecimalNumber<2>(minute, digits + 2); + readDecimalNumber<2>(second, digits + 4); + has_time = true; + } + else + { + /// It looks like time zone offset + has_time_zone_offset = true; + if (c == '-') + time_zone_offset_negative = true; + + if (num_digits == 4) + { + readDecimalNumber<2>(time_zone_offset_hour, digits); + readDecimalNumber<2>(time_zone_offset_minute, digits + 2); + } + else if (num_digits == 3) + { + readDecimalNumber<1>(time_zone_offset_hour, digits); + readDecimalNumber<2>(time_zone_offset_minute, digits + 1); + } + else if (num_digits == 2) + { + readDecimalNumber<2>(time_zone_offset_hour, digits); + } + else if (num_digits == 1) + { + readDecimalNumber<1>(time_zone_offset_hour, digits); + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits for time zone offset: {}", num_digits); + + if (num_digits < 3 && checkChar(':', in)) + { + num_digits = readDigits(digits, sizeof(digits), in); + + if (num_digits == 2) + { + readDecimalNumber<2>(time_zone_offset_minute, digits); + } + else if (num_digits == 1) + { + readDecimalNumber<1>(time_zone_offset_minute, digits); + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected number of decimal digits for time zone offset in minutes: {}", num_digits); + } + } + } + else + { + char alpha[3]; + + size_t num_alpha = readAlpha(alpha, sizeof(alpha), in); + + if (!num_alpha) + { + break; + } + else if (num_alpha == 1) + { + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected alphabetical character"); + } + else if (num_alpha == 2) + { + if (alpha[1] == 'M' || alpha[1] == 'm') + { + if (alpha[0] == 'A' || alpha[0] == 'a') + { + is_am = true; + } + else if (alpha[0] == 'P' || alpha[0] == 'p') + { + is_pm = true; + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected word"); + } + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected word"); + } + else if (num_alpha == 3) + { + bool has_day_of_week = false; + + if (read_alpha_month(alpha)) + { + } + else if (0 == strncasecmp(alpha, "UTC", 3)) has_time_zone_offset = true; // NOLINT + else if (0 == strncasecmp(alpha, "GMT", 3)) has_time_zone_offset = true; + else if (0 == strncasecmp(alpha, "MSK", 3)) { has_time_zone_offset = true; time_zone_offset_hour = 3; } + else if (0 == strncasecmp(alpha, "MSD", 3)) { has_time_zone_offset = true; time_zone_offset_hour = 4; } + + else if (0 == strncasecmp(alpha, "Mon", 3)) has_day_of_week = true; // NOLINT + else if (0 == strncasecmp(alpha, "Tue", 3)) has_day_of_week = true; + else if (0 == strncasecmp(alpha, "Wed", 3)) has_day_of_week = true; + else if (0 == strncasecmp(alpha, "Thu", 3)) has_day_of_week = true; + else if (0 == strncasecmp(alpha, "Fri", 3)) has_day_of_week = true; + else if (0 == strncasecmp(alpha, "Sat", 3)) has_day_of_week = true; + else if (0 == strncasecmp(alpha, "Sun", 3)) has_day_of_week = true; + + else + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected word"); + + while (!in.eof() && isAlphaASCII(*in.position())) + ++in.position(); + + /// For RFC 2822 + if (has_day_of_week) + checkChar(',', in); + } + else + return on_error(ErrorCodes::LOGICAL_ERROR, "Cannot read DateTime: logical error, unexpected branch in code"); + } + } + } + + //// Date like '2022/03/04, ' should parse fail? + if (has_comma_between_date_and_time && (!has_time || !year || !month || !day_of_month)) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected word after Date"); + + /// If neither Date nor Time is parsed successfully, it should fail + if (!year && !month && !day_of_month && !has_time) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: neither Date nor Time was parsed successfully"); + + if (!day_of_month) + day_of_month = 1; + if (!month) + month = 1; + if (!year) + { + time_t now = time(nullptr); + UInt16 curr_year = local_time_zone.toYear(now); + year = now < local_time_zone.makeDateTime(curr_year, month, day_of_month, hour, minute, second) ? curr_year - 1 : curr_year; + } + + auto is_leap_year = (year % 400 == 0) || (year % 100 != 0 && year % 4 == 0); + + auto check_date = [](const auto & is_leap_year_, const auto & month_, const auto & day_) + { + if ((month_ == 1 || month_ == 3 || month_ == 5 || month_ == 7 || month_ == 8 || month_ == 10 || month_ == 12) && day_ >= 1 && day_ <= 31) + return true; + else if (month_ == 2 && ((is_leap_year_ && day_ >= 1 && day_ <= 29) || (!is_leap_year_ && day_ >= 1 && day_ <= 28))) + return true; + else if ((month_ == 4 || month_ == 6 || month_ == 9 || month_ == 11) && day_ >= 1 && day_ <= 30) + return true; + return false; + }; + + if (!check_date(is_leap_year, month, day_of_month)) + return on_error(ErrorCodes::CANNOT_PARSE_DATETIME, "Cannot read DateTime: unexpected date: {}-{}-{}", + year, static_cast<UInt16>(month), static_cast<UInt16>(day_of_month)); + + if (is_am && hour == 12) + hour = 0; + + if (is_pm && hour < 12) + hour += 12; + + auto adjust_time_zone = [&] + { + if (time_zone_offset_hour) + { + if (time_zone_offset_negative) + res += time_zone_offset_hour * 3600; + else + res -= time_zone_offset_hour * 3600; + } + + if (time_zone_offset_minute) + { + if (time_zone_offset_negative) + res += time_zone_offset_minute * 60; + else + res -= time_zone_offset_minute * 60; + } + }; + + if (has_time_zone_offset) + { + res = utc_time_zone.makeDateTime(year, month, day_of_month, hour, minute, second); + adjust_time_zone(); + } + else + { + res = local_time_zone.makeDateTime(year, month, day_of_month, hour, minute, second); + } + + return ReturnType(true); +} + +template <typename ReturnType, bool is_us_style> +ReturnType parseDateTime64BestEffortImpl(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + time_t whole; + DateTimeSubsecondPart subsecond = {0, 0}; // needs to be explicitly initialized sine it could be missing from input string + + if constexpr (std::is_same_v<ReturnType, bool>) + { + if (!parseDateTimeBestEffortImpl<bool, is_us_style>(whole, in, local_time_zone, utc_time_zone, &subsecond)) + return false; + } + else + { + parseDateTimeBestEffortImpl<ReturnType, is_us_style>(whole, in, local_time_zone, utc_time_zone, &subsecond); + } + + + DateTime64::NativeType fractional = subsecond.value; + if (scale < subsecond.digits) + { + fractional /= common::exp10_i64(subsecond.digits - scale); + } + else if (scale > subsecond.digits) + { + fractional *= common::exp10_i64(scale - subsecond.digits); + } + + if constexpr (std::is_same_v<ReturnType, bool>) + return DecimalUtils::tryGetDecimalFromComponents<DateTime64>(whole, fractional, scale, res); + + res = DecimalUtils::decimalFromComponents<DateTime64>(whole, fractional, scale); + return ReturnType(true); +} + +} + +void parseDateTimeBestEffort(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + parseDateTimeBestEffortImpl<void, false>(res, in, local_time_zone, utc_time_zone, nullptr); +} + +void parseDateTimeBestEffortUS(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + parseDateTimeBestEffortImpl<void, true>(res, in, local_time_zone, utc_time_zone, nullptr); +} + +bool tryParseDateTimeBestEffort(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + return parseDateTimeBestEffortImpl<bool, false>(res, in, local_time_zone, utc_time_zone, nullptr); +} + +bool tryParseDateTimeBestEffortUS(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + return parseDateTimeBestEffortImpl<bool, true>(res, in, local_time_zone, utc_time_zone, nullptr); +} + +void parseDateTime64BestEffort(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + return parseDateTime64BestEffortImpl<void, false>(res, scale, in, local_time_zone, utc_time_zone); +} + +void parseDateTime64BestEffortUS(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + return parseDateTime64BestEffortImpl<void, true>(res, scale, in, local_time_zone, utc_time_zone); +} + +bool tryParseDateTime64BestEffort(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + return parseDateTime64BestEffortImpl<bool, false>(res, scale, in, local_time_zone, utc_time_zone); +} + +bool tryParseDateTime64BestEffortUS(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone) +{ + return parseDateTime64BestEffortImpl<bool, true>(res, scale, in, local_time_zone, utc_time_zone); +} + +} diff --git a/contrib/clickhouse/src/IO/parseDateTimeBestEffort.h b/contrib/clickhouse/src/IO/parseDateTimeBestEffort.h new file mode 100644 index 0000000000..22af44f9e7 --- /dev/null +++ b/contrib/clickhouse/src/IO/parseDateTimeBestEffort.h @@ -0,0 +1,66 @@ +#pragma once +#include <stddef.h> +#include <time.h> + +#include <Core/Types.h> + +class DateLUTImpl; + +namespace DB +{ + +class ReadBuffer; + +/** https://xkcd.com/1179/ + * + * The existence of this function is an example of bad practice + * and contradicts our development principles. + * + * This function will recognize the following patterns: + * + * NNNNNNNNNN - 9..10 digits is a unix timestamp + * + * YYYYMMDDhhmmss - 14 numbers is always interpreted this way + * + * YYYYMMDD - 8 digits in a row + * YYYY*MM*DD - or with any delimiter after first 4-digit year component and after month. + * + * DD/MM/YY + * DD/MM/YYYY - when '/' separator is used, these are the only possible forms + * + * hh:mm:ss - when ':' separator is used, it is always time + * hh:mm - it can be specified without seconds + * + * YYYY - 4 digits is always year + * + * YYYYMM - 6 digits is a year, month if year was not already read + * hhmmss - 6 digits is a time if year was already read + * + * .nnnnnnn - any number of digits after point is fractional part of second, if it is not YYYY.MM.DD or DD.MM.YYYY + * + * T - means that time will follow + * + * Z - means zero UTC offset + * + * +hhmm + * +hh:mm + * +hh + * -... - time zone offset + * + * single whitespace can be used as a separator + * + * AM/PM - AM means: subtract 12 hours if a value is 12 and PM means: add 12 hours if a value is less than 12. + * + * Jan/Feb/Mar/Apr/May/Jun/Jul/Aug/Sep/Oct/Nov/Dec - allowed to specify month + * Mon/Tue/Wed/Thu/Fri/Sat/Sun - simply ignored. + */ + +void parseDateTimeBestEffort(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +bool tryParseDateTimeBestEffort(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +void parseDateTimeBestEffortUS(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +bool tryParseDateTimeBestEffortUS(time_t & res, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +void parseDateTime64BestEffort(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +bool tryParseDateTime64BestEffort(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +void parseDateTime64BestEffortUS(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +bool tryParseDateTime64BestEffortUS(DateTime64 & res, UInt32 scale, ReadBuffer & in, const DateLUTImpl & local_time_zone, const DateLUTImpl & utc_time_zone); +} diff --git a/contrib/clickhouse/src/IO/readDecimalText.h b/contrib/clickhouse/src/IO/readDecimalText.h new file mode 100644 index 0000000000..9fd9c439b8 --- /dev/null +++ b/contrib/clickhouse/src/IO/readDecimalText.h @@ -0,0 +1,227 @@ +#pragma once + +#include <limits> +#include <IO/ReadHelpers.h> +#include <Common/intExp.h> +#include <base/wide_integer_to_string.h> + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_PARSE_NUMBER; + extern const int ARGUMENT_OUT_OF_BOUND; +} + +/// Try to read Decimal into underlying type T from ReadBuffer. Throws if 'digits_only' is set and there's unexpected symbol in input. +/// Returns integer 'exponent' factor that x should be multiplied by to get correct Decimal value: result = x * 10^exponent. +/// Use 'digits' input as max allowed meaning decimal digits in result. Place actual number of meaning digits in 'digits' output. +/// Does not care about decimal scale, only about meaningful digits in decimal text representation. +template <bool _throw_on_error, typename T> +inline bool readDigits(ReadBuffer & buf, T & x, uint32_t & digits, int32_t & exponent, bool digits_only = false) +{ + x = T(0); + exponent = 0; + uint32_t max_digits = digits; + digits = 0; + uint32_t places = 0; + typename T::NativeType sign = 1; + bool leading_zeroes = true; + bool after_point = false; + + if (buf.eof()) + { + if constexpr (_throw_on_error) + throwReadAfterEOF(); + return false; + } + + switch (*buf.position()) + { + case '-': + sign = -1; + [[fallthrough]]; + case '+': + ++buf.position(); + break; + } + + bool stop = false; + while (!buf.eof() && !stop) + { + const char & byte = *buf.position(); + switch (byte) + { + case '.': + after_point = true; + leading_zeroes = false; + break; + case '0': + { + if (leading_zeroes) + break; + + if (after_point) + { + ++places; /// Count trailing zeroes. They would be used only if there's some other digit after them. + break; + } + [[fallthrough]]; + } + case '1': [[fallthrough]]; + case '2': [[fallthrough]]; + case '3': [[fallthrough]]; + case '4': [[fallthrough]]; + case '5': [[fallthrough]]; + case '6': [[fallthrough]]; + case '7': [[fallthrough]]; + case '8': [[fallthrough]]; + case '9': + { + leading_zeroes = false; + + ++places; // num zeroes before + current digit + if (digits + places > max_digits) + { + if (after_point) + { + /// Simply cut excessive digits. + break; + } + else + { + if constexpr (_throw_on_error) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Too many digits ({} > {}) in decimal value", + std::to_string(digits + places), std::to_string(max_digits)); + + return false; + } + } + else + { + digits += places; + if (after_point) + exponent -= places; + + // TODO: accurate shift10 for big integers + x *= intExp10OfSize<typename T::NativeType>(places); + places = 0; + + x += (byte - '0'); + break; + } + } + case 'e': [[fallthrough]]; + case 'E': + { + ++buf.position(); + Int32 addition_exp = 0; + if (!tryReadIntText(addition_exp, buf)) + { + if constexpr (_throw_on_error) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot parse exponent while reading decimal"); + else + return false; + } + exponent += addition_exp; + stop = true; + continue; + } + + default: + if (digits_only) + { + if constexpr (_throw_on_error) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Unexpected symbol while reading decimal"); + return false; + } + stop = true; + continue; + } + ++buf.position(); + } + + x *= sign; + return true; +} + +template <typename T, typename ReturnType=void> +inline ReturnType readDecimalText(ReadBuffer & buf, T & x, uint32_t precision, uint32_t & scale, bool digits_only = false) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + uint32_t digits = precision; + int32_t exponent; + auto ok = readDigits<throw_exception>(buf, x, digits, exponent, digits_only); + + if (!throw_exception && !ok) + return ReturnType(false); + + if (static_cast<int32_t>(digits) + exponent > static_cast<int32_t>(precision - scale)) + { + if constexpr (throw_exception) + { + static constexpr auto pattern = "Decimal value is too big: {} digits were read: {}e{}." + " Expected to read decimal with scale {} and precision {}"; + + if constexpr (is_big_int_v<typename T::NativeType>) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, pattern, digits, x.value, exponent, scale, precision); + else + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, pattern, digits, x, exponent, scale, precision); + } + else + return ReturnType(false); + } + + if (static_cast<int32_t>(scale) + exponent < 0) + { + auto divisor_exp = -exponent - static_cast<int32_t>(scale); + + if (divisor_exp >= std::numeric_limits<typename T::NativeType>::digits10) + { + /// Too big negative exponent + x.value = 0; + scale = 0; + return ReturnType(true); + } + else + { + /// Too many digits after point. Just cut off excessive digits. + auto divisor = intExp10OfSize<typename T::NativeType>(divisor_exp); + assert(divisor > 0); /// This is for Clang Static Analyzer. It is not smart enough to infer it automatically. + x.value /= divisor; + scale = 0; + return ReturnType(true); + } + } + + scale += exponent; + return ReturnType(true); +} + +template <typename T> +inline bool tryReadDecimalText(ReadBuffer & buf, T & x, uint32_t precision, uint32_t & scale) +{ + return readDecimalText<T, bool>(buf, x, precision, scale, true); +} + +template <typename T> +inline void readCSVDecimalText(ReadBuffer & buf, T & x, uint32_t precision, uint32_t & scale) +{ + if (buf.eof()) + throwReadAfterEOF(); + + char maybe_quote = *buf.position(); + + if (maybe_quote == '\'' || maybe_quote == '\"') + ++buf.position(); + + readDecimalText(buf, x, precision, scale, false); + + if (maybe_quote == '\'' || maybe_quote == '\"') + assertChar(maybe_quote, buf); +} + +} diff --git a/contrib/clickhouse/src/IO/readFloatText.cpp b/contrib/clickhouse/src/IO/readFloatText.cpp new file mode 100644 index 0000000000..d1143f7c62 --- /dev/null +++ b/contrib/clickhouse/src/IO/readFloatText.cpp @@ -0,0 +1,70 @@ +#include <IO/readFloatText.h> + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED; +} + +/** Must successfully parse inf, INF and Infinity. + * All other variants in different cases are also parsed for simplicity. + */ +bool parseInfinity(ReadBuffer & buf) +{ + if (!checkStringCaseInsensitive("inf", buf)) + return false; + + /// Just inf. + if (buf.eof() || !isWordCharASCII(*buf.position())) + return true; + + /// If word characters after inf, it should be infinity. + return checkStringCaseInsensitive("inity", buf); +} + + +/** Must successfully parse nan, NAN and NaN. + * All other variants in different cases are also parsed for simplicity. + */ +bool parseNaN(ReadBuffer & buf) +{ + return checkStringCaseInsensitive("nan", buf); +} + + +void assertInfinity(ReadBuffer & buf) +{ + if (!parseInfinity(buf)) + throw Exception(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "Cannot parse infinity."); +} + +void assertNaN(ReadBuffer & buf) +{ + if (!parseNaN(buf)) + throw Exception(ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED, "Cannot parse NaN."); +} + + +template void readFloatTextPrecise<Float32>(Float32 &, ReadBuffer &); +template void readFloatTextPrecise<Float64>(Float64 &, ReadBuffer &); +template bool tryReadFloatTextPrecise<Float32>(Float32 &, ReadBuffer &); +template bool tryReadFloatTextPrecise<Float64>(Float64 &, ReadBuffer &); + +template void readFloatTextFast<Float32>(Float32 &, ReadBuffer &); +template void readFloatTextFast<Float64>(Float64 &, ReadBuffer &); +template bool tryReadFloatTextFast<Float32>(Float32 &, ReadBuffer &); +template bool tryReadFloatTextFast<Float64>(Float64 &, ReadBuffer &); + +template void readFloatTextSimple<Float32>(Float32 &, ReadBuffer &); +template void readFloatTextSimple<Float64>(Float64 &, ReadBuffer &); +template bool tryReadFloatTextSimple<Float32>(Float32 &, ReadBuffer &); +template bool tryReadFloatTextSimple<Float64>(Float64 &, ReadBuffer &); + +template void readFloatText<Float32>(Float32 &, ReadBuffer &); +template void readFloatText<Float64>(Float64 &, ReadBuffer &); +template bool tryReadFloatText<Float32>(Float32 &, ReadBuffer &); +template bool tryReadFloatText<Float64>(Float64 &, ReadBuffer &); + +} diff --git a/contrib/clickhouse/src/IO/readFloatText.h b/contrib/clickhouse/src/IO/readFloatText.h new file mode 100644 index 0000000000..da4719b8dc --- /dev/null +++ b/contrib/clickhouse/src/IO/readFloatText.h @@ -0,0 +1,596 @@ +#pragma once +#include <type_traits> +#include <IO/ReadHelpers.h> +#include <Core/Defines.h> +#include <base/shift10.h> +#include <Common/StringUtils/StringUtils.h> +#include <double-conversion/double-conversion.h> + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunneeded-internal-declaration" +#endif +#include <fast_float/fast_float.h> +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +/** Methods for reading floating point numbers from text with decimal representation. + * There are "precise", "fast" and "simple" implementations. + * + * Neither of methods support hexadecimal numbers (0xABC), binary exponent (1p100), leading plus sign. + * + * Precise method always returns a number that is the closest machine representable number to the input. + * + * Fast method is faster (up to 3 times) and usually return the same value, + * but in rare cases result may differ by lest significant bit (for Float32) + * and by up to two least significant bits (for Float64) from precise method. + * Also fast method may parse some garbage as some other unspecified garbage. + * + * Simple method is little faster for cases of parsing short (few digit) integers, but less precise and slower in other cases. + * It's not recommended to use simple method and it is left only for reference. + * + * For performance test, look at 'read_float_perf' test. + * + * For precision test. + * Parse all existing Float32 numbers: + +CREATE TABLE test.floats ENGINE = Log AS SELECT reinterpretAsFloat32(reinterpretAsString(toUInt32(number))) AS x FROM numbers(0x100000000); + +WITH + toFloat32(toString(x)) AS y, + reinterpretAsUInt32(reinterpretAsString(x)) AS bin_x, + reinterpretAsUInt32(reinterpretAsString(y)) AS bin_y, + abs(bin_x - bin_y) AS diff +SELECT + diff, + count() +FROM test.floats +WHERE NOT isNaN(x) +GROUP BY diff +ORDER BY diff ASC +LIMIT 100 + + * Here are the results: + * + Precise: + ┌─diff─┬────count()─┐ + │ 0 │ 4278190082 │ + └──────┴────────────┘ + (100% roundtrip property) + + Fast: + ┌─diff─┬────count()─┐ + │ 0 │ 3685260580 │ + │ 1 │ 592929502 │ + └──────┴────────────┘ + (The difference is 1 in least significant bit in 13.8% of numbers.) + + Simple: + ┌─diff─┬────count()─┐ + │ 0 │ 2169879994 │ + │ 1 │ 1807178292 │ + │ 2 │ 269505944 │ + │ 3 │ 28826966 │ + │ 4 │ 2566488 │ + │ 5 │ 212878 │ + │ 6 │ 18276 │ + │ 7 │ 1214 │ + │ 8 │ 30 │ + └──────┴────────────┘ + + * Parse random Float64 numbers: + +WITH + rand64() AS bin_x, + reinterpretAsFloat64(reinterpretAsString(bin_x)) AS x, + toFloat64(toString(x)) AS y, + reinterpretAsUInt64(reinterpretAsString(y)) AS bin_y, + abs(bin_x - bin_y) AS diff +SELECT + diff, + count() +FROM numbers(100000000) +WHERE NOT isNaN(x) +GROUP BY diff +ORDER BY diff ASC +LIMIT 100 + + */ + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_PARSE_NUMBER; +} + + +/// Returns true, iff parsed. +bool parseInfinity(ReadBuffer & buf); +bool parseNaN(ReadBuffer & buf); + +void assertInfinity(ReadBuffer & buf); +void assertNaN(ReadBuffer & buf); + + +template <bool throw_exception> +bool assertOrParseInfinity(ReadBuffer & buf) +{ + if constexpr (throw_exception) + { + assertInfinity(buf); + return true; + } + else + return parseInfinity(buf); +} + +template <bool throw_exception> +bool assertOrParseNaN(ReadBuffer & buf) +{ + if constexpr (throw_exception) + { + assertNaN(buf); + return true; + } + else + return parseNaN(buf); +} + + +template <typename T, typename ReturnType> +ReturnType readFloatTextPreciseImpl(T & x, ReadBuffer & buf) +{ + static_assert(std::is_same_v<T, double> || std::is_same_v<T, float>, "Argument for readFloatTextPreciseImpl must be float or double"); + static_assert('a' > '.' && 'A' > '.' && '\n' < '.' && '\t' < '.' && '\'' < '.' && '"' < '.', "Layout of char is not like ASCII"); + + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + /// Fast path (avoid copying) if the buffer have at least MAX_LENGTH bytes. + static constexpr int MAX_LENGTH = 316; + + if (likely(!buf.eof() && buf.position() + MAX_LENGTH <= buf.buffer().end())) + { + auto * initial_position = buf.position(); + auto res = fast_float::from_chars(initial_position, buf.buffer().end(), x); + + if (unlikely(res.ec != std::errc())) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value"); + else + return ReturnType(false); + } + + buf.position() += res.ptr - initial_position; + + return ReturnType(true); + } + else + { + /// Slow path. Copy characters that may be present in floating point number to temporary buffer. + bool negative = false; + + /// We check eof here because we can parse +inf +nan + while (!buf.eof()) + { + switch (*buf.position()) + { + case '+': + ++buf.position(); + continue; + + case '-': + { + negative = true; + ++buf.position(); + continue; + } + + case 'i': [[fallthrough]]; + case 'I': + { + if (assertOrParseInfinity<throw_exception>(buf)) + { + x = std::numeric_limits<T>::infinity(); + if (negative) + x = -x; + return ReturnType(true); + } + return ReturnType(false); + } + + case 'n': [[fallthrough]]; + case 'N': + { + if (assertOrParseNaN<throw_exception>(buf)) + { + x = std::numeric_limits<T>::quiet_NaN(); + if (negative) + x = -x; + return ReturnType(true); + } + return ReturnType(false); + } + + default: + break; + } + + break; + } + + + char tmp_buf[MAX_LENGTH]; + int num_copied_chars = 0; + + while (!buf.eof() && num_copied_chars < MAX_LENGTH) + { + char c = *buf.position(); + if (!(isNumericASCII(c) || c == '-' || c == '+' || c == '.' || c == 'e' || c == 'E')) + break; + + tmp_buf[num_copied_chars] = c; + ++buf.position(); + ++num_copied_chars; + } + + auto res = fast_float::from_chars(tmp_buf, tmp_buf + num_copied_chars, x); + + if (unlikely(res.ec != std::errc())) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value"); + else + return ReturnType(false); + } + + if (negative) + x = -x; + + return ReturnType(true); + } +} + + +// credit: https://johnnylee-sde.github.io/Fast-numeric-string-to-int/ +static inline bool is_made_of_eight_digits_fast(uint64_t val) noexcept +{ + return (((val & 0xF0F0F0F0F0F0F0F0) | (((val + 0x0606060606060606) & 0xF0F0F0F0F0F0F0F0) >> 4)) == 0x3333333333333333); +} + +static inline bool is_made_of_eight_digits_fast(const char * chars) noexcept +{ + uint64_t val; + ::memcpy(&val, chars, 8); + return is_made_of_eight_digits_fast(val); +} + +template <size_t N, typename T> +static inline void readUIntTextUpToNSignificantDigits(T & x, ReadBuffer & buf) +{ + /// In optimistic case we can skip bound checking for first loop. + if (buf.position() + N <= buf.buffer().end()) + { + for (size_t i = 0; i < N; ++i) + { + if (isNumericASCII(*buf.position())) + { + x *= 10; + x += *buf.position() & 0x0F; + ++buf.position(); + } + else + return; + } + } + else + { + for (size_t i = 0; i < N; ++i) + { + if (!buf.eof() && isNumericASCII(*buf.position())) + { + x *= 10; + x += *buf.position() & 0x0F; + ++buf.position(); + } + else + return; + } + } + + while (!buf.eof() && (buf.position() + 8 <= buf.buffer().end()) && + is_made_of_eight_digits_fast(buf.position())) + { + buf.position() += 8; + } + + while (!buf.eof() && isNumericASCII(*buf.position())) + ++buf.position(); +} + + +template <typename T, typename ReturnType> +ReturnType readFloatTextFastImpl(T & x, ReadBuffer & in) +{ + static_assert(std::is_same_v<T, double> || std::is_same_v<T, float>, "Argument for readFloatTextImpl must be float or double"); + static_assert('a' > '.' && 'A' > '.' && '\n' < '.' && '\t' < '.' && '\'' < '.' && '"' < '.', "Layout of char is not like ASCII"); + + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + bool negative = false; + x = 0; + UInt64 before_point = 0; + UInt64 after_point = 0; + int after_point_exponent = 0; + int exponent = 0; + + if (in.eof()) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value"); + else + return false; + } + + if (*in.position() == '-') + { + negative = true; + ++in.position(); + } + else if (*in.position() == '+') + ++in.position(); + + auto count_after_sign = in.count(); + + constexpr int significant_digits = std::numeric_limits<UInt64>::digits10; + readUIntTextUpToNSignificantDigits<significant_digits>(before_point, in); + + size_t read_digits = in.count() - count_after_sign; + + if (unlikely(read_digits > significant_digits)) + { + int before_point_additional_exponent = static_cast<int>(read_digits) - significant_digits; + x = static_cast<T>(shift10(before_point, before_point_additional_exponent)); + } + else + { + x = before_point; + + /// Shortcut for the common case when there is an integer that fit in Int64. + if (read_digits && (in.eof() || *in.position() < '.')) + { + if (negative) + x = -x; + return ReturnType(true); + } + } + + if (checkChar('.', in)) + { + auto after_point_count = in.count(); + + while (!in.eof() && *in.position() == '0') + ++in.position(); + + auto after_leading_zeros_count = in.count(); + int after_point_num_leading_zeros = static_cast<int>(after_leading_zeros_count - after_point_count); + + readUIntTextUpToNSignificantDigits<significant_digits>(after_point, in); + read_digits = in.count() - after_leading_zeros_count; + after_point_exponent = (read_digits > significant_digits ? -significant_digits : static_cast<int>(-read_digits)) - after_point_num_leading_zeros; + } + + if (checkChar('e', in) || checkChar('E', in)) + { + if (in.eof()) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value: nothing after exponent"); + else + return false; + } + + bool exponent_negative = false; + if (*in.position() == '-') + { + exponent_negative = true; + ++in.position(); + } + else if (*in.position() == '+') + { + ++in.position(); + } + + readUIntTextUpToNSignificantDigits<4>(exponent, in); + if (exponent_negative) + exponent = -exponent; + } + + if (after_point) + x += static_cast<T>(shift10(after_point, after_point_exponent)); + + if (exponent) + x = static_cast<T>(shift10(x, exponent)); + + if (negative) + x = -x; + + auto num_characters_without_sign = in.count() - count_after_sign; + + /// Denormals. At most one character is read before denormal and it is '-'. + if (num_characters_without_sign == 0) + { + if (in.eof()) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value: no digits read"); + else + return false; + } + + if (*in.position() == '+') + { + ++in.position(); + if (in.eof()) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value: nothing after plus sign"); + else + return false; + } + else if (negative) + { + if constexpr (throw_exception) + throw ParsingException(ErrorCodes::CANNOT_PARSE_NUMBER, "Cannot read floating point value: plus after minus sign"); + else + return false; + } + } + + if (*in.position() == 'i' || *in.position() == 'I') + { + if (assertOrParseInfinity<throw_exception>(in)) + { + x = std::numeric_limits<T>::infinity(); + if (negative) + x = -x; + return ReturnType(true); + } + return ReturnType(false); + } + else if (*in.position() == 'n' || *in.position() == 'N') + { + if (assertOrParseNaN<throw_exception>(in)) + { + x = std::numeric_limits<T>::quiet_NaN(); + if (negative) + x = -x; + return ReturnType(true); + } + return ReturnType(false); + } + } + + return ReturnType(true); +} + +template <typename T, typename ReturnType> +ReturnType readFloatTextSimpleImpl(T & x, ReadBuffer & buf) +{ + static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; + + bool negative = false; + x = 0; + bool after_point = false; + T power_of_ten = 1; + + if (buf.eof()) + throwReadAfterEOF(); + + while (!buf.eof()) + { + switch (*buf.position()) + { + case '+': + break; + case '-': + negative = true; + break; + case '.': + after_point = true; + break; + case '0': [[fallthrough]]; + case '1': [[fallthrough]]; + case '2': [[fallthrough]]; + case '3': [[fallthrough]]; + case '4': [[fallthrough]]; + case '5': [[fallthrough]]; + case '6': [[fallthrough]]; + case '7': [[fallthrough]]; + case '8': [[fallthrough]]; + case '9': + if (after_point) + { + power_of_ten /= 10; + x += (*buf.position() - '0') * power_of_ten; + } + else + { + x *= 10; + x += *buf.position() - '0'; + } + break; + case 'e': [[fallthrough]]; + case 'E': + { + ++buf.position(); + Int32 exponent = 0; + readIntText(exponent, buf); + x = shift10(x, exponent); + if (negative) + x = -x; + return ReturnType(true); + } + + case 'i': [[fallthrough]]; + case 'I': + { + if (assertOrParseInfinity<throw_exception>(buf)) + { + x = std::numeric_limits<T>::infinity(); + if (negative) + x = -x; + return ReturnType(true); + } + return ReturnType(false); + } + + case 'n': [[fallthrough]]; + case 'N': + { + if (assertOrParseNaN<throw_exception>(buf)) + { + x = std::numeric_limits<T>::quiet_NaN(); + if (negative) + x = -x; + return ReturnType(true); + } + return ReturnType(false); + } + + default: + { + if (negative) + x = -x; + return ReturnType(true); + } + } + ++buf.position(); + } + + if (negative) + x = -x; + + return ReturnType(true); +} + +template <typename T> void readFloatTextPrecise(T & x, ReadBuffer & in) { readFloatTextPreciseImpl<T, void>(x, in); } +template <typename T> bool tryReadFloatTextPrecise(T & x, ReadBuffer & in) { return readFloatTextPreciseImpl<T, bool>(x, in); } + +template <typename T> void readFloatTextFast(T & x, ReadBuffer & in) { readFloatTextFastImpl<T, void>(x, in); } +template <typename T> bool tryReadFloatTextFast(T & x, ReadBuffer & in) { return readFloatTextFastImpl<T, bool>(x, in); } + +template <typename T> void readFloatTextSimple(T & x, ReadBuffer & in) { readFloatTextSimpleImpl<T, void>(x, in); } +template <typename T> bool tryReadFloatTextSimple(T & x, ReadBuffer & in) { return readFloatTextSimpleImpl<T, bool>(x, in); } + + +/// Implementation that is selected as default. + +template <typename T> void readFloatText(T & x, ReadBuffer & in) { readFloatTextFast(x, in); } +template <typename T> bool tryReadFloatText(T & x, ReadBuffer & in) { return tryReadFloatTextFast(x, in); } + +} |