diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/neh | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/neh')
64 files changed, 15089 insertions, 0 deletions
diff --git a/library/cpp/neh/README.md b/library/cpp/neh/README.md new file mode 100644 index 0000000000..962b3f67bc --- /dev/null +++ b/library/cpp/neh/README.md @@ -0,0 +1,15 @@ +Транспортная библиотека neh +=========================== + +Обеспечивает простой интерфейс для осуществления запросов по сети (request/response - client/server). Обеспечивает лёгкую смену транспортного протокола. +Есть несколько реализаций транспорта, каждая со своими плюсами/минусами. + +Документация +============ +https://wiki.yandex-team.ru/development/poisk/arcadia/library/neh/ + +FAQ +=== +Q: А давайте прикрутим SSL (поддержку https)! +A: ~~Этого не будет. neh - низкоуровневая шина, там не место ssl. Подробнее тут: https://clubs.at.yandex-team.ru/stackoverflow/5634~~ +A: Сделали diff --git a/library/cpp/neh/asio/asio.cpp b/library/cpp/neh/asio/asio.cpp new file mode 100644 index 0000000000..8b6cf383ea --- /dev/null +++ b/library/cpp/neh/asio/asio.cpp @@ -0,0 +1,187 @@ +#include "io_service_impl.h" +#include "deadline_timer_impl.h" +#include "tcp_socket_impl.h" +#include "tcp_acceptor_impl.h" + +using namespace NDns; +using namespace NAsio; + +namespace NAsio { + TIOService::TWork::TWork(TWork& w) + : Srv_(w.Srv_) + { + Srv_.GetImpl().WorkStarted(); + } + + TIOService::TWork::TWork(TIOService& srv) + : Srv_(srv) + { + Srv_.GetImpl().WorkStarted(); + } + + TIOService::TWork::~TWork() { + Srv_.GetImpl().WorkFinished(); + } + + TIOService::TIOService() + : Impl_(new TImpl()) + { + } + + TIOService::~TIOService() { + } + + void TIOService::Run() { + Impl_->Run(); + } + + void TIOService::Post(TCompletionHandler h) { + Impl_->Post(std::move(h)); + } + + void TIOService::Abort() { + Impl_->Abort(); + } + + TDeadlineTimer::TDeadlineTimer(TIOService& srv) noexcept + : Srv_(srv) + , Impl_(nullptr) + { + } + + TDeadlineTimer::~TDeadlineTimer() { + if (Impl_) { + Srv_.GetImpl().ScheduleOp(new TUnregisterTimerOperation(Impl_)); + } + } + + void TDeadlineTimer::AsyncWaitExpireAt(TDeadline deadline, THandler h) { + if (!Impl_) { + Impl_ = new TDeadlineTimer::TImpl(Srv_.GetImpl()); + Srv_.GetImpl().ScheduleOp(new TRegisterTimerOperation(Impl_)); + } + Impl_->AsyncWaitExpireAt(deadline, h); + } + + void TDeadlineTimer::Cancel() { + Impl_->Cancel(); + } + + TTcpSocket::TTcpSocket(TIOService& srv) noexcept + : Srv_(srv) + , Impl_(new TImpl(srv.GetImpl())) + { + } + + TTcpSocket::~TTcpSocket() { + } + + void TTcpSocket::AsyncConnect(const TEndpoint& ep, TTcpSocket::TConnectHandler h, TDeadline deadline) { + Impl_->AsyncConnect(ep, h, deadline); + } + + void TTcpSocket::AsyncWrite(TSendedData& d, TTcpSocket::TWriteHandler h, TDeadline deadline) { + Impl_->AsyncWrite(d, h, deadline); + } + + void TTcpSocket::AsyncWrite(TContIOVector* vec, TWriteHandler h, TDeadline deadline) { + Impl_->AsyncWrite(vec, h, deadline); + } + + void TTcpSocket::AsyncWrite(const void* data, size_t size, TWriteHandler h, TDeadline deadline) { + class TBuffers: public IBuffers { + public: + TBuffers(const void* theData, size_t theSize) + : Part(theData, theSize) + , IOVec(&Part, 1) + { + } + + TContIOVector* GetIOvec() override { + return &IOVec; + } + + IOutputStream::TPart Part; + TContIOVector IOVec; + }; + + TSendedData d(new TBuffers(data, size)); + Impl_->AsyncWrite(d, h, deadline); + } + + void TTcpSocket::AsyncRead(void* buff, size_t size, TTcpSocket::TReadHandler h, TDeadline deadline) { + Impl_->AsyncRead(buff, size, h, deadline); + } + + void TTcpSocket::AsyncReadSome(void* buff, size_t size, TTcpSocket::TReadHandler h, TDeadline deadline) { + Impl_->AsyncReadSome(buff, size, h, deadline); + } + + void TTcpSocket::AsyncPollRead(TTcpSocket::TPollHandler h, TDeadline deadline) { + Impl_->AsyncPollRead(h, deadline); + } + + void TTcpSocket::AsyncPollWrite(TTcpSocket::TPollHandler h, TDeadline deadline) { + Impl_->AsyncPollWrite(h, deadline); + } + + void TTcpSocket::AsyncCancel() { + return Impl_->AsyncCancel(); + } + + size_t TTcpSocket::WriteSome(TContIOVector& d, TErrorCode& ec) noexcept { + return Impl_->WriteSome(d, ec); + } + + size_t TTcpSocket::WriteSome(const void* buff, size_t size, TErrorCode& ec) noexcept { + return Impl_->WriteSome(buff, size, ec); + } + + size_t TTcpSocket::ReadSome(void* buff, size_t size, TErrorCode& ec) noexcept { + return Impl_->ReadSome(buff, size, ec); + } + + bool TTcpSocket::IsOpen() const noexcept { + return Native() != INVALID_SOCKET; + } + + void TTcpSocket::Shutdown(TShutdownMode what, TErrorCode& ec) { + return Impl_->Shutdown(what, ec); + } + + SOCKET TTcpSocket::Native() const noexcept { + return Impl_->Fd(); + } + + TEndpoint TTcpSocket::RemoteEndpoint() const { + return Impl_->RemoteEndpoint(); + } + + ////////////////////////////////// + + TTcpAcceptor::TTcpAcceptor(TIOService& srv) noexcept + : Srv_(srv) + , Impl_(new TImpl(srv.GetImpl())) + { + } + + TTcpAcceptor::~TTcpAcceptor() { + } + + void TTcpAcceptor::Bind(TEndpoint& ep, TErrorCode& ec) noexcept { + return Impl_->Bind(ep, ec); + } + + void TTcpAcceptor::Listen(int backlog, TErrorCode& ec) noexcept { + return Impl_->Listen(backlog, ec); + } + + void TTcpAcceptor::AsyncAccept(TTcpSocket& s, TTcpAcceptor::TAcceptHandler h, TDeadline deadline) { + return Impl_->AsyncAccept(s, h, deadline); + } + + void TTcpAcceptor::AsyncCancel() { + Impl_->AsyncCancel(); + } + +} diff --git a/library/cpp/neh/asio/asio.h b/library/cpp/neh/asio/asio.h new file mode 100644 index 0000000000..a902d663cf --- /dev/null +++ b/library/cpp/neh/asio/asio.h @@ -0,0 +1,280 @@ +#pragma once + +// +//primary header for work with asio +// + +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/network/socket.h> +#include <util/network/endpoint.h> +#include <util/system/error.h> +#include <util/stream/output.h> +#include <functional> + +#include <library/cpp/dns/cache.h> + +//#define DEBUG_ASIO + +class TContIOVector; + +namespace NAsio { + class TErrorCode { + public: + inline TErrorCode(int val = 0) noexcept + : Val_(val) + { + } + + typedef void (*TUnspecifiedBoolType)(); + + static void UnspecifiedBoolTrue() { + } + + //safe cast to bool value + operator TUnspecifiedBoolType() const noexcept { // true if error + return Val_ == 0 ? nullptr : UnspecifiedBoolTrue; + } + + bool operator!() const noexcept { + return Val_ == 0; + } + + void Assign(int val) noexcept { + Val_ = val; + } + + int Value() const noexcept { + return Val_; + } + + TString Text() const { + if (!Val_) { + return TString(); + } + return LastSystemErrorText(Val_); + } + + void Check() { + if (Val_) { + throw TSystemError(Val_); + } + } + + private: + int Val_; + }; + + //wrapper for TInstant, for enabling use TDuration (+TInstant::Now()) as deadline + class TDeadline: public TInstant { + public: + TDeadline() + : TInstant(TInstant::Max()) + { + } + + TDeadline(const TInstant& t) + : TInstant(t) + { + } + + TDeadline(const TDuration& d) + : TInstant(TInstant::Now() + d) + { + } + }; + + class IHandlingContext { + public: + virtual ~IHandlingContext() { + } + + //if handler throw exception, call this function be ignored + virtual void ContinueUseHandler(TDeadline deadline = TDeadline()) = 0; + }; + + typedef std::function<void()> TCompletionHandler; + + class TIOService: public TNonCopyable { + public: + TIOService(); + ~TIOService(); + + void Run(); + void Post(TCompletionHandler); //call handler in Run() thread-executor + void Abort(); //in Run() all exist async i/o operations + timers receive error = ECANCELED, Run() exited + + //counterpart boost::asio::io_service::work + class TWork { + public: + TWork(TWork&); + TWork(TIOService&); + ~TWork(); + + private: + void operator=(const TWork&); //disable + + TIOService& Srv_; + }; + + class TImpl; + + TImpl& GetImpl() noexcept { + return *Impl_; + } + + private: + THolder<TImpl> Impl_; + }; + + class TDeadlineTimer: public TNonCopyable { + public: + typedef std::function<void(const TErrorCode& err, IHandlingContext&)> THandler; + + TDeadlineTimer(TIOService&) noexcept; + ~TDeadlineTimer(); + + void AsyncWaitExpireAt(TDeadline, THandler); + void Cancel(); + + TIOService& GetIOService() const noexcept { + return Srv_; + } + + class TImpl; + + private: + TIOService& Srv_; + TImpl* Impl_; + }; + + class TTcpSocket: public TNonCopyable { + public: + class IBuffers { + public: + virtual ~IBuffers() { + } + virtual TContIOVector* GetIOvec() = 0; + }; + typedef TAutoPtr<IBuffers> TSendedData; + + typedef std::function<void(const TErrorCode& err, IHandlingContext&)> THandler; + typedef THandler TConnectHandler; + typedef std::function<void(const TErrorCode& err, size_t amount, IHandlingContext&)> TWriteHandler; + typedef std::function<void(const TErrorCode& err, size_t amount, IHandlingContext&)> TReadHandler; + typedef THandler TPollHandler; + + enum TShutdownMode { + ShutdownReceive = SHUT_RD, + ShutdownSend = SHUT_WR, + ShutdownBoth = SHUT_RDWR + }; + + TTcpSocket(TIOService&) noexcept; + ~TTcpSocket(); + + void AsyncConnect(const TEndpoint& ep, TConnectHandler, TDeadline deadline = TDeadline()); + void AsyncWrite(TSendedData&, TWriteHandler, TDeadline deadline = TDeadline()); + void AsyncWrite(TContIOVector* buff, TWriteHandler, TDeadline deadline = TDeadline()); + void AsyncWrite(const void* buff, size_t size, TWriteHandler, TDeadline deadline = TDeadline()); + void AsyncRead(void* buff, size_t size, TReadHandler, TDeadline deadline = TDeadline()); + void AsyncReadSome(void* buff, size_t size, TReadHandler, TDeadline deadline = TDeadline()); + void AsyncPollWrite(TPollHandler, TDeadline deadline = TDeadline()); + void AsyncPollRead(TPollHandler, TDeadline deadline = TDeadline()); + void AsyncCancel(); + + //sync, but non blocked methods + size_t WriteSome(TContIOVector&, TErrorCode&) noexcept; + size_t WriteSome(const void* buff, size_t size, TErrorCode&) noexcept; + size_t ReadSome(void* buff, size_t size, TErrorCode&) noexcept; + + bool IsOpen() const noexcept; + void Shutdown(TShutdownMode mode, TErrorCode& ec); + + TIOService& GetIOService() const noexcept { + return Srv_; + } + + SOCKET Native() const noexcept; + + TEndpoint RemoteEndpoint() const; + + inline size_t WriteSome(TContIOVector& v) { + TErrorCode ec; + size_t n = WriteSome(v, ec); + ec.Check(); + return n; + } + + inline size_t WriteSome(const void* buff, size_t size) { + TErrorCode ec; + size_t n = WriteSome(buff, size, ec); + ec.Check(); + return n; + } + + inline size_t ReadSome(void* buff, size_t size) { + TErrorCode ec; + size_t n = ReadSome(buff, size, ec); + ec.Check(); + return n; + } + + void Shutdown(TShutdownMode mode) { + TErrorCode ec; + Shutdown(mode, ec); + ec.Check(); + } + + class TImpl; + + TImpl& GetImpl() const noexcept { + return *Impl_; + } + + private: + TIOService& Srv_; + TIntrusivePtr<TImpl> Impl_; + }; + + class TTcpAcceptor: public TNonCopyable { + public: + typedef std::function<void(const TErrorCode& err, IHandlingContext&)> TAcceptHandler; + + TTcpAcceptor(TIOService&) noexcept; + ~TTcpAcceptor(); + + void Bind(TEndpoint&, TErrorCode&) noexcept; + void Listen(int backlog, TErrorCode&) noexcept; + + void AsyncAccept(TTcpSocket&, TAcceptHandler, TDeadline deadline = TDeadline()); + + void AsyncCancel(); + + inline void Bind(TEndpoint& ep) { + TErrorCode ec; + Bind(ep, ec); + ec.Check(); + } + inline void Listen(int backlog) { + TErrorCode ec; + Listen(backlog, ec); + ec.Check(); + } + + TIOService& GetIOService() const noexcept { + return Srv_; + } + + class TImpl; + + TImpl& GetImpl() const noexcept { + return *Impl_; + } + + private: + TIOService& Srv_; + TIntrusivePtr<TImpl> Impl_; + }; +} diff --git a/library/cpp/neh/asio/deadline_timer_impl.cpp b/library/cpp/neh/asio/deadline_timer_impl.cpp new file mode 100644 index 0000000000..399a4338fb --- /dev/null +++ b/library/cpp/neh/asio/deadline_timer_impl.cpp @@ -0,0 +1 @@ +#include "deadline_timer_impl.h" diff --git a/library/cpp/neh/asio/deadline_timer_impl.h b/library/cpp/neh/asio/deadline_timer_impl.h new file mode 100644 index 0000000000..d9db625c94 --- /dev/null +++ b/library/cpp/neh/asio/deadline_timer_impl.h @@ -0,0 +1,110 @@ +#pragma once + +#include "io_service_impl.h" + +namespace NAsio { + class TTimerOperation: public TOperation { + public: + TTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline) + : TOperation(deadline) + , T_(t) + { + } + + void AddOp(TIOService::TImpl&) override { + Y_ASSERT(0); + } + + void Finalize() override { + DBGOUT("TTimerDeadlineOperation::Finalize()"); + T_->DelOp(this); + } + + protected: + TIOService::TImpl::TTimer* T_; + }; + + class TRegisterTimerOperation: public TTimerOperation { + public: + TRegisterTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline = TInstant::Max()) + : TTimerOperation(t, deadline) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + T_->GetIOServiceImpl().SyncRegisterTimer(T_); + return true; + } + }; + + class TTimerDeadlineOperation: public TTimerOperation { + public: + TTimerDeadlineOperation(TIOService::TImpl::TTimer* t, TDeadlineTimer::THandler h, TInstant deadline) + : TTimerOperation(t, deadline) + , H_(h) + { + } + + void AddOp(TIOService::TImpl&) override { + T_->AddOp(this); + } + + bool Execute(int errorCode) override { + DBGOUT("TTimerDeadlineOperation::Execute(" << errorCode << ")"); + H_(errorCode == ETIMEDOUT ? 0 : errorCode, *this); + return true; + } + + private: + TDeadlineTimer::THandler H_; + }; + + class TCancelTimerOperation: public TTimerOperation { + public: + TCancelTimerOperation(TIOService::TImpl::TTimer* t) + : TTimerOperation(t, TInstant::Max()) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + T_->FailOperations(ECANCELED); + return true; + } + }; + + class TUnregisterTimerOperation: public TTimerOperation { + public: + TUnregisterTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline = TInstant::Max()) + : TTimerOperation(t, deadline) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + DBGOUT("TUnregisterTimerOperation::Execute(" << errorCode << ")"); + T_->GetIOServiceImpl().SyncUnregisterAndDestroyTimer(T_); + return true; + } + }; + + class TDeadlineTimer::TImpl: public TIOService::TImpl::TTimer { + public: + TImpl(TIOService::TImpl& srv) + : TIOService::TImpl::TTimer(srv) + { + } + + void AsyncWaitExpireAt(TDeadline d, TDeadlineTimer::THandler h) { + Srv_.ScheduleOp(new TTimerDeadlineOperation(this, h, d)); + } + + void Cancel() { + Srv_.ScheduleOp(new TCancelTimerOperation(this)); + } + }; +} diff --git a/library/cpp/neh/asio/executor.cpp b/library/cpp/neh/asio/executor.cpp new file mode 100644 index 0000000000..03b26bf847 --- /dev/null +++ b/library/cpp/neh/asio/executor.cpp @@ -0,0 +1 @@ +#include "executor.h" diff --git a/library/cpp/neh/asio/executor.h b/library/cpp/neh/asio/executor.h new file mode 100644 index 0000000000..4f6549044d --- /dev/null +++ b/library/cpp/neh/asio/executor.h @@ -0,0 +1,76 @@ +#pragma once + +#include "asio.h" + +#include <library/cpp/deprecated/atomic/atomic.h> + +#include <util/thread/factory.h> +#include <util/system/thread.h> + +namespace NAsio { + class TIOServiceExecutor: public IThreadFactory::IThreadAble { + public: + TIOServiceExecutor() + : Work_(new TIOService::TWork(Srv_)) + { + T_ = SystemThreadFactory()->Run(this); + } + + ~TIOServiceExecutor() override { + SyncShutdown(); + } + + void DoExecute() override { + TThread::SetCurrentThreadName("NehAsioExecutor"); + Srv_.Run(); + } + + inline TIOService& GetIOService() noexcept { + return Srv_; + } + + void SyncShutdown() { + if (Work_) { + Work_.Destroy(); + Srv_.Abort(); //cancel all async operations, break Run() execution + T_->Join(); + } + } + + private: + TIOService Srv_; + TAutoPtr<TIOService::TWork> Work_; + typedef TAutoPtr<IThreadFactory::IThread> IThreadRef; + IThreadRef T_; + }; + + class TExecutorsPool { + public: + TExecutorsPool(size_t executors) + : C_(0) + { + for (size_t i = 0; i < executors; ++i) { + E_.push_back(new TIOServiceExecutor()); + } + } + + inline size_t Size() const noexcept { + return E_.size(); + } + + inline TIOServiceExecutor& GetExecutor() noexcept { + TAtomicBase next = AtomicIncrement(C_); + return *E_[next % E_.size()]; + } + + void SyncShutdown() { + for (size_t i = 0; i < E_.size(); ++i) { + E_[i]->SyncShutdown(); + } + } + + private: + TAtomic C_; + TVector<TAutoPtr<TIOServiceExecutor>> E_; + }; +} diff --git a/library/cpp/neh/asio/io_service_impl.cpp b/library/cpp/neh/asio/io_service_impl.cpp new file mode 100644 index 0000000000..d49b3fb03e --- /dev/null +++ b/library/cpp/neh/asio/io_service_impl.cpp @@ -0,0 +1,161 @@ +#include "io_service_impl.h" + +#include <library/cpp/coroutine/engine/poller.h> + +using namespace NAsio; + +void TFdOperation::AddOp(TIOService::TImpl& srv) { + srv.AddOp(this); +} + +void TFdOperation::Finalize() { + (*PH_)->DelOp(this); +} + +void TPollFdEventHandler::ExecuteOperations(TFdOperations& oprs, int errorCode) { + TFdOperations::iterator it = oprs.begin(); + + try { + while (it != oprs.end()) { + TFdOperation* op = it->Get(); + + if (op->Execute(errorCode)) { // throw ? + if (op->IsRequiredRepeat()) { + Srv_.UpdateOpDeadline(op); + ++it; //operation completed, but want be repeated + } else { + FinishedOperations_.push_back(*it); + it = oprs.erase(it); + } + } else { + ++it; //operation not completed + } + } + } catch (...) { + if (it != oprs.end()) { + FinishedOperations_.push_back(*it); + oprs.erase(it); + } + throw; + } +} + +void TPollFdEventHandler::DelOp(TFdOperation* op) { + TAutoPtr<TPollFdEventHandler>& evh = *op->PH_; + + if (op->IsPollRead()) { + Y_ASSERT(FinishOp(ReadOperations_, op)); + } else { + Y_ASSERT(FinishOp(WriteOperations_, op)); + } + Srv_.FixHandledEvents(evh); //alarm, - 'this' can be destroyed here! +} + +void TInterrupterHandler::OnFdEvent(int status, ui16 filter) { + if (!status && (filter & CONT_POLL_READ)) { + PI_.Reset(); + } +} + +void TIOService::TImpl::Run() { + TEvh& iEvh = Evh_.Get(I_.Fd()); + iEvh.Reset(new TInterrupterHandler(*this, I_)); + + TInterrupterKeeper ik(*this, iEvh); + Y_UNUSED(ik); + IPollerFace::TEvents evs; + AtomicSet(NeedCheckOpQueue_, 1); + TInstant deadline; + + while (Y_LIKELY(!Aborted_ && (AtomicGet(OutstandingWork_) || FdEventHandlersCnt_ > 1 || TimersOpCnt_ || AtomicGet(NeedCheckOpQueue_)))) { + //while + // expected work (external flag) + // or have event handlers (exclude interrupter) + // or have not completed timer operation + // or have any operation in queues + + AtomicIncrement(IsWaiting_); + if (!AtomicGet(NeedCheckOpQueue_)) { + P_->Wait(evs, deadline); + } + AtomicDecrement(IsWaiting_); + + if (evs.size()) { + for (IPollerFace::TEvents::const_iterator iev = evs.begin(); iev != evs.end() && !Aborted_; ++iev) { + const IPollerFace::TEvent& ev = *iev; + TEvh& evh = *(TEvh*)ev.Data; + + if (!evh) { + continue; //op. cancel (see ProcessOpQueue) can destroy evh + } + + int status = ev.Status; + if (ev.Status == EIO) { + int error = status; + if (GetSockOpt(evh->Fd(), SOL_SOCKET, SO_ERROR, error) == 0) { + status = error; + } + } + + OnFdEvent(evh, status, ev.Filter); //here handle fd events + //immediatly after handling events for one descriptor check op. queue + //often queue can contain another operation for this fd (next async read as sample) + //so we can optimize redundant epoll_ctl (or similar) calls + ProcessOpQueue(); + } + + evs.clear(); + } else { + ProcessOpQueue(); + } + + deadline = DeadlinesQueue_.NextDeadline(); //here handle timeouts/process timers + } +} + +void TIOService::TImpl::Abort() { + class TAbortOperation: public TNoneOperation { + public: + TAbortOperation(TIOService::TImpl& srv) + : TNoneOperation() + , Srv_(srv) + { + Speculative_ = true; + } + + private: + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + Srv_.ProcessAbort(); + return true; + } + + TIOService::TImpl& Srv_; + }; + AtomicSet(HasAbort_, 1); + ScheduleOp(new TAbortOperation(*this)); +} + +void TIOService::TImpl::ProcessAbort() { + Aborted_ = true; + + for (int fd = 0; fd <= MaxFd_; ++fd) { + TEvh& evh = Evh_.Get(fd); + if (!!evh && evh->Fd() != I_.Fd()) { + OnFdEvent(evh, ECANCELED, CONT_POLL_READ | CONT_POLL_WRITE); + } + } + + for (auto t : Timers_) { + t->FailOperations(ECANCELED); + } + + TOperationPtr op; + while (OpQueue_.Dequeue(&op)) { //cancel all enqueued operations + try { + op->Execute(ECANCELED); + } catch (...) { + } + op.Destroy(); + } +} diff --git a/library/cpp/neh/asio/io_service_impl.h b/library/cpp/neh/asio/io_service_impl.h new file mode 100644 index 0000000000..46fa9f9ee1 --- /dev/null +++ b/library/cpp/neh/asio/io_service_impl.h @@ -0,0 +1,744 @@ +#pragma once + +#include "asio.h" +#include "poll_interrupter.h" + +#include <library/cpp/neh/lfqueue.h> +#include <library/cpp/neh/pipequeue.h> + +#include <library/cpp/dns/cache.h> + +#include <util/generic/hash_set.h> +#include <util/network/iovec.h> +#include <util/network/pollerimpl.h> +#include <util/thread/lfqueue.h> +#include <util/thread/factory.h> + +#ifdef DEBUG_ASIO +#define DBGOUT(args) Cout << args << Endl; +#else +#define DBGOUT(args) +#endif + +namespace NAsio { + //TODO: copypaste from neh, - need fix + template <class T> + class TLockFreeSequence { + public: + inline TLockFreeSequence() { + memset((void*)T_, 0, sizeof(T_)); + } + + inline ~TLockFreeSequence() { + for (size_t i = 0; i < Y_ARRAY_SIZE(T_); ++i) { + delete[] T_[i]; + } + } + + inline T& Get(size_t n) { + const size_t i = GetValueBitCount(n + 1) - 1; + + return GetList(i)[n + 1 - (((size_t)1) << i)]; + } + + private: + inline T* GetList(size_t n) { + T* volatile* t = T_ + n; + + while (!*t) { + TArrayHolder<T> nt(new T[((size_t)1) << n]); + + if (AtomicCas(t, nt.Get(), nullptr)) { + return nt.Release(); + } + } + + return *t; + } + + private: + T* volatile T_[sizeof(size_t) * 8]; + }; + + struct TOperationCompare { + template <class T> + static inline bool Compare(const T& l, const T& r) noexcept { + return l.DeadLine() < r.DeadLine() || (l.DeadLine() == r.DeadLine() && &l < &r); + } + }; + + //async operation, execute in contex TIOService()::Run() thread-executor + //usualy used for call functors/callbacks + class TOperation: public TRbTreeItem<TOperation, TOperationCompare>, public IHandlingContext { + public: + TOperation(TInstant deadline = TInstant::Max()) + : D_(deadline) + , Speculative_(false) + , RequiredRepeatExecution_(false) + , ND_(deadline) + { + } + + //register this operation in svc.impl. + virtual void AddOp(TIOService::TImpl&) = 0; + + //return false, if operation not completed + virtual bool Execute(int errorCode = 0) = 0; + + void ContinueUseHandler(TDeadline deadline) override { + RequiredRepeatExecution_ = true; + ND_ = deadline; + } + + virtual void Finalize() = 0; + + inline TInstant Deadline() const noexcept { + return D_; + } + + inline TInstant DeadLine() const noexcept { + return D_; + } + + inline bool Speculative() const noexcept { + return Speculative_; + } + + inline bool IsRequiredRepeat() const noexcept { + return RequiredRepeatExecution_; + } + + inline void PrepareReExecution() noexcept { + RequiredRepeatExecution_ = false; + D_ = ND_; + } + + protected: + TInstant D_; + bool Speculative_; //if true, operation will be runned immediately after dequeue (even without wating any event) + //as sample used for optimisation writing, - obviously in buffers exist space for write + bool RequiredRepeatExecution_; //set to true, if required re-exec operation + TInstant ND_; //new deadline (for re-exec operation) + }; + + typedef TAutoPtr<TOperation> TOperationPtr; + + class TNoneOperation: public TOperation { + public: + TNoneOperation(TInstant deadline = TInstant::Max()) + : TOperation(deadline) + { + } + + void AddOp(TIOService::TImpl&) override { + Y_ASSERT(0); + } + + void Finalize() override { + } + }; + + class TPollFdEventHandler; + + //descriptor use operation + class TFdOperation: public TOperation { + public: + enum TPollType { + PollRead, + PollWrite + }; + + TFdOperation(SOCKET fd, TPollType pt, TInstant deadline = TInstant::Max()) + : TOperation(deadline) + , Fd_(fd) + , PT_(pt) + , PH_(nullptr) + { + Y_ASSERT(Fd() != INVALID_SOCKET); + } + + inline SOCKET Fd() const noexcept { + return Fd_; + } + + inline bool IsPollRead() const noexcept { + return PT_ == PollRead; + } + + void AddOp(TIOService::TImpl& srv) override; + + void Finalize() override; + + protected: + SOCKET Fd_; + TPollType PT_; + + public: + TAutoPtr<TPollFdEventHandler>* PH_; + }; + + typedef TAutoPtr<TFdOperation> TFdOperationPtr; + + class TPollFdEventHandler { + public: + TPollFdEventHandler(SOCKET fd, TIOService::TImpl& srv) + : Fd_(fd) + , HandledEvents_(0) + , Srv_(srv) + { + } + + virtual ~TPollFdEventHandler() { + Y_ASSERT(ReadOperations_.size() == 0); + Y_ASSERT(WriteOperations_.size() == 0); + } + + inline void AddReadOp(TFdOperationPtr op) { + ReadOperations_.push_back(op); + } + + inline void AddWriteOp(TFdOperationPtr op) { + WriteOperations_.push_back(op); + } + + virtual void OnFdEvent(int status, ui16 filter) { + DBGOUT("PollEvent(fd=" << Fd_ << ", " << status << ", " << filter << ")"); + if (status) { + ExecuteOperations(ReadOperations_, status); + ExecuteOperations(WriteOperations_, status); + } else { + if (filter & CONT_POLL_READ) { + ExecuteOperations(ReadOperations_, status); + } + if (filter & CONT_POLL_WRITE) { + ExecuteOperations(WriteOperations_, status); + } + } + } + + typedef TVector<TFdOperationPtr> TFdOperations; + + void ExecuteOperations(TFdOperations& oprs, int errorCode); + + //return true if filter handled events changed and require re-configure events poller + virtual bool FixHandledEvents() noexcept { + DBGOUT("TPollFdEventHandler::FixHandledEvents()"); + ui16 filter = 0; + + if (WriteOperations_.size()) { + filter |= CONT_POLL_WRITE; + } + if (ReadOperations_.size()) { + filter |= CONT_POLL_READ; + } + + if (Y_LIKELY(HandledEvents_ == filter)) { + return false; + } + + HandledEvents_ = filter; + return true; + } + + inline bool FinishOp(TFdOperations& oprs, TFdOperation* op) noexcept { + for (TFdOperations::iterator it = oprs.begin(); it != oprs.end(); ++it) { + if (it->Get() == op) { + FinishedOperations_.push_back(*it); + oprs.erase(it); + return true; + } + } + return false; + } + + void DelOp(TFdOperation* op); + + inline SOCKET Fd() const noexcept { + return Fd_; + } + + inline ui16 HandledEvents() const noexcept { + return HandledEvents_; + } + + inline void AddHandlingEvent(ui16 ev) noexcept { + HandledEvents_ |= ev; + } + + inline void DestroyFinishedOperations() { + FinishedOperations_.clear(); + } + + TIOService::TImpl& GetServiceImpl() const noexcept { + return Srv_; + } + + protected: + SOCKET Fd_; + ui16 HandledEvents_; + TIOService::TImpl& Srv_; + + private: + TVector<TFdOperationPtr> ReadOperations_; + TVector<TFdOperationPtr> WriteOperations_; + // we can't immediatly destroy finished operations, this can cause closing used socket descriptor Fd_ + // (on cascade deletion operation object-handler), but later we use Fd_ for modify handled events at poller, + // so we collect here finished operations and destroy it only after update poller, - + // call FixHandledEvents(TPollFdEventHandlerPtr&) + TVector<TFdOperationPtr> FinishedOperations_; + }; + + //additional descriptor for poller, used for interrupt current poll wait + class TInterrupterHandler: public TPollFdEventHandler { + public: + TInterrupterHandler(TIOService::TImpl& srv, TPollInterrupter& pi) + : TPollFdEventHandler(pi.Fd(), srv) + , PI_(pi) + { + HandledEvents_ = CONT_POLL_READ; + } + + ~TInterrupterHandler() override { + DBGOUT("~TInterrupterHandler"); + } + + void OnFdEvent(int status, ui16 filter) override; + + bool FixHandledEvents() noexcept override { + DBGOUT("TInterrupterHandler::FixHandledEvents()"); + return false; + } + + private: + TPollInterrupter& PI_; + }; + + namespace { + inline TAutoPtr<IPollerFace> CreatePoller() { + try { +#if defined(_linux_) + return IPollerFace::Construct(TStringBuf("epoll")); +#endif +#if defined(_freebsd_) || defined(_darwin_) + return IPollerFace::Construct(TStringBuf("kqueue")); +#endif + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + } + return IPollerFace::Default(); + } + } + + //some equivalent TContExecutor + class TIOService::TImpl: public TNonCopyable { + public: + typedef TAutoPtr<TPollFdEventHandler> TEvh; + typedef TLockFreeSequence<TEvh> TEventHandlers; + + class TTimer { + public: + typedef THashSet<TOperation*> TOperations; + + TTimer(TIOService::TImpl& srv) + : Srv_(srv) + { + } + + virtual ~TTimer() { + FailOperations(ECANCELED); + } + + void AddOp(TOperation* op) { + THolder<TOperation> tmp(op); + Operations_.insert(op); + Y_UNUSED(tmp.Release()); + Srv_.RegisterOpDeadline(op); + Srv_.IncTimersOp(); + } + + void DelOp(TOperation* op) { + TOperations::iterator it = Operations_.find(op); + if (it != Operations_.end()) { + Srv_.DecTimersOp(); + delete op; + Operations_.erase(it); + } + } + + inline void FailOperations(int ec) { + for (auto operation : Operations_) { + try { + operation->Execute(ec); //throw ? + } catch (...) { + } + Srv_.DecTimersOp(); + delete operation; + } + Operations_.clear(); + } + + TIOService::TImpl& GetIOServiceImpl() const noexcept { + return Srv_; + } + + protected: + TIOService::TImpl& Srv_; + THashSet<TOperation*> Operations_; + }; + + class TTimers: public THashSet<TTimer*> { + public: + ~TTimers() { + for (auto it : *this) { + delete it; + } + } + }; + + TImpl() + : P_(CreatePoller()) + , DeadlinesQueue_(*this) + { + } + + ~TImpl() { + TOperationPtr op; + + while (OpQueue_.Dequeue(&op)) { //cancel all enqueued operations + try { + op->Execute(ECANCELED); + } catch (...) { + } + op.Destroy(); + } + } + + //similar TContExecutor::Execute() or io_service::run() + //process event loop (exit if none to do (no timers or event handlers)) + void Run(); + + //enqueue functor fo call in Run() eventloop (thread safing) + inline void Post(TCompletionHandler h) { + class TFuncOperation: public TNoneOperation { + public: + TFuncOperation(TCompletionHandler completionHandler) + : TNoneOperation() + , H_(std::move(completionHandler)) + { + Speculative_ = true; + } + + private: + //return false, if operation not completed + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + H_(); + return true; + } + + TCompletionHandler H_; + }; + + ScheduleOp(new TFuncOperation(std::move(h))); + } + + //cancel all current operations (handlers be called with errorCode == ECANCELED) + void Abort(); + bool HasAbort() { + return AtomicGet(HasAbort_); + } + + inline void ScheduleOp(TOperationPtr op) { //throw std::bad_alloc + Y_ASSERT(!Aborted_); + Y_ASSERT(!!op); + OpQueue_.Enqueue(op); + Interrupt(); + } + + inline void Interrupt() noexcept { + AtomicSet(NeedCheckOpQueue_, 1); + if (AtomicAdd(IsWaiting_, 0) == 1) { + I_.Interrupt(); + } + } + + inline void UpdateOpDeadline(TOperation* op) { + TInstant oldDeadline = op->Deadline(); + op->PrepareReExecution(); + + if (oldDeadline == op->Deadline()) { + return; + } + + if (oldDeadline != TInstant::Max()) { + op->UnLink(); + } + if (op->Deadline() != TInstant::Max()) { + DeadlinesQueue_.Register(op); + } + } + + void SyncRegisterTimer(TTimer* t) { + Timers_.insert(t); + } + + inline void SyncUnregisterAndDestroyTimer(TTimer* t) { + Timers_.erase(t); + delete t; + } + + inline void IncTimersOp() noexcept { + ++TimersOpCnt_; + } + + inline void DecTimersOp() noexcept { + --TimersOpCnt_; + } + + inline void WorkStarted() { + AtomicIncrement(OutstandingWork_); + } + + inline void WorkFinished() { + if (AtomicDecrement(OutstandingWork_) == 0) { + Interrupt(); + } + } + + private: + void ProcessAbort(); + + inline TEvh& EnsureGetEvh(SOCKET fd) { + TEvh& evh = Evh_.Get(fd); + if (!evh) { + evh.Reset(new TPollFdEventHandler(fd, *this)); + } + return evh; + } + + inline void OnTimeoutOp(TOperation* op) { + DBGOUT("OnTimeoutOp"); + try { + op->Execute(ETIMEDOUT); //throw ? + } catch (...) { + op->Finalize(); + throw; + } + + if (op->IsRequiredRepeat()) { + //operation not completed + UpdateOpDeadline(op); + } else { + //destroy operation structure + op->Finalize(); + } + } + + public: + inline void FixHandledEvents(TEvh& evh) { + if (!!evh) { + if (evh->FixHandledEvents()) { + if (!evh->HandledEvents()) { + DelEventHandler(evh); + evh.Destroy(); + } else { + ModEventHandler(evh); + evh->DestroyFinishedOperations(); + } + } else { + evh->DestroyFinishedOperations(); + } + } + } + + private: + inline TEvh& GetHandlerForOp(TFdOperation* op) { + TEvh& evh = EnsureGetEvh(op->Fd()); + op->PH_ = &evh; + return evh; + } + + void ProcessOpQueue() { + if (!AtomicGet(NeedCheckOpQueue_)) { + return; + } + AtomicSet(NeedCheckOpQueue_, 0); + + TOperationPtr op; + + while (OpQueue_.Dequeue(&op)) { + if (op->Speculative()) { + if (op->Execute(Y_UNLIKELY(Aborted_) ? ECANCELED : 0)) { + op.Destroy(); + continue; //operation completed + } + + if (!op->IsRequiredRepeat()) { + op->PrepareReExecution(); + } + } + RegisterOpDeadline(op.Get()); + op.Get()->AddOp(*this); // ... -> AddOp() + Y_UNUSED(op.Release()); + } + } + + inline void RegisterOpDeadline(TOperation* op) { + if (op->DeadLine() != TInstant::Max()) { + DeadlinesQueue_.Register(op); + } + } + + public: + inline void AddOp(TFdOperation* op) { + DBGOUT("AddOp<Fd>(" << op->Fd() << ")"); + TEvh& evh = GetHandlerForOp(op); + if (op->IsPollRead()) { + evh->AddReadOp(op); + EnsureEventHandled(evh, CONT_POLL_READ); + } else { + evh->AddWriteOp(op); + EnsureEventHandled(evh, CONT_POLL_WRITE); + } + } + + private: + inline void EnsureEventHandled(TEvh& evh, ui16 ev) { + if (!evh->HandledEvents()) { + evh->AddHandlingEvent(ev); + AddEventHandler(evh); + } else { + if ((evh->HandledEvents() & ev) == 0) { + evh->AddHandlingEvent(ev); + ModEventHandler(evh); + } + } + } + + public: + //cancel all current operations for socket + //method MUST be called from Run() thread-executor + void CancelFdOp(SOCKET fd) { + TEvh& evh = Evh_.Get(fd); + if (!evh) { + return; + } + + OnFdEvent(evh, ECANCELED, CONT_POLL_READ | CONT_POLL_WRITE); + } + + private: + //helper for fixing handled events even in case exception + struct TExceptionProofFixerHandledEvents { + TExceptionProofFixerHandledEvents(TIOService::TImpl& srv, TEvh& iEvh) + : Srv_(srv) + , Evh_(iEvh) + { + } + + ~TExceptionProofFixerHandledEvents() { + Srv_.FixHandledEvents(Evh_); + } + + TIOService::TImpl& Srv_; + TEvh& Evh_; + }; + + inline void OnFdEvent(TEvh& evh, int status, ui16 filter) { + TExceptionProofFixerHandledEvents fixer(*this, evh); + Y_UNUSED(fixer); + evh->OnFdEvent(status, filter); + } + + inline void AddEventHandler(TEvh& evh) { + if (evh->Fd() > MaxFd_) { + MaxFd_ = evh->Fd(); + } + SetEventHandler(&evh, evh->Fd(), evh->HandledEvents()); + ++FdEventHandlersCnt_; + } + + inline void ModEventHandler(TEvh& evh) { + SetEventHandler(&evh, evh->Fd(), evh->HandledEvents()); + } + + inline void DelEventHandler(TEvh& evh) { + SetEventHandler(&evh, evh->Fd(), 0); + --FdEventHandlersCnt_; + } + + inline void SetEventHandler(void* h, int fd, ui16 flags) { + DBGOUT("SetEventHandler(" << fd << ", " << flags << ")"); + P_->Set(h, fd, flags); + } + + //exception safe call DelEventHandler + struct TInterrupterKeeper { + TInterrupterKeeper(TImpl& srv, TEvh& iEvh) + : Srv_(srv) + , Evh_(iEvh) + { + Srv_.AddEventHandler(Evh_); + } + + ~TInterrupterKeeper() { + Srv_.DelEventHandler(Evh_); + } + + TImpl& Srv_; + TEvh& Evh_; + }; + + TAutoPtr<IPollerFace> P_; + TPollInterrupter I_; + TAtomic IsWaiting_ = 0; + TAtomic NeedCheckOpQueue_ = 0; + TAtomic OutstandingWork_ = 0; + + NNeh::TAutoLockFreeQueue<TOperation> OpQueue_; + + TEventHandlers Evh_; //i/o event handlers + TTimers Timers_; //timeout event handlers + + size_t FdEventHandlersCnt_ = 0; //i/o event handlers counter + size_t TimersOpCnt_ = 0; //timers op counter + SOCKET MaxFd_ = 0; //max used descriptor num + TAtomic HasAbort_ = 0; + bool Aborted_ = false; + + class TDeadlinesQueue { + public: + TDeadlinesQueue(TIOService::TImpl& srv) + : Srv_(srv) + { + } + + inline void Register(TOperation* op) { + Deadlines_.Insert(op); + } + + TInstant NextDeadline() { + TDeadlines::TIterator it = Deadlines_.Begin(); + + while (it != Deadlines_.End()) { + if (it->DeadLine() > TInstant::Now()) { + DBGOUT("TDeadlinesQueue::NewDeadline:" << (it->DeadLine().GetValue() - TInstant::Now().GetValue())); + return it->DeadLine(); + } + + TOperation* op = &*(it++); + Srv_.OnTimeoutOp(op); + } + + return Deadlines_.Empty() ? TInstant::Max() : Deadlines_.Begin()->DeadLine(); + } + + private: + typedef TRbTree<TOperation, TOperationCompare> TDeadlines; + TDeadlines Deadlines_; + TIOService::TImpl& Srv_; + }; + + TDeadlinesQueue DeadlinesQueue_; + }; +} diff --git a/library/cpp/neh/asio/poll_interrupter.cpp b/library/cpp/neh/asio/poll_interrupter.cpp new file mode 100644 index 0000000000..c96d40c4f3 --- /dev/null +++ b/library/cpp/neh/asio/poll_interrupter.cpp @@ -0,0 +1 @@ +#include "poll_interrupter.h" diff --git a/library/cpp/neh/asio/poll_interrupter.h b/library/cpp/neh/asio/poll_interrupter.h new file mode 100644 index 0000000000..faf815c512 --- /dev/null +++ b/library/cpp/neh/asio/poll_interrupter.h @@ -0,0 +1,107 @@ +#pragma once + +#include <util/system/defaults.h> +#include <util/generic/yexception.h> +#include <util/network/socket.h> +#include <util/system/pipe.h> + +#ifdef _linux_ +#include <sys/eventfd.h> +#endif + +#if defined(_bionic_) && !defined(EFD_SEMAPHORE) +#define EFD_SEMAPHORE 1 +#endif + +namespace NAsio { +#ifdef _linux_ + class TEventFdPollInterrupter { + public: + inline TEventFdPollInterrupter() { + F_ = eventfd(0, EFD_NONBLOCK | EFD_SEMAPHORE); + if (F_ < 0) { + ythrow TFileError() << "failed to create a eventfd"; + } + } + + inline ~TEventFdPollInterrupter() { + close(F_); + } + + inline void Interrupt() const noexcept { + const static eventfd_t ev(1); + ssize_t res = ::write(F_, &ev, sizeof ev); + Y_UNUSED(res); + } + + inline bool Reset() const noexcept { + eventfd_t ev(0); + + for (;;) { + ssize_t res = ::read(F_, &ev, sizeof ev); + if (res && res == EINTR) { + continue; + } + + return res > 0; + } + } + + int Fd() { + return F_; + } + + private: + int F_; + }; +#endif + + class TPipePollInterrupter { + public: + TPipePollInterrupter() { + TPipeHandle::Pipe(S_[0], S_[1]); + + SetNonBlock(S_[0]); + SetNonBlock(S_[1]); + } + + inline void Interrupt() const noexcept { + char byte = 0; + ssize_t res = S_[1].Write(&byte, 1); + Y_UNUSED(res); + } + + inline bool Reset() const noexcept { + char buff[256]; + + for (;;) { + ssize_t r = S_[0].Read(buff, sizeof buff); + + if (r < 0 && r == EINTR) { + continue; + } + + bool wasInterrupted = r > 0; + + while (r == sizeof buff) { + r = S_[0].Read(buff, sizeof buff); + } + + return wasInterrupted; + } + } + + PIPEHANDLE Fd() const noexcept { + return S_[0]; + } + + private: + TPipeHandle S_[2]; + }; + +#ifdef _linux_ + typedef TEventFdPollInterrupter TPollInterrupter; //more effective than pipe, but only linux impl. +#else + typedef TPipePollInterrupter TPollInterrupter; +#endif +} diff --git a/library/cpp/neh/asio/tcp_acceptor_impl.cpp b/library/cpp/neh/asio/tcp_acceptor_impl.cpp new file mode 100644 index 0000000000..7e1d75fcf5 --- /dev/null +++ b/library/cpp/neh/asio/tcp_acceptor_impl.cpp @@ -0,0 +1,25 @@ +#include "tcp_acceptor_impl.h" + +using namespace NAsio; + +bool TOperationAccept::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, *this); + + return true; + } + + struct sockaddr_storage addr; + socklen_t sz = sizeof(addr); + + SOCKET res = ::accept(Fd(), (sockaddr*)&addr, &sz); + + if (res == INVALID_SOCKET) { + H_(LastSystemError(), *this); + } else { + NS_.Assign(res, TEndpoint(new NAddr::TOpaqueAddr((sockaddr*)&addr))); + H_(0, *this); + } + + return true; +} diff --git a/library/cpp/neh/asio/tcp_acceptor_impl.h b/library/cpp/neh/asio/tcp_acceptor_impl.h new file mode 100644 index 0000000000..c990236efc --- /dev/null +++ b/library/cpp/neh/asio/tcp_acceptor_impl.h @@ -0,0 +1,76 @@ +#pragma once + +#include "asio.h" + +#include "tcp_socket_impl.h" + +namespace NAsio { + class TOperationAccept: public TFdOperation { + public: + TOperationAccept(SOCKET fd, TTcpSocket::TImpl& newSocket, TTcpAcceptor::TAcceptHandler h, TInstant deadline) + : TFdOperation(fd, PollRead, deadline) + , H_(h) + , NS_(newSocket) + { + } + + bool Execute(int errorCode) override; + + TTcpAcceptor::TAcceptHandler H_; + TTcpSocket::TImpl& NS_; + }; + + class TTcpAcceptor::TImpl: public TThrRefBase { + public: + TImpl(TIOService::TImpl& srv) noexcept + : Srv_(srv) + { + } + + inline void Bind(TEndpoint& ep, TErrorCode& ec) noexcept { + TSocketHolder s(socket(ep.SockAddr()->sa_family, SOCK_STREAM, 0)); + + if (s == INVALID_SOCKET) { + ec.Assign(LastSystemError()); + } + + FixIPv6ListenSocket(s); + CheckedSetSockOpt(s, SOL_SOCKET, SO_REUSEADDR, 1, "reuse addr"); + SetNonBlock(s); + + if (::bind(s, ep.SockAddr(), ep.SockAddrLen())) { + ec.Assign(LastSystemError()); + return; + } + + S_.Swap(s); + } + + inline void Listen(int backlog, TErrorCode& ec) noexcept { + if (::listen(S_, backlog)) { + ec.Assign(LastSystemError()); + return; + } + } + + inline void AsyncAccept(TTcpSocket& s, TTcpAcceptor::TAcceptHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationAccept((SOCKET)S_, s.GetImpl(), h, deadline)); //set callback + } + + inline void AsyncCancel() { + Srv_.ScheduleOp(new TOperationCancel<TTcpAcceptor::TImpl>(this)); + } + + inline TIOService::TImpl& GetIOServiceImpl() const noexcept { + return Srv_; + } + + inline SOCKET Fd() const noexcept { + return S_; + } + + private: + TIOService::TImpl& Srv_; + TSocketHolder S_; + }; +} diff --git a/library/cpp/neh/asio/tcp_socket_impl.cpp b/library/cpp/neh/asio/tcp_socket_impl.cpp new file mode 100644 index 0000000000..98cef97561 --- /dev/null +++ b/library/cpp/neh/asio/tcp_socket_impl.cpp @@ -0,0 +1,117 @@ +#include "tcp_socket_impl.h" + +using namespace NAsio; + +TSocketOperation::TSocketOperation(TTcpSocket::TImpl& s, TPollType pt, TInstant deadline) + : TFdOperation(s.Fd(), pt, deadline) + , S_(s) +{ +} + +bool TOperationWrite::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, Written_, *this); + + return true; //op. completed + } + + TErrorCode ec; + TContIOVector& iov = *Buffs_->GetIOvec(); + + size_t n = S_.WriteSome(iov, ec); + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + H_(ec, Written_ + n, *this); + + return true; + } + + if (n) { + Written_ += n; + iov.Proceed(n); + if (!iov.Bytes()) { + H_(ec, Written_, *this); + + return true; //op. completed + } + } + + return false; //operation not compleled +} + +bool TOperationWriteVector::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, Written_, *this); + + return true; //op. completed + } + + TErrorCode ec; + + size_t n = S_.WriteSome(V_, ec); + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + H_(ec, Written_ + n, *this); + + return true; + } + + if (n) { + Written_ += n; + V_.Proceed(n); + if (!V_.Bytes()) { + H_(ec, Written_, *this); + + return true; //op. completed + } + } + + return false; //operation not compleled +} + +bool TOperationReadSome::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, 0, *this); + + return true; //op. completed + } + + TErrorCode ec; + + H_(ec, S_.ReadSome(Buff_, Size_, ec), *this); + + return true; +} + +bool TOperationRead::Execute(int errorCode) { + if (errorCode) { + H_(errorCode, Read_, *this); + + return true; //op. completed + } + + TErrorCode ec; + size_t n = S_.ReadSome(Buff_, Size_, ec); + Read_ += n; + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + H_(ec, Read_, *this); + + return true; //op. completed + } + + if (n) { + Size_ -= n; + if (!Size_) { + H_(ec, Read_, *this); + + return true; + } + Buff_ += n; + } else if (!ec) { // EOF while read not all + H_(ec, Read_, *this); + return true; + } + + return false; +} diff --git a/library/cpp/neh/asio/tcp_socket_impl.h b/library/cpp/neh/asio/tcp_socket_impl.h new file mode 100644 index 0000000000..44f8f42d87 --- /dev/null +++ b/library/cpp/neh/asio/tcp_socket_impl.h @@ -0,0 +1,332 @@ +#pragma once + +#include "asio.h" +#include "io_service_impl.h" + +#include <sys/uio.h> + +#if defined(_bionic_) +# define IOV_MAX 1024 +#endif + +namespace NAsio { + // ownership/keep-alive references: + // Handlers <- TOperation...(TFdOperation) <- TPollFdEventHandler <- TIOService + + class TSocketOperation: public TFdOperation { + public: + TSocketOperation(TTcpSocket::TImpl& s, TPollType pt, TInstant deadline); + + protected: + TTcpSocket::TImpl& S_; + }; + + class TOperationConnect: public TSocketOperation { + public: + TOperationConnect(TTcpSocket::TImpl& s, TTcpSocket::TConnectHandler h, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + { + } + + bool Execute(int errorCode) override { + H_(errorCode, *this); + + return true; + } + + TTcpSocket::TConnectHandler H_; + }; + + class TOperationConnectFailed: public TSocketOperation { + public: + TOperationConnectFailed(TTcpSocket::TImpl& s, TTcpSocket::TConnectHandler h, int errorCode, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + , ErrorCode_(errorCode) + { + Speculative_ = true; + } + + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + H_(ErrorCode_, *this); + + return true; + } + + TTcpSocket::TConnectHandler H_; + int ErrorCode_; + }; + + class TOperationWrite: public TSocketOperation { + public: + TOperationWrite(TTcpSocket::TImpl& s, NAsio::TTcpSocket::TSendedData& buffs, TTcpSocket::TWriteHandler h, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + , Buffs_(buffs) + , Written_(0) + { + Speculative_ = true; + } + + //return true, if not need write more data + bool Execute(int errorCode) override; + + private: + TTcpSocket::TWriteHandler H_; + NAsio::TTcpSocket::TSendedData Buffs_; + size_t Written_; + }; + + class TOperationWriteVector: public TSocketOperation { + public: + TOperationWriteVector(TTcpSocket::TImpl& s, TContIOVector* v, TTcpSocket::TWriteHandler h, TInstant deadline) + : TSocketOperation(s, PollWrite, deadline) + , H_(h) + , V_(*v) + , Written_(0) + { + Speculative_ = true; + } + + //return true, if not need write more data + bool Execute(int errorCode) override; + + private: + TTcpSocket::TWriteHandler H_; + TContIOVector& V_; + size_t Written_; + }; + + class TOperationReadSome: public TSocketOperation { + public: + TOperationReadSome(TTcpSocket::TImpl& s, void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) + : TSocketOperation(s, PollRead, deadline) + , H_(h) + , Buff_(static_cast<char*>(buff)) + , Size_(size) + { + } + + //return true, if not need read more data + bool Execute(int errorCode) override; + + protected: + TTcpSocket::TReadHandler H_; + char* Buff_; + size_t Size_; + }; + + class TOperationRead: public TOperationReadSome { + public: + TOperationRead(TTcpSocket::TImpl& s, void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) + : TOperationReadSome(s, buff, size, h, deadline) + , Read_(0) + { + } + + bool Execute(int errorCode) override; + + private: + size_t Read_; + }; + + class TOperationPoll: public TSocketOperation { + public: + TOperationPoll(TTcpSocket::TImpl& s, TPollType pt, TTcpSocket::TPollHandler h, TInstant deadline) + : TSocketOperation(s, pt, deadline) + , H_(h) + { + } + + bool Execute(int errorCode) override { + H_(errorCode, *this); + + return true; + } + + private: + TTcpSocket::TPollHandler H_; + }; + + template <class T> + class TOperationCancel: public TNoneOperation { + public: + TOperationCancel(T* s) + : TNoneOperation() + , S_(s) + { + Speculative_ = true; + } + + ~TOperationCancel() override { + } + + private: + bool Execute(int errorCode) override { + Y_UNUSED(errorCode); + if (!errorCode && S_->Fd() != INVALID_SOCKET) { + S_->GetIOServiceImpl().CancelFdOp(S_->Fd()); + } + return true; + } + + TIntrusivePtr<T> S_; + }; + + class TTcpSocket::TImpl: public TNonCopyable, public TThrRefBase { + public: + typedef TTcpSocket::TSendedData TSendedData; + + TImpl(TIOService::TImpl& srv) noexcept + : Srv_(srv) + { + } + + ~TImpl() override { + DBGOUT("TSocket::~TImpl()"); + } + + void Assign(SOCKET fd, TEndpoint ep) { + TSocketHolder(fd).Swap(S_); + RemoteEndpoint_ = ep; + } + + void AsyncConnect(const TEndpoint& ep, TTcpSocket::TConnectHandler h, TInstant deadline) { + TSocketHolder s(socket(ep.SockAddr()->sa_family, SOCK_STREAM, 0)); + + if (Y_UNLIKELY(s == INVALID_SOCKET || Srv_.HasAbort())) { + throw TSystemError() << TStringBuf("can't create socket"); + } + + SetNonBlock(s); + + int err; + do { + err = connect(s, ep.SockAddr(), (int)ep.SockAddrLen()); + if (Y_LIKELY(err)) { + err = LastSystemError(); + } +#if defined(_freebsd_) + if (Y_UNLIKELY(err == EINTR)) { + err = EINPROGRESS; + } + } while (0); +#elif defined(_linux_) + } while (Y_UNLIKELY(err == EINTR)); +#else + } while (0); +#endif + + RemoteEndpoint_ = ep; + S_.Swap(s); + + DBGOUT("AsyncConnect(): " << err); + if (Y_LIKELY(err == EINPROGRESS || err == EWOULDBLOCK || err == 0)) { + Srv_.ScheduleOp(new TOperationConnect(*this, h, deadline)); //set callback + } else { + Srv_.ScheduleOp(new TOperationConnectFailed(*this, h, err, deadline)); //set callback + } + } + + inline void AsyncWrite(TSendedData& d, TTcpSocket::TWriteHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationWrite(*this, d, h, deadline)); + } + + inline void AsyncWrite(TContIOVector* v, TTcpSocket::TWriteHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationWriteVector(*this, v, h, deadline)); + } + + inline void AsyncRead(void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationRead(*this, buff, size, h, deadline)); + } + + inline void AsyncReadSome(void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationReadSome(*this, buff, size, h, deadline)); + } + + inline void AsyncPollWrite(TTcpSocket::TPollHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationPoll(*this, TOperationPoll::PollWrite, h, deadline)); + } + + inline void AsyncPollRead(TTcpSocket::TPollHandler h, TInstant deadline) { + Srv_.ScheduleOp(new TOperationPoll(*this, TOperationPoll::PollRead, h, deadline)); + } + + inline void AsyncCancel() { + if (Y_UNLIKELY(Srv_.HasAbort())) { + return; + } + Srv_.ScheduleOp(new TOperationCancel<TTcpSocket::TImpl>(this)); + } + + inline bool SysCallHasResult(ssize_t& n, TErrorCode& ec) noexcept { + if (n >= 0) { + return true; + } + + int errn = LastSystemError(); + if (errn == EINTR) { + return false; + } + + ec.Assign(errn); + n = 0; + return true; + } + + size_t WriteSome(TContIOVector& iov, TErrorCode& ec) noexcept { + for (;;) { + ssize_t n = writev(S_, (const iovec*)iov.Parts(), Min(IOV_MAX, (int)iov.Count())); + DBGOUT("WriteSome(): n=" << n); + if (SysCallHasResult(n, ec)) { + return n; + } + } + } + + size_t WriteSome(const void* buff, size_t size, TErrorCode& ec) noexcept { + for (;;) { + ssize_t n = send(S_, (char*)buff, size, 0); + DBGOUT("WriteSome(): n=" << n); + if (SysCallHasResult(n, ec)) { + return n; + } + } + } + + size_t ReadSome(void* buff, size_t size, TErrorCode& ec) noexcept { + for (;;) { + ssize_t n = recv(S_, (char*)buff, size, 0); + DBGOUT("ReadSome(): n=" << n); + if (SysCallHasResult(n, ec)) { + return n; + } + } + } + + inline void Shutdown(TTcpSocket::TShutdownMode mode, TErrorCode& ec) { + if (shutdown(S_, mode)) { + ec.Assign(LastSystemError()); + } + } + + TIOService::TImpl& GetIOServiceImpl() const noexcept { + return Srv_; + } + + inline SOCKET Fd() const noexcept { + return S_; + } + + TEndpoint RemoteEndpoint() const { + return RemoteEndpoint_; + } + + private: + TIOService::TImpl& Srv_; + TSocketHolder S_; + TEndpoint RemoteEndpoint_; + }; +} diff --git a/library/cpp/neh/conn_cache.cpp b/library/cpp/neh/conn_cache.cpp new file mode 100644 index 0000000000..a2f0868733 --- /dev/null +++ b/library/cpp/neh/conn_cache.cpp @@ -0,0 +1 @@ +#include "conn_cache.h" diff --git a/library/cpp/neh/conn_cache.h b/library/cpp/neh/conn_cache.h new file mode 100644 index 0000000000..944ba4cfec --- /dev/null +++ b/library/cpp/neh/conn_cache.h @@ -0,0 +1,149 @@ +#pragma once + +#include <string.h> + +#include <util/generic/ptr.h> +#include <util/generic/singleton.h> +#include <util/generic/string.h> +#include <library/cpp/deprecated/atomic/atomic.h> +#include <util/thread/lfqueue.h> + +#include "http_common.h" + +namespace NNeh { + namespace NHttp2 { + // TConn must be refcounted and contain methods: + // void SetCached(bool) noexcept + // bool IsValid() const noexcept + // void Close() noexcept + template <class TConn> + class TConnCache { + struct TCounter : TAtomicCounter { + inline void IncCount(const TConn* const&) { + Inc(); + } + + inline void DecCount(const TConn* const&) { + Dec(); + } + }; + + public: + typedef TIntrusivePtr<TConn> TConnRef; + + class TConnList: public TLockFreeQueue<TConn*, TCounter> { + public: + ~TConnList() { + Clear(); + } + + inline void Clear() { + TConn* conn; + + while (this->Dequeue(&conn)) { + conn->Close(); + conn->UnRef(); + } + } + + inline size_t Size() { + return this->GetCounter().Val(); + } + }; + + inline void Put(TConnRef& conn, size_t addrId) { + conn->SetCached(true); + ConnList(addrId).Enqueue(conn.Get()); + conn->Ref(); + Y_UNUSED(conn.Release()); + CachedConn_.Inc(); + } + + bool Get(TConnRef& conn, size_t addrId) { + TConnList& connList = ConnList(addrId); + TConn* connTmp; + + while (connList.Dequeue(&connTmp)) { + connTmp->SetCached(false); + CachedConn_.Dec(); + if (connTmp->IsValid()) { + TConnRef(connTmp).Swap(conn); + connTmp->DecRef(); + + return true; + } else { + connTmp->UnRef(); + } + } + return false; + } + + inline size_t Size() const noexcept { + return CachedConn_.Val(); + } + + inline size_t Validate(size_t addrId) { + TConnList& cl = Lst_.Get(addrId); + return Validate(cl); + } + + //close/remove part of the connections from cache + size_t Purge(size_t addrId, size_t frac256) { + TConnList& cl = Lst_.Get(addrId); + size_t qsize = cl.Size(); + if (!qsize) { + return 0; + } + + size_t purgeCounter = ((qsize * frac256) >> 8); + if (!purgeCounter && qsize >= 2) { + purgeCounter = 1; + } + + size_t pc = 0; + { + TConn* conn; + while (purgeCounter-- && cl.Dequeue(&conn)) { + conn->SetCached(false); + if (conn->IsValid()) { + conn->Close(); + } + CachedConn_.Dec(); + conn->UnRef(); + ++pc; + } + } + pc += Validate(cl); + + return pc; + } + + private: + inline TConnList& ConnList(size_t addrId) { + return Lst_.Get(addrId); + } + + inline size_t Validate(TConnList& cl) { + size_t pc = 0; + size_t nc = cl.Size(); + + TConn* conn; + while (nc-- && cl.Dequeue(&conn)) { + if (conn->IsValid()) { + cl.Enqueue(conn); + } else { + ++pc; + conn->SetCached(false); + CachedConn_.Dec(); + conn->UnRef(); + } + } + + return pc; + } + + NNeh::NHttp::TLockFreeSequence<TConnList> Lst_; + TAtomicCounter CachedConn_; + }; + } +} diff --git a/library/cpp/neh/details.h b/library/cpp/neh/details.h new file mode 100644 index 0000000000..cf3b6fd765 --- /dev/null +++ b/library/cpp/neh/details.h @@ -0,0 +1,99 @@ +#pragma once + +#include "neh.h" +#include <library/cpp/neh/utils.h> + +#include <library/cpp/http/io/headers.h> + +#include <util/generic/singleton.h> +#include <library/cpp/deprecated/atomic/atomic.h> + +namespace NNeh { + class TNotifyHandle: public THandle { + public: + inline TNotifyHandle(IOnRecv* r, const TMessage& msg, TStatCollector* s = nullptr) noexcept + : THandle(r, s) + , Msg_(msg) + , StartTime_(TInstant::Now()) + { + } + + void NotifyResponse(const TString& resp, const TString& firstLine = {}, const THttpHeaders& headers = Default<THttpHeaders>()) { + Notify(new TResponse(Msg_, resp, ExecDuration(), firstLine, headers)); + } + + void NotifyError(const TString& errorText) { + Notify(TResponse::FromError(Msg_, new TError(errorText), ExecDuration())); + } + + void NotifyError(TErrorRef error) { + Notify(TResponse::FromError(Msg_, error, ExecDuration())); + } + + /** Calls when asnwer is received and reponse has headers and first line. + */ + void NotifyError(TErrorRef error, const TString& data, const TString& firstLine, const THttpHeaders& headers) { + Notify(TResponse::FromError(Msg_, error, data, ExecDuration(), firstLine, headers)); + } + + const TMessage& Message() const noexcept { + return Msg_; + } + + private: + inline TDuration ExecDuration() const { + TInstant now = TInstant::Now(); + if (now > StartTime_) { + return now - StartTime_; + } + + return TDuration::Zero(); + } + + const TMessage Msg_; + const TInstant StartTime_; + }; + + typedef TIntrusivePtr<TNotifyHandle> TNotifyHandleRef; + + class TSimpleHandle: public TNotifyHandle { + public: + inline TSimpleHandle(IOnRecv* r, const TMessage& msg, TStatCollector* s = nullptr) noexcept + : TNotifyHandle(r, msg, s) + , SendComplete_(false) + , Canceled_(false) + { + } + + bool MessageSendedCompletely() const noexcept override { + return SendComplete_; + } + + void Cancel() noexcept override { + Canceled_ = true; + THandle::Cancel(); + } + + inline void SetSendComplete() noexcept { + SendComplete_ = true; + } + + inline bool Canceled() const noexcept { + return Canceled_; + } + + inline const TAtomicBool* CanceledPtr() const noexcept { + return &Canceled_; + } + + void ResetOnRecv() noexcept { + F_ = nullptr; + } + + private: + TAtomicBool SendComplete_; + TAtomicBool Canceled_; + }; + + typedef TIntrusivePtr<TSimpleHandle> TSimpleHandleRef; +} diff --git a/library/cpp/neh/factory.cpp b/library/cpp/neh/factory.cpp new file mode 100644 index 0000000000..5b9fd700e6 --- /dev/null +++ b/library/cpp/neh/factory.cpp @@ -0,0 +1,67 @@ +#include "factory.h" +#include "udp.h" +#include "netliba.h" +#include "https.h" +#include "http2.h" +#include "inproc.h" +#include "tcp.h" +#include "tcp2.h" + +#include <util/generic/hash.h> +#include <util/generic/strbuf.h> +#include <util/generic/singleton.h> + +using namespace NNeh; + +namespace { + class TProtocolFactory: public IProtocolFactory, public THashMap<TStringBuf, IProtocol*> { + public: + inline TProtocolFactory() { + Register(NetLibaProtocol()); + Register(Http1Protocol()); + Register(Post1Protocol()); + Register(Full1Protocol()); + Register(UdpProtocol()); + Register(InProcProtocol()); + Register(TcpProtocol()); + Register(Tcp2Protocol()); + Register(Http2Protocol()); + Register(Post2Protocol()); + Register(Full2Protocol()); + Register(SSLGetProtocol()); + Register(SSLPostProtocol()); + Register(SSLFullProtocol()); + Register(UnixSocketGetProtocol()); + Register(UnixSocketPostProtocol()); + Register(UnixSocketFullProtocol()); + } + + IProtocol* Protocol(const TStringBuf& proto) override { + const_iterator it = find(proto); + + if (it == end()) { + ythrow yexception() << "unsupported scheme " << proto; + } + + return it->second; + } + + void Register(IProtocol* proto) override { + (*this)[proto->Scheme()] = proto; + } + }; + + IProtocolFactory* GLOBAL_FACTORY = nullptr; +} + +void NNeh::SetGlobalFactory(IProtocolFactory* factory) { + GLOBAL_FACTORY = factory; +} + +IProtocolFactory* NNeh::ProtocolFactory() { + if (GLOBAL_FACTORY) { + return GLOBAL_FACTORY; + } + + return Singleton<TProtocolFactory>(); +} diff --git a/library/cpp/neh/factory.h b/library/cpp/neh/factory.h new file mode 100644 index 0000000000..17bebef8ed --- /dev/null +++ b/library/cpp/neh/factory.h @@ -0,0 +1,37 @@ +#pragma once + +#include "neh.h" +#include "rpc.h" + +namespace NNeh { + struct TParsedLocation; + + class IProtocol { + public: + virtual ~IProtocol() { + } + virtual IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) = 0; + virtual THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef&) = 0; + virtual THandleRef ScheduleAsyncRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& statRef, bool useAsyncSendRequest = false) { + Y_UNUSED(useAsyncSendRequest); + return ScheduleRequest(msg, fallback, statRef); + } + virtual TStringBuf Scheme() const noexcept = 0; + virtual bool SetOption(TStringBuf name, TStringBuf value) { + Y_UNUSED(name); + Y_UNUSED(value); + return false; + } + }; + + class IProtocolFactory { + public: + virtual IProtocol* Protocol(const TStringBuf& scheme) = 0; + virtual void Register(IProtocol* proto) = 0; + virtual ~IProtocolFactory() { + } + }; + + void SetGlobalFactory(IProtocolFactory* factory); + IProtocolFactory* ProtocolFactory(); +} diff --git a/library/cpp/neh/http2.cpp b/library/cpp/neh/http2.cpp new file mode 100644 index 0000000000..0bba29cf22 --- /dev/null +++ b/library/cpp/neh/http2.cpp @@ -0,0 +1,2102 @@ +#include "http2.h" + +#include "conn_cache.h" +#include "details.h" +#include "factory.h" +#include "http_common.h" +#include "smart_ptr.h" +#include "utils.h" + +#include <library/cpp/http/push_parser/http_parser.h> +#include <library/cpp/http/misc/httpcodes.h> +#include <library/cpp/http/misc/parsed_request.h> +#include <library/cpp/neh/asio/executor.h> + +#include <util/generic/singleton.h> +#include <util/generic/vector.h> +#include <util/network/iovec.h> +#include <util/stream/output.h> +#include <util/stream/zlib.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> +#include <util/system/spinlock.h> +#include <util/system/yassert.h> +#include <util/thread/factory.h> +#include <util/thread/singleton.h> +#include <util/system/sanitizers.h> + +#include <atomic> + +#if defined(_unix_) +#include <sys/ioctl.h> +#endif + +#if defined(_linux_) +#undef SIOCGSTAMP +#undef SIOCGSTAMPNS +#include <linux/sockios.h> +#define FIONWRITE SIOCOUTQ +#endif + +//#define DEBUG_HTTP2 + +#ifdef DEBUG_HTTP2 +#define DBGOUT(args) Cout << args << Endl; +#else +#define DBGOUT(args) +#endif + +using namespace NDns; +using namespace NAsio; +using namespace NNeh; +using namespace NNeh::NHttp; +using namespace NNeh::NHttp2; +using namespace std::placeholders; + +// +// has complex keep-alive references between entities in multi-thread enviroment, +// this create risks for races/memory leak, etc.. +// so connecting/disconnecting entities must be doing carefully +// +// handler <=-> request <==> connection(socket) <= handlers, stored in io_service +// ^ +// +== connections_cache +// '=>' -- shared/intrusive ptr +// '->' -- weak_ptr +// + +static TDuration FixTimeoutForSanitizer(const TDuration timeout) { + ui64 multiplier = 1; + if (NSan::ASanIsOn()) { + // https://github.com/google/sanitizers/wiki/AddressSanitizer + multiplier = 4; + } else if (NSan::MSanIsOn()) { + // via https://github.com/google/sanitizers/wiki/MemorySanitizer + multiplier = 3; + } else if (NSan::TSanIsOn()) { + // via https://clang.llvm.org/docs/ThreadSanitizer.html + multiplier = 15; + } + + return TDuration::FromValue(timeout.GetValue() * multiplier); +} + +TDuration THttp2Options::ConnectTimeout = FixTimeoutForSanitizer(TDuration::MilliSeconds(1000)); +TDuration THttp2Options::InputDeadline = TDuration::Max(); +TDuration THttp2Options::OutputDeadline = TDuration::Max(); +TDuration THttp2Options::SymptomSlowConnect = FixTimeoutForSanitizer(TDuration::MilliSeconds(10)); +size_t THttp2Options::InputBufferSize = 16 * 1024; +bool THttp2Options::KeepInputBufferForCachedConnections = false; +size_t THttp2Options::AsioThreads = 4; +size_t THttp2Options::AsioServerThreads = 4; +bool THttp2Options::EnsureSendingCompleteByAck = false; +int THttp2Options::Backlog = 100; +TDuration THttp2Options::ServerInputDeadline = FixTimeoutForSanitizer(TDuration::MilliSeconds(500)); +TDuration THttp2Options::ServerOutputDeadline = TDuration::Max(); +TDuration THttp2Options::ServerInputDeadlineKeepAliveMax = FixTimeoutForSanitizer(TDuration::Seconds(120)); +TDuration THttp2Options::ServerInputDeadlineKeepAliveMin = FixTimeoutForSanitizer(TDuration::Seconds(10)); +bool THttp2Options::ServerUseDirectWrite = false; +bool THttp2Options::UseResponseAsErrorMessage = false; +bool THttp2Options::FullHeadersAsErrorMessage = false; +bool THttp2Options::ErrorDetailsAsResponseBody = false; +bool THttp2Options::RedirectionNotError = false; +bool THttp2Options::AnyResponseIsNotError = false; +bool THttp2Options::TcpKeepAlive = false; +i32 THttp2Options::LimitRequestsPerConnection = -1; +bool THttp2Options::QuickAck = false; +bool THttp2Options::UseAsyncSendRequest = false; + +bool THttp2Options::Set(TStringBuf name, TStringBuf value) { +#define HTTP2_TRY_SET(optType, optName) \ + if (name == TStringBuf(#optName)) { \ + optName = FromString<optType>(value); \ + } + + HTTP2_TRY_SET(TDuration, ConnectTimeout) + else HTTP2_TRY_SET(TDuration, InputDeadline) + else HTTP2_TRY_SET(TDuration, OutputDeadline) + else HTTP2_TRY_SET(TDuration, SymptomSlowConnect) else HTTP2_TRY_SET(size_t, InputBufferSize) else HTTP2_TRY_SET(bool, KeepInputBufferForCachedConnections) else HTTP2_TRY_SET(size_t, AsioThreads) else HTTP2_TRY_SET(size_t, AsioServerThreads) else HTTP2_TRY_SET(bool, EnsureSendingCompleteByAck) else HTTP2_TRY_SET(int, Backlog) else HTTP2_TRY_SET(TDuration, ServerInputDeadline) else HTTP2_TRY_SET(TDuration, ServerOutputDeadline) else HTTP2_TRY_SET(TDuration, ServerInputDeadlineKeepAliveMax) else HTTP2_TRY_SET(TDuration, ServerInputDeadlineKeepAliveMin) else HTTP2_TRY_SET(bool, ServerUseDirectWrite) else HTTP2_TRY_SET(bool, UseResponseAsErrorMessage) else HTTP2_TRY_SET(bool, FullHeadersAsErrorMessage) else HTTP2_TRY_SET(bool, ErrorDetailsAsResponseBody) else HTTP2_TRY_SET(bool, RedirectionNotError) else HTTP2_TRY_SET(bool, AnyResponseIsNotError) else HTTP2_TRY_SET(bool, TcpKeepAlive) else HTTP2_TRY_SET(i32, LimitRequestsPerConnection) else HTTP2_TRY_SET(bool, QuickAck) + else HTTP2_TRY_SET(bool, UseAsyncSendRequest) else { + return false; + } + return true; +} + +namespace NNeh { + const NDns::TResolvedHost* Resolve(const TStringBuf host, ui16 port, NHttp::EResolverType resolverType); +} + +namespace { +//#define DEBUG_STAT + +#ifdef DEBUG_STAT + struct TDebugStat { + static std::atomic<size_t> ConnTotal; + static std::atomic<size_t> ConnActive; + static std::atomic<size_t> ConnCached; + static std::atomic<size_t> ConnDestroyed; + static std::atomic<size_t> ConnFailed; + static std::atomic<size_t> ConnConnCanceled; + static std::atomic<size_t> ConnSlow; + static std::atomic<size_t> Conn2Success; + static std::atomic<size_t> ConnPurgedInCache; + static std::atomic<size_t> ConnDestroyedInCache; + static std::atomic<size_t> RequestTotal; + static std::atomic<size_t> RequestSuccessed; + static std::atomic<size_t> RequestFailed; + static void Print() { + Cout << "ct=" << ConnTotal.load(std::memory_order_acquire) + << " ca=" << ConnActive.load(std::memory_order_acquire) + << " cch=" << ConnCached.load(std::memory_order_acquire) + << " cd=" << ConnDestroyed.load(std::memory_order_acquire) + << " cf=" << ConnFailed.load(std::memory_order_acquire) + << " ccc=" << ConnConnCanceled.load(std::memory_order_acquire) + << " csl=" << ConnSlow.load(std::memory_order_acquire) + << " c2s=" << Conn2Success.load(std::memory_order_acquire) + << " cpc=" << ConnPurgedInCache.load(std::memory_order_acquire) + << " cdc=" << ConnDestroyedInCache.load(std::memory_order_acquire) + << " rt=" << RequestTotal.load(std::memory_order_acquire) + << " rs=" << RequestSuccessed.load(std::memory_order_acquire) + << " rf=" << RequestFailed.load(std::memory_order_acquire) + << Endl; + } + }; + std::atomic<size_t> TDebugStat::ConnTotal = 0; + std::atomic<size_t> TDebugStat::ConnActive = 0; + std::atomic<size_t> TDebugStat::ConnCached = 0; + std::atomic<size_t> TDebugStat::ConnDestroyed = 0; + std::atomic<size_t> TDebugStat::ConnFailed = 0; + std::atomic<size_t> TDebugStat::ConnConnCanceled = 0; + std::atomic<size_t> TDebugStat::ConnSlow = 0; + std::atomic<size_t> TDebugStat::Conn2Success = 0; + std::atomic<size_t> TDebugStat::ConnPurgedInCache = 0; + std::atomic<size_t> TDebugStat::ConnDestroyedInCache = 0; + std::atomic<size_t> TDebugStat::RequestTotal = 0; + std::atomic<size_t> TDebugStat::RequestSuccessed = 0; + std::atomic<size_t> TDebugStat::RequestFailed = 0; +#endif + + inline void PrepareSocket(SOCKET s, const TRequestSettings& requestSettings = TRequestSettings()) { + if (requestSettings.NoDelay) { + SetNoDelay(s, true); + } + } + + bool Compress(TData& data, const TString& compressionScheme) { + if (compressionScheme == "gzip") { + try { + TData gzipped(data.size()); + TMemoryOutput out(gzipped.data(), gzipped.size()); + TZLibCompress c(&out, ZLib::GZip); + c.Write(data.data(), data.size()); + c.Finish(); + gzipped.resize(out.Buf() - gzipped.data()); + data.swap(gzipped); + return true; + } catch (yexception&) { + // gzipped data occupies more space than original data + } + } + return false; + } + + class THttpRequestBuffers: public NAsio::TTcpSocket::IBuffers { + public: + THttpRequestBuffers(TRequestData::TPtr rd) + : Req_(rd) + , Parts_(Req_->Parts()) + , IOvec_(Parts_.data(), Parts_.size()) + { + } + + TContIOVector* GetIOvec() override { + return &IOvec_; + } + + private: + TRequestData::TPtr Req_; + TVector<IOutputStream::TPart> Parts_; + TContIOVector IOvec_; + }; + + struct TRequestGet1: public TRequestGet { + static inline TStringBuf Name() noexcept { + return TStringBuf("http"); + } + }; + + struct TRequestPost1: public TRequestPost { + static inline TStringBuf Name() noexcept { + return TStringBuf("post"); + } + }; + + struct TRequestFull1: public TRequestFull { + static inline TStringBuf Name() noexcept { + return TStringBuf("full"); + } + }; + + struct TRequestGet2: public TRequestGet { + static inline TStringBuf Name() noexcept { + return TStringBuf("http2"); + } + }; + + struct TRequestPost2: public TRequestPost { + static inline TStringBuf Name() noexcept { + return TStringBuf("post2"); + } + }; + + struct TRequestFull2: public TRequestFull { + static inline TStringBuf Name() noexcept { + return TStringBuf("full2"); + } + }; + + struct TRequestUnixSocketGet: public TRequestGet { + static inline TStringBuf Name() noexcept { + return TStringBuf("http+unix"); + } + + static TRequestSettings RequestSettings() { + return TRequestSettings() + .SetNoDelay(false) + .SetResolverType(EResolverType::EUNIXSOCKET); + } + }; + + struct TRequestUnixSocketPost: public TRequestPost { + static inline TStringBuf Name() noexcept { + return TStringBuf("post+unix"); + } + + static TRequestSettings RequestSettings() { + return TRequestSettings() + .SetNoDelay(false) + .SetResolverType(EResolverType::EUNIXSOCKET); + } + }; + + struct TRequestUnixSocketFull: public TRequestFull { + static inline TStringBuf Name() noexcept { + return TStringBuf("full+unix"); + } + + static TRequestSettings RequestSettings() { + return TRequestSettings() + .SetNoDelay(false) + .SetResolverType(EResolverType::EUNIXSOCKET); + } + }; + + typedef TAutoPtr<THttpRequestBuffers> THttpRequestBuffersPtr; + + class THttpRequest; + typedef TSharedPtrB<THttpRequest> THttpRequestRef; + + class THttpConn; + typedef TIntrusivePtr<THttpConn> THttpConnRef; + + typedef std::function<TRequestData::TPtr(const TMessage&, const TParsedLocation&)> TRequestBuilder; + + class THttpRequest { + public: + class THandle: public TSimpleHandle { + public: + THandle(IOnRecv* f, const TMessage& msg, TStatCollector* s) noexcept + : TSimpleHandle(f, msg, s) + { + } + + bool MessageSendedCompletely() const noexcept override { + if (TSimpleHandle::MessageSendedCompletely()) { + return true; + } + + THttpRequestRef req(GetRequest()); + if (!!req && req->RequestSendedCompletely()) { + const_cast<THandle*>(this)->SetSendComplete(); + } + + return TSimpleHandle::MessageSendedCompletely(); + } + + void Cancel() noexcept override { + if (TSimpleHandle::Canceled()) { + return; + } + + THttpRequestRef req(GetRequest()); + if (!!req) { + TSimpleHandle::Cancel(); + req->Cancel(); + } + } + + void NotifyError(TErrorRef error, const THttpParser* rsp = nullptr) { +#ifdef DEBUG_STAT + ++TDebugStat::RequestFailed; +#endif + if (rsp) { + TSimpleHandle::NotifyError(error, rsp->DecodedContent(), rsp->FirstLine(), rsp->Headers()); + } else { + TSimpleHandle::NotifyError(error); + } + + ReleaseRequest(); + } + + //not thread safe! + void SetRequest(const TWeakPtrB<THttpRequest>& r) noexcept { + Req_ = r; + } + + private: + THttpRequestRef GetRequest() const noexcept { + TGuard<TSpinLock> g(SP_); + return Req_; + } + + void ReleaseRequest() noexcept { + TWeakPtrB<THttpRequest> tmp; + TGuard<TSpinLock> g(SP_); + tmp.Swap(Req_); + } + + mutable TSpinLock SP_; + TWeakPtrB<THttpRequest> Req_; + }; + + typedef TIntrusivePtr<THandle> THandleRef; + + static void Run(THandleRef& h, const TMessage& msg, TRequestBuilder f, const TRequestSettings& s) { + THttpRequestRef req(new THttpRequest(h, msg, f, s)); + req->WeakThis_ = req; + h->SetRequest(req->WeakThis_); + req->Run(req); + } + + ~THttpRequest() { + DBGOUT("~THttpRequest()"); + } + + private: + THttpRequest(THandleRef& h, TMessage msg, TRequestBuilder f, const TRequestSettings& s) + : Hndl_(h) + , RequestBuilder_(f) + , RequestSettings_(s) + , Msg_(std::move(msg)) + , Loc_(Msg_.Addr) + , Addr_(Resolve(Loc_.Host, Loc_.GetPort(), RequestSettings_.ResolverType)) + , AddrIter_(Addr_->Addr.Begin()) + , Canceled_(false) + , RequestSendedCompletely_(false) + { + } + + void Run(THttpRequestRef& req); + + public: + THttpRequestBuffersPtr BuildRequest() { + return new THttpRequestBuffers(RequestBuilder_(Msg_, Loc_)); + } + + TRequestSettings RequestSettings() { + return RequestSettings_; + } + + //can create a spare socket in an attempt to decrease connecting time + void OnDetectSlowConnecting(); + + //remove extra connection on success connec + void OnConnect(THttpConn* c); + + //have some response input + void OnBeginRead() noexcept { + RequestSendedCompletely_ = true; + } + + void OnResponse(TAutoPtr<THttpParser>& rsp); + + void OnConnectFailed(THttpConn* c, const TErrorCode& ec); + void OnSystemError(THttpConn* c, const TErrorCode& ec); + void OnError(THttpConn* c, const TString& errorText); + + bool RequestSendedCompletely() noexcept; + + void Cancel() noexcept; + + private: + void NotifyResponse(const TString& resp, const TString& firstLine, const THttpHeaders& headers) { + THandleRef h(ReleaseHandler()); + if (!!h) { + h->NotifyResponse(resp, firstLine, headers); + } + } + + void NotifyError( + const TString& errorText, + TError::TType errorType = TError::UnknownType, + i32 errorCode = 0, i32 systemErrorCode = 0) { + NotifyError(new TError(errorText, errorType, errorCode, systemErrorCode)); + } + + void NotifyError(TErrorRef error, const THttpParser* rsp = nullptr) { + THandleRef h(ReleaseHandler()); + if (!!h) { + h->NotifyError(error, rsp); + } + } + + void Finalize(THttpConn* skipConn = nullptr) noexcept; + + inline THandleRef ReleaseHandler() noexcept { + THandleRef h; + { + TGuard<TSpinLock> g(SL_); + h.Swap(Hndl_); + } + return h; + } + + inline THttpConnRef GetConn() noexcept { + TGuard<TSpinLock> g(SL_); + return Conn_; + } + + inline THttpConnRef ReleaseConn() noexcept { + THttpConnRef c; + { + TGuard<TSpinLock> g(SL_); + c.Swap(Conn_); + } + return c; + } + + inline THttpConnRef ReleaseConn2() noexcept { + THttpConnRef c; + { + TGuard<TSpinLock> g(SL_); + c.Swap(Conn2_); + } + return c; + } + + TSpinLock SL_; //guaranted calling notify() only once (prevent race between asio thread and current) + THandleRef Hndl_; + TRequestBuilder RequestBuilder_; + TRequestSettings RequestSettings_; + const TMessage Msg_; + const TParsedLocation Loc_; + const TResolvedHost* Addr_; + TNetworkAddress::TIterator AddrIter_; + THttpConnRef Conn_; + THttpConnRef Conn2_; //concurrent connection used, if detected slow connecting on first connection + TWeakPtrB<THttpRequest> WeakThis_; + TAtomicBool Canceled_; + TAtomicBool RequestSendedCompletely_; + }; + + TAtomicCounter* HttpOutConnCounter(); + + class THttpConn: public TThrRefBase { + public: + static THttpConnRef Create(TIOService& srv); + + ~THttpConn() override { + DBGOUT("~THttpConn()"); + Req_.Reset(); + HttpOutConnCounter()->Dec(); +#ifdef DEBUG_STAT + ++TDebugStat::ConnDestroyed; +#endif + } + + void StartRequest(THttpRequestRef req, const TEndpoint& ep, size_t addrId, TDuration slowConn, bool useAsyncSendRequest = false) { + { + //thread safe linking connection->request + TGuard<TSpinLock> g(SL_); + Req_ = req; + } + AddrId_ = addrId; + try { + TDuration connectDeadline(THttp2Options::ConnectTimeout); + if (THttp2Options::ConnectTimeout > slowConn) { + //use append non fatal connect deadline, so on first timedout + //report about slow connecting to THttpRequest, and continue wait ConnectDeadline_ period + connectDeadline = slowConn; + ConnectDeadline_ = THttp2Options::ConnectTimeout - slowConn; + } + DBGOUT("AsyncConnect to " << ep.IpToString()); + AS_.AsyncConnect(ep, std::bind(&THttpConn::OnConnect, THttpConnRef(this), _1, _2, useAsyncSendRequest), connectDeadline); + } catch (...) { + ReleaseRequest(); + throw; + } + } + + //start next request on keep-alive connection + bool StartNextRequest(THttpRequestRef& req, bool useAsyncSendRequest = false) { + if (Finalized_) { + return false; + } + + { + //thread safe linking connection->request + TGuard<TSpinLock> g(SL_); + Req_ = req; + } + + RequestWritten_ = false; + BeginReadResponse_ = false; + + try { + if (!useAsyncSendRequest) { + TErrorCode ec; + SendRequest(req->BuildRequest(), ec); //throw std::bad_alloc + if (ec.Value() == ECANCELED) { + OnCancel(); + } else if (ec) { + OnError(ec); + } + } else { + SendRequestAsync(req->BuildRequest()); //throw std::bad_alloc + } + } catch (...) { + OnError(CurrentExceptionMessage()); + throw; + } + return true; + } + + //connection received from cache must be validated before using + //(process removing closed conection from cache consume some time) + inline bool IsValid() const noexcept { + return !Finalized_; + } + + void SetCached(bool v) noexcept { + Cached_ = v; + } + + void Close() noexcept { + try { + Cancel(); + } catch (...) { + } + } + + void DetachRequest() noexcept { + ReleaseRequest(); + } + + void Cancel() { //throw std::bad_alloc + if (!Canceled_) { + Canceled_ = true; + Finalized_ = true; + OnCancel(); + AS_.AsyncCancel(); + } + } + + void OnCancel() { + THttpRequestRef r(ReleaseRequest()); + if (!!r) { + static const TString reqCanceled("request canceled"); + r->OnError(this, reqCanceled); + } + } + + bool RequestSendedCompletely() const noexcept { + DBGOUT("RequestSendedCompletely()"); + if (!Connected_ || !RequestWritten_) { + return false; + } + if (BeginReadResponse_) { + return true; + } +#if defined(FIONWRITE) + if (THttp2Options::EnsureSendingCompleteByAck) { + int nbytes = Max<int>(); + int err = ioctl(AS_.Native(), FIONWRITE, &nbytes); + return err ? false : nbytes == 0; + } +#endif + return true; + } + + TIOService& GetIOService() const noexcept { + return AS_.GetIOService(); + } + + private: + THttpConn(TIOService& srv) + : AddrId_(0) + , AS_(srv) + , BuffSize_(THttp2Options::InputBufferSize) + , Connected_(false) + , Cached_(false) + , Canceled_(false) + , Finalized_(false) + , InAsyncRead_(false) + , RequestWritten_(false) + , BeginReadResponse_(false) + { + HttpOutConnCounter()->Inc(); + } + + //can be called only from asio + void OnConnect(const TErrorCode& ec, IHandlingContext& ctx, bool useAsyncSendRequest = false) { + DBGOUT("THttpConn::OnConnect: " << ec.Value()); + if (Y_UNLIKELY(ec)) { + if (ec.Value() == ETIMEDOUT && ConnectDeadline_.GetValue()) { + //detect slow connecting (yet not reached final timeout) + DBGOUT("OnConnectTimingCheck"); + THttpRequestRef req(GetRequest()); + if (!req) { + return; //cancel from client thread can ahead us + } + TDuration newDeadline(ConnectDeadline_); + ConnectDeadline_ = TDuration::Zero(); //next timeout is final + + req->OnDetectSlowConnecting(); + //continue wait connect + ctx.ContinueUseHandler(newDeadline); + + return; + } +#ifdef DEBUG_STAT + if (ec.Value() != ECANCELED) { + ++TDebugStat::ConnFailed; + } else { + ++TDebugStat::ConnConnCanceled; + } +#endif + if (ec.Value() == EIO) { + //try get more detail error info + char buf[1]; + TErrorCode errConnect; + AS_.ReadSome(buf, 1, errConnect); + OnConnectFailed(errConnect.Value() ? errConnect : ec); + } else if (ec.Value() == ECANCELED) { + // not try connecting to next host ip addr, simple fail + OnError(ec); + } else { + OnConnectFailed(ec); + } + } else { + Connected_ = true; + + THttpRequestRef req(GetRequest()); + if (!req || Canceled_) { + return; + } + + try { + PrepareSocket(AS_.Native(), req->RequestSettings()); + if (THttp2Options::TcpKeepAlive) { + SetKeepAlive(AS_.Native(), true); + } + } catch (TSystemError& err) { + TErrorCode ec2(err.Status()); + OnError(ec2); + return; + } + + req->OnConnect(this); + + THttpRequestBuffersPtr ptr(req->BuildRequest()); + PrepareParser(); + + if (!useAsyncSendRequest) { + TErrorCode ec3; + SendRequest(ptr, ec3); + if (ec3) { + OnError(ec3); + } + } else { + SendRequestAsync(ptr); + } + } + } + + void PrepareParser() { + Prs_ = new THttpParser(); + Prs_->Prepare(); + } + + void SendRequest(const THttpRequestBuffersPtr& bfs, TErrorCode& ec) { //throw std::bad_alloc + if (!THttp2Options::UseAsyncSendRequest) { + size_t amount = AS_.WriteSome(*bfs->GetIOvec(), ec); + + if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK && ec.Value() != EINPROGRESS) { + return; + } + ec.Assign(0); + + bfs->GetIOvec()->Proceed(amount); + + if (bfs->GetIOvec()->Complete()) { + RequestWritten_ = true; + StartRead(); + } else { + SendRequestAsync(bfs); + } + } else { + SendRequestAsync(bfs); + } + } + + void SendRequestAsync(const THttpRequestBuffersPtr& bfs) { + NAsio::TTcpSocket::TSendedData sd(bfs.Release()); + AS_.AsyncWrite(sd, std::bind(&THttpConn::OnWrite, THttpConnRef(this), _1, _2, _3), THttp2Options::OutputDeadline); + } + + void OnWrite(const TErrorCode& err, size_t amount, IHandlingContext& ctx) { + Y_UNUSED(amount); + Y_UNUSED(ctx); + if (err) { + OnError(err); + } else { + DBGOUT("OnWrite()"); + RequestWritten_ = true; + StartRead(); + } + } + + inline void StartRead() { + if (!InAsyncRead_ && !Canceled_) { + InAsyncRead_ = true; + AS_.AsyncPollRead(std::bind(&THttpConn::OnCanRead, THttpConnRef(this), _1, _2), THttp2Options::InputDeadline); + } + } + + //can be called only from asio + void OnReadSome(const TErrorCode& err, size_t bytes, IHandlingContext& ctx) { + if (Y_UNLIKELY(err)) { + OnError(err); + return; + } + if (!BeginReadResponse_) { + //used in MessageSendedCompletely() + BeginReadResponse_ = true; + THttpRequestRef r(GetRequest()); + if (!!r) { + r->OnBeginRead(); + } + } + DBGOUT("receive:" << TStringBuf(Buff_.Get(), bytes)); + try { + if (!Prs_) { + throw yexception() << TStringBuf("receive some data while not in request"); + } + +#if defined(_linux_) + if (THttp2Options::QuickAck) { + SetSockOpt(AS_.Native(), SOL_TCP, TCP_QUICKACK, (int)1); + } +#endif + + DBGOUT("parse:"); + while (!Prs_->Parse(Buff_.Get(), bytes)) { + if (BuffSize_ == bytes) { + TErrorCode ec; + bytes = AS_.ReadSome(Buff_.Get(), BuffSize_, ec); + + if (!ec) { + continue; + } + + if (ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) { + OnError(ec); + + return; + } + } + //continue async. read from socket + ctx.ContinueUseHandler(THttp2Options::InputDeadline); + + return; + } + + //succesfully reach end of http response + THttpRequestRef r(ReleaseRequest()); + if (!r) { + //lost race to req. canceling + DBGOUT("connection failed"); + return; + } + + DBGOUT("response:"); + bool keepALive = Prs_->IsKeepAlive(); + + r->OnResponse(Prs_); + + if (!keepALive) { + return; + } + + //continue use connection (keep-alive mode) + PrepareParser(); + + if (!THttp2Options::KeepInputBufferForCachedConnections) { + Buff_.Destroy(); + } + //continue async. read from socket + ctx.ContinueUseHandler(THttp2Options::InputDeadline); + + PutSelfToCache(); + } catch (...) { + OnError(CurrentExceptionMessage()); + } + } + + void PutSelfToCache(); + + //method for reaction on input data for re-used keep-alive connection, + //which free/release buffer when was placed in cache + void OnCanRead(const TErrorCode& err, IHandlingContext& ctx) { + if (Y_UNLIKELY(err)) { + OnError(err); + } else { + if (!Buff_) { + Buff_.Reset(new char[BuffSize_]); + } + TErrorCode ec; + OnReadSome(ec, AS_.ReadSome(Buff_.Get(), BuffSize_, ec), ctx); + } + } + + //unlink connection and request, thread-safe mark connection as non valid + inline THttpRequestRef GetRequest() noexcept { + TGuard<TSpinLock> g(SL_); + return Req_; + } + + inline THttpRequestRef ReleaseRequest() noexcept { + THttpRequestRef r; + { + TGuard<TSpinLock> g(SL_); + r.Swap(Req_); + } + return r; + } + + void OnConnectFailed(const TErrorCode& ec); + + inline void OnError(const TErrorCode& ec) { + OnError(ec.Text()); + } + + inline void OnError(const TString& errText); + + size_t AddrId_; + NAsio::TTcpSocket AS_; + TArrayHolder<char> Buff_; //input buffer + const size_t BuffSize_; + + TAutoPtr<THttpParser> Prs_; //input data parser & parsed info storage + + TSpinLock SL_; + THttpRequestRef Req_; //current request + TDuration ConnectDeadline_; + TAtomicBool Connected_; + TAtomicBool Cached_; + TAtomicBool Canceled_; + TAtomicBool Finalized_; + + bool InAsyncRead_; + TAtomicBool RequestWritten_; + TAtomicBool BeginReadResponse_; + }; + + //conn limits monitoring, cache clean, contain used in http clients asio threads/executors + class THttpConnManager: public IThreadFactory::IThreadAble { + public: + THttpConnManager() + : TotalConn(0) + , EP_(THttp2Options::AsioThreads) + , InPurging_(0) + , MaxConnId_(0) + , Shutdown_(false) + { + T_ = SystemThreadFactory()->Run(this); + Limits.SetSoft(40000); + Limits.SetHard(50000); + } + + ~THttpConnManager() override { + { + TGuard<TMutex> g(PurgeMutex_); + + Shutdown_ = true; + CondPurge_.Signal(); + } + + EP_.SyncShutdown(); + + T_->Join(); + } + + inline void SetLimits(size_t softLimit, size_t hardLimit) noexcept { + Limits.SetSoft(softLimit); + Limits.SetHard(hardLimit); + } + + inline std::pair<size_t, size_t> GetLimits() const noexcept { + return {Limits.Soft(), Limits.Hard()}; + } + + inline void CheckLimits() { + if (ExceedSoftLimit()) { + SuggestPurgeCache(); + + if (ExceedHardLimit()) { + Y_FAIL("neh::http2 output connections limit reached"); + //ythrow yexception() << "neh::http2 output connections limit reached"; + } + } + } + + inline bool Get(THttpConnRef& conn, size_t addrId) { +#ifdef DEBUG_STAT + TDebugStat::ConnTotal.store(TotalConn.Val(), std::memory_order_release); + TDebugStat::ConnActive(Active(), std::memory_order_release); + TDebugStat::ConnCached(Cache_.Size(), std::memory_order_release); +#endif + return Cache_.Get(conn, addrId); + } + + inline void Put(THttpConnRef& conn, size_t addrId) { + if (Y_LIKELY(!Shutdown_ && !ExceedHardLimit() && !CacheDisabled())) { + if (Y_UNLIKELY(addrId > (size_t)AtomicGet(MaxConnId_))) { + AtomicSet(MaxConnId_, addrId); + } + Cache_.Put(conn, addrId); + } else { + conn->Close(); + conn.Drop(); + } + } + + inline size_t OnConnError(size_t addrId) { + return Cache_.Validate(addrId); + } + + TIOService& GetIOService() { + return EP_.GetExecutor().GetIOService(); + } + + bool CacheDisabled() const { + return Limits.Soft() == 0; + } + + bool IsShutdown() const noexcept { + return Shutdown_; + } + + TAtomicCounter TotalConn; + + private: + inline size_t Total() const noexcept { + return TotalConn.Val(); + } + + inline size_t Active() const noexcept { + return TFdLimits::ExceedLimit(Total(), Cache_.Size()); + } + + inline size_t ExceedSoftLimit() const noexcept { + return TFdLimits::ExceedLimit(Total(), Limits.Soft()); + } + + inline size_t ExceedHardLimit() const noexcept { + return TFdLimits::ExceedLimit(Total(), Limits.Hard()); + } + + void SuggestPurgeCache() { + if (AtomicTryLock(&InPurging_)) { + //evaluate the usefulness of purging the cache + //если в кеше мало соединений (< MaxConnId_/16 или 64), не чистим кеш + if (Cache_.Size() > (Min((size_t)AtomicGet(MaxConnId_), (size_t)1024U) >> 4)) { + //по мере приближения к hardlimit нужда в чистке cache приближается к 100% + size_t closenessToHardLimit256 = ((Active() + 1) << 8) / (Limits.Delta() + 1); + //чем больше соединений в кеше, а не в работе, тем менее нужен кеш (можно его почистить) + size_t cacheUselessness256 = ((Cache_.Size() + 1) << 8) / (Active() + 1); + + //итого, - пороги срабатывания: + //при достижении soft-limit, если соединения в кеше, а не в работе + //на полпути от soft-limit к hard-limit, если в кеше больше половины соединений + //при приближении к hardlimit пытаться почистить кеш почти постоянно + if ((closenessToHardLimit256 + cacheUselessness256) >= 256U) { + TGuard<TMutex> g(PurgeMutex_); + + CondPurge_.Signal(); + return; //memo: thread MUST unlock InPurging_ (see DoExecute()) + } + } + AtomicUnlock(&InPurging_); + } + } + + void DoExecute() override { + while (true) { + { + TGuard<TMutex> g(PurgeMutex_); + + if (Shutdown_) + return; + + CondPurge_.WaitI(PurgeMutex_); + } + + PurgeCache(); + + AtomicUnlock(&InPurging_); + } + } + + void PurgeCache() noexcept { + //try remove at least ExceedSoftLimit() oldest connections from cache + //вычисляем долю кеша, которую нужно почистить (в 256 долях) (но не менее 1/32 кеша) + size_t frac256 = Min(size_t(Max(size_t(256U / 32U), (ExceedSoftLimit() << 8) / (Cache_.Size() + 1))), (size_t)256U); + + size_t processed = 0; + size_t maxConnId = AtomicGet(MaxConnId_); + for (size_t i = 0; i <= maxConnId && !Shutdown_; ++i) { + processed += Cache_.Purge(i, frac256); + if (processed > 32) { +#ifdef DEBUG_STAT + TDebugStat::ConnPurgedInCache += processed; +#endif + processed = 0; + Sleep(TDuration::MilliSeconds(10)); //prevent big spike cpu/system usage + } + } + } + + TFdLimits Limits; + TExecutorsPool EP_; + + TConnCache<THttpConn> Cache_; + TAtomic InPurging_; + TAtomic MaxConnId_; + + TAutoPtr<IThreadFactory::IThread> T_; + TCondVar CondPurge_; + TMutex PurgeMutex_; + TAtomicBool Shutdown_; + }; + + THttpConnManager* HttpConnManager() { + return Singleton<THttpConnManager>(); + } + + TAtomicCounter* HttpOutConnCounter() { + return &HttpConnManager()->TotalConn; + } + + THttpConnRef THttpConn::Create(TIOService& srv) { + if (HttpConnManager()->IsShutdown()) { + throw yexception() << "can't create connection with shutdowned service"; + } + + return new THttpConn(srv); + } + + void THttpConn::PutSelfToCache() { + THttpConnRef c(this); + HttpConnManager()->Put(c, AddrId_); + } + + void THttpConn::OnConnectFailed(const TErrorCode& ec) { + THttpRequestRef r(GetRequest()); + if (!!r) { + r->OnConnectFailed(this, ec); + } + OnError(ec); + } + + void THttpConn::OnError(const TString& errText) { + Finalized_ = true; + if (Connected_) { + Connected_ = false; + TErrorCode ec; + AS_.Shutdown(NAsio::TTcpSocket::ShutdownBoth, ec); + } else { + if (AS_.IsOpen()) { + AS_.AsyncCancel(); + } + } + THttpRequestRef r(ReleaseRequest()); + if (!!r) { + r->OnError(this, errText); + } else { + if (Cached_) { + size_t res = HttpConnManager()->OnConnError(AddrId_); + Y_UNUSED(res); +#ifdef DEBUG_STAT + TDebugStat::ConnDestroyedInCache += res; +#endif + } + } + } + + void THttpRequest::Run(THttpRequestRef& req) { +#ifdef DEBUG_STAT + if ((++TDebugStat::RequestTotal & 0xFFF) == 0) { + TDebugStat::Print(); + } +#endif + try { + while (!Canceled_) { + THttpConnRef conn; + if (HttpConnManager()->Get(conn, Addr_->Id)) { + DBGOUT("Use connection from cache"); + Conn_ = conn; //thread magic + if (!conn->StartNextRequest(req, RequestSettings_.UseAsyncSendRequest)) { + continue; //if use connection from cache, ignore write error and try another conn + } + } else { + HttpConnManager()->CheckLimits(); //here throw exception if reach hard limit (or atexit() state) + Conn_ = THttpConn::Create(HttpConnManager()->GetIOService()); + TEndpoint ep(new NAddr::TAddrInfo(&*AddrIter_)); + Conn_->StartRequest(req, ep, Addr_->Id, THttp2Options::SymptomSlowConnect); // can throw + } + break; + } + } catch (...) { + Conn_.Reset(); + throw; + } + } + + //it seems we have lost TCP SYN packet, create extra connection for decrease response time + void THttpRequest::OnDetectSlowConnecting() { +#ifdef DEBUG_STAT + ++TDebugStat::ConnSlow; +#endif + //use some io_service (Run() thread-executor), from first conn. for more thread safety + THttpConnRef conn = GetConn(); + + if (!conn) { + return; + } + + THttpConnRef conn2; + try { + conn2 = THttpConn::Create(conn->GetIOService()); + } catch (...) { + return; // cant create spare connection, simple continue use only main + } + + { + TGuard<TSpinLock> g(SL_); + Conn2_ = conn2; + } + + if (Y_UNLIKELY(Canceled_)) { + ReleaseConn2(); + } else { + //use connect timeout for disable detecting slow connecting on second conn. + TEndpoint ep(new NAddr::TAddrInfo(&*Addr_->Addr.Begin())); + try { + conn2->StartRequest(WeakThis_, ep, Addr_->Id, THttp2Options::ConnectTimeout); + } catch (...) { + // ignore errors on spare connection + ReleaseConn2(); + } + } + } + + void THttpRequest::OnConnect(THttpConn* c) { + THttpConnRef extraConn; + { + TGuard<TSpinLock> g(SL_); + if (Y_UNLIKELY(!!Conn2_)) { + //has pair concurrent conn, 'should stay only one' + if (Conn2_.Get() == c) { +#ifdef DEBUG_STAT + ++TDebugStat::Conn2Success; +#endif + Conn2_.Swap(Conn_); + } + extraConn.Swap(Conn2_); + } + } + if (!!extraConn) { + extraConn->DetachRequest(); //prevent call OnError() + extraConn->Close(); + } + } + + void THttpRequest::OnResponse(TAutoPtr<THttpParser>& rsp) { + DBGOUT("THttpRequest::OnResponse()"); + ReleaseConn(); + if (Y_LIKELY(((rsp->RetCode() >= 200 && rsp->RetCode() < (!THttp2Options::RedirectionNotError ? 300 : 400)) || THttp2Options::AnyResponseIsNotError))) { + NotifyResponse(rsp->DecodedContent(), rsp->FirstLine(), rsp->Headers()); + } else { + TString message; + + if (THttp2Options::FullHeadersAsErrorMessage) { + TStringStream err; + err << rsp->FirstLine(); + + THttpHeaders hdrs = rsp->Headers(); + for (auto h = hdrs.begin(); h < hdrs.end(); h++) { + err << h->ToString() << TStringBuf("\r\n"); + } + + message = err.Str(); + } else if (THttp2Options::UseResponseAsErrorMessage) { + message = rsp->DecodedContent(); + } else { + TStringStream err; + err << TStringBuf("request failed(") << rsp->FirstLine() << TStringBuf(")"); + message = err.Str(); + } + + NotifyError(new TError(message, TError::ProtocolSpecific, rsp->RetCode()), rsp.Get()); + } + } + + void THttpRequest::OnConnectFailed(THttpConn* c, const TErrorCode& ec) { + DBGOUT("THttpRequest::OnConnectFailed()"); + //detach/discard failed conn, try connect to next ip addr (if can) + THttpConnRef cc(GetConn()); + if (c != cc.Get() || AddrIter_ == Addr_->Addr.End() || ++AddrIter_ == Addr_->Addr.End() || Canceled_) { + return OnSystemError(c, ec); + } + // can try next host addr + c->DetachRequest(); + c->Close(); + THttpConnRef nextConn; + try { + nextConn = THttpConn::Create(HttpConnManager()->GetIOService()); + } catch (...) { + OnSystemError(nullptr, ec); + return; + } + { + THttpConnRef nc = nextConn; + TGuard<TSpinLock> g(SL_); + Conn_.Swap(nc); + } + TEndpoint ep(new NAddr::TAddrInfo(&*AddrIter_)); + try { + nextConn->StartRequest(WeakThis_, ep, Addr_->Id, THttp2Options::SymptomSlowConnect); + } catch (...) { + OnError(nullptr, CurrentExceptionMessage()); + return; + } + + if (Canceled_) { + OnError(nullptr, "canceled"); + } + } + + void THttpRequest::OnSystemError(THttpConn* c, const TErrorCode& ec) { + DBGOUT("THttpRequest::OnSystemError()"); + NotifyError(ec.Text(), TError::TType::UnknownType, 0, ec.Value()); + Finalize(c); + } + + void THttpRequest::OnError(THttpConn* c, const TString& errorText) { + DBGOUT("THttpRequest::OnError()"); + NotifyError(errorText); + Finalize(c); + } + + bool THttpRequest::RequestSendedCompletely() noexcept { + if (RequestSendedCompletely_) { + return true; + } + + THttpConnRef c(GetConn()); + return !!c ? c->RequestSendedCompletely() : false; + } + + void THttpRequest::Cancel() noexcept { + if (!Canceled_) { + Canceled_ = true; + try { + static const TString canceled("Canceled"); + NotifyError(canceled, TError::Cancelled); + Finalize(); + } catch (...) { + } + } + } + + inline void FinalizeConn(THttpConnRef& c, THttpConn* skipConn) noexcept { + if (!!c && c.Get() != skipConn) { + c->DetachRequest(); + c->Close(); + } + } + + void THttpRequest::Finalize(THttpConn* skipConn) noexcept { + THttpConnRef c1(ReleaseConn()); + FinalizeConn(c1, skipConn); + THttpConnRef c2(ReleaseConn2()); + FinalizeConn(c2, skipConn); + } + + /////////////////////////////////// server side //////////////////////////////////// + + TAtomicCounter* HttpInConnCounter() { + return Singleton<TAtomicCounter>(); + } + + TFdLimits* HttpInConnLimits() { + return Singleton<TFdLimits>(); + } + + class THttpServer: public IRequester { + typedef TAutoPtr<TTcpAcceptor> TTcpAcceptorPtr; + typedef TAtomicSharedPtr<TTcpSocket> TTcpSocketRef; + class TConn; + typedef TSharedPtrB<TConn> TConnRef; + + class TRequest: public IHttpRequest { + public: + TRequest(TWeakPtrB<TConn>& c, TAutoPtr<THttpParser>& p) + : C_(c) + , P_(p) + , RemoteHost_(C_->RemoteHost()) + , CompressionScheme_(P_->GetBestCompressionScheme()) + , H_(TStringBuf(P_->FirstLine())) + { + } + + ~TRequest() override { + if (!!C_) { + try { + C_->SendError(Id(), 503, "service unavailable (request ignored)", P_->HttpVersion(), {}); + } catch (...) { + DBGOUT("~TRequest()::SendFail() exception"); + } + } + } + + TAtomicBase Id() const { + return Id_; + } + + protected: + TStringBuf Scheme() const override { + return TStringBuf("http"); + } + + TString RemoteHost() const override { + return RemoteHost_; + } + + TStringBuf Service() const override { + return TStringBuf(H_.Path).Skip(1); + } + + const THttpHeaders& Headers() const override { + return P_->Headers(); + } + + TStringBuf Method() const override { + return H_.Method; + } + + TStringBuf Body() const override { + return P_->DecodedContent(); + } + + TStringBuf Cgi() const override { + return H_.Cgi; + } + + TStringBuf RequestId() const override { + return TStringBuf(); + } + + bool Canceled() const override { + if (!C_) { + return false; + } + return C_->IsCanceled(); + } + + void SendReply(TData& data) override { + SendReply(data, TString(), HttpCodes::HTTP_OK); + } + + void SendReply(TData& data, const TString& headers, int httpCode) override { + if (!!C_) { + C_->Send(Id(), data, CompressionScheme_, P_->HttpVersion(), headers, httpCode); + C_.Reset(); + } + } + + void SendError(TResponseError err, const THttpErrorDetails& details) override { + static const unsigned errorToHttpCode[IRequest::MaxResponseError] = + { + 400, + 403, + 404, + 429, + 500, + 501, + 502, + 503, + 509}; + + if (!!C_) { + C_->SendError(Id(), errorToHttpCode[err], details.Details, P_->HttpVersion(), details.Headers); + C_.Reset(); + } + } + + static TAtomicBase NextId() { + static TAtomic idGenerator = 0; + TAtomicBase id = 0; + do { + id = AtomicIncrement(idGenerator); + } while (!id); + return id; + } + + TSharedPtrB<TConn> C_; + TAutoPtr<THttpParser> P_; + TString RemoteHost_; + TString CompressionScheme_; + TParsedHttpFull H_; + TAtomicBase Id_ = NextId(); + }; + + class TRequestGet: public TRequest { + public: + TRequestGet(TWeakPtrB<TConn>& c, TAutoPtr<THttpParser> p) + : TRequest(c, p) + { + } + + TStringBuf Data() const override { + return H_.Cgi; + } + }; + + class TRequestPost: public TRequest { + public: + TRequestPost(TWeakPtrB<TConn>& c, TAutoPtr<THttpParser> p) + : TRequest(c, p) + { + } + + TStringBuf Data() const override { + return P_->DecodedContent(); + } + }; + + class TConn { + private: + TConn(THttpServer& hs, const TTcpSocketRef& s) + : HS_(hs) + , AS_(s) + , RemoteHost_(NNeh::PrintHostByRfc(*AS_->RemoteEndpoint().Addr())) + , BuffSize_(THttp2Options::InputBufferSize) + , Buff_(new char[BuffSize_]) + , Canceled_(false) + , LeftRequestsToDisconnect_(hs.LimitRequestsPerConnection) + { + DBGOUT("THttpServer::TConn()"); + HS_.OnCreateConn(); + } + + inline TConnRef SelfRef() noexcept { + return WeakThis_; + } + + public: + static void Create(THttpServer& hs, const TTcpSocketRef& s) { + TSharedPtrB<TConn> conn(new TConn(hs, s)); + conn->WeakThis_ = conn; + conn->ExpectNewRequest(); + conn->AS_->AsyncPollRead(std::bind(&TConn::OnCanRead, conn, _1, _2), THttp2Options::ServerInputDeadline); + } + + ~TConn() { + DBGOUT("~THttpServer::TConn(" << (!AS_ ? -666 : AS_->Native()) << ")"); + HS_.OnDestroyConn(); + } + + private: + void ExpectNewRequest() { + P_.Reset(new THttpParser(THttpParser::Request)); + P_->Prepare(); + } + + void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) { + if (ec) { + OnError(); + } else { + TErrorCode ec2; + OnReadSome(ec2, AS_->ReadSome(Buff_.Get(), BuffSize_, ec2), ctx); + } + } + + void OnError() { + DBGOUT("Srv OnError(" << (!AS_ ? -666 : AS_->Native()) << ")"); + Canceled_ = true; + AS_->AsyncCancel(); + } + + void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) { + if (ec || !amount) { + OnError(); + + return; + } + + DBGOUT("ReadSome(" << (!AS_ ? -666 : AS_->Native()) << "): " << amount); + try { + size_t buffPos = 0; + //DBGOUT("receive and parse: " << TStringBuf(Buff_.Get(), amount)); + while (P_->Parse(Buff_.Get() + buffPos, amount - buffPos)) { + SeenMessageWithoutKeepalive_ |= !P_->IsKeepAlive() || LeftRequestsToDisconnect_ == 1; + char rt = *P_->FirstLine().data(); + const size_t extraDataSize = P_->GetExtraDataSize(); + if (rt == 'P' || rt == 'p') { + OnRequest(new TRequestPost(WeakThis_, P_)); + } else { + OnRequest(new TRequestGet(WeakThis_, P_)); + } + if (extraDataSize) { + // has http pipelining + buffPos = amount - extraDataSize; + ExpectNewRequest(); + } else { + ExpectNewRequest(); + ctx.ContinueUseHandler(HS_.GetKeepAliveTimeout()); + return; + } + } + ctx.ContinueUseHandler(THttp2Options::ServerInputDeadline); + } catch (...) { + OnError(); + } + } + + void OnRequest(TRequest* r) { + DBGOUT("OnRequest()"); + if (AtomicGet(PrimaryResponse_)) { + // has pipelining + PipelineOrder_.Enqueue(r->Id()); + } else { + AtomicSet(PrimaryResponse_, r->Id()); + } + HS_.OnRequest(r); + OnRequestDone(); + } + + void OnRequestDone() { + DBGOUT("OnRequestDone()"); + if (LeftRequestsToDisconnect_ > 0) { + --LeftRequestsToDisconnect_; + } + } + + static void PrintHttpVersion(IOutputStream& out, const THttpVersion& ver) { + out << TStringBuf("HTTP/") << ver.Major << TStringBuf(".") << ver.Minor; + } + + struct TResponseData : TThrRefBase { + TResponseData(size_t reqId, TTcpSocket::TSendedData data) + : RequestId_(reqId) + , Data_(data) + { + } + + size_t RequestId_; + TTcpSocket::TSendedData Data_; + }; + typedef TIntrusivePtr<TResponseData> TResponseDataRef; + + public: + //called non thread-safe (from outside thread) + void Send(TAtomicBase requestId, TData& data, const TString& compressionScheme, const THttpVersion& ver, const TString& headers, int httpCode) { + class THttpResponseFormatter { + public: + THttpResponseFormatter(TData& theData, const TString& contentEncoding, const THttpVersion& theVer, const TString& theHeaders, int theHttpCode, bool closeConnection) { + Header.Reserve(128 + contentEncoding.size() + theHeaders.size()); + PrintHttpVersion(Header, theVer); + Header << TStringBuf(" ") << theHttpCode << ' ' << HttpCodeStr(theHttpCode); + if (Compress(theData, contentEncoding)) { + Header << TStringBuf("\r\nContent-Encoding: ") << contentEncoding; + } + Header << TStringBuf("\r\nContent-Length: ") << theData.size(); + if (closeConnection) { + Header << TStringBuf("\r\nConnection: close"); + } else if (Y_LIKELY(theVer.Major > 1 || theVer.Minor > 0)) { + // since HTTP/1.1 Keep-Alive is default behaviour + Header << TStringBuf("\r\nConnection: Keep-Alive"); + } + if (theHeaders) { + Header << theHeaders; + } + Header << TStringBuf("\r\n\r\n"); + + Body.swap(theData); + + Parts[0].buf = Header.Data(); + Parts[0].len = Header.Size(); + Parts[1].buf = Body.data(); + Parts[1].len = Body.size(); + } + + TStringStream Header; + TData Body; + IOutputStream::TPart Parts[2]; + }; + + class TBuffers: public THttpResponseFormatter, public TTcpSocket::IBuffers { + public: + TBuffers(TData& theData, const TString& contentEncoding, const THttpVersion& theVer, const TString& theHeaders, int theHttpCode, bool closeConnection) + : THttpResponseFormatter(theData, contentEncoding, theVer, theHeaders, theHttpCode, closeConnection) + , IOVec(Parts, 2) + { + } + + TContIOVector* GetIOvec() override { + return &IOVec; + } + + TContIOVector IOVec; + }; + + TTcpSocket::TSendedData sd(new TBuffers(data, compressionScheme, ver, headers, httpCode, SeenMessageWithoutKeepalive_)); + SendData(requestId, sd); + } + + //called non thread-safe (from outside thread) + void SendError(TAtomicBase requestId, unsigned httpCode, const TString& descr, const THttpVersion& ver, const TString& headers) { + if (Canceled_) { + return; + } + + class THttpErrorResponseFormatter { + public: + THttpErrorResponseFormatter(unsigned theHttpCode, const TString& theDescr, const THttpVersion& theVer, bool closeConnection, const TString& headers) { + PrintHttpVersion(Answer, theVer); + Answer << TStringBuf(" ") << theHttpCode << TStringBuf(" "); + if (theDescr.size() && !THttp2Options::ErrorDetailsAsResponseBody) { + // Reason-Phrase = *<TEXT, excluding CR, LF> + // replace bad chars to '.' + TString reasonPhrase = theDescr; + for (TString::iterator it = reasonPhrase.begin(); it != reasonPhrase.end(); ++it) { + char& ch = *it; + if (ch == ' ') { + continue; + } + if (((ch & 31) == ch) || static_cast<unsigned>(ch) == 127 || (static_cast<unsigned>(ch) & 0x80)) { + //CTLs || DEL(127) || non ascii + // (ch <= 32) || (ch >= 127) + ch = '.'; + } + } + Answer << reasonPhrase; + } else { + Answer << HttpCodeStr(static_cast<int>(theHttpCode)); + } + + if (closeConnection) { + Answer << TStringBuf("\r\nConnection: close"); + } + + if (headers) { + Answer << "\r\n" << headers; + } + + if (THttp2Options::ErrorDetailsAsResponseBody) { + Answer << TStringBuf("\r\nContent-Length:") << theDescr.size() << "\r\n\r\n" << theDescr; + } else { + Answer << "\r\n" + "Content-Length:0\r\n\r\n"sv; + } + + Parts[0].buf = Answer.Data(); + Parts[0].len = Answer.Size(); + } + + TStringStream Answer; + IOutputStream::TPart Parts[1]; + }; + + class TBuffers: public THttpErrorResponseFormatter, public TTcpSocket::IBuffers { + public: + TBuffers( + unsigned theHttpCode, + const TString& theDescr, + const THttpVersion& theVer, + bool closeConnection, + const TString& headers + ) + : THttpErrorResponseFormatter(theHttpCode, theDescr, theVer, closeConnection, headers) + , IOVec(Parts, 1) + { + } + + TContIOVector* GetIOvec() override { + return &IOVec; + } + + TContIOVector IOVec; + }; + + TTcpSocket::TSendedData sd(new TBuffers(httpCode, descr, ver, SeenMessageWithoutKeepalive_, headers)); + SendData(requestId, sd); + } + + void ProcessPipeline() { + // on successfull response to current (PrimaryResponse_) request + TAtomicBase requestId; + if (PipelineOrder_.Dequeue(&requestId)) { + TAtomicBase oldReqId; + do { + oldReqId = AtomicGet(PrimaryResponse_); + Y_VERIFY(oldReqId, "race inside http pipelining"); + } while (!AtomicCas(&PrimaryResponse_, requestId, oldReqId)); + + ProcessResponsesData(); + } else { + TAtomicBase oldReqId = AtomicGet(PrimaryResponse_); + if (oldReqId) { + while (!AtomicCas(&PrimaryResponse_, 0, oldReqId)) { + Y_VERIFY(oldReqId == AtomicGet(PrimaryResponse_), "race inside http pipelining [2]"); + } + } + } + } + + void ProcessResponsesData() { + // process responses data queue, send response (if already have next PrimaryResponse_) + TResponseDataRef rd; + while (ResponsesDataQueue_.Dequeue(&rd)) { + ResponsesData_[rd->RequestId_] = rd; + } + TAtomicBase requestId = AtomicGet(PrimaryResponse_); + if (requestId) { + THashMap<TAtomicBase, TResponseDataRef>::iterator it = ResponsesData_.find(requestId); + if (it != ResponsesData_.end()) { + // has next primary response + rd = it->second; + ResponsesData_.erase(it); + AS_->AsyncWrite(rd->Data_, std::bind(&TConn::OnSend, SelfRef(), _1, _2, _3), THttp2Options::ServerOutputDeadline); + } + } + } + + private: + void SendData(TAtomicBase requestId, TTcpSocket::TSendedData sd) { + TContIOVector& vec = *sd->GetIOvec(); + + if (requestId != AtomicGet(PrimaryResponse_)) { + // already has another request for response first, so push this to queue + // + enqueue event for safe checking queue (at local/transport thread) + TResponseDataRef rdr = new TResponseData(requestId, sd); + ResponsesDataQueue_.Enqueue(rdr); + AS_->GetIOService().Post(std::bind(&TConn::ProcessResponsesData, SelfRef())); + return; + } + if (THttp2Options::ServerUseDirectWrite) { + vec.Proceed(AS_->WriteSome(vec)); + } + if (!vec.Complete()) { + DBGOUT("AsyncWrite()"); + AS_->AsyncWrite(sd, std::bind(&TConn::OnSend, SelfRef(), _1, _2, _3), THttp2Options::ServerOutputDeadline); + } else { + // run ProcessPipeline at safe thread + AS_->GetIOService().Post(std::bind(&TConn::ProcessPipeline, SelfRef())); + } + } + + void OnSend(const TErrorCode& ec, size_t amount, IHandlingContext&) { + Y_UNUSED(amount); + if (ec) { + OnError(); + } else { + ProcessPipeline(); + } + + if (SeenMessageWithoutKeepalive_) { + TErrorCode shutdown_ec; + AS_->Shutdown(TTcpSocket::ShutdownBoth, shutdown_ec); + } + } + + public: + bool IsCanceled() const noexcept { + return Canceled_; + } + + const TString& RemoteHost() const noexcept { + return RemoteHost_; + } + + private: + TWeakPtrB<TConn> WeakThis_; + THttpServer& HS_; + TTcpSocketRef AS_; + TString RemoteHost_; + size_t BuffSize_; + TArrayHolder<char> Buff_; + TAutoPtr<THttpParser> P_; + // pipeline supporting + TAtomic PrimaryResponse_ = 0; + TLockFreeQueue<TAtomicBase> PipelineOrder_; + TLockFreeQueue<TResponseDataRef> ResponsesDataQueue_; + THashMap<TAtomicBase, TResponseDataRef> ResponsesData_; + + TAtomicBool Canceled_; + bool SeenMessageWithoutKeepalive_ = false; + + i32 LeftRequestsToDisconnect_ = -1; + }; + + /////////////////////////////////////////////////////////// + + public: + THttpServer(IOnRequest* cb, const TParsedLocation& loc) + : E_(THttp2Options::AsioServerThreads) + , CB_(cb) + , LimitRequestsPerConnection(THttp2Options::LimitRequestsPerConnection) + { + TNetworkAddress addr(loc.GetPort()); + + for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { + TEndpoint ep(new NAddr::TAddrInfo(&*it)); + TTcpAcceptorPtr a(new TTcpAcceptor(AcceptExecutor_.GetIOService())); + DBGOUT("bind:" << ep.IpToString() << ":" << ep.Port()); + a->Bind(ep); + a->Listen(THttp2Options::Backlog); + StartAccept(a.Get()); + A_.push_back(a); + } + } + + ~THttpServer() override { + AcceptExecutor_.SyncShutdown(); //cancel operation for all current sockets (include acceptors) + A_.clear(); //stop listening + E_.SyncShutdown(); + } + + void OnAccept(TTcpAcceptor* a, TAtomicSharedPtr<TTcpSocket> s, const TErrorCode& ec, IHandlingContext&) { + if (Y_UNLIKELY(ec)) { + if (ec.Value() == ECANCELED) { + return; + } else if (ec.Value() == EMFILE || ec.Value() == ENFILE || ec.Value() == ENOMEM || ec.Value() == ENOBUFS) { + //reach some os limit, suspend accepting + TAtomicSharedPtr<TDeadlineTimer> dt(new TDeadlineTimer(a->GetIOService())); + dt->AsyncWaitExpireAt(TDuration::Seconds(30), std::bind(&THttpServer::OnTimeoutSuspendAccept, this, a, dt, _1, _2)); + return; + } else { + Cdbg << "acc: " << ec.Text() << Endl; + } + } else { + if (static_cast<size_t>(HttpInConnCounter()->Val()) < HttpInConnLimits()->Hard()) { + try { + SetNonBlock(s->Native()); + PrepareSocket(s->Native()); + TConn::Create(*this, s); + } catch (TSystemError& err) { + TErrorCode ec2(err.Status()); + Cdbg << "acc: " << ec2.Text() << Endl; + } + } //else accepted socket will be closed + } + StartAccept(a); //continue accepting + } + + void OnTimeoutSuspendAccept(TTcpAcceptor* a, TAtomicSharedPtr<TDeadlineTimer>, const TErrorCode& ec, IHandlingContext&) { + if (!ec) { + DBGOUT("resume acceptor") + StartAccept(a); + } + } + + void OnRequest(IRequest* r) { + try { + CB_->OnRequest(r); + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + } + } + + protected: + void OnCreateConn() noexcept { + HttpInConnCounter()->Inc(); + } + + void OnDestroyConn() noexcept { + HttpInConnCounter()->Dec(); + } + + TDuration GetKeepAliveTimeout() const noexcept { + size_t cc = HttpInConnCounter()->Val(); + TFdLimits lim(*HttpInConnLimits()); + + if (!TFdLimits::ExceedLimit(cc, lim.Soft())) { + return THttp2Options::ServerInputDeadlineKeepAliveMax; + } + + if (cc > lim.Hard()) { + cc = lim.Hard(); + } + TDuration::TValue softTuneRange = THttp2Options::ServerInputDeadlineKeepAliveMax.Seconds() - THttp2Options::ServerInputDeadlineKeepAliveMin.Seconds(); + + return TDuration::Seconds((softTuneRange * (cc - lim.Soft())) / (lim.Hard() - lim.Soft() + 1)) + THttp2Options::ServerInputDeadlineKeepAliveMin; + } + + private: + void StartAccept(TTcpAcceptor* a) { + TAtomicSharedPtr<TTcpSocket> s(new TTcpSocket(E_.Size() ? E_.GetExecutor().GetIOService() : AcceptExecutor_.GetIOService())); + a->AsyncAccept(*s, std::bind(&THttpServer::OnAccept, this, a, s, _1, _2)); + } + + TIOServiceExecutor AcceptExecutor_; + TVector<TTcpAcceptorPtr> A_; + TExecutorsPool E_; + IOnRequest* CB_; + + public: + const i32 LimitRequestsPerConnection; + }; + + template <class T> + class THttp2Protocol: public IProtocol { + public: + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new THttpServer(cb, loc); + } + + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + THttpRequest::THandleRef ret(new THttpRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + try { + THttpRequest::Run(ret, msg, &T::Build, T::RequestSettings()); + } catch (...) { + ret->ResetOnRecv(); + throw; + } + return ret.Get(); + } + + THandleRef ScheduleAsyncRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss, bool useAsyncSendRequest) override { + THttpRequest::THandleRef ret(new THttpRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + try { + auto requestSettings = T::RequestSettings(); + requestSettings.SetUseAsyncSendRequest(useAsyncSendRequest); + THttpRequest::Run(ret, msg, &T::Build, requestSettings); + } catch (...) { + ret->ResetOnRecv(); + throw; + } + return ret.Get(); + } + + TStringBuf Scheme() const noexcept override { + return T::Name(); + } + + bool SetOption(TStringBuf name, TStringBuf value) override { + return THttp2Options::Set(name, value); + } + }; +} + +namespace NNeh { + IProtocol* Http1Protocol() { + return Singleton<THttp2Protocol<TRequestGet1>>(); + } + IProtocol* Post1Protocol() { + return Singleton<THttp2Protocol<TRequestPost1>>(); + } + IProtocol* Full1Protocol() { + return Singleton<THttp2Protocol<TRequestFull1>>(); + } + IProtocol* Http2Protocol() { + return Singleton<THttp2Protocol<TRequestGet2>>(); + } + IProtocol* Post2Protocol() { + return Singleton<THttp2Protocol<TRequestPost2>>(); + } + IProtocol* Full2Protocol() { + return Singleton<THttp2Protocol<TRequestFull2>>(); + } + IProtocol* UnixSocketGetProtocol() { + return Singleton<THttp2Protocol<TRequestUnixSocketGet>>(); + } + IProtocol* UnixSocketPostProtocol() { + return Singleton<THttp2Protocol<TRequestUnixSocketPost>>(); + } + IProtocol* UnixSocketFullProtocol() { + return Singleton<THttp2Protocol<TRequestUnixSocketFull>>(); + } + + void SetHttp2OutputConnectionsLimits(size_t softLimit, size_t hardLimit) { + HttpConnManager()->SetLimits(softLimit, hardLimit); + } + + void SetHttp2InputConnectionsLimits(size_t softLimit, size_t hardLimit) { + HttpInConnLimits()->SetSoft(softLimit); + HttpInConnLimits()->SetHard(hardLimit); + } + + TAtomicBase GetHttpOutputConnectionCount() { + return HttpOutConnCounter()->Val(); + } + + std::pair<size_t, size_t> GetHttpOutputConnectionLimits() { + return HttpConnManager()->GetLimits(); + } + + TAtomicBase GetHttpInputConnectionCount() { + return HttpInConnCounter()->Val(); + } + + void SetHttp2InputConnectionsTimeouts(unsigned minSeconds, unsigned maxSeconds) { + THttp2Options::ServerInputDeadlineKeepAliveMin = TDuration::Seconds(minSeconds); + THttp2Options::ServerInputDeadlineKeepAliveMax = TDuration::Seconds(maxSeconds); + } + + class TUnixSocketResolver { + public: + NDns::TResolvedHost* Resolve(const TString& path) { + TString unixSocketPath = path; + if (path.size() > 2 && path[0] == '[' && path[path.size() - 1] == ']') { + unixSocketPath = path.substr(1, path.size() - 2); + } + + if (auto resolvedUnixSocket = ResolvedUnixSockets_.FindPtr(unixSocketPath)) { + return resolvedUnixSocket->Get(); + } + + TNetworkAddress na{TUnixSocketPath(unixSocketPath)}; + ResolvedUnixSockets_[unixSocketPath] = MakeHolder<NDns::TResolvedHost>(unixSocketPath, na); + + return ResolvedUnixSockets_[unixSocketPath].Get(); + } + + private: + THashMap<TString, THolder<NDns::TResolvedHost>> ResolvedUnixSockets_; + }; + + TUnixSocketResolver* UnixSocketResolver() { + return FastTlsSingleton<TUnixSocketResolver>(); + } + + const NDns::TResolvedHost* Resolve(const TStringBuf host, ui16 port, NHttp::EResolverType resolverType) { + if (resolverType == EResolverType::EUNIXSOCKET) { + return UnixSocketResolver()->Resolve(TString(host)); + } + return NDns::CachedResolve(NDns::TResolveInfo(host, port)); + + } +} diff --git a/library/cpp/neh/http2.h b/library/cpp/neh/http2.h new file mode 100644 index 0000000000..7bc16affa0 --- /dev/null +++ b/library/cpp/neh/http2.h @@ -0,0 +1,119 @@ +#pragma once + +#include "factory.h" +#include "http_common.h" + +#include <util/datetime/base.h> +#include <library/cpp/dns/cache.h> +#include <utility> + +namespace NNeh { + IProtocol* Http1Protocol(); + IProtocol* Post1Protocol(); + IProtocol* Full1Protocol(); + IProtocol* Http2Protocol(); + IProtocol* Post2Protocol(); + IProtocol* Full2Protocol(); + IProtocol* UnixSocketGetProtocol(); + IProtocol* UnixSocketPostProtocol(); + IProtocol* UnixSocketFullProtocol(); + + //global options + struct THttp2Options { + //connect timeout + static TDuration ConnectTimeout; + + //input and output timeouts + static TDuration InputDeadline; + static TDuration OutputDeadline; + + //when detected slow connection, will be runned concurrent parallel connection + //not used, if SymptomSlowConnect > ConnectTimeout + static TDuration SymptomSlowConnect; + + //input buffer size + static size_t InputBufferSize; + + //http client input buffer politic + static bool KeepInputBufferForCachedConnections; + + //asio threads + static size_t AsioThreads; + + //asio server threads, - if == 0, use acceptor thread for read/parse incoming requests + //esle use one thread for accepting + AsioServerThreads for process established tcp connections + static size_t AsioServerThreads; + + //use ACK for ensure completely sending request (call ioctl() for checking emptiness output buffer) + //reliable check, but can spend to much time (40ms or like it) (see Wikipedia: TCP delayed acknowledgment) + //disabling this option reduce sending validation to established connection and written all request data to socket buffer + static bool EnsureSendingCompleteByAck; + + //listen socket queue limit + static int Backlog; + + //expecting receiving request data right after connect or inside receiving request data + static TDuration ServerInputDeadline; + + //timelimit for sending response data + static TDuration ServerOutputDeadline; + + //expecting receiving request for keep-alived socket + //(Max - if not reached SoftLimit, Min, if reached Hard limit) + static TDuration ServerInputDeadlineKeepAliveMax; + static TDuration ServerInputDeadlineKeepAliveMin; + + //try write data into socket fd in contex handler->SendReply() call + //(instead moving write job to asio thread) + //this reduce sys_cpu load (less sys calls), but increase user_cpu and response time + static bool ServerUseDirectWrite; + + //use http response body as error message + static bool UseResponseAsErrorMessage; + + //pass all http response headers as error message + static bool FullHeadersAsErrorMessage; + + //use details (SendError argument) as response body + static bool ErrorDetailsAsResponseBody; + + //consider responses with 3xx code as successful + static bool RedirectionNotError; + + //consider response with any code as successful + static bool AnyResponseIsNotError; + + //enable tcp keepalive for outgoing requests sockets + static bool TcpKeepAlive; + + //enable limit requests per keep alive connection + static i32 LimitRequestsPerConnection; + + //enable TCP_QUICKACK + static bool QuickAck; + + // enable write to socket via ScheduleOp + static bool UseAsyncSendRequest; + + //set option, - return false, if option name not recognized + static bool Set(TStringBuf name, TStringBuf value); + }; + + /// if exceed soft limit, reduce quantity unused connections in cache + void SetHttp2OutputConnectionsLimits(size_t softLimit, size_t hardLimit); + + /// if exceed soft limit, reduce quantity unused connections in cache + void SetHttp2InputConnectionsLimits(size_t softLimit, size_t hardLimit); + + /// for debug and monitoring purposes + TAtomicBase GetHttpOutputConnectionCount(); + TAtomicBase GetHttpInputConnectionCount(); + std::pair<size_t, size_t> GetHttpOutputConnectionLimits(); + + /// unused input sockets keepalive timeouts + /// real(used) timeout: + /// - max, if not reached soft limit + /// - min, if reached hard limit + /// - approx. linear changed[max..min], while conn. count in range [soft..hard] + void SetHttp2InputConnectionsTimeouts(unsigned minSeconds, unsigned maxSeconds); +} diff --git a/library/cpp/neh/http_common.cpp b/library/cpp/neh/http_common.cpp new file mode 100644 index 0000000000..7ae466c31a --- /dev/null +++ b/library/cpp/neh/http_common.cpp @@ -0,0 +1,235 @@ +#include "http_common.h" + +#include "location.h" +#include "http_headers.h" + +#include <util/generic/array_ref.h> +#include <util/generic/singleton.h> +#include <util/stream/length.h> +#include <util/stream/null.h> +#include <util/stream/str.h> +#include <util/string/ascii.h> + +using NNeh::NHttp::ERequestType; + +namespace { + bool IsEmpty(const TStringBuf url) { + return url.empty(); + } + + void WriteImpl(const TStringBuf url, IOutputStream& out) { + out << url; + } + + bool IsEmpty(const TConstArrayRef<TString> urlParts) { + return urlParts.empty(); + } + + void WriteImpl(const TConstArrayRef<TString> urlParts, IOutputStream& out) { + NNeh::NHttp::JoinUrlParts(urlParts, out); + } + + template <typename T> + size_t GetLength(const T& urlParts) { + TCountingOutput out(&Cnull); + WriteImpl(urlParts, out); + return out.Counter(); + } + + template <typename T> + void WriteUrl(const T& urlParts, IOutputStream& out) { + if (!IsEmpty(urlParts)) { + out << '?'; + WriteImpl(urlParts, out); + } + } +} + +namespace NNeh { + namespace NHttp { + size_t GetUrlPartsLength(const TConstArrayRef<TString> urlParts) { + size_t res = 0; + + for (const auto& u : urlParts) { + res += u.length(); + } + + if (urlParts.size() > 0) { + res += urlParts.size() - 1; //'&' between parts + } + + return res; + } + + void JoinUrlParts(const TConstArrayRef<TString> urlParts, IOutputStream& out) { + if (urlParts.empty()) { + return; + } + + out << urlParts[0]; + + for (size_t i = 1; i < urlParts.size(); ++i) { + out << '&' << urlParts[i]; + } + } + + void WriteUrlParts(const TConstArrayRef<TString> urlParts, IOutputStream& out) { + WriteUrl(urlParts, out); + } + } +} + +namespace { + const TStringBuf schemeHttps = "https"; + const TStringBuf schemeHttp = "http"; + const TStringBuf schemeHttp2 = "http2"; + const TStringBuf schemePost = "post"; + const TStringBuf schemePosts = "posts"; + const TStringBuf schemePost2 = "post2"; + const TStringBuf schemeFull = "full"; + const TStringBuf schemeFulls = "fulls"; + const TStringBuf schemeHttpUnix = "http+unix"; + const TStringBuf schemePostUnix = "post+unix"; + + /* + @brief SafeWriteHeaders write headers from hdrs to out with some checks: + - filter out Content-Lenthgh because we'll add it ourselfs later. + + @todo ensure headers right formatted (now receive from perl report bad format headers) + */ + void SafeWriteHeaders(IOutputStream& out, TStringBuf hdrs) { + NNeh::NHttp::THeaderSplitter splitter(hdrs); + TStringBuf msgHdr; + while (splitter.Next(msgHdr)) { + if (!AsciiHasPrefixIgnoreCase(msgHdr, TStringBuf("Content-Length"))) { + out << msgHdr << TStringBuf("\r\n"); + } + } + } + + template <typename T, typename W> + TString BuildRequest(const NNeh::TParsedLocation& loc, const T& urlParams, const TStringBuf headers, const W& content, const TStringBuf contentType, ERequestType requestType, NNeh::NHttp::ERequestFlags requestFlags) { + const bool isAbsoluteUri = requestFlags.HasFlags(NNeh::NHttp::ERequestFlag::AbsoluteUri); + + const auto contentLength = GetLength(content); + TStringStream out; + out.Reserve(loc.Service.length() + loc.Host.length() + GetLength(urlParams) + headers.length() + contentType.length() + contentLength + (isAbsoluteUri ? (loc.Host.length() + 13) : 0) // 13 - is a max port number length + scheme length + + 96); //just some extra space + + Y_ASSERT(requestType != ERequestType::Any); + out << requestType; + out << ' '; + if (isAbsoluteUri) { + out << loc.Scheme << TStringBuf("://") << loc.Host; + if (loc.Port) { + out << ':' << loc.Port; + } + } + out << '/' << loc.Service; + + WriteUrl(urlParams, out); + out << TStringBuf(" HTTP/1.1\r\n"); + + NNeh::NHttp::WriteHostHeaderIfNot(out, loc.Host, loc.Port, headers); + SafeWriteHeaders(out, headers); + if (!IsEmpty(content)) { + if (!!contentType && headers.find(TStringBuf("Content-Type:")) == TString::npos) { + out << TStringBuf("Content-Type: ") << contentType << TStringBuf("\r\n"); + } + out << TStringBuf("Content-Length: ") << contentLength << TStringBuf("\r\n"); + out << TStringBuf("\r\n"); + WriteImpl(content, out); + } else { + out << TStringBuf("\r\n"); + } + return out.Str(); + } + + bool NeedGetRequestFor(TStringBuf scheme) { + return scheme == schemeHttp2 || scheme == schemeHttp || scheme == schemeHttps || scheme == schemeHttpUnix; + } + + bool NeedPostRequestFor(TStringBuf scheme) { + return scheme == schemePost2 || scheme == schemePost || scheme == schemePosts || scheme == schemePostUnix; + } + + inline ERequestType ChooseReqType(ERequestType userReqType, ERequestType defaultReqType) { + Y_ASSERT(defaultReqType != ERequestType::Any); + return userReqType != ERequestType::Any ? userReqType : defaultReqType; + } +} + +namespace NNeh { + namespace NHttp { + const TStringBuf DefaultContentType = "application/x-www-form-urlencoded"; + + template <typename T> + bool MakeFullRequestImpl(TMessage& msg, const TStringBuf proxy, const T& urlParams, const TStringBuf headers, const TStringBuf content, const TStringBuf contentType, ERequestType reqType, ERequestFlags reqFlags) { + NNeh::TParsedLocation loc(msg.Addr); + + if (content.size()) { + //content MUST be placed inside POST requests + if (!IsEmpty(urlParams)) { + if (NeedGetRequestFor(loc.Scheme)) { + msg.Data = BuildRequest(loc, urlParams, headers, content, contentType, ChooseReqType(reqType, ERequestType::Post), reqFlags); + } else { + // cannot place in first header line potentially unsafe data from POST message + // (can contain forbidden for url-path characters) + // so support such mutation only for GET requests + return false; + } + } else { + if (NeedGetRequestFor(loc.Scheme) || NeedPostRequestFor(loc.Scheme)) { + msg.Data = BuildRequest(loc, urlParams, headers, content, contentType, ChooseReqType(reqType, ERequestType::Post), reqFlags); + } else { + return false; + } + } + } else { + if (NeedGetRequestFor(loc.Scheme)) { + msg.Data = BuildRequest(loc, urlParams, headers, "", "", ChooseReqType(reqType, ERequestType::Get), reqFlags); + } else if (NeedPostRequestFor(loc.Scheme)) { + msg.Data = BuildRequest(loc, TString(), headers, urlParams, contentType, ChooseReqType(reqType, ERequestType::Post), reqFlags); + } else { + return false; + } + } + + if (proxy.IsInited()) { + loc = NNeh::TParsedLocation(proxy); + msg.Addr = proxy; + } + + TString schemePostfix = ""; + if (loc.Scheme.EndsWith("+unix")) { + schemePostfix = "+unix"; + } + + // ugly but still... https2 will break it :( + if ('s' == loc.Scheme[loc.Scheme.size() - 1]) { + msg.Addr.replace(0, loc.Scheme.size(), schemeFulls + schemePostfix); + } else { + msg.Addr.replace(0, loc.Scheme.size(), schemeFull + schemePostfix); + } + + return true; + } + + bool MakeFullRequest(TMessage& msg, const TStringBuf headers, const TStringBuf content, const TStringBuf contentType, ERequestType reqType, ERequestFlags reqFlags) { + return MakeFullRequestImpl(msg, {}, msg.Data, headers, content, contentType, reqType, reqFlags); + } + + bool MakeFullRequest(TMessage& msg, const TConstArrayRef<TString> urlParts, const TStringBuf headers, const TStringBuf content, const TStringBuf contentType, ERequestType reqType, ERequestFlags reqFlags) { + return MakeFullRequestImpl(msg, {}, urlParts, headers, content, contentType, reqType, reqFlags); + } + + bool MakeFullProxyRequest(TMessage& msg, TStringBuf proxyAddr, TStringBuf headers, TStringBuf content, TStringBuf contentType, ERequestType reqType, ERequestFlags flags) { + return MakeFullRequestImpl(msg, proxyAddr, msg.Data, headers, content, contentType, reqType, flags | ERequestFlag::AbsoluteUri); + } + + bool IsHttpScheme(TStringBuf scheme) { + return NeedGetRequestFor(scheme) || NeedPostRequestFor(scheme); + } + } +} + diff --git a/library/cpp/neh/http_common.h b/library/cpp/neh/http_common.h new file mode 100644 index 0000000000..91b4ca1356 --- /dev/null +++ b/library/cpp/neh/http_common.h @@ -0,0 +1,305 @@ +#pragma once + +#include <util/generic/array_ref.h> +#include <util/generic/flags.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/stream/mem.h> +#include <util/stream/output.h> +#include <library/cpp/deprecated/atomic/atomic.h> + +#include "location.h" +#include "neh.h" +#include "rpc.h" + +#include <atomic> + +//common primitives for http/http2 + +namespace NNeh { + struct THttpErrorDetails { + TString Details = {}; + TString Headers = {}; + }; + + class IHttpRequest: public IRequest { + public: + using IRequest::SendReply; + virtual void SendReply(TData& data, const TString& headers, int httpCode = 200) = 0; + virtual const THttpHeaders& Headers() const = 0; + virtual TStringBuf Method() const = 0; + virtual TStringBuf Body() const = 0; + virtual TStringBuf Cgi() const = 0; + void SendError(TResponseError err, const TString& details = TString()) override final { + SendError(err, THttpErrorDetails{.Details = details}); + } + + virtual void SendError(TResponseError err, const THttpErrorDetails& details) = 0; + }; + + namespace NHttp { + enum class EResolverType { + ETCP = 0, + EUNIXSOCKET = 1 + }; + + struct TFdLimits { + public: + TFdLimits() + : Soft_(10000) + , Hard_(15000) + { + } + + TFdLimits(const TFdLimits& other) { + Soft_.store(other.Soft(), std::memory_order_release); + Hard_.store(other.Hard(), std::memory_order_release); + } + + inline size_t Delta() const noexcept { + return ExceedLimit(Hard_.load(std::memory_order_acquire), Soft_.load(std::memory_order_acquire)); + } + + inline static size_t ExceedLimit(size_t val, size_t limit) noexcept { + return val > limit ? val - limit : 0; + } + + void SetSoft(size_t value) noexcept { + Soft_.store(value, std::memory_order_release); + } + + void SetHard(size_t value) noexcept { + Hard_.store(value, std::memory_order_release); + } + + size_t Soft() const noexcept { + return Soft_.load(std::memory_order_acquire); + } + + size_t Hard() const noexcept { + return Hard_.load(std::memory_order_acquire); + } + + private: + std::atomic<size_t> Soft_; + std::atomic<size_t> Hard_; + }; + + template <class T> + class TLockFreeSequence { + public: + inline TLockFreeSequence() { + memset((void*)T_, 0, sizeof(T_)); + } + + inline ~TLockFreeSequence() { + for (size_t i = 0; i < Y_ARRAY_SIZE(T_); ++i) { + delete[] T_[i]; + } + } + + inline T& Get(size_t n) { + const size_t i = GetValueBitCount(n + 1) - 1; + + return GetList(i)[n + 1 - (((size_t)1) << i)]; + } + + private: + inline T* GetList(size_t n) { + T* volatile* t = T_ + n; + + T* result; + while (!(result = AtomicGet(*t))) { + TArrayHolder<T> nt(new T[((size_t)1) << n]); + + if (AtomicCas(t, nt.Get(), nullptr)) { + return nt.Release(); + } + } + + return result; + } + + private: + T* volatile T_[sizeof(size_t) * 8]; + }; + + class TRequestData: public TNonCopyable { + public: + using TPtr = TAutoPtr<TRequestData>; + using TParts = TVector<IOutputStream::TPart>; + + inline TRequestData(size_t memSize) + : Mem(memSize) + { + } + + inline void SendTo(IOutputStream& io) const { + io.Write(Parts_.data(), Parts_.size()); + } + + inline void AddPart(const void* buf, size_t len) noexcept { + Parts_.push_back(IOutputStream::TPart(buf, len)); + } + + const TParts& Parts() const noexcept { + return Parts_; + } + + TVector<char> Mem; + + private: + TParts Parts_; + }; + + struct TRequestSettings { + bool NoDelay = true; + EResolverType ResolverType = EResolverType::ETCP; + bool UseAsyncSendRequest = false; + + TRequestSettings& SetNoDelay(bool noDelay) { + NoDelay = noDelay; + return *this; + } + + TRequestSettings& SetResolverType(EResolverType resolverType) { + ResolverType = resolverType; + return *this; + } + + TRequestSettings& SetUseAsyncSendRequest(bool useAsyncSendRequest) { + UseAsyncSendRequest = useAsyncSendRequest; + return *this; + } + }; + + struct TRequestGet { + static TRequestData::TPtr Build(const TMessage& msg, const TParsedLocation& loc) { + TRequestData::TPtr req(new TRequestData(50 + loc.Service.size() + msg.Data.size() + loc.Host.size())); + TMemoryOutput out(req->Mem.data(), req->Mem.size()); + + out << TStringBuf("GET /") << loc.Service; + + if (!!msg.Data) { + out << '?' << msg.Data; + } + + out << TStringBuf(" HTTP/1.1\r\nHost: ") << loc.Host; + + if (!!loc.Port) { + out << TStringBuf(":") << loc.Port; + } + + out << TStringBuf("\r\n\r\n"); + + req->AddPart(req->Mem.data(), out.Buf() - req->Mem.data()); + return req; + } + + static inline TStringBuf Name() noexcept { + return TStringBuf("http"); + } + + static TRequestSettings RequestSettings() { + return TRequestSettings(); + } + }; + + struct TRequestPost { + static TRequestData::TPtr Build(const TMessage& msg, const TParsedLocation& loc) { + TRequestData::TPtr req(new TRequestData(100 + loc.Service.size() + loc.Host.size())); + TMemoryOutput out(req->Mem.data(), req->Mem.size()); + + out << TStringBuf("POST /") << loc.Service + << TStringBuf(" HTTP/1.1\r\nHost: ") << loc.Host; + + if (!!loc.Port) { + out << TStringBuf(":") << loc.Port; + } + + out << TStringBuf("\r\nContent-Length: ") << msg.Data.size() << TStringBuf("\r\n\r\n"); + + req->AddPart(req->Mem.data(), out.Buf() - req->Mem.data()); + req->AddPart(msg.Data.data(), msg.Data.size()); + return req; + } + + static inline TStringBuf Name() noexcept { + return TStringBuf("post"); + } + + static TRequestSettings RequestSettings() { + return TRequestSettings(); + } + }; + + struct TRequestFull { + static TRequestData::TPtr Build(const TMessage& msg, const TParsedLocation&) { + TRequestData::TPtr req(new TRequestData(0)); + req->AddPart(msg.Data.data(), msg.Data.size()); + return req; + } + + static inline TStringBuf Name() noexcept { + return TStringBuf("full"); + } + + static TRequestSettings RequestSettings() { + return TRequestSettings(); + } + }; + + enum class ERequestType { + Any = 0 /* "ANY" */, + Post /* "POST" */, + Get /* "GET" */, + Put /* "PUT" */, + Delete /* "DELETE" */, + Patch /* "PATCH" */, + }; + + enum class ERequestFlag { + None = 0, + /** use absoulte uri for proxy requests in the first request line + * POST http://ya.ru HTTP/1.1 + * @see https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + */ + AbsoluteUri = 1, + }; + + Y_DECLARE_FLAGS(ERequestFlags, ERequestFlag) + Y_DECLARE_OPERATORS_FOR_FLAGS(ERequestFlags) + + static constexpr ERequestType DefaultRequestType = ERequestType::Any; + + extern const TStringBuf DefaultContentType; + + /// @brief `MakeFullRequest` transmutes http/post/http2/post2 message to full/full2 with + /// additional HTTP headers and/or content data. + /// + /// If reqType is `Any`, then request type is POST, unless content is empty and schema + /// prefix is http/https/http2, in that case request type is GET. + /// + /// @msg[in] Will get URL from `msg.Data`. + bool MakeFullRequest(TMessage& msg, TStringBuf headers, TStringBuf content, TStringBuf contentType = DefaultContentType, ERequestType reqType = DefaultRequestType, ERequestFlags flags = ERequestFlag::None); + + /// @see `MakeFullrequest`. + /// + /// @urlParts[in] Will construct url from `urlParts`, `msg.Data` is not used. + bool MakeFullRequest(TMessage& msg, TConstArrayRef<TString> urlParts, TStringBuf headers, TStringBuf content, TStringBuf contentType = DefaultContentType, ERequestType reqType = DefaultRequestType, ERequestFlags flags = ERequestFlag::None); + + /// Same as `MakeFullRequest` but it will add ERequestFlag::AbsoluteUri to the @a flags + /// and replace msg.Addr with @a proxyAddr + /// + /// @see `MakeFullrequest`. + bool MakeFullProxyRequest(TMessage& msg, TStringBuf proxyAddr, TStringBuf headers, TStringBuf content, TStringBuf contentType = DefaultContentType, ERequestType reqType = DefaultRequestType, ERequestFlags flags = ERequestFlag::None); + + size_t GetUrlPartsLength(TConstArrayRef<TString> urlParts); + //part1&part2&... + void JoinUrlParts(TConstArrayRef<TString> urlParts, IOutputStream& out); + //'?' + JoinUrlParts + void WriteUrlParts(TConstArrayRef<TString> urlParts, IOutputStream& out); + + bool IsHttpScheme(TStringBuf scheme); + } +} diff --git a/library/cpp/neh/http_headers.cpp b/library/cpp/neh/http_headers.cpp new file mode 100644 index 0000000000..79dc823e7e --- /dev/null +++ b/library/cpp/neh/http_headers.cpp @@ -0,0 +1 @@ +#include "http_headers.h" diff --git a/library/cpp/neh/http_headers.h b/library/cpp/neh/http_headers.h new file mode 100644 index 0000000000..70cf3a9fbe --- /dev/null +++ b/library/cpp/neh/http_headers.h @@ -0,0 +1,55 @@ +#pragma once + +#include <util/generic/strbuf.h> +#include <util/stream/output.h> +#include <util/string/ascii.h> + +namespace NNeh { + namespace NHttp { + template <typename Port> + void WriteHostHeader(IOutputStream& out, TStringBuf host, Port port) { + out << TStringBuf("Host: ") << host; + if (port) { + out << TStringBuf(":") << port; + } + out << TStringBuf("\r\n"); + } + + class THeaderSplitter { + public: + THeaderSplitter(TStringBuf headers) + : Headers_(headers) + { + } + + bool Next(TStringBuf& header) { + while (Headers_.ReadLine(header)) { + if (!header.Empty()) { + return true; + } + } + return false; + } + private: + TStringBuf Headers_; + }; + + inline bool HasHostHeader(TStringBuf headers) { + THeaderSplitter splitter(headers); + TStringBuf header; + while (splitter.Next(header)) { + if (AsciiHasPrefixIgnoreCase(header, "Host:")) { + return true; + } + } + return false; + } + + template <typename Port> + void WriteHostHeaderIfNot(IOutputStream& out, TStringBuf host, Port port, TStringBuf headers) { + if (!NNeh::NHttp::HasHostHeader(headers)) { + NNeh::NHttp::WriteHostHeader(out, host, port); + } + } + } +} diff --git a/library/cpp/neh/https.cpp b/library/cpp/neh/https.cpp new file mode 100644 index 0000000000..d0e150e778 --- /dev/null +++ b/library/cpp/neh/https.cpp @@ -0,0 +1,1936 @@ +#include "https.h" + +#include "details.h" +#include "factory.h" +#include "http_common.h" +#include "jobqueue.h" +#include "location.h" +#include "multi.h" +#include "pipequeue.h" +#include "utils.h" + +#include <contrib/libs/openssl/include/openssl/ssl.h> +#include <contrib/libs/openssl/include/openssl/err.h> +#include <contrib/libs/openssl/include/openssl/bio.h> +#include <contrib/libs/openssl/include/openssl/x509v3.h> + +#include <library/cpp/openssl/init/init.h> +#include <library/cpp/openssl/method/io.h> +#include <library/cpp/coroutine/listener/listen.h> +#include <library/cpp/dns/cache.h> +#include <library/cpp/http/misc/parsed_request.h> +#include <library/cpp/http/misc/httpcodes.h> +#include <library/cpp/http/io/stream.h> + +#include <util/generic/cast.h> +#include <util/generic/list.h> +#include <util/generic/utility.h> +#include <util/network/socket.h> +#include <util/stream/str.h> +#include <util/stream/zlib.h> +#include <util/string/builder.h> +#include <util/string/cast.h> +#include <util/system/condvar.h> +#include <util/system/error.h> +#include <util/system/types.h> +#include <util/thread/factory.h> + +#include <atomic> + +#if defined(_unix_) +#include <sys/ioctl.h> +#endif + +#if defined(_linux_) +#undef SIOCGSTAMP +#undef SIOCGSTAMPNS +#include <linux/sockios.h> +#define FIONWRITE SIOCOUTQ +#endif + +using namespace NDns; +using namespace NAddr; + +namespace NNeh { + TString THttpsOptions::CAFile; + TString THttpsOptions::CAPath; + TString THttpsOptions::ClientCertificate; + TString THttpsOptions::ClientPrivateKey; + TString THttpsOptions::ClientPrivateKeyPassword; + bool THttpsOptions::EnableSslServerDebug = false; + bool THttpsOptions::EnableSslClientDebug = false; + bool THttpsOptions::CheckCertificateHostname = false; + THttpsOptions::TVerifyCallback THttpsOptions::ClientVerifyCallback = nullptr; + THttpsOptions::TPasswordCallback THttpsOptions::KeyPasswdCallback = nullptr; + bool THttpsOptions::RedirectionNotError = false; + + bool THttpsOptions::Set(TStringBuf name, TStringBuf value) { +#define YNDX_NEH_HTTPS_TRY_SET(optName) \ + if (name == TStringBuf(#optName)) { \ + optName = FromString<decltype(optName)>(value); \ + return true; \ + } + + YNDX_NEH_HTTPS_TRY_SET(CAFile); + YNDX_NEH_HTTPS_TRY_SET(CAPath); + YNDX_NEH_HTTPS_TRY_SET(ClientCertificate); + YNDX_NEH_HTTPS_TRY_SET(ClientPrivateKey); + YNDX_NEH_HTTPS_TRY_SET(ClientPrivateKeyPassword); + YNDX_NEH_HTTPS_TRY_SET(EnableSslServerDebug); + YNDX_NEH_HTTPS_TRY_SET(EnableSslClientDebug); + YNDX_NEH_HTTPS_TRY_SET(CheckCertificateHostname); + YNDX_NEH_HTTPS_TRY_SET(RedirectionNotError); + +#undef YNDX_NEH_HTTPS_TRY_SET + + return false; + } +} + +namespace NNeh { + namespace NHttps { + namespace { + // force ssl_write/ssl_read functions to return this value via BIO_method_read/write that means request is canceled + constexpr int SSL_RVAL_TIMEOUT = -42; + + struct TInputConnections { + TInputConnections() + : Counter(0) + , MaxUnusedConnKeepaliveTimeout(120) + , MinUnusedConnKeepaliveTimeout(10) + { + } + + inline size_t ExceedSoftLimit() const noexcept { + return NHttp::TFdLimits::ExceedLimit(Counter.Val(), Limits.Soft()); + } + + inline size_t ExceedHardLimit() const noexcept { + return NHttp::TFdLimits::ExceedLimit(Counter.Val(), Limits.Hard()); + } + + inline size_t DeltaLimit() const noexcept { + return Limits.Delta(); + } + + unsigned UnusedConnKeepaliveTimeout() const { + if (size_t e = ExceedSoftLimit()) { + size_t d = DeltaLimit(); + size_t leftAvailableFd = NHttp::TFdLimits::ExceedLimit(d, e); + unsigned r = static_cast<unsigned>(MaxUnusedConnKeepaliveTimeout.load(std::memory_order_acquire) * leftAvailableFd / (d + 1)); + return Max(r, (unsigned)MinUnusedConnKeepaliveTimeout.load(std::memory_order_acquire)); + } + return MaxUnusedConnKeepaliveTimeout.load(std::memory_order_acquire); + } + + void SetFdLimits(size_t soft, size_t hard) { + Limits.SetSoft(soft); + Limits.SetHard(hard); + } + + NHttp::TFdLimits Limits; + TAtomicCounter Counter; + std::atomic<unsigned> MaxUnusedConnKeepaliveTimeout; //in seconds + std::atomic<unsigned> MinUnusedConnKeepaliveTimeout; //in seconds + }; + + TInputConnections* InputConnections() { + return Singleton<TInputConnections>(); + } + + struct TSharedSocket: public TSocketHolder, public TAtomicRefCount<TSharedSocket> { + inline TSharedSocket(TSocketHolder& s) + : TSocketHolder(s.Release()) + { + InputConnections()->Counter.Inc(); + } + + ~TSharedSocket() { + InputConnections()->Counter.Dec(); + } + }; + + using TSocketRef = TIntrusivePtr<TSharedSocket>; + + struct TX509Deleter { + static void Destroy(X509* cert) { + X509_free(cert); + } + }; + using TX509Holder = THolder<X509, TX509Deleter>; + + struct TSslSessionDeleter { + static void Destroy(SSL_SESSION* sess) { + SSL_SESSION_free(sess); + } + }; + using TSslSessionHolder = THolder<SSL_SESSION, TSslSessionDeleter>; + + struct TSslDeleter { + static void Destroy(SSL* ssl) { + SSL_free(ssl); + } + }; + using TSslHolder = THolder<SSL, TSslDeleter>; + + // read from bio and write via operator<<() to dst + template <typename T> + class TBIOInput : public NOpenSSL::TAbstractIO { + public: + TBIOInput(T& dst) + : Dst_(dst) + { + } + + int Write(const char* data, size_t dlen, size_t* written) override { + Dst_ << TStringBuf(data, dlen); + *written = dlen; + return 1; + } + + int Read(char* data, size_t dlen, size_t* readbytes) override { + Y_UNUSED(data); + Y_UNUSED(dlen); + Y_UNUSED(readbytes); + return -1; + } + + int Puts(const char* buf) override { + Y_UNUSED(buf); + return -1; + } + + int Gets(char* buf, int len) override { + Y_UNUSED(buf); + Y_UNUSED(len); + return -1; + } + + void Flush() override { + } + + private: + T& Dst_; + }; + } + + class TSslException: public yexception { + public: + TSslException() = default; + + TSslException(TStringBuf f) { + *this << f << Endl; + InitErr(); + } + + TSslException(TStringBuf f, const SSL* ssl, int ret) { + *this << f << TStringBuf(" error type: "); + const int etype = SSL_get_error(ssl, ret); + switch (etype) { + case SSL_ERROR_ZERO_RETURN: + *this << TStringBuf("SSL_ERROR_ZERO_RETURN"); + break; + case SSL_ERROR_WANT_READ: + *this << TStringBuf("SSL_ERROR_WANT_READ"); + break; + case SSL_ERROR_WANT_WRITE: + *this << TStringBuf("SSL_ERROR_WANT_WRITE"); + break; + case SSL_ERROR_WANT_CONNECT: + *this << TStringBuf("SSL_ERROR_WANT_CONNECT"); + break; + case SSL_ERROR_WANT_ACCEPT: + *this << TStringBuf("SSL_ERROR_WANT_ACCEPT"); + break; + case SSL_ERROR_WANT_X509_LOOKUP: + *this << TStringBuf("SSL_ERROR_WANT_X509_LOOKUP"); + break; + case SSL_ERROR_SYSCALL: + *this << TStringBuf("SSL_ERROR_SYSCALL ret: ") << ret << TStringBuf(", errno: ") << errno; + break; + case SSL_ERROR_SSL: + *this << TStringBuf("SSL_ERROR_SSL"); + break; + } + *this << ' '; + InitErr(); + } + + private: + void InitErr() { + TBIOInput<TSslException> bio(*this); + ERR_print_errors(bio); + } + }; + + namespace { + enum EMatchResult { + MATCH_FOUND, + NO_MATCH, + NO_EXTENSION, + ERROR + }; + bool EqualNoCase(TStringBuf a, TStringBuf b) { + return (a.size() == b.size()) && ToString(a).to_lower() == ToString(b).to_lower(); + } + bool MatchDomainName(TStringBuf tmpl, TStringBuf name) { + // match wildcards only in the left-most part + // do not support (optional according to RFC) partial wildcards (ww*.yandex.ru) + // see RFC-6125 + TStringBuf tmplRest = tmpl; + TStringBuf tmplFirst = tmplRest.NextTok('.'); + if (tmplFirst == "*") { + tmpl = tmplRest; + name.NextTok('.'); + } + return EqualNoCase(tmpl, name); + } + + EMatchResult MatchCertAltNames(X509* cert, TStringBuf hostname) { + EMatchResult result = NO_MATCH; + STACK_OF(GENERAL_NAME)* names = (STACK_OF(GENERAL_NAME)*)X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, NULL); + if (!names) { + return NO_EXTENSION; + } + + int namesCt = sk_GENERAL_NAME_num(names); + for (int i = 0; i < namesCt; ++i) { + const GENERAL_NAME* name = sk_GENERAL_NAME_value(names, i); + + if (name->type == GEN_DNS) { + TStringBuf dnsName((const char*)ASN1_STRING_get0_data(name->d.dNSName), ASN1_STRING_length(name->d.dNSName)); + if (MatchDomainName(dnsName, hostname)) { + result = MATCH_FOUND; + break; + } + } + } + sk_GENERAL_NAME_pop_free(names, GENERAL_NAME_free); + return result; + } + + EMatchResult MatchCertCommonName(X509* cert, TStringBuf hostname) { + int commonNameLoc = X509_NAME_get_index_by_NID(X509_get_subject_name(cert), NID_commonName, -1); + if (commonNameLoc < 0) { + return ERROR; + } + + X509_NAME_ENTRY* commonNameEntry = X509_NAME_get_entry(X509_get_subject_name(cert), commonNameLoc); + if (!commonNameEntry) { + return ERROR; + } + + ASN1_STRING* commonNameAsn1 = X509_NAME_ENTRY_get_data(commonNameEntry); + if (!commonNameAsn1) { + return ERROR; + } + + TStringBuf commonName((const char*)ASN1_STRING_get0_data(commonNameAsn1), ASN1_STRING_length(commonNameAsn1)); + + return MatchDomainName(commonName, hostname) + ? MATCH_FOUND + : NO_MATCH; + } + + bool CheckCertHostname(X509* cert, TStringBuf hostname) { + switch (MatchCertAltNames(cert, hostname)) { + case MATCH_FOUND: + return true; + break; + case NO_EXTENSION: + return MatchCertCommonName(cert, hostname) == MATCH_FOUND; + break; + default: + return false; + } + } + + void ParseUserInfo(const TParsedLocation& loc, TString& cert, TString& pvtKey) { + if (!loc.UserInfo) { + return; + } + + TStringBuf kws = loc.UserInfo; + while (kws) { + TStringBuf name = kws.NextTok('='); + TStringBuf value = kws.NextTok(';'); + if (TStringBuf("cert") == name) { + cert = value; + } else if (TStringBuf("key") == name) { + pvtKey = value; + } + } + } + + struct TSSLInit { + inline TSSLInit() { + InitOpenSSL(); + } + } SSL_INIT; + } + + static inline void PrepareSocket(SOCKET s) { + SetNoDelay(s, true); + } + + class TConnCache; + static TConnCache* SocketCache(); + + class TConnCache: public IThreadFactory::IThreadAble { + public: + typedef TAutoLockFreeQueue<TSocketHolder> TConnList; + typedef TAutoPtr<TSocketHolder> TSocketRef; + + struct TConnection { + inline TConnection(TSocketRef& s, bool reUsed, const TResolvedHost* host) noexcept + : Socket(s) + , ReUsed(reUsed) + , Host(host) + { + SocketCache()->ActiveSockets.Inc(); + } + + inline ~TConnection() { + if (!!Socket) { + SocketCache()->ActiveSockets.Dec(); + } + } + + SOCKET Fd() { + return *Socket; + } + + protected: + friend class TConnCache; + TSocketRef Socket; + + public: + const bool ReUsed; + const TResolvedHost* Host; + }; + + TConnCache() + : InPurging_(0) + , MaxConnId_(0) + , Shutdown_(false) + { + T_ = SystemThreadFactory()->Run(this); + } + + ~TConnCache() override { + { + TGuard<TMutex> g(PurgeMutex_); + + Shutdown_ = true; + CondPurge_.Signal(); + } + + T_->Join(); + } + + //used for forwarding filling cache + class TConnector: public IJob { + public: + //create fresh connection + TConnector(const TResolvedHost* host) + : Host_(host) + { + } + + //continue connecting exist socket + TConnector(const TResolvedHost* host, TSocketRef& s) + : Host_(host) + , S_(s) + { + } + + void DoRun(TCont* c) override { + THolder<TConnector> This(this); + + try { + if (!S_) { + TSocketRef res(new TSocketHolder()); + + for (TNetworkAddress::TIterator it = Host_->Addr.Begin(); it != Host_->Addr.End(); ++it) { + int ret = NCoro::ConnectD(c, *res, *it, TDuration::MilliSeconds(300).ToDeadLine()); + + if (!ret) { + TConnection tc(res, false, Host_); + SocketCache()->Release(tc); + return; + } + + if (ret == ECANCELED) { + return; + } + } + } else { + if (!NCoro::PollT(c, *S_, CONT_POLL_WRITE, TDuration::MilliSeconds(300))) { + TConnection tc(S_, false, Host_); + SocketCache()->Release(tc); + } + } + } catch (...) { + } + } + + private: + const TResolvedHost* Host_; + TSocketRef S_; + }; + + TConnection* Connect(TCont* c, const TString& msgAddr, const TResolvedHost* addr, TErrorRef* error) { + if (ExceedHardLimit()) { + if (error) { + *error = new TError("neh::https output connections limit reached", TError::TType::UnknownType); + } + return nullptr; + } + + TSocketRef res; + TConnList& connList = ConnList(addr); + + while (connList.Dequeue(&res)) { + CachedSockets.Dec(); + + if (IsNotSocketClosedByOtherSide(*res)) { + if (connList.Size() == 0) { + //available connections exhausted - try create yet one (reserve) + TAutoPtr<IJob> job(new TConnector(addr)); + + if (c) { + try { + c->Executor()->Create(*job, "https-con"); + Y_UNUSED(job.Release()); + } catch (...) { + } + } else { + JobQueue()->Schedule(job); + } + } + return new TConnection(res, true, addr); + } + } + + if (!c) { + if (error) { + *error = new TError("directo connection failed"); + } + return nullptr; + } + + try { + //run reserve/concurrent connecting + TAutoPtr<IJob> job(new TConnector(addr)); + + c->Executor()->Create(*job, "https-con"); + Y_UNUSED(job.Release()); + } catch (...) { + } + + TNetworkAddress::TIterator ait = addr->Addr.Begin(); + res.Reset(new TSocketHolder(NCoro::Socket(*ait))); + const TInstant now(TInstant::Now()); + const TInstant deadline(now + TDuration::Seconds(10)); + TDuration delay = TDuration::MilliSeconds(8); + TInstant checkpoint = Min(deadline, now + delay); + int ret = NCoro::ConnectD(c, *res, ait->ai_addr, ait->ai_addrlen, checkpoint); + + if (ret) { + do { + if ((ret == ETIMEDOUT || ret == EINTR) && checkpoint < deadline) { + delay += delay; + checkpoint = Min(deadline, now + delay); + + TSocketRef res2; + + if (connList.Dequeue(&res2)) { + CachedSockets.Dec(); + + if (IsNotSocketClosedByOtherSide(*res2)) { + try { + TAutoPtr<IJob> job(new TConnector(addr, res)); + + c->Executor()->Create(*job, "https-con"); + Y_UNUSED(job.Release()); + } catch (...) { + } + + res = res2; + + break; + } + } + } else { + if (error) { + *error = new TError(TStringBuilder() << TStringBuf("can not connect to ") << msgAddr); + } + return nullptr; + } + } while (ret = NCoro::PollD(c, *res, CONT_POLL_WRITE, checkpoint)); + } + + PrepareSocket(*res); + + return new TConnection(res, false, addr); + } + + inline void Release(TConnection& conn) { + if (!ExceedHardLimit()) { + size_t maxConnId = MaxConnId_.load(std::memory_order_acquire); + + while (maxConnId < conn.Host->Id) { + MaxConnId_.compare_exchange_strong( + maxConnId, + conn.Host->Id, + std::memory_order_seq_cst, + std::memory_order_seq_cst); + maxConnId = MaxConnId_.load(std::memory_order_acquire); + } + + CachedSockets.Inc(); + ActiveSockets.Dec(); + + ConnList(conn.Host).Enqueue(conn.Socket); + } + + if (CachedSockets.Val() && ExceedSoftLimit()) { + SuggestPurgeCache(); + } + } + + void SetFdLimits(size_t soft, size_t hard) { + Limits.SetSoft(soft); + Limits.SetHard(hard); + } + + private: + void SuggestPurgeCache() { + if (AtomicTryLock(&InPurging_)) { + //evaluate the usefulness of purging the cache + //если в кеше мало соединений (< MaxConnId_/16 или 64), не чистим кеш + if ((size_t)CachedSockets.Val() > (Min((size_t)MaxConnId_.load(std::memory_order_acquire), (size_t)1024U) >> 4)) { + //по мере приближения к hardlimit нужда в чистке cache приближается к 100% + size_t closenessToHardLimit256 = ((ActiveSockets.Val() + 1) << 8) / (Limits.Delta() + 1); + //чем больше соединений в кеше, а не в работе, тем менее нужен кеш (можно его почистить) + size_t cacheUselessness256 = ((CachedSockets.Val() + 1) << 8) / (ActiveSockets.Val() + 1); + + //итого, - пороги срабатывания: + //при достижении soft-limit, если соединения в кеше, а не в работе + //на полпути от soft-limit к hard-limit, если в кеше больше половины соединений + //при приближении к hardlimit пытаться почистить кеш почти постоянно + if ((closenessToHardLimit256 + cacheUselessness256) >= 256U) { + TGuard<TMutex> g(PurgeMutex_); + + CondPurge_.Signal(); + return; //memo: thread MUST unlock InPurging_ (see DoExecute()) + } + } + AtomicUnlock(&InPurging_); + } + } + + void DoExecute() override { + while (true) { + { + TGuard<TMutex> g(PurgeMutex_); + + if (Shutdown_) + return; + + CondPurge_.WaitI(PurgeMutex_); + } + + PurgeCache(); + + AtomicUnlock(&InPurging_); + } + } + + inline void OnPurgeSocket(ui64& processed) { + CachedSockets.Dec(); + if ((processed++ & 0x3f) == 0) { + //suspend execution every 64 processed socket (clean rate ~= 6400 sockets/sec) + Sleep(TDuration::MilliSeconds(10)); + } + } + + void PurgeCache() noexcept { + //try remove at least ExceedSoftLimit() oldest connections from cache + //вычисляем долю кеша, которую нужно почистить (в 256 долях) (но не менее 1/32 кеша) + size_t frac256 = Min(size_t(Max(size_t(256U / 32U), (ExceedSoftLimit() << 8) / (CachedSockets.Val() + 1))), (size_t)256U); + TSocketRef tmp; + + ui64 processed = 0; + for (size_t i = 0; i < MaxConnId_.load(std::memory_order_acquire) && !Shutdown_; i++) { + TConnList& tc = Lst_.Get(i); + if (size_t qsize = tc.Size()) { + //в каждой очереди чистим вычисленную долю + size_t purgeCounter = ((qsize * frac256) >> 8); + + if (!purgeCounter && qsize) { + if (qsize <= 2) { + TSocketRef res; + if (tc.Dequeue(&res)) { + if (IsNotSocketClosedByOtherSide(*res)) { + tc.Enqueue(res); + } else { + OnPurgeSocket(processed); + } + } + } else { + purgeCounter = 1; + } + } + while (purgeCounter-- && tc.Dequeue(&tmp)) { + OnPurgeSocket(processed); + } + } + } + } + + inline TConnList& ConnList(const TResolvedHost* addr) { + return Lst_.Get(addr->Id); + } + + inline size_t TotalSockets() const noexcept { + return ActiveSockets.Val() + CachedSockets.Val(); + } + + inline size_t ExceedSoftLimit() const noexcept { + return NHttp::TFdLimits::ExceedLimit(TotalSockets(), Limits.Soft()); + } + + inline size_t ExceedHardLimit() const noexcept { + return NHttp::TFdLimits::ExceedLimit(TotalSockets(), Limits.Hard()); + } + + NHttp::TFdLimits Limits; + TAtomicCounter ActiveSockets; + TAtomicCounter CachedSockets; + + NHttp::TLockFreeSequence<TConnList> Lst_; + + TAtomic InPurging_; + std::atomic<size_t> MaxConnId_; + + TAutoPtr<IThreadFactory::IThread> T_; + TCondVar CondPurge_; + TMutex PurgeMutex_; + TAtomicBool Shutdown_; + }; + + class TSslCtx: public TThrRefBase { + protected: + TSslCtx() + : SslCtx_(nullptr) + { + } + + public: + ~TSslCtx() override { + SSL_CTX_free(SslCtx_); + } + + operator SSL_CTX*() { + return SslCtx_; + } + + protected: + SSL_CTX* SslCtx_; + }; + using TSslCtxPtr = TIntrusivePtr<TSslCtx>; + + class TSslCtxServer: public TSslCtx { + struct TPasswordCallbackUserData { + TParsedLocation Location; + TString CertFileName; + TString KeyFileName; + }; + class TUserDataHolder { + public: + TUserDataHolder(SSL_CTX* ctx, const TParsedLocation& location, const TString& certFileName, const TString& keyFileName) + : SslCtx_(ctx) + , Data_{location, certFileName, keyFileName} + { + SSL_CTX_set_default_passwd_cb_userdata(SslCtx_, &Data_); + } + ~TUserDataHolder() { + SSL_CTX_set_default_passwd_cb_userdata(SslCtx_, nullptr); + } + private: + SSL_CTX* SslCtx_; + TPasswordCallbackUserData Data_; + }; + public: + TSslCtxServer(const TParsedLocation& loc) { + const SSL_METHOD* method = SSLv23_server_method(); + if (Y_UNLIKELY(!method)) { + ythrow TSslException(TStringBuf("SSLv23_server_method")); + } + + SslCtx_ = SSL_CTX_new(method); + if (Y_UNLIKELY(!SslCtx_)) { + ythrow TSslException(TStringBuf("SSL_CTX_new(server)")); + } + + TString cert, key; + ParseUserInfo(loc, cert, key); + + TUserDataHolder holder(SslCtx_, loc, cert, key); + + SSL_CTX_set_default_passwd_cb(SslCtx_, [](char* buf, int size, int rwflag, void* userData) -> int { + Y_UNUSED(rwflag); + Y_UNUSED(userData); + + if (THttpsOptions::KeyPasswdCallback == nullptr || userData == nullptr) { + return 0; + } + + auto data = static_cast<TPasswordCallbackUserData*>(userData); + const auto& passwd = THttpsOptions::KeyPasswdCallback(data->Location, data->CertFileName, data->KeyFileName); + + if (size < static_cast<int>(passwd.size())) { + return -1; + } + + return passwd.copy(buf, size, 0); + }); + + if (!cert || !key) { + ythrow TSslException() << TStringBuf("no certificate or private key is specified for server"); + } + + if (1 != SSL_CTX_use_certificate_chain_file(SslCtx_, cert.data())) { + ythrow TSslException(TStringBuf("SSL_CTX_use_certificate_chain_file (server)")); + } + + if (1 != SSL_CTX_use_PrivateKey_file(SslCtx_, key.data(), SSL_FILETYPE_PEM)) { + ythrow TSslException(TStringBuf("SSL_CTX_use_PrivateKey_file (server)")); + } + + if (1 != SSL_CTX_check_private_key(SslCtx_)) { + ythrow TSslException(TStringBuf("SSL_CTX_check_private_key (server)")); + } + } + }; + + class TSslCtxClient: public TSslCtx { + public: + TSslCtxClient() { + const SSL_METHOD* method = SSLv23_client_method(); + if (Y_UNLIKELY(!method)) { + ythrow TSslException(TStringBuf("SSLv23_client_method")); + } + + SslCtx_ = SSL_CTX_new(method); + if (Y_UNLIKELY(!SslCtx_)) { + ythrow TSslException(TStringBuf("SSL_CTX_new(client)")); + } + + const TString& caFile = THttpsOptions::CAFile; + const TString& caPath = THttpsOptions::CAPath; + if (caFile || caPath) { + if (!SSL_CTX_load_verify_locations(SslCtx_, caFile ? caFile.data() : nullptr, caPath ? caPath.data() : nullptr)) { + ythrow TSslException(TStringBuf("SSL_CTX_load_verify_locations(client)")); + } + } + + SSL_CTX_set_options(SslCtx_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); + if (THttpsOptions::ClientVerifyCallback) { + SSL_CTX_set_verify(SslCtx_, SSL_VERIFY_PEER, THttpsOptions::ClientVerifyCallback); + } else { + SSL_CTX_set_verify(SslCtx_, SSL_VERIFY_NONE, nullptr); + } + + const TString& clientCertificate = THttpsOptions::ClientCertificate; + const TString& clientPrivateKey = THttpsOptions::ClientPrivateKey; + if (clientCertificate && clientPrivateKey) { + SSL_CTX_set_default_passwd_cb(SslCtx_, [](char* buf, int size, int rwflag, void* userData) -> int { + Y_UNUSED(rwflag); + Y_UNUSED(userData); + + const TString& clientPrivateKeyPwd = THttpsOptions::ClientPrivateKeyPassword; + if (!clientPrivateKeyPwd) { + return 0; + } + if (size < static_cast<int>(clientPrivateKeyPwd.size())) { + return -1; + } + + return clientPrivateKeyPwd.copy(buf, size, 0); + }); + if (1 != SSL_CTX_use_certificate_chain_file(SslCtx_, clientCertificate.c_str())) { + ythrow TSslException(TStringBuf("SSL_CTX_use_certificate_chain_file (client)")); + } + if (1 != SSL_CTX_use_PrivateKey_file(SslCtx_, clientPrivateKey.c_str(), SSL_FILETYPE_PEM)) { + ythrow TSslException(TStringBuf("SSL_CTX_use_PrivateKey_file (client)")); + } + if (1 != SSL_CTX_check_private_key(SslCtx_)) { + ythrow TSslException(TStringBuf("SSL_CTX_check_private_key (client)")); + } + } else if (clientCertificate || clientPrivateKey) { + ythrow TSslException() << TStringBuf("both certificate and private key must be specified for client"); + } + } + + static TSslCtxClient& Instance() { + return *Singleton<TSslCtxClient>(); + } + }; + + class TContBIO : public NOpenSSL::TAbstractIO { + public: + TContBIO(SOCKET s, const TAtomicBool* canceled = nullptr) + : Timeout_(TDuration::MicroSeconds(10000)) + , S_(s) + , Canceled_(canceled) + , Cont_(nullptr) + { + } + + SOCKET Socket() { + return S_; + } + + int PollT(int what, const TDuration& timeout) { + return NCoro::PollT(Cont_, Socket(), what, timeout); + } + + void WaitUntilWritten() { +#if defined(FIONWRITE) + if (Y_LIKELY(Cont_)) { + int err; + int nbytes = Max<int>(); + TDuration tout = TDuration::MilliSeconds(10); + + while (((err = ioctl(S_, FIONWRITE, &nbytes)) == 0) && nbytes) { + err = NCoro::PollT(Cont_, S_, CONT_POLL_READ, tout); + + if (!err) { + //wait complete, cause have some data + break; + } + + if (err != ETIMEDOUT) { + ythrow TSystemError(err) << TStringBuf("request failed"); + } + + tout = tout * 2; + } + + if (err) { + ythrow TSystemError() << TStringBuf("ioctl() failed"); + } + } else { + ythrow TSslException() << TStringBuf("No cont available"); + } +#endif + } + + void AcquireCont(TCont* c) { + Cont_ = c; + } + void ReleaseCont() { + Cont_ = nullptr; + } + + int Write(const char* data, size_t dlen, size_t* written) override { + if (Y_UNLIKELY(!Cont_)) { + return -1; + } + + while (true) { + auto done = NCoro::WriteI(Cont_, S_, data, dlen); + if (done.Status() != EAGAIN) { + *written = done.Checked(); + return 1; + } + } + } + + int Read(char* data, size_t dlen, size_t* readbytes) override { + if (Y_UNLIKELY(!Cont_)) { + return -1; + } + + if (!Canceled_) { + while (true) { + auto done = NCoro::ReadI(Cont_, S_, data, dlen); + if (EAGAIN != done.Status()) { + *readbytes = done.Processed(); + return 1; + } + } + } + + while (true) { + if (*Canceled_) { + return SSL_RVAL_TIMEOUT; + } + + TContIOStatus ioStat(NCoro::ReadT(Cont_, S_, data, dlen, Timeout_)); + if (ioStat.Status() == ETIMEDOUT) { + //increase to 1.5 times every iteration (to 1sec floor) + Timeout_ = TDuration::MicroSeconds(Min<ui64>(1000000, Timeout_.MicroSeconds() + (Timeout_.MicroSeconds() >> 1))); + continue; + } + + *readbytes = ioStat.Processed(); + return 1; + } + } + + int Puts(const char* buf) override { + Y_UNUSED(buf); + return -1; + } + + int Gets(char* buf, int size) override { + Y_UNUSED(buf); + Y_UNUSED(size); + return -1; + } + + void Flush() override { + } + + private: + TDuration Timeout_; + SOCKET S_; + const TAtomicBool* Canceled_; + TCont* Cont_; + }; + + class TSslIOStream: public IInputStream, public IOutputStream { + protected: + TSslIOStream(TSslCtx& sslCtx, TAutoPtr<TContBIO> connection) + : Connection_(connection) + , SslCtx_(sslCtx) + , Ssl_(nullptr) + { + } + + virtual void Handshake() = 0; + + public: + void WaitUntilWritten() { + if (Connection_) { + Connection_->WaitUntilWritten(); + } + } + + int PollReadT(const TDuration& timeout) { + if (!Connection_) { + return -1; + } + + while (true) { + const int rpoll = Connection_->PollT(CONT_POLL_READ, timeout); + if (!Ssl_ || rpoll) { + return rpoll; + } + + char c = 0; + const int rpeek = SSL_peek(Ssl_.Get(), &c, sizeof(c)); + if (rpeek < 0) { + return -1; + } else if (rpeek > 0) { + return 0; + } else { + if ((SSL_get_shutdown(Ssl_.Get()) & SSL_RECEIVED_SHUTDOWN) != 0) { + Shutdown(); // wait until shutdown is finished + return EIO; + } + } + } + } + + void Shutdown() { + if (Ssl_ && Connection_) { + for (size_t i = 0; i < 2; ++i) { + bool rval = SSL_shutdown(Ssl_.Get()); + if (0 == rval) { + continue; + } else if (1 == rval) { + break; + } + } + } + } + + inline void AcquireCont(TCont* c) { + if (Y_UNLIKELY(!Connection_)) { + ythrow TSslException() << TStringBuf("no connection provided"); + } + + Connection_->AcquireCont(c); + } + + inline void ReleaseCont() { + if (Connection_) { + Connection_->ReleaseCont(); + } + } + + TContIOStatus WriteVectorI(const TList<IOutputStream::TPart>& vec) { + for (const auto& p : vec) { + Write(p.buf, p.len); + } + return TContIOStatus::Success(vec.size()); + } + + SOCKET Socket() { + if (Y_UNLIKELY(!Connection_)) { + ythrow TSslException() << TStringBuf("no connection provided"); + } + + return Connection_->Socket(); + } + + private: + void DoWrite(const void* buf, size_t len) override { + if (Y_UNLIKELY(!Connection_)) { + ythrow TSslException() << TStringBuf("DoWrite() no connection provided"); + } + + const int rval = SSL_write(Ssl_.Get(), buf, len); + if (rval <= 0) { + ythrow TSslException(TStringBuf("SSL_write"), Ssl_.Get(), rval); + } + } + + size_t DoRead(void* buf, size_t len) override { + if (Y_UNLIKELY(!Connection_)) { + ythrow TSslException() << TStringBuf("DoRead() no connection provided"); + } + + const int rval = SSL_read(Ssl_.Get(), buf, len); + if (rval < 0) { + if (SSL_RVAL_TIMEOUT == rval) { + ythrow TSystemError(ECANCELED) << TStringBuf(" http request canceled"); + } + ythrow TSslException(TStringBuf("SSL_read"), Ssl_.Get(), rval); + } else if (0 == rval) { + if ((SSL_get_shutdown(Ssl_.Get()) & SSL_RECEIVED_SHUTDOWN) != 0) { + return rval; + } else { + const int err = SSL_get_error(Ssl_.Get(), rval); + if (SSL_ERROR_ZERO_RETURN != err) { + ythrow TSslException(TStringBuf("SSL_read"), Ssl_.Get(), rval); + } + } + } + + return static_cast<size_t>(rval); + } + + protected: + // just for ssl debug + static void InfoCB(const SSL* s, int where, int ret) { + TStringBuf str; + const int w = where & ~SSL_ST_MASK; + if (w & SSL_ST_CONNECT) { + str = TStringBuf("SSL_connect"); + } else if (w & SSL_ST_ACCEPT) { + str = TStringBuf("SSL_accept"); + } else { + str = TStringBuf("undefined"); + } + + if (where & SSL_CB_LOOP) { + Cerr << str << ':' << SSL_state_string_long(s) << Endl; + } else if (where & SSL_CB_ALERT) { + Cerr << TStringBuf("SSL3 alert ") << ((where & SSL_CB_READ) ? TStringBuf("read") : TStringBuf("write")) << ' ' << SSL_alert_type_string_long(ret) << ':' << SSL_alert_desc_string_long(ret) << Endl; + } else if (where & SSL_CB_EXIT) { + if (ret == 0) { + Cerr << str << TStringBuf(":failed in ") << SSL_state_string_long(s) << Endl; + } else if (ret < 0) { + Cerr << str << TStringBuf(":error in ") << SSL_state_string_long(s) << Endl; + } + } + } + + protected: + THolder<TContBIO> Connection_; + TSslCtx& SslCtx_; + TSslHolder Ssl_; + }; + + class TContBIOWatcher { + public: + TContBIOWatcher(TSslIOStream& io, TCont* c) noexcept + : IO_(io) + { + IO_.AcquireCont(c); + } + + ~TContBIOWatcher() noexcept { + IO_.ReleaseCont(); + } + + private: + TSslIOStream& IO_; + }; + + class TSslClientIOStream: public TSslIOStream { + public: + TSslClientIOStream(TSslCtxClient& sslCtx, const TParsedLocation& loc, SOCKET s, const TAtomicBool* canceled) + : TSslIOStream(sslCtx, new TContBIO(s, canceled)) + , Location_(loc) + { + } + + void Handshake() override { + Ssl_.Reset(SSL_new(SslCtx_)); + if (THttpsOptions::EnableSslClientDebug) { + SSL_set_info_callback(Ssl_.Get(), InfoCB); + } + + BIO_up_ref(*Connection_); // SSL_set_bio consumes only one reference if rbio and wbio are the same + SSL_set_bio(Ssl_.Get(), *Connection_, *Connection_); + + const TString hostname(Location_.Host); + const int rev = SSL_set_tlsext_host_name(Ssl_.Get(), hostname.data()); + if (Y_UNLIKELY(1 != rev)) { + ythrow TSslException(TStringBuf("SSL_set_tlsext_host_name(client)"), Ssl_.Get(), rev); + } + + TString cert, pvtKey; + ParseUserInfo(Location_, cert, pvtKey); + + if (cert && (1 != SSL_use_certificate_file(Ssl_.Get(), cert.data(), SSL_FILETYPE_PEM))) { + ythrow TSslException(TStringBuf("SSL_use_certificate_file(client)")); + } + + if (pvtKey) { + if (1 != SSL_use_PrivateKey_file(Ssl_.Get(), pvtKey.data(), SSL_FILETYPE_PEM)) { + ythrow TSslException(TStringBuf("SSL_use_PrivateKey_file(client)")); + } + + if (1 != SSL_check_private_key(Ssl_.Get())) { + ythrow TSslException(TStringBuf("SSL_check_private_key(client)")); + } + } + + SSL_set_connect_state(Ssl_.Get()); + + // TODO restore session if reconnect + const int rval = SSL_do_handshake(Ssl_.Get()); + if (1 != rval) { + if (rval == SSL_RVAL_TIMEOUT) { + ythrow TSystemError(ECANCELED) << TStringBuf("canceled"); + } else { + ythrow TSslException(TStringBuf("BIO_do_handshake(client)"), Ssl_.Get(), rval); + } + } + + if (THttpsOptions::CheckCertificateHostname) { + TX509Holder peerCert(SSL_get_peer_certificate(Ssl_.Get())); + if (!peerCert) { + ythrow TSslException(TStringBuf("SSL_get_peer_certificate(client)")); + } + + if (!CheckCertHostname(peerCert.Get(), Location_.Host)) { + ythrow TSslException(TStringBuf("CheckCertHostname(client)")); + } + } + } + + private: + const TParsedLocation Location_; + //TSslSessionHolder Session_; + }; + + static TConnCache* SocketCache() { + return Singleton<TConnCache>(); + } + + //some templates magic + template <class T> + static inline TAutoPtr<T> AutoPtr(T* t) noexcept { + return t; + } + + static inline TString ReadAll(THttpInput& in) { + TString ret; + ui64 clin; + + if (in.GetContentLength(clin)) { + const size_t cl = SafeIntegerCast<size_t>(clin); + + ret.ReserveAndResize(cl); + size_t sz = in.Load(ret.begin(), cl); + if (sz != cl) { + throw yexception() << TStringBuf("not full content: ") << sz << TStringBuf(" bytes from ") << cl; + } + } else if (in.HasContent()) { + TVector<char> buff(9500); //common jumbo frame size + + while (size_t len = in.Read(buff.data(), buff.size())) { + ret.AppendNoAlias(buff.data(), len); + } + } + + return ret; + } + + template <class TRequestType> + class THttpsRequest: public IJob { + public: + inline THttpsRequest(TSimpleHandleRef hndl, TMessage msg) + : Hndl_(hndl) + , Msg_(std::move(msg)) + , Loc_(Msg_.Addr) + , Addr_(CachedThrResolve(TResolveInfo(Loc_.Host, Loc_.GetPort()))) + { + } + + void DoRun(TCont* c) override { + THolder<THttpsRequest> This(this); + + if (c->Cancelled()) { + Hndl_->NotifyError(new TError("canceled", TError::TType::Cancelled)); + return; + } + + TErrorRef error; + THolder<TConnCache::TConnection> s(SocketCache()->Connect(c, Msg_.Addr, Addr_, &error)); + if (!s) { + Hndl_->NotifyError(error); + return; + } + + TSslClientIOStream io(TSslCtxClient::Instance(), Loc_, s->Fd(), Hndl_->CanceledPtr()); + TContBIOWatcher w(io, c); + TString received; + THttpHeaders headers; + TString firstLine; + + try { + io.Handshake(); + RequestData().SendTo(io); + Req_.Destroy(); + error = ProcessRecv(io, &received, &headers, &firstLine); + } catch (const TSystemError& e) { + if (c->Cancelled() || e.Status() == ECANCELED) { + error = new TError("canceled", TError::TType::Cancelled); + } else { + error = new TError(CurrentExceptionMessage()); + } + } catch (...) { + if (c->Cancelled()) { + error = new TError("canceled", TError::TType::Cancelled); + } else { + error = new TError(CurrentExceptionMessage()); + } + } + + if (error) { + Hndl_->NotifyError(error, received, firstLine, headers); + } else { + io.Shutdown(); + SocketCache()->Release(*s); + Hndl_->NotifyResponse(received, firstLine, headers); + } + } + + TErrorRef ProcessRecv(TSslClientIOStream& io, TString* data, THttpHeaders* headers, TString* firstLine) { + io.WaitUntilWritten(); + + Hndl_->SetSendComplete(); + + THttpInput in(&io); + *data = ReadAll(in); + *firstLine = in.FirstLine(); + *headers = in.Headers(); + + i32 code = ParseHttpRetCode(in.FirstLine()); + if (code < 200 || code > (!THttpsOptions::RedirectionNotError ? 299 : 399)) { + return new TError(TStringBuilder() << TStringBuf("request failed(") << in.FirstLine() << ')', TError::TType::ProtocolSpecific, code); + } + + return nullptr; + } + + const NHttp::TRequestData& RequestData() { + if (!Req_) { + Req_ = TRequestType::Build(Msg_, Loc_); + } + return *Req_; + } + + private: + TSimpleHandleRef Hndl_; + const TMessage Msg_; + const TParsedLocation Loc_; + const TResolvedHost* Addr_; + NHttp::TRequestData::TPtr Req_; + }; + + class TServer: public IRequester, public TContListener::ICallBack { + class TSslServerIOStream: public TSslIOStream, public TThrRefBase { + public: + TSslServerIOStream(TSslCtxServer& sslCtx, TSocketRef s) + : TSslIOStream(sslCtx, new TContBIO(*s)) + , S_(s) + { + } + + void Close(bool shutdown) { + if (shutdown) { + Shutdown(); + } + S_->Close(); + } + + void Handshake() override { + if (!Ssl_) { + Ssl_.Reset(SSL_new(SslCtx_)); + if (THttpsOptions::EnableSslServerDebug) { + SSL_set_info_callback(Ssl_.Get(), InfoCB); + } + + BIO_up_ref(*Connection_); // SSL_set_bio consumes only one reference if rbio and wbio are the same + SSL_set_bio(Ssl_.Get(), *Connection_, *Connection_); + + const int rc = SSL_accept(Ssl_.Get()); + if (1 != rc) { + ythrow TSslException(TStringBuf("SSL_accept"), Ssl_.Get(), rc); + } + } + + if (!SSL_is_init_finished(Ssl_.Get())) { + const int rc = SSL_do_handshake(Ssl_.Get()); + if (rc != 1) { + ythrow TSslException(TStringBuf("SSL_do_handshake"), Ssl_.Get(), rc); + } + } + } + + private: + TSocketRef S_; + }; + + class TJobsQueue: public TAutoOneConsumerPipeQueue<IJob>, public TThrRefBase { + }; + + typedef TIntrusivePtr<TJobsQueue> TJobsQueueRef; + + class TWrite: public IJob, public TData { + private: + template <class T> + static void WriteHeader(IOutputStream& os, TStringBuf name, T value) { + os << name << TStringBuf(": ") << value << TStringBuf("\r\n"); + } + + static void WriteHttpCode(IOutputStream& os, TMaybe<IRequest::TResponseError> error) { + if (!error.Defined()) { + os << HttpCodeStrEx(HttpCodes::HTTP_OK); + return; + } + + switch (*error) { + case IRequest::TResponseError::BadRequest: + os << HttpCodeStrEx(HttpCodes::HTTP_BAD_REQUEST); + break; + case IRequest::TResponseError::Forbidden: + os << HttpCodeStrEx(HttpCodes::HTTP_FORBIDDEN); + break; + case IRequest::TResponseError::NotExistService: + os << HttpCodeStrEx(HttpCodes::HTTP_NOT_FOUND); + break; + case IRequest::TResponseError::TooManyRequests: + os << HttpCodeStrEx(HttpCodes::HTTP_TOO_MANY_REQUESTS); + break; + case IRequest::TResponseError::InternalError: + os << HttpCodeStrEx(HttpCodes::HTTP_INTERNAL_SERVER_ERROR); + break; + case IRequest::TResponseError::NotImplemented: + os << HttpCodeStrEx(HttpCodes::HTTP_NOT_IMPLEMENTED); + break; + case IRequest::TResponseError::BadGateway: + os << HttpCodeStrEx(HttpCodes::HTTP_BAD_GATEWAY); + break; + case IRequest::TResponseError::ServiceUnavailable: + os << HttpCodeStrEx(HttpCodes::HTTP_SERVICE_UNAVAILABLE); + break; + case IRequest::TResponseError::BandwidthLimitExceeded: + os << HttpCodeStrEx(HttpCodes::HTTP_BANDWIDTH_LIMIT_EXCEEDED); + break; + case IRequest::TResponseError::MaxResponseError: + ythrow yexception() << TStringBuf("unknow type of error"); + } + } + + public: + inline TWrite(TData& data, const TString& compressionScheme, TIntrusivePtr<TSslServerIOStream> io, TServer* server, const TString& headers, int httpCode) + : CompressionScheme_(compressionScheme) + , IO_(io) + , Server_(server) + , Error_(TMaybe<IRequest::TResponseError>()) + , Headers_(headers) + , HttpCode_(httpCode) + { + swap(data); + } + + inline TWrite(TData& data, const TString& compressionScheme, TIntrusivePtr<TSslServerIOStream> io, TServer* server, IRequest::TResponseError error, const TString& headers) + : CompressionScheme_(compressionScheme) + , IO_(io) + , Server_(server) + , Error_(error) + , Headers_(headers) + , HttpCode_(0) + { + swap(data); + } + + void DoRun(TCont* c) override { + THolder<TWrite> This(this); + + try { + TContBIOWatcher w(*IO_, c); + + PrepareSocket(IO_->Socket()); + + char buf[128]; + TMemoryOutput mo(buf, sizeof(buf)); + + mo << TStringBuf("HTTP/1.1 "); + if (HttpCode_) { + mo << HttpCodeStrEx(HttpCode_); + } else { + WriteHttpCode(mo, Error_); + } + mo << TStringBuf("\r\n"); + + if (!CompressionScheme_.empty()) { + WriteHeader(mo, TStringBuf("Content-Encoding"), TStringBuf(CompressionScheme_)); + } + WriteHeader(mo, TStringBuf("Connection"), TStringBuf("Keep-Alive")); + WriteHeader(mo, TStringBuf("Content-Length"), size()); + + mo << Headers_; + + mo << TStringBuf("\r\n"); + + IO_->Write(buf, mo.Buf() - buf); + if (size()) { + IO_->Write(data(), size()); + } + + Server_->Enqueue(new TRead(IO_, Server_)); + } catch (...) { + } + } + + private: + const TString CompressionScheme_; + TIntrusivePtr<TSslServerIOStream> IO_; + TServer* Server_; + TMaybe<IRequest::TResponseError> Error_; + TString Headers_; + int HttpCode_; + }; + + class TRequest: public IHttpRequest { + public: + inline TRequest(THttpInput& in, TIntrusivePtr<TSslServerIOStream> io, TServer* server) + : IO_(io) + , Tmp_(in.FirstLine()) + , CompressionScheme_(in.BestCompressionScheme()) + , RemoteHost_(PrintHostByRfc(*GetPeerAddr(IO_->Socket()))) + , Headers_(in.Headers()) + , H_(Tmp_) + , Server_(server) + { + } + + ~TRequest() override { + if (!!IO_) { + try { + Server_->Enqueue(new TFail(IO_, Server_)); + } catch (...) { + } + } + } + + TStringBuf Scheme() const override { + return TStringBuf("https"); + } + + TString RemoteHost() const override { + return RemoteHost_; + } + + const THttpHeaders& Headers() const override { + return Headers_; + } + + TStringBuf Method() const override { + return H_.Method; + } + + TStringBuf Cgi() const override { + return H_.Cgi; + } + + TStringBuf Service() const override { + return TStringBuf(H_.Path).Skip(1); + } + + TStringBuf RequestId() const override { + return TStringBuf(); + } + + bool Canceled() const override { + if (!IO_) { + return false; + } + return !IsNotSocketClosedByOtherSide(IO_->Socket()); + } + + void SendReply(TData& data) override { + SendReply(data, TString(), HttpCodes::HTTP_OK); + } + + void SendReply(TData& data, const TString& headers, int httpCode) override { + const bool compressed = Compress(data); + Server_->Enqueue(new TWrite(data, compressed ? CompressionScheme_ : TString(), IO_, Server_, headers, httpCode)); + Y_UNUSED(IO_.Release()); + } + + void SendError(TResponseError error, const THttpErrorDetails& details) override { + TData data; + Server_->Enqueue(new TWrite(data, TString(), IO_, Server_, error, details.Headers)); + Y_UNUSED(IO_.Release()); + } + + private: + bool Compress(TData& data) const { + if (CompressionScheme_ == TStringBuf("gzip")) { + try { + TData gzipped(data.size()); + TMemoryOutput out(gzipped.data(), gzipped.size()); + TZLibCompress c(&out, ZLib::GZip); + c.Write(data.data(), data.size()); + c.Finish(); + gzipped.resize(out.Buf() - gzipped.data()); + data.swap(gzipped); + return true; + } catch (yexception&) { + // gzipped data occupies more space than original data + } + } + return false; + } + + private: + TIntrusivePtr<TSslServerIOStream> IO_; + const TString Tmp_; + const TString CompressionScheme_; + const TString RemoteHost_; + const THttpHeaders Headers_; + + protected: + TParsedHttpFull H_; + TServer* Server_; + }; + + class TGetRequest: public TRequest { + public: + inline TGetRequest(THttpInput& in, TIntrusivePtr<TSslServerIOStream> io, TServer* server) + : TRequest(in, io, server) + { + } + + TStringBuf Data() const override { + return H_.Cgi; + } + + TStringBuf Body() const override { + return TStringBuf(); + } + }; + + class TPostRequest: public TRequest { + public: + inline TPostRequest(THttpInput& in, TIntrusivePtr<TSslServerIOStream> io, TServer* server) + : TRequest(in, io, server) + , Data_(ReadAll(in)) + { + } + + TStringBuf Data() const override { + return Data_; + } + + TStringBuf Body() const override { + return Data_; + } + + private: + TString Data_; + }; + + class TFail: public IJob { + public: + inline TFail(TIntrusivePtr<TSslServerIOStream> io, TServer* server) + : IO_(io) + , Server_(server) + { + } + + void DoRun(TCont* c) override { + THolder<TFail> This(this); + constexpr TStringBuf answer = "HTTP/1.1 503 Service unavailable\r\n" + "Content-Length: 0\r\n\r\n"sv; + + try { + TContBIOWatcher w(*IO_, c); + IO_->Write(answer); + Server_->Enqueue(new TRead(IO_, Server_)); + } catch (...) { + } + } + + private: + TIntrusivePtr<TSslServerIOStream> IO_; + TServer* Server_; + }; + + class TRead: public IJob { + public: + TRead(TIntrusivePtr<TSslServerIOStream> io, TServer* server, bool selfRemove = false) + : IO_(io) + , Server_(server) + , SelfRemove(selfRemove) + { + } + + inline void operator()(TCont* c) { + try { + TContBIOWatcher w(*IO_, c); + + if (IO_->PollReadT(TDuration::Seconds(InputConnections()->UnusedConnKeepaliveTimeout()))) { + IO_->Close(true); + return; + } + + IO_->Handshake(); + THttpInput in(IO_.Get()); + + const char sym = *in.FirstLine().data(); + + if (sym == 'p' || sym == 'P') { + Server_->OnRequest(new TPostRequest(in, IO_, Server_)); + } else { + Server_->OnRequest(new TGetRequest(in, IO_, Server_)); + } + } catch (...) { + IO_->Close(false); + } + + if (SelfRemove) { + delete this; + } + } + + private: + void DoRun(TCont* c) override { + THolder<TRead> This(this); + (*this)(c); + } + + private: + TIntrusivePtr<TSslServerIOStream> IO_; + TServer* Server_; + bool SelfRemove = false; + }; + + public: + inline TServer(IOnRequest* cb, const TParsedLocation& loc) + : CB_(cb) + , E_(RealStackSize(16000)) + , L_(new TContListener(this, &E_, TContListener::TOptions().SetDeferAccept(true))) + , JQ_(new TJobsQueue()) + , SslCtx_(loc) + { + L_->Bind(TNetworkAddress(loc.GetPort())); + E_.Create<TServer, &TServer::RunDispatcher>(this, "dispatcher"); + Thrs_.push_back(Spawn<TServer, &TServer::Run>(this)); + } + + ~TServer() override { + JQ_->Enqueue(nullptr); + + for (size_t i = 0; i < Thrs_.size(); ++i) { + Thrs_[i]->Join(); + } + } + + void Run() { + //SetHighestThreadPriority(); + L_->Listen(); + E_.Execute(); + } + + inline void OnRequest(const IRequestRef& req) { + CB_->OnRequest(req); + } + + TJobsQueueRef& JobQueue() noexcept { + return JQ_; + } + + void Enqueue(IJob* j) { + JQ_->EnqueueSafe(TAutoPtr<IJob>(j)); + } + + void RunDispatcher(TCont* c) { + for (;;) { + TAutoPtr<IJob> job(JQ_->Dequeue(c)); + + if (!job) { + break; + } + + try { + c->Executor()->Create(*job, "https-job"); + Y_UNUSED(job.Release()); + } catch (...) { + } + } + + JQ_->Enqueue(nullptr); + c->Executor()->Abort(); + } + + void OnAcceptFull(const TAcceptFull& a) override { + try { + TSocketRef s(new TSharedSocket(*a.S)); + + if (InputConnections()->ExceedHardLimit()) { + s->Close(); + return; + } + + THolder<TRead> read(new TRead(new TSslServerIOStream(SslCtx_, s), this, /* selfRemove */ true)); + E_.Create(*read, "https-response"); + Y_UNUSED(read.Release()); + E_.Running()->Yield(); + } catch (...) { + } + } + + void OnError() override { + try { + throw; + } catch (const TSystemError& e) { + //crutch for prevent 100% busyloop (simple suspend listener/accepter) + if (e.Status() == EMFILE) { + E_.Running()->SleepT(TDuration::MilliSeconds(500)); + } + } + } + + private: + IOnRequest* CB_; + TContExecutor E_; + THolder<TContListener> L_; + TVector<TThreadRef> Thrs_; + TJobsQueueRef JQ_; + TSslCtxServer SslCtx_; + }; + + template <class T> + class THttpsProtocol: public IProtocol { + public: + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new TServer(cb, loc); + } + + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + TSimpleHandleRef ret(new TSimpleHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + try { + TAutoPtr<THttpsRequest<T>> req(new THttpsRequest<T>(ret, msg)); + JobQueue()->Schedule(req); + return ret.Get(); + } catch (...) { + ret->ResetOnRecv(); + throw; + } + } + + TStringBuf Scheme() const noexcept override { + return T::Name(); + } + + bool SetOption(TStringBuf name, TStringBuf value) override { + return THttpsOptions::Set(name, value); + } + }; + + struct TRequestGet: public NHttp::TRequestGet { + static inline TStringBuf Name() noexcept { + return TStringBuf("https"); + } + }; + + struct TRequestFull: public NHttp::TRequestFull { + static inline TStringBuf Name() noexcept { + return TStringBuf("fulls"); + } + }; + + struct TRequestPost: public NHttp::TRequestPost { + static inline TStringBuf Name() noexcept { + return TStringBuf("posts"); + } + }; + + } +} + +namespace NNeh { + IProtocol* SSLGetProtocol() { + return Singleton<NHttps::THttpsProtocol<NNeh::NHttps::TRequestGet>>(); + } + + IProtocol* SSLPostProtocol() { + return Singleton<NHttps::THttpsProtocol<NNeh::NHttps::TRequestPost>>(); + } + + IProtocol* SSLFullProtocol() { + return Singleton<NHttps::THttpsProtocol<NNeh::NHttps::TRequestFull>>(); + } + + void SetHttpOutputConnectionsLimits(size_t softLimit, size_t hardLimit) { + Y_VERIFY( + hardLimit > softLimit, + "invalid output fd limits; hardLimit=%" PRISZT ", softLimit=%" PRISZT, + hardLimit, softLimit); + + NHttps::SocketCache()->SetFdLimits(softLimit, hardLimit); + } + + void SetHttpInputConnectionsLimits(size_t softLimit, size_t hardLimit) { + Y_VERIFY( + hardLimit > softLimit, + "invalid output fd limits; hardLimit=%" PRISZT ", softLimit=%" PRISZT, + hardLimit, softLimit); + + NHttps::InputConnections()->SetFdLimits(softLimit, hardLimit); + } + + void SetHttpInputConnectionsTimeouts(unsigned minSec, unsigned maxSec) { + Y_VERIFY( + maxSec > minSec, + "invalid input fd limits timeouts; maxSec=%u, minSec=%u", + maxSec, minSec); + + NHttps::InputConnections()->MinUnusedConnKeepaliveTimeout.store(minSec, std::memory_order_release); + NHttps::InputConnections()->MaxUnusedConnKeepaliveTimeout.store(maxSec, std::memory_order_release); + } +} diff --git a/library/cpp/neh/https.h b/library/cpp/neh/https.h new file mode 100644 index 0000000000..6dbb5370d7 --- /dev/null +++ b/library/cpp/neh/https.h @@ -0,0 +1,47 @@ +#pragma once + +#include <contrib/libs/openssl/include/openssl/ossl_typ.h> + +#include <util/generic/string.h> +#include <util/generic/strbuf.h> + +#include <functional> + +namespace NNeh { + class IProtocol; + struct TParsedLocation; + + IProtocol* SSLGetProtocol(); + IProtocol* SSLPostProtocol(); + IProtocol* SSLFullProtocol(); + + /// if exceed soft limit, reduce quantity unused connections in cache + void SetHttpOutputConnectionsLimits(size_t softLimit, size_t hardLimit); + + /// if exceed soft limit, reduce keepalive time for unused connections + void SetHttpInputConnectionsLimits(size_t softLimit, size_t hardLimit); + + /// unused input sockets keepalive timeouts + /// real(used) timeout: + /// - max, if not reached soft limit + /// - min, if reached hard limit + /// - approx. linear changed[max..min], while conn. count in range [soft..hard] + void SetHttpInputConnectionsTimeouts(unsigned minSeconds, unsigned maxSeconds); + + struct THttpsOptions { + using TVerifyCallback = int (*)(int, X509_STORE_CTX*); + using TPasswordCallback = std::function<TString (const TParsedLocation&, const TString&, const TString&)>; + static TString CAFile; + static TString CAPath; + static TString ClientCertificate; + static TString ClientPrivateKey; + static TString ClientPrivateKeyPassword; + static bool CheckCertificateHostname; + static bool EnableSslServerDebug; + static bool EnableSslClientDebug; + static TVerifyCallback ClientVerifyCallback; + static TPasswordCallback KeyPasswdCallback; + static bool RedirectionNotError; + static bool Set(TStringBuf name, TStringBuf value); + }; +} diff --git a/library/cpp/neh/inproc.cpp b/library/cpp/neh/inproc.cpp new file mode 100644 index 0000000000..b124f38d17 --- /dev/null +++ b/library/cpp/neh/inproc.cpp @@ -0,0 +1,212 @@ +#include "inproc.h" + +#include "details.h" +#include "neh.h" +#include "location.h" +#include "utils.h" +#include "factory.h" + +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/singleton.h> +#include <util/stream/output.h> +#include <util/string/cast.h> + +using namespace NNeh; + +namespace { + const TString canceled = "canceled"; + + struct TInprocHandle: public TNotifyHandle { + inline TInprocHandle(const TMessage& msg, IOnRecv* r, TStatCollector* sc) noexcept + : TNotifyHandle(r, msg, sc) + , Canceled_(false) + , NotifyCnt_(0) + { + } + + bool MessageSendedCompletely() const noexcept override { + return true; + } + + void Cancel() noexcept override { + THandle::Cancel(); //inform stat collector + Canceled_ = true; + try { + if (MarkReplied()) { + NotifyError(new TError(canceled, TError::Cancelled)); + } + } catch (...) { + Cdbg << "inproc canc. " << CurrentExceptionMessage() << Endl; + } + } + + inline void SendReply(const TString& resp) { + if (MarkReplied()) { + NotifyResponse(resp); + } + } + + inline void SendError(const TString& details) { + if (MarkReplied()) { + NotifyError(new TError{details, TError::ProtocolSpecific, 1}); + } + } + + void Disable() { + F_ = nullptr; + MarkReplied(); + } + + inline bool Canceled() const noexcept { + return Canceled_; + } + + private: + //return true when mark first reply + inline bool MarkReplied() { + return AtomicAdd(NotifyCnt_, 1) == 1; + } + + private: + TAtomicBool Canceled_; + TAtomic NotifyCnt_; + }; + + typedef TIntrusivePtr<TInprocHandle> TInprocHandleRef; + + class TInprocLocation: public TParsedLocation { + public: + TInprocLocation(const TStringBuf& addr) + : TParsedLocation(addr) + { + Service.Split('?', InprocService, InprocId); + } + + TStringBuf InprocService; + TStringBuf InprocId; + }; + + class TRequest: public IRequest { + public: + TRequest(const TInprocHandleRef& hndl) + : Location(hndl->Message().Addr) + , Handle_(hndl) + { + } + + TStringBuf Scheme() const override { + return TStringBuf("inproc"); + } + + TString RemoteHost() const override { + return TString(); + } + + TStringBuf Service() const override { + return Location.InprocService; + } + + TStringBuf Data() const override { + return Handle_->Message().Data; + } + + TStringBuf RequestId() const override { + return Location.InprocId; + } + + bool Canceled() const override { + return Handle_->Canceled(); + } + + void SendReply(TData& data) override { + Handle_->SendReply(TString(data.data(), data.size())); + } + + void SendError(TResponseError, const TString& details) override { + Handle_->SendError(details); + } + + const TMessage Request; + const TInprocLocation Location; + + private: + TInprocHandleRef Handle_; + }; + + class TInprocRequester: public IRequester { + public: + TInprocRequester(IOnRequest*& rqcb) + : RegisteredCallback_(rqcb) + { + } + + ~TInprocRequester() override { + RegisteredCallback_ = nullptr; + } + + private: + IOnRequest*& RegisteredCallback_; + }; + + class TInprocRequesterStg: public IProtocol { + public: + inline TInprocRequesterStg() { + V_.resize(1 + (size_t)Max<ui16>()); + } + + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + IOnRequest*& rqcb = Find(loc); + + if (!rqcb) { + rqcb = cb; + } else if (rqcb != cb) { + ythrow yexception() << "shit happen - already registered"; + } + + return new TInprocRequester(rqcb); + } + + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + TInprocHandleRef hndl(new TInprocHandle(msg, fallback, !ss ? nullptr : new TStatCollector(ss))); + try { + TAutoPtr<TRequest> req(new TRequest(hndl)); + + if (IOnRequest* cb = Find(req->Location)) { + cb->OnRequest(req.Release()); + } else { + throw yexception() << TStringBuf("not found inproc location"); + } + } catch (...) { + hndl->Disable(); + throw; + } + + return THandleRef(hndl.Get()); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("inproc"); + } + + private: + static inline ui16 Id(const TParsedLocation& loc) { + return loc.GetPort(); + } + + inline IOnRequest*& Find(const TParsedLocation& loc) { + return Find(Id(loc)); + } + + inline IOnRequest*& Find(ui16 id) { + return V_[id]; + } + + private: + TVector<IOnRequest*> V_; + }; +} + +IProtocol* NNeh::InProcProtocol() { + return Singleton<TInprocRequesterStg>(); +} diff --git a/library/cpp/neh/inproc.h b/library/cpp/neh/inproc.h new file mode 100644 index 0000000000..546ef915ab --- /dev/null +++ b/library/cpp/neh/inproc.h @@ -0,0 +1,7 @@ +#pragma once + +namespace NNeh { + class IProtocol; + + IProtocol* InProcProtocol(); +} diff --git a/library/cpp/neh/jobqueue.cpp b/library/cpp/neh/jobqueue.cpp new file mode 100644 index 0000000000..8d02fd9bbf --- /dev/null +++ b/library/cpp/neh/jobqueue.cpp @@ -0,0 +1,79 @@ +#include "utils.h" +#include "lfqueue.h" +#include "jobqueue.h" +#include "pipequeue.h" + +#include <util/thread/factory.h> +#include <util/generic/singleton.h> +#include <util/system/thread.h> + +using namespace NNeh; + +namespace { + class TExecThread: public IThreadFactory::IThreadAble, public IJob { + public: + TExecThread() + : T_(SystemThreadFactory()->Run(this)) + { + } + + ~TExecThread() override { + Enqueue(this); + T_->Join(); + } + + inline void Enqueue(IJob* job) { + Q_.Enqueue(job); + } + + private: + void DoRun(TCont* c) override { + c->Executor()->Abort(); + } + + void DoExecute() override { + SetHighestThreadPriority(); + + TContExecutor e(RealStackSize(20000)); + + e.Execute<TExecThread, &TExecThread::Dispatcher>(this); + } + + inline void Dispatcher(TCont* c) { + IJob* job; + + while ((job = Q_.Dequeue(c))) { + try { + c->Executor()->Create(*job, "job"); + } catch (...) { + (*job)(c); + } + } + } + + typedef TAutoPtr<IThreadFactory::IThread> IThreadRef; + TOneConsumerPipeQueue<IJob> Q_; + IThreadRef T_; + }; + + class TJobScatter: public IJobQueue { + public: + inline TJobScatter() { + for (size_t i = 0; i < 2; ++i) { + E_.push_back(new TExecThread()); + } + } + + void ScheduleImpl(IJob* job) override { + E_[TThread::CurrentThreadId() % E_.size()]->Enqueue(job); + } + + private: + typedef TAutoPtr<TExecThread> TExecThreadRef; + TVector<TExecThreadRef> E_; + }; +} + +IJobQueue* NNeh::JobQueue() { + return Singleton<TJobScatter>(); +} diff --git a/library/cpp/neh/jobqueue.h b/library/cpp/neh/jobqueue.h new file mode 100644 index 0000000000..ec86458cf0 --- /dev/null +++ b/library/cpp/neh/jobqueue.h @@ -0,0 +1,41 @@ +#pragma once + +#include <library/cpp/coroutine/engine/impl.h> + +#include <util/generic/yexception.h> +#include <util/stream/output.h> + +namespace NNeh { + class IJob { + public: + inline void operator()(TCont* c) noexcept { + try { + DoRun(c); + } catch (...) { + Cdbg << "run " << CurrentExceptionMessage() << Endl; + } + } + + virtual ~IJob() { + } + + private: + virtual void DoRun(TCont* c) = 0; + }; + + class IJobQueue { + public: + template <class T> + inline void Schedule(T req) { + ScheduleImpl(req.Get()); + Y_UNUSED(req.Release()); + } + + virtual void ScheduleImpl(IJob* job) = 0; + + virtual ~IJobQueue() { + } + }; + + IJobQueue* JobQueue(); +} diff --git a/library/cpp/neh/lfqueue.h b/library/cpp/neh/lfqueue.h new file mode 100644 index 0000000000..c957047a99 --- /dev/null +++ b/library/cpp/neh/lfqueue.h @@ -0,0 +1,53 @@ +#pragma once + +#include <util/thread/lfqueue.h> +#include <util/generic/ptr.h> + +namespace NNeh { + template <class T> + class TAutoLockFreeQueue { + struct TCounter : TAtomicCounter { + inline void IncCount(const T* const&) { + Inc(); + } + + inline void DecCount(const T* const&) { + Dec(); + } + }; + + public: + typedef TAutoPtr<T> TRef; + + inline ~TAutoLockFreeQueue() { + TRef tmp; + + while (Dequeue(&tmp)) { + } + } + + inline bool Dequeue(TRef* t) { + T* res = nullptr; + + if (Q_.Dequeue(&res)) { + t->Reset(res); + + return true; + } + + return false; + } + + inline void Enqueue(TRef& t) { + Q_.Enqueue(t.Get()); + Y_UNUSED(t.Release()); + } + + inline size_t Size() { + return Q_.GetCounter().Val(); + } + + private: + TLockFreeQueue<T*, TCounter> Q_; + }; +} diff --git a/library/cpp/neh/location.cpp b/library/cpp/neh/location.cpp new file mode 100644 index 0000000000..36c0846673 --- /dev/null +++ b/library/cpp/neh/location.cpp @@ -0,0 +1,50 @@ +#include "location.h" + +#include <util/string/cast.h> + +using namespace NNeh; + +TParsedLocation::TParsedLocation(TStringBuf path) { + path.Split(':', Scheme, path); + path.Skip(2); + + const size_t pos = path.find_first_of(TStringBuf("?@")); + + if (TStringBuf::npos != pos && '@' == path[pos]) { + path.SplitAt(pos, UserInfo, path); + path.Skip(1); + } + + auto checkRange = [](size_t b, size_t e){ + return b != TStringBuf::npos && e != TStringBuf::npos && b < e; + }; + + size_t oBracket = path.find_first_of('['); + size_t cBracket = path.find_first_of(']'); + size_t endEndPointPos = path.find_first_of('/'); + if (checkRange(oBracket, cBracket)) { + endEndPointPos = path.find_first_of('/', cBracket); + } + EndPoint = path.SubStr(0, endEndPointPos); + Host = EndPoint; + + size_t lastColon = EndPoint.find_last_of(':'); + if (checkRange(cBracket, lastColon) + || (cBracket == TStringBuf::npos && lastColon != TStringBuf::npos)) + { + Host = EndPoint.SubStr(0, lastColon); + Port = EndPoint.SubStr(lastColon + 1, EndPoint.size() - lastColon + 1); + } + + if (endEndPointPos != TStringBuf::npos) { + Service = path.SubStr(endEndPointPos + 1, path.size() - endEndPointPos + 1); + } +} + +ui16 TParsedLocation::GetPort() const { + if (!Port) { + return TStringBuf("https") == Scheme || TStringBuf("fulls") == Scheme || TStringBuf("posts") == Scheme ? 443 : 80; + } + + return FromString<ui16>(Port); +} diff --git a/library/cpp/neh/location.h b/library/cpp/neh/location.h new file mode 100644 index 0000000000..8a9154785f --- /dev/null +++ b/library/cpp/neh/location.h @@ -0,0 +1,18 @@ +#pragma once + +#include <util/generic/strbuf.h> + +namespace NNeh { + struct TParsedLocation { + TParsedLocation(TStringBuf path); + + ui16 GetPort() const; + + TStringBuf Scheme; + TStringBuf UserInfo; + TStringBuf EndPoint; + TStringBuf Host; + TStringBuf Port; + TStringBuf Service; + }; +} diff --git a/library/cpp/neh/multi.cpp b/library/cpp/neh/multi.cpp new file mode 100644 index 0000000000..4929d4efd9 --- /dev/null +++ b/library/cpp/neh/multi.cpp @@ -0,0 +1,35 @@ +#include "rpc.h" +#include "multi.h" +#include "location.h" +#include "factory.h" + +#include <util/string/cast.h> +#include <util/generic/hash.h> + +using namespace NNeh; + +namespace { + namespace NMulti { + class TRequester: public IRequester { + public: + inline TRequester(const TListenAddrs& addrs, IOnRequest* cb) { + for (const auto& addr : addrs) { + TParsedLocation loc(addr); + IRequesterRef& req = R_[ToString(loc.Scheme) + ToString(loc.GetPort())]; + + if (!req) { + req = ProtocolFactory()->Protocol(loc.Scheme)->CreateRequester(cb, loc); + } + } + } + + private: + typedef THashMap<TString, IRequesterRef> TRequesters; + TRequesters R_; + }; + } +} + +IRequesterRef NNeh::MultiRequester(const TListenAddrs& addrs, IOnRequest* cb) { + return new NMulti::TRequester(addrs, cb); +} diff --git a/library/cpp/neh/multi.h b/library/cpp/neh/multi.h new file mode 100644 index 0000000000..ab99dd78f9 --- /dev/null +++ b/library/cpp/neh/multi.h @@ -0,0 +1,12 @@ +#pragma once + +#include "rpc.h" + +#include <util/generic/vector.h> +#include <util/generic/string.h> + +namespace NNeh { + typedef TVector<TString> TListenAddrs; + + IRequesterRef MultiRequester(const TListenAddrs& addrs, IOnRequest* rq); +} diff --git a/library/cpp/neh/multiclient.cpp b/library/cpp/neh/multiclient.cpp new file mode 100644 index 0000000000..cb7672755e --- /dev/null +++ b/library/cpp/neh/multiclient.cpp @@ -0,0 +1,378 @@ +#include "multiclient.h" +#include "utils.h" + +#include <library/cpp/containers/intrusive_rb_tree/rb_tree.h> + +#include <atomic> + +namespace { + using namespace NNeh; + + struct TCompareDeadline { + template <class T> + static inline bool Compare(const T& l, const T& r) noexcept { + return l.Deadline() < r.Deadline() || (l.Deadline() == r.Deadline() && &l < &r); + } + }; + + class TMultiClient: public IMultiClient, public TThrRefBase { + class TRequestSupervisor: public TRbTreeItem<TRequestSupervisor, TCompareDeadline>, public IOnRecv, public TThrRefBase, public TNonCopyable { + private: + TRequestSupervisor() { + } //disable + + public: + inline TRequestSupervisor(const TRequest& request, TMultiClient* mc) noexcept + : MC_(mc) + , Request_(request) + , Maked_(0) + , FinishOnMakeRequest_(0) + , Handled_(0) + , Dequeued_(false) + { + } + + inline TInstant Deadline() const noexcept { + return Request_.Deadline; + } + + //not thread safe (can be called at some time from TMultiClient::Request() and TRequestSupervisor::OnNotify()) + void OnMakeRequest(THandleRef h) noexcept { + //request can be mark as maked only once, so only one/first call set handle + if (AtomicCas(&Maked_, 1, 0)) { + H_.Swap(h); + //[paranoid mode on] make sure handle be initiated before return + AtomicSet(FinishOnMakeRequest_, 1); + } else { + while (!AtomicGet(FinishOnMakeRequest_)) { + SpinLockPause(); + } + //[paranoid mode off] + } + } + + void FillEvent(TEvent& ev) noexcept { + ev.Hndl = H_; + FillEventUserData(ev); + } + + void FillEventUserData(TEvent& ev) noexcept { + ev.UserData = Request_.UserData; + } + + void ResetRequest() noexcept { //destroy keepaliving cross-ref TRequestSupervisor<->THandle + H_.Drop(); + } + + //method OnProcessRequest() & OnProcessResponse() executed from Wait() context (thread) + void OnEndProcessRequest() { + Dequeued_ = true; + if (Y_UNLIKELY(IsHandled())) { + ResetRequest(); //race - response already handled before processing request from queue + } else { + MC_->RegisterRequest(this); + } + } + + void OnEndProcessResponse() { + if (Y_LIKELY(Dequeued_)) { + UnLink(); + ResetRequest(); + } //else request yet not dequeued/registered, so we not need unlink request + //(when we later dequeue request OnEndProcessRequest()...IsHandled() return true and we reset request) + } + + //IOnRecv interface + void OnNotify(THandle& h) override { + if (Y_LIKELY(MarkAsHandled())) { + THandleRef hr(&h); + OnMakeRequest(hr); //fix race with receiving response before return control from NNeh::Request() + MC_->ScheduleResponse(this, hr); + } + } + + void OnRecv(THandle&) noexcept override { + UnRef(); + } + + void OnEnd() noexcept override { + UnRef(); + } + // + + //request can be handled only once, so only one/first call MarkAsHandled() return true + bool MarkAsHandled() noexcept { + return AtomicCas(&Handled_, 1, 0); + } + + bool IsHandled() const noexcept { + return AtomicGet(Handled_); + } + + private: + TIntrusivePtr<TMultiClient> MC_; + TRequest Request_; + THandleRef H_; + TAtomic Maked_; + TAtomic FinishOnMakeRequest_; + TAtomic Handled_; + bool Dequeued_; + }; + + typedef TRbTree<TRequestSupervisor, TCompareDeadline> TRequestsSupervisors; + typedef TIntrusivePtr<TRequestSupervisor> TRequestSupervisorRef; + + public: + TMultiClient() + : Interrupt_(false) + , NearDeadline_(TInstant::Max().GetValue()) + , E_(::TSystemEvent::rAuto) + , Shutdown_(false) + { + } + + struct TResetRequest { + inline void operator()(TRequestSupervisor& rs) const noexcept { + rs.ResetRequest(); + } + }; + + void Shutdown() { + //reset THandleRef's for all exist supervisors and jobs queue (+prevent creating new) + //- so we break crossref-chain, which prevent destroy this object THande->TRequestSupervisor->TMultiClient) + Shutdown_ = true; + RS_.ForEachNoOrder(TResetRequest()); + RS_.Clear(); + CleanQueue(); + } + + private: + class IJob { + public: + virtual ~IJob() { + } + virtual bool Process(TEvent&) = 0; + virtual void Cancel() = 0; + }; + typedef TAutoPtr<IJob> TJobPtr; + + class TNewRequest: public IJob { + public: + TNewRequest(TRequestSupervisorRef& rs) + : RS_(rs) + { + } + + private: + bool Process(TEvent&) override { + RS_->OnEndProcessRequest(); + return false; + } + + void Cancel() override { + RS_->ResetRequest(); + } + + TRequestSupervisorRef RS_; + }; + + class TNewResponse: public IJob { + public: + TNewResponse(TRequestSupervisor* rs, THandleRef& h) noexcept + : RS_(rs) + , H_(h) + { + } + + private: + bool Process(TEvent& ev) override { + ev.Type = TEvent::Response; + ev.Hndl = H_; + RS_->FillEventUserData(ev); + RS_->OnEndProcessResponse(); + return true; + } + + void Cancel() override { + RS_->ResetRequest(); + } + + TRequestSupervisorRef RS_; + THandleRef H_; + }; + + public: + THandleRef Request(const TRequest& request) override { + TIntrusivePtr<TRequestSupervisor> rs(new TRequestSupervisor(request, this)); + THandleRef h; + try { + rs->Ref(); + h = NNeh::Request(request.Msg, rs.Get()); + //accurately handle race when processing new request event + //(we already can receive response (call OnNotify) before we schedule info about new request here) + } catch (...) { + rs->UnRef(); + throw; + } + rs->OnMakeRequest(h); + ScheduleRequest(rs, h, request.Deadline); + return h; + } + + bool Wait(TEvent& ev, const TInstant deadline_ = TInstant::Max()) override { + while (!Interrupt_) { + TInstant deadline = deadline_; + const TInstant now = TInstant::Now(); + if (deadline != TInstant::Max() && now >= deadline) { + break; + } + + { //process jobs queue (requests/responses info) + TAutoPtr<IJob> j; + while (JQ_.Dequeue(&j)) { + if (j->Process(ev)) { + return true; + } + } + } + + if (!RS_.Empty()) { + TRequestSupervisor* nearRS = &*RS_.Begin(); + if (nearRS->Deadline() <= now) { + if (!nearRS->MarkAsHandled()) { + //race with notify, - now in queue must exist response job for this request + continue; + } + ev.Type = TEvent::Timeout; + nearRS->FillEvent(ev); + nearRS->ResetRequest(); + nearRS->UnLink(); + return true; + } + deadline = Min(nearRS->Deadline(), deadline); + } + + if (SetNearDeadline(deadline)) { + continue; //update deadline to more far time, so need re-check queue for avoiding race + } + + E_.WaitD(deadline); + } + Interrupt_ = false; + return false; + } + + void Interrupt() override { + Interrupt_ = true; + Signal(); + } + + size_t QueueSize() override { + return JQ_.Size(); + } + + private: + void Signal() { + //TODO:try optimize - hack with skipping signaling if not have waiters (reduce mutex usage) + E_.Signal(); + } + + void ScheduleRequest(TIntrusivePtr<TRequestSupervisor>& rs, const THandleRef& h, const TInstant& deadline) { + TJobPtr j(new TNewRequest(rs)); + JQ_.Enqueue(j); + if (!h->Signalled) { + if (deadline.GetValue() < GetNearDeadline_()) { + Signal(); + } + } + } + + void ScheduleResponse(TRequestSupervisor* rs, THandleRef& h) { + TJobPtr j(new TNewResponse(rs, h)); + JQ_.Enqueue(j); + if (Y_UNLIKELY(Shutdown_)) { + CleanQueue(); + } else { + Signal(); + } + } + + //return true, if deadline re-installed to more late time + bool SetNearDeadline(const TInstant& deadline) { + bool deadlineMovedFurther = deadline.GetValue() > GetNearDeadline_(); + SetNearDeadline_(deadline.GetValue()); + return deadlineMovedFurther; + } + + //used only from Wait() + void RegisterRequest(TRequestSupervisor* rs) { + if (rs->Deadline() != TInstant::Max()) { + RS_.Insert(rs); + } else { + rs->ResetRequest(); //prevent blocking destruction 'endless' requests + } + } + + void CleanQueue() { + TAutoPtr<IJob> j; + while (JQ_.Dequeue(&j)) { + j->Cancel(); + } + } + + private: + void SetNearDeadline_(const TInstant::TValue& v) noexcept { + TGuard<TAdaptiveLock> g(NDLock_); + NearDeadline_.store(v, std::memory_order_release); + } + + TInstant::TValue GetNearDeadline_() const noexcept { + TGuard<TAdaptiveLock> g(NDLock_); + return NearDeadline_.load(std::memory_order_acquire); + } + + NNeh::TAutoLockFreeQueue<IJob> JQ_; + TAtomicBool Interrupt_; + TRequestsSupervisors RS_; + TAdaptiveLock NDLock_; + std::atomic<TInstant::TValue> NearDeadline_; + ::TSystemEvent E_; + TAtomicBool Shutdown_; + }; + + class TMultiClientAutoShutdown: public IMultiClient { + public: + TMultiClientAutoShutdown() + : MC_(new TMultiClient()) + { + } + + ~TMultiClientAutoShutdown() override { + MC_->Shutdown(); + } + + size_t QueueSize() override { + return MC_->QueueSize(); + } + + private: + THandleRef Request(const TRequest& req) override { + return MC_->Request(req); + } + + bool Wait(TEvent& ev, TInstant deadline = TInstant::Max()) override { + return MC_->Wait(ev, deadline); + } + + void Interrupt() override { + return MC_->Interrupt(); + } + + private: + TIntrusivePtr<TMultiClient> MC_; + }; +} + +TMultiClientPtr NNeh::CreateMultiClient() { + return new TMultiClientAutoShutdown(); +} diff --git a/library/cpp/neh/multiclient.h b/library/cpp/neh/multiclient.h new file mode 100644 index 0000000000..e12b73dcd9 --- /dev/null +++ b/library/cpp/neh/multiclient.h @@ -0,0 +1,72 @@ +#pragma once + +#include "neh.h" + +namespace NNeh { + /// thread-safe dispacher for processing multiple neh requests + /// (method Wait() MUST be called from single thread, methods Request and Interrupt are thread-safe) + class IMultiClient { + public: + virtual ~IMultiClient() { + } + + struct TRequest { + TRequest() + : Deadline(TInstant::Max()) + , UserData(nullptr) + { + } + + TRequest(const TMessage& msg, TInstant deadline = TInstant::Max(), void* userData = nullptr) + : Msg(msg) + , Deadline(deadline) + , UserData(userData) + { + } + + TMessage Msg; + TInstant Deadline; + void* UserData; + }; + + /// WARNING: + /// Wait(event) called from another thread can return Event + /// for this request before this call return control + virtual THandleRef Request(const TRequest& req) = 0; + + virtual size_t QueueSize() = 0; + + struct TEvent { + enum TType { + Timeout, + Response, + SizeEventType + }; + + TEvent() + : Type(SizeEventType) + , UserData(nullptr) + { + } + + TEvent(TType t, void* userData) + : Type(t) + , UserData(userData) + { + } + + TType Type; + THandleRef Hndl; + void* UserData; + }; + + /// return false if interrupted + virtual bool Wait(TEvent&, TInstant = TInstant::Max()) = 0; + /// interrupt guaranteed breaking execution Wait(), but few interrupts can be handled as one + virtual void Interrupt() = 0; + }; + + typedef TAutoPtr<IMultiClient> TMultiClientPtr; + + TMultiClientPtr CreateMultiClient(); +} diff --git a/library/cpp/neh/neh.cpp b/library/cpp/neh/neh.cpp new file mode 100644 index 0000000000..2a3eef5023 --- /dev/null +++ b/library/cpp/neh/neh.cpp @@ -0,0 +1,146 @@ +#include "neh.h" + +#include "details.h" +#include "factory.h" + +#include <util/generic/list.h> +#include <util/generic/hash_set.h> +#include <util/digest/numeric.h> +#include <util/string/cast.h> + +using namespace NNeh; + +namespace { + class TMultiRequester: public IMultiRequester { + struct TOps { + template <class T> + inline bool operator()(const T& l, const T& r) const noexcept { + return l.Get() == r.Get(); + } + + template <class T> + inline size_t operator()(const T& t) const noexcept { + return NumericHash(t.Get()); + } + }; + + struct TOnComplete { + TMultiRequester* Parent; + bool Signalled; + + inline TOnComplete(TMultiRequester* parent) + : Parent(parent) + , Signalled(false) + { + } + + inline void operator()(TWaitHandle* wh) { + THandleRef req(static_cast<THandle*>(wh)); + + Signalled = true; + Parent->OnComplete(req); + } + }; + + public: + void Add(const THandleRef& req) override { + Reqs_.insert(req); + } + + void Del(const THandleRef& req) override { + Reqs_.erase(req); + } + + bool Wait(THandleRef& req, TInstant deadLine) override { + while (Complete_.empty()) { + if (Reqs_.empty()) { + return false; + } + + TOnComplete cb(this); + + WaitForMultipleObj(Reqs_.begin(), Reqs_.end(), deadLine, cb); + + if (!cb.Signalled) { + return false; + } + } + + req = *Complete_.begin(); + Complete_.pop_front(); + + return true; + } + + bool IsEmpty() const override { + return Reqs_.empty() && Complete_.empty(); + } + + inline void OnComplete(const THandleRef& req) { + Complete_.push_back(req); + Reqs_.erase(req); + } + + private: + typedef THashSet<THandleRef, TOps, TOps> TReqs; + typedef TList<THandleRef> TComplete; + TReqs Reqs_; + TComplete Complete_; + }; + + inline IProtocol* ProtocolForMessage(const TMessage& msg) { + return ProtocolFactory()->Protocol(TStringBuf(msg.Addr).Before(':')); + } +} + +NNeh::TMessage NNeh::TMessage::FromString(const TStringBuf req) { + TStringBuf addr; + TStringBuf data; + + req.Split('?', addr, data); + return TMessage(ToString(addr), ToString(data)); +} + +namespace { + const TString svcFail = "service status: failed"; +} + +THandleRef NNeh::Request(const TMessage& msg, IOnRecv* fallback, bool useAsyncSendRequest) { + TServiceStatRef ss; + + if (TServiceStat::Disabled()) { + return ProtocolForMessage(msg)->ScheduleAsyncRequest(msg, fallback, ss, useAsyncSendRequest); + } + + ss = GetServiceStat(msg.Addr); + TServiceStat::EStatus es = ss->GetStatus(); + + if (es == TServiceStat::Ok) { + return ProtocolForMessage(msg)->ScheduleAsyncRequest(msg, fallback, ss, useAsyncSendRequest); + } + + if (es == TServiceStat::ReTry) { + //send empty data request for validating service (update TServiceStat info) + TMessage validator; + + validator.Addr = msg.Addr; + + ProtocolForMessage(msg)->ScheduleAsyncRequest(validator, nullptr, ss, useAsyncSendRequest); + } + + TNotifyHandleRef h(new TNotifyHandle(fallback, msg)); + h->NotifyError(new TError(svcFail)); + return h.Get(); +} + +THandleRef NNeh::Request(const TString& req, IOnRecv* fallback) { + return Request(TMessage::FromString(req), fallback); +} + +IMultiRequesterRef NNeh::CreateRequester() { + return new TMultiRequester(); +} + +bool NNeh::SetProtocolOption(TStringBuf protoOption, TStringBuf value) { + return ProtocolFactory()->Protocol(protoOption.Before('/'))->SetOption(protoOption.After('/'), value); +} diff --git a/library/cpp/neh/neh.h b/library/cpp/neh/neh.h new file mode 100644 index 0000000000..e0211a7dff --- /dev/null +++ b/library/cpp/neh/neh.h @@ -0,0 +1,320 @@ +#pragma once + +#include "wfmo.h" +#include "stat.h" + +#include <library/cpp/http/io/headers.h> + +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/datetime/base.h> + +#include <utility> + +namespace NNeh { + struct TMessage { + TMessage() = default; + + inline TMessage(TString addr, TString data) + : Addr(std::move(addr)) + , Data(std::move(data)) + { + } + + static TMessage FromString(TStringBuf request); + + TString Addr; + TString Data; + }; + + using TMessageRef = TAutoPtr<TMessage>; + + struct TError { + public: + enum TType { + UnknownType, + Cancelled, + ProtocolSpecific + }; + + TError(TString text, TType type = UnknownType, i32 code = 0, i32 systemCode = 0) + : Type(std::move(type)) + , Code(code) + , Text(text) + , SystemCode(systemCode) + { + } + + TType Type = UnknownType; + i32 Code = 0; // protocol specific code (example(http): 404) + TString Text; + i32 SystemCode = 0; // system error code + }; + + using TErrorRef = TAutoPtr<TError>; + + struct TResponse; + using TResponseRef = TAutoPtr<TResponse>; + + struct TResponse { + inline TResponse(TMessage req, + TString data, + const TDuration duration) + : TResponse(std::move(req), std::move(data), duration, {} /* firstLine */, {} /* headers */, {} /* error */) + { + } + + inline TResponse(TMessage req, + TString data, + const TDuration duration, + TString firstLine, + THttpHeaders headers) + : TResponse(std::move(req), std::move(data), duration, std::move(firstLine), std::move(headers), {} /* error */) + { + } + + inline TResponse(TMessage req, + TString data, + const TDuration duration, + TString firstLine, + THttpHeaders headers, + TErrorRef error) + : Request(std::move(req)) + , Data(std::move(data)) + , Duration(duration) + , FirstLine(std::move(firstLine)) + , Headers(std::move(headers)) + , Error_(std::move(error)) + { + } + + inline static TResponseRef FromErrorText(TMessage msg, TString error, const TDuration duration) { + return new TResponse(std::move(msg), {} /* data */, duration, {} /* firstLine */, {} /* headers */, new TError(std::move(error))); + } + + inline static TResponseRef FromError(TMessage msg, TErrorRef error, const TDuration duration) { + return new TResponse(std::move(msg), {} /* data */, duration, {} /* firstLine */, {} /* headers */, error); + } + + inline static TResponseRef FromError(TMessage msg, TErrorRef error, const TDuration duration, + TString data, TString firstLine, THttpHeaders headers) + { + return new TResponse(std::move(msg), std::move(data), duration, std::move(firstLine), std::move(headers), error); + } + + inline static TResponseRef FromError( + TMessage msg, + TErrorRef error, + TString data, + const TDuration duration, + TString firstLine, + THttpHeaders headers) + { + return new TResponse(std::move(msg), std::move(data), duration, std::move(firstLine), std::move(headers), error); + } + + inline bool IsError() const { + return Error_.Get(); + } + + inline TError::TType GetErrorType() const { + return Error_.Get() ? Error_->Type : TError::UnknownType; + } + + inline i32 GetErrorCode() const { + return Error_.Get() ? Error_->Code : 0; + } + + inline i32 GetSystemErrorCode() const { + return Error_.Get() ? Error_->SystemCode : 0; + } + + inline TString GetErrorText() const { + return Error_.Get() ? Error_->Text : TString(); + } + + const TMessage Request; + const TString Data; + const TDuration Duration; + const TString FirstLine; + THttpHeaders Headers; + + private: + THolder<TError> Error_; + }; + + class THandle; + + class IOnRecv { + public: + virtual ~IOnRecv() = default; + + virtual void OnNotify(THandle&) { + } //callback on receive response + virtual void OnEnd() { + } //response was extracted by Wait() method, - OnRecv() will not be called + virtual void OnRecv(THandle& resp) = 0; //callback on destroy handler + }; + + class THandle: public TThrRefBase, public TWaitHandle { + public: + inline THandle(IOnRecv* f, TStatCollector* s = nullptr) noexcept + : F_(f) + , Stat_(s) + { + } + + ~THandle() override { + if (F_) { + try { + F_->OnRecv(*this); + } catch (...) { + } + } + } + + virtual bool MessageSendedCompletely() const noexcept { + //TODO + return true; + } + + virtual void Cancel() noexcept { + //TODO + if (!!Stat_) + Stat_->OnCancel(); + } + + inline const TResponse* Response() const noexcept { + return R_.Get(); + } + + //method MUST be called only after success Wait() for this handle or from callback IOnRecv::OnRecv() + //else exist chance for memory leak (race between Get()/Notify()) + inline TResponseRef Get() noexcept { + return R_; + } + + inline bool Wait(TResponseRef& msg, const TInstant deadLine) { + if (WaitForOne(*this, deadLine)) { + if (F_) { + F_->OnEnd(); + F_ = nullptr; + } + msg = Get(); + + return true; + } + + return false; + } + + inline bool Wait(TResponseRef& msg, const TDuration timeOut) { + return Wait(msg, timeOut.ToDeadLine()); + } + + inline bool Wait(TResponseRef& msg) { + return Wait(msg, TInstant::Max()); + } + + inline TResponseRef Wait(const TInstant deadLine) { + TResponseRef ret; + + Wait(ret, deadLine); + + return ret; + } + + inline TResponseRef Wait(const TDuration timeOut) { + return Wait(timeOut.ToDeadLine()); + } + + inline TResponseRef Wait() { + return Wait(TInstant::Max()); + } + + protected: + inline void Notify(TResponseRef resp) { + if (!!Stat_) { + if (!resp || resp->IsError()) { + Stat_->OnFail(); + } else { + Stat_->OnSuccess(); + } + } + R_.Swap(resp); + if (F_) { + try { + F_->OnNotify(*this); + } catch (...) { + } + } + Signal(); + } + + IOnRecv* F_; + + private: + TResponseRef R_; + THolder<TStatCollector> Stat_; + }; + + using THandleRef = TIntrusivePtr<THandle>; + + THandleRef Request(const TMessage& msg, IOnRecv* fallback, bool useAsyncSendRequest = false); + + inline THandleRef Request(const TMessage& msg) { + return Request(msg, nullptr); + } + + THandleRef Request(const TString& req, IOnRecv* fallback); + + inline THandleRef Request(const TString& req) { + return Request(req, nullptr); + } + + class IMultiRequester { + public: + virtual ~IMultiRequester() = default; + + virtual void Add(const THandleRef& req) = 0; + virtual void Del(const THandleRef& req) = 0; + virtual bool Wait(THandleRef& req, TInstant deadLine) = 0; + virtual bool IsEmpty() const = 0; + + inline void Schedule(const TString& req) { + Add(Request(req)); + } + + inline bool Wait(THandleRef& req, TDuration timeOut) { + return Wait(req, timeOut.ToDeadLine()); + } + + inline bool Wait(THandleRef& req) { + return Wait(req, TInstant::Max()); + } + + inline bool Wait(TResponseRef& resp, TInstant deadLine) { + THandleRef req; + + while (Wait(req, deadLine)) { + resp = req->Get(); + + if (!!resp) { + return true; + } + } + + return false; + } + + inline bool Wait(TResponseRef& resp) { + return Wait(resp, TInstant::Max()); + } + }; + + using IMultiRequesterRef = TAutoPtr<IMultiRequester>; + + IMultiRequesterRef CreateRequester(); + + bool SetProtocolOption(TStringBuf protoOption, TStringBuf value); +} diff --git a/library/cpp/neh/netliba.cpp b/library/cpp/neh/netliba.cpp new file mode 100644 index 0000000000..f69906f3ba --- /dev/null +++ b/library/cpp/neh/netliba.cpp @@ -0,0 +1,508 @@ +#include "details.h" +#include "factory.h" +#include "http_common.h" +#include "location.h" +#include "multi.h" +#include "netliba.h" +#include "netliba_udp_http.h" +#include "lfqueue.h" +#include "utils.h" + +#include <library/cpp/dns/cache.h> + +#include <util/generic/hash.h> +#include <util/generic/singleton.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/string/cast.h> +#include <util/system/yassert.h> + +#include <atomic> + +using namespace NDns; +using namespace NNeh; +namespace NNeh { + size_t TNetLibaOptions::ClientThreads = 4; + TDuration TNetLibaOptions::AckTailEffect = TDuration::Seconds(30); + + bool TNetLibaOptions::Set(TStringBuf name, TStringBuf value) { +#define NETLIBA_TRY_SET(optType, optName) \ + if (name == TStringBuf(#optName)) { \ + optName = FromString<optType>(value); \ + } + + NETLIBA_TRY_SET(size_t, ClientThreads) + else NETLIBA_TRY_SET(TDuration, AckTailEffect) else { + return false; + } + return true; + } +} + +namespace { + namespace NNetLiba { + using namespace NNetliba; + using namespace NNehNetliba; + + typedef NNehNetliba::IRequester INetLibaRequester; + typedef TAutoPtr<TUdpHttpRequest> TUdpHttpRequestPtr; + typedef TAutoPtr<TUdpHttpResponse> TUdpHttpResponsePtr; + + static inline const addrinfo* FindIPBase(const TNetworkAddress* addr, int family) { + for (TNetworkAddress::TIterator it = addr->Begin(); it != addr->End(); ++it) { + if (it->ai_family == family) { + return &*it; + } + } + + return nullptr; + } + + static inline const sockaddr_in6& FindIP(const TNetworkAddress* addr) { + //prefer ipv6 + const addrinfo* ret = FindIPBase(addr, AF_INET6); + + if (!ret) { + ret = FindIPBase(addr, AF_INET); + } + + if (!ret) { + ythrow yexception() << "ip not supported by " << *addr; + } + + return *(const sockaddr_in6*)(ret->ai_addr); + } + + class TLastAckTimes { + struct TTimeVal { + TTimeVal() + : Val(0) + { + } + + std::atomic<TInstant::TValue> Val; + }; + + public: + TInstant::TValue Get(size_t idAddr) { + return Tm_.Get(idAddr).Val.load(std::memory_order_acquire); + } + + void Set(size_t idAddr) { + Tm_.Get(idAddr).Val.store(TInstant::Now().GetValue(), std::memory_order_release); + } + + static TLastAckTimes& Common() { + return *Singleton<TLastAckTimes>(); + } + + private: + NNeh::NHttp::TLockFreeSequence<TTimeVal> Tm_; + }; + + class TRequest: public TSimpleHandle { + public: + inline TRequest(TIntrusivePtr<INetLibaRequester>& r, size_t idAddr, const TMessage& msg, IOnRecv* cb, TStatCollector* s) + : TSimpleHandle(cb, msg, s) + , R_(r) + , IdAddr_(idAddr) + , Notified_(false) + { + CreateGuid(&Guid_); + } + + void Cancel() noexcept override { + TSimpleHandle::Cancel(); + R_->CancelRequest(Guid_); + } + + inline const TString& Addr() const noexcept { + return Message().Addr; + } + + inline const TGUID& Guid() const noexcept { + return Guid_; + } + + //return false if already notifie + inline bool SetNotified() noexcept { + bool ret = Notified_; + Notified_ = true; + return !ret; + } + + void OnSend() { + if (TNetLibaOptions::AckTailEffect.GetValue() && TLastAckTimes::Common().Get(IdAddr_) + TNetLibaOptions::AckTailEffect.GetValue() > TInstant::Now().GetValue()) { + //fake(predicted) completing detection + SetSendComplete(); + } + } + + void OnRequestAck() { + if (TNetLibaOptions::AckTailEffect.GetValue()) { + TLastAckTimes::Common().Set(IdAddr_); + } + SetSendComplete(); + } + + private: + TIntrusivePtr<INetLibaRequester> R_; + size_t IdAddr_; + TGUID Guid_; + bool Notified_; + }; + + typedef TIntrusivePtr<TRequest> TRequestRef; + + class TNetLibaBus { + class TEventsHandler: public IEventsCollector { + typedef THashMap<TGUID, TRequestRef, TGUIDHash> TInFly; + + public: + inline void OnSend(TRequestRef& req) { + Q_.Enqueue(req); + req->OnSend(); + } + + private: + void UpdateInFly() { + TRequestRef req; + + while (Q_.Dequeue(&req)) { + if (!req) { + return; + } + + InFly_[req->Guid()] = req; + } + } + + void AddRequest(TUdpHttpRequest* req) override { + //ignore received requests in client + delete req; + } + + void AddResponse(TUdpHttpResponse* resp) override { + TUdpHttpResponsePtr ptr(resp); + + UpdateInFly(); + TInFly::iterator it = InFly_.find(resp->ReqId); + + Y_VERIFY(it != InFly_.end(), "incorrect incoming message"); + + TRequestRef& req = it->second; + + if (req->SetNotified()) { + if (resp->Ok == TUdpHttpResponse::OK) { + req->NotifyResponse(TString(resp->Data.data(), resp->Data.size())); + } else { + if (resp->Ok == TUdpHttpResponse::CANCELED) { + req->NotifyError(new TError(resp->Error, TError::Cancelled)); + } else { + req->NotifyError(new TError(resp->Error)); + } + } + } + + InFly_.erase(it); + } + + void AddCancel(const TGUID& guid) override { + UpdateInFly(); + TInFly::iterator it = InFly_.find(guid); + + if (it != InFly_.end() && it->second->SetNotified()) { + it->second->NotifyError("Canceled (before ack)"); + } + } + + void AddRequestAck(const TGUID& guid) override { + UpdateInFly(); + TInFly::iterator it = InFly_.find(guid); + + Y_VERIFY(it != InFly_.end(), "incorrect complete notification"); + + it->second->OnRequestAck(); + } + + private: + TLockFreeQueue<TRequestRef> Q_; + TInFly InFly_; + }; + + struct TClientThread { + TClientThread(int physicalCpu) + : EH_(new TEventsHandler()) + , R_(CreateHttpUdpRequester(0, IEventsCollectorRef(EH_.Get()), physicalCpu)) + { + R_->EnableReportRequestAck(); + } + + ~TClientThread() { + R_->StopNoWait(); + } + + TIntrusivePtr<TEventsHandler> EH_; + TIntrusivePtr<INetLibaRequester> R_; + }; + + public: + TNetLibaBus() { + for (size_t i = 0; i < TNetLibaOptions::ClientThreads; ++i) { + Clnt_.push_back(new TClientThread(i)); + } + } + + inline THandleRef Schedule(const TMessage& msg, IOnRecv* cb, TServiceStatRef& ss) { + TParsedLocation loc(msg.Addr); + TUdpAddress addr; + + const TResolvedHost* resHost = CachedResolve(TResolveInfo(loc.Host, loc.GetPort())); + GetUdpAddress(&addr, FindIP(&resHost->Addr)); + + TClientThread& clnt = *Clnt_[resHost->Id % Clnt_.size()]; + TIntrusivePtr<INetLibaRequester> rr = clnt.R_; + TRequestRef req(new TRequest(rr, resHost->Id, msg, cb, !ss ? nullptr : new TStatCollector(ss))); + + clnt.EH_->OnSend(req); + rr->SendRequest(addr, ToString(loc.Service), msg.Data, req->Guid()); + + return THandleRef(req.Get()); + } + + private: + TVector<TAutoPtr<TClientThread>> Clnt_; + }; + + //server + class TRequester: public TThrRefBase { + struct TSrvRequestState: public TAtomicRefCount<TSrvRequestState> { + TSrvRequestState() + : Canceled(false) + { + } + + TAtomicBool Canceled; + }; + + class TRequest: public IRequest { + public: + inline TRequest(TAutoPtr<TUdpHttpRequest> req, TIntrusivePtr<TSrvRequestState> state, TRequester* parent) + : R_(req) + , S_(state) + , P_(parent) + { + } + + ~TRequest() override { + if (!!P_) { + P_->RequestProcessed(this); + } + } + + TStringBuf Scheme() const override { + return TStringBuf("netliba"); + } + + TString RemoteHost() const override { + if (!H_) { + TUdpAddress tmp(R_->PeerAddress); + tmp.Scope = 0; //discard scope from serialized addr + + TString addr = GetAddressAsString(tmp); + + TStringBuf host, port; + + TStringBuf(addr).RSplit(':', host, port); + H_ = host; + } + return H_; + } + + TStringBuf Service() const override { + return TStringBuf(R_->Url.c_str(), R_->Url.length()); + } + + TStringBuf Data() const override { + return TStringBuf((const char*)R_->Data.data(), R_->Data.size()); + } + + TStringBuf RequestId() const override { + const TGUID& g = R_->ReqId; + + return TStringBuf((const char*)g.dw, sizeof(g.dw)); + } + + bool Canceled() const override { + return S_->Canceled; + } + + void SendReply(TData& data) override { + TIntrusivePtr<TRequester> p; + p.Swap(P_); + if (!!p) { + if (!Canceled()) { + p->R_->SendResponse(R_->ReqId, &data); + } + p->RequestProcessed(this); + } + } + + void SendError(TResponseError, const TString&) override { + // TODO + } + + inline const TGUID& RequestGuid() const noexcept { + return R_->ReqId; + } + + private: + TAutoPtr<TUdpHttpRequest> R_; + mutable TString H_; + TIntrusivePtr<TSrvRequestState> S_; + TIntrusivePtr<TRequester> P_; + }; + + class TEventsHandler: public IEventsCollector { + public: + TEventsHandler(TRequester* parent) + { + P_.store(parent, std::memory_order_release); + } + + void RequestProcessed(const TRequest* r) { + FinishedReqs_.Enqueue(r->RequestGuid()); + } + + //thread safe method for disable proxy callbacks to parent (OnRequest(...)) + void SyncStop() { + P_.store(nullptr, std::memory_order_release); + while (!RequesterPtrPotector_.TryAcquire()) { + Sleep(TDuration::MicroSeconds(100)); + } + RequesterPtrPotector_.Release(); + } + + private: + typedef THashMap<TGUID, TIntrusivePtr<TSrvRequestState>, TGUIDHash> TStatesInProcessRequests; + + void AddRequest(TUdpHttpRequest* req) override { + TUdpHttpRequestPtr ptr(req); + + TSrvRequestState* state = new TSrvRequestState(); + + InProcess_[req->ReqId] = state; + try { + TGuard<TSpinLock> m(RequesterPtrPotector_); + if (TRequester* p = P_.load(std::memory_order_acquire)) { + p->OnRequest(ptr, state); //move req. owning to parent + } + } catch (...) { + Cdbg << "ignore exc.: " << CurrentExceptionMessage() << Endl; + } + } + + void AddResponse(TUdpHttpResponse*) override { + Y_FAIL("unexpected response in neh netliba server"); + } + + void AddCancel(const TGUID& guid) override { + UpdateInProcess(); + TStatesInProcessRequests::iterator ustate = InProcess_.find(guid); + if (ustate != InProcess_.end()) + ustate->second->Canceled = true; + } + + void AddRequestAck(const TGUID&) override { + Y_FAIL("unexpected acc in neh netliba server"); + } + + void UpdateInProcess() { + TGUID guid; + + while (FinishedReqs_.Dequeue(&guid)) { + InProcess_.erase(guid); + } + } + + private: + TLockFreeStack<TGUID> FinishedReqs_; //processed requests (responded or destroyed) + TStatesInProcessRequests InProcess_; + TSpinLock RequesterPtrPotector_; + std::atomic<TRequester*> P_; + }; + + public: + inline TRequester(IOnRequest* cb, ui16 port) + : CB_(cb) + , EH_(new TEventsHandler(this)) + , R_(CreateHttpUdpRequester(port, EH_.Get())) + { + R_->EnableReportRequestCancel(); + } + + ~TRequester() override { + Shutdown(); + } + + void Shutdown() noexcept { + if (!Shutdown_) { + Shutdown_ = true; + R_->StopNoWait(); + EH_->SyncStop(); + } + } + + void OnRequest(TUdpHttpRequestPtr req, TSrvRequestState* state) { + CB_->OnRequest(new TRequest(req, state, this)); + } + + void RequestProcessed(const TRequest* r) { + EH_->RequestProcessed(r); + } + + private: + IOnRequest* CB_; + TIntrusivePtr<TEventsHandler> EH_; + TIntrusivePtr<INetLibaRequester> R_; + bool Shutdown_ = false; + }; + + typedef TIntrusivePtr<TRequester> TRequesterRef; + + class TRequesterAutoShutdown: public NNeh::IRequester { + public: + TRequesterAutoShutdown(const TRequesterRef& r) + : R_(r) + { + } + + ~TRequesterAutoShutdown() override { + R_->Shutdown(); + } + + private: + TRequesterRef R_; + }; + + class TProtocol: public IProtocol { + public: + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + return Singleton<TNetLibaBus>()->Schedule(msg, fallback, ss); + } + + NNeh::IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + TRequesterRef r(new TRequester(cb, loc.GetPort())); + return new TRequesterAutoShutdown(r); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("netliba"); + } + }; + } +} + +IProtocol* NNeh::NetLibaProtocol() { + return Singleton<NNetLiba::TProtocol>(); +} diff --git a/library/cpp/neh/netliba.h b/library/cpp/neh/netliba.h new file mode 100644 index 0000000000..9635d29753 --- /dev/null +++ b/library/cpp/neh/netliba.h @@ -0,0 +1,20 @@ +#pragma once + +#include <util/datetime/base.h> + +namespace NNeh { + //global options + struct TNetLibaOptions { + static size_t ClientThreads; + + //period for quick send complete confirmation for next request to some address after receiving request ack + static TDuration AckTailEffect; + + //set option, - return false, if option name not recognized + static bool Set(TStringBuf name, TStringBuf value); + }; + + class IProtocol; + + IProtocol* NetLibaProtocol(); +} diff --git a/library/cpp/neh/netliba_udp_http.cpp b/library/cpp/neh/netliba_udp_http.cpp new file mode 100644 index 0000000000..a4df426f02 --- /dev/null +++ b/library/cpp/neh/netliba_udp_http.cpp @@ -0,0 +1,808 @@ +#include "netliba_udp_http.h" +#include "utils.h" + +#include <library/cpp/netliba/v6/cpu_affinity.h> +#include <library/cpp/netliba/v6/stdafx.h> +#include <library/cpp/netliba/v6/udp_client_server.h> +#include <library/cpp/netliba/v6/udp_socket.h> + +#include <library/cpp/netliba/v6/block_chain.h> // depend on another headers + +#include <util/system/hp_timer.h> +#include <util/system/shmat.h> +#include <util/system/spinlock.h> +#include <util/system/thread.h> +#include <util/system/types.h> +#include <util/system/yassert.h> +#include <util/thread/lfqueue.h> + +#include <atomic> + +#if !defined(_win_) +#include <signal.h> +#include <pthread.h> +#endif + +using namespace NNetliba; + +namespace { + const float HTTP_TIMEOUT = 15.0f; + const size_t MIN_SHARED_MEM_PACKET = 1000; + const size_t MAX_PACKET_SIZE = 0x70000000; + + NNeh::TAtomicBool PanicAttack; + std::atomic<NHPTimer::STime> LastHeartbeat; + std::atomic<double> HeartbeatTimeout; + + bool IsLocal(const TUdpAddress& addr) { + return addr.IsIPv4() ? IsLocalIPv4(addr.GetIPv4()) : IsLocalIPv6(addr.Network, addr.Interface); + } + + void StopAllNetLibaThreads() { + PanicAttack = true; // AAAA!!!! + } + + void ReadShm(TSharedMemory* shm, TVector<char>* data) { + Y_ASSERT(shm); + int dataSize = shm->GetSize(); + data->yresize(dataSize); + memcpy(&(*data)[0], shm->GetPtr(), dataSize); + } + + void ReadShm(TSharedMemory* shm, TString* data) { + Y_ASSERT(shm); + size_t dataSize = shm->GetSize(); + data->ReserveAndResize(dataSize); + memcpy(data->begin(), shm->GetPtr(), dataSize); + } + + template <class T> + void EraseList(TLockFreeQueue<T*>* data) { + T* ptr = 0; + while (data->Dequeue(&ptr)) { + delete ptr; + } + } + + enum EHttpPacket { + PKT_REQUEST, + PKT_PING, + PKT_PING_RESPONSE, + PKT_RESPONSE, + PKT_LOCAL_REQUEST, + PKT_LOCAL_RESPONSE, + PKT_CANCEL, + }; +} + +namespace NNehNetliba { + TUdpHttpMessage::TUdpHttpMessage(const TGUID& reqId, const TUdpAddress& peerAddr) + : ReqId(reqId) + , PeerAddress(peerAddr) + { + } + + TUdpHttpRequest::TUdpHttpRequest(TAutoPtr<TRequest>& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr) + : TUdpHttpMessage(reqId, peerAddr) + { + TBlockChainIterator reqData(dataHolder->Data->GetChain()); + char pktType; + reqData.Read(&pktType, 1); + ReadArr(&reqData, &Url); + if (pktType == PKT_REQUEST) { + ReadYArr(&reqData, &Data); + } else if (pktType == PKT_LOCAL_REQUEST) { + ReadShm(dataHolder->Data->GetSharedData(), &Data); + } else { + Y_ASSERT(0); + } + + if (reqData.HasFailed()) { + Y_ASSERT(0 && "wrong format, memory corruption suspected"); + Url = ""; + Data.clear(); + } + } + + TUdpHttpResponse::TUdpHttpResponse(TAutoPtr<TRequest>& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr, EResult result, const char* error) + : TUdpHttpMessage(reqId, peerAddr) + , Ok(result) + { + if (result == TUdpHttpResponse::FAILED) { + Error = error ? error : "request failed"; + } else if (result == TUdpHttpResponse::CANCELED) { + Error = error ? error : "request cancelled"; + } else { + TBlockChainIterator reqData(dataHolder->Data->GetChain()); + if (Y_UNLIKELY(reqData.HasFailed())) { + Y_ASSERT(0 && "wrong format, memory corruption suspected"); + Ok = TUdpHttpResponse::FAILED; + Data.clear(); + Error = "wrong response format"; + } else { + char pktType; + reqData.Read(&pktType, 1); + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + Y_ASSERT(ReqId == guid); + if (pktType == PKT_RESPONSE) { + ReadArr<TString>(&reqData, &Data); + } else if (pktType == PKT_LOCAL_RESPONSE) { + ReadShm(dataHolder->Data->GetSharedData(), &Data); + } else { + Y_ASSERT(0); + } + } + } + } + + class TUdpHttp: public IRequester { + enum EDir { + DIR_OUT, + DIR_IN + }; + + struct TInRequestState { + enum EState { + S_WAITING, + S_RESPONSE_SENDING, + S_CANCELED, + }; + + TInRequestState() + : State(S_WAITING) + { + } + + TInRequestState(const TUdpAddress& address) + : State(S_WAITING) + , Address(address) + { + } + + EState State; + TUdpAddress Address; + }; + + struct TOutRequestState { + enum EState { + S_SENDING, + S_WAITING, + S_WAITING_PING_SENDING, + S_WAITING_PING_SENT, + S_CANCEL_AFTER_SENDING + }; + + TOutRequestState() + : State(S_SENDING) + , TimePassed(0) + , PingTransferId(-1) + { + } + + EState State; + TUdpAddress Address; + double TimePassed; + int PingTransferId; + IEventsCollectorRef EventsCollector; + }; + + struct TTransferPurpose { + EDir Dir; + TGUID Guid; + + TTransferPurpose() + : Dir(DIR_OUT) + { + } + + TTransferPurpose(EDir dir, TGUID guid) + : Dir(dir) + , Guid(guid) + { + } + }; + + struct TSendRequest { + TSendRequest() = default; + + TSendRequest(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket>* data, const TGUID& reqGuid, const IEventsCollectorRef& eventsCollector) + : Addr(addr) + , Data(*data) + , ReqGuid(reqGuid) + , EventsCollector(eventsCollector) + , Crc32(CalcChecksum(Data->GetChain())) + { + } + + TUdpAddress Addr; + TAutoPtr<TRopeDataPacket> Data; + TGUID ReqGuid; + IEventsCollectorRef EventsCollector; + ui32 Crc32; + }; + + struct TSendResponse { + TSendResponse() = default; + + TSendResponse(const TGUID& reqGuid, EPacketPriority prior, TVector<char>* data) + : ReqGuid(reqGuid) + , DataCrc32(0) + , Priority(prior) + { + if (data && !data->empty()) { + data->swap(Data); + DataCrc32 = TIncrementalChecksumCalcer::CalcBlockSum(&Data[0], Data.ysize()); + } + } + + TVector<char> Data; + TGUID ReqGuid; + ui32 DataCrc32; + EPacketPriority Priority; + }; + + typedef THashMap<TGUID, TOutRequestState, TGUIDHash> TOutRequestHash; + typedef THashMap<TGUID, TInRequestState, TGUIDHash> TInRequestHash; + + public: + TUdpHttp(const IEventsCollectorRef& eventsCollector) + : MyThread_(ExecServerThread, (void*)this) + , AbortTransactions_(false) + , Port_(0) + , EventCollector_(eventsCollector) + , ReportRequestCancel_(false) + , ReporRequestAck_(false) + , PhysicalCpu_(-1) + { + } + + ~TUdpHttp() override { + if (MyThread_.Running()) { + AtomicSet(KeepRunning_, 0); + MyThread_.Join(); + } + } + + bool Start(int port, int physicalCpu) { + Y_ASSERT(Host_.Get() == nullptr); + Port_ = port; + PhysicalCpu_ = physicalCpu; + MyThread_.Start(); + HasStarted_.Wait(); + return Host_.Get() != nullptr; + } + + void EnableReportRequestCancel() override { + ReportRequestCancel_ = true; + } + + void EnableReportRequestAck() override { + ReporRequestAck_ = true; + } + + void SendRequest(const TUdpAddress& addr, const TString& url, const TString& data, const TGUID& reqId) override { + Y_VERIFY( + data.size() < MAX_PACKET_SIZE, + "data size is too large; data.size()=%" PRISZT ", MAX_PACKET_SIZE=%" PRISZT, + data.size(), MAX_PACKET_SIZE); + + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + if (data.size() > MIN_SHARED_MEM_PACKET && IsLocal(addr)) { + TIntrusivePtr<TSharedMemory> shm = new TSharedMemory; + if (shm->Create(data.size())) { + ms->Write((char)PKT_LOCAL_REQUEST); + ms->WriteStroka(url); + memcpy(shm->GetPtr(), data.begin(), data.size()); + ms->AttachSharedData(shm); + } + } + if (ms->GetSharedData() == nullptr) { + ms->Write((char)PKT_REQUEST); + ms->WriteStroka(url); + struct TStrokaStorage: public TThrRefBase, public TString { + TStrokaStorage(const TString& s) + : TString(s) + { + } + }; + TStrokaStorage* ss = new TStrokaStorage(data); + ms->Write((int)ss->size()); + ms->AddBlock(ss, ss->begin(), ss->size()); + } + + SendReqList_.Enqueue(new TSendRequest(addr, &ms, reqId, EventCollector_)); + Host_->CancelWait(); + } + + void CancelRequest(const TGUID& reqId) override { + CancelReqList_.Enqueue(reqId); + Host_->CancelWait(); + } + + void SendResponse(const TGUID& reqId, TVector<char>* data) override { + if (data && data->size() > MAX_PACKET_SIZE) { + Y_FAIL( + "data size is too large; data->size()=%" PRISZT ", MAX_PACKET_SIZE=%" PRISZT, + data->size(), MAX_PACKET_SIZE); + } + SendRespList_.Enqueue(new TSendResponse(reqId, PP_NORMAL, data)); + Host_->CancelWait(); + } + + void StopNoWait() override { + AbortTransactions_ = true; + AtomicSet(KeepRunning_, 0); + // calcel all outgoing requests + TGuard<TSpinLock> lock(Spn_); + while (!OutRequests_.empty()) { + // cancel without informing peer that we are cancelling the request + FinishRequest(OutRequests_.begin(), TUdpHttpResponse::CANCELED, nullptr, "request canceled: inside TUdpHttp::StopNoWait()"); + } + } + + private: + void FinishRequest(TOutRequestHash::iterator i, TUdpHttpResponse::EResult ok, TRequestPtr data, const char* error = nullptr) { + TOutRequestState& s = i->second; + s.EventsCollector->AddResponse(new TUdpHttpResponse(data, i->first, s.Address, ok, error)); + OutRequests_.erase(i); + } + + int SendWithHighPriority(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data) { + ui32 crc32 = CalcChecksum(data->GetChain()); + return Host_->Send(addr, data.Release(), crc32, nullptr, PP_HIGH); + } + + void ProcessIncomingPackets() { + TVector<TGUID> failedRequests; + for (;;) { + TAutoPtr<TRequest> req = Host_->GetRequest(); + if (req.Get() == nullptr) + break; + + TBlockChainIterator reqData(req->Data->GetChain()); + char pktType; + reqData.Read(&pktType, 1); + switch (pktType) { + case PKT_REQUEST: + case PKT_LOCAL_REQUEST: { + TGUID reqId = req->Guid; + TInRequestHash::iterator z = InRequests_.find(reqId); + if (z != InRequests_.end()) { + // oops, this request already exists! + // might happen if request can be stored in single packet + // and this packet had source IP broken during transmission and managed to pass crc checks + // since we already reported wrong source address for this request to the user + // the best thing we can do is to stop the program to avoid further complications + // but we just report the accident to stderr + fprintf(stderr, "Jackpot, same request %s received twice from %s and earlier from %s\n", + GetGuidAsString(reqId).c_str(), GetAddressAsString(z->second.Address).c_str(), + GetAddressAsString(req->Address).c_str()); + } else { + InRequests_[reqId] = TInRequestState(req->Address); + EventCollector_->AddRequest(new TUdpHttpRequest(req, reqId, req->Address)); + } + } break; + case PKT_PING: { + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + bool ok = InRequests_.find(guid) != InRequests_.end(); + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write((char)PKT_PING_RESPONSE); + ms->Write(guid); + ms->Write(ok); + SendWithHighPriority(req->Address, ms.Release()); + } break; + case PKT_PING_RESPONSE: { + TGUID guid; + bool ok; + reqData.Read(&guid, sizeof(guid)); + reqData.Read(&ok, sizeof(ok)); + TOutRequestHash::iterator i = OutRequests_.find(guid); + if (i == OutRequests_.end()) + ; //Y_ASSERT(0); // actually possible with some packet orders + else { + if (!ok) { + // can not delete request at this point + // since we can receive failed ping and response at the same moment + // consider sequence: client sends ping, server sends response + // and replies false to ping as reply is sent + // we can not receive failed ping_response earlier then response itself + // but we can receive them simultaneously + failedRequests.push_back(guid); + } else { + TOutRequestState& s = i->second; + switch (s.State) { + case TOutRequestState::S_WAITING_PING_SENDING: { + Y_ASSERT(s.PingTransferId >= 0); + TTransferHash::iterator k = TransferHash_.find(s.PingTransferId); + if (k != TransferHash_.end()) + TransferHash_.erase(k); + else + Y_ASSERT(0); + s.PingTransferId = -1; + s.TimePassed = 0; + s.State = TOutRequestState::S_WAITING; + } break; + case TOutRequestState::S_WAITING_PING_SENT: + s.TimePassed = 0; + s.State = TOutRequestState::S_WAITING; + break; + default: + Y_ASSERT(0); + break; + } + } + } + } break; + case PKT_RESPONSE: + case PKT_LOCAL_RESPONSE: { + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + TOutRequestHash::iterator i = OutRequests_.find(guid); + if (i == OutRequests_.end()) { + ; //Y_ASSERT(0); // does happen + } else { + FinishRequest(i, TUdpHttpResponse::OK, req); + } + } break; + case PKT_CANCEL: { + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + TInRequestHash::iterator i = InRequests_.find(guid); + if (i == InRequests_.end()) { + ; //Y_ASSERT(0); // may happen + } else { + TInRequestState& s = i->second; + if (s.State != TInRequestState::S_CANCELED && ReportRequestCancel_) + EventCollector_->AddCancel(guid); + s.State = TInRequestState::S_CANCELED; + } + } break; + default: + Y_ASSERT(0); + } + } + // cleanup failed requests + for (size_t k = 0; k < failedRequests.size(); ++k) { + const TGUID& guid = failedRequests[k]; + TOutRequestHash::iterator i = OutRequests_.find(guid); + if (i != OutRequests_.end()) + FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "failed udp ping"); + } + } + + void AnalyzeSendResults() { + TSendResult res; + while (Host_->GetSendResult(&res)) { + TTransferHash::iterator k = TransferHash_.find(res.TransferId); + if (k != TransferHash_.end()) { + const TTransferPurpose& tp = k->second; + switch (tp.Dir) { + case DIR_OUT: { + TOutRequestHash::iterator i = OutRequests_.find(tp.Guid); + if (i != OutRequests_.end()) { + const TGUID& reqId = i->first; + TOutRequestState& s = i->second; + switch (s.State) { + case TOutRequestState::S_SENDING: + if (!res.Success) { + FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_SENDING"); + } else { + if (ReporRequestAck_ && !!s.EventsCollector) { + s.EventsCollector->AddRequestAck(reqId); + } + s.State = TOutRequestState::S_WAITING; + s.TimePassed = 0; + } + break; + case TOutRequestState::S_CANCEL_AFTER_SENDING: + DoSendCancel(s.Address, reqId); + FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request failed: state S_CANCEL_AFTER_SENDING"); + break; + case TOutRequestState::S_WAITING: + case TOutRequestState::S_WAITING_PING_SENT: + Y_ASSERT(0); + break; + case TOutRequestState::S_WAITING_PING_SENDING: + Y_ASSERT(s.PingTransferId >= 0 && s.PingTransferId == res.TransferId); + if (!res.Success) { + FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_WAITING_PING_SENDING"); + } else { + s.PingTransferId = -1; + s.State = TOutRequestState::S_WAITING_PING_SENT; + s.TimePassed = 0; + } + break; + default: + Y_ASSERT(0); + break; + } + } + } break; + case DIR_IN: { + TInRequestHash::iterator i = InRequests_.find(tp.Guid); + if (i != InRequests_.end()) { + Y_ASSERT(i->second.State == TInRequestState::S_RESPONSE_SENDING || i->second.State == TInRequestState::S_CANCELED); + InRequests_.erase(i); + } + } break; + default: + Y_ASSERT(0); + break; + } + TransferHash_.erase(k); + } + } + } + + void SendPingsIfNeeded() { + NHPTimer::STime tChk = PingsSendT_; + float deltaT = (float)NHPTimer::GetTimePassed(&tChk); + if (deltaT < 0.05) { + return; + } + PingsSendT_ = tChk; + deltaT = ClampVal(deltaT, 0.0f, HTTP_TIMEOUT / 3); + + { + for (TOutRequestHash::iterator i = OutRequests_.begin(); i != OutRequests_.end();) { + TOutRequestHash::iterator curIt = i++; + TOutRequestState& s = curIt->second; + const TGUID& guid = curIt->first; + switch (s.State) { + case TOutRequestState::S_WAITING: + s.TimePassed += deltaT; + if (s.TimePassed > HTTP_TIMEOUT) { + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write((char)PKT_PING); + ms->Write(guid); + int transId = SendWithHighPriority(s.Address, ms.Release()); + TransferHash_[transId] = TTransferPurpose(DIR_OUT, guid); + s.State = TOutRequestState::S_WAITING_PING_SENDING; + s.PingTransferId = transId; + } + break; + case TOutRequestState::S_WAITING_PING_SENT: + s.TimePassed += deltaT; + if (s.TimePassed > HTTP_TIMEOUT) { + FinishRequest(curIt, TUdpHttpResponse::FAILED, nullptr, "request failed: http timeout in state S_WAITING_PING_SENT"); + } + break; + default: + break; + } + } + } + } + + void Step() { + { + TGuard<TSpinLock> lock(Spn_); + DoSends(); + } + Host_->Step(); + { + TGuard<TSpinLock> lock(Spn_); + DoSends(); + ProcessIncomingPackets(); + AnalyzeSendResults(); + SendPingsIfNeeded(); + } + } + + void Wait() { + Host_->Wait(0.1f); + } + + void DoSendCancel(const TUdpAddress& addr, const TGUID& req) { + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write((char)PKT_CANCEL); + ms->Write(req); + SendWithHighPriority(addr, ms); + } + + void DoSends() { + { + // cancelling requests + TGUID reqGuid; + while (CancelReqList_.Dequeue(&reqGuid)) { + TOutRequestHash::iterator i = OutRequests_.find(reqGuid); + if (i == OutRequests_.end()) { + AnticipateCancels_.insert(reqGuid); + continue; // cancelling non existing request is ok + } + TOutRequestState& s = i->second; + if (s.State == TOutRequestState::S_SENDING) { + // we are in trouble - have not sent request and we already have to cancel it, wait send + s.State = TOutRequestState::S_CANCEL_AFTER_SENDING; + s.EventsCollector->AddCancel(i->first); + } else { + DoSendCancel(s.Address, reqGuid); + FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request canceled: notify requested side"); + } + } + } + { + // sending replies + for (TSendResponse* rd = nullptr; SendRespList_.Dequeue(&rd); delete rd) { + TInRequestHash::iterator i = InRequests_.find(rd->ReqGuid); + if (i == InRequests_.end()) { + Y_ASSERT(0); + continue; + } + TInRequestState& s = i->second; + if (s.State == TInRequestState::S_CANCELED) { + // need not send response for the canceled request + InRequests_.erase(i); + continue; + } + + Y_ASSERT(s.State == TInRequestState::S_WAITING); + s.State = TInRequestState::S_RESPONSE_SENDING; + + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ui32 crc32 = 0; + int dataSize = rd->Data.ysize(); + if (rd->Data.size() > MIN_SHARED_MEM_PACKET && IsLocal(s.Address)) { + TIntrusivePtr<TSharedMemory> shm = new TSharedMemory; + if (shm->Create(dataSize)) { + ms->Write((char)PKT_LOCAL_RESPONSE); + ms->Write(rd->ReqGuid); + memcpy(shm->GetPtr(), &rd->Data[0], dataSize); + TVector<char> empty; + rd->Data.swap(empty); + ms->AttachSharedData(shm); + crc32 = CalcChecksum(ms->GetChain()); + } + } + if (ms->GetSharedData() == nullptr) { + ms->Write((char)PKT_RESPONSE); + ms->Write(rd->ReqGuid); + + // to offload crc calcs from inner thread, crc of data[] is calced outside and passed in DataCrc32 + // this means that we are calculating crc when shared memory is used + // it is hard to avoid since in SendResponse() we don't know if shared mem will be used + // (peer address is not available there) + TIncrementalChecksumCalcer csCalcer; + AddChain(&csCalcer, ms->GetChain()); + // here we are replicating the way WriteDestructive serializes data + csCalcer.AddBlock(&dataSize, sizeof(dataSize)); + csCalcer.AddBlockSum(rd->DataCrc32, dataSize); + crc32 = csCalcer.CalcChecksum(); + + ms->WriteDestructive(&rd->Data); + //ui32 chkCrc = CalcChecksum(ms->GetChain()); // can not use since its slow for large responses + //Y_ASSERT(chkCrc == crc32); + } + + int transId = Host_->Send(s.Address, ms.Release(), crc32, nullptr, rd->Priority); + TransferHash_[transId] = TTransferPurpose(DIR_IN, rd->ReqGuid); + } + } + { + // sending requests + for (TSendRequest* rd = nullptr; SendReqList_.Dequeue(&rd); delete rd) { + Y_ASSERT(OutRequests_.find(rd->ReqGuid) == OutRequests_.end()); + + { + TOutRequestState& s = OutRequests_[rd->ReqGuid]; + s.State = TOutRequestState::S_SENDING; + s.Address = rd->Addr; + s.EventsCollector = rd->EventsCollector; + } + + if (AnticipateCancels_.find(rd->ReqGuid) != AnticipateCancels_.end()) { + FinishRequest(OutRequests_.find(rd->ReqGuid), TUdpHttpResponse::CANCELED, nullptr, "Canceled (before transmit)"); + } else { + TGUID pktGuid = rd->ReqGuid; // request packet id should match request id + int transId = Host_->Send(rd->Addr, rd->Data.Release(), rd->Crc32, &pktGuid, PP_NORMAL); + TransferHash_[transId] = TTransferPurpose(DIR_OUT, rd->ReqGuid); + } + } + } + if (!AnticipateCancels_.empty()) { + AnticipateCancels_.clear(); + } + } + + void FinishOutstandingTransactions() { + // wait all pending requests, all new requests are canceled + while ((!OutRequests_.empty() || !InRequests_.empty() || !SendRespList_.IsEmpty() || !SendReqList_.IsEmpty()) && !PanicAttack) { + Step(); + sleep(0); + } + } + + static void* ExecServerThread(void* param) { + TUdpHttp* pThis = (TUdpHttp*)param; + if (pThis->GetPhysicalCpu() >= 0) { + BindToSocket(pThis->GetPhysicalCpu()); + } + SetHighestThreadPriority(); + + TIntrusivePtr<NNetlibaSocket::ISocket> socket = NNetlibaSocket::CreateSocket(); + socket->Open(pThis->Port_); + if (socket->IsValid()) { + pThis->Port_ = socket->GetPort(); + pThis->Host_ = CreateUdpHost(socket); + } else { + pThis->Host_ = nullptr; + } + + pThis->HasStarted_.Signal(); + if (!pThis->Host_) + return nullptr; + + NHPTimer::GetTime(&pThis->PingsSendT_); + while (AtomicGet(pThis->KeepRunning_) && !PanicAttack) { + if (HeartbeatTimeout.load(std::memory_order_acquire) > 0) { + NHPTimer::STime chk = LastHeartbeat.load(std::memory_order_acquire); + if (NHPTimer::GetTimePassed(&chk) > HeartbeatTimeout.load(std::memory_order_acquire)) { + StopAllNetLibaThreads(); +#ifndef _win_ + killpg(0, SIGKILL); +#endif + abort(); + break; + } + } + pThis->Step(); + pThis->Wait(); + } + if (!pThis->AbortTransactions_ && !PanicAttack) { + pThis->FinishOutstandingTransactions(); + } + pThis->Host_ = nullptr; + return nullptr; + } + + int GetPhysicalCpu() const noexcept { + return PhysicalCpu_; + } + + private: + TThread MyThread_; + TAtomic KeepRunning_ = 1; + bool AbortTransactions_; + TSpinLock Spn_; + TSystemEvent HasStarted_; + + NHPTimer::STime PingsSendT_; + + TIntrusivePtr<IUdpHost> Host_; + int Port_; + TOutRequestHash OutRequests_; + TInRequestHash InRequests_; + + typedef THashMap<int, TTransferPurpose> TTransferHash; + TTransferHash TransferHash_; + + // hold it here to not construct on every DoSends() + typedef THashSet<TGUID, TGUIDHash> TAnticipateCancels; + TAnticipateCancels AnticipateCancels_; + + TLockFreeQueue<TSendRequest*> SendReqList_; + TLockFreeQueue<TSendResponse*> SendRespList_; + TLockFreeQueue<TGUID> CancelReqList_; + + TIntrusivePtr<IEventsCollector> EventCollector_; + + bool ReportRequestCancel_; + bool ReporRequestAck_; + int PhysicalCpu_; + }; + + IRequesterRef CreateHttpUdpRequester(int port, const IEventsCollectorRef& ec, int physicalCpu) { + TUdpHttp* udpHttp = new TUdpHttp(ec); + IRequesterRef res(udpHttp); + if (!udpHttp->Start(port, physicalCpu)) { + if (port) { + ythrow yexception() << "netliba can't bind port=" << port; + } else { + ythrow yexception() << "netliba can't bind random port"; + } + } + return res; + } +} diff --git a/library/cpp/neh/netliba_udp_http.h b/library/cpp/neh/netliba_udp_http.h new file mode 100644 index 0000000000..324366c10c --- /dev/null +++ b/library/cpp/neh/netliba_udp_http.h @@ -0,0 +1,79 @@ +#pragma once + +#include <library/cpp/netliba/v6/net_queue_stat.h> +#include <library/cpp/netliba/v6/udp_address.h> +#include <library/cpp/netliba/v6/udp_debug.h> + +#include <util/generic/guid.h> +#include <util/generic/ptr.h> +#include <util/network/init.h> +#include <util/system/event.h> + +namespace NNetliba { + struct TRequest; +} + +namespace NNehNetliba { + using namespace NNetliba; + + typedef TAutoPtr<TRequest> TRequestPtr; + + class TUdpHttpMessage { + public: + TUdpHttpMessage(const TGUID& reqId, const TUdpAddress& peerAddr); + + TGUID ReqId; + TUdpAddress PeerAddress; + }; + + class TUdpHttpRequest: public TUdpHttpMessage { + public: + TUdpHttpRequest(TRequestPtr& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr); + + TString Url; + TVector<char> Data; + }; + + class TUdpHttpResponse: public TUdpHttpMessage { + public: + enum EResult { + FAILED = 0, + OK = 1, + CANCELED = 2 + }; + + TUdpHttpResponse(TRequestPtr& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr, EResult result, const char* error); + + EResult Ok; + TString Data; + TString Error; + }; + + class IRequester: public TThrRefBase { + public: + virtual void EnableReportRequestCancel() = 0; + virtual void EnableReportRequestAck() = 0; + + // vector<char> *data - vector will be cleared upon call + virtual void SendRequest(const TUdpAddress&, const TString& url, const TString& data, const TGUID&) = 0; + virtual void CancelRequest(const TGUID&) = 0; + virtual void SendResponse(const TGUID&, TVector<char>* data) = 0; + + virtual void StopNoWait() = 0; + }; + + class IEventsCollector: public TThrRefBase { + public: + // move ownership request/response object to event collector + virtual void AddRequest(TUdpHttpRequest*) = 0; + virtual void AddResponse(TUdpHttpResponse*) = 0; + virtual void AddCancel(const TGUID&) = 0; + virtual void AddRequestAck(const TGUID&) = 0; + }; + + typedef TIntrusivePtr<IEventsCollector> IEventsCollectorRef; + typedef TIntrusivePtr<IRequester> IRequesterRef; + + // throw exception, if can't bind port + IRequesterRef CreateHttpUdpRequester(int port, const IEventsCollectorRef&, int physicalCpu = -1); +} diff --git a/library/cpp/neh/pipequeue.cpp b/library/cpp/neh/pipequeue.cpp new file mode 100644 index 0000000000..551e584228 --- /dev/null +++ b/library/cpp/neh/pipequeue.cpp @@ -0,0 +1 @@ +#include "pipequeue.h" diff --git a/library/cpp/neh/pipequeue.h b/library/cpp/neh/pipequeue.h new file mode 100644 index 0000000000..bed8d44bd2 --- /dev/null +++ b/library/cpp/neh/pipequeue.h @@ -0,0 +1,207 @@ +#pragma once + +#include "lfqueue.h" + +#include <library/cpp/coroutine/engine/impl.h> +#include <library/cpp/coroutine/engine/network.h> +#include <library/cpp/deprecated/atomic/atomic.h> +#include <util/system/pipe.h> + +#ifdef _linux_ +#include <sys/eventfd.h> +#endif + +#if defined(_bionic_) && !defined(EFD_SEMAPHORE) +#define EFD_SEMAPHORE 1 +#endif + +namespace NNeh { +#ifdef _linux_ + class TSemaphoreEventFd { + public: + inline TSemaphoreEventFd() { + F_ = eventfd(0, EFD_NONBLOCK | EFD_SEMAPHORE); + if (F_ < 0) { + ythrow TFileError() << "failed to create a eventfd"; + } + } + + inline ~TSemaphoreEventFd() { + close(F_); + } + + inline size_t Acquire(TCont* c) { + ui64 ev; + return NCoro::ReadI(c, F_, &ev, sizeof ev).Processed(); + } + + inline void Release() { + const static ui64 ev(1); + (void)write(F_, &ev, sizeof ev); + } + + private: + int F_; + }; +#endif + + class TSemaphorePipe { + public: + inline TSemaphorePipe() { + TPipeHandle::Pipe(S_[0], S_[1]); + + SetNonBlock(S_[0]); + SetNonBlock(S_[1]); + } + + inline size_t Acquire(TCont* c) { + char ch; + return NCoro::ReadI(c, S_[0], &ch, 1).Processed(); + } + + inline size_t Acquire(TCont* c, char* buff, size_t buflen) { + return NCoro::ReadI(c, S_[0], buff, buflen).Processed(); + } + + inline void Release() { + char ch = 13; + S_[1].Write(&ch, 1); + } + + private: + TPipeHandle S_[2]; + }; + + class TPipeQueueBase { + public: + inline void Enqueue(void* job) { + Q_.Enqueue(job); + S_.Release(); + } + + inline void* Dequeue(TCont* c, char* ch, size_t buflen) { + void* ret = nullptr; + + while (!Q_.Dequeue(&ret) && S_.Acquire(c, ch, buflen)) { + } + + return ret; + } + + inline void* Dequeue() noexcept { + void* ret = nullptr; + + Q_.Dequeue(&ret); + + return ret; + } + + private: + TLockFreeQueue<void*> Q_; + TSemaphorePipe S_; + }; + + template <class T, size_t buflen = 1> + class TPipeQueue { + public: + template <class TPtr> + inline void EnqueueSafe(TPtr req) { + Enqueue(req.Get()); + req.Release(); + } + + inline void Enqueue(T* req) { + Q_.Enqueue(req); + } + + template <class TPtr> + inline void DequeueSafe(TCont* c, TPtr& ret) { + ret.Reset(Dequeue(c)); + } + + inline T* Dequeue(TCont* c) { + char ch[buflen]; + + return (T*)Q_.Dequeue(c, ch, sizeof(ch)); + } + + protected: + TPipeQueueBase Q_; + }; + + //optimized for avoiding unnecessary usage semaphore + use eventfd on linux + template <class T> + struct TOneConsumerPipeQueue { + inline TOneConsumerPipeQueue() + : Signaled_(0) + , SkipWait_(0) + { + } + + inline void Enqueue(T* job) { + Q_.Enqueue(job); + + AtomicSet(SkipWait_, 1); + if (AtomicCas(&Signaled_, 1, 0)) { + S_.Release(); + } + } + + inline T* Dequeue(TCont* c) { + T* ret = nullptr; + + while (!Q_.Dequeue(&ret)) { + AtomicSet(Signaled_, 0); + if (!AtomicCas(&SkipWait_, 0, 1)) { + if (!S_.Acquire(c)) { + break; + } + } + AtomicSet(Signaled_, 1); + } + + return ret; + } + + template <class TPtr> + inline void EnqueueSafe(TPtr req) { + Enqueue(req.Get()); + Y_UNUSED(req.Release()); + } + + template <class TPtr> + inline void DequeueSafe(TCont* c, TPtr& ret) { + ret.Reset(Dequeue(c)); + } + + protected: + TLockFreeQueue<T*> Q_; +#ifdef _linux_ + TSemaphoreEventFd S_; +#else + TSemaphorePipe S_; +#endif + TAtomic Signaled_; + TAtomic SkipWait_; + }; + + template <class T, size_t buflen = 1> + struct TAutoPipeQueue: public TPipeQueue<T, buflen> { + ~TAutoPipeQueue() { + while (T* t = (T*)TPipeQueue<T, buflen>::Q_.Dequeue()) { + delete t; + } + } + }; + + template <class T> + struct TAutoOneConsumerPipeQueue: public TOneConsumerPipeQueue<T> { + ~TAutoOneConsumerPipeQueue() { + T* ret = nullptr; + + while (TOneConsumerPipeQueue<T>::Q_.Dequeue(&ret)) { + delete ret; + } + } + }; +} diff --git a/library/cpp/neh/rpc.cpp b/library/cpp/neh/rpc.cpp new file mode 100644 index 0000000000..13813bea1f --- /dev/null +++ b/library/cpp/neh/rpc.cpp @@ -0,0 +1,322 @@ +#include "rpc.h" +#include "rq.h" +#include "multi.h" +#include "location.h" + +#include <library/cpp/threading/thread_local/thread_local.h> + +#include <util/generic/hash.h> +#include <util/thread/factory.h> +#include <util/system/spinlock.h> + +using namespace NNeh; + +namespace { + typedef std::pair<TString, IServiceRef> TServiceDescr; + typedef TVector<TServiceDescr> TServicesBase; + + class TServices: public TServicesBase, public TThrRefBase, public IOnRequest { + typedef THashMap<TStringBuf, IServiceRef> TSrvs; + + struct TVersionedServiceMap { + TSrvs Srvs; + i64 Version = 0; + }; + + + struct TFunc: public IThreadFactory::IThreadAble { + inline TFunc(TServices* parent) + : Parent(parent) + { + } + + void DoExecute() override { + TThread::SetCurrentThreadName("NehTFunc"); + TVersionedServiceMap mp; + while (true) { + IRequestRef req = Parent->RQ_->Next(); + + if (!req) { + break; + } + + Parent->ServeRequest(mp, req); + } + + Parent->RQ_->Schedule(nullptr); + } + + TServices* Parent; + }; + + public: + inline TServices() + : RQ_(CreateRequestQueue()) + { + } + + inline TServices(TCheck check) + : RQ_(CreateRequestQueue()) + , C_(check) + { + } + + inline ~TServices() override { + LF_.Destroy(); + } + + inline void Add(const TString& service, IServiceRef srv) { + TGuard<TSpinLock> guard(L_); + + push_back(std::make_pair(service, srv)); + AtomicIncrement(SelfVersion_); + } + + inline void Listen() { + Y_ENSURE(!HasLoop_ || !*HasLoop_); + HasLoop_ = false; + RR_ = MultiRequester(ListenAddrs(), this); + } + + inline void Loop(size_t threads) { + Y_ENSURE(!HasLoop_ || *HasLoop_); + HasLoop_ = true; + + TIntrusivePtr<TServices> self(this); + IRequesterRef rr = MultiRequester(ListenAddrs(), this); + TFunc func(this); + + typedef TAutoPtr<IThreadFactory::IThread> IThreadRef; + TVector<IThreadRef> thrs; + + for (size_t i = 1; i < threads; ++i) { + thrs.push_back(SystemThreadFactory()->Run(&func)); + } + + func.Execute(); + + for (size_t i = 0; i < thrs.size(); ++i) { + thrs[i]->Join(); + } + RQ_->Clear(); + } + + inline void ForkLoop(size_t threads) { + Y_ENSURE(!HasLoop_ || *HasLoop_); + HasLoop_ = true; + //here we can have trouble with binding port(s), so expect exceptions + IRequesterRef rr = MultiRequester(ListenAddrs(), this); + LF_.Reset(new TLoopFunc(this, threads, rr)); + } + + inline void Stop() { + RQ_->Schedule(nullptr); + } + + inline void SyncStopFork() { + Stop(); + if (LF_) { + LF_->SyncStop(); + } + RQ_->Clear(); + LF_.Destroy(); + } + + void OnRequest(IRequestRef req) override { + if (C_) { + if (auto error = C_(req)) { + req->SendError(*error); + return; + } + } + if (!*HasLoop_) { + ServeRequest(LocalMap_.GetRef(), req); + } else { + RQ_->Schedule(req); + } + } + + private: + class TLoopFunc: public TFunc { + public: + TLoopFunc(TServices* parent, size_t threads, IRequesterRef& rr) + : TFunc(parent) + , RR_(rr) + { + T_.reserve(threads); + + try { + for (size_t i = 0; i < threads; ++i) { + T_.push_back(SystemThreadFactory()->Run(this)); + } + } catch (...) { + //paranoid mode on + SyncStop(); + throw; + } + } + + ~TLoopFunc() override { + try { + SyncStop(); + } catch (...) { + Cdbg << TStringBuf("neh rpc ~loop_func: ") << CurrentExceptionMessage() << Endl; + } + } + + void SyncStop() { + if (!T_) { + return; + } + + Parent->Stop(); + + for (size_t i = 0; i < T_.size(); ++i) { + T_[i]->Join(); + } + T_.clear(); + } + + private: + typedef TAutoPtr<IThreadFactory::IThread> IThreadRef; + TVector<IThreadRef> T_; + IRequesterRef RR_; + }; + + inline void ServeRequest(TVersionedServiceMap& mp, IRequestRef req) { + if (!req) { + return; + } + + const TStringBuf name = req->Service(); + TSrvs::const_iterator it = mp.Srvs.find(name); + + if (Y_UNLIKELY(it == mp.Srvs.end())) { + if (UpdateServices(mp.Srvs, mp.Version)) { + it = mp.Srvs.find(name); + } + } + + if (Y_UNLIKELY(it == mp.Srvs.end())) { + it = mp.Srvs.find(TStringBuf("*")); + } + + if (Y_UNLIKELY(it == mp.Srvs.end())) { + req->SendError(IRequest::NotExistService); + } else { + try { + it->second->ServeRequest(req); + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + } + } + } + + inline bool UpdateServices(TSrvs& srvs, i64& version) const { + if (AtomicGet(SelfVersion_) == version) { + return false; + } + + srvs.clear(); + + TGuard<TSpinLock> guard(L_); + + for (const auto& it : *this) { + srvs[TParsedLocation(it.first).Service] = it.second; + } + version = AtomicGet(SelfVersion_); + + return true; + } + + inline TListenAddrs ListenAddrs() const { + TListenAddrs addrs; + + { + TGuard<TSpinLock> guard(L_); + + for (const auto& it : *this) { + addrs.push_back(it.first); + } + } + + return addrs; + } + + TSpinLock L_; + IRequestQueueRef RQ_; + THolder<TLoopFunc> LF_; + TAtomic SelfVersion_ = 1; + TCheck C_; + + NThreading::TThreadLocalValue<TVersionedServiceMap> LocalMap_; + + IRequesterRef RR_; + TMaybe<bool> HasLoop_; + }; + + class TServicesFace: public IServices { + public: + inline TServicesFace() + : S_(new TServices()) + { + } + + inline TServicesFace(TCheck check) + : S_(new TServices(check)) + { + } + + void DoAdd(const TString& service, IServiceRef srv) override { + S_->Add(service, srv); + } + + void Loop(size_t threads) override { + S_->Loop(threads); + } + + void ForkLoop(size_t threads) override { + S_->ForkLoop(threads); + } + + void SyncStopFork() override { + S_->SyncStopFork(); + } + + void Stop() override { + S_->Stop(); + } + + void Listen() override { + S_->Listen(); + } + + private: + TIntrusivePtr<TServices> S_; + }; +} + +IServiceRef NNeh::Wrap(const TServiceFunction& func) { + struct TWrapper: public IService { + inline TWrapper(const TServiceFunction& f) + : F(f) + { + } + + void ServeRequest(const IRequestRef& request) override { + F(request); + } + + TServiceFunction F; + }; + + return new TWrapper(func); +} + +IServicesRef NNeh::CreateLoop() { + return new TServicesFace(); +} + +IServicesRef NNeh::CreateLoop(TCheck check) { + return new TServicesFace(check); +} diff --git a/library/cpp/neh/rpc.h b/library/cpp/neh/rpc.h new file mode 100644 index 0000000000..482ff7ce53 --- /dev/null +++ b/library/cpp/neh/rpc.h @@ -0,0 +1,155 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/strbuf.h> +#include <util/generic/maybe.h> +#include <util/stream/output.h> +#include <util/datetime/base.h> +#include <functional> + +namespace NNeh { + using TData = TVector<char>; + + class TDataSaver: public TData, public IOutputStream { + public: + TDataSaver() = default; + ~TDataSaver() override = default; + TDataSaver(TDataSaver&&) noexcept = default; + TDataSaver& operator=(TDataSaver&&) noexcept = default; + + void DoWrite(const void* buf, size_t len) override { + insert(end(), (const char*)buf, (const char*)buf + len); + } + }; + + class IRequest { + public: + IRequest() + : ArrivalTime_(TInstant::Now()) + { + } + + virtual ~IRequest() = default; + + virtual TStringBuf Scheme() const = 0; + virtual TString RemoteHost() const = 0; //IP-literal / IPv4address / reg-name() + virtual TStringBuf Service() const = 0; + virtual TStringBuf Data() const = 0; + virtual TStringBuf RequestId() const = 0; + virtual bool Canceled() const = 0; + virtual void SendReply(TData& data) = 0; + enum TResponseError { + BadRequest, // bad request data - http_code 400 + Forbidden, // forbidden request - http_code 403 + NotExistService, // not found request handler - http_code 404 + TooManyRequests, // too many requests for the handler - http_code 429 + InternalError, // s...amthing happen - http_code 500 + NotImplemented, // not implemented - http_code 501 + BadGateway, // remote backend not available - http_code 502 + ServiceUnavailable, // overload - http_code 503 + BandwidthLimitExceeded, // 5xx version of 429 + MaxResponseError // count error types + }; + virtual void SendError(TResponseError err, const TString& details = TString()) = 0; + virtual TInstant ArrivalTime() const { + return ArrivalTime_; + } + + private: + TInstant ArrivalTime_; + }; + + using IRequestRef = TAutoPtr<IRequest>; + + struct IOnRequest { + virtual void OnRequest(IRequestRef req) = 0; + }; + + class TRequestOut: public TDataSaver { + public: + inline TRequestOut(IRequest* req) + : Req_(req) + { + } + + ~TRequestOut() override { + try { + Finish(); + } catch (...) { + } + } + + void DoFinish() override { + if (Req_) { + Req_->SendReply(*this); + Req_ = nullptr; + } + } + + private: + IRequest* Req_; + }; + + struct IRequester { + virtual ~IRequester() = default; + }; + + using IRequesterRef = TAtomicSharedPtr<IRequester>; + + struct IService: public TThrRefBase { + virtual void ServeRequest(const IRequestRef& request) = 0; + }; + + using IServiceRef = TIntrusivePtr<IService>; + using TServiceFunction = std::function<void(const IRequestRef&)>; + + IServiceRef Wrap(const TServiceFunction& func); + + class IServices { + public: + virtual ~IServices() = default; + + /// use current thread and run #threads-1 in addition + virtual void Loop(size_t threads) = 0; + /// run #threads and return control + virtual void ForkLoop(size_t threads) = 0; + /// send stopping request and wait stopping all services + virtual void SyncStopFork() = 0; + /// send stopping request and return control (async call) + virtual void Stop() = 0; + /// just listen, don't start any threads + virtual void Listen() = 0; + + inline IServices& Add(const TString& service, IServiceRef srv) { + DoAdd(service, srv); + + return *this; + } + + inline IServices& Add(const TString& service, const TServiceFunction& func) { + return Add(service, Wrap(func)); + } + + template <class T> + inline IServices& Add(const TString& service, T& t) { + return this->Add(service, std::bind(&T::ServeRequest, std::ref(t), std::placeholders::_1)); + } + + template <class T, void (T::*M)(const IRequestRef&)> + inline IServices& Add(const TString& service, T& t) { + return this->Add(service, std::bind(M, std::ref(t), std::placeholders::_1)); + } + + private: + virtual void DoAdd(const TString& service, IServiceRef srv) = 0; + }; + + using IServicesRef = TAutoPtr<IServices>; + using TCheck = std::function<TMaybe<IRequest::TResponseError>(const IRequestRef&)>; + + IServicesRef CreateLoop(); + // if request fails check it will be cancelled + IServicesRef CreateLoop(TCheck check); +} diff --git a/library/cpp/neh/rq.cpp b/library/cpp/neh/rq.cpp new file mode 100644 index 0000000000..b3ae8f470d --- /dev/null +++ b/library/cpp/neh/rq.cpp @@ -0,0 +1,312 @@ +#include "rq.h" +#include "lfqueue.h" + +#include <library/cpp/threading/atomic/bool.h> + +#include <util/system/tls.h> +#include <util/system/pipe.h> +#include <util/system/event.h> +#include <util/system/mutex.h> +#include <util/system/condvar.h> +#include <util/system/guard.h> +#include <util/network/socket.h> +#include <util/generic/deque.h> + +using namespace NNeh; + +namespace { + class TBaseLockFreeRequestQueue: public IRequestQueue { + public: + void Clear() override { + IRequestRef req; + while (Q_.Dequeue(&req)) { + } + } + + protected: + NNeh::TAutoLockFreeQueue<IRequest> Q_; + }; + + class TFdRequestQueue: public TBaseLockFreeRequestQueue { + public: + inline TFdRequestQueue() { + TPipeHandle::Pipe(R_, W_); + SetNonBlock(W_); + } + + void Schedule(IRequestRef req) override { + Q_.Enqueue(req); + char ch = 42; + W_.Write(&ch, 1); + } + + IRequestRef Next() override { + IRequestRef ret; + +#if 0 + for (size_t i = 0; i < 20; ++i) { + if (Q_.Dequeue(&ret)) { + return ret; + } + + //asm volatile ("pause;"); + } +#endif + + while (!Q_.Dequeue(&ret)) { + char ch; + + R_.Read(&ch, 1); + } + + return ret; + } + + private: + TPipeHandle R_; + TPipeHandle W_; + }; + + struct TNehFdEvent { + inline TNehFdEvent() { + TPipeHandle::Pipe(R, W); + SetNonBlock(W); + } + + inline void Signal() noexcept { + char ch = 21; + W.Write(&ch, 1); + } + + inline void Wait() noexcept { + char buf[128]; + R.Read(buf, sizeof(buf)); + } + + TPipeHandle R; + TPipeHandle W; + }; + + template <class TEvent> + class TEventRequestQueue: public TBaseLockFreeRequestQueue { + public: + void Schedule(IRequestRef req) override { + Q_.Enqueue(req); + E_.Signal(); + } + + IRequestRef Next() override { + IRequestRef ret; + + while (!Q_.Dequeue(&ret)) { + E_.Wait(); + } + + E_.Signal(); + + return ret; + } + + private: + TEvent E_; + }; + + template <class TEvent> + class TLazyEventRequestQueue: public TBaseLockFreeRequestQueue { + public: + void Schedule(IRequestRef req) override { + Q_.Enqueue(req); + if (C_.Val()) { + E_.Signal(); + } + } + + IRequestRef Next() override { + IRequestRef ret; + + C_.Inc(); + while (!Q_.Dequeue(&ret)) { + E_.Wait(); + } + C_.Dec(); + + if (Q_.Size() && C_.Val()) { + E_.Signal(); + } + + return ret; + } + + private: + TEvent E_; + TAtomicCounter C_; + }; + + class TCondVarRequestQueue: public IRequestQueue { + public: + void Clear() override { + TGuard<TMutex> g(M_); + Q_.clear(); + } + + void Schedule(IRequestRef req) override { + { + TGuard<TMutex> g(M_); + + Q_.push_back(req); + } + + C_.Signal(); + } + + IRequestRef Next() override { + TGuard<TMutex> g(M_); + + while (Q_.empty()) { + C_.Wait(M_); + } + + IRequestRef ret = *Q_.begin(); + Q_.pop_front(); + + return ret; + } + + private: + TDeque<IRequestRef> Q_; + TMutex M_; + TCondVar C_; + }; + + class TBusyRequestQueue: public TBaseLockFreeRequestQueue { + public: + void Schedule(IRequestRef req) override { + Q_.Enqueue(req); + } + + IRequestRef Next() override { + IRequestRef ret; + + while (!Q_.Dequeue(&ret)) { + } + + return ret; + } + }; + + class TSleepRequestQueue: public TBaseLockFreeRequestQueue { + public: + void Schedule(IRequestRef req) override { + Q_.Enqueue(req); + } + + IRequestRef Next() override { + IRequestRef ret; + + while (!Q_.Dequeue(&ret)) { + usleep(1); + } + + return ret; + } + }; + + struct TStupidEvent { + inline TStupidEvent() + : InWait(false) + { + } + + inline bool Signal() noexcept { + const bool ret = InWait; + Ev.Signal(); + + return ret; + } + + inline void Wait() noexcept { + InWait = true; + Ev.Wait(); + InWait = false; + } + + TAutoEvent Ev; + NAtomic::TBool InWait; + }; + + template <class TEvent> + class TLFRequestQueue: public TBaseLockFreeRequestQueue { + struct TLocalQueue: public TEvent { + }; + + public: + void Schedule(IRequestRef req) override { + Q_.Enqueue(req); + + for (TLocalQueue* lq = 0; FQ_.Dequeue(&lq) && !lq->Signal();) { + } + } + + IRequestRef Next() override { + while (true) { + IRequestRef ret; + + if (Q_.Dequeue(&ret)) { + return ret; + } + + TLocalQueue* lq = LocalQueue(); + + FQ_.Enqueue(lq); + + if (Q_.Dequeue(&ret)) { + TLocalQueue* besttry; + + if (FQ_.Dequeue(&besttry)) { + if (besttry == lq) { + //huraay, get rid of spurious wakeup + } else { + FQ_.Enqueue(besttry); + } + } + + return ret; + } + + lq->Wait(); + } + } + + private: + static inline TLocalQueue* LocalQueue() noexcept { + Y_POD_STATIC_THREAD(TLocalQueue*) + lq((TLocalQueue*)0); + + if (!lq) { + Y_STATIC_THREAD(TLocalQueue) + slq; + + lq = &(TLocalQueue&)slq; + } + + return lq; + } + + private: + TLockFreeStack<TLocalQueue*> FQ_; + }; +} + +IRequestQueueRef NNeh::CreateRequestQueue() { +//return new TCondVarRequestQueue(); +//return new TSleepRequestQueue(); +//return new TBusyRequestQueue(); +//return new TLFRequestQueue<TStupidEvent>(); +#if defined(_freebsd_) + return new TFdRequestQueue(); +#endif + //return new TFdRequestQueue(); + return new TLazyEventRequestQueue<TAutoEvent>(); + //return new TEventRequestQueue<TAutoEvent>(); + //return new TEventRequestQueue<TNehFdEvent>(); +} diff --git a/library/cpp/neh/rq.h b/library/cpp/neh/rq.h new file mode 100644 index 0000000000..39db6c7124 --- /dev/null +++ b/library/cpp/neh/rq.h @@ -0,0 +1,18 @@ +#pragma once + +#include "rpc.h" + +namespace NNeh { + class IRequestQueue { + public: + virtual ~IRequestQueue() { + } + + virtual void Clear() = 0; + virtual void Schedule(IRequestRef req) = 0; + virtual IRequestRef Next() = 0; + }; + + typedef TAutoPtr<IRequestQueue> IRequestQueueRef; + IRequestQueueRef CreateRequestQueue(); +} diff --git a/library/cpp/neh/smart_ptr.cpp b/library/cpp/neh/smart_ptr.cpp new file mode 100644 index 0000000000..62a540c871 --- /dev/null +++ b/library/cpp/neh/smart_ptr.cpp @@ -0,0 +1 @@ +#include "smart_ptr.h" diff --git a/library/cpp/neh/smart_ptr.h b/library/cpp/neh/smart_ptr.h new file mode 100644 index 0000000000..1ec4653304 --- /dev/null +++ b/library/cpp/neh/smart_ptr.h @@ -0,0 +1,332 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <library/cpp/deprecated/atomic/atomic.h> + +namespace NNeh { + //limited emulation shared_ptr/weak_ptr from boost lib. + //the main value means the weak_ptr functionality, else recommended use types from util/generic/ptr.h + + //smart pointer counter shared between shared and weak ptrs. + class TSPCounted: public TThrRefBase { + public: + inline TSPCounted() noexcept + : C_(0) + { + } + + inline void Inc() noexcept { + AtomicIncrement(C_); + } + + //return false if C_ already 0, else increment and return true + inline bool TryInc() noexcept { + for (;;) { + intptr_t curVal(AtomicGet(C_)); + + if (!curVal) { + return false; + } + + intptr_t newVal(curVal + 1); + + if (AtomicCas(&C_, newVal, curVal)) { + return true; + } + } + } + + inline intptr_t Dec() noexcept { + return AtomicDecrement(C_); + } + + inline intptr_t Value() const noexcept { + return AtomicGet(C_); + } + + private: + TAtomic C_; + }; + + typedef TIntrusivePtr<TSPCounted> TSPCountedRef; + + class TWeakCount; + + class TSPCount { + public: + TSPCount(TSPCounted* c = nullptr) noexcept + : C_(c) + { + } + + inline void Swap(TSPCount& r) noexcept { + DoSwap(C_, r.C_); + } + + inline size_t UseCount() const noexcept { + if (!C_) { + return 0; + } + return C_->Value(); + } + + inline bool operator!() const noexcept { + return !C_; + } + + inline TSPCounted* GetCounted() const noexcept { + return C_.Get(); + } + + inline void Reset() noexcept { + if (!!C_) { + C_.Drop(); + } + } + + protected: + TIntrusivePtr<TSPCounted> C_; + }; + + class TSharedCount: public TSPCount { + public: + inline TSharedCount() noexcept { + } + + /// @throws std::bad_alloc + inline explicit TSharedCount(const TSharedCount& r) + : TSPCount(r.C_.Get()) + { + if (!!C_) { + (C_->Inc()); + } + } + + //'c' must exist and has already increased ref + inline explicit TSharedCount(TSPCounted* c) noexcept + : TSPCount(c) + { + } + + public: + /// @throws std::bad_alloc + inline void Inc() { + if (!C_) { + TSPCountedRef(new TSPCounted()).Swap(C_); + } + C_->Inc(); + } + + inline bool TryInc() noexcept { + if (!C_) { + return false; + } + return C_->TryInc(); + } + + inline intptr_t Dec() noexcept { + if (!C_) { + Y_ASSERT(0); + return 0; + } + return C_->Dec(); + } + + void Drop() noexcept { + C_.Drop(); + } + + protected: + template <class Y> + friend class TSharedPtrB; + + // 'c' MUST BE already incremented + void Assign(TSPCounted* c) noexcept { + TSPCountedRef(c).Swap(C_); + } + + private: + TSharedCount& operator=(const TSharedCount&); //disable + }; + + class TWeakCount: public TSPCount { + public: + inline TWeakCount() noexcept { + } + + inline explicit TWeakCount(const TWeakCount& r) noexcept + : TSPCount(r.GetCounted()) + { + } + + inline explicit TWeakCount(const TSharedCount& r) noexcept + : TSPCount(r.GetCounted()) + { + } + + private: + TWeakCount& operator=(const TWeakCount&); //disable + }; + + template <class T> + class TWeakPtrB; + + template <class T> + class TSharedPtrB { + public: + inline TSharedPtrB() noexcept + : T_(nullptr) + { + } + + /// @throws std::bad_alloc + inline TSharedPtrB(T* t) + : T_(nullptr) + { + if (t) { + THolder<T> h(t); + C_.Inc(); + T_ = h.Release(); + } + } + + inline TSharedPtrB(const TSharedPtrB<T>& r) noexcept + : T_(r.T_) + , C_(r.C_) + { + Y_ASSERT((!!T_ && !!C_.UseCount()) || (!T_ && !C_.UseCount())); + } + + inline TSharedPtrB(const TWeakPtrB<T>& r) noexcept + : T_(r.T_) + { + if (T_) { + TSPCounted* spc = r.C_.GetCounted(); + + if (spc && spc->TryInc()) { + C_.Assign(spc); + } else { //obsolete ptr + T_ = nullptr; + } + } + } + + inline ~TSharedPtrB() { + Reset(); + } + + TSharedPtrB& operator=(const TSharedPtrB<T>& r) noexcept { + TSharedPtrB<T>(r).Swap(*this); + return *this; + } + + TSharedPtrB& operator=(const TWeakPtrB<T>& r) noexcept { + TSharedPtrB<T>(r).Swap(*this); + return *this; + } + + void Swap(TSharedPtrB<T>& r) noexcept { + DoSwap(T_, r.T_); + DoSwap(C_, r.C_); + Y_ASSERT((!!T_ && !!UseCount()) || (!T_ && !UseCount())); + } + + inline bool operator!() const noexcept { + return !T_; + } + + inline T* Get() noexcept { + return T_; + } + + inline T* operator->() noexcept { + return T_; + } + + inline T* operator->() const noexcept { + return T_; + } + + inline T& operator*() noexcept { + return *T_; + } + + inline T& operator*() const noexcept { + return *T_; + } + + inline void Reset() noexcept { + if (T_) { + if (C_.Dec() == 0) { + delete T_; + } + T_ = nullptr; + C_.Drop(); + } + } + + inline size_t UseCount() const noexcept { + return C_.UseCount(); + } + + protected: + template <class Y> + friend class TWeakPtrB; + + T* T_; + TSharedCount C_; + }; + + template <class T> + class TWeakPtrB { + public: + inline TWeakPtrB() noexcept + : T_(nullptr) + { + } + + inline TWeakPtrB(const TWeakPtrB<T>& r) noexcept + : T_(r.T_) + , C_(r.C_) + { + } + + inline TWeakPtrB(const TSharedPtrB<T>& r) noexcept + : T_(r.T_) + , C_(r.C_) + { + } + + TWeakPtrB& operator=(const TWeakPtrB<T>& r) noexcept { + TWeakPtrB(r).Swap(*this); + return *this; + } + + TWeakPtrB& operator=(const TSharedPtrB<T>& r) noexcept { + TWeakPtrB(r).Swap(*this); + return *this; + } + + inline void Swap(TWeakPtrB<T>& r) noexcept { + DoSwap(T_, r.T_); + DoSwap(C_, r.C_); + } + + inline void Reset() noexcept { + T_ = 0; + C_.Reset(); + } + + inline size_t UseCount() const noexcept { + return C_.UseCount(); + } + + protected: + template <class Y> + friend class TSharedPtrB; + + T* T_; + TWeakCount C_; + }; + +} diff --git a/library/cpp/neh/stat.cpp b/library/cpp/neh/stat.cpp new file mode 100644 index 0000000000..ef6422fb52 --- /dev/null +++ b/library/cpp/neh/stat.cpp @@ -0,0 +1,114 @@ +#include "stat.h" + +#include <util/generic/hash.h> +#include <util/generic/singleton.h> +#include <util/system/spinlock.h> +#include <util/system/tls.h> + +using namespace NNeh; + +volatile TAtomic NNeh::TServiceStat::MaxContinuousErrors_ = 0; //by default disabled +volatile TAtomic NNeh::TServiceStat::ReSendValidatorPeriod_ = 100; + +NNeh::TServiceStat::EStatus NNeh::TServiceStat::GetStatus() { + if (!AtomicGet(MaxContinuousErrors_) || AtomicGet(LastContinuousErrors_) < AtomicGet(MaxContinuousErrors_)) { + return Ok; + } + + if (RequestsInProcess_.Val() != 0) + return Fail; + + if (AtomicIncrement(SendValidatorCounter_) != AtomicGet(ReSendValidatorPeriod_)) { + return Fail; + } + + //time for refresh service status (send validation request) + AtomicSet(SendValidatorCounter_, 0); + + return ReTry; +} + +void NNeh::TServiceStat::DbgOut(IOutputStream& out) const { + out << "----------------------------------------------------" << '\n';; + out << "RequestsInProcess: " << RequestsInProcess_.Val() << '\n'; + out << "LastContinuousErrors: " << AtomicGet(LastContinuousErrors_) << '\n'; + out << "SendValidatorCounter: " << AtomicGet(SendValidatorCounter_) << '\n'; + out << "ReSendValidatorPeriod: " << AtomicGet(ReSendValidatorPeriod_) << Endl; +} + +void NNeh::TServiceStat::OnBegin() { + RequestsInProcess_.Inc(); +} + +void NNeh::TServiceStat::OnSuccess() { + RequestsInProcess_.Dec(); + AtomicSet(LastContinuousErrors_, 0); +} + +void NNeh::TServiceStat::OnCancel() { + RequestsInProcess_.Dec(); +} + +void NNeh::TServiceStat::OnFail() { + RequestsInProcess_.Dec(); + if (AtomicIncrement(LastContinuousErrors_) == AtomicGet(MaxContinuousErrors_)) { + AtomicSet(SendValidatorCounter_, 0); + } +} + +namespace { + class TGlobalServicesStat { + public: + inline TServiceStatRef ServiceStat(const TStringBuf addr) noexcept { + const auto guard = Guard(Lock_); + + TServiceStatRef& ss = SS_[addr]; + + if (!ss) { + TServiceStatRef tmp(new TServiceStat()); + + ss.Swap(tmp); + } + return ss; + } + + protected: + TAdaptiveLock Lock_; + THashMap<TString, TServiceStatRef> SS_; + }; + + class TServicesStat { + public: + inline TServiceStatRef ServiceStat(const TStringBuf addr) noexcept { + TServiceStatRef& ss = SS_[addr]; + + if (!ss) { + TServiceStatRef tmp(Singleton<TGlobalServicesStat>()->ServiceStat(addr)); + + ss.Swap(tmp); + } + return ss; + } + + protected: + THashMap<TString, TServiceStatRef> SS_; + }; + + inline TServicesStat* ThrServiceStat() { + Y_POD_STATIC_THREAD(TServicesStat*) + ss; + + if (!ss) { + Y_STATIC_THREAD(TServicesStat) + tss; + + ss = &(TServicesStat&)tss; + } + + return ss; + } +} + +TServiceStatRef NNeh::GetServiceStat(const TStringBuf addr) { + return ThrServiceStat()->ServiceStat(addr); +} diff --git a/library/cpp/neh/stat.h b/library/cpp/neh/stat.h new file mode 100644 index 0000000000..803e8d2974 --- /dev/null +++ b/library/cpp/neh/stat.h @@ -0,0 +1,96 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/stream/output.h> +#include <library/cpp/deprecated/atomic/atomic.h> +#include <library/cpp/deprecated/atomic/atomic_ops.h> + +namespace NNeh { + class TStatCollector; + + /// NEH service workability statistics collector. + /// + /// Disabled by default, use `TServiceStat::ConfigureValidator` to set `maxContinuousErrors` + /// different from zero. + class TServiceStat: public TThrRefBase { + public: + static void ConfigureValidator(unsigned maxContinuousErrors, unsigned reSendValidatorPeriod) noexcept { + AtomicSet(MaxContinuousErrors_, maxContinuousErrors); + AtomicSet(ReSendValidatorPeriod_, reSendValidatorPeriod); + } + static bool Disabled() noexcept { + return !AtomicGet(MaxContinuousErrors_); + } + + enum EStatus { + Ok, + Fail, + ReTry //time for sending request-validator to service + }; + + EStatus GetStatus(); + + void DbgOut(IOutputStream&) const; + + protected: + friend class TStatCollector; + + virtual void OnBegin(); + virtual void OnSuccess(); + virtual void OnCancel(); + virtual void OnFail(); + + static TAtomic MaxContinuousErrors_; + static TAtomic ReSendValidatorPeriod_; + TAtomicCounter RequestsInProcess_; + TAtomic LastContinuousErrors_ = 0; + TAtomic SendValidatorCounter_ = 0; + }; + + using TServiceStatRef = TIntrusivePtr<TServiceStat>; + + //thread safe (race protected) service stat updater + class TStatCollector { + public: + TStatCollector(TServiceStatRef& ss) + : SS_(ss) + { + ss->OnBegin(); + } + + ~TStatCollector() { + if (CanInformSS()) { + SS_->OnFail(); + } + } + + void OnCancel() noexcept { + if (CanInformSS()) { + SS_->OnCancel(); + } + } + + void OnFail() noexcept { + if (CanInformSS()) { + SS_->OnFail(); + } + } + + void OnSuccess() noexcept { + if (CanInformSS()) { + SS_->OnSuccess(); + } + } + + private: + inline bool CanInformSS() noexcept { + return AtomicGet(CanInformSS_) && AtomicCas(&CanInformSS_, 0, 1); + } + + TServiceStatRef SS_; + TAtomic CanInformSS_ = 1; + }; + + TServiceStatRef GetServiceStat(TStringBuf addr); + +} diff --git a/library/cpp/neh/tcp.cpp b/library/cpp/neh/tcp.cpp new file mode 100644 index 0000000000..80f464dac2 --- /dev/null +++ b/library/cpp/neh/tcp.cpp @@ -0,0 +1,676 @@ +#include "tcp.h" + +#include "details.h" +#include "factory.h" +#include "location.h" +#include "pipequeue.h" +#include "utils.h" + +#include <library/cpp/coroutine/listener/listen.h> +#include <library/cpp/coroutine/engine/events.h> +#include <library/cpp/coroutine/engine/sockpool.h> +#include <library/cpp/dns/cache.h> + +#include <util/ysaveload.h> +#include <util/generic/buffer.h> +#include <util/generic/guid.h> +#include <util/generic/hash.h> +#include <util/generic/intrlist.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/system/yassert.h> +#include <util/system/unaligned_mem.h> +#include <util/stream/buffered.h> +#include <util/stream/mem.h> + +using namespace NDns; +using namespace NNeh; + +using TNehMessage = TMessage; + +template <> +struct TSerializer<TGUID> { + static inline void Save(IOutputStream* out, const TGUID& g) { + out->Write(&g.dw, sizeof(g.dw)); + } + + static inline void Load(IInputStream* in, TGUID& g) { + in->Load(&g.dw, sizeof(g.dw)); + } +}; + +namespace { + namespace NNehTCP { + typedef IOutputStream::TPart TPart; + + static inline ui64 LocalGuid(const TGUID& g) { + return ReadUnaligned<ui64>(g.dw); + } + + static inline TString LoadStroka(IInputStream& input, size_t len) { + TString tmp; + + tmp.ReserveAndResize(len); + input.Load(tmp.begin(), tmp.size()); + + return tmp; + } + + struct TParts: public TVector<TPart> { + template <class T> + inline void Push(const T& t) { + Push(TPart(t)); + } + + inline void Push(const TPart& part) { + if (part.len) { + push_back(part); + } + } + + inline void Clear() noexcept { + clear(); + } + }; + + template <class T> + struct TMessageQueue { + inline TMessageQueue(TContExecutor* e) + : Ev(e) + { + } + + template <class TPtr> + inline void Enqueue(TPtr p) noexcept { + L.PushBack(p.Release()); + Ev.Signal(); + } + + template <class TPtr> + inline bool Dequeue(TPtr& p) noexcept { + do { + if (TryDequeue(p)) { + return true; + } + } while (Ev.WaitI() != ECANCELED); + + return false; + } + + template <class TPtr> + inline bool TryDequeue(TPtr& p) noexcept { + if (L.Empty()) { + return false; + } + + p.Reset(L.PopFront()); + + return true; + } + + inline TContExecutor* Executor() const noexcept { + return Ev.Executor(); + } + + TIntrusiveListWithAutoDelete<T, TDelete> L; + TContSimpleEvent Ev; + }; + + template <class Q, class C> + inline bool Dequeue(Q& q, C& c, size_t len) { + typename C::value_type t; + size_t slen = 0; + + if (q.Dequeue(t)) { + slen += t->Length(); + c.push_back(t); + + while (slen < len && q.TryDequeue(t)) { + slen += t->Length(); + c.push_back(t); + } + + return true; + } + + return false; + } + + struct TServer: public IRequester, public TContListener::ICallBack { + struct TLink; + typedef TIntrusivePtr<TLink> TLinkRef; + + struct TResponce: public TIntrusiveListItem<TResponce> { + inline TResponce(const TLinkRef& link, TData& data, TStringBuf reqid) + : Link(link) + { + Data.swap(data); + + TMemoryOutput out(Buf, sizeof(Buf)); + + ::Save(&out, (ui32)(reqid.size() + Data.size())); + out.Write(reqid.data(), reqid.size()); + + Y_ASSERT(reqid.size() == 16); + + Len = out.Buf() - Buf; + } + + inline void Serialize(TParts& parts) { + parts.Push(TStringBuf(Buf, Len)); + parts.Push(TStringBuf(Data.data(), Data.size())); + } + + inline size_t Length() const noexcept { + return Len + Data.size(); + } + + TLinkRef Link; + TData Data; + char Buf[32]; + size_t Len; + }; + + typedef TAutoPtr<TResponce> TResponcePtr; + + struct TRequest: public IRequest { + inline TRequest(const TLinkRef& link, IInputStream& in, size_t len) + : Link(link) + { + Buf.Proceed(len); + in.Load(Buf.Data(), Buf.Size()); + if ((ServiceBegin() - Buf.Data()) + ServiceLen() > Buf.Size()) { + throw yexception() << "invalid request (service len)"; + } + } + + TStringBuf Scheme() const override { + return TStringBuf("tcp"); + } + + TString RemoteHost() const override { + return Link->RemoteHost; + } + + TStringBuf Service() const override { + return TStringBuf(ServiceBegin(), ServiceLen()); + } + + TStringBuf Data() const override { + return TStringBuf(Service().end(), Buf.End()); + } + + TStringBuf RequestId() const override { + return TStringBuf(Buf.Data(), 16); + } + + bool Canceled() const override { + //TODO + return false; + } + + void SendReply(TData& data) override { + Link->P->Schedule(new TResponce(Link, data, RequestId())); + } + + void SendError(TResponseError, const TString&) override { + // TODO + } + + size_t ServiceLen() const noexcept { + const char* ptr = RequestId().end(); + return *(ui32*)ptr; + } + + const char* ServiceBegin() const noexcept { + return RequestId().end() + sizeof(ui32); + } + + TBuffer Buf; + TLinkRef Link; + }; + + struct TLink: public TAtomicRefCount<TLink> { + inline TLink(TServer* parent, const TAcceptFull& a) + : P(parent) + , MQ(Executor()) + { + S.Swap(*a.S); + SetNoDelay(S, true); + + RemoteHost = PrintHostByRfc(*GetPeerAddr(S)); + + TLinkRef self(this); + + Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv"); + Executor()->Create<TLink, &TLink::SendCycle>(this, "send"); + + Executor()->Running()->Yield(); + } + + inline void Enqueue(TResponcePtr res) { + MQ.Enqueue(res); + } + + inline TContExecutor* Executor() const noexcept { + return P->E.Get(); + } + + void SendCycle(TCont* c) { + TLinkRef self(this); + + try { + DoSendCycle(c); + } catch (...) { + Cdbg << "neh/tcp/1: " << CurrentExceptionMessage() << Endl; + } + } + + inline void DoSendCycle(TCont* c) { + TVector<TResponcePtr> responses; + TParts parts; + + while (Dequeue(MQ, responses, 7000)) { + for (size_t i = 0; i < responses.size(); ++i) { + responses[i]->Serialize(parts); + } + + { + TContIOVector iovec(parts.data(), parts.size()); + NCoro::WriteVectorI(c, S, &iovec); + } + + parts.Clear(); + responses.clear(); + } + } + + void RecvCycle(TCont* c) { + TLinkRef self(this); + + try { + DoRecvCycle(c); + } catch (...) { + if (!c->Cancelled()) { + Cdbg << "neh/tcp/2: " << CurrentExceptionMessage() << Endl; + } + } + } + + inline void DoRecvCycle(TCont* c) { + TContIO io(S, c); + TBufferedInput input(&io, 8192 * 4); + + while (true) { + ui32 len; + + try { + ::Load(&input, len); + } catch (TLoadEOF&) { + return; + } + + P->CB->OnRequest(new TRequest(this, input, len)); + } + } + + TServer* P; + TMessageQueue<TResponce> MQ; + TSocketHolder S; + TString RemoteHost; + }; + + inline TServer(IOnRequest* cb, ui16 port) + : CB(cb) + , Addr(port) + { + Thrs.push_back(Spawn<TServer, &TServer::Run>(this)); + } + + ~TServer() override { + Schedule(nullptr); + + for (size_t i = 0; i < Thrs.size(); ++i) { + Thrs[i]->Join(); + } + } + + void Run() { + E = MakeHolder<TContExecutor>(RealStackSize(32000)); + THolder<TContListener> L(new TContListener(this, E.Get(), TContListener::TOptions().SetDeferAccept(true))); + //SetHighestThreadPriority(); + L->Bind(Addr); + E->Create<TServer, &TServer::RunDispatcher>(this, "dispatcher"); + L->Listen(); + E->Execute(); + } + + void OnAcceptFull(const TAcceptFull& a) override { + //I love such code + new TLink(this, a); + } + + void OnError() override { + Cerr << CurrentExceptionMessage() << Endl; + } + + inline void Schedule(TResponcePtr res) { + PQ.EnqueueSafe(res); + } + + void RunDispatcher(TCont* c) { + while (true) { + TResponcePtr res; + + PQ.DequeueSafe(c, res); + + if (!res) { + break; + } + + TLinkRef link = res->Link; + + link->Enqueue(res); + } + + c->Executor()->Abort(); + } + THolder<TContExecutor> E; + IOnRequest* CB; + TNetworkAddress Addr; + TOneConsumerPipeQueue<TResponce> PQ; + TVector<TThreadRef> Thrs; + }; + + struct TClient { + struct TRequest: public TIntrusiveListItem<TRequest> { + inline TRequest(const TSimpleHandleRef& hndl, const TNehMessage& msg) + : Hndl(hndl) + , Msg(msg) + , Loc(Msg.Addr) + , RI(CachedThrResolve(TResolveInfo(Loc.Host, Loc.GetPort()))) + { + CreateGuid(&Guid); + } + + inline void Serialize(TParts& parts) { + TMemoryOutput out(Buf, sizeof(Buf)); + + ::Save(&out, (ui32)MsgLen()); + ::Save(&out, Guid); + ::Save(&out, (ui32) Loc.Service.size()); + + if (Loc.Service.size() > out.Avail()) { + parts.Push(TStringBuf(Buf, out.Buf())); + parts.Push(Loc.Service); + } else { + out.Write(Loc.Service.data(), Loc.Service.size()); + parts.Push(TStringBuf(Buf, out.Buf())); + } + + parts.Push(Msg.Data); + } + + inline size_t Length() const noexcept { + return sizeof(ui32) + MsgLen(); + } + + inline size_t MsgLen() const noexcept { + return sizeof(Guid.dw) + sizeof(ui32) + Loc.Service.size() + Msg.Data.size(); + } + + void OnError(const TString& errText) { + Hndl->NotifyError(errText); + } + + TSimpleHandleRef Hndl; + TNehMessage Msg; + TGUID Guid; + const TParsedLocation Loc; + const TResolvedHost* RI; + char Buf[128]; + }; + + typedef TAutoPtr<TRequest> TRequestPtr; + + struct TChannel { + struct TLink: public TIntrusiveListItem<TLink>, public TSimpleRefCount<TLink> { + inline TLink(TChannel* parent) + : P(parent) + { + Executor()->Create<TLink, &TLink::SendCycle>(this, "send"); + } + + void SendCycle(TCont* c) { + TIntrusivePtr<TLink> self(this); + + try { + DoSendCycle(c); + OnError("shutdown"); + } catch (...) { + OnError(CurrentExceptionMessage()); + } + + Unlink(); + } + + inline void DoSendCycle(TCont* c) { + if (int ret = NCoro::ConnectI(c, S, P->RI->Addr)) { + ythrow TSystemError(ret) << "can't connect"; + } + SetNoDelay(S, true); + Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv"); + + TVector<TRequestPtr> reqs; + TParts parts; + + while (Dequeue(P->Q, reqs, 7000)) { + for (size_t i = 0; i < reqs.size(); ++i) { + TRequestPtr& req = reqs[i]; + + req->Serialize(parts); + InFly[LocalGuid(req->Guid)] = req; + } + + { + TContIOVector vec(parts.data(), parts.size()); + NCoro::WriteVectorI(c, S, &vec); + } + + reqs.clear(); + parts.Clear(); + } + } + + void RecvCycle(TCont* c) { + TIntrusivePtr<TLink> self(this); + + try { + DoRecvCycle(c); + OnError("service close connection"); + } catch (...) { + OnError(CurrentExceptionMessage()); + } + } + + inline void DoRecvCycle(TCont* c) { + TContIO io(S, c); + TBufferedInput input(&io, 8192 * 4); + + while (true) { + ui32 len; + TGUID g; + + try { + ::Load(&input, len); + } catch (TLoadEOF&) { + return; + } + ::Load(&input, g); + const TString data(LoadStroka(input, len - sizeof(g.dw))); + + TInFly::iterator it = InFly.find(LocalGuid(g)); + + if (it == InFly.end()) { + continue; + } + + TRequestPtr req = it->second; + + InFly.erase(it); + req->Hndl->NotifyResponse(data); + } + } + + inline TContExecutor* Executor() const noexcept { + return P->Q.Executor(); + } + + void OnError(const TString& errText) { + for (auto& it : InFly) { + it.second->OnError(errText); + } + InFly.clear(); + + TRequestPtr req; + while (P->Q.TryDequeue(req)) { + req->OnError(errText); + } + } + + TChannel* P; + TSocketHolder S; + typedef THashMap<ui64, TRequestPtr> TInFly; + TInFly InFly; + }; + + inline TChannel(TContExecutor* e, const TResolvedHost* ri) + : Q(e) + , RI(ri) + { + } + + inline void Enqueue(TRequestPtr req) { + Q.Enqueue(req); + + if (Links.Empty()) { + for (size_t i = 0; i < 1; ++i) { + SpawnLink(); + } + } + } + + inline void SpawnLink() { + Links.PushBack(new TLink(this)); + } + + TMessageQueue<TRequest> Q; + TIntrusiveList<TLink> Links; + const TResolvedHost* RI; + }; + + typedef TAutoPtr<TChannel> TChannelPtr; + + inline TClient() { + Thr = Spawn<TClient, &TClient::RunExecutor>(this); + } + + inline ~TClient() { + Reqs.Enqueue(nullptr); + Thr->Join(); + } + + inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) { + TSimpleHandleRef ret(new TSimpleHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + + Reqs.Enqueue(new TRequest(ret, msg)); + + return ret.Get(); + } + + void RunExecutor() { + //SetHighestThreadPriority(); + TContExecutor e(RealStackSize(32000)); + + e.Create<TClient, &TClient::RunDispatcher>(this, "dispatcher"); + e.Execute(); + } + + void RunDispatcher(TCont* c) { + TRequestPtr req; + + while (true) { + Reqs.DequeueSafe(c, req); + + if (!req) { + break; + } + + TChannelPtr& ch = Channels.Get(req->RI->Id); + + if (!ch) { + ch.Reset(new TChannel(c->Executor(), req->RI)); + } + + ch->Enqueue(req); + } + + c->Executor()->Abort(); + } + + TThreadRef Thr; + TOneConsumerPipeQueue<TRequest> Reqs; + TSocketMap<TChannelPtr> Channels; + }; + + struct TMultiClient { + inline TMultiClient() + : Next(0) + { + for (size_t i = 0; i < 2; ++i) { + Clients.push_back(new TClient()); + } + } + + inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) { + return Clients[AtomicIncrement(Next) % Clients.size()]->Schedule(msg, fallback, ss); + } + + TVector<TAutoPtr<TClient>> Clients; + TAtomic Next; + }; + +#if 0 + static inline TMultiClient* Client() { + return Singleton<NNehTCP::TMultiClient>(); + } +#else + static inline TClient* Client() { + return Singleton<NNehTCP::TClient>(); + } +#endif + + class TTcpProtocol: public IProtocol { + public: + inline TTcpProtocol() { + InitNetworkSubSystem(); + } + + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new TServer(cb, loc.GetPort()); + } + + THandleRef ScheduleRequest(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + return Client()->Schedule(msg, fallback, ss); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("tcp"); + } + }; + } +} + +IProtocol* NNeh::TcpProtocol() { + return Singleton<NNehTCP::TTcpProtocol>(); +} diff --git a/library/cpp/neh/tcp.h b/library/cpp/neh/tcp.h new file mode 100644 index 0000000000..e0d25d0bac --- /dev/null +++ b/library/cpp/neh/tcp.h @@ -0,0 +1,7 @@ +#pragma once + +namespace NNeh { + class IProtocol; + + IProtocol* TcpProtocol(); +} diff --git a/library/cpp/neh/tcp2.cpp b/library/cpp/neh/tcp2.cpp new file mode 100644 index 0000000000..3dad055af1 --- /dev/null +++ b/library/cpp/neh/tcp2.cpp @@ -0,0 +1,1656 @@ +#include "tcp2.h" + +#include "details.h" +#include "factory.h" +#include "http_common.h" +#include "neh.h" +#include "utils.h" + +#include <library/cpp/dns/cache.h> +#include <library/cpp/neh/asio/executor.h> +#include <library/cpp/threading/atomic/bool.h> + +#include <util/generic/buffer.h> +#include <util/generic/hash.h> +#include <util/generic/singleton.h> +#include <util/network/endpoint.h> +#include <util/network/init.h> +#include <util/network/iovec.h> +#include <util/network/socket.h> +#include <util/string/cast.h> + +#include <atomic> + +//#define DEBUG_TCP2 1 +#ifdef DEBUG_TCP2 +TSpinLock OUT_LOCK; +#define DBGOUT(args) \ + { \ + TGuard<TSpinLock> m(OUT_LOCK); \ + Cout << TInstant::Now().GetValue() << " " << args << Endl; \ + } +#else +#define DBGOUT(args) +#endif + +using namespace std::placeholders; + +namespace NNeh { + TDuration TTcp2Options::ConnectTimeout = TDuration::MilliSeconds(300); + size_t TTcp2Options::InputBufferSize = 16000; + size_t TTcp2Options::AsioClientThreads = 4; + size_t TTcp2Options::AsioServerThreads = 4; + int TTcp2Options::Backlog = 100; + bool TTcp2Options::ClientUseDirectWrite = true; + bool TTcp2Options::ServerUseDirectWrite = true; + TDuration TTcp2Options::ServerInputDeadline = TDuration::Seconds(3600); + TDuration TTcp2Options::ServerOutputDeadline = TDuration::Seconds(10); + + bool TTcp2Options::Set(TStringBuf name, TStringBuf value) { +#define TCP2_TRY_SET(optType, optName) \ + if (name == TStringBuf(#optName)) { \ + optName = FromString<optType>(value); \ + } + + TCP2_TRY_SET(TDuration, ConnectTimeout) + else TCP2_TRY_SET(size_t, InputBufferSize) else TCP2_TRY_SET(size_t, AsioClientThreads) else TCP2_TRY_SET(size_t, AsioServerThreads) else TCP2_TRY_SET(int, Backlog) else TCP2_TRY_SET(bool, ClientUseDirectWrite) else TCP2_TRY_SET(bool, ServerUseDirectWrite) else TCP2_TRY_SET(TDuration, ServerInputDeadline) else TCP2_TRY_SET(TDuration, ServerOutputDeadline) else { + return false; + } + return true; + } +} + +namespace { + namespace NNehTcp2 { + using namespace NAsio; + using namespace NDns; + using namespace NNeh; + + const TString canceled = "canceled"; + const TString emptyReply = "empty reply"; + + inline void PrepareSocket(SOCKET s) { + SetNoDelay(s, true); + } + + typedef ui64 TRequestId; + +#pragma pack(push, 1) //disable align struct members (structs mapped to data transmitted other network) + struct TBaseHeader { + enum TMessageType { + Request = 1, + Response = 2, + Cancel = 3, + MaxMessageType + }; + + TBaseHeader(TRequestId id, ui32 headerLength, ui8 version, ui8 mType) + : Id(id) + , HeaderLength(headerLength) + , Version(version) + , Type(mType) + { + } + + TRequestId Id; //message id, - monotonic inc. sequence (skip nil value) + ui32 HeaderLength; + ui8 Version; //current version: 1 + ui8 Type; //<- TMessageType (+ in future possible ForceResponse,etc) + }; + + struct TRequestHeader: public TBaseHeader { + TRequestHeader(TRequestId reqId, size_t servicePathLength, size_t dataSize) + : TBaseHeader(reqId, sizeof(TRequestHeader) + servicePathLength, 1, (ui8)Request) + , ContentLength(dataSize) + { + } + + ui32 ContentLength; + }; + + struct TResponseHeader: public TBaseHeader { + enum TErrorCode { + Success = 0, + EmptyReply = 1 //not found such service or service not sent response + , + MaxErrorCode + }; + + TResponseHeader(TRequestId reqId, TErrorCode code, size_t dataSize) + : TBaseHeader(reqId, sizeof(TResponseHeader), 1, (ui8)Response) + , ErrorCode((ui16)code) + , ContentLength(dataSize) + { + } + + TString ErrorDescription() const { + if (ErrorCode == (ui16)EmptyReply) { + return emptyReply; + } + + TStringStream ss; + ss << TStringBuf("tcp2 err_code=") << ErrorCode; + return ss.Str(); + } + + ui16 ErrorCode; + ui32 ContentLength; + }; + + struct TCancelHeader: public TBaseHeader { + TCancelHeader(TRequestId reqId) + : TBaseHeader(reqId, sizeof(TCancelHeader), 1, (ui8)Cancel) + { + } + }; +#pragma pack(pop) + + static const size_t maxHeaderSize = sizeof(TResponseHeader); + + //buffer for read input data, - header + message data + struct TTcp2Message { + TTcp2Message() + : Loader_(&TTcp2Message::LoadBaseHeader) + , RequireBytesForComplete_(sizeof(TBaseHeader)) + , Header_(sizeof(TBaseHeader)) + { + } + + void Clear() { + Loader_ = &TTcp2Message::LoadBaseHeader; + RequireBytesForComplete_ = sizeof(TBaseHeader); + Header_.Clear(); + Content_.clear(); + } + + TBuffer& Header() noexcept { + return Header_; + } + + const TString& Content() const noexcept { + return Content_; + } + + bool IsComplete() const noexcept { + return RequireBytesForComplete_ == 0; + } + + size_t LoadFrom(const char* buf, size_t len) { + return (this->*Loader_)(buf, len); + } + + const TBaseHeader& BaseHeader() const { + return *reinterpret_cast<const TBaseHeader*>(Header_.Data()); + } + + const TRequestHeader& RequestHeader() const { + return *reinterpret_cast<const TRequestHeader*>(Header_.Data()); + } + + const TResponseHeader& ResponseHeader() const { + return *reinterpret_cast<const TResponseHeader*>(Header_.Data()); + } + + private: + size_t LoadBaseHeader(const char* buf, size_t len) { + size_t useBytes = Min<size_t>(sizeof(TBaseHeader) - Header_.Size(), len); + Header_.Append(buf, useBytes); + if (Y_UNLIKELY(sizeof(TBaseHeader) > Header_.Size())) { + //base header yet not complete + return useBytes; + } + { + const TBaseHeader& hdr = BaseHeader(); + if (BaseHeader().HeaderLength > 32000) { //some heuristic header size limit + throw yexception() << TStringBuf("to large neh/tcp2 header size: ") << BaseHeader().HeaderLength; + } + //header completed + Header_.Reserve(hdr.HeaderLength); + } + const TBaseHeader& hdr = BaseHeader(); //reallocation can move Header_ data to another place, so use fresh 'hdr' + if (Y_UNLIKELY(hdr.Version != 1)) { + throw yexception() << TStringBuf("unsupported protocol version: ") << static_cast<unsigned>(hdr.Version); + } + RequireBytesForComplete_ = hdr.HeaderLength - sizeof(TBaseHeader); + return useBytes + LoadHeader(buf + useBytes, len - useBytes); + } + + size_t LoadHeader(const char* buf, size_t len) { + size_t useBytes = Min<size_t>(RequireBytesForComplete_, len); + Header_.Append(buf, useBytes); + RequireBytesForComplete_ -= useBytes; + if (RequireBytesForComplete_) { + //continue load header + Loader_ = &TTcp2Message::LoadHeader; + return useBytes; + } + + const TBaseHeader& hdr = *reinterpret_cast<const TBaseHeader*>(Header_.Data()); + + if (hdr.Type == TBaseHeader::Request) { + if (Header_.Size() < sizeof(TRequestHeader)) { + throw yexception() << TStringBuf("invalid request header size"); + } + InitContentLoading(RequestHeader().ContentLength); + } else if (hdr.Type == TBaseHeader::Response) { + if (Header_.Size() < sizeof(TResponseHeader)) { + throw yexception() << TStringBuf("invalid response header size"); + } + InitContentLoading(ResponseHeader().ContentLength); + } else if (hdr.Type == TBaseHeader::Cancel) { + if (Header_.Size() < sizeof(TCancelHeader)) { + throw yexception() << TStringBuf("invalid cancel header size"); + } + return useBytes; + } else { + throw yexception() << TStringBuf("unsupported request type: ") << static_cast<unsigned>(hdr.Type); + } + return useBytes + (this->*Loader_)(buf + useBytes, len - useBytes); + } + + void InitContentLoading(size_t contentLength) { + RequireBytesForComplete_ = contentLength; + Content_.ReserveAndResize(contentLength); + Loader_ = &TTcp2Message::LoadContent; + } + + size_t LoadContent(const char* buf, size_t len) { + size_t curContentSize = Content_.size() - RequireBytesForComplete_; + size_t useBytes = Min<size_t>(RequireBytesForComplete_, len); + memcpy(Content_.begin() + curContentSize, buf, useBytes); + RequireBytesForComplete_ -= useBytes; + return useBytes; + } + + private: + typedef size_t (TTcp2Message::*TLoader)(const char*, size_t); + + TLoader Loader_; //current loader (stages - base-header/header/content) + size_t RequireBytesForComplete_; + TBuffer Header_; + TString Content_; + }; + + //base storage for output data + class TMultiBuffers { + public: + TMultiBuffers() + : IOVec_(nullptr, 0) + , DataSize_(0) + , PoolBytes_(0) + { + } + + void Clear() noexcept { + Parts_.clear(); + DataSize_ = 0; + PoolBytes_ = 0; + } + + bool HasFreeSpace() const noexcept { + return DataSize_ < 64000 && (PoolBytes_ < (MemPoolSize_ - maxHeaderSize)); + } + + bool HasData() const noexcept { + return Parts_.size(); + } + + TContIOVector* GetIOvec() noexcept { + return &IOVec_; + } + + protected: + void AddPart(const void* buf, size_t len) { + Parts_.push_back(IOutputStream::TPart(buf, len)); + DataSize_ += len; + } + + //used for allocate header (MUST be POD type) + template <typename T> + inline T* Allocate() noexcept { + size_t poolBytes = PoolBytes_; + PoolBytes_ += sizeof(T); + return (T*)(MemPool_ + poolBytes); + } + + //used for allocate header (MUST be POD type) + some tail + template <typename T> + inline T* AllocatePlus(size_t tailSize) noexcept { + Y_ASSERT(tailSize <= MemPoolReserve_); + size_t poolBytes = PoolBytes_; + PoolBytes_ += sizeof(T) + tailSize; + return (T*)(MemPool_ + poolBytes); + } + + protected: + TContIOVector IOVec_; + TVector<IOutputStream::TPart> Parts_; + static const size_t MemPoolSize_ = maxHeaderSize * 100; + static const size_t MemPoolReserve_ = 32; + size_t DataSize_; + size_t PoolBytes_; + char MemPool_[MemPoolSize_ + MemPoolReserve_]; + }; + + //protector for limit usage tcp connection output (and used data) only from one thread at same time + class TOutputLock { + public: + TOutputLock() noexcept + : Lock_(0) + { + } + + bool TryAquire() noexcept { + do { + if (AtomicTryLock(&Lock_)) { + return true; + } + } while (!AtomicGet(Lock_)); //without magic loop atomic lock some unreliable + return false; + } + + void Release() noexcept { + AtomicUnlock(&Lock_); + } + + bool IsFree() const noexcept { + return !AtomicGet(Lock_); + } + + private: + TAtomic Lock_; + }; + + class TClient { + class TRequest; + class TConnection; + typedef TIntrusivePtr<TRequest> TRequestRef; + typedef TIntrusivePtr<TConnection> TConnectionRef; + + class TRequest: public TThrRefBase, public TNonCopyable { + public: + class THandle: public TSimpleHandle { + public: + THandle(IOnRecv* f, const TMessage& msg, TStatCollector* s) noexcept + : TSimpleHandle(f, msg, s) + { + } + + bool MessageSendedCompletely() const noexcept override { + if (TSimpleHandle::MessageSendedCompletely()) { + return true; + } + + TRequestRef req = GetRequest(); + if (!!req && req->RequestSendedCompletely()) { + const_cast<THandle*>(this)->SetSendComplete(); + } + + return TSimpleHandle::MessageSendedCompletely(); + } + + void Cancel() noexcept override { + if (TSimpleHandle::Canceled()) { + return; + } + + TRequestRef req = GetRequest(); + if (!!req) { + req->Cancel(); + TSimpleHandle::Cancel(); + } + } + + void NotifyResponse(const TString& resp) { + TNotifyHandle::NotifyResponse(resp); + + ReleaseRequest(); + } + + void NotifyError(const TString& error) { + TNotifyHandle::NotifyError(error); + + ReleaseRequest(); + } + + void NotifyError(TErrorRef error) { + TNotifyHandle::NotifyError(error); + + ReleaseRequest(); + } + + //not thread safe! + void SetRequest(const TRequestRef& r) noexcept { + Req_ = r; + } + + void ReleaseRequest() noexcept { + TRequestRef tmp; + TGuard<TSpinLock> g(SP_); + tmp.Swap(Req_); + } + + private: + TRequestRef GetRequest() const noexcept { + TGuard<TSpinLock> g(SP_); + return Req_; + } + + mutable TSpinLock SP_; + TRequestRef Req_; + }; + + typedef TIntrusivePtr<THandle> THandleRef; + + static void Run(THandleRef& h, const TMessage& msg, TClient& clnt) { + TRequestRef req(new TRequest(h, msg, clnt)); + h->SetRequest(req); + req->Run(req); + } + + ~TRequest() override { + DBGOUT("TClient::~TRequest()"); + } + + private: + TRequest(THandleRef& h, TMessage msg, TClient& clnt) + : Hndl_(h) + , Clnt_(clnt) + , Msg_(std::move(msg)) + , Loc_(Msg_.Addr) + , Addr_(CachedResolve(TResolveInfo(Loc_.Host, Loc_.GetPort()))) + , Canceled_(false) + , Id_(0) + { + DBGOUT("TClient::TRequest()"); + } + + void Run(TRequestRef& req) { + TDestination& dest = Clnt_.Dest_.Get(Addr_->Id); + dest.Run(req); + } + + public: + void OnResponse(TTcp2Message& msg) { + DBGOUT("TRequest::OnResponse: " << msg.ResponseHeader().Id); + THandleRef h = ReleaseHandler(); + if (!h) { + return; + } + + const TResponseHeader& respHdr = msg.ResponseHeader(); + if (Y_LIKELY(!respHdr.ErrorCode)) { + h->NotifyResponse(msg.Content()); + } else { + h->NotifyError(new TError(respHdr.ErrorDescription(), TError::ProtocolSpecific, respHdr.ErrorCode)); + } + ReleaseConn(); + } + + void OnError(const TString& err, const i32 systemCode = 0) { + DBGOUT("TRequest::OnError: " << Id_.load(std::memory_order_acquire)); + THandleRef h = ReleaseHandler(); + if (!h) { + return; + } + + h->NotifyError(new TError(err, TError::UnknownType, 0, systemCode)); + ReleaseConn(); + } + + void SetConnection(TConnection* conn) noexcept { + auto g = Guard(AL_); + Conn_ = conn; + } + + bool Canceled() const noexcept { + return Canceled_; + } + + const TResolvedHost* Addr() const noexcept { + return Addr_; + } + + TStringBuf Service() const noexcept { + return Loc_.Service; + } + + const TString& Data() const noexcept { + return Msg_.Data; + } + + TClient& Client() noexcept { + return Clnt_; + } + + bool RequestSendedCompletely() const noexcept { + if (Id_.load(std::memory_order_acquire) == 0) { + return false; + } + + TConnectionRef conn = GetConn(); + if (!conn) { + return false; + } + + TRequestId lastSendedReqId = conn->LastSendedRequestId(); + if (lastSendedReqId >= Id_.load(std::memory_order_acquire)) { + return true; + } else if (Y_UNLIKELY((Id_.load(std::memory_order_acquire) - lastSendedReqId) > (Max<TRequestId>() - Max<ui32>()))) { + //overflow req-id value + return true; + } + return false; + } + + void Cancel() noexcept { + Canceled_ = true; + THandleRef h = ReleaseHandler(); + if (!h) { + return; + } + + TConnectionRef conn = ReleaseConn(); + if (!!conn && Id_.load(std::memory_order_acquire)) { + conn->Cancel(Id_.load(std::memory_order_acquire)); + } + h->NotifyError(new TError(canceled, TError::Cancelled)); + } + + void SetReqId(TRequestId reqId) noexcept { + auto guard = Guard(IdLock_); + Id_.store(reqId, std::memory_order_release); + } + + TRequestId ReqId() const noexcept { + return Id_.load(std::memory_order_acquire); + } + + private: + inline THandleRef ReleaseHandler() noexcept { + THandleRef h; + { + auto g = Guard(AL_); + h.Swap(Hndl_); + } + return h; + } + + inline TConnectionRef GetConn() const noexcept { + auto g = Guard(AL_); + return Conn_; + } + + inline TConnectionRef ReleaseConn() noexcept { + TConnectionRef c; + { + auto g = Guard(AL_); + c.Swap(Conn_); + } + return c; + } + + mutable TAdaptiveLock AL_; //guaranted calling notify() only once (prevent race between asio thread and current) + THandleRef Hndl_; + TClient& Clnt_; + const TMessage Msg_; + const TParsedLocation Loc_; + const TResolvedHost* Addr_; + TConnectionRef Conn_; + NAtomic::TBool Canceled_; + TSpinLock IdLock_; + std::atomic<TRequestId> Id_; + }; + + class TConnection: public TThrRefBase { + enum TState { + Init, + Connecting, + Connected, + Closed, + MaxState + }; + typedef THashMap<TRequestId, TRequestRef> TReqsInFly; + + public: + class TOutputBuffers: public TMultiBuffers { + public: + void AddRequest(const TRequestRef& req) { + Requests_.push_back(req); + if (req->Service().size() > MemPoolReserve_) { + TRequestHeader* hdr = new (Allocate<TRequestHeader>()) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size()); + AddPart(hdr, sizeof(TRequestHeader)); + AddPart(req->Service().data(), req->Service().size()); + } else { + TRequestHeader* hdr = new (AllocatePlus<TRequestHeader>(req->Service().size())) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size()); + AddPart(hdr, sizeof(TRequestHeader) + req->Service().size()); + memmove(++hdr, req->Service().data(), req->Service().size()); + } + AddPart(req->Data().data(), req->Data().size()); + IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); + } + + void AddCancelRequest(TRequestId reqId) { + TCancelHeader* hdr = new (Allocate<TCancelHeader>()) TCancelHeader(reqId); + AddPart(hdr, sizeof(TCancelHeader)); + IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); + } + + void Clear() { + TMultiBuffers::Clear(); + Requests_.clear(); + } + + private: + TVector<TRequestRef> Requests_; + }; + + TConnection(TIOService& srv) + : AS_(srv) + , State_(Init) + , BuffSize_(TTcp2Options::InputBufferSize) + , Buff_(new char[BuffSize_]) + , NeedCheckReqsQueue_(0) + , NeedCheckCancelsQueue_(0) + , GenReqId_(0) + , LastSendedReqId_(0) + { + } + + ~TConnection() override { + try { + DBGOUT("TClient::~TConnection()"); + OnError("~"); + } catch (...) { + Cdbg << "tcp2::~cln_conn: " << CurrentExceptionMessage() << Endl; + } + } + + //called from client thread + bool Run(TRequestRef& req) { + if (Y_UNLIKELY(AtomicGet(State_) == Closed)) { + return false; + } + + req->Ref(); + try { + Reqs_.Enqueue(req.Get()); + } catch (...) { + req->UnRef(); + throw; + } + + AtomicSet(NeedCheckReqsQueue_, 1); + req->SetConnection(this); + TAtomicBase state = AtomicGet(State_); + if (Y_LIKELY(state == Connected)) { + ProcessOutputReqsQueue(); + return true; + } + + if (state == Init) { + if (AtomicCas(&State_, Connecting, Init)) { + try { + TEndpoint addr(new NAddr::TAddrInfo(&*req->Addr()->Addr.Begin())); + AS_.AsyncConnect(addr, std::bind(&TConnection::OnConnect, TConnectionRef(this), _1, _2), TTcp2Options::ConnectTimeout); + } catch (...) { + AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCallback, TConnectionRef(this), CurrentExceptionMessage())); + } + return true; + } + } + state = AtomicGet(State_); + if (state == Connected) { + ProcessOutputReqsQueue(); + } else if (state == Closed) { + SafeOnError(); + } + return true; + } + + //called from client thread + void Cancel(TRequestId id) { + Cancels_.Enqueue(id); + AtomicSet(NeedCheckCancelsQueue_, 1); + if (Y_LIKELY(AtomicGet(State_) == Connected)) { + ProcessOutputCancelsQueue(); + } + } + + void ProcessOutputReqsQueue() { + if (OutputLock_.TryAquire()) { + SendMessages(false); + } + } + + void ProcessOutputCancelsQueue() { + if (OutputLock_.TryAquire()) { + AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true)); + return; + } + } + + //must be called only from asio thread + void ProcessReqsInFlyQueue() { + if (AtomicGet(State_) == Closed) { + return; + } + + TRequest* reqPtr; + + while (ReqsInFlyQueue_.Dequeue(&reqPtr)) { + TRequestRef reqTmp(reqPtr); + reqPtr->UnRef(); + ReqsInFly_[reqPtr->ReqId()].Swap(reqTmp); + } + } + + //must be called only from asio thread + void OnConnect(const TErrorCode& ec, IHandlingContext&) { + DBGOUT("TConnect::OnConnect: " << ec.Value()); + if (Y_UNLIKELY(ec)) { + if (ec.Value() == EIO) { + //try get more detail error info + char buf[1]; + TErrorCode errConnect; + AS_.ReadSome(buf, 1, errConnect); + OnErrorCode(errConnect.Value() ? errConnect : ec); + } else { + OnErrorCode(ec); + } + } else { + try { + PrepareSocket(AS_.Native()); + AtomicSet(State_, Connected); + AS_.AsyncPollRead(std::bind(&TConnection::OnCanRead, TConnectionRef(this), _1, _2)); + if (OutputLock_.TryAquire()) { + SendMessages(true); + return; + } + } catch (...) { + OnError(CurrentExceptionMessage()); + } + } + } + + //must be called only after succes aquiring output + void SendMessages(bool asioThread) { + //DBGOUT("SendMessages"); + if (Y_UNLIKELY(AtomicGet(State_) == Closed)) { + if (asioThread) { + OnError(Error_); + } else { + SafeOnError(); + } + return; + } + + do { + if (asioThread) { + AtomicSet(NeedCheckCancelsQueue_, 0); + TRequestId reqId; + + ProcessReqsInFlyQueue(); + while (Cancels_.Dequeue(&reqId)) { + TReqsInFly::iterator it = ReqsInFly_.find(reqId); + if (it == ReqsInFly_.end()) { + continue; + } + + ReqsInFly_.erase(it); + OutputBuffers_.AddCancelRequest(reqId); + if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) { + if (!FlushOutputBuffers(asioThread, 0)) { + return; + } + } + } + } else if (AtomicGet(NeedCheckCancelsQueue_)) { + AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true)); + return; + } + + TRequestId lastReqId = 0; + { + AtomicSet(NeedCheckReqsQueue_, 0); + TRequest* reqPtr; + + while (Reqs_.Dequeue(&reqPtr)) { + TRequestRef reqTmp(reqPtr); + reqPtr->UnRef(); + reqPtr->SetReqId(GenerateReqId()); + if (reqPtr->Canceled()) { + continue; + } + lastReqId = reqPtr->ReqId(); + if (asioThread) { + TRequestRef& req = ReqsInFly_[(TRequestId)reqPtr->ReqId()]; + req.Swap(reqTmp); + OutputBuffers_.AddRequest(req); + } else { //can access to ReqsInFly_ only from asio thread, so enqueue req to update ReqsInFly_ queue + try { + reqTmp->Ref(); + ReqsInFlyQueue_.Enqueue(reqPtr); + } catch (...) { + reqTmp->UnRef(); + throw; + } + OutputBuffers_.AddRequest(reqTmp); + } + if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) { + if (!FlushOutputBuffers(asioThread, lastReqId)) { + return; + } + } + } + } + + if (OutputBuffers_.HasData()) { + if (!FlushOutputBuffers(asioThread, lastReqId)) { + return; + } + } + + OutputLock_.Release(); + + if (!AtomicGet(NeedCheckReqsQueue_) && !AtomicGet(NeedCheckCancelsQueue_)) { + DBGOUT("TClient::SendMessages(exit2)"); + return; + } + } while (OutputLock_.TryAquire()); + DBGOUT("TClient::SendMessages(exit1)"); + } + + TRequestId GenerateReqId() noexcept { + TRequestId reqId; + { + auto guard = Guard(GenReqIdLock_); + reqId = ++GenReqId_; + } + return Y_LIKELY(reqId) ? reqId : GenerateReqId(); + } + + //called non thread-safe (from outside thread) + bool FlushOutputBuffers(bool asioThread, TRequestId reqId) { + if (asioThread || TTcp2Options::ClientUseDirectWrite) { + TContIOVector& vec = *OutputBuffers_.GetIOvec(); + TErrorCode err; + vec.Proceed(AS_.WriteSome(vec, err)); + + if (Y_UNLIKELY(err)) { + if (asioThread) { + OnErrorCode(err); + } else { + AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCode, TConnectionRef(this), err)); + } + return false; + } + + if (vec.Complete()) { + LastSendedReqId_.store(reqId, std::memory_order_release); + DBGOUT("Client::FlushOutputBuffers(" << reqId << ")"); + OutputBuffers_.Clear(); + return true; + } + } + + DBGOUT("Client::AsyncWrite(" << reqId << ")"); + AS_.AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), reqId, _1, _2, _3), TTcp2Options::ServerOutputDeadline); + return false; + } + + //must be called only from asio thread + void OnSend(TRequestId reqId, const TErrorCode& ec, size_t amount, IHandlingContext&) { + Y_UNUSED(amount); + if (Y_UNLIKELY(ec)) { + OnErrorCode(ec); + } else { + if (Y_LIKELY(reqId)) { + DBGOUT("Client::OnSend(" << reqId << ")"); + LastSendedReqId_.store(reqId, std::memory_order_release); + } + //output already aquired, used asio thread + OutputBuffers_.Clear(); + SendMessages(true); + } + } + + //must be called only from asio thread + void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) { + //DBGOUT("OnCanRead(" << ec.Value() << ")"); + if (Y_UNLIKELY(ec)) { + OnErrorCode(ec); + } else { + TErrorCode ec2; + OnReadSome(ec2, AS_.ReadSome(Buff_.Get(), BuffSize_, ec2), ctx); + } + } + + //must be called only from asio thread + void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) { + //DBGOUT("OnReadSome(" << ec.Value() << ", " << amount << ")"); + if (Y_UNLIKELY(ec)) { + OnErrorCode(ec); + + return; + } + + while (1) { + if (Y_UNLIKELY(!amount)) { + OnError("tcp conn. closed"); + + return; + } + + try { + const char* buff = Buff_.Get(); + size_t leftBytes = amount; + do { + size_t useBytes = Msg_.LoadFrom(buff, leftBytes); + leftBytes -= useBytes; + buff += useBytes; + if (Msg_.IsComplete()) { + //DBGOUT("OnReceiveMessage(" << Msg_.BaseHeader().Id << "): " << leftBytes); + OnReceiveMessage(); + Msg_.Clear(); + } + } while (leftBytes); + + if (amount == BuffSize_) { + //try decrease system calls, - re-run ReadSome if has full filled buffer + TErrorCode ecR; + amount = AS_.ReadSome(Buff_.Get(), BuffSize_, ecR); + if (!ecR) { + continue; //process next input data + } + if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) { + ctx.ContinueUseHandler(); + } else { + OnErrorCode(ec); + } + } else { + ctx.ContinueUseHandler(); + } + } catch (...) { + OnError(CurrentExceptionMessage()); + } + + return; + } + } + + //must be called only from asio thread + void OnErrorCode(TErrorCode ec) { + OnError(ec.Text(), ec.Value()); + } + + //must be called only from asio thread + void OnErrorCallback(TString err) { + OnError(err); + } + + //must be called only from asio thread + void OnError(const TString& err, const i32 systemCode = 0) { + if (AtomicGet(State_) != Closed) { + Error_ = err; + SystemCode_ = systemCode; + AtomicSet(State_, Closed); + AS_.AsyncCancel(); + } + SafeOnError(); + for (auto& it : ReqsInFly_) { + it.second->OnError(err); + } + ReqsInFly_.clear(); + } + + void SafeOnError() { + TRequest* reqPtr; + + while (Reqs_.Dequeue(&reqPtr)) { + TRequestRef req(reqPtr); + reqPtr->UnRef(); + //DBGOUT("err queue(" << AS_.Native() << "):" << size_t(reqPtr)); + req->OnError(Error_, SystemCode_); + } + + while (ReqsInFlyQueue_.Dequeue(&reqPtr)) { + TRequestRef req(reqPtr); + reqPtr->UnRef(); + //DBGOUT("err fly queue(" << AS_.Native() << "):" << size_t(reqPtr)); + req->OnError(Error_, SystemCode_); + } + } + + //must be called only from asio thread + void OnReceiveMessage() { + //DBGOUT("OnReceiveMessage"); + const TBaseHeader& hdr = Msg_.BaseHeader(); + + if (hdr.Type == TBaseHeader::Response) { + ProcessReqsInFlyQueue(); + TReqsInFly::iterator it = ReqsInFly_.find(hdr.Id); + if (it == ReqsInFly_.end()) { + DBGOUT("ignore response: " << hdr.Id); + return; + } + + it->second->OnResponse(Msg_); + ReqsInFly_.erase(it); + } else { + throw yexception() << TStringBuf("unsupported message type: ") << hdr.Type; + } + } + + TRequestId LastSendedRequestId() const noexcept { + return LastSendedReqId_.load(std::memory_order_acquire); + } + + private: + NAsio::TTcpSocket AS_; + TAtomic State_; //state machine status (TState) + TString Error_; + i32 SystemCode_ = 0; + + //input + size_t BuffSize_; + TArrayHolder<char> Buff_; + TTcp2Message Msg_; + + //output + TOutputLock OutputLock_; + TAtomic NeedCheckReqsQueue_; + TLockFreeQueue<TRequest*> Reqs_; + TAtomic NeedCheckCancelsQueue_; + TLockFreeQueue<TRequestId> Cancels_; + TAdaptiveLock GenReqIdLock_; + std::atomic<TRequestId> GenReqId_; + std::atomic<TRequestId> LastSendedReqId_; + TLockFreeQueue<TRequest*> ReqsInFlyQueue_; + TReqsInFly ReqsInFly_; + TOutputBuffers OutputBuffers_; + }; + + class TDestination { + public: + void Run(TRequestRef& req) { + while (1) { + TConnectionRef conn = GetConnection(); + if (!!conn && conn->Run(req)) { + return; + } + + DBGOUT("TDestination CreateConnection"); + CreateConnection(conn, req->Client().ExecutorsPool().GetExecutor().GetIOService()); + } + } + + private: + TConnectionRef GetConnection() { + TGuard<TSpinLock> g(L_); + return Conn_; + } + + void CreateConnection(TConnectionRef& oldConn, TIOService& srv) { + TConnectionRef conn(new TConnection(srv)); + TGuard<TSpinLock> g(L_); + if (Conn_ == oldConn) { + Conn_.Swap(conn); + } + } + + TSpinLock L_; + TConnectionRef Conn_; + }; + + //////////// TClient ///////// + + public: + TClient() + : EP_(TTcp2Options::AsioClientThreads) + { + } + + ~TClient() { + EP_.SyncShutdown(); + } + + THandleRef Schedule(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) { + //find exist connection or create new + TRequest::THandleRef hndl(new TRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + try { + TRequest::Run(hndl, msg, *this); + } catch (...) { + hndl->ResetOnRecv(); + hndl->ReleaseRequest(); + throw; + } + return hndl.Get(); + } + + TExecutorsPool& ExecutorsPool() { + return EP_; + } + + private: + NNeh::NHttp::TLockFreeSequence<TDestination> Dest_; + TExecutorsPool EP_; + }; + + ////////// server side //////////////////////////////////////////////////////////////////////////////////////////// + + class TServer: public IRequester { + typedef TAutoPtr<TTcpAcceptor> TTcpAcceptorPtr; + typedef TAtomicSharedPtr<TTcpSocket> TTcpSocketRef; + class TConnection; + typedef TIntrusivePtr<TConnection> TConnectionRef; + + struct TRequest: public IRequest { + struct TState: public TThrRefBase { + TState() + : Canceled(false) + { + } + + TAtomicBool Canceled; + }; + typedef TIntrusivePtr<TState> TStateRef; + + TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content); + ~TRequest() override; + + TStringBuf Scheme() const override { + return TStringBuf("tcp2"); + } + + TString RemoteHost() const override; + + TStringBuf Service() const override { + return TStringBuf(Buf.Data() + sizeof(TRequestHeader), Buf.End()); + } + + TStringBuf Data() const override { + return TStringBuf(Content_); + } + + TStringBuf RequestId() const override { + return TStringBuf(); + } + + bool Canceled() const override { + return State->Canceled; + } + + void SendReply(TData& data) override; + + void SendError(TResponseError, const TString&) override { + // TODO + } + + const TRequestHeader& RequestHeader() const noexcept { + return *reinterpret_cast<const TRequestHeader*>(Buf.Data()); + } + + private: + TConnectionRef Conn; + TBuffer Buf; //service-name + message-data + TString Content_; + TAtomic Replied_; + + public: + TIntrusivePtr<TState> State; + }; + + class TConnection: public TThrRefBase { + private: + TConnection(TServer& srv, const TTcpSocketRef& sock) + : Srv_(srv) + , AS_(sock) + , Canceled_(false) + , RemoteHost_(NNeh::PrintHostByRfc(*AS_->RemoteEndpoint().Addr())) + , BuffSize_(TTcp2Options::InputBufferSize) + , Buff_(new char[BuffSize_]) + , NeedCheckOutputQueue_(0) + { + DBGOUT("TServer::TConnection()"); + } + + public: + class TOutputBuffers: public TMultiBuffers { + public: + void AddResponse(TRequestId reqId, TData& data) { + TResponseHeader* hdr = new (Allocate<TResponseHeader>()) TResponseHeader(reqId, TResponseHeader::Success, data.size()); + ResponseData_.push_back(TAutoPtr<TData>(new TData())); + TData& movedData = *ResponseData_.back(); + movedData.swap(data); + AddPart(hdr, sizeof(TResponseHeader)); + AddPart(movedData.data(), movedData.size()); + IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); + } + + void AddError(TRequestId reqId, TResponseHeader::TErrorCode errCode) { + TResponseHeader* hdr = new (Allocate<TResponseHeader>()) TResponseHeader(reqId, errCode, 0); + AddPart(hdr, sizeof(TResponseHeader)); + IOVec_ = TContIOVector(Parts_.data(), Parts_.size()); + } + + void Clear() { + TMultiBuffers::Clear(); + ResponseData_.clear(); + } + + private: + TVector<TAutoPtr<TData>> ResponseData_; + }; + + static void Create(TServer& srv, const TTcpSocketRef& sock) { + TConnectionRef conn(new TConnection(srv, sock)); + conn->AS_->AsyncPollRead(std::bind(&TConnection::OnCanRead, conn, _1, _2), TTcp2Options::ServerInputDeadline); + } + + ~TConnection() override { + DBGOUT("~TServer::TConnection(" << (!AS_ ? -666 : AS_->Native()) << ")"); + } + + private: + void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) { + if (ec) { + OnError(); + } else { + TErrorCode ec2; + OnReadSome(ec2, AS_->ReadSome(Buff_.Get(), BuffSize_, ec2), ctx); + } + } + + void OnError() { + DBGOUT("Srv OnError(" << (!AS_ ? -666 : AS_->Native()) << ")" + << " c=" << (size_t)this); + Canceled_ = true; + AS_->AsyncCancel(); + } + + void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) { + while (1) { + if (ec || !amount) { + OnError(); + return; + } + + try { + const char* buff = Buff_.Get(); + size_t leftBytes = amount; + do { + size_t useBytes = Msg_.LoadFrom(buff, leftBytes); + leftBytes -= useBytes; + buff += useBytes; + if (Msg_.IsComplete()) { + OnReceiveMessage(); + } + } while (leftBytes); + + if (amount == BuffSize_) { + //try decrease system calls, - re-run ReadSome if has full filled buffer + TErrorCode ecR; + amount = AS_->ReadSome(Buff_.Get(), BuffSize_, ecR); + if (!ecR) { + continue; + } + if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) { + ctx.ContinueUseHandler(); + } else { + OnError(); + } + } else { + ctx.ContinueUseHandler(); + } + } catch (...) { + DBGOUT("exc. " << CurrentExceptionMessage()); + OnError(); + } + return; + } + } + + void OnReceiveMessage() { + DBGOUT("OnReceiveMessage()"); + const TBaseHeader& hdr = Msg_.BaseHeader(); + + if (hdr.Type == TBaseHeader::Request) { + TRequest* reqPtr = new TRequest(TConnectionRef(this), Msg_.Header(), Msg_.Content()); + IRequestRef req(reqPtr); + ReqsState_[reqPtr->RequestHeader().Id] = reqPtr->State; + OnRequest(req); + } else if (hdr.Type == TBaseHeader::Cancel) { + OnCancelRequest(hdr.Id); + } else { + throw yexception() << "unsupported message type: " << (ui32)hdr.Type; + } + Msg_.Clear(); + { + TRequestId reqId; + while (FinReqs_.Dequeue(&reqId)) { + ReqsState_.erase(reqId); + } + } + } + + void OnRequest(IRequestRef& r) { + DBGOUT("OnRequest()"); + Srv_.OnRequest(r); + } + + void OnCancelRequest(TRequestId reqId) { + THashMap<TRequestId, TRequest::TStateRef>::iterator it = ReqsState_.find(reqId); + if (it == ReqsState_.end()) { + return; + } + + it->second->Canceled = true; + } + + public: + class TOutputData { + public: + TOutputData(TRequestId reqId) + : ReqId(reqId) + { + } + + virtual ~TOutputData() { + } + + virtual void MoveTo(TOutputBuffers& bufs) = 0; + + TRequestId ReqId; + }; + + class TResponseData: public TOutputData { + public: + TResponseData(TRequestId reqId, TData& data) + : TOutputData(reqId) + { + Data.swap(data); + } + + void MoveTo(TOutputBuffers& bufs) override { + bufs.AddResponse(ReqId, Data); + } + + TData Data; + }; + + class TResponseErrorData: public TOutputData { + public: + TResponseErrorData(TRequestId reqId, TResponseHeader::TErrorCode errorCode) + : TOutputData(reqId) + , ErrorCode(errorCode) + { + } + + void MoveTo(TOutputBuffers& bufs) override { + bufs.AddError(ReqId, ErrorCode); + } + + TResponseHeader::TErrorCode ErrorCode; + }; + + //called non thread-safe (from client thread) + void SendResponse(TRequestId reqId, TData& data) { + DBGOUT("SendResponse: " << reqId << " " << (size_t)~data << " c=" << (size_t)this); + TAutoPtr<TOutputData> od(new TResponseData(reqId, data)); + OutputData_.Enqueue(od); + ProcessOutputQueue(); + } + + //called non thread-safe (from outside thread) + void SendError(TRequestId reqId, TResponseHeader::TErrorCode err) { + DBGOUT("SendResponseError: " << reqId << " c=" << (size_t)this); + TAutoPtr<TOutputData> od(new TResponseErrorData(reqId, err)); + OutputData_.Enqueue(od); + ProcessOutputQueue(); + } + + void ProcessOutputQueue() { + AtomicSet(NeedCheckOutputQueue_, 1); + if (OutputLock_.TryAquire()) { + SendMessages(false); + return; + } + DBGOUT("ProcessOutputQueue: !AquireOutputOwnership: " << (int)OutputLock_.IsFree()); + } + + //must be called only after success aquiring output + void SendMessages(bool asioThread) { + DBGOUT("TServer::SendMessages(enter)"); + try { + do { + AtomicUnlock(&NeedCheckOutputQueue_); + TAutoPtr<TOutputData> d; + while (OutputData_.Dequeue(&d)) { + d->MoveTo(OutputBuffers_); + if (!OutputBuffers_.HasFreeSpace()) { + if (!FlushOutputBuffers(asioThread)) { + return; + } + } + } + + if (OutputBuffers_.HasData()) { + if (!FlushOutputBuffers(asioThread)) { + return; + } + } + + OutputLock_.Release(); + + if (!AtomicGet(NeedCheckOutputQueue_)) { + DBGOUT("Server::SendMessages(exit2): " << (int)OutputLock_.IsFree()); + return; + } + } while (OutputLock_.TryAquire()); + DBGOUT("Server::SendMessages(exit1)"); + } catch (...) { + OnError(); + } + } + + bool FlushOutputBuffers(bool asioThread) { + DBGOUT("FlushOutputBuffers: cnt=" << OutputBuffers_.GetIOvec()->Count() << " c=" << (size_t)this); + //TODO:reseach direct write efficiency + if (asioThread || TTcp2Options::ServerUseDirectWrite) { + TContIOVector& vec = *OutputBuffers_.GetIOvec(); + + vec.Proceed(AS_->WriteSome(vec)); + + if (vec.Complete()) { + OutputBuffers_.Clear(); + //DBGOUT("WriteResponse: " << " c=" << (size_t)this); + return true; + } + } + + //socket buffer filled - use async write for sending left data + DBGOUT("AsyncWriteResponse: " + << " [" << OutputBuffers_.GetIOvec()->Bytes() << "]" + << " c=" << (size_t)this); + AS_->AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), _1, _2, _3), TTcp2Options::ServerOutputDeadline); + return false; + } + + void OnFinishRequest(TRequestId reqId) { + if (Y_LIKELY(!Canceled_)) { + FinReqs_.Enqueue(reqId); + } + } + + private: + void OnSend(const TErrorCode& ec, size_t amount, IHandlingContext&) { + Y_UNUSED(amount); + DBGOUT("TServer::OnSend(" << ec.Value() << ", " << amount << ")"); + if (ec) { + OnError(); + } else { + OutputBuffers_.Clear(); + SendMessages(true); + } + } + + public: + bool IsCanceled() const noexcept { + return Canceled_; + } + + const TString& RemoteHost() const noexcept { + return RemoteHost_; + } + + private: + TServer& Srv_; + TTcpSocketRef AS_; + NAtomic::TBool Canceled_; + TString RemoteHost_; + + //input + size_t BuffSize_; + TArrayHolder<char> Buff_; + TTcp2Message Msg_; + THashMap<TRequestId, TRequest::TStateRef> ReqsState_; + TLockFreeQueue<TRequestId> FinReqs_; + + //output + TOutputLock OutputLock_; //protect socket/buffers from simultaneous access from few threads + TAtomic NeedCheckOutputQueue_; + NNeh::TAutoLockFreeQueue<TOutputData> OutputData_; + TOutputBuffers OutputBuffers_; + }; + + //////////// TServer ///////// + public: + TServer(IOnRequest* cb, ui16 port) + : EP_(TTcp2Options::AsioServerThreads) + , CB_(cb) + { + TNetworkAddress addr(port); + + for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { + TEndpoint ep(new NAddr::TAddrInfo(&*it)); + TTcpAcceptorPtr a(new TTcpAcceptor(EA_.GetIOService())); + //DBGOUT("bind:" << ep.IpToString() << ":" << ep.Port()); + a->Bind(ep); + a->Listen(TTcp2Options::Backlog); + StartAccept(a.Get()); + A_.push_back(a); + } + } + + ~TServer() override { + EA_.SyncShutdown(); //cancel accepting connections + A_.clear(); //stop listening + EP_.SyncShutdown(); //close all exist connections + } + + void StartAccept(TTcpAcceptor* a) { + const auto s = MakeAtomicShared<TTcpSocket>(EP_.Size() ? EP_.GetExecutor().GetIOService() : EA_.GetIOService()); + a->AsyncAccept(*s, std::bind(&TServer::OnAccept, this, a, s, _1, _2)); + } + + void OnAccept(TTcpAcceptor* a, TTcpSocketRef s, const TErrorCode& ec, IHandlingContext&) { + if (Y_UNLIKELY(ec)) { + if (ec.Value() == ECANCELED) { + return; + } else if (ec.Value() == EMFILE || ec.Value() == ENFILE || ec.Value() == ENOMEM || ec.Value() == ENOBUFS) { + //reach some os limit, suspend accepting for preventing busyloop (100% cpu usage) + TSimpleSharedPtr<TDeadlineTimer> dt(new TDeadlineTimer(a->GetIOService())); + dt->AsyncWaitExpireAt(TDuration::Seconds(30), std::bind(&TServer::OnTimeoutSuspendAccept, this, a, dt, _1, _2)); + } else { + Cdbg << "acc: " << ec.Text() << Endl; + } + } else { + SetNonBlock(s->Native()); + PrepareSocket(s->Native()); + TConnection::Create(*this, s); + } + StartAccept(a); //continue accepting + } + + void OnTimeoutSuspendAccept(TTcpAcceptor* a, TSimpleSharedPtr<TDeadlineTimer>, const TErrorCode& ec, IHandlingContext&) { + if (!ec) { + DBGOUT("resume acceptor"); + StartAccept(a); + } + } + + void OnRequest(IRequestRef& r) { + try { + CB_->OnRequest(r); + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + } + } + + private: + TVector<TTcpAcceptorPtr> A_; + TIOServiceExecutor EA_; //thread, where accepted incoming tcp connections + TExecutorsPool EP_; //threads, for process write/read data to/from tcp connections (if empty, use EA_ for r/w) + IOnRequest* CB_; + }; + + TServer::TRequest::TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content) + : Conn(conn) + , Content_(content) + , Replied_(0) + , State(new TState()) + { + DBGOUT("TServer::TRequest()"); + Buf.Swap(buf); + } + + TServer::TRequest::~TRequest() { + DBGOUT("TServer::~TRequest()"); + if (!AtomicGet(Replied_)) { + Conn->SendError(RequestHeader().Id, TResponseHeader::EmptyReply); + } + Conn->OnFinishRequest(RequestHeader().Id); + } + + TString TServer::TRequest::RemoteHost() const { + return Conn->RemoteHost(); + } + + void TServer::TRequest::SendReply(TData& data) { + do { + if (AtomicCas(&Replied_, 1, 0)) { + Conn->SendResponse(RequestHeader().Id, data); + return; + } + } while (AtomicGet(Replied_) == 0); + } + + class TProtocol: public IProtocol { + public: + inline TProtocol() { + InitNetworkSubSystem(); + } + + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new TServer(cb, loc.GetPort()); + } + + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + return Singleton<TClient>()->Schedule(msg, fallback, ss); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("tcp2"); + } + + bool SetOption(TStringBuf name, TStringBuf value) override { + return TTcp2Options::Set(name, value); + } + }; + } +} + +NNeh::IProtocol* NNeh::Tcp2Protocol() { + return Singleton<NNehTcp2::TProtocol>(); +} diff --git a/library/cpp/neh/tcp2.h b/library/cpp/neh/tcp2.h new file mode 100644 index 0000000000..bd6d8c25bd --- /dev/null +++ b/library/cpp/neh/tcp2.h @@ -0,0 +1,44 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/system/defaults.h> + +namespace NNeh { + //global options + struct TTcp2Options { + //connect timeout + static TDuration ConnectTimeout; + + //input buffer size + static size_t InputBufferSize; + + //asio client threads + static size_t AsioClientThreads; + + //asio server threads, - if == 0, use acceptor thread for read/parse incoming requests + //esle use one thread for accepting + AsioServerThreads for process established tcp connections + static size_t AsioServerThreads; + + //listen socket queue limit + static int Backlog; + + //try call non block write to socket from client thread (for decrease latency) + static bool ClientUseDirectWrite; + + //try call non block write to socket from client thread (for decrease latency) + static bool ServerUseDirectWrite; + + //expecting receiving request data right after connect or inside receiving request data + static TDuration ServerInputDeadline; + + //timelimit for sending response data + static TDuration ServerOutputDeadline; + + //set option, - return false, if option name not recognized + static bool Set(TStringBuf name, TStringBuf value); + }; + + class IProtocol; + + IProtocol* Tcp2Protocol(); +} diff --git a/library/cpp/neh/udp.cpp b/library/cpp/neh/udp.cpp new file mode 100644 index 0000000000..13250a2493 --- /dev/null +++ b/library/cpp/neh/udp.cpp @@ -0,0 +1,691 @@ +#include "udp.h" +#include "details.h" +#include "neh.h" +#include "location.h" +#include "utils.h" +#include "factory.h" + +#include <library/cpp/dns/cache.h> + +#include <util/network/socket.h> +#include <util/network/address.h> +#include <util/generic/deque.h> +#include <util/generic/hash.h> +#include <util/generic/string.h> +#include <util/generic/buffer.h> +#include <util/generic/singleton.h> +#include <util/digest/murmur.h> +#include <util/random/random.h> +#include <util/ysaveload.h> +#include <util/system/thread.h> +#include <util/system/pipe.h> +#include <util/system/error.h> +#include <util/stream/mem.h> +#include <util/stream/buffer.h> +#include <util/string/cast.h> + +using namespace NNeh; +using namespace NDns; +using namespace NAddr; + +namespace { + namespace NUdp { + enum EPacketType { + PT_REQUEST = 1, + PT_RESPONSE = 2, + PT_STOP = 3, + PT_TIMEOUT = 4 + }; + + struct TUdpHandle: public TNotifyHandle { + inline TUdpHandle(IOnRecv* r, const TMessage& msg, TStatCollector* sc) noexcept + : TNotifyHandle(r, msg, sc) + { + } + + void Cancel() noexcept override { + THandle::Cancel(); //inform stat collector + } + + bool MessageSendedCompletely() const noexcept override { + //TODO + return true; + } + }; + + static inline IRemoteAddrPtr GetSendAddr(SOCKET s) { + IRemoteAddrPtr local = GetSockAddr(s); + const sockaddr* addr = local->Addr(); + + switch (addr->sa_family) { + case AF_INET: { + const TIpAddress a = *(const sockaddr_in*)addr; + + return MakeHolder<TIPv4Addr>(TIpAddress(InetToHost(INADDR_LOOPBACK), a.Port())); + } + + case AF_INET6: { + sockaddr_in6 a = *(const sockaddr_in6*)addr; + + a.sin6_addr = in6addr_loopback; + + return MakeHolder<TIPv6Addr>(a); + } + } + + ythrow yexception() << "unsupported"; + } + + typedef ui32 TCheckSum; + + static inline TString GenerateGuid() { + const ui64 res[2] = { + RandomNumber<ui64>(), RandomNumber<ui64>()}; + + return TString((const char*)res, sizeof(res)); + } + + static inline TCheckSum Sum(const TStringBuf& s) noexcept { + return HostToInet(MurmurHash<TCheckSum>(s.data(), s.size())); + } + + struct TPacket; + + template <class T> + static inline void Serialize(TPacket& p, const T& t); + + struct TPacket { + inline TPacket(IRemoteAddrPtr addr) + : Addr(std::move(addr)) + { + } + + template <class T> + inline TPacket(const T& t, IRemoteAddrPtr addr) + : Addr(std::move(addr)) + { + NUdp::Serialize(*this, t); + } + + inline TPacket(TSocketHolder& s, TBuffer& tmp) { + TAutoPtr<TOpaqueAddr> addr(new TOpaqueAddr()); + + retry_on_intr : { + const int rv = recvfrom(s, tmp.Data(), tmp.size(), MSG_WAITALL, addr->MutableAddr(), addr->LenPtr()); + + if (rv < 0) { + int err = LastSystemError(); + if (err == EAGAIN || err == EWOULDBLOCK) { + Data.Resize(sizeof(TCheckSum) + 1); + *(Data.data() + sizeof(TCheckSum)) = static_cast<char>(PT_TIMEOUT); + } else if (err == EINTR) { + goto retry_on_intr; + } else { + ythrow TSystemError() << "recv failed"; + } + } else { + Data.Append(tmp.Data(), (size_t)rv); + Addr.Reset(addr.Release()); + CheckSign(); + } + } + } + + inline void SendTo(TSocketHolder& s) { + Sign(); + + if (sendto(s, Data.data(), Data.size(), 0, Addr->Addr(), Addr->Len()) < 0) { + Cdbg << LastSystemErrorText() << Endl; + } + } + + IRemoteAddrPtr Addr; + TBuffer Data; + + inline void Sign() { + const TCheckSum sum = CalcSign(); + + memcpy(Data.Data(), &sum, sizeof(sum)); + } + + inline char Type() const noexcept { + return *(Data.data() + sizeof(TCheckSum)); + } + + inline void CheckSign() const { + if (Data.size() < 16) { + ythrow yexception() << "small packet"; + } + + if (StoredSign() != CalcSign()) { + ythrow yexception() << "bad checksum"; + } + } + + inline TCheckSum StoredSign() const noexcept { + TCheckSum sum; + + memcpy(&sum, Data.Data(), sizeof(sum)); + + return sum; + } + + inline TCheckSum CalcSign() const noexcept { + return Sum(Body()); + } + + inline TStringBuf Body() const noexcept { + return TStringBuf(Data.data() + sizeof(TCheckSum), Data.End()); + } + }; + + typedef TAutoPtr<TPacket> TPacketRef; + + class TPacketInput: public TMemoryInput { + public: + inline TPacketInput(const TPacket& p) + : TMemoryInput(p.Body().data(), p.Body().size()) + { + } + }; + + class TPacketOutput: public TBufferOutput { + public: + inline TPacketOutput(TPacket& p) + : TBufferOutput(p.Data) + { + p.Data.Proceed(sizeof(TCheckSum)); + } + }; + + template <class T> + static inline void Serialize(TPacketOutput* out, const T& t) { + Save(out, t.Type()); + t.Serialize(out); + } + + template <class T> + static inline void Serialize(TPacket& p, const T& t) { + TPacketOutput out(p); + + NUdp::Serialize(&out, t); + } + + namespace NPrivate { + template <class T> + static inline void Deserialize(TPacketInput* in, T& t) { + char type; + Load(in, type); + + if (type != t.Type()) { + ythrow yexception() << "unsupported packet"; + } + + t.Deserialize(in); + } + + template <class T> + static inline void Deserialize(const TPacket& p, T& t) { + TPacketInput in(p); + + Deserialize(&in, t); + } + } + + struct TRequestPacket { + TString Guid; + TString Service; + TString Data; + + inline TRequestPacket(const TPacket& p) { + NPrivate::Deserialize(p, *this); + } + + inline TRequestPacket(const TString& srv, const TString& data) + : Guid(GenerateGuid()) + , Service(srv) + , Data(data) + { + } + + inline char Type() const noexcept { + return static_cast<char>(PT_REQUEST); + } + + inline void Serialize(TPacketOutput* out) const { + Save(out, Guid); + Save(out, Service); + Save(out, Data); + } + + inline void Deserialize(TPacketInput* in) { + Load(in, Guid); + Load(in, Service); + Load(in, Data); + } + }; + + template <class TStore> + struct TResponsePacket { + TString Guid; + TStore Data; + + inline TResponsePacket(const TString& guid, TStore& data) + : Guid(guid) + { + Data.swap(data); + } + + inline TResponsePacket(const TPacket& p) { + NPrivate::Deserialize(p, *this); + } + + inline char Type() const noexcept { + return static_cast<char>(PT_RESPONSE); + } + + inline void Serialize(TPacketOutput* out) const { + Save(out, Guid); + Save(out, Data); + } + + inline void Deserialize(TPacketInput* in) { + Load(in, Guid); + Load(in, Data); + } + }; + + struct TStopPacket { + inline char Type() const noexcept { + return static_cast<char>(PT_STOP); + } + + inline void Serialize(TPacketOutput* out) const { + Save(out, TString("stop packet")); + } + }; + + struct TBindError: public TSystemError { + }; + + struct TSocketDescr { + inline TSocketDescr(TSocketHolder& s, int family) + : S(s.Release()) + , Family(family) + { + } + + TSocketHolder S; + int Family; + }; + + typedef TAutoPtr<TSocketDescr> TSocketRef; + typedef TVector<TSocketRef> TSockets; + + static inline void CreateSocket(TSocketHolder& s, const IRemoteAddr& addr) { + TSocketHolder res(socket(addr.Addr()->sa_family, SOCK_DGRAM, IPPROTO_UDP)); + + if (!res) { + ythrow TSystemError() << "can not create socket"; + } + + FixIPv6ListenSocket(res); + + if (bind(res, addr.Addr(), addr.Len()) != 0) { + ythrow TBindError() << "can not bind " << PrintHostAndPort(addr); + } + + res.Swap(s); + } + + static inline void CreateSockets(TSockets& s, ui16 port) { + TNetworkAddress addr(port); + + for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { + TSocketHolder res; + + CreateSocket(res, TAddrInfo(&*it)); + + s.push_back(new TSocketDescr(res, it->ai_family)); + } + } + + static inline void CreateSocketsOnRandomPort(TSockets& s) { + while (true) { + try { + TSockets tmp; + + CreateSockets(tmp, 5000 + (RandomNumber<ui16>() % 1000)); + tmp.swap(s); + + return; + } catch (const TBindError&) { + } + } + } + + typedef ui64 TTimeStamp; + + static inline TTimeStamp TimeStamp() noexcept { + return GetCycleCount() >> 31; + } + + struct TRequestDescr: public TIntrusiveListItem<TRequestDescr> { + inline TRequestDescr(const TString& guid, const TNotifyHandleRef& hndl, const TMessage& msg) + : Guid(guid) + , Hndl(hndl) + , Msg(msg) + , TS(TimeStamp()) + { + } + + TString Guid; + TNotifyHandleRef Hndl; + TMessage Msg; + TTimeStamp TS; + }; + + typedef TAutoPtr<TRequestDescr> TRequestDescrRef; + + class TProto { + class TRequest: public IRequest, public TRequestPacket { + public: + inline TRequest(TPacket& p, TProto* parent) + : TRequestPacket(p) + , Addr_(std::move(p.Addr)) + , H_(PrintHostByRfc(*Addr_)) + , P_(parent) + { + } + + TStringBuf Scheme() const override { + return TStringBuf("udp"); + } + + TString RemoteHost() const override { + return H_; + } + + TStringBuf Service() const override { + return ((TRequestPacket&)(*this)).Service; + } + + TStringBuf Data() const override { + return ((TRequestPacket&)(*this)).Data; + } + + TStringBuf RequestId() const override { + return ((TRequestPacket&)(*this)).Guid; + } + + bool Canceled() const override { + //TODO ? + return false; + } + + void SendReply(TData& data) override { + P_->Schedule(new TPacket(TResponsePacket<TData>(Guid, data), std::move(Addr_))); + } + + void SendError(TResponseError, const TString&) override { + // TODO + } + + private: + IRemoteAddrPtr Addr_; + TString H_; + TProto* P_; + }; + + public: + inline TProto(IOnRequest* cb, TSocketHolder& s) + : CB_(cb) + , ToSendEv_(TSystemEvent::rAuto) + , S_(s.Release()) + { + SetSocketTimeout(S_, 10); + Thrs_.push_back(Spawn<TProto, &TProto::ExecuteRecv>(this)); + Thrs_.push_back(Spawn<TProto, &TProto::ExecuteSend>(this)); + } + + inline ~TProto() { + Schedule(new TPacket(TStopPacket(), GetSendAddr(S_))); + + for (size_t i = 0; i < Thrs_.size(); ++i) { + Thrs_[i]->Join(); + } + } + + inline TPacketRef Recv() { + TBuffer tmp; + + tmp.Resize(128 * 1024); + + while (true) { + try { + return new TPacket(S_, tmp); + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + + continue; + } + } + } + + typedef THashMap<TString, TRequestDescrRef> TInFlyBase; + + struct TInFly: public TInFlyBase, public TIntrusiveList<TRequestDescr> { + typedef TInFlyBase::iterator TIter; + typedef TInFlyBase::const_iterator TContsIter; + + inline void Insert(TRequestDescrRef& d) { + PushBack(d.Get()); + (*this)[d->Guid] = d; + } + + inline void EraseStale() noexcept { + const TTimeStamp now = TimeStamp(); + + for (TIterator it = Begin(); (it != End()) && (it->TS < now) && ((now - it->TS) > 120);) { + it->Hndl->NotifyError("request timeout"); + TString safe_key = (it++)->Guid; + erase(safe_key); + } + } + }; + + inline void ExecuteRecv() { + SetHighestThreadPriority(); + + TInFly infly; + + while (true) { + TPacketRef p = Recv(); + + switch (static_cast<EPacketType>(p->Type())) { + case PT_REQUEST: + if (CB_) { + CB_->OnRequest(new TRequest(*p, this)); + } else { + //skip request in case of client + } + + break; + + case PT_RESPONSE: { + CancelStaleRequests(infly); + + TResponsePacket<TString> rp(*p); + + TInFly::TIter it = static_cast<TInFlyBase&>(infly).find(rp.Guid); + + if (it == static_cast<TInFlyBase&>(infly).end()) { + break; + } + + const TRequestDescrRef& d = it->second; + d->Hndl->NotifyResponse(rp.Data); + + infly.erase(it); + + break; + } + + case PT_STOP: + Schedule(nullptr); + + return; + + case PT_TIMEOUT: + CancelStaleRequests(infly); + + break; + } + } + } + + inline void ExecuteSend() { + SetHighestThreadPriority(); + + while (true) { + TPacketRef p; + + while (!ToSend_.Dequeue(&p)) { + ToSendEv_.Wait(); + } + + //shutdown + if (!p) { + return; + } + + p->SendTo(S_); + } + } + + inline void Schedule(TPacketRef p) { + ToSend_.Enqueue(p); + ToSendEv_.Signal(); + } + + inline void Schedule(TRequestDescrRef dsc, TPacketRef p) { + ScheduledReqs_.Enqueue(dsc); + Schedule(p); + } + + protected: + void CancelStaleRequests(TInFly& infly) { + TRequestDescrRef d; + + while (ScheduledReqs_.Dequeue(&d)) { + infly.Insert(d); + } + + infly.EraseStale(); + } + + IOnRequest* CB_; + NNeh::TAutoLockFreeQueue<TPacket> ToSend_; + NNeh::TAutoLockFreeQueue<TRequestDescr> ScheduledReqs_; + TSystemEvent ToSendEv_; + TSocketHolder S_; + TVector<TThreadRef> Thrs_; + }; + + class TProtos { + public: + inline TProtos() { + TSockets s; + + CreateSocketsOnRandomPort(s); + Init(nullptr, s); + } + + inline TProtos(IOnRequest* cb, ui16 port) { + TSockets s; + + CreateSockets(s, port); + Init(cb, s); + } + + static inline TProtos* Instance() { + return Singleton<TProtos>(); + } + + inline void Schedule(const TMessage& msg, const TNotifyHandleRef& hndl) { + TParsedLocation loc(msg.Addr); + const TNetworkAddress* addr = &CachedThrResolve(TResolveInfo(loc.Host, loc.GetPort()))->Addr; + + for (TNetworkAddress::TIterator ai = addr->Begin(); ai != addr->End(); ++ai) { + TProto* proto = Find(ai->ai_family); + + if (proto) { + TRequestPacket rp(ToString(loc.Service), msg.Data); + TRequestDescrRef rd(new TRequestDescr(rp.Guid, hndl, msg)); + IRemoteAddrPtr raddr(new TAddrInfo(&*ai)); + TPacketRef p(new TPacket(rp, std::move(raddr))); + + proto->Schedule(rd, p); + + return; + } + } + + ythrow yexception() << "unsupported protocol family"; + } + + private: + inline void Init(IOnRequest* cb, TSockets& s) { + for (auto& it : s) { + P_[it->Family] = new TProto(cb, it->S); + } + } + + inline TProto* Find(int family) const { + TProtoStorage::const_iterator it = P_.find(family); + + if (it == P_.end()) { + return nullptr; + } + + return it->second.Get(); + } + + private: + typedef TAutoPtr<TProto> TProtoRef; + typedef THashMap<int, TProtoRef> TProtoStorage; + TProtoStorage P_; + }; + + class TRequester: public IRequester, public TProtos { + public: + inline TRequester(IOnRequest* cb, ui16 port) + : TProtos(cb, port) + { + } + }; + + class TProtocol: public IProtocol { + public: + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new TRequester(cb, loc.GetPort()); + } + + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + TNotifyHandleRef ret(new TUdpHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + + TProtos::Instance()->Schedule(msg, ret); + + return ret.Get(); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("udp"); + } + }; + } +} + +IProtocol* NNeh::UdpProtocol() { + return Singleton<NUdp::TProtocol>(); +} diff --git a/library/cpp/neh/udp.h b/library/cpp/neh/udp.h new file mode 100644 index 0000000000..491df042f2 --- /dev/null +++ b/library/cpp/neh/udp.h @@ -0,0 +1,7 @@ +#pragma once + +namespace NNeh { + class IProtocol; + + IProtocol* UdpProtocol(); +} diff --git a/library/cpp/neh/utils.cpp b/library/cpp/neh/utils.cpp new file mode 100644 index 0000000000..2f8671c581 --- /dev/null +++ b/library/cpp/neh/utils.cpp @@ -0,0 +1,47 @@ +#include "utils.h" + +#include <util/generic/utility.h> +#include <util/stream/output.h> +#include <util/stream/str.h> +#include <util/system/error.h> + +#if defined(_unix_) +#include <pthread.h> +#endif + +#if defined(_win_) +#include <windows.h> +#endif + +using namespace NNeh; + +size_t NNeh::RealStackSize(size_t len) noexcept { +#if defined(NDEBUG) && !defined(_san_enabled_) + return len; +#else + return Max<size_t>(len, 64000); +#endif +} + +TString NNeh::PrintHostByRfc(const NAddr::IRemoteAddr& addr) { + TStringStream ss; + + if (addr.Addr()->sa_family == AF_INET) { + NAddr::PrintHost(ss, addr); + } else if (addr.Addr()->sa_family == AF_INET6) { + ss << '['; + NAddr::PrintHost(ss, addr); + ss << ']'; + } + return ss.Str(); +} + +NAddr::IRemoteAddrPtr NNeh::GetPeerAddr(SOCKET s) { + TAutoPtr<NAddr::TOpaqueAddr> addr(new NAddr::TOpaqueAddr()); + + if (getpeername(s, addr->MutableAddr(), addr->LenPtr()) < 0) { + ythrow TSystemError() << "getpeername() failed"; + } + + return NAddr::IRemoteAddrPtr(addr.Release()); +} diff --git a/library/cpp/neh/utils.h b/library/cpp/neh/utils.h new file mode 100644 index 0000000000..ff1f63c2df --- /dev/null +++ b/library/cpp/neh/utils.h @@ -0,0 +1,39 @@ +#pragma once + +#include <library/cpp/threading/atomic/bool.h> + +#include <util/network/address.h> +#include <util/system/thread.h> +#include <util/generic/cast.h> +#include <library/cpp/deprecated/atomic/atomic.h> + +namespace NNeh { + typedef TAutoPtr<TThread> TThreadRef; + + template <typename T, void (T::*M)()> + static void* HelperMemberFunc(void* arg) { + T* obj = reinterpret_cast<T*>(arg); + (obj->*M)(); + return nullptr; + } + + template <typename T, void (T::*M)()> + static TThreadRef Spawn(T* t) { + TThreadRef thr(new TThread(HelperMemberFunc<T, M>, t)); + + thr->Start(); + + return thr; + } + + size_t RealStackSize(size_t len) noexcept; + + //from rfc3986: + //host = IP-literal / IPv4address / reg-name + //IP-literal = "[" ( IPv6address / IPvFuture ) "]" + TString PrintHostByRfc(const NAddr::IRemoteAddr& addr); + + NAddr::IRemoteAddrPtr GetPeerAddr(SOCKET s); + + using TAtomicBool = NAtomic::TBool; +} diff --git a/library/cpp/neh/wfmo.h b/library/cpp/neh/wfmo.h new file mode 100644 index 0000000000..11f32dda22 --- /dev/null +++ b/library/cpp/neh/wfmo.h @@ -0,0 +1,140 @@ +#pragma once + +#include "lfqueue.h" + +#include <library/cpp/threading/atomic/bool.h> + +#include <util/generic/vector.h> +#include <library/cpp/deprecated/atomic/atomic.h> +#include <library/cpp/deprecated/atomic/atomic_ops.h> +#include <util/system/event.h> +#include <util/system/spinlock.h> + +namespace NNeh { + template <class T> + class TBlockedQueue: public TLockFreeQueue<T>, public TSystemEvent { + public: + inline TBlockedQueue() noexcept + : TSystemEvent(TSystemEvent::rAuto) + { + } + + inline void Notify(T t) noexcept { + this->Enqueue(t); + Signal(); + } + }; + + class TWaitQueue { + public: + struct TWaitHandle { + inline TWaitHandle() noexcept + : Signalled(false) + , Parent(nullptr) + { + } + + inline void Signal() noexcept { + TGuard<TSpinLock> lock(M_); + + Signalled = true; + + if (Parent) { + Parent->Notify(this); + } + } + + inline void Register(TWaitQueue* parent) noexcept { + TGuard<TSpinLock> lock(M_); + + Parent = parent; + + if (Signalled) { + if (Parent) { + Parent->Notify(this); + } + } + } + + NAtomic::TBool Signalled; + TWaitQueue* Parent; + TSpinLock M_; + }; + + inline ~TWaitQueue() { + for (size_t i = 0; i < H_.size(); ++i) { + H_[i]->Register(nullptr); + } + } + + inline void Register(TWaitHandle& ev) { + H_.push_back(&ev); + ev.Register(this); + } + + template <class T> + inline void Register(const T& ev) { + Register(static_cast<TWaitHandle&>(*ev)); + } + + inline bool Wait(const TInstant& deadLine) noexcept { + return Q_.WaitD(deadLine); + } + + inline void Notify(TWaitHandle* wq) noexcept { + Q_.Notify(wq); + } + + inline bool Dequeue(TWaitHandle** wq) noexcept { + return Q_.Dequeue(wq); + } + + private: + TBlockedQueue<TWaitHandle*> Q_; + TVector<TWaitHandle*> H_; + }; + + typedef TWaitQueue::TWaitHandle TWaitHandle; + + template <class It, class T> + static inline void WaitForMultipleObj(It b, It e, const TInstant& deadLine, T& func) { + TWaitQueue hndl; + + while (b != e) { + hndl.Register(*b++); + } + + do { + TWaitHandle* ret = nullptr; + + if (hndl.Dequeue(&ret)) { + do { + func(ret); + } while (hndl.Dequeue(&ret)); + + return; + } + } while (hndl.Wait(deadLine)); + } + + struct TSignalled { + inline TSignalled() + : Signalled(false) + { + } + + inline void operator()(const TWaitHandle*) noexcept { + Signalled = true; + } + + bool Signalled; + }; + + static inline bool WaitForOne(TWaitHandle& wh, const TInstant& deadLine) { + TSignalled func; + + WaitForMultipleObj(&wh, &wh + 1, deadLine, func); + + return func.Signalled; + } +} |