aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/neh
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/neh
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'library/cpp/neh')
-rw-r--r--library/cpp/neh/README.md15
-rw-r--r--library/cpp/neh/asio/asio.cpp187
-rw-r--r--library/cpp/neh/asio/asio.h280
-rw-r--r--library/cpp/neh/asio/deadline_timer_impl.cpp1
-rw-r--r--library/cpp/neh/asio/deadline_timer_impl.h110
-rw-r--r--library/cpp/neh/asio/executor.cpp1
-rw-r--r--library/cpp/neh/asio/executor.h76
-rw-r--r--library/cpp/neh/asio/io_service_impl.cpp161
-rw-r--r--library/cpp/neh/asio/io_service_impl.h744
-rw-r--r--library/cpp/neh/asio/poll_interrupter.cpp1
-rw-r--r--library/cpp/neh/asio/poll_interrupter.h107
-rw-r--r--library/cpp/neh/asio/tcp_acceptor_impl.cpp25
-rw-r--r--library/cpp/neh/asio/tcp_acceptor_impl.h76
-rw-r--r--library/cpp/neh/asio/tcp_socket_impl.cpp117
-rw-r--r--library/cpp/neh/asio/tcp_socket_impl.h332
-rw-r--r--library/cpp/neh/conn_cache.cpp1
-rw-r--r--library/cpp/neh/conn_cache.h149
-rw-r--r--library/cpp/neh/details.h99
-rw-r--r--library/cpp/neh/factory.cpp67
-rw-r--r--library/cpp/neh/factory.h37
-rw-r--r--library/cpp/neh/http2.cpp2102
-rw-r--r--library/cpp/neh/http2.h119
-rw-r--r--library/cpp/neh/http_common.cpp235
-rw-r--r--library/cpp/neh/http_common.h305
-rw-r--r--library/cpp/neh/http_headers.cpp1
-rw-r--r--library/cpp/neh/http_headers.h55
-rw-r--r--library/cpp/neh/https.cpp1936
-rw-r--r--library/cpp/neh/https.h47
-rw-r--r--library/cpp/neh/inproc.cpp212
-rw-r--r--library/cpp/neh/inproc.h7
-rw-r--r--library/cpp/neh/jobqueue.cpp79
-rw-r--r--library/cpp/neh/jobqueue.h41
-rw-r--r--library/cpp/neh/lfqueue.h53
-rw-r--r--library/cpp/neh/location.cpp50
-rw-r--r--library/cpp/neh/location.h18
-rw-r--r--library/cpp/neh/multi.cpp35
-rw-r--r--library/cpp/neh/multi.h12
-rw-r--r--library/cpp/neh/multiclient.cpp378
-rw-r--r--library/cpp/neh/multiclient.h72
-rw-r--r--library/cpp/neh/neh.cpp146
-rw-r--r--library/cpp/neh/neh.h320
-rw-r--r--library/cpp/neh/netliba.cpp508
-rw-r--r--library/cpp/neh/netliba.h20
-rw-r--r--library/cpp/neh/netliba_udp_http.cpp808
-rw-r--r--library/cpp/neh/netliba_udp_http.h79
-rw-r--r--library/cpp/neh/pipequeue.cpp1
-rw-r--r--library/cpp/neh/pipequeue.h207
-rw-r--r--library/cpp/neh/rpc.cpp322
-rw-r--r--library/cpp/neh/rpc.h155
-rw-r--r--library/cpp/neh/rq.cpp312
-rw-r--r--library/cpp/neh/rq.h18
-rw-r--r--library/cpp/neh/smart_ptr.cpp1
-rw-r--r--library/cpp/neh/smart_ptr.h332
-rw-r--r--library/cpp/neh/stat.cpp114
-rw-r--r--library/cpp/neh/stat.h96
-rw-r--r--library/cpp/neh/tcp.cpp676
-rw-r--r--library/cpp/neh/tcp.h7
-rw-r--r--library/cpp/neh/tcp2.cpp1656
-rw-r--r--library/cpp/neh/tcp2.h44
-rw-r--r--library/cpp/neh/udp.cpp691
-rw-r--r--library/cpp/neh/udp.h7
-rw-r--r--library/cpp/neh/utils.cpp47
-rw-r--r--library/cpp/neh/utils.h39
-rw-r--r--library/cpp/neh/wfmo.h140
64 files changed, 15089 insertions, 0 deletions
diff --git a/library/cpp/neh/README.md b/library/cpp/neh/README.md
new file mode 100644
index 0000000000..962b3f67bc
--- /dev/null
+++ b/library/cpp/neh/README.md
@@ -0,0 +1,15 @@
+Транспортная библиотека neh
+===========================
+
+Обеспечивает простой интерфейс для осуществления запросов по сети (request/response - client/server). Обеспечивает лёгкую смену транспортного протокола.
+Есть несколько реализаций транспорта, каждая со своими плюсами/минусами.
+
+Документация
+============
+https://wiki.yandex-team.ru/development/poisk/arcadia/library/neh/
+
+FAQ
+===
+Q: А давайте прикрутим SSL (поддержку https)!
+A: ~~Этого не будет. neh - низкоуровневая шина, там не место ssl. Подробнее тут: https://clubs.at.yandex-team.ru/stackoverflow/5634~~
+A: Сделали
diff --git a/library/cpp/neh/asio/asio.cpp b/library/cpp/neh/asio/asio.cpp
new file mode 100644
index 0000000000..8b6cf383ea
--- /dev/null
+++ b/library/cpp/neh/asio/asio.cpp
@@ -0,0 +1,187 @@
+#include "io_service_impl.h"
+#include "deadline_timer_impl.h"
+#include "tcp_socket_impl.h"
+#include "tcp_acceptor_impl.h"
+
+using namespace NDns;
+using namespace NAsio;
+
+namespace NAsio {
+ TIOService::TWork::TWork(TWork& w)
+ : Srv_(w.Srv_)
+ {
+ Srv_.GetImpl().WorkStarted();
+ }
+
+ TIOService::TWork::TWork(TIOService& srv)
+ : Srv_(srv)
+ {
+ Srv_.GetImpl().WorkStarted();
+ }
+
+ TIOService::TWork::~TWork() {
+ Srv_.GetImpl().WorkFinished();
+ }
+
+ TIOService::TIOService()
+ : Impl_(new TImpl())
+ {
+ }
+
+ TIOService::~TIOService() {
+ }
+
+ void TIOService::Run() {
+ Impl_->Run();
+ }
+
+ void TIOService::Post(TCompletionHandler h) {
+ Impl_->Post(std::move(h));
+ }
+
+ void TIOService::Abort() {
+ Impl_->Abort();
+ }
+
+ TDeadlineTimer::TDeadlineTimer(TIOService& srv) noexcept
+ : Srv_(srv)
+ , Impl_(nullptr)
+ {
+ }
+
+ TDeadlineTimer::~TDeadlineTimer() {
+ if (Impl_) {
+ Srv_.GetImpl().ScheduleOp(new TUnregisterTimerOperation(Impl_));
+ }
+ }
+
+ void TDeadlineTimer::AsyncWaitExpireAt(TDeadline deadline, THandler h) {
+ if (!Impl_) {
+ Impl_ = new TDeadlineTimer::TImpl(Srv_.GetImpl());
+ Srv_.GetImpl().ScheduleOp(new TRegisterTimerOperation(Impl_));
+ }
+ Impl_->AsyncWaitExpireAt(deadline, h);
+ }
+
+ void TDeadlineTimer::Cancel() {
+ Impl_->Cancel();
+ }
+
+ TTcpSocket::TTcpSocket(TIOService& srv) noexcept
+ : Srv_(srv)
+ , Impl_(new TImpl(srv.GetImpl()))
+ {
+ }
+
+ TTcpSocket::~TTcpSocket() {
+ }
+
+ void TTcpSocket::AsyncConnect(const TEndpoint& ep, TTcpSocket::TConnectHandler h, TDeadline deadline) {
+ Impl_->AsyncConnect(ep, h, deadline);
+ }
+
+ void TTcpSocket::AsyncWrite(TSendedData& d, TTcpSocket::TWriteHandler h, TDeadline deadline) {
+ Impl_->AsyncWrite(d, h, deadline);
+ }
+
+ void TTcpSocket::AsyncWrite(TContIOVector* vec, TWriteHandler h, TDeadline deadline) {
+ Impl_->AsyncWrite(vec, h, deadline);
+ }
+
+ void TTcpSocket::AsyncWrite(const void* data, size_t size, TWriteHandler h, TDeadline deadline) {
+ class TBuffers: public IBuffers {
+ public:
+ TBuffers(const void* theData, size_t theSize)
+ : Part(theData, theSize)
+ , IOVec(&Part, 1)
+ {
+ }
+
+ TContIOVector* GetIOvec() override {
+ return &IOVec;
+ }
+
+ IOutputStream::TPart Part;
+ TContIOVector IOVec;
+ };
+
+ TSendedData d(new TBuffers(data, size));
+ Impl_->AsyncWrite(d, h, deadline);
+ }
+
+ void TTcpSocket::AsyncRead(void* buff, size_t size, TTcpSocket::TReadHandler h, TDeadline deadline) {
+ Impl_->AsyncRead(buff, size, h, deadline);
+ }
+
+ void TTcpSocket::AsyncReadSome(void* buff, size_t size, TTcpSocket::TReadHandler h, TDeadline deadline) {
+ Impl_->AsyncReadSome(buff, size, h, deadline);
+ }
+
+ void TTcpSocket::AsyncPollRead(TTcpSocket::TPollHandler h, TDeadline deadline) {
+ Impl_->AsyncPollRead(h, deadline);
+ }
+
+ void TTcpSocket::AsyncPollWrite(TTcpSocket::TPollHandler h, TDeadline deadline) {
+ Impl_->AsyncPollWrite(h, deadline);
+ }
+
+ void TTcpSocket::AsyncCancel() {
+ return Impl_->AsyncCancel();
+ }
+
+ size_t TTcpSocket::WriteSome(TContIOVector& d, TErrorCode& ec) noexcept {
+ return Impl_->WriteSome(d, ec);
+ }
+
+ size_t TTcpSocket::WriteSome(const void* buff, size_t size, TErrorCode& ec) noexcept {
+ return Impl_->WriteSome(buff, size, ec);
+ }
+
+ size_t TTcpSocket::ReadSome(void* buff, size_t size, TErrorCode& ec) noexcept {
+ return Impl_->ReadSome(buff, size, ec);
+ }
+
+ bool TTcpSocket::IsOpen() const noexcept {
+ return Native() != INVALID_SOCKET;
+ }
+
+ void TTcpSocket::Shutdown(TShutdownMode what, TErrorCode& ec) {
+ return Impl_->Shutdown(what, ec);
+ }
+
+ SOCKET TTcpSocket::Native() const noexcept {
+ return Impl_->Fd();
+ }
+
+ TEndpoint TTcpSocket::RemoteEndpoint() const {
+ return Impl_->RemoteEndpoint();
+ }
+
+ //////////////////////////////////
+
+ TTcpAcceptor::TTcpAcceptor(TIOService& srv) noexcept
+ : Srv_(srv)
+ , Impl_(new TImpl(srv.GetImpl()))
+ {
+ }
+
+ TTcpAcceptor::~TTcpAcceptor() {
+ }
+
+ void TTcpAcceptor::Bind(TEndpoint& ep, TErrorCode& ec) noexcept {
+ return Impl_->Bind(ep, ec);
+ }
+
+ void TTcpAcceptor::Listen(int backlog, TErrorCode& ec) noexcept {
+ return Impl_->Listen(backlog, ec);
+ }
+
+ void TTcpAcceptor::AsyncAccept(TTcpSocket& s, TTcpAcceptor::TAcceptHandler h, TDeadline deadline) {
+ return Impl_->AsyncAccept(s, h, deadline);
+ }
+
+ void TTcpAcceptor::AsyncCancel() {
+ Impl_->AsyncCancel();
+ }
+
+}
diff --git a/library/cpp/neh/asio/asio.h b/library/cpp/neh/asio/asio.h
new file mode 100644
index 0000000000..a902d663cf
--- /dev/null
+++ b/library/cpp/neh/asio/asio.h
@@ -0,0 +1,280 @@
+#pragma once
+
+//
+//primary header for work with asio
+//
+
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/generic/vector.h>
+#include <util/network/socket.h>
+#include <util/network/endpoint.h>
+#include <util/system/error.h>
+#include <util/stream/output.h>
+#include <functional>
+
+#include <library/cpp/dns/cache.h>
+
+//#define DEBUG_ASIO
+
+class TContIOVector;
+
+namespace NAsio {
+ class TErrorCode {
+ public:
+ inline TErrorCode(int val = 0) noexcept
+ : Val_(val)
+ {
+ }
+
+ typedef void (*TUnspecifiedBoolType)();
+
+ static void UnspecifiedBoolTrue() {
+ }
+
+ //safe cast to bool value
+ operator TUnspecifiedBoolType() const noexcept { // true if error
+ return Val_ == 0 ? nullptr : UnspecifiedBoolTrue;
+ }
+
+ bool operator!() const noexcept {
+ return Val_ == 0;
+ }
+
+ void Assign(int val) noexcept {
+ Val_ = val;
+ }
+
+ int Value() const noexcept {
+ return Val_;
+ }
+
+ TString Text() const {
+ if (!Val_) {
+ return TString();
+ }
+ return LastSystemErrorText(Val_);
+ }
+
+ void Check() {
+ if (Val_) {
+ throw TSystemError(Val_);
+ }
+ }
+
+ private:
+ int Val_;
+ };
+
+ //wrapper for TInstant, for enabling use TDuration (+TInstant::Now()) as deadline
+ class TDeadline: public TInstant {
+ public:
+ TDeadline()
+ : TInstant(TInstant::Max())
+ {
+ }
+
+ TDeadline(const TInstant& t)
+ : TInstant(t)
+ {
+ }
+
+ TDeadline(const TDuration& d)
+ : TInstant(TInstant::Now() + d)
+ {
+ }
+ };
+
+ class IHandlingContext {
+ public:
+ virtual ~IHandlingContext() {
+ }
+
+ //if handler throw exception, call this function be ignored
+ virtual void ContinueUseHandler(TDeadline deadline = TDeadline()) = 0;
+ };
+
+ typedef std::function<void()> TCompletionHandler;
+
+ class TIOService: public TNonCopyable {
+ public:
+ TIOService();
+ ~TIOService();
+
+ void Run();
+ void Post(TCompletionHandler); //call handler in Run() thread-executor
+ void Abort(); //in Run() all exist async i/o operations + timers receive error = ECANCELED, Run() exited
+
+ //counterpart boost::asio::io_service::work
+ class TWork {
+ public:
+ TWork(TWork&);
+ TWork(TIOService&);
+ ~TWork();
+
+ private:
+ void operator=(const TWork&); //disable
+
+ TIOService& Srv_;
+ };
+
+ class TImpl;
+
+ TImpl& GetImpl() noexcept {
+ return *Impl_;
+ }
+
+ private:
+ THolder<TImpl> Impl_;
+ };
+
+ class TDeadlineTimer: public TNonCopyable {
+ public:
+ typedef std::function<void(const TErrorCode& err, IHandlingContext&)> THandler;
+
+ TDeadlineTimer(TIOService&) noexcept;
+ ~TDeadlineTimer();
+
+ void AsyncWaitExpireAt(TDeadline, THandler);
+ void Cancel();
+
+ TIOService& GetIOService() const noexcept {
+ return Srv_;
+ }
+
+ class TImpl;
+
+ private:
+ TIOService& Srv_;
+ TImpl* Impl_;
+ };
+
+ class TTcpSocket: public TNonCopyable {
+ public:
+ class IBuffers {
+ public:
+ virtual ~IBuffers() {
+ }
+ virtual TContIOVector* GetIOvec() = 0;
+ };
+ typedef TAutoPtr<IBuffers> TSendedData;
+
+ typedef std::function<void(const TErrorCode& err, IHandlingContext&)> THandler;
+ typedef THandler TConnectHandler;
+ typedef std::function<void(const TErrorCode& err, size_t amount, IHandlingContext&)> TWriteHandler;
+ typedef std::function<void(const TErrorCode& err, size_t amount, IHandlingContext&)> TReadHandler;
+ typedef THandler TPollHandler;
+
+ enum TShutdownMode {
+ ShutdownReceive = SHUT_RD,
+ ShutdownSend = SHUT_WR,
+ ShutdownBoth = SHUT_RDWR
+ };
+
+ TTcpSocket(TIOService&) noexcept;
+ ~TTcpSocket();
+
+ void AsyncConnect(const TEndpoint& ep, TConnectHandler, TDeadline deadline = TDeadline());
+ void AsyncWrite(TSendedData&, TWriteHandler, TDeadline deadline = TDeadline());
+ void AsyncWrite(TContIOVector* buff, TWriteHandler, TDeadline deadline = TDeadline());
+ void AsyncWrite(const void* buff, size_t size, TWriteHandler, TDeadline deadline = TDeadline());
+ void AsyncRead(void* buff, size_t size, TReadHandler, TDeadline deadline = TDeadline());
+ void AsyncReadSome(void* buff, size_t size, TReadHandler, TDeadline deadline = TDeadline());
+ void AsyncPollWrite(TPollHandler, TDeadline deadline = TDeadline());
+ void AsyncPollRead(TPollHandler, TDeadline deadline = TDeadline());
+ void AsyncCancel();
+
+ //sync, but non blocked methods
+ size_t WriteSome(TContIOVector&, TErrorCode&) noexcept;
+ size_t WriteSome(const void* buff, size_t size, TErrorCode&) noexcept;
+ size_t ReadSome(void* buff, size_t size, TErrorCode&) noexcept;
+
+ bool IsOpen() const noexcept;
+ void Shutdown(TShutdownMode mode, TErrorCode& ec);
+
+ TIOService& GetIOService() const noexcept {
+ return Srv_;
+ }
+
+ SOCKET Native() const noexcept;
+
+ TEndpoint RemoteEndpoint() const;
+
+ inline size_t WriteSome(TContIOVector& v) {
+ TErrorCode ec;
+ size_t n = WriteSome(v, ec);
+ ec.Check();
+ return n;
+ }
+
+ inline size_t WriteSome(const void* buff, size_t size) {
+ TErrorCode ec;
+ size_t n = WriteSome(buff, size, ec);
+ ec.Check();
+ return n;
+ }
+
+ inline size_t ReadSome(void* buff, size_t size) {
+ TErrorCode ec;
+ size_t n = ReadSome(buff, size, ec);
+ ec.Check();
+ return n;
+ }
+
+ void Shutdown(TShutdownMode mode) {
+ TErrorCode ec;
+ Shutdown(mode, ec);
+ ec.Check();
+ }
+
+ class TImpl;
+
+ TImpl& GetImpl() const noexcept {
+ return *Impl_;
+ }
+
+ private:
+ TIOService& Srv_;
+ TIntrusivePtr<TImpl> Impl_;
+ };
+
+ class TTcpAcceptor: public TNonCopyable {
+ public:
+ typedef std::function<void(const TErrorCode& err, IHandlingContext&)> TAcceptHandler;
+
+ TTcpAcceptor(TIOService&) noexcept;
+ ~TTcpAcceptor();
+
+ void Bind(TEndpoint&, TErrorCode&) noexcept;
+ void Listen(int backlog, TErrorCode&) noexcept;
+
+ void AsyncAccept(TTcpSocket&, TAcceptHandler, TDeadline deadline = TDeadline());
+
+ void AsyncCancel();
+
+ inline void Bind(TEndpoint& ep) {
+ TErrorCode ec;
+ Bind(ep, ec);
+ ec.Check();
+ }
+ inline void Listen(int backlog) {
+ TErrorCode ec;
+ Listen(backlog, ec);
+ ec.Check();
+ }
+
+ TIOService& GetIOService() const noexcept {
+ return Srv_;
+ }
+
+ class TImpl;
+
+ TImpl& GetImpl() const noexcept {
+ return *Impl_;
+ }
+
+ private:
+ TIOService& Srv_;
+ TIntrusivePtr<TImpl> Impl_;
+ };
+}
diff --git a/library/cpp/neh/asio/deadline_timer_impl.cpp b/library/cpp/neh/asio/deadline_timer_impl.cpp
new file mode 100644
index 0000000000..399a4338fb
--- /dev/null
+++ b/library/cpp/neh/asio/deadline_timer_impl.cpp
@@ -0,0 +1 @@
+#include "deadline_timer_impl.h"
diff --git a/library/cpp/neh/asio/deadline_timer_impl.h b/library/cpp/neh/asio/deadline_timer_impl.h
new file mode 100644
index 0000000000..d9db625c94
--- /dev/null
+++ b/library/cpp/neh/asio/deadline_timer_impl.h
@@ -0,0 +1,110 @@
+#pragma once
+
+#include "io_service_impl.h"
+
+namespace NAsio {
+ class TTimerOperation: public TOperation {
+ public:
+ TTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline)
+ : TOperation(deadline)
+ , T_(t)
+ {
+ }
+
+ void AddOp(TIOService::TImpl&) override {
+ Y_ASSERT(0);
+ }
+
+ void Finalize() override {
+ DBGOUT("TTimerDeadlineOperation::Finalize()");
+ T_->DelOp(this);
+ }
+
+ protected:
+ TIOService::TImpl::TTimer* T_;
+ };
+
+ class TRegisterTimerOperation: public TTimerOperation {
+ public:
+ TRegisterTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline = TInstant::Max())
+ : TTimerOperation(t, deadline)
+ {
+ Speculative_ = true;
+ }
+
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ T_->GetIOServiceImpl().SyncRegisterTimer(T_);
+ return true;
+ }
+ };
+
+ class TTimerDeadlineOperation: public TTimerOperation {
+ public:
+ TTimerDeadlineOperation(TIOService::TImpl::TTimer* t, TDeadlineTimer::THandler h, TInstant deadline)
+ : TTimerOperation(t, deadline)
+ , H_(h)
+ {
+ }
+
+ void AddOp(TIOService::TImpl&) override {
+ T_->AddOp(this);
+ }
+
+ bool Execute(int errorCode) override {
+ DBGOUT("TTimerDeadlineOperation::Execute(" << errorCode << ")");
+ H_(errorCode == ETIMEDOUT ? 0 : errorCode, *this);
+ return true;
+ }
+
+ private:
+ TDeadlineTimer::THandler H_;
+ };
+
+ class TCancelTimerOperation: public TTimerOperation {
+ public:
+ TCancelTimerOperation(TIOService::TImpl::TTimer* t)
+ : TTimerOperation(t, TInstant::Max())
+ {
+ Speculative_ = true;
+ }
+
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ T_->FailOperations(ECANCELED);
+ return true;
+ }
+ };
+
+ class TUnregisterTimerOperation: public TTimerOperation {
+ public:
+ TUnregisterTimerOperation(TIOService::TImpl::TTimer* t, TInstant deadline = TInstant::Max())
+ : TTimerOperation(t, deadline)
+ {
+ Speculative_ = true;
+ }
+
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ DBGOUT("TUnregisterTimerOperation::Execute(" << errorCode << ")");
+ T_->GetIOServiceImpl().SyncUnregisterAndDestroyTimer(T_);
+ return true;
+ }
+ };
+
+ class TDeadlineTimer::TImpl: public TIOService::TImpl::TTimer {
+ public:
+ TImpl(TIOService::TImpl& srv)
+ : TIOService::TImpl::TTimer(srv)
+ {
+ }
+
+ void AsyncWaitExpireAt(TDeadline d, TDeadlineTimer::THandler h) {
+ Srv_.ScheduleOp(new TTimerDeadlineOperation(this, h, d));
+ }
+
+ void Cancel() {
+ Srv_.ScheduleOp(new TCancelTimerOperation(this));
+ }
+ };
+}
diff --git a/library/cpp/neh/asio/executor.cpp b/library/cpp/neh/asio/executor.cpp
new file mode 100644
index 0000000000..03b26bf847
--- /dev/null
+++ b/library/cpp/neh/asio/executor.cpp
@@ -0,0 +1 @@
+#include "executor.h"
diff --git a/library/cpp/neh/asio/executor.h b/library/cpp/neh/asio/executor.h
new file mode 100644
index 0000000000..4f6549044d
--- /dev/null
+++ b/library/cpp/neh/asio/executor.h
@@ -0,0 +1,76 @@
+#pragma once
+
+#include "asio.h"
+
+#include <library/cpp/deprecated/atomic/atomic.h>
+
+#include <util/thread/factory.h>
+#include <util/system/thread.h>
+
+namespace NAsio {
+ class TIOServiceExecutor: public IThreadFactory::IThreadAble {
+ public:
+ TIOServiceExecutor()
+ : Work_(new TIOService::TWork(Srv_))
+ {
+ T_ = SystemThreadFactory()->Run(this);
+ }
+
+ ~TIOServiceExecutor() override {
+ SyncShutdown();
+ }
+
+ void DoExecute() override {
+ TThread::SetCurrentThreadName("NehAsioExecutor");
+ Srv_.Run();
+ }
+
+ inline TIOService& GetIOService() noexcept {
+ return Srv_;
+ }
+
+ void SyncShutdown() {
+ if (Work_) {
+ Work_.Destroy();
+ Srv_.Abort(); //cancel all async operations, break Run() execution
+ T_->Join();
+ }
+ }
+
+ private:
+ TIOService Srv_;
+ TAutoPtr<TIOService::TWork> Work_;
+ typedef TAutoPtr<IThreadFactory::IThread> IThreadRef;
+ IThreadRef T_;
+ };
+
+ class TExecutorsPool {
+ public:
+ TExecutorsPool(size_t executors)
+ : C_(0)
+ {
+ for (size_t i = 0; i < executors; ++i) {
+ E_.push_back(new TIOServiceExecutor());
+ }
+ }
+
+ inline size_t Size() const noexcept {
+ return E_.size();
+ }
+
+ inline TIOServiceExecutor& GetExecutor() noexcept {
+ TAtomicBase next = AtomicIncrement(C_);
+ return *E_[next % E_.size()];
+ }
+
+ void SyncShutdown() {
+ for (size_t i = 0; i < E_.size(); ++i) {
+ E_[i]->SyncShutdown();
+ }
+ }
+
+ private:
+ TAtomic C_;
+ TVector<TAutoPtr<TIOServiceExecutor>> E_;
+ };
+}
diff --git a/library/cpp/neh/asio/io_service_impl.cpp b/library/cpp/neh/asio/io_service_impl.cpp
new file mode 100644
index 0000000000..d49b3fb03e
--- /dev/null
+++ b/library/cpp/neh/asio/io_service_impl.cpp
@@ -0,0 +1,161 @@
+#include "io_service_impl.h"
+
+#include <library/cpp/coroutine/engine/poller.h>
+
+using namespace NAsio;
+
+void TFdOperation::AddOp(TIOService::TImpl& srv) {
+ srv.AddOp(this);
+}
+
+void TFdOperation::Finalize() {
+ (*PH_)->DelOp(this);
+}
+
+void TPollFdEventHandler::ExecuteOperations(TFdOperations& oprs, int errorCode) {
+ TFdOperations::iterator it = oprs.begin();
+
+ try {
+ while (it != oprs.end()) {
+ TFdOperation* op = it->Get();
+
+ if (op->Execute(errorCode)) { // throw ?
+ if (op->IsRequiredRepeat()) {
+ Srv_.UpdateOpDeadline(op);
+ ++it; //operation completed, but want be repeated
+ } else {
+ FinishedOperations_.push_back(*it);
+ it = oprs.erase(it);
+ }
+ } else {
+ ++it; //operation not completed
+ }
+ }
+ } catch (...) {
+ if (it != oprs.end()) {
+ FinishedOperations_.push_back(*it);
+ oprs.erase(it);
+ }
+ throw;
+ }
+}
+
+void TPollFdEventHandler::DelOp(TFdOperation* op) {
+ TAutoPtr<TPollFdEventHandler>& evh = *op->PH_;
+
+ if (op->IsPollRead()) {
+ Y_ASSERT(FinishOp(ReadOperations_, op));
+ } else {
+ Y_ASSERT(FinishOp(WriteOperations_, op));
+ }
+ Srv_.FixHandledEvents(evh); //alarm, - 'this' can be destroyed here!
+}
+
+void TInterrupterHandler::OnFdEvent(int status, ui16 filter) {
+ if (!status && (filter & CONT_POLL_READ)) {
+ PI_.Reset();
+ }
+}
+
+void TIOService::TImpl::Run() {
+ TEvh& iEvh = Evh_.Get(I_.Fd());
+ iEvh.Reset(new TInterrupterHandler(*this, I_));
+
+ TInterrupterKeeper ik(*this, iEvh);
+ Y_UNUSED(ik);
+ IPollerFace::TEvents evs;
+ AtomicSet(NeedCheckOpQueue_, 1);
+ TInstant deadline;
+
+ while (Y_LIKELY(!Aborted_ && (AtomicGet(OutstandingWork_) || FdEventHandlersCnt_ > 1 || TimersOpCnt_ || AtomicGet(NeedCheckOpQueue_)))) {
+ //while
+ // expected work (external flag)
+ // or have event handlers (exclude interrupter)
+ // or have not completed timer operation
+ // or have any operation in queues
+
+ AtomicIncrement(IsWaiting_);
+ if (!AtomicGet(NeedCheckOpQueue_)) {
+ P_->Wait(evs, deadline);
+ }
+ AtomicDecrement(IsWaiting_);
+
+ if (evs.size()) {
+ for (IPollerFace::TEvents::const_iterator iev = evs.begin(); iev != evs.end() && !Aborted_; ++iev) {
+ const IPollerFace::TEvent& ev = *iev;
+ TEvh& evh = *(TEvh*)ev.Data;
+
+ if (!evh) {
+ continue; //op. cancel (see ProcessOpQueue) can destroy evh
+ }
+
+ int status = ev.Status;
+ if (ev.Status == EIO) {
+ int error = status;
+ if (GetSockOpt(evh->Fd(), SOL_SOCKET, SO_ERROR, error) == 0) {
+ status = error;
+ }
+ }
+
+ OnFdEvent(evh, status, ev.Filter); //here handle fd events
+ //immediatly after handling events for one descriptor check op. queue
+ //often queue can contain another operation for this fd (next async read as sample)
+ //so we can optimize redundant epoll_ctl (or similar) calls
+ ProcessOpQueue();
+ }
+
+ evs.clear();
+ } else {
+ ProcessOpQueue();
+ }
+
+ deadline = DeadlinesQueue_.NextDeadline(); //here handle timeouts/process timers
+ }
+}
+
+void TIOService::TImpl::Abort() {
+ class TAbortOperation: public TNoneOperation {
+ public:
+ TAbortOperation(TIOService::TImpl& srv)
+ : TNoneOperation()
+ , Srv_(srv)
+ {
+ Speculative_ = true;
+ }
+
+ private:
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ Srv_.ProcessAbort();
+ return true;
+ }
+
+ TIOService::TImpl& Srv_;
+ };
+ AtomicSet(HasAbort_, 1);
+ ScheduleOp(new TAbortOperation(*this));
+}
+
+void TIOService::TImpl::ProcessAbort() {
+ Aborted_ = true;
+
+ for (int fd = 0; fd <= MaxFd_; ++fd) {
+ TEvh& evh = Evh_.Get(fd);
+ if (!!evh && evh->Fd() != I_.Fd()) {
+ OnFdEvent(evh, ECANCELED, CONT_POLL_READ | CONT_POLL_WRITE);
+ }
+ }
+
+ for (auto t : Timers_) {
+ t->FailOperations(ECANCELED);
+ }
+
+ TOperationPtr op;
+ while (OpQueue_.Dequeue(&op)) { //cancel all enqueued operations
+ try {
+ op->Execute(ECANCELED);
+ } catch (...) {
+ }
+ op.Destroy();
+ }
+}
diff --git a/library/cpp/neh/asio/io_service_impl.h b/library/cpp/neh/asio/io_service_impl.h
new file mode 100644
index 0000000000..46fa9f9ee1
--- /dev/null
+++ b/library/cpp/neh/asio/io_service_impl.h
@@ -0,0 +1,744 @@
+#pragma once
+
+#include "asio.h"
+#include "poll_interrupter.h"
+
+#include <library/cpp/neh/lfqueue.h>
+#include <library/cpp/neh/pipequeue.h>
+
+#include <library/cpp/dns/cache.h>
+
+#include <util/generic/hash_set.h>
+#include <util/network/iovec.h>
+#include <util/network/pollerimpl.h>
+#include <util/thread/lfqueue.h>
+#include <util/thread/factory.h>
+
+#ifdef DEBUG_ASIO
+#define DBGOUT(args) Cout << args << Endl;
+#else
+#define DBGOUT(args)
+#endif
+
+namespace NAsio {
+ //TODO: copypaste from neh, - need fix
+ template <class T>
+ class TLockFreeSequence {
+ public:
+ inline TLockFreeSequence() {
+ memset((void*)T_, 0, sizeof(T_));
+ }
+
+ inline ~TLockFreeSequence() {
+ for (size_t i = 0; i < Y_ARRAY_SIZE(T_); ++i) {
+ delete[] T_[i];
+ }
+ }
+
+ inline T& Get(size_t n) {
+ const size_t i = GetValueBitCount(n + 1) - 1;
+
+ return GetList(i)[n + 1 - (((size_t)1) << i)];
+ }
+
+ private:
+ inline T* GetList(size_t n) {
+ T* volatile* t = T_ + n;
+
+ while (!*t) {
+ TArrayHolder<T> nt(new T[((size_t)1) << n]);
+
+ if (AtomicCas(t, nt.Get(), nullptr)) {
+ return nt.Release();
+ }
+ }
+
+ return *t;
+ }
+
+ private:
+ T* volatile T_[sizeof(size_t) * 8];
+ };
+
+ struct TOperationCompare {
+ template <class T>
+ static inline bool Compare(const T& l, const T& r) noexcept {
+ return l.DeadLine() < r.DeadLine() || (l.DeadLine() == r.DeadLine() && &l < &r);
+ }
+ };
+
+ //async operation, execute in contex TIOService()::Run() thread-executor
+ //usualy used for call functors/callbacks
+ class TOperation: public TRbTreeItem<TOperation, TOperationCompare>, public IHandlingContext {
+ public:
+ TOperation(TInstant deadline = TInstant::Max())
+ : D_(deadline)
+ , Speculative_(false)
+ , RequiredRepeatExecution_(false)
+ , ND_(deadline)
+ {
+ }
+
+ //register this operation in svc.impl.
+ virtual void AddOp(TIOService::TImpl&) = 0;
+
+ //return false, if operation not completed
+ virtual bool Execute(int errorCode = 0) = 0;
+
+ void ContinueUseHandler(TDeadline deadline) override {
+ RequiredRepeatExecution_ = true;
+ ND_ = deadline;
+ }
+
+ virtual void Finalize() = 0;
+
+ inline TInstant Deadline() const noexcept {
+ return D_;
+ }
+
+ inline TInstant DeadLine() const noexcept {
+ return D_;
+ }
+
+ inline bool Speculative() const noexcept {
+ return Speculative_;
+ }
+
+ inline bool IsRequiredRepeat() const noexcept {
+ return RequiredRepeatExecution_;
+ }
+
+ inline void PrepareReExecution() noexcept {
+ RequiredRepeatExecution_ = false;
+ D_ = ND_;
+ }
+
+ protected:
+ TInstant D_;
+ bool Speculative_; //if true, operation will be runned immediately after dequeue (even without wating any event)
+ //as sample used for optimisation writing, - obviously in buffers exist space for write
+ bool RequiredRepeatExecution_; //set to true, if required re-exec operation
+ TInstant ND_; //new deadline (for re-exec operation)
+ };
+
+ typedef TAutoPtr<TOperation> TOperationPtr;
+
+ class TNoneOperation: public TOperation {
+ public:
+ TNoneOperation(TInstant deadline = TInstant::Max())
+ : TOperation(deadline)
+ {
+ }
+
+ void AddOp(TIOService::TImpl&) override {
+ Y_ASSERT(0);
+ }
+
+ void Finalize() override {
+ }
+ };
+
+ class TPollFdEventHandler;
+
+ //descriptor use operation
+ class TFdOperation: public TOperation {
+ public:
+ enum TPollType {
+ PollRead,
+ PollWrite
+ };
+
+ TFdOperation(SOCKET fd, TPollType pt, TInstant deadline = TInstant::Max())
+ : TOperation(deadline)
+ , Fd_(fd)
+ , PT_(pt)
+ , PH_(nullptr)
+ {
+ Y_ASSERT(Fd() != INVALID_SOCKET);
+ }
+
+ inline SOCKET Fd() const noexcept {
+ return Fd_;
+ }
+
+ inline bool IsPollRead() const noexcept {
+ return PT_ == PollRead;
+ }
+
+ void AddOp(TIOService::TImpl& srv) override;
+
+ void Finalize() override;
+
+ protected:
+ SOCKET Fd_;
+ TPollType PT_;
+
+ public:
+ TAutoPtr<TPollFdEventHandler>* PH_;
+ };
+
+ typedef TAutoPtr<TFdOperation> TFdOperationPtr;
+
+ class TPollFdEventHandler {
+ public:
+ TPollFdEventHandler(SOCKET fd, TIOService::TImpl& srv)
+ : Fd_(fd)
+ , HandledEvents_(0)
+ , Srv_(srv)
+ {
+ }
+
+ virtual ~TPollFdEventHandler() {
+ Y_ASSERT(ReadOperations_.size() == 0);
+ Y_ASSERT(WriteOperations_.size() == 0);
+ }
+
+ inline void AddReadOp(TFdOperationPtr op) {
+ ReadOperations_.push_back(op);
+ }
+
+ inline void AddWriteOp(TFdOperationPtr op) {
+ WriteOperations_.push_back(op);
+ }
+
+ virtual void OnFdEvent(int status, ui16 filter) {
+ DBGOUT("PollEvent(fd=" << Fd_ << ", " << status << ", " << filter << ")");
+ if (status) {
+ ExecuteOperations(ReadOperations_, status);
+ ExecuteOperations(WriteOperations_, status);
+ } else {
+ if (filter & CONT_POLL_READ) {
+ ExecuteOperations(ReadOperations_, status);
+ }
+ if (filter & CONT_POLL_WRITE) {
+ ExecuteOperations(WriteOperations_, status);
+ }
+ }
+ }
+
+ typedef TVector<TFdOperationPtr> TFdOperations;
+
+ void ExecuteOperations(TFdOperations& oprs, int errorCode);
+
+ //return true if filter handled events changed and require re-configure events poller
+ virtual bool FixHandledEvents() noexcept {
+ DBGOUT("TPollFdEventHandler::FixHandledEvents()");
+ ui16 filter = 0;
+
+ if (WriteOperations_.size()) {
+ filter |= CONT_POLL_WRITE;
+ }
+ if (ReadOperations_.size()) {
+ filter |= CONT_POLL_READ;
+ }
+
+ if (Y_LIKELY(HandledEvents_ == filter)) {
+ return false;
+ }
+
+ HandledEvents_ = filter;
+ return true;
+ }
+
+ inline bool FinishOp(TFdOperations& oprs, TFdOperation* op) noexcept {
+ for (TFdOperations::iterator it = oprs.begin(); it != oprs.end(); ++it) {
+ if (it->Get() == op) {
+ FinishedOperations_.push_back(*it);
+ oprs.erase(it);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void DelOp(TFdOperation* op);
+
+ inline SOCKET Fd() const noexcept {
+ return Fd_;
+ }
+
+ inline ui16 HandledEvents() const noexcept {
+ return HandledEvents_;
+ }
+
+ inline void AddHandlingEvent(ui16 ev) noexcept {
+ HandledEvents_ |= ev;
+ }
+
+ inline void DestroyFinishedOperations() {
+ FinishedOperations_.clear();
+ }
+
+ TIOService::TImpl& GetServiceImpl() const noexcept {
+ return Srv_;
+ }
+
+ protected:
+ SOCKET Fd_;
+ ui16 HandledEvents_;
+ TIOService::TImpl& Srv_;
+
+ private:
+ TVector<TFdOperationPtr> ReadOperations_;
+ TVector<TFdOperationPtr> WriteOperations_;
+ // we can't immediatly destroy finished operations, this can cause closing used socket descriptor Fd_
+ // (on cascade deletion operation object-handler), but later we use Fd_ for modify handled events at poller,
+ // so we collect here finished operations and destroy it only after update poller, -
+ // call FixHandledEvents(TPollFdEventHandlerPtr&)
+ TVector<TFdOperationPtr> FinishedOperations_;
+ };
+
+ //additional descriptor for poller, used for interrupt current poll wait
+ class TInterrupterHandler: public TPollFdEventHandler {
+ public:
+ TInterrupterHandler(TIOService::TImpl& srv, TPollInterrupter& pi)
+ : TPollFdEventHandler(pi.Fd(), srv)
+ , PI_(pi)
+ {
+ HandledEvents_ = CONT_POLL_READ;
+ }
+
+ ~TInterrupterHandler() override {
+ DBGOUT("~TInterrupterHandler");
+ }
+
+ void OnFdEvent(int status, ui16 filter) override;
+
+ bool FixHandledEvents() noexcept override {
+ DBGOUT("TInterrupterHandler::FixHandledEvents()");
+ return false;
+ }
+
+ private:
+ TPollInterrupter& PI_;
+ };
+
+ namespace {
+ inline TAutoPtr<IPollerFace> CreatePoller() {
+ try {
+#if defined(_linux_)
+ return IPollerFace::Construct(TStringBuf("epoll"));
+#endif
+#if defined(_freebsd_) || defined(_darwin_)
+ return IPollerFace::Construct(TStringBuf("kqueue"));
+#endif
+ } catch (...) {
+ Cdbg << CurrentExceptionMessage() << Endl;
+ }
+ return IPollerFace::Default();
+ }
+ }
+
+ //some equivalent TContExecutor
+ class TIOService::TImpl: public TNonCopyable {
+ public:
+ typedef TAutoPtr<TPollFdEventHandler> TEvh;
+ typedef TLockFreeSequence<TEvh> TEventHandlers;
+
+ class TTimer {
+ public:
+ typedef THashSet<TOperation*> TOperations;
+
+ TTimer(TIOService::TImpl& srv)
+ : Srv_(srv)
+ {
+ }
+
+ virtual ~TTimer() {
+ FailOperations(ECANCELED);
+ }
+
+ void AddOp(TOperation* op) {
+ THolder<TOperation> tmp(op);
+ Operations_.insert(op);
+ Y_UNUSED(tmp.Release());
+ Srv_.RegisterOpDeadline(op);
+ Srv_.IncTimersOp();
+ }
+
+ void DelOp(TOperation* op) {
+ TOperations::iterator it = Operations_.find(op);
+ if (it != Operations_.end()) {
+ Srv_.DecTimersOp();
+ delete op;
+ Operations_.erase(it);
+ }
+ }
+
+ inline void FailOperations(int ec) {
+ for (auto operation : Operations_) {
+ try {
+ operation->Execute(ec); //throw ?
+ } catch (...) {
+ }
+ Srv_.DecTimersOp();
+ delete operation;
+ }
+ Operations_.clear();
+ }
+
+ TIOService::TImpl& GetIOServiceImpl() const noexcept {
+ return Srv_;
+ }
+
+ protected:
+ TIOService::TImpl& Srv_;
+ THashSet<TOperation*> Operations_;
+ };
+
+ class TTimers: public THashSet<TTimer*> {
+ public:
+ ~TTimers() {
+ for (auto it : *this) {
+ delete it;
+ }
+ }
+ };
+
+ TImpl()
+ : P_(CreatePoller())
+ , DeadlinesQueue_(*this)
+ {
+ }
+
+ ~TImpl() {
+ TOperationPtr op;
+
+ while (OpQueue_.Dequeue(&op)) { //cancel all enqueued operations
+ try {
+ op->Execute(ECANCELED);
+ } catch (...) {
+ }
+ op.Destroy();
+ }
+ }
+
+ //similar TContExecutor::Execute() or io_service::run()
+ //process event loop (exit if none to do (no timers or event handlers))
+ void Run();
+
+ //enqueue functor fo call in Run() eventloop (thread safing)
+ inline void Post(TCompletionHandler h) {
+ class TFuncOperation: public TNoneOperation {
+ public:
+ TFuncOperation(TCompletionHandler completionHandler)
+ : TNoneOperation()
+ , H_(std::move(completionHandler))
+ {
+ Speculative_ = true;
+ }
+
+ private:
+ //return false, if operation not completed
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ H_();
+ return true;
+ }
+
+ TCompletionHandler H_;
+ };
+
+ ScheduleOp(new TFuncOperation(std::move(h)));
+ }
+
+ //cancel all current operations (handlers be called with errorCode == ECANCELED)
+ void Abort();
+ bool HasAbort() {
+ return AtomicGet(HasAbort_);
+ }
+
+ inline void ScheduleOp(TOperationPtr op) { //throw std::bad_alloc
+ Y_ASSERT(!Aborted_);
+ Y_ASSERT(!!op);
+ OpQueue_.Enqueue(op);
+ Interrupt();
+ }
+
+ inline void Interrupt() noexcept {
+ AtomicSet(NeedCheckOpQueue_, 1);
+ if (AtomicAdd(IsWaiting_, 0) == 1) {
+ I_.Interrupt();
+ }
+ }
+
+ inline void UpdateOpDeadline(TOperation* op) {
+ TInstant oldDeadline = op->Deadline();
+ op->PrepareReExecution();
+
+ if (oldDeadline == op->Deadline()) {
+ return;
+ }
+
+ if (oldDeadline != TInstant::Max()) {
+ op->UnLink();
+ }
+ if (op->Deadline() != TInstant::Max()) {
+ DeadlinesQueue_.Register(op);
+ }
+ }
+
+ void SyncRegisterTimer(TTimer* t) {
+ Timers_.insert(t);
+ }
+
+ inline void SyncUnregisterAndDestroyTimer(TTimer* t) {
+ Timers_.erase(t);
+ delete t;
+ }
+
+ inline void IncTimersOp() noexcept {
+ ++TimersOpCnt_;
+ }
+
+ inline void DecTimersOp() noexcept {
+ --TimersOpCnt_;
+ }
+
+ inline void WorkStarted() {
+ AtomicIncrement(OutstandingWork_);
+ }
+
+ inline void WorkFinished() {
+ if (AtomicDecrement(OutstandingWork_) == 0) {
+ Interrupt();
+ }
+ }
+
+ private:
+ void ProcessAbort();
+
+ inline TEvh& EnsureGetEvh(SOCKET fd) {
+ TEvh& evh = Evh_.Get(fd);
+ if (!evh) {
+ evh.Reset(new TPollFdEventHandler(fd, *this));
+ }
+ return evh;
+ }
+
+ inline void OnTimeoutOp(TOperation* op) {
+ DBGOUT("OnTimeoutOp");
+ try {
+ op->Execute(ETIMEDOUT); //throw ?
+ } catch (...) {
+ op->Finalize();
+ throw;
+ }
+
+ if (op->IsRequiredRepeat()) {
+ //operation not completed
+ UpdateOpDeadline(op);
+ } else {
+ //destroy operation structure
+ op->Finalize();
+ }
+ }
+
+ public:
+ inline void FixHandledEvents(TEvh& evh) {
+ if (!!evh) {
+ if (evh->FixHandledEvents()) {
+ if (!evh->HandledEvents()) {
+ DelEventHandler(evh);
+ evh.Destroy();
+ } else {
+ ModEventHandler(evh);
+ evh->DestroyFinishedOperations();
+ }
+ } else {
+ evh->DestroyFinishedOperations();
+ }
+ }
+ }
+
+ private:
+ inline TEvh& GetHandlerForOp(TFdOperation* op) {
+ TEvh& evh = EnsureGetEvh(op->Fd());
+ op->PH_ = &evh;
+ return evh;
+ }
+
+ void ProcessOpQueue() {
+ if (!AtomicGet(NeedCheckOpQueue_)) {
+ return;
+ }
+ AtomicSet(NeedCheckOpQueue_, 0);
+
+ TOperationPtr op;
+
+ while (OpQueue_.Dequeue(&op)) {
+ if (op->Speculative()) {
+ if (op->Execute(Y_UNLIKELY(Aborted_) ? ECANCELED : 0)) {
+ op.Destroy();
+ continue; //operation completed
+ }
+
+ if (!op->IsRequiredRepeat()) {
+ op->PrepareReExecution();
+ }
+ }
+ RegisterOpDeadline(op.Get());
+ op.Get()->AddOp(*this); // ... -> AddOp()
+ Y_UNUSED(op.Release());
+ }
+ }
+
+ inline void RegisterOpDeadline(TOperation* op) {
+ if (op->DeadLine() != TInstant::Max()) {
+ DeadlinesQueue_.Register(op);
+ }
+ }
+
+ public:
+ inline void AddOp(TFdOperation* op) {
+ DBGOUT("AddOp<Fd>(" << op->Fd() << ")");
+ TEvh& evh = GetHandlerForOp(op);
+ if (op->IsPollRead()) {
+ evh->AddReadOp(op);
+ EnsureEventHandled(evh, CONT_POLL_READ);
+ } else {
+ evh->AddWriteOp(op);
+ EnsureEventHandled(evh, CONT_POLL_WRITE);
+ }
+ }
+
+ private:
+ inline void EnsureEventHandled(TEvh& evh, ui16 ev) {
+ if (!evh->HandledEvents()) {
+ evh->AddHandlingEvent(ev);
+ AddEventHandler(evh);
+ } else {
+ if ((evh->HandledEvents() & ev) == 0) {
+ evh->AddHandlingEvent(ev);
+ ModEventHandler(evh);
+ }
+ }
+ }
+
+ public:
+ //cancel all current operations for socket
+ //method MUST be called from Run() thread-executor
+ void CancelFdOp(SOCKET fd) {
+ TEvh& evh = Evh_.Get(fd);
+ if (!evh) {
+ return;
+ }
+
+ OnFdEvent(evh, ECANCELED, CONT_POLL_READ | CONT_POLL_WRITE);
+ }
+
+ private:
+ //helper for fixing handled events even in case exception
+ struct TExceptionProofFixerHandledEvents {
+ TExceptionProofFixerHandledEvents(TIOService::TImpl& srv, TEvh& iEvh)
+ : Srv_(srv)
+ , Evh_(iEvh)
+ {
+ }
+
+ ~TExceptionProofFixerHandledEvents() {
+ Srv_.FixHandledEvents(Evh_);
+ }
+
+ TIOService::TImpl& Srv_;
+ TEvh& Evh_;
+ };
+
+ inline void OnFdEvent(TEvh& evh, int status, ui16 filter) {
+ TExceptionProofFixerHandledEvents fixer(*this, evh);
+ Y_UNUSED(fixer);
+ evh->OnFdEvent(status, filter);
+ }
+
+ inline void AddEventHandler(TEvh& evh) {
+ if (evh->Fd() > MaxFd_) {
+ MaxFd_ = evh->Fd();
+ }
+ SetEventHandler(&evh, evh->Fd(), evh->HandledEvents());
+ ++FdEventHandlersCnt_;
+ }
+
+ inline void ModEventHandler(TEvh& evh) {
+ SetEventHandler(&evh, evh->Fd(), evh->HandledEvents());
+ }
+
+ inline void DelEventHandler(TEvh& evh) {
+ SetEventHandler(&evh, evh->Fd(), 0);
+ --FdEventHandlersCnt_;
+ }
+
+ inline void SetEventHandler(void* h, int fd, ui16 flags) {
+ DBGOUT("SetEventHandler(" << fd << ", " << flags << ")");
+ P_->Set(h, fd, flags);
+ }
+
+ //exception safe call DelEventHandler
+ struct TInterrupterKeeper {
+ TInterrupterKeeper(TImpl& srv, TEvh& iEvh)
+ : Srv_(srv)
+ , Evh_(iEvh)
+ {
+ Srv_.AddEventHandler(Evh_);
+ }
+
+ ~TInterrupterKeeper() {
+ Srv_.DelEventHandler(Evh_);
+ }
+
+ TImpl& Srv_;
+ TEvh& Evh_;
+ };
+
+ TAutoPtr<IPollerFace> P_;
+ TPollInterrupter I_;
+ TAtomic IsWaiting_ = 0;
+ TAtomic NeedCheckOpQueue_ = 0;
+ TAtomic OutstandingWork_ = 0;
+
+ NNeh::TAutoLockFreeQueue<TOperation> OpQueue_;
+
+ TEventHandlers Evh_; //i/o event handlers
+ TTimers Timers_; //timeout event handlers
+
+ size_t FdEventHandlersCnt_ = 0; //i/o event handlers counter
+ size_t TimersOpCnt_ = 0; //timers op counter
+ SOCKET MaxFd_ = 0; //max used descriptor num
+ TAtomic HasAbort_ = 0;
+ bool Aborted_ = false;
+
+ class TDeadlinesQueue {
+ public:
+ TDeadlinesQueue(TIOService::TImpl& srv)
+ : Srv_(srv)
+ {
+ }
+
+ inline void Register(TOperation* op) {
+ Deadlines_.Insert(op);
+ }
+
+ TInstant NextDeadline() {
+ TDeadlines::TIterator it = Deadlines_.Begin();
+
+ while (it != Deadlines_.End()) {
+ if (it->DeadLine() > TInstant::Now()) {
+ DBGOUT("TDeadlinesQueue::NewDeadline:" << (it->DeadLine().GetValue() - TInstant::Now().GetValue()));
+ return it->DeadLine();
+ }
+
+ TOperation* op = &*(it++);
+ Srv_.OnTimeoutOp(op);
+ }
+
+ return Deadlines_.Empty() ? TInstant::Max() : Deadlines_.Begin()->DeadLine();
+ }
+
+ private:
+ typedef TRbTree<TOperation, TOperationCompare> TDeadlines;
+ TDeadlines Deadlines_;
+ TIOService::TImpl& Srv_;
+ };
+
+ TDeadlinesQueue DeadlinesQueue_;
+ };
+}
diff --git a/library/cpp/neh/asio/poll_interrupter.cpp b/library/cpp/neh/asio/poll_interrupter.cpp
new file mode 100644
index 0000000000..c96d40c4f3
--- /dev/null
+++ b/library/cpp/neh/asio/poll_interrupter.cpp
@@ -0,0 +1 @@
+#include "poll_interrupter.h"
diff --git a/library/cpp/neh/asio/poll_interrupter.h b/library/cpp/neh/asio/poll_interrupter.h
new file mode 100644
index 0000000000..faf815c512
--- /dev/null
+++ b/library/cpp/neh/asio/poll_interrupter.h
@@ -0,0 +1,107 @@
+#pragma once
+
+#include <util/system/defaults.h>
+#include <util/generic/yexception.h>
+#include <util/network/socket.h>
+#include <util/system/pipe.h>
+
+#ifdef _linux_
+#include <sys/eventfd.h>
+#endif
+
+#if defined(_bionic_) && !defined(EFD_SEMAPHORE)
+#define EFD_SEMAPHORE 1
+#endif
+
+namespace NAsio {
+#ifdef _linux_
+ class TEventFdPollInterrupter {
+ public:
+ inline TEventFdPollInterrupter() {
+ F_ = eventfd(0, EFD_NONBLOCK | EFD_SEMAPHORE);
+ if (F_ < 0) {
+ ythrow TFileError() << "failed to create a eventfd";
+ }
+ }
+
+ inline ~TEventFdPollInterrupter() {
+ close(F_);
+ }
+
+ inline void Interrupt() const noexcept {
+ const static eventfd_t ev(1);
+ ssize_t res = ::write(F_, &ev, sizeof ev);
+ Y_UNUSED(res);
+ }
+
+ inline bool Reset() const noexcept {
+ eventfd_t ev(0);
+
+ for (;;) {
+ ssize_t res = ::read(F_, &ev, sizeof ev);
+ if (res && res == EINTR) {
+ continue;
+ }
+
+ return res > 0;
+ }
+ }
+
+ int Fd() {
+ return F_;
+ }
+
+ private:
+ int F_;
+ };
+#endif
+
+ class TPipePollInterrupter {
+ public:
+ TPipePollInterrupter() {
+ TPipeHandle::Pipe(S_[0], S_[1]);
+
+ SetNonBlock(S_[0]);
+ SetNonBlock(S_[1]);
+ }
+
+ inline void Interrupt() const noexcept {
+ char byte = 0;
+ ssize_t res = S_[1].Write(&byte, 1);
+ Y_UNUSED(res);
+ }
+
+ inline bool Reset() const noexcept {
+ char buff[256];
+
+ for (;;) {
+ ssize_t r = S_[0].Read(buff, sizeof buff);
+
+ if (r < 0 && r == EINTR) {
+ continue;
+ }
+
+ bool wasInterrupted = r > 0;
+
+ while (r == sizeof buff) {
+ r = S_[0].Read(buff, sizeof buff);
+ }
+
+ return wasInterrupted;
+ }
+ }
+
+ PIPEHANDLE Fd() const noexcept {
+ return S_[0];
+ }
+
+ private:
+ TPipeHandle S_[2];
+ };
+
+#ifdef _linux_
+ typedef TEventFdPollInterrupter TPollInterrupter; //more effective than pipe, but only linux impl.
+#else
+ typedef TPipePollInterrupter TPollInterrupter;
+#endif
+}
diff --git a/library/cpp/neh/asio/tcp_acceptor_impl.cpp b/library/cpp/neh/asio/tcp_acceptor_impl.cpp
new file mode 100644
index 0000000000..7e1d75fcf5
--- /dev/null
+++ b/library/cpp/neh/asio/tcp_acceptor_impl.cpp
@@ -0,0 +1,25 @@
+#include "tcp_acceptor_impl.h"
+
+using namespace NAsio;
+
+bool TOperationAccept::Execute(int errorCode) {
+ if (errorCode) {
+ H_(errorCode, *this);
+
+ return true;
+ }
+
+ struct sockaddr_storage addr;
+ socklen_t sz = sizeof(addr);
+
+ SOCKET res = ::accept(Fd(), (sockaddr*)&addr, &sz);
+
+ if (res == INVALID_SOCKET) {
+ H_(LastSystemError(), *this);
+ } else {
+ NS_.Assign(res, TEndpoint(new NAddr::TOpaqueAddr((sockaddr*)&addr)));
+ H_(0, *this);
+ }
+
+ return true;
+}
diff --git a/library/cpp/neh/asio/tcp_acceptor_impl.h b/library/cpp/neh/asio/tcp_acceptor_impl.h
new file mode 100644
index 0000000000..c990236efc
--- /dev/null
+++ b/library/cpp/neh/asio/tcp_acceptor_impl.h
@@ -0,0 +1,76 @@
+#pragma once
+
+#include "asio.h"
+
+#include "tcp_socket_impl.h"
+
+namespace NAsio {
+ class TOperationAccept: public TFdOperation {
+ public:
+ TOperationAccept(SOCKET fd, TTcpSocket::TImpl& newSocket, TTcpAcceptor::TAcceptHandler h, TInstant deadline)
+ : TFdOperation(fd, PollRead, deadline)
+ , H_(h)
+ , NS_(newSocket)
+ {
+ }
+
+ bool Execute(int errorCode) override;
+
+ TTcpAcceptor::TAcceptHandler H_;
+ TTcpSocket::TImpl& NS_;
+ };
+
+ class TTcpAcceptor::TImpl: public TThrRefBase {
+ public:
+ TImpl(TIOService::TImpl& srv) noexcept
+ : Srv_(srv)
+ {
+ }
+
+ inline void Bind(TEndpoint& ep, TErrorCode& ec) noexcept {
+ TSocketHolder s(socket(ep.SockAddr()->sa_family, SOCK_STREAM, 0));
+
+ if (s == INVALID_SOCKET) {
+ ec.Assign(LastSystemError());
+ }
+
+ FixIPv6ListenSocket(s);
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_REUSEADDR, 1, "reuse addr");
+ SetNonBlock(s);
+
+ if (::bind(s, ep.SockAddr(), ep.SockAddrLen())) {
+ ec.Assign(LastSystemError());
+ return;
+ }
+
+ S_.Swap(s);
+ }
+
+ inline void Listen(int backlog, TErrorCode& ec) noexcept {
+ if (::listen(S_, backlog)) {
+ ec.Assign(LastSystemError());
+ return;
+ }
+ }
+
+ inline void AsyncAccept(TTcpSocket& s, TTcpAcceptor::TAcceptHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationAccept((SOCKET)S_, s.GetImpl(), h, deadline)); //set callback
+ }
+
+ inline void AsyncCancel() {
+ Srv_.ScheduleOp(new TOperationCancel<TTcpAcceptor::TImpl>(this));
+ }
+
+ inline TIOService::TImpl& GetIOServiceImpl() const noexcept {
+ return Srv_;
+ }
+
+ inline SOCKET Fd() const noexcept {
+ return S_;
+ }
+
+ private:
+ TIOService::TImpl& Srv_;
+ TSocketHolder S_;
+ };
+}
diff --git a/library/cpp/neh/asio/tcp_socket_impl.cpp b/library/cpp/neh/asio/tcp_socket_impl.cpp
new file mode 100644
index 0000000000..98cef97561
--- /dev/null
+++ b/library/cpp/neh/asio/tcp_socket_impl.cpp
@@ -0,0 +1,117 @@
+#include "tcp_socket_impl.h"
+
+using namespace NAsio;
+
+TSocketOperation::TSocketOperation(TTcpSocket::TImpl& s, TPollType pt, TInstant deadline)
+ : TFdOperation(s.Fd(), pt, deadline)
+ , S_(s)
+{
+}
+
+bool TOperationWrite::Execute(int errorCode) {
+ if (errorCode) {
+ H_(errorCode, Written_, *this);
+
+ return true; //op. completed
+ }
+
+ TErrorCode ec;
+ TContIOVector& iov = *Buffs_->GetIOvec();
+
+ size_t n = S_.WriteSome(iov, ec);
+
+ if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) {
+ H_(ec, Written_ + n, *this);
+
+ return true;
+ }
+
+ if (n) {
+ Written_ += n;
+ iov.Proceed(n);
+ if (!iov.Bytes()) {
+ H_(ec, Written_, *this);
+
+ return true; //op. completed
+ }
+ }
+
+ return false; //operation not compleled
+}
+
+bool TOperationWriteVector::Execute(int errorCode) {
+ if (errorCode) {
+ H_(errorCode, Written_, *this);
+
+ return true; //op. completed
+ }
+
+ TErrorCode ec;
+
+ size_t n = S_.WriteSome(V_, ec);
+
+ if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) {
+ H_(ec, Written_ + n, *this);
+
+ return true;
+ }
+
+ if (n) {
+ Written_ += n;
+ V_.Proceed(n);
+ if (!V_.Bytes()) {
+ H_(ec, Written_, *this);
+
+ return true; //op. completed
+ }
+ }
+
+ return false; //operation not compleled
+}
+
+bool TOperationReadSome::Execute(int errorCode) {
+ if (errorCode) {
+ H_(errorCode, 0, *this);
+
+ return true; //op. completed
+ }
+
+ TErrorCode ec;
+
+ H_(ec, S_.ReadSome(Buff_, Size_, ec), *this);
+
+ return true;
+}
+
+bool TOperationRead::Execute(int errorCode) {
+ if (errorCode) {
+ H_(errorCode, Read_, *this);
+
+ return true; //op. completed
+ }
+
+ TErrorCode ec;
+ size_t n = S_.ReadSome(Buff_, Size_, ec);
+ Read_ += n;
+
+ if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) {
+ H_(ec, Read_, *this);
+
+ return true; //op. completed
+ }
+
+ if (n) {
+ Size_ -= n;
+ if (!Size_) {
+ H_(ec, Read_, *this);
+
+ return true;
+ }
+ Buff_ += n;
+ } else if (!ec) { // EOF while read not all
+ H_(ec, Read_, *this);
+ return true;
+ }
+
+ return false;
+}
diff --git a/library/cpp/neh/asio/tcp_socket_impl.h b/library/cpp/neh/asio/tcp_socket_impl.h
new file mode 100644
index 0000000000..44f8f42d87
--- /dev/null
+++ b/library/cpp/neh/asio/tcp_socket_impl.h
@@ -0,0 +1,332 @@
+#pragma once
+
+#include "asio.h"
+#include "io_service_impl.h"
+
+#include <sys/uio.h>
+
+#if defined(_bionic_)
+# define IOV_MAX 1024
+#endif
+
+namespace NAsio {
+ // ownership/keep-alive references:
+ // Handlers <- TOperation...(TFdOperation) <- TPollFdEventHandler <- TIOService
+
+ class TSocketOperation: public TFdOperation {
+ public:
+ TSocketOperation(TTcpSocket::TImpl& s, TPollType pt, TInstant deadline);
+
+ protected:
+ TTcpSocket::TImpl& S_;
+ };
+
+ class TOperationConnect: public TSocketOperation {
+ public:
+ TOperationConnect(TTcpSocket::TImpl& s, TTcpSocket::TConnectHandler h, TInstant deadline)
+ : TSocketOperation(s, PollWrite, deadline)
+ , H_(h)
+ {
+ }
+
+ bool Execute(int errorCode) override {
+ H_(errorCode, *this);
+
+ return true;
+ }
+
+ TTcpSocket::TConnectHandler H_;
+ };
+
+ class TOperationConnectFailed: public TSocketOperation {
+ public:
+ TOperationConnectFailed(TTcpSocket::TImpl& s, TTcpSocket::TConnectHandler h, int errorCode, TInstant deadline)
+ : TSocketOperation(s, PollWrite, deadline)
+ , H_(h)
+ , ErrorCode_(errorCode)
+ {
+ Speculative_ = true;
+ }
+
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ H_(ErrorCode_, *this);
+
+ return true;
+ }
+
+ TTcpSocket::TConnectHandler H_;
+ int ErrorCode_;
+ };
+
+ class TOperationWrite: public TSocketOperation {
+ public:
+ TOperationWrite(TTcpSocket::TImpl& s, NAsio::TTcpSocket::TSendedData& buffs, TTcpSocket::TWriteHandler h, TInstant deadline)
+ : TSocketOperation(s, PollWrite, deadline)
+ , H_(h)
+ , Buffs_(buffs)
+ , Written_(0)
+ {
+ Speculative_ = true;
+ }
+
+ //return true, if not need write more data
+ bool Execute(int errorCode) override;
+
+ private:
+ TTcpSocket::TWriteHandler H_;
+ NAsio::TTcpSocket::TSendedData Buffs_;
+ size_t Written_;
+ };
+
+ class TOperationWriteVector: public TSocketOperation {
+ public:
+ TOperationWriteVector(TTcpSocket::TImpl& s, TContIOVector* v, TTcpSocket::TWriteHandler h, TInstant deadline)
+ : TSocketOperation(s, PollWrite, deadline)
+ , H_(h)
+ , V_(*v)
+ , Written_(0)
+ {
+ Speculative_ = true;
+ }
+
+ //return true, if not need write more data
+ bool Execute(int errorCode) override;
+
+ private:
+ TTcpSocket::TWriteHandler H_;
+ TContIOVector& V_;
+ size_t Written_;
+ };
+
+ class TOperationReadSome: public TSocketOperation {
+ public:
+ TOperationReadSome(TTcpSocket::TImpl& s, void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline)
+ : TSocketOperation(s, PollRead, deadline)
+ , H_(h)
+ , Buff_(static_cast<char*>(buff))
+ , Size_(size)
+ {
+ }
+
+ //return true, if not need read more data
+ bool Execute(int errorCode) override;
+
+ protected:
+ TTcpSocket::TReadHandler H_;
+ char* Buff_;
+ size_t Size_;
+ };
+
+ class TOperationRead: public TOperationReadSome {
+ public:
+ TOperationRead(TTcpSocket::TImpl& s, void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline)
+ : TOperationReadSome(s, buff, size, h, deadline)
+ , Read_(0)
+ {
+ }
+
+ bool Execute(int errorCode) override;
+
+ private:
+ size_t Read_;
+ };
+
+ class TOperationPoll: public TSocketOperation {
+ public:
+ TOperationPoll(TTcpSocket::TImpl& s, TPollType pt, TTcpSocket::TPollHandler h, TInstant deadline)
+ : TSocketOperation(s, pt, deadline)
+ , H_(h)
+ {
+ }
+
+ bool Execute(int errorCode) override {
+ H_(errorCode, *this);
+
+ return true;
+ }
+
+ private:
+ TTcpSocket::TPollHandler H_;
+ };
+
+ template <class T>
+ class TOperationCancel: public TNoneOperation {
+ public:
+ TOperationCancel(T* s)
+ : TNoneOperation()
+ , S_(s)
+ {
+ Speculative_ = true;
+ }
+
+ ~TOperationCancel() override {
+ }
+
+ private:
+ bool Execute(int errorCode) override {
+ Y_UNUSED(errorCode);
+ if (!errorCode && S_->Fd() != INVALID_SOCKET) {
+ S_->GetIOServiceImpl().CancelFdOp(S_->Fd());
+ }
+ return true;
+ }
+
+ TIntrusivePtr<T> S_;
+ };
+
+ class TTcpSocket::TImpl: public TNonCopyable, public TThrRefBase {
+ public:
+ typedef TTcpSocket::TSendedData TSendedData;
+
+ TImpl(TIOService::TImpl& srv) noexcept
+ : Srv_(srv)
+ {
+ }
+
+ ~TImpl() override {
+ DBGOUT("TSocket::~TImpl()");
+ }
+
+ void Assign(SOCKET fd, TEndpoint ep) {
+ TSocketHolder(fd).Swap(S_);
+ RemoteEndpoint_ = ep;
+ }
+
+ void AsyncConnect(const TEndpoint& ep, TTcpSocket::TConnectHandler h, TInstant deadline) {
+ TSocketHolder s(socket(ep.SockAddr()->sa_family, SOCK_STREAM, 0));
+
+ if (Y_UNLIKELY(s == INVALID_SOCKET || Srv_.HasAbort())) {
+ throw TSystemError() << TStringBuf("can't create socket");
+ }
+
+ SetNonBlock(s);
+
+ int err;
+ do {
+ err = connect(s, ep.SockAddr(), (int)ep.SockAddrLen());
+ if (Y_LIKELY(err)) {
+ err = LastSystemError();
+ }
+#if defined(_freebsd_)
+ if (Y_UNLIKELY(err == EINTR)) {
+ err = EINPROGRESS;
+ }
+ } while (0);
+#elif defined(_linux_)
+ } while (Y_UNLIKELY(err == EINTR));
+#else
+ } while (0);
+#endif
+
+ RemoteEndpoint_ = ep;
+ S_.Swap(s);
+
+ DBGOUT("AsyncConnect(): " << err);
+ if (Y_LIKELY(err == EINPROGRESS || err == EWOULDBLOCK || err == 0)) {
+ Srv_.ScheduleOp(new TOperationConnect(*this, h, deadline)); //set callback
+ } else {
+ Srv_.ScheduleOp(new TOperationConnectFailed(*this, h, err, deadline)); //set callback
+ }
+ }
+
+ inline void AsyncWrite(TSendedData& d, TTcpSocket::TWriteHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationWrite(*this, d, h, deadline));
+ }
+
+ inline void AsyncWrite(TContIOVector* v, TTcpSocket::TWriteHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationWriteVector(*this, v, h, deadline));
+ }
+
+ inline void AsyncRead(void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationRead(*this, buff, size, h, deadline));
+ }
+
+ inline void AsyncReadSome(void* buff, size_t size, TTcpSocket::TReadHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationReadSome(*this, buff, size, h, deadline));
+ }
+
+ inline void AsyncPollWrite(TTcpSocket::TPollHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationPoll(*this, TOperationPoll::PollWrite, h, deadline));
+ }
+
+ inline void AsyncPollRead(TTcpSocket::TPollHandler h, TInstant deadline) {
+ Srv_.ScheduleOp(new TOperationPoll(*this, TOperationPoll::PollRead, h, deadline));
+ }
+
+ inline void AsyncCancel() {
+ if (Y_UNLIKELY(Srv_.HasAbort())) {
+ return;
+ }
+ Srv_.ScheduleOp(new TOperationCancel<TTcpSocket::TImpl>(this));
+ }
+
+ inline bool SysCallHasResult(ssize_t& n, TErrorCode& ec) noexcept {
+ if (n >= 0) {
+ return true;
+ }
+
+ int errn = LastSystemError();
+ if (errn == EINTR) {
+ return false;
+ }
+
+ ec.Assign(errn);
+ n = 0;
+ return true;
+ }
+
+ size_t WriteSome(TContIOVector& iov, TErrorCode& ec) noexcept {
+ for (;;) {
+ ssize_t n = writev(S_, (const iovec*)iov.Parts(), Min(IOV_MAX, (int)iov.Count()));
+ DBGOUT("WriteSome(): n=" << n);
+ if (SysCallHasResult(n, ec)) {
+ return n;
+ }
+ }
+ }
+
+ size_t WriteSome(const void* buff, size_t size, TErrorCode& ec) noexcept {
+ for (;;) {
+ ssize_t n = send(S_, (char*)buff, size, 0);
+ DBGOUT("WriteSome(): n=" << n);
+ if (SysCallHasResult(n, ec)) {
+ return n;
+ }
+ }
+ }
+
+ size_t ReadSome(void* buff, size_t size, TErrorCode& ec) noexcept {
+ for (;;) {
+ ssize_t n = recv(S_, (char*)buff, size, 0);
+ DBGOUT("ReadSome(): n=" << n);
+ if (SysCallHasResult(n, ec)) {
+ return n;
+ }
+ }
+ }
+
+ inline void Shutdown(TTcpSocket::TShutdownMode mode, TErrorCode& ec) {
+ if (shutdown(S_, mode)) {
+ ec.Assign(LastSystemError());
+ }
+ }
+
+ TIOService::TImpl& GetIOServiceImpl() const noexcept {
+ return Srv_;
+ }
+
+ inline SOCKET Fd() const noexcept {
+ return S_;
+ }
+
+ TEndpoint RemoteEndpoint() const {
+ return RemoteEndpoint_;
+ }
+
+ private:
+ TIOService::TImpl& Srv_;
+ TSocketHolder S_;
+ TEndpoint RemoteEndpoint_;
+ };
+}
diff --git a/library/cpp/neh/conn_cache.cpp b/library/cpp/neh/conn_cache.cpp
new file mode 100644
index 0000000000..a2f0868733
--- /dev/null
+++ b/library/cpp/neh/conn_cache.cpp
@@ -0,0 +1 @@
+#include "conn_cache.h"
diff --git a/library/cpp/neh/conn_cache.h b/library/cpp/neh/conn_cache.h
new file mode 100644
index 0000000000..944ba4cfec
--- /dev/null
+++ b/library/cpp/neh/conn_cache.h
@@ -0,0 +1,149 @@
+#pragma once
+
+#include <string.h>
+
+#include <util/generic/ptr.h>
+#include <util/generic/singleton.h>
+#include <util/generic/string.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <util/thread/lfqueue.h>
+
+#include "http_common.h"
+
+namespace NNeh {
+ namespace NHttp2 {
+ // TConn must be refcounted and contain methods:
+ // void SetCached(bool) noexcept
+ // bool IsValid() const noexcept
+ // void Close() noexcept
+ template <class TConn>
+ class TConnCache {
+ struct TCounter : TAtomicCounter {
+ inline void IncCount(const TConn* const&) {
+ Inc();
+ }
+
+ inline void DecCount(const TConn* const&) {
+ Dec();
+ }
+ };
+
+ public:
+ typedef TIntrusivePtr<TConn> TConnRef;
+
+ class TConnList: public TLockFreeQueue<TConn*, TCounter> {
+ public:
+ ~TConnList() {
+ Clear();
+ }
+
+ inline void Clear() {
+ TConn* conn;
+
+ while (this->Dequeue(&conn)) {
+ conn->Close();
+ conn->UnRef();
+ }
+ }
+
+ inline size_t Size() {
+ return this->GetCounter().Val();
+ }
+ };
+
+ inline void Put(TConnRef& conn, size_t addrId) {
+ conn->SetCached(true);
+ ConnList(addrId).Enqueue(conn.Get());
+ conn->Ref();
+ Y_UNUSED(conn.Release());
+ CachedConn_.Inc();
+ }
+
+ bool Get(TConnRef& conn, size_t addrId) {
+ TConnList& connList = ConnList(addrId);
+ TConn* connTmp;
+
+ while (connList.Dequeue(&connTmp)) {
+ connTmp->SetCached(false);
+ CachedConn_.Dec();
+ if (connTmp->IsValid()) {
+ TConnRef(connTmp).Swap(conn);
+ connTmp->DecRef();
+
+ return true;
+ } else {
+ connTmp->UnRef();
+ }
+ }
+ return false;
+ }
+
+ inline size_t Size() const noexcept {
+ return CachedConn_.Val();
+ }
+
+ inline size_t Validate(size_t addrId) {
+ TConnList& cl = Lst_.Get(addrId);
+ return Validate(cl);
+ }
+
+ //close/remove part of the connections from cache
+ size_t Purge(size_t addrId, size_t frac256) {
+ TConnList& cl = Lst_.Get(addrId);
+ size_t qsize = cl.Size();
+ if (!qsize) {
+ return 0;
+ }
+
+ size_t purgeCounter = ((qsize * frac256) >> 8);
+ if (!purgeCounter && qsize >= 2) {
+ purgeCounter = 1;
+ }
+
+ size_t pc = 0;
+ {
+ TConn* conn;
+ while (purgeCounter-- && cl.Dequeue(&conn)) {
+ conn->SetCached(false);
+ if (conn->IsValid()) {
+ conn->Close();
+ }
+ CachedConn_.Dec();
+ conn->UnRef();
+ ++pc;
+ }
+ }
+ pc += Validate(cl);
+
+ return pc;
+ }
+
+ private:
+ inline TConnList& ConnList(size_t addrId) {
+ return Lst_.Get(addrId);
+ }
+
+ inline size_t Validate(TConnList& cl) {
+ size_t pc = 0;
+ size_t nc = cl.Size();
+
+ TConn* conn;
+ while (nc-- && cl.Dequeue(&conn)) {
+ if (conn->IsValid()) {
+ cl.Enqueue(conn);
+ } else {
+ ++pc;
+ conn->SetCached(false);
+ CachedConn_.Dec();
+ conn->UnRef();
+ }
+ }
+
+ return pc;
+ }
+
+ NNeh::NHttp::TLockFreeSequence<TConnList> Lst_;
+ TAtomicCounter CachedConn_;
+ };
+ }
+}
diff --git a/library/cpp/neh/details.h b/library/cpp/neh/details.h
new file mode 100644
index 0000000000..cf3b6fd765
--- /dev/null
+++ b/library/cpp/neh/details.h
@@ -0,0 +1,99 @@
+#pragma once
+
+#include "neh.h"
+#include <library/cpp/neh/utils.h>
+
+#include <library/cpp/http/io/headers.h>
+
+#include <util/generic/singleton.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+
+namespace NNeh {
+ class TNotifyHandle: public THandle {
+ public:
+ inline TNotifyHandle(IOnRecv* r, const TMessage& msg, TStatCollector* s = nullptr) noexcept
+ : THandle(r, s)
+ , Msg_(msg)
+ , StartTime_(TInstant::Now())
+ {
+ }
+
+ void NotifyResponse(const TString& resp, const TString& firstLine = {}, const THttpHeaders& headers = Default<THttpHeaders>()) {
+ Notify(new TResponse(Msg_, resp, ExecDuration(), firstLine, headers));
+ }
+
+ void NotifyError(const TString& errorText) {
+ Notify(TResponse::FromError(Msg_, new TError(errorText), ExecDuration()));
+ }
+
+ void NotifyError(TErrorRef error) {
+ Notify(TResponse::FromError(Msg_, error, ExecDuration()));
+ }
+
+ /** Calls when asnwer is received and reponse has headers and first line.
+ */
+ void NotifyError(TErrorRef error, const TString& data, const TString& firstLine, const THttpHeaders& headers) {
+ Notify(TResponse::FromError(Msg_, error, data, ExecDuration(), firstLine, headers));
+ }
+
+ const TMessage& Message() const noexcept {
+ return Msg_;
+ }
+
+ private:
+ inline TDuration ExecDuration() const {
+ TInstant now = TInstant::Now();
+ if (now > StartTime_) {
+ return now - StartTime_;
+ }
+
+ return TDuration::Zero();
+ }
+
+ const TMessage Msg_;
+ const TInstant StartTime_;
+ };
+
+ typedef TIntrusivePtr<TNotifyHandle> TNotifyHandleRef;
+
+ class TSimpleHandle: public TNotifyHandle {
+ public:
+ inline TSimpleHandle(IOnRecv* r, const TMessage& msg, TStatCollector* s = nullptr) noexcept
+ : TNotifyHandle(r, msg, s)
+ , SendComplete_(false)
+ , Canceled_(false)
+ {
+ }
+
+ bool MessageSendedCompletely() const noexcept override {
+ return SendComplete_;
+ }
+
+ void Cancel() noexcept override {
+ Canceled_ = true;
+ THandle::Cancel();
+ }
+
+ inline void SetSendComplete() noexcept {
+ SendComplete_ = true;
+ }
+
+ inline bool Canceled() const noexcept {
+ return Canceled_;
+ }
+
+ inline const TAtomicBool* CanceledPtr() const noexcept {
+ return &Canceled_;
+ }
+
+ void ResetOnRecv() noexcept {
+ F_ = nullptr;
+ }
+
+ private:
+ TAtomicBool SendComplete_;
+ TAtomicBool Canceled_;
+ };
+
+ typedef TIntrusivePtr<TSimpleHandle> TSimpleHandleRef;
+}
diff --git a/library/cpp/neh/factory.cpp b/library/cpp/neh/factory.cpp
new file mode 100644
index 0000000000..5b9fd700e6
--- /dev/null
+++ b/library/cpp/neh/factory.cpp
@@ -0,0 +1,67 @@
+#include "factory.h"
+#include "udp.h"
+#include "netliba.h"
+#include "https.h"
+#include "http2.h"
+#include "inproc.h"
+#include "tcp.h"
+#include "tcp2.h"
+
+#include <util/generic/hash.h>
+#include <util/generic/strbuf.h>
+#include <util/generic/singleton.h>
+
+using namespace NNeh;
+
+namespace {
+ class TProtocolFactory: public IProtocolFactory, public THashMap<TStringBuf, IProtocol*> {
+ public:
+ inline TProtocolFactory() {
+ Register(NetLibaProtocol());
+ Register(Http1Protocol());
+ Register(Post1Protocol());
+ Register(Full1Protocol());
+ Register(UdpProtocol());
+ Register(InProcProtocol());
+ Register(TcpProtocol());
+ Register(Tcp2Protocol());
+ Register(Http2Protocol());
+ Register(Post2Protocol());
+ Register(Full2Protocol());
+ Register(SSLGetProtocol());
+ Register(SSLPostProtocol());
+ Register(SSLFullProtocol());
+ Register(UnixSocketGetProtocol());
+ Register(UnixSocketPostProtocol());
+ Register(UnixSocketFullProtocol());
+ }
+
+ IProtocol* Protocol(const TStringBuf& proto) override {
+ const_iterator it = find(proto);
+
+ if (it == end()) {
+ ythrow yexception() << "unsupported scheme " << proto;
+ }
+
+ return it->second;
+ }
+
+ void Register(IProtocol* proto) override {
+ (*this)[proto->Scheme()] = proto;
+ }
+ };
+
+ IProtocolFactory* GLOBAL_FACTORY = nullptr;
+}
+
+void NNeh::SetGlobalFactory(IProtocolFactory* factory) {
+ GLOBAL_FACTORY = factory;
+}
+
+IProtocolFactory* NNeh::ProtocolFactory() {
+ if (GLOBAL_FACTORY) {
+ return GLOBAL_FACTORY;
+ }
+
+ return Singleton<TProtocolFactory>();
+}
diff --git a/library/cpp/neh/factory.h b/library/cpp/neh/factory.h
new file mode 100644
index 0000000000..17bebef8ed
--- /dev/null
+++ b/library/cpp/neh/factory.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include "neh.h"
+#include "rpc.h"
+
+namespace NNeh {
+ struct TParsedLocation;
+
+ class IProtocol {
+ public:
+ virtual ~IProtocol() {
+ }
+ virtual IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) = 0;
+ virtual THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef&) = 0;
+ virtual THandleRef ScheduleAsyncRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& statRef, bool useAsyncSendRequest = false) {
+ Y_UNUSED(useAsyncSendRequest);
+ return ScheduleRequest(msg, fallback, statRef);
+ }
+ virtual TStringBuf Scheme() const noexcept = 0;
+ virtual bool SetOption(TStringBuf name, TStringBuf value) {
+ Y_UNUSED(name);
+ Y_UNUSED(value);
+ return false;
+ }
+ };
+
+ class IProtocolFactory {
+ public:
+ virtual IProtocol* Protocol(const TStringBuf& scheme) = 0;
+ virtual void Register(IProtocol* proto) = 0;
+ virtual ~IProtocolFactory() {
+ }
+ };
+
+ void SetGlobalFactory(IProtocolFactory* factory);
+ IProtocolFactory* ProtocolFactory();
+}
diff --git a/library/cpp/neh/http2.cpp b/library/cpp/neh/http2.cpp
new file mode 100644
index 0000000000..0bba29cf22
--- /dev/null
+++ b/library/cpp/neh/http2.cpp
@@ -0,0 +1,2102 @@
+#include "http2.h"
+
+#include "conn_cache.h"
+#include "details.h"
+#include "factory.h"
+#include "http_common.h"
+#include "smart_ptr.h"
+#include "utils.h"
+
+#include <library/cpp/http/push_parser/http_parser.h>
+#include <library/cpp/http/misc/httpcodes.h>
+#include <library/cpp/http/misc/parsed_request.h>
+#include <library/cpp/neh/asio/executor.h>
+
+#include <util/generic/singleton.h>
+#include <util/generic/vector.h>
+#include <util/network/iovec.h>
+#include <util/stream/output.h>
+#include <util/stream/zlib.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+#include <util/system/spinlock.h>
+#include <util/system/yassert.h>
+#include <util/thread/factory.h>
+#include <util/thread/singleton.h>
+#include <util/system/sanitizers.h>
+
+#include <atomic>
+
+#if defined(_unix_)
+#include <sys/ioctl.h>
+#endif
+
+#if defined(_linux_)
+#undef SIOCGSTAMP
+#undef SIOCGSTAMPNS
+#include <linux/sockios.h>
+#define FIONWRITE SIOCOUTQ
+#endif
+
+//#define DEBUG_HTTP2
+
+#ifdef DEBUG_HTTP2
+#define DBGOUT(args) Cout << args << Endl;
+#else
+#define DBGOUT(args)
+#endif
+
+using namespace NDns;
+using namespace NAsio;
+using namespace NNeh;
+using namespace NNeh::NHttp;
+using namespace NNeh::NHttp2;
+using namespace std::placeholders;
+
+//
+// has complex keep-alive references between entities in multi-thread enviroment,
+// this create risks for races/memory leak, etc..
+// so connecting/disconnecting entities must be doing carefully
+//
+// handler <=-> request <==> connection(socket) <= handlers, stored in io_service
+// ^
+// +== connections_cache
+// '=>' -- shared/intrusive ptr
+// '->' -- weak_ptr
+//
+
+static TDuration FixTimeoutForSanitizer(const TDuration timeout) {
+ ui64 multiplier = 1;
+ if (NSan::ASanIsOn()) {
+ // https://github.com/google/sanitizers/wiki/AddressSanitizer
+ multiplier = 4;
+ } else if (NSan::MSanIsOn()) {
+ // via https://github.com/google/sanitizers/wiki/MemorySanitizer
+ multiplier = 3;
+ } else if (NSan::TSanIsOn()) {
+ // via https://clang.llvm.org/docs/ThreadSanitizer.html
+ multiplier = 15;
+ }
+
+ return TDuration::FromValue(timeout.GetValue() * multiplier);
+}
+
+TDuration THttp2Options::ConnectTimeout = FixTimeoutForSanitizer(TDuration::MilliSeconds(1000));
+TDuration THttp2Options::InputDeadline = TDuration::Max();
+TDuration THttp2Options::OutputDeadline = TDuration::Max();
+TDuration THttp2Options::SymptomSlowConnect = FixTimeoutForSanitizer(TDuration::MilliSeconds(10));
+size_t THttp2Options::InputBufferSize = 16 * 1024;
+bool THttp2Options::KeepInputBufferForCachedConnections = false;
+size_t THttp2Options::AsioThreads = 4;
+size_t THttp2Options::AsioServerThreads = 4;
+bool THttp2Options::EnsureSendingCompleteByAck = false;
+int THttp2Options::Backlog = 100;
+TDuration THttp2Options::ServerInputDeadline = FixTimeoutForSanitizer(TDuration::MilliSeconds(500));
+TDuration THttp2Options::ServerOutputDeadline = TDuration::Max();
+TDuration THttp2Options::ServerInputDeadlineKeepAliveMax = FixTimeoutForSanitizer(TDuration::Seconds(120));
+TDuration THttp2Options::ServerInputDeadlineKeepAliveMin = FixTimeoutForSanitizer(TDuration::Seconds(10));
+bool THttp2Options::ServerUseDirectWrite = false;
+bool THttp2Options::UseResponseAsErrorMessage = false;
+bool THttp2Options::FullHeadersAsErrorMessage = false;
+bool THttp2Options::ErrorDetailsAsResponseBody = false;
+bool THttp2Options::RedirectionNotError = false;
+bool THttp2Options::AnyResponseIsNotError = false;
+bool THttp2Options::TcpKeepAlive = false;
+i32 THttp2Options::LimitRequestsPerConnection = -1;
+bool THttp2Options::QuickAck = false;
+bool THttp2Options::UseAsyncSendRequest = false;
+
+bool THttp2Options::Set(TStringBuf name, TStringBuf value) {
+#define HTTP2_TRY_SET(optType, optName) \
+ if (name == TStringBuf(#optName)) { \
+ optName = FromString<optType>(value); \
+ }
+
+ HTTP2_TRY_SET(TDuration, ConnectTimeout)
+ else HTTP2_TRY_SET(TDuration, InputDeadline)
+ else HTTP2_TRY_SET(TDuration, OutputDeadline)
+ else HTTP2_TRY_SET(TDuration, SymptomSlowConnect) else HTTP2_TRY_SET(size_t, InputBufferSize) else HTTP2_TRY_SET(bool, KeepInputBufferForCachedConnections) else HTTP2_TRY_SET(size_t, AsioThreads) else HTTP2_TRY_SET(size_t, AsioServerThreads) else HTTP2_TRY_SET(bool, EnsureSendingCompleteByAck) else HTTP2_TRY_SET(int, Backlog) else HTTP2_TRY_SET(TDuration, ServerInputDeadline) else HTTP2_TRY_SET(TDuration, ServerOutputDeadline) else HTTP2_TRY_SET(TDuration, ServerInputDeadlineKeepAliveMax) else HTTP2_TRY_SET(TDuration, ServerInputDeadlineKeepAliveMin) else HTTP2_TRY_SET(bool, ServerUseDirectWrite) else HTTP2_TRY_SET(bool, UseResponseAsErrorMessage) else HTTP2_TRY_SET(bool, FullHeadersAsErrorMessage) else HTTP2_TRY_SET(bool, ErrorDetailsAsResponseBody) else HTTP2_TRY_SET(bool, RedirectionNotError) else HTTP2_TRY_SET(bool, AnyResponseIsNotError) else HTTP2_TRY_SET(bool, TcpKeepAlive) else HTTP2_TRY_SET(i32, LimitRequestsPerConnection) else HTTP2_TRY_SET(bool, QuickAck)
+ else HTTP2_TRY_SET(bool, UseAsyncSendRequest) else {
+ return false;
+ }
+ return true;
+}
+
+namespace NNeh {
+ const NDns::TResolvedHost* Resolve(const TStringBuf host, ui16 port, NHttp::EResolverType resolverType);
+}
+
+namespace {
+//#define DEBUG_STAT
+
+#ifdef DEBUG_STAT
+ struct TDebugStat {
+ static std::atomic<size_t> ConnTotal;
+ static std::atomic<size_t> ConnActive;
+ static std::atomic<size_t> ConnCached;
+ static std::atomic<size_t> ConnDestroyed;
+ static std::atomic<size_t> ConnFailed;
+ static std::atomic<size_t> ConnConnCanceled;
+ static std::atomic<size_t> ConnSlow;
+ static std::atomic<size_t> Conn2Success;
+ static std::atomic<size_t> ConnPurgedInCache;
+ static std::atomic<size_t> ConnDestroyedInCache;
+ static std::atomic<size_t> RequestTotal;
+ static std::atomic<size_t> RequestSuccessed;
+ static std::atomic<size_t> RequestFailed;
+ static void Print() {
+ Cout << "ct=" << ConnTotal.load(std::memory_order_acquire)
+ << " ca=" << ConnActive.load(std::memory_order_acquire)
+ << " cch=" << ConnCached.load(std::memory_order_acquire)
+ << " cd=" << ConnDestroyed.load(std::memory_order_acquire)
+ << " cf=" << ConnFailed.load(std::memory_order_acquire)
+ << " ccc=" << ConnConnCanceled.load(std::memory_order_acquire)
+ << " csl=" << ConnSlow.load(std::memory_order_acquire)
+ << " c2s=" << Conn2Success.load(std::memory_order_acquire)
+ << " cpc=" << ConnPurgedInCache.load(std::memory_order_acquire)
+ << " cdc=" << ConnDestroyedInCache.load(std::memory_order_acquire)
+ << " rt=" << RequestTotal.load(std::memory_order_acquire)
+ << " rs=" << RequestSuccessed.load(std::memory_order_acquire)
+ << " rf=" << RequestFailed.load(std::memory_order_acquire)
+ << Endl;
+ }
+ };
+ std::atomic<size_t> TDebugStat::ConnTotal = 0;
+ std::atomic<size_t> TDebugStat::ConnActive = 0;
+ std::atomic<size_t> TDebugStat::ConnCached = 0;
+ std::atomic<size_t> TDebugStat::ConnDestroyed = 0;
+ std::atomic<size_t> TDebugStat::ConnFailed = 0;
+ std::atomic<size_t> TDebugStat::ConnConnCanceled = 0;
+ std::atomic<size_t> TDebugStat::ConnSlow = 0;
+ std::atomic<size_t> TDebugStat::Conn2Success = 0;
+ std::atomic<size_t> TDebugStat::ConnPurgedInCache = 0;
+ std::atomic<size_t> TDebugStat::ConnDestroyedInCache = 0;
+ std::atomic<size_t> TDebugStat::RequestTotal = 0;
+ std::atomic<size_t> TDebugStat::RequestSuccessed = 0;
+ std::atomic<size_t> TDebugStat::RequestFailed = 0;
+#endif
+
+ inline void PrepareSocket(SOCKET s, const TRequestSettings& requestSettings = TRequestSettings()) {
+ if (requestSettings.NoDelay) {
+ SetNoDelay(s, true);
+ }
+ }
+
+ bool Compress(TData& data, const TString& compressionScheme) {
+ if (compressionScheme == "gzip") {
+ try {
+ TData gzipped(data.size());
+ TMemoryOutput out(gzipped.data(), gzipped.size());
+ TZLibCompress c(&out, ZLib::GZip);
+ c.Write(data.data(), data.size());
+ c.Finish();
+ gzipped.resize(out.Buf() - gzipped.data());
+ data.swap(gzipped);
+ return true;
+ } catch (yexception&) {
+ // gzipped data occupies more space than original data
+ }
+ }
+ return false;
+ }
+
+ class THttpRequestBuffers: public NAsio::TTcpSocket::IBuffers {
+ public:
+ THttpRequestBuffers(TRequestData::TPtr rd)
+ : Req_(rd)
+ , Parts_(Req_->Parts())
+ , IOvec_(Parts_.data(), Parts_.size())
+ {
+ }
+
+ TContIOVector* GetIOvec() override {
+ return &IOvec_;
+ }
+
+ private:
+ TRequestData::TPtr Req_;
+ TVector<IOutputStream::TPart> Parts_;
+ TContIOVector IOvec_;
+ };
+
+ struct TRequestGet1: public TRequestGet {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("http");
+ }
+ };
+
+ struct TRequestPost1: public TRequestPost {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("post");
+ }
+ };
+
+ struct TRequestFull1: public TRequestFull {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("full");
+ }
+ };
+
+ struct TRequestGet2: public TRequestGet {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("http2");
+ }
+ };
+
+ struct TRequestPost2: public TRequestPost {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("post2");
+ }
+ };
+
+ struct TRequestFull2: public TRequestFull {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("full2");
+ }
+ };
+
+ struct TRequestUnixSocketGet: public TRequestGet {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("http+unix");
+ }
+
+ static TRequestSettings RequestSettings() {
+ return TRequestSettings()
+ .SetNoDelay(false)
+ .SetResolverType(EResolverType::EUNIXSOCKET);
+ }
+ };
+
+ struct TRequestUnixSocketPost: public TRequestPost {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("post+unix");
+ }
+
+ static TRequestSettings RequestSettings() {
+ return TRequestSettings()
+ .SetNoDelay(false)
+ .SetResolverType(EResolverType::EUNIXSOCKET);
+ }
+ };
+
+ struct TRequestUnixSocketFull: public TRequestFull {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("full+unix");
+ }
+
+ static TRequestSettings RequestSettings() {
+ return TRequestSettings()
+ .SetNoDelay(false)
+ .SetResolverType(EResolverType::EUNIXSOCKET);
+ }
+ };
+
+ typedef TAutoPtr<THttpRequestBuffers> THttpRequestBuffersPtr;
+
+ class THttpRequest;
+ typedef TSharedPtrB<THttpRequest> THttpRequestRef;
+
+ class THttpConn;
+ typedef TIntrusivePtr<THttpConn> THttpConnRef;
+
+ typedef std::function<TRequestData::TPtr(const TMessage&, const TParsedLocation&)> TRequestBuilder;
+
+ class THttpRequest {
+ public:
+ class THandle: public TSimpleHandle {
+ public:
+ THandle(IOnRecv* f, const TMessage& msg, TStatCollector* s) noexcept
+ : TSimpleHandle(f, msg, s)
+ {
+ }
+
+ bool MessageSendedCompletely() const noexcept override {
+ if (TSimpleHandle::MessageSendedCompletely()) {
+ return true;
+ }
+
+ THttpRequestRef req(GetRequest());
+ if (!!req && req->RequestSendedCompletely()) {
+ const_cast<THandle*>(this)->SetSendComplete();
+ }
+
+ return TSimpleHandle::MessageSendedCompletely();
+ }
+
+ void Cancel() noexcept override {
+ if (TSimpleHandle::Canceled()) {
+ return;
+ }
+
+ THttpRequestRef req(GetRequest());
+ if (!!req) {
+ TSimpleHandle::Cancel();
+ req->Cancel();
+ }
+ }
+
+ void NotifyError(TErrorRef error, const THttpParser* rsp = nullptr) {
+#ifdef DEBUG_STAT
+ ++TDebugStat::RequestFailed;
+#endif
+ if (rsp) {
+ TSimpleHandle::NotifyError(error, rsp->DecodedContent(), rsp->FirstLine(), rsp->Headers());
+ } else {
+ TSimpleHandle::NotifyError(error);
+ }
+
+ ReleaseRequest();
+ }
+
+ //not thread safe!
+ void SetRequest(const TWeakPtrB<THttpRequest>& r) noexcept {
+ Req_ = r;
+ }
+
+ private:
+ THttpRequestRef GetRequest() const noexcept {
+ TGuard<TSpinLock> g(SP_);
+ return Req_;
+ }
+
+ void ReleaseRequest() noexcept {
+ TWeakPtrB<THttpRequest> tmp;
+ TGuard<TSpinLock> g(SP_);
+ tmp.Swap(Req_);
+ }
+
+ mutable TSpinLock SP_;
+ TWeakPtrB<THttpRequest> Req_;
+ };
+
+ typedef TIntrusivePtr<THandle> THandleRef;
+
+ static void Run(THandleRef& h, const TMessage& msg, TRequestBuilder f, const TRequestSettings& s) {
+ THttpRequestRef req(new THttpRequest(h, msg, f, s));
+ req->WeakThis_ = req;
+ h->SetRequest(req->WeakThis_);
+ req->Run(req);
+ }
+
+ ~THttpRequest() {
+ DBGOUT("~THttpRequest()");
+ }
+
+ private:
+ THttpRequest(THandleRef& h, TMessage msg, TRequestBuilder f, const TRequestSettings& s)
+ : Hndl_(h)
+ , RequestBuilder_(f)
+ , RequestSettings_(s)
+ , Msg_(std::move(msg))
+ , Loc_(Msg_.Addr)
+ , Addr_(Resolve(Loc_.Host, Loc_.GetPort(), RequestSettings_.ResolverType))
+ , AddrIter_(Addr_->Addr.Begin())
+ , Canceled_(false)
+ , RequestSendedCompletely_(false)
+ {
+ }
+
+ void Run(THttpRequestRef& req);
+
+ public:
+ THttpRequestBuffersPtr BuildRequest() {
+ return new THttpRequestBuffers(RequestBuilder_(Msg_, Loc_));
+ }
+
+ TRequestSettings RequestSettings() {
+ return RequestSettings_;
+ }
+
+ //can create a spare socket in an attempt to decrease connecting time
+ void OnDetectSlowConnecting();
+
+ //remove extra connection on success connec
+ void OnConnect(THttpConn* c);
+
+ //have some response input
+ void OnBeginRead() noexcept {
+ RequestSendedCompletely_ = true;
+ }
+
+ void OnResponse(TAutoPtr<THttpParser>& rsp);
+
+ void OnConnectFailed(THttpConn* c, const TErrorCode& ec);
+ void OnSystemError(THttpConn* c, const TErrorCode& ec);
+ void OnError(THttpConn* c, const TString& errorText);
+
+ bool RequestSendedCompletely() noexcept;
+
+ void Cancel() noexcept;
+
+ private:
+ void NotifyResponse(const TString& resp, const TString& firstLine, const THttpHeaders& headers) {
+ THandleRef h(ReleaseHandler());
+ if (!!h) {
+ h->NotifyResponse(resp, firstLine, headers);
+ }
+ }
+
+ void NotifyError(
+ const TString& errorText,
+ TError::TType errorType = TError::UnknownType,
+ i32 errorCode = 0, i32 systemErrorCode = 0) {
+ NotifyError(new TError(errorText, errorType, errorCode, systemErrorCode));
+ }
+
+ void NotifyError(TErrorRef error, const THttpParser* rsp = nullptr) {
+ THandleRef h(ReleaseHandler());
+ if (!!h) {
+ h->NotifyError(error, rsp);
+ }
+ }
+
+ void Finalize(THttpConn* skipConn = nullptr) noexcept;
+
+ inline THandleRef ReleaseHandler() noexcept {
+ THandleRef h;
+ {
+ TGuard<TSpinLock> g(SL_);
+ h.Swap(Hndl_);
+ }
+ return h;
+ }
+
+ inline THttpConnRef GetConn() noexcept {
+ TGuard<TSpinLock> g(SL_);
+ return Conn_;
+ }
+
+ inline THttpConnRef ReleaseConn() noexcept {
+ THttpConnRef c;
+ {
+ TGuard<TSpinLock> g(SL_);
+ c.Swap(Conn_);
+ }
+ return c;
+ }
+
+ inline THttpConnRef ReleaseConn2() noexcept {
+ THttpConnRef c;
+ {
+ TGuard<TSpinLock> g(SL_);
+ c.Swap(Conn2_);
+ }
+ return c;
+ }
+
+ TSpinLock SL_; //guaranted calling notify() only once (prevent race between asio thread and current)
+ THandleRef Hndl_;
+ TRequestBuilder RequestBuilder_;
+ TRequestSettings RequestSettings_;
+ const TMessage Msg_;
+ const TParsedLocation Loc_;
+ const TResolvedHost* Addr_;
+ TNetworkAddress::TIterator AddrIter_;
+ THttpConnRef Conn_;
+ THttpConnRef Conn2_; //concurrent connection used, if detected slow connecting on first connection
+ TWeakPtrB<THttpRequest> WeakThis_;
+ TAtomicBool Canceled_;
+ TAtomicBool RequestSendedCompletely_;
+ };
+
+ TAtomicCounter* HttpOutConnCounter();
+
+ class THttpConn: public TThrRefBase {
+ public:
+ static THttpConnRef Create(TIOService& srv);
+
+ ~THttpConn() override {
+ DBGOUT("~THttpConn()");
+ Req_.Reset();
+ HttpOutConnCounter()->Dec();
+#ifdef DEBUG_STAT
+ ++TDebugStat::ConnDestroyed;
+#endif
+ }
+
+ void StartRequest(THttpRequestRef req, const TEndpoint& ep, size_t addrId, TDuration slowConn, bool useAsyncSendRequest = false) {
+ {
+ //thread safe linking connection->request
+ TGuard<TSpinLock> g(SL_);
+ Req_ = req;
+ }
+ AddrId_ = addrId;
+ try {
+ TDuration connectDeadline(THttp2Options::ConnectTimeout);
+ if (THttp2Options::ConnectTimeout > slowConn) {
+ //use append non fatal connect deadline, so on first timedout
+ //report about slow connecting to THttpRequest, and continue wait ConnectDeadline_ period
+ connectDeadline = slowConn;
+ ConnectDeadline_ = THttp2Options::ConnectTimeout - slowConn;
+ }
+ DBGOUT("AsyncConnect to " << ep.IpToString());
+ AS_.AsyncConnect(ep, std::bind(&THttpConn::OnConnect, THttpConnRef(this), _1, _2, useAsyncSendRequest), connectDeadline);
+ } catch (...) {
+ ReleaseRequest();
+ throw;
+ }
+ }
+
+ //start next request on keep-alive connection
+ bool StartNextRequest(THttpRequestRef& req, bool useAsyncSendRequest = false) {
+ if (Finalized_) {
+ return false;
+ }
+
+ {
+ //thread safe linking connection->request
+ TGuard<TSpinLock> g(SL_);
+ Req_ = req;
+ }
+
+ RequestWritten_ = false;
+ BeginReadResponse_ = false;
+
+ try {
+ if (!useAsyncSendRequest) {
+ TErrorCode ec;
+ SendRequest(req->BuildRequest(), ec); //throw std::bad_alloc
+ if (ec.Value() == ECANCELED) {
+ OnCancel();
+ } else if (ec) {
+ OnError(ec);
+ }
+ } else {
+ SendRequestAsync(req->BuildRequest()); //throw std::bad_alloc
+ }
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ throw;
+ }
+ return true;
+ }
+
+ //connection received from cache must be validated before using
+ //(process removing closed conection from cache consume some time)
+ inline bool IsValid() const noexcept {
+ return !Finalized_;
+ }
+
+ void SetCached(bool v) noexcept {
+ Cached_ = v;
+ }
+
+ void Close() noexcept {
+ try {
+ Cancel();
+ } catch (...) {
+ }
+ }
+
+ void DetachRequest() noexcept {
+ ReleaseRequest();
+ }
+
+ void Cancel() { //throw std::bad_alloc
+ if (!Canceled_) {
+ Canceled_ = true;
+ Finalized_ = true;
+ OnCancel();
+ AS_.AsyncCancel();
+ }
+ }
+
+ void OnCancel() {
+ THttpRequestRef r(ReleaseRequest());
+ if (!!r) {
+ static const TString reqCanceled("request canceled");
+ r->OnError(this, reqCanceled);
+ }
+ }
+
+ bool RequestSendedCompletely() const noexcept {
+ DBGOUT("RequestSendedCompletely()");
+ if (!Connected_ || !RequestWritten_) {
+ return false;
+ }
+ if (BeginReadResponse_) {
+ return true;
+ }
+#if defined(FIONWRITE)
+ if (THttp2Options::EnsureSendingCompleteByAck) {
+ int nbytes = Max<int>();
+ int err = ioctl(AS_.Native(), FIONWRITE, &nbytes);
+ return err ? false : nbytes == 0;
+ }
+#endif
+ return true;
+ }
+
+ TIOService& GetIOService() const noexcept {
+ return AS_.GetIOService();
+ }
+
+ private:
+ THttpConn(TIOService& srv)
+ : AddrId_(0)
+ , AS_(srv)
+ , BuffSize_(THttp2Options::InputBufferSize)
+ , Connected_(false)
+ , Cached_(false)
+ , Canceled_(false)
+ , Finalized_(false)
+ , InAsyncRead_(false)
+ , RequestWritten_(false)
+ , BeginReadResponse_(false)
+ {
+ HttpOutConnCounter()->Inc();
+ }
+
+ //can be called only from asio
+ void OnConnect(const TErrorCode& ec, IHandlingContext& ctx, bool useAsyncSendRequest = false) {
+ DBGOUT("THttpConn::OnConnect: " << ec.Value());
+ if (Y_UNLIKELY(ec)) {
+ if (ec.Value() == ETIMEDOUT && ConnectDeadline_.GetValue()) {
+ //detect slow connecting (yet not reached final timeout)
+ DBGOUT("OnConnectTimingCheck");
+ THttpRequestRef req(GetRequest());
+ if (!req) {
+ return; //cancel from client thread can ahead us
+ }
+ TDuration newDeadline(ConnectDeadline_);
+ ConnectDeadline_ = TDuration::Zero(); //next timeout is final
+
+ req->OnDetectSlowConnecting();
+ //continue wait connect
+ ctx.ContinueUseHandler(newDeadline);
+
+ return;
+ }
+#ifdef DEBUG_STAT
+ if (ec.Value() != ECANCELED) {
+ ++TDebugStat::ConnFailed;
+ } else {
+ ++TDebugStat::ConnConnCanceled;
+ }
+#endif
+ if (ec.Value() == EIO) {
+ //try get more detail error info
+ char buf[1];
+ TErrorCode errConnect;
+ AS_.ReadSome(buf, 1, errConnect);
+ OnConnectFailed(errConnect.Value() ? errConnect : ec);
+ } else if (ec.Value() == ECANCELED) {
+ // not try connecting to next host ip addr, simple fail
+ OnError(ec);
+ } else {
+ OnConnectFailed(ec);
+ }
+ } else {
+ Connected_ = true;
+
+ THttpRequestRef req(GetRequest());
+ if (!req || Canceled_) {
+ return;
+ }
+
+ try {
+ PrepareSocket(AS_.Native(), req->RequestSettings());
+ if (THttp2Options::TcpKeepAlive) {
+ SetKeepAlive(AS_.Native(), true);
+ }
+ } catch (TSystemError& err) {
+ TErrorCode ec2(err.Status());
+ OnError(ec2);
+ return;
+ }
+
+ req->OnConnect(this);
+
+ THttpRequestBuffersPtr ptr(req->BuildRequest());
+ PrepareParser();
+
+ if (!useAsyncSendRequest) {
+ TErrorCode ec3;
+ SendRequest(ptr, ec3);
+ if (ec3) {
+ OnError(ec3);
+ }
+ } else {
+ SendRequestAsync(ptr);
+ }
+ }
+ }
+
+ void PrepareParser() {
+ Prs_ = new THttpParser();
+ Prs_->Prepare();
+ }
+
+ void SendRequest(const THttpRequestBuffersPtr& bfs, TErrorCode& ec) { //throw std::bad_alloc
+ if (!THttp2Options::UseAsyncSendRequest) {
+ size_t amount = AS_.WriteSome(*bfs->GetIOvec(), ec);
+
+ if (ec && ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK && ec.Value() != EINPROGRESS) {
+ return;
+ }
+ ec.Assign(0);
+
+ bfs->GetIOvec()->Proceed(amount);
+
+ if (bfs->GetIOvec()->Complete()) {
+ RequestWritten_ = true;
+ StartRead();
+ } else {
+ SendRequestAsync(bfs);
+ }
+ } else {
+ SendRequestAsync(bfs);
+ }
+ }
+
+ void SendRequestAsync(const THttpRequestBuffersPtr& bfs) {
+ NAsio::TTcpSocket::TSendedData sd(bfs.Release());
+ AS_.AsyncWrite(sd, std::bind(&THttpConn::OnWrite, THttpConnRef(this), _1, _2, _3), THttp2Options::OutputDeadline);
+ }
+
+ void OnWrite(const TErrorCode& err, size_t amount, IHandlingContext& ctx) {
+ Y_UNUSED(amount);
+ Y_UNUSED(ctx);
+ if (err) {
+ OnError(err);
+ } else {
+ DBGOUT("OnWrite()");
+ RequestWritten_ = true;
+ StartRead();
+ }
+ }
+
+ inline void StartRead() {
+ if (!InAsyncRead_ && !Canceled_) {
+ InAsyncRead_ = true;
+ AS_.AsyncPollRead(std::bind(&THttpConn::OnCanRead, THttpConnRef(this), _1, _2), THttp2Options::InputDeadline);
+ }
+ }
+
+ //can be called only from asio
+ void OnReadSome(const TErrorCode& err, size_t bytes, IHandlingContext& ctx) {
+ if (Y_UNLIKELY(err)) {
+ OnError(err);
+ return;
+ }
+ if (!BeginReadResponse_) {
+ //used in MessageSendedCompletely()
+ BeginReadResponse_ = true;
+ THttpRequestRef r(GetRequest());
+ if (!!r) {
+ r->OnBeginRead();
+ }
+ }
+ DBGOUT("receive:" << TStringBuf(Buff_.Get(), bytes));
+ try {
+ if (!Prs_) {
+ throw yexception() << TStringBuf("receive some data while not in request");
+ }
+
+#if defined(_linux_)
+ if (THttp2Options::QuickAck) {
+ SetSockOpt(AS_.Native(), SOL_TCP, TCP_QUICKACK, (int)1);
+ }
+#endif
+
+ DBGOUT("parse:");
+ while (!Prs_->Parse(Buff_.Get(), bytes)) {
+ if (BuffSize_ == bytes) {
+ TErrorCode ec;
+ bytes = AS_.ReadSome(Buff_.Get(), BuffSize_, ec);
+
+ if (!ec) {
+ continue;
+ }
+
+ if (ec.Value() != EAGAIN && ec.Value() != EWOULDBLOCK) {
+ OnError(ec);
+
+ return;
+ }
+ }
+ //continue async. read from socket
+ ctx.ContinueUseHandler(THttp2Options::InputDeadline);
+
+ return;
+ }
+
+ //succesfully reach end of http response
+ THttpRequestRef r(ReleaseRequest());
+ if (!r) {
+ //lost race to req. canceling
+ DBGOUT("connection failed");
+ return;
+ }
+
+ DBGOUT("response:");
+ bool keepALive = Prs_->IsKeepAlive();
+
+ r->OnResponse(Prs_);
+
+ if (!keepALive) {
+ return;
+ }
+
+ //continue use connection (keep-alive mode)
+ PrepareParser();
+
+ if (!THttp2Options::KeepInputBufferForCachedConnections) {
+ Buff_.Destroy();
+ }
+ //continue async. read from socket
+ ctx.ContinueUseHandler(THttp2Options::InputDeadline);
+
+ PutSelfToCache();
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+ }
+
+ void PutSelfToCache();
+
+ //method for reaction on input data for re-used keep-alive connection,
+ //which free/release buffer when was placed in cache
+ void OnCanRead(const TErrorCode& err, IHandlingContext& ctx) {
+ if (Y_UNLIKELY(err)) {
+ OnError(err);
+ } else {
+ if (!Buff_) {
+ Buff_.Reset(new char[BuffSize_]);
+ }
+ TErrorCode ec;
+ OnReadSome(ec, AS_.ReadSome(Buff_.Get(), BuffSize_, ec), ctx);
+ }
+ }
+
+ //unlink connection and request, thread-safe mark connection as non valid
+ inline THttpRequestRef GetRequest() noexcept {
+ TGuard<TSpinLock> g(SL_);
+ return Req_;
+ }
+
+ inline THttpRequestRef ReleaseRequest() noexcept {
+ THttpRequestRef r;
+ {
+ TGuard<TSpinLock> g(SL_);
+ r.Swap(Req_);
+ }
+ return r;
+ }
+
+ void OnConnectFailed(const TErrorCode& ec);
+
+ inline void OnError(const TErrorCode& ec) {
+ OnError(ec.Text());
+ }
+
+ inline void OnError(const TString& errText);
+
+ size_t AddrId_;
+ NAsio::TTcpSocket AS_;
+ TArrayHolder<char> Buff_; //input buffer
+ const size_t BuffSize_;
+
+ TAutoPtr<THttpParser> Prs_; //input data parser & parsed info storage
+
+ TSpinLock SL_;
+ THttpRequestRef Req_; //current request
+ TDuration ConnectDeadline_;
+ TAtomicBool Connected_;
+ TAtomicBool Cached_;
+ TAtomicBool Canceled_;
+ TAtomicBool Finalized_;
+
+ bool InAsyncRead_;
+ TAtomicBool RequestWritten_;
+ TAtomicBool BeginReadResponse_;
+ };
+
+ //conn limits monitoring, cache clean, contain used in http clients asio threads/executors
+ class THttpConnManager: public IThreadFactory::IThreadAble {
+ public:
+ THttpConnManager()
+ : TotalConn(0)
+ , EP_(THttp2Options::AsioThreads)
+ , InPurging_(0)
+ , MaxConnId_(0)
+ , Shutdown_(false)
+ {
+ T_ = SystemThreadFactory()->Run(this);
+ Limits.SetSoft(40000);
+ Limits.SetHard(50000);
+ }
+
+ ~THttpConnManager() override {
+ {
+ TGuard<TMutex> g(PurgeMutex_);
+
+ Shutdown_ = true;
+ CondPurge_.Signal();
+ }
+
+ EP_.SyncShutdown();
+
+ T_->Join();
+ }
+
+ inline void SetLimits(size_t softLimit, size_t hardLimit) noexcept {
+ Limits.SetSoft(softLimit);
+ Limits.SetHard(hardLimit);
+ }
+
+ inline std::pair<size_t, size_t> GetLimits() const noexcept {
+ return {Limits.Soft(), Limits.Hard()};
+ }
+
+ inline void CheckLimits() {
+ if (ExceedSoftLimit()) {
+ SuggestPurgeCache();
+
+ if (ExceedHardLimit()) {
+ Y_FAIL("neh::http2 output connections limit reached");
+ //ythrow yexception() << "neh::http2 output connections limit reached";
+ }
+ }
+ }
+
+ inline bool Get(THttpConnRef& conn, size_t addrId) {
+#ifdef DEBUG_STAT
+ TDebugStat::ConnTotal.store(TotalConn.Val(), std::memory_order_release);
+ TDebugStat::ConnActive(Active(), std::memory_order_release);
+ TDebugStat::ConnCached(Cache_.Size(), std::memory_order_release);
+#endif
+ return Cache_.Get(conn, addrId);
+ }
+
+ inline void Put(THttpConnRef& conn, size_t addrId) {
+ if (Y_LIKELY(!Shutdown_ && !ExceedHardLimit() && !CacheDisabled())) {
+ if (Y_UNLIKELY(addrId > (size_t)AtomicGet(MaxConnId_))) {
+ AtomicSet(MaxConnId_, addrId);
+ }
+ Cache_.Put(conn, addrId);
+ } else {
+ conn->Close();
+ conn.Drop();
+ }
+ }
+
+ inline size_t OnConnError(size_t addrId) {
+ return Cache_.Validate(addrId);
+ }
+
+ TIOService& GetIOService() {
+ return EP_.GetExecutor().GetIOService();
+ }
+
+ bool CacheDisabled() const {
+ return Limits.Soft() == 0;
+ }
+
+ bool IsShutdown() const noexcept {
+ return Shutdown_;
+ }
+
+ TAtomicCounter TotalConn;
+
+ private:
+ inline size_t Total() const noexcept {
+ return TotalConn.Val();
+ }
+
+ inline size_t Active() const noexcept {
+ return TFdLimits::ExceedLimit(Total(), Cache_.Size());
+ }
+
+ inline size_t ExceedSoftLimit() const noexcept {
+ return TFdLimits::ExceedLimit(Total(), Limits.Soft());
+ }
+
+ inline size_t ExceedHardLimit() const noexcept {
+ return TFdLimits::ExceedLimit(Total(), Limits.Hard());
+ }
+
+ void SuggestPurgeCache() {
+ if (AtomicTryLock(&InPurging_)) {
+ //evaluate the usefulness of purging the cache
+ //если в кеше мало соединений (< MaxConnId_/16 или 64), не чистим кеш
+ if (Cache_.Size() > (Min((size_t)AtomicGet(MaxConnId_), (size_t)1024U) >> 4)) {
+ //по мере приближения к hardlimit нужда в чистке cache приближается к 100%
+ size_t closenessToHardLimit256 = ((Active() + 1) << 8) / (Limits.Delta() + 1);
+ //чем больше соединений в кеше, а не в работе, тем менее нужен кеш (можно его почистить)
+ size_t cacheUselessness256 = ((Cache_.Size() + 1) << 8) / (Active() + 1);
+
+ //итого, - пороги срабатывания:
+ //при достижении soft-limit, если соединения в кеше, а не в работе
+ //на полпути от soft-limit к hard-limit, если в кеше больше половины соединений
+ //при приближении к hardlimit пытаться почистить кеш почти постоянно
+ if ((closenessToHardLimit256 + cacheUselessness256) >= 256U) {
+ TGuard<TMutex> g(PurgeMutex_);
+
+ CondPurge_.Signal();
+ return; //memo: thread MUST unlock InPurging_ (see DoExecute())
+ }
+ }
+ AtomicUnlock(&InPurging_);
+ }
+ }
+
+ void DoExecute() override {
+ while (true) {
+ {
+ TGuard<TMutex> g(PurgeMutex_);
+
+ if (Shutdown_)
+ return;
+
+ CondPurge_.WaitI(PurgeMutex_);
+ }
+
+ PurgeCache();
+
+ AtomicUnlock(&InPurging_);
+ }
+ }
+
+ void PurgeCache() noexcept {
+ //try remove at least ExceedSoftLimit() oldest connections from cache
+ //вычисляем долю кеша, которую нужно почистить (в 256 долях) (но не менее 1/32 кеша)
+ size_t frac256 = Min(size_t(Max(size_t(256U / 32U), (ExceedSoftLimit() << 8) / (Cache_.Size() + 1))), (size_t)256U);
+
+ size_t processed = 0;
+ size_t maxConnId = AtomicGet(MaxConnId_);
+ for (size_t i = 0; i <= maxConnId && !Shutdown_; ++i) {
+ processed += Cache_.Purge(i, frac256);
+ if (processed > 32) {
+#ifdef DEBUG_STAT
+ TDebugStat::ConnPurgedInCache += processed;
+#endif
+ processed = 0;
+ Sleep(TDuration::MilliSeconds(10)); //prevent big spike cpu/system usage
+ }
+ }
+ }
+
+ TFdLimits Limits;
+ TExecutorsPool EP_;
+
+ TConnCache<THttpConn> Cache_;
+ TAtomic InPurging_;
+ TAtomic MaxConnId_;
+
+ TAutoPtr<IThreadFactory::IThread> T_;
+ TCondVar CondPurge_;
+ TMutex PurgeMutex_;
+ TAtomicBool Shutdown_;
+ };
+
+ THttpConnManager* HttpConnManager() {
+ return Singleton<THttpConnManager>();
+ }
+
+ TAtomicCounter* HttpOutConnCounter() {
+ return &HttpConnManager()->TotalConn;
+ }
+
+ THttpConnRef THttpConn::Create(TIOService& srv) {
+ if (HttpConnManager()->IsShutdown()) {
+ throw yexception() << "can't create connection with shutdowned service";
+ }
+
+ return new THttpConn(srv);
+ }
+
+ void THttpConn::PutSelfToCache() {
+ THttpConnRef c(this);
+ HttpConnManager()->Put(c, AddrId_);
+ }
+
+ void THttpConn::OnConnectFailed(const TErrorCode& ec) {
+ THttpRequestRef r(GetRequest());
+ if (!!r) {
+ r->OnConnectFailed(this, ec);
+ }
+ OnError(ec);
+ }
+
+ void THttpConn::OnError(const TString& errText) {
+ Finalized_ = true;
+ if (Connected_) {
+ Connected_ = false;
+ TErrorCode ec;
+ AS_.Shutdown(NAsio::TTcpSocket::ShutdownBoth, ec);
+ } else {
+ if (AS_.IsOpen()) {
+ AS_.AsyncCancel();
+ }
+ }
+ THttpRequestRef r(ReleaseRequest());
+ if (!!r) {
+ r->OnError(this, errText);
+ } else {
+ if (Cached_) {
+ size_t res = HttpConnManager()->OnConnError(AddrId_);
+ Y_UNUSED(res);
+#ifdef DEBUG_STAT
+ TDebugStat::ConnDestroyedInCache += res;
+#endif
+ }
+ }
+ }
+
+ void THttpRequest::Run(THttpRequestRef& req) {
+#ifdef DEBUG_STAT
+ if ((++TDebugStat::RequestTotal & 0xFFF) == 0) {
+ TDebugStat::Print();
+ }
+#endif
+ try {
+ while (!Canceled_) {
+ THttpConnRef conn;
+ if (HttpConnManager()->Get(conn, Addr_->Id)) {
+ DBGOUT("Use connection from cache");
+ Conn_ = conn; //thread magic
+ if (!conn->StartNextRequest(req, RequestSettings_.UseAsyncSendRequest)) {
+ continue; //if use connection from cache, ignore write error and try another conn
+ }
+ } else {
+ HttpConnManager()->CheckLimits(); //here throw exception if reach hard limit (or atexit() state)
+ Conn_ = THttpConn::Create(HttpConnManager()->GetIOService());
+ TEndpoint ep(new NAddr::TAddrInfo(&*AddrIter_));
+ Conn_->StartRequest(req, ep, Addr_->Id, THttp2Options::SymptomSlowConnect); // can throw
+ }
+ break;
+ }
+ } catch (...) {
+ Conn_.Reset();
+ throw;
+ }
+ }
+
+ //it seems we have lost TCP SYN packet, create extra connection for decrease response time
+ void THttpRequest::OnDetectSlowConnecting() {
+#ifdef DEBUG_STAT
+ ++TDebugStat::ConnSlow;
+#endif
+ //use some io_service (Run() thread-executor), from first conn. for more thread safety
+ THttpConnRef conn = GetConn();
+
+ if (!conn) {
+ return;
+ }
+
+ THttpConnRef conn2;
+ try {
+ conn2 = THttpConn::Create(conn->GetIOService());
+ } catch (...) {
+ return; // cant create spare connection, simple continue use only main
+ }
+
+ {
+ TGuard<TSpinLock> g(SL_);
+ Conn2_ = conn2;
+ }
+
+ if (Y_UNLIKELY(Canceled_)) {
+ ReleaseConn2();
+ } else {
+ //use connect timeout for disable detecting slow connecting on second conn.
+ TEndpoint ep(new NAddr::TAddrInfo(&*Addr_->Addr.Begin()));
+ try {
+ conn2->StartRequest(WeakThis_, ep, Addr_->Id, THttp2Options::ConnectTimeout);
+ } catch (...) {
+ // ignore errors on spare connection
+ ReleaseConn2();
+ }
+ }
+ }
+
+ void THttpRequest::OnConnect(THttpConn* c) {
+ THttpConnRef extraConn;
+ {
+ TGuard<TSpinLock> g(SL_);
+ if (Y_UNLIKELY(!!Conn2_)) {
+ //has pair concurrent conn, 'should stay only one'
+ if (Conn2_.Get() == c) {
+#ifdef DEBUG_STAT
+ ++TDebugStat::Conn2Success;
+#endif
+ Conn2_.Swap(Conn_);
+ }
+ extraConn.Swap(Conn2_);
+ }
+ }
+ if (!!extraConn) {
+ extraConn->DetachRequest(); //prevent call OnError()
+ extraConn->Close();
+ }
+ }
+
+ void THttpRequest::OnResponse(TAutoPtr<THttpParser>& rsp) {
+ DBGOUT("THttpRequest::OnResponse()");
+ ReleaseConn();
+ if (Y_LIKELY(((rsp->RetCode() >= 200 && rsp->RetCode() < (!THttp2Options::RedirectionNotError ? 300 : 400)) || THttp2Options::AnyResponseIsNotError))) {
+ NotifyResponse(rsp->DecodedContent(), rsp->FirstLine(), rsp->Headers());
+ } else {
+ TString message;
+
+ if (THttp2Options::FullHeadersAsErrorMessage) {
+ TStringStream err;
+ err << rsp->FirstLine();
+
+ THttpHeaders hdrs = rsp->Headers();
+ for (auto h = hdrs.begin(); h < hdrs.end(); h++) {
+ err << h->ToString() << TStringBuf("\r\n");
+ }
+
+ message = err.Str();
+ } else if (THttp2Options::UseResponseAsErrorMessage) {
+ message = rsp->DecodedContent();
+ } else {
+ TStringStream err;
+ err << TStringBuf("request failed(") << rsp->FirstLine() << TStringBuf(")");
+ message = err.Str();
+ }
+
+ NotifyError(new TError(message, TError::ProtocolSpecific, rsp->RetCode()), rsp.Get());
+ }
+ }
+
+ void THttpRequest::OnConnectFailed(THttpConn* c, const TErrorCode& ec) {
+ DBGOUT("THttpRequest::OnConnectFailed()");
+ //detach/discard failed conn, try connect to next ip addr (if can)
+ THttpConnRef cc(GetConn());
+ if (c != cc.Get() || AddrIter_ == Addr_->Addr.End() || ++AddrIter_ == Addr_->Addr.End() || Canceled_) {
+ return OnSystemError(c, ec);
+ }
+ // can try next host addr
+ c->DetachRequest();
+ c->Close();
+ THttpConnRef nextConn;
+ try {
+ nextConn = THttpConn::Create(HttpConnManager()->GetIOService());
+ } catch (...) {
+ OnSystemError(nullptr, ec);
+ return;
+ }
+ {
+ THttpConnRef nc = nextConn;
+ TGuard<TSpinLock> g(SL_);
+ Conn_.Swap(nc);
+ }
+ TEndpoint ep(new NAddr::TAddrInfo(&*AddrIter_));
+ try {
+ nextConn->StartRequest(WeakThis_, ep, Addr_->Id, THttp2Options::SymptomSlowConnect);
+ } catch (...) {
+ OnError(nullptr, CurrentExceptionMessage());
+ return;
+ }
+
+ if (Canceled_) {
+ OnError(nullptr, "canceled");
+ }
+ }
+
+ void THttpRequest::OnSystemError(THttpConn* c, const TErrorCode& ec) {
+ DBGOUT("THttpRequest::OnSystemError()");
+ NotifyError(ec.Text(), TError::TType::UnknownType, 0, ec.Value());
+ Finalize(c);
+ }
+
+ void THttpRequest::OnError(THttpConn* c, const TString& errorText) {
+ DBGOUT("THttpRequest::OnError()");
+ NotifyError(errorText);
+ Finalize(c);
+ }
+
+ bool THttpRequest::RequestSendedCompletely() noexcept {
+ if (RequestSendedCompletely_) {
+ return true;
+ }
+
+ THttpConnRef c(GetConn());
+ return !!c ? c->RequestSendedCompletely() : false;
+ }
+
+ void THttpRequest::Cancel() noexcept {
+ if (!Canceled_) {
+ Canceled_ = true;
+ try {
+ static const TString canceled("Canceled");
+ NotifyError(canceled, TError::Cancelled);
+ Finalize();
+ } catch (...) {
+ }
+ }
+ }
+
+ inline void FinalizeConn(THttpConnRef& c, THttpConn* skipConn) noexcept {
+ if (!!c && c.Get() != skipConn) {
+ c->DetachRequest();
+ c->Close();
+ }
+ }
+
+ void THttpRequest::Finalize(THttpConn* skipConn) noexcept {
+ THttpConnRef c1(ReleaseConn());
+ FinalizeConn(c1, skipConn);
+ THttpConnRef c2(ReleaseConn2());
+ FinalizeConn(c2, skipConn);
+ }
+
+ /////////////////////////////////// server side ////////////////////////////////////
+
+ TAtomicCounter* HttpInConnCounter() {
+ return Singleton<TAtomicCounter>();
+ }
+
+ TFdLimits* HttpInConnLimits() {
+ return Singleton<TFdLimits>();
+ }
+
+ class THttpServer: public IRequester {
+ typedef TAutoPtr<TTcpAcceptor> TTcpAcceptorPtr;
+ typedef TAtomicSharedPtr<TTcpSocket> TTcpSocketRef;
+ class TConn;
+ typedef TSharedPtrB<TConn> TConnRef;
+
+ class TRequest: public IHttpRequest {
+ public:
+ TRequest(TWeakPtrB<TConn>& c, TAutoPtr<THttpParser>& p)
+ : C_(c)
+ , P_(p)
+ , RemoteHost_(C_->RemoteHost())
+ , CompressionScheme_(P_->GetBestCompressionScheme())
+ , H_(TStringBuf(P_->FirstLine()))
+ {
+ }
+
+ ~TRequest() override {
+ if (!!C_) {
+ try {
+ C_->SendError(Id(), 503, "service unavailable (request ignored)", P_->HttpVersion(), {});
+ } catch (...) {
+ DBGOUT("~TRequest()::SendFail() exception");
+ }
+ }
+ }
+
+ TAtomicBase Id() const {
+ return Id_;
+ }
+
+ protected:
+ TStringBuf Scheme() const override {
+ return TStringBuf("http");
+ }
+
+ TString RemoteHost() const override {
+ return RemoteHost_;
+ }
+
+ TStringBuf Service() const override {
+ return TStringBuf(H_.Path).Skip(1);
+ }
+
+ const THttpHeaders& Headers() const override {
+ return P_->Headers();
+ }
+
+ TStringBuf Method() const override {
+ return H_.Method;
+ }
+
+ TStringBuf Body() const override {
+ return P_->DecodedContent();
+ }
+
+ TStringBuf Cgi() const override {
+ return H_.Cgi;
+ }
+
+ TStringBuf RequestId() const override {
+ return TStringBuf();
+ }
+
+ bool Canceled() const override {
+ if (!C_) {
+ return false;
+ }
+ return C_->IsCanceled();
+ }
+
+ void SendReply(TData& data) override {
+ SendReply(data, TString(), HttpCodes::HTTP_OK);
+ }
+
+ void SendReply(TData& data, const TString& headers, int httpCode) override {
+ if (!!C_) {
+ C_->Send(Id(), data, CompressionScheme_, P_->HttpVersion(), headers, httpCode);
+ C_.Reset();
+ }
+ }
+
+ void SendError(TResponseError err, const THttpErrorDetails& details) override {
+ static const unsigned errorToHttpCode[IRequest::MaxResponseError] =
+ {
+ 400,
+ 403,
+ 404,
+ 429,
+ 500,
+ 501,
+ 502,
+ 503,
+ 509};
+
+ if (!!C_) {
+ C_->SendError(Id(), errorToHttpCode[err], details.Details, P_->HttpVersion(), details.Headers);
+ C_.Reset();
+ }
+ }
+
+ static TAtomicBase NextId() {
+ static TAtomic idGenerator = 0;
+ TAtomicBase id = 0;
+ do {
+ id = AtomicIncrement(idGenerator);
+ } while (!id);
+ return id;
+ }
+
+ TSharedPtrB<TConn> C_;
+ TAutoPtr<THttpParser> P_;
+ TString RemoteHost_;
+ TString CompressionScheme_;
+ TParsedHttpFull H_;
+ TAtomicBase Id_ = NextId();
+ };
+
+ class TRequestGet: public TRequest {
+ public:
+ TRequestGet(TWeakPtrB<TConn>& c, TAutoPtr<THttpParser> p)
+ : TRequest(c, p)
+ {
+ }
+
+ TStringBuf Data() const override {
+ return H_.Cgi;
+ }
+ };
+
+ class TRequestPost: public TRequest {
+ public:
+ TRequestPost(TWeakPtrB<TConn>& c, TAutoPtr<THttpParser> p)
+ : TRequest(c, p)
+ {
+ }
+
+ TStringBuf Data() const override {
+ return P_->DecodedContent();
+ }
+ };
+
+ class TConn {
+ private:
+ TConn(THttpServer& hs, const TTcpSocketRef& s)
+ : HS_(hs)
+ , AS_(s)
+ , RemoteHost_(NNeh::PrintHostByRfc(*AS_->RemoteEndpoint().Addr()))
+ , BuffSize_(THttp2Options::InputBufferSize)
+ , Buff_(new char[BuffSize_])
+ , Canceled_(false)
+ , LeftRequestsToDisconnect_(hs.LimitRequestsPerConnection)
+ {
+ DBGOUT("THttpServer::TConn()");
+ HS_.OnCreateConn();
+ }
+
+ inline TConnRef SelfRef() noexcept {
+ return WeakThis_;
+ }
+
+ public:
+ static void Create(THttpServer& hs, const TTcpSocketRef& s) {
+ TSharedPtrB<TConn> conn(new TConn(hs, s));
+ conn->WeakThis_ = conn;
+ conn->ExpectNewRequest();
+ conn->AS_->AsyncPollRead(std::bind(&TConn::OnCanRead, conn, _1, _2), THttp2Options::ServerInputDeadline);
+ }
+
+ ~TConn() {
+ DBGOUT("~THttpServer::TConn(" << (!AS_ ? -666 : AS_->Native()) << ")");
+ HS_.OnDestroyConn();
+ }
+
+ private:
+ void ExpectNewRequest() {
+ P_.Reset(new THttpParser(THttpParser::Request));
+ P_->Prepare();
+ }
+
+ void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) {
+ if (ec) {
+ OnError();
+ } else {
+ TErrorCode ec2;
+ OnReadSome(ec2, AS_->ReadSome(Buff_.Get(), BuffSize_, ec2), ctx);
+ }
+ }
+
+ void OnError() {
+ DBGOUT("Srv OnError(" << (!AS_ ? -666 : AS_->Native()) << ")");
+ Canceled_ = true;
+ AS_->AsyncCancel();
+ }
+
+ void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) {
+ if (ec || !amount) {
+ OnError();
+
+ return;
+ }
+
+ DBGOUT("ReadSome(" << (!AS_ ? -666 : AS_->Native()) << "): " << amount);
+ try {
+ size_t buffPos = 0;
+ //DBGOUT("receive and parse: " << TStringBuf(Buff_.Get(), amount));
+ while (P_->Parse(Buff_.Get() + buffPos, amount - buffPos)) {
+ SeenMessageWithoutKeepalive_ |= !P_->IsKeepAlive() || LeftRequestsToDisconnect_ == 1;
+ char rt = *P_->FirstLine().data();
+ const size_t extraDataSize = P_->GetExtraDataSize();
+ if (rt == 'P' || rt == 'p') {
+ OnRequest(new TRequestPost(WeakThis_, P_));
+ } else {
+ OnRequest(new TRequestGet(WeakThis_, P_));
+ }
+ if (extraDataSize) {
+ // has http pipelining
+ buffPos = amount - extraDataSize;
+ ExpectNewRequest();
+ } else {
+ ExpectNewRequest();
+ ctx.ContinueUseHandler(HS_.GetKeepAliveTimeout());
+ return;
+ }
+ }
+ ctx.ContinueUseHandler(THttp2Options::ServerInputDeadline);
+ } catch (...) {
+ OnError();
+ }
+ }
+
+ void OnRequest(TRequest* r) {
+ DBGOUT("OnRequest()");
+ if (AtomicGet(PrimaryResponse_)) {
+ // has pipelining
+ PipelineOrder_.Enqueue(r->Id());
+ } else {
+ AtomicSet(PrimaryResponse_, r->Id());
+ }
+ HS_.OnRequest(r);
+ OnRequestDone();
+ }
+
+ void OnRequestDone() {
+ DBGOUT("OnRequestDone()");
+ if (LeftRequestsToDisconnect_ > 0) {
+ --LeftRequestsToDisconnect_;
+ }
+ }
+
+ static void PrintHttpVersion(IOutputStream& out, const THttpVersion& ver) {
+ out << TStringBuf("HTTP/") << ver.Major << TStringBuf(".") << ver.Minor;
+ }
+
+ struct TResponseData : TThrRefBase {
+ TResponseData(size_t reqId, TTcpSocket::TSendedData data)
+ : RequestId_(reqId)
+ , Data_(data)
+ {
+ }
+
+ size_t RequestId_;
+ TTcpSocket::TSendedData Data_;
+ };
+ typedef TIntrusivePtr<TResponseData> TResponseDataRef;
+
+ public:
+ //called non thread-safe (from outside thread)
+ void Send(TAtomicBase requestId, TData& data, const TString& compressionScheme, const THttpVersion& ver, const TString& headers, int httpCode) {
+ class THttpResponseFormatter {
+ public:
+ THttpResponseFormatter(TData& theData, const TString& contentEncoding, const THttpVersion& theVer, const TString& theHeaders, int theHttpCode, bool closeConnection) {
+ Header.Reserve(128 + contentEncoding.size() + theHeaders.size());
+ PrintHttpVersion(Header, theVer);
+ Header << TStringBuf(" ") << theHttpCode << ' ' << HttpCodeStr(theHttpCode);
+ if (Compress(theData, contentEncoding)) {
+ Header << TStringBuf("\r\nContent-Encoding: ") << contentEncoding;
+ }
+ Header << TStringBuf("\r\nContent-Length: ") << theData.size();
+ if (closeConnection) {
+ Header << TStringBuf("\r\nConnection: close");
+ } else if (Y_LIKELY(theVer.Major > 1 || theVer.Minor > 0)) {
+ // since HTTP/1.1 Keep-Alive is default behaviour
+ Header << TStringBuf("\r\nConnection: Keep-Alive");
+ }
+ if (theHeaders) {
+ Header << theHeaders;
+ }
+ Header << TStringBuf("\r\n\r\n");
+
+ Body.swap(theData);
+
+ Parts[0].buf = Header.Data();
+ Parts[0].len = Header.Size();
+ Parts[1].buf = Body.data();
+ Parts[1].len = Body.size();
+ }
+
+ TStringStream Header;
+ TData Body;
+ IOutputStream::TPart Parts[2];
+ };
+
+ class TBuffers: public THttpResponseFormatter, public TTcpSocket::IBuffers {
+ public:
+ TBuffers(TData& theData, const TString& contentEncoding, const THttpVersion& theVer, const TString& theHeaders, int theHttpCode, bool closeConnection)
+ : THttpResponseFormatter(theData, contentEncoding, theVer, theHeaders, theHttpCode, closeConnection)
+ , IOVec(Parts, 2)
+ {
+ }
+
+ TContIOVector* GetIOvec() override {
+ return &IOVec;
+ }
+
+ TContIOVector IOVec;
+ };
+
+ TTcpSocket::TSendedData sd(new TBuffers(data, compressionScheme, ver, headers, httpCode, SeenMessageWithoutKeepalive_));
+ SendData(requestId, sd);
+ }
+
+ //called non thread-safe (from outside thread)
+ void SendError(TAtomicBase requestId, unsigned httpCode, const TString& descr, const THttpVersion& ver, const TString& headers) {
+ if (Canceled_) {
+ return;
+ }
+
+ class THttpErrorResponseFormatter {
+ public:
+ THttpErrorResponseFormatter(unsigned theHttpCode, const TString& theDescr, const THttpVersion& theVer, bool closeConnection, const TString& headers) {
+ PrintHttpVersion(Answer, theVer);
+ Answer << TStringBuf(" ") << theHttpCode << TStringBuf(" ");
+ if (theDescr.size() && !THttp2Options::ErrorDetailsAsResponseBody) {
+ // Reason-Phrase = *<TEXT, excluding CR, LF>
+ // replace bad chars to '.'
+ TString reasonPhrase = theDescr;
+ for (TString::iterator it = reasonPhrase.begin(); it != reasonPhrase.end(); ++it) {
+ char& ch = *it;
+ if (ch == ' ') {
+ continue;
+ }
+ if (((ch & 31) == ch) || static_cast<unsigned>(ch) == 127 || (static_cast<unsigned>(ch) & 0x80)) {
+ //CTLs || DEL(127) || non ascii
+ // (ch <= 32) || (ch >= 127)
+ ch = '.';
+ }
+ }
+ Answer << reasonPhrase;
+ } else {
+ Answer << HttpCodeStr(static_cast<int>(theHttpCode));
+ }
+
+ if (closeConnection) {
+ Answer << TStringBuf("\r\nConnection: close");
+ }
+
+ if (headers) {
+ Answer << "\r\n" << headers;
+ }
+
+ if (THttp2Options::ErrorDetailsAsResponseBody) {
+ Answer << TStringBuf("\r\nContent-Length:") << theDescr.size() << "\r\n\r\n" << theDescr;
+ } else {
+ Answer << "\r\n"
+ "Content-Length:0\r\n\r\n"sv;
+ }
+
+ Parts[0].buf = Answer.Data();
+ Parts[0].len = Answer.Size();
+ }
+
+ TStringStream Answer;
+ IOutputStream::TPart Parts[1];
+ };
+
+ class TBuffers: public THttpErrorResponseFormatter, public TTcpSocket::IBuffers {
+ public:
+ TBuffers(
+ unsigned theHttpCode,
+ const TString& theDescr,
+ const THttpVersion& theVer,
+ bool closeConnection,
+ const TString& headers
+ )
+ : THttpErrorResponseFormatter(theHttpCode, theDescr, theVer, closeConnection, headers)
+ , IOVec(Parts, 1)
+ {
+ }
+
+ TContIOVector* GetIOvec() override {
+ return &IOVec;
+ }
+
+ TContIOVector IOVec;
+ };
+
+ TTcpSocket::TSendedData sd(new TBuffers(httpCode, descr, ver, SeenMessageWithoutKeepalive_, headers));
+ SendData(requestId, sd);
+ }
+
+ void ProcessPipeline() {
+ // on successfull response to current (PrimaryResponse_) request
+ TAtomicBase requestId;
+ if (PipelineOrder_.Dequeue(&requestId)) {
+ TAtomicBase oldReqId;
+ do {
+ oldReqId = AtomicGet(PrimaryResponse_);
+ Y_VERIFY(oldReqId, "race inside http pipelining");
+ } while (!AtomicCas(&PrimaryResponse_, requestId, oldReqId));
+
+ ProcessResponsesData();
+ } else {
+ TAtomicBase oldReqId = AtomicGet(PrimaryResponse_);
+ if (oldReqId) {
+ while (!AtomicCas(&PrimaryResponse_, 0, oldReqId)) {
+ Y_VERIFY(oldReqId == AtomicGet(PrimaryResponse_), "race inside http pipelining [2]");
+ }
+ }
+ }
+ }
+
+ void ProcessResponsesData() {
+ // process responses data queue, send response (if already have next PrimaryResponse_)
+ TResponseDataRef rd;
+ while (ResponsesDataQueue_.Dequeue(&rd)) {
+ ResponsesData_[rd->RequestId_] = rd;
+ }
+ TAtomicBase requestId = AtomicGet(PrimaryResponse_);
+ if (requestId) {
+ THashMap<TAtomicBase, TResponseDataRef>::iterator it = ResponsesData_.find(requestId);
+ if (it != ResponsesData_.end()) {
+ // has next primary response
+ rd = it->second;
+ ResponsesData_.erase(it);
+ AS_->AsyncWrite(rd->Data_, std::bind(&TConn::OnSend, SelfRef(), _1, _2, _3), THttp2Options::ServerOutputDeadline);
+ }
+ }
+ }
+
+ private:
+ void SendData(TAtomicBase requestId, TTcpSocket::TSendedData sd) {
+ TContIOVector& vec = *sd->GetIOvec();
+
+ if (requestId != AtomicGet(PrimaryResponse_)) {
+ // already has another request for response first, so push this to queue
+ // + enqueue event for safe checking queue (at local/transport thread)
+ TResponseDataRef rdr = new TResponseData(requestId, sd);
+ ResponsesDataQueue_.Enqueue(rdr);
+ AS_->GetIOService().Post(std::bind(&TConn::ProcessResponsesData, SelfRef()));
+ return;
+ }
+ if (THttp2Options::ServerUseDirectWrite) {
+ vec.Proceed(AS_->WriteSome(vec));
+ }
+ if (!vec.Complete()) {
+ DBGOUT("AsyncWrite()");
+ AS_->AsyncWrite(sd, std::bind(&TConn::OnSend, SelfRef(), _1, _2, _3), THttp2Options::ServerOutputDeadline);
+ } else {
+ // run ProcessPipeline at safe thread
+ AS_->GetIOService().Post(std::bind(&TConn::ProcessPipeline, SelfRef()));
+ }
+ }
+
+ void OnSend(const TErrorCode& ec, size_t amount, IHandlingContext&) {
+ Y_UNUSED(amount);
+ if (ec) {
+ OnError();
+ } else {
+ ProcessPipeline();
+ }
+
+ if (SeenMessageWithoutKeepalive_) {
+ TErrorCode shutdown_ec;
+ AS_->Shutdown(TTcpSocket::ShutdownBoth, shutdown_ec);
+ }
+ }
+
+ public:
+ bool IsCanceled() const noexcept {
+ return Canceled_;
+ }
+
+ const TString& RemoteHost() const noexcept {
+ return RemoteHost_;
+ }
+
+ private:
+ TWeakPtrB<TConn> WeakThis_;
+ THttpServer& HS_;
+ TTcpSocketRef AS_;
+ TString RemoteHost_;
+ size_t BuffSize_;
+ TArrayHolder<char> Buff_;
+ TAutoPtr<THttpParser> P_;
+ // pipeline supporting
+ TAtomic PrimaryResponse_ = 0;
+ TLockFreeQueue<TAtomicBase> PipelineOrder_;
+ TLockFreeQueue<TResponseDataRef> ResponsesDataQueue_;
+ THashMap<TAtomicBase, TResponseDataRef> ResponsesData_;
+
+ TAtomicBool Canceled_;
+ bool SeenMessageWithoutKeepalive_ = false;
+
+ i32 LeftRequestsToDisconnect_ = -1;
+ };
+
+ ///////////////////////////////////////////////////////////
+
+ public:
+ THttpServer(IOnRequest* cb, const TParsedLocation& loc)
+ : E_(THttp2Options::AsioServerThreads)
+ , CB_(cb)
+ , LimitRequestsPerConnection(THttp2Options::LimitRequestsPerConnection)
+ {
+ TNetworkAddress addr(loc.GetPort());
+
+ for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) {
+ TEndpoint ep(new NAddr::TAddrInfo(&*it));
+ TTcpAcceptorPtr a(new TTcpAcceptor(AcceptExecutor_.GetIOService()));
+ DBGOUT("bind:" << ep.IpToString() << ":" << ep.Port());
+ a->Bind(ep);
+ a->Listen(THttp2Options::Backlog);
+ StartAccept(a.Get());
+ A_.push_back(a);
+ }
+ }
+
+ ~THttpServer() override {
+ AcceptExecutor_.SyncShutdown(); //cancel operation for all current sockets (include acceptors)
+ A_.clear(); //stop listening
+ E_.SyncShutdown();
+ }
+
+ void OnAccept(TTcpAcceptor* a, TAtomicSharedPtr<TTcpSocket> s, const TErrorCode& ec, IHandlingContext&) {
+ if (Y_UNLIKELY(ec)) {
+ if (ec.Value() == ECANCELED) {
+ return;
+ } else if (ec.Value() == EMFILE || ec.Value() == ENFILE || ec.Value() == ENOMEM || ec.Value() == ENOBUFS) {
+ //reach some os limit, suspend accepting
+ TAtomicSharedPtr<TDeadlineTimer> dt(new TDeadlineTimer(a->GetIOService()));
+ dt->AsyncWaitExpireAt(TDuration::Seconds(30), std::bind(&THttpServer::OnTimeoutSuspendAccept, this, a, dt, _1, _2));
+ return;
+ } else {
+ Cdbg << "acc: " << ec.Text() << Endl;
+ }
+ } else {
+ if (static_cast<size_t>(HttpInConnCounter()->Val()) < HttpInConnLimits()->Hard()) {
+ try {
+ SetNonBlock(s->Native());
+ PrepareSocket(s->Native());
+ TConn::Create(*this, s);
+ } catch (TSystemError& err) {
+ TErrorCode ec2(err.Status());
+ Cdbg << "acc: " << ec2.Text() << Endl;
+ }
+ } //else accepted socket will be closed
+ }
+ StartAccept(a); //continue accepting
+ }
+
+ void OnTimeoutSuspendAccept(TTcpAcceptor* a, TAtomicSharedPtr<TDeadlineTimer>, const TErrorCode& ec, IHandlingContext&) {
+ if (!ec) {
+ DBGOUT("resume acceptor")
+ StartAccept(a);
+ }
+ }
+
+ void OnRequest(IRequest* r) {
+ try {
+ CB_->OnRequest(r);
+ } catch (...) {
+ Cdbg << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ protected:
+ void OnCreateConn() noexcept {
+ HttpInConnCounter()->Inc();
+ }
+
+ void OnDestroyConn() noexcept {
+ HttpInConnCounter()->Dec();
+ }
+
+ TDuration GetKeepAliveTimeout() const noexcept {
+ size_t cc = HttpInConnCounter()->Val();
+ TFdLimits lim(*HttpInConnLimits());
+
+ if (!TFdLimits::ExceedLimit(cc, lim.Soft())) {
+ return THttp2Options::ServerInputDeadlineKeepAliveMax;
+ }
+
+ if (cc > lim.Hard()) {
+ cc = lim.Hard();
+ }
+ TDuration::TValue softTuneRange = THttp2Options::ServerInputDeadlineKeepAliveMax.Seconds() - THttp2Options::ServerInputDeadlineKeepAliveMin.Seconds();
+
+ return TDuration::Seconds((softTuneRange * (cc - lim.Soft())) / (lim.Hard() - lim.Soft() + 1)) + THttp2Options::ServerInputDeadlineKeepAliveMin;
+ }
+
+ private:
+ void StartAccept(TTcpAcceptor* a) {
+ TAtomicSharedPtr<TTcpSocket> s(new TTcpSocket(E_.Size() ? E_.GetExecutor().GetIOService() : AcceptExecutor_.GetIOService()));
+ a->AsyncAccept(*s, std::bind(&THttpServer::OnAccept, this, a, s, _1, _2));
+ }
+
+ TIOServiceExecutor AcceptExecutor_;
+ TVector<TTcpAcceptorPtr> A_;
+ TExecutorsPool E_;
+ IOnRequest* CB_;
+
+ public:
+ const i32 LimitRequestsPerConnection;
+ };
+
+ template <class T>
+ class THttp2Protocol: public IProtocol {
+ public:
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ return new THttpServer(cb, loc);
+ }
+
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ THttpRequest::THandleRef ret(new THttpRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+ try {
+ THttpRequest::Run(ret, msg, &T::Build, T::RequestSettings());
+ } catch (...) {
+ ret->ResetOnRecv();
+ throw;
+ }
+ return ret.Get();
+ }
+
+ THandleRef ScheduleAsyncRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss, bool useAsyncSendRequest) override {
+ THttpRequest::THandleRef ret(new THttpRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+ try {
+ auto requestSettings = T::RequestSettings();
+ requestSettings.SetUseAsyncSendRequest(useAsyncSendRequest);
+ THttpRequest::Run(ret, msg, &T::Build, requestSettings);
+ } catch (...) {
+ ret->ResetOnRecv();
+ throw;
+ }
+ return ret.Get();
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return T::Name();
+ }
+
+ bool SetOption(TStringBuf name, TStringBuf value) override {
+ return THttp2Options::Set(name, value);
+ }
+ };
+}
+
+namespace NNeh {
+ IProtocol* Http1Protocol() {
+ return Singleton<THttp2Protocol<TRequestGet1>>();
+ }
+ IProtocol* Post1Protocol() {
+ return Singleton<THttp2Protocol<TRequestPost1>>();
+ }
+ IProtocol* Full1Protocol() {
+ return Singleton<THttp2Protocol<TRequestFull1>>();
+ }
+ IProtocol* Http2Protocol() {
+ return Singleton<THttp2Protocol<TRequestGet2>>();
+ }
+ IProtocol* Post2Protocol() {
+ return Singleton<THttp2Protocol<TRequestPost2>>();
+ }
+ IProtocol* Full2Protocol() {
+ return Singleton<THttp2Protocol<TRequestFull2>>();
+ }
+ IProtocol* UnixSocketGetProtocol() {
+ return Singleton<THttp2Protocol<TRequestUnixSocketGet>>();
+ }
+ IProtocol* UnixSocketPostProtocol() {
+ return Singleton<THttp2Protocol<TRequestUnixSocketPost>>();
+ }
+ IProtocol* UnixSocketFullProtocol() {
+ return Singleton<THttp2Protocol<TRequestUnixSocketFull>>();
+ }
+
+ void SetHttp2OutputConnectionsLimits(size_t softLimit, size_t hardLimit) {
+ HttpConnManager()->SetLimits(softLimit, hardLimit);
+ }
+
+ void SetHttp2InputConnectionsLimits(size_t softLimit, size_t hardLimit) {
+ HttpInConnLimits()->SetSoft(softLimit);
+ HttpInConnLimits()->SetHard(hardLimit);
+ }
+
+ TAtomicBase GetHttpOutputConnectionCount() {
+ return HttpOutConnCounter()->Val();
+ }
+
+ std::pair<size_t, size_t> GetHttpOutputConnectionLimits() {
+ return HttpConnManager()->GetLimits();
+ }
+
+ TAtomicBase GetHttpInputConnectionCount() {
+ return HttpInConnCounter()->Val();
+ }
+
+ void SetHttp2InputConnectionsTimeouts(unsigned minSeconds, unsigned maxSeconds) {
+ THttp2Options::ServerInputDeadlineKeepAliveMin = TDuration::Seconds(minSeconds);
+ THttp2Options::ServerInputDeadlineKeepAliveMax = TDuration::Seconds(maxSeconds);
+ }
+
+ class TUnixSocketResolver {
+ public:
+ NDns::TResolvedHost* Resolve(const TString& path) {
+ TString unixSocketPath = path;
+ if (path.size() > 2 && path[0] == '[' && path[path.size() - 1] == ']') {
+ unixSocketPath = path.substr(1, path.size() - 2);
+ }
+
+ if (auto resolvedUnixSocket = ResolvedUnixSockets_.FindPtr(unixSocketPath)) {
+ return resolvedUnixSocket->Get();
+ }
+
+ TNetworkAddress na{TUnixSocketPath(unixSocketPath)};
+ ResolvedUnixSockets_[unixSocketPath] = MakeHolder<NDns::TResolvedHost>(unixSocketPath, na);
+
+ return ResolvedUnixSockets_[unixSocketPath].Get();
+ }
+
+ private:
+ THashMap<TString, THolder<NDns::TResolvedHost>> ResolvedUnixSockets_;
+ };
+
+ TUnixSocketResolver* UnixSocketResolver() {
+ return FastTlsSingleton<TUnixSocketResolver>();
+ }
+
+ const NDns::TResolvedHost* Resolve(const TStringBuf host, ui16 port, NHttp::EResolverType resolverType) {
+ if (resolverType == EResolverType::EUNIXSOCKET) {
+ return UnixSocketResolver()->Resolve(TString(host));
+ }
+ return NDns::CachedResolve(NDns::TResolveInfo(host, port));
+
+ }
+}
diff --git a/library/cpp/neh/http2.h b/library/cpp/neh/http2.h
new file mode 100644
index 0000000000..7bc16affa0
--- /dev/null
+++ b/library/cpp/neh/http2.h
@@ -0,0 +1,119 @@
+#pragma once
+
+#include "factory.h"
+#include "http_common.h"
+
+#include <util/datetime/base.h>
+#include <library/cpp/dns/cache.h>
+#include <utility>
+
+namespace NNeh {
+ IProtocol* Http1Protocol();
+ IProtocol* Post1Protocol();
+ IProtocol* Full1Protocol();
+ IProtocol* Http2Protocol();
+ IProtocol* Post2Protocol();
+ IProtocol* Full2Protocol();
+ IProtocol* UnixSocketGetProtocol();
+ IProtocol* UnixSocketPostProtocol();
+ IProtocol* UnixSocketFullProtocol();
+
+ //global options
+ struct THttp2Options {
+ //connect timeout
+ static TDuration ConnectTimeout;
+
+ //input and output timeouts
+ static TDuration InputDeadline;
+ static TDuration OutputDeadline;
+
+ //when detected slow connection, will be runned concurrent parallel connection
+ //not used, if SymptomSlowConnect > ConnectTimeout
+ static TDuration SymptomSlowConnect;
+
+ //input buffer size
+ static size_t InputBufferSize;
+
+ //http client input buffer politic
+ static bool KeepInputBufferForCachedConnections;
+
+ //asio threads
+ static size_t AsioThreads;
+
+ //asio server threads, - if == 0, use acceptor thread for read/parse incoming requests
+ //esle use one thread for accepting + AsioServerThreads for process established tcp connections
+ static size_t AsioServerThreads;
+
+ //use ACK for ensure completely sending request (call ioctl() for checking emptiness output buffer)
+ //reliable check, but can spend to much time (40ms or like it) (see Wikipedia: TCP delayed acknowledgment)
+ //disabling this option reduce sending validation to established connection and written all request data to socket buffer
+ static bool EnsureSendingCompleteByAck;
+
+ //listen socket queue limit
+ static int Backlog;
+
+ //expecting receiving request data right after connect or inside receiving request data
+ static TDuration ServerInputDeadline;
+
+ //timelimit for sending response data
+ static TDuration ServerOutputDeadline;
+
+ //expecting receiving request for keep-alived socket
+ //(Max - if not reached SoftLimit, Min, if reached Hard limit)
+ static TDuration ServerInputDeadlineKeepAliveMax;
+ static TDuration ServerInputDeadlineKeepAliveMin;
+
+ //try write data into socket fd in contex handler->SendReply() call
+ //(instead moving write job to asio thread)
+ //this reduce sys_cpu load (less sys calls), but increase user_cpu and response time
+ static bool ServerUseDirectWrite;
+
+ //use http response body as error message
+ static bool UseResponseAsErrorMessage;
+
+ //pass all http response headers as error message
+ static bool FullHeadersAsErrorMessage;
+
+ //use details (SendError argument) as response body
+ static bool ErrorDetailsAsResponseBody;
+
+ //consider responses with 3xx code as successful
+ static bool RedirectionNotError;
+
+ //consider response with any code as successful
+ static bool AnyResponseIsNotError;
+
+ //enable tcp keepalive for outgoing requests sockets
+ static bool TcpKeepAlive;
+
+ //enable limit requests per keep alive connection
+ static i32 LimitRequestsPerConnection;
+
+ //enable TCP_QUICKACK
+ static bool QuickAck;
+
+ // enable write to socket via ScheduleOp
+ static bool UseAsyncSendRequest;
+
+ //set option, - return false, if option name not recognized
+ static bool Set(TStringBuf name, TStringBuf value);
+ };
+
+ /// if exceed soft limit, reduce quantity unused connections in cache
+ void SetHttp2OutputConnectionsLimits(size_t softLimit, size_t hardLimit);
+
+ /// if exceed soft limit, reduce quantity unused connections in cache
+ void SetHttp2InputConnectionsLimits(size_t softLimit, size_t hardLimit);
+
+ /// for debug and monitoring purposes
+ TAtomicBase GetHttpOutputConnectionCount();
+ TAtomicBase GetHttpInputConnectionCount();
+ std::pair<size_t, size_t> GetHttpOutputConnectionLimits();
+
+ /// unused input sockets keepalive timeouts
+ /// real(used) timeout:
+ /// - max, if not reached soft limit
+ /// - min, if reached hard limit
+ /// - approx. linear changed[max..min], while conn. count in range [soft..hard]
+ void SetHttp2InputConnectionsTimeouts(unsigned minSeconds, unsigned maxSeconds);
+}
diff --git a/library/cpp/neh/http_common.cpp b/library/cpp/neh/http_common.cpp
new file mode 100644
index 0000000000..7ae466c31a
--- /dev/null
+++ b/library/cpp/neh/http_common.cpp
@@ -0,0 +1,235 @@
+#include "http_common.h"
+
+#include "location.h"
+#include "http_headers.h"
+
+#include <util/generic/array_ref.h>
+#include <util/generic/singleton.h>
+#include <util/stream/length.h>
+#include <util/stream/null.h>
+#include <util/stream/str.h>
+#include <util/string/ascii.h>
+
+using NNeh::NHttp::ERequestType;
+
+namespace {
+ bool IsEmpty(const TStringBuf url) {
+ return url.empty();
+ }
+
+ void WriteImpl(const TStringBuf url, IOutputStream& out) {
+ out << url;
+ }
+
+ bool IsEmpty(const TConstArrayRef<TString> urlParts) {
+ return urlParts.empty();
+ }
+
+ void WriteImpl(const TConstArrayRef<TString> urlParts, IOutputStream& out) {
+ NNeh::NHttp::JoinUrlParts(urlParts, out);
+ }
+
+ template <typename T>
+ size_t GetLength(const T& urlParts) {
+ TCountingOutput out(&Cnull);
+ WriteImpl(urlParts, out);
+ return out.Counter();
+ }
+
+ template <typename T>
+ void WriteUrl(const T& urlParts, IOutputStream& out) {
+ if (!IsEmpty(urlParts)) {
+ out << '?';
+ WriteImpl(urlParts, out);
+ }
+ }
+}
+
+namespace NNeh {
+ namespace NHttp {
+ size_t GetUrlPartsLength(const TConstArrayRef<TString> urlParts) {
+ size_t res = 0;
+
+ for (const auto& u : urlParts) {
+ res += u.length();
+ }
+
+ if (urlParts.size() > 0) {
+ res += urlParts.size() - 1; //'&' between parts
+ }
+
+ return res;
+ }
+
+ void JoinUrlParts(const TConstArrayRef<TString> urlParts, IOutputStream& out) {
+ if (urlParts.empty()) {
+ return;
+ }
+
+ out << urlParts[0];
+
+ for (size_t i = 1; i < urlParts.size(); ++i) {
+ out << '&' << urlParts[i];
+ }
+ }
+
+ void WriteUrlParts(const TConstArrayRef<TString> urlParts, IOutputStream& out) {
+ WriteUrl(urlParts, out);
+ }
+ }
+}
+
+namespace {
+ const TStringBuf schemeHttps = "https";
+ const TStringBuf schemeHttp = "http";
+ const TStringBuf schemeHttp2 = "http2";
+ const TStringBuf schemePost = "post";
+ const TStringBuf schemePosts = "posts";
+ const TStringBuf schemePost2 = "post2";
+ const TStringBuf schemeFull = "full";
+ const TStringBuf schemeFulls = "fulls";
+ const TStringBuf schemeHttpUnix = "http+unix";
+ const TStringBuf schemePostUnix = "post+unix";
+
+ /*
+ @brief SafeWriteHeaders write headers from hdrs to out with some checks:
+ - filter out Content-Lenthgh because we'll add it ourselfs later.
+
+ @todo ensure headers right formatted (now receive from perl report bad format headers)
+ */
+ void SafeWriteHeaders(IOutputStream& out, TStringBuf hdrs) {
+ NNeh::NHttp::THeaderSplitter splitter(hdrs);
+ TStringBuf msgHdr;
+ while (splitter.Next(msgHdr)) {
+ if (!AsciiHasPrefixIgnoreCase(msgHdr, TStringBuf("Content-Length"))) {
+ out << msgHdr << TStringBuf("\r\n");
+ }
+ }
+ }
+
+ template <typename T, typename W>
+ TString BuildRequest(const NNeh::TParsedLocation& loc, const T& urlParams, const TStringBuf headers, const W& content, const TStringBuf contentType, ERequestType requestType, NNeh::NHttp::ERequestFlags requestFlags) {
+ const bool isAbsoluteUri = requestFlags.HasFlags(NNeh::NHttp::ERequestFlag::AbsoluteUri);
+
+ const auto contentLength = GetLength(content);
+ TStringStream out;
+ out.Reserve(loc.Service.length() + loc.Host.length() + GetLength(urlParams) + headers.length() + contentType.length() + contentLength + (isAbsoluteUri ? (loc.Host.length() + 13) : 0) // 13 - is a max port number length + scheme length
+ + 96); //just some extra space
+
+ Y_ASSERT(requestType != ERequestType::Any);
+ out << requestType;
+ out << ' ';
+ if (isAbsoluteUri) {
+ out << loc.Scheme << TStringBuf("://") << loc.Host;
+ if (loc.Port) {
+ out << ':' << loc.Port;
+ }
+ }
+ out << '/' << loc.Service;
+
+ WriteUrl(urlParams, out);
+ out << TStringBuf(" HTTP/1.1\r\n");
+
+ NNeh::NHttp::WriteHostHeaderIfNot(out, loc.Host, loc.Port, headers);
+ SafeWriteHeaders(out, headers);
+ if (!IsEmpty(content)) {
+ if (!!contentType && headers.find(TStringBuf("Content-Type:")) == TString::npos) {
+ out << TStringBuf("Content-Type: ") << contentType << TStringBuf("\r\n");
+ }
+ out << TStringBuf("Content-Length: ") << contentLength << TStringBuf("\r\n");
+ out << TStringBuf("\r\n");
+ WriteImpl(content, out);
+ } else {
+ out << TStringBuf("\r\n");
+ }
+ return out.Str();
+ }
+
+ bool NeedGetRequestFor(TStringBuf scheme) {
+ return scheme == schemeHttp2 || scheme == schemeHttp || scheme == schemeHttps || scheme == schemeHttpUnix;
+ }
+
+ bool NeedPostRequestFor(TStringBuf scheme) {
+ return scheme == schemePost2 || scheme == schemePost || scheme == schemePosts || scheme == schemePostUnix;
+ }
+
+ inline ERequestType ChooseReqType(ERequestType userReqType, ERequestType defaultReqType) {
+ Y_ASSERT(defaultReqType != ERequestType::Any);
+ return userReqType != ERequestType::Any ? userReqType : defaultReqType;
+ }
+}
+
+namespace NNeh {
+ namespace NHttp {
+ const TStringBuf DefaultContentType = "application/x-www-form-urlencoded";
+
+ template <typename T>
+ bool MakeFullRequestImpl(TMessage& msg, const TStringBuf proxy, const T& urlParams, const TStringBuf headers, const TStringBuf content, const TStringBuf contentType, ERequestType reqType, ERequestFlags reqFlags) {
+ NNeh::TParsedLocation loc(msg.Addr);
+
+ if (content.size()) {
+ //content MUST be placed inside POST requests
+ if (!IsEmpty(urlParams)) {
+ if (NeedGetRequestFor(loc.Scheme)) {
+ msg.Data = BuildRequest(loc, urlParams, headers, content, contentType, ChooseReqType(reqType, ERequestType::Post), reqFlags);
+ } else {
+ // cannot place in first header line potentially unsafe data from POST message
+ // (can contain forbidden for url-path characters)
+ // so support such mutation only for GET requests
+ return false;
+ }
+ } else {
+ if (NeedGetRequestFor(loc.Scheme) || NeedPostRequestFor(loc.Scheme)) {
+ msg.Data = BuildRequest(loc, urlParams, headers, content, contentType, ChooseReqType(reqType, ERequestType::Post), reqFlags);
+ } else {
+ return false;
+ }
+ }
+ } else {
+ if (NeedGetRequestFor(loc.Scheme)) {
+ msg.Data = BuildRequest(loc, urlParams, headers, "", "", ChooseReqType(reqType, ERequestType::Get), reqFlags);
+ } else if (NeedPostRequestFor(loc.Scheme)) {
+ msg.Data = BuildRequest(loc, TString(), headers, urlParams, contentType, ChooseReqType(reqType, ERequestType::Post), reqFlags);
+ } else {
+ return false;
+ }
+ }
+
+ if (proxy.IsInited()) {
+ loc = NNeh::TParsedLocation(proxy);
+ msg.Addr = proxy;
+ }
+
+ TString schemePostfix = "";
+ if (loc.Scheme.EndsWith("+unix")) {
+ schemePostfix = "+unix";
+ }
+
+ // ugly but still... https2 will break it :(
+ if ('s' == loc.Scheme[loc.Scheme.size() - 1]) {
+ msg.Addr.replace(0, loc.Scheme.size(), schemeFulls + schemePostfix);
+ } else {
+ msg.Addr.replace(0, loc.Scheme.size(), schemeFull + schemePostfix);
+ }
+
+ return true;
+ }
+
+ bool MakeFullRequest(TMessage& msg, const TStringBuf headers, const TStringBuf content, const TStringBuf contentType, ERequestType reqType, ERequestFlags reqFlags) {
+ return MakeFullRequestImpl(msg, {}, msg.Data, headers, content, contentType, reqType, reqFlags);
+ }
+
+ bool MakeFullRequest(TMessage& msg, const TConstArrayRef<TString> urlParts, const TStringBuf headers, const TStringBuf content, const TStringBuf contentType, ERequestType reqType, ERequestFlags reqFlags) {
+ return MakeFullRequestImpl(msg, {}, urlParts, headers, content, contentType, reqType, reqFlags);
+ }
+
+ bool MakeFullProxyRequest(TMessage& msg, TStringBuf proxyAddr, TStringBuf headers, TStringBuf content, TStringBuf contentType, ERequestType reqType, ERequestFlags flags) {
+ return MakeFullRequestImpl(msg, proxyAddr, msg.Data, headers, content, contentType, reqType, flags | ERequestFlag::AbsoluteUri);
+ }
+
+ bool IsHttpScheme(TStringBuf scheme) {
+ return NeedGetRequestFor(scheme) || NeedPostRequestFor(scheme);
+ }
+ }
+}
+
diff --git a/library/cpp/neh/http_common.h b/library/cpp/neh/http_common.h
new file mode 100644
index 0000000000..91b4ca1356
--- /dev/null
+++ b/library/cpp/neh/http_common.h
@@ -0,0 +1,305 @@
+#pragma once
+
+#include <util/generic/array_ref.h>
+#include <util/generic/flags.h>
+#include <util/generic/ptr.h>
+#include <util/generic/vector.h>
+#include <util/stream/mem.h>
+#include <util/stream/output.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+
+#include "location.h"
+#include "neh.h"
+#include "rpc.h"
+
+#include <atomic>
+
+//common primitives for http/http2
+
+namespace NNeh {
+ struct THttpErrorDetails {
+ TString Details = {};
+ TString Headers = {};
+ };
+
+ class IHttpRequest: public IRequest {
+ public:
+ using IRequest::SendReply;
+ virtual void SendReply(TData& data, const TString& headers, int httpCode = 200) = 0;
+ virtual const THttpHeaders& Headers() const = 0;
+ virtual TStringBuf Method() const = 0;
+ virtual TStringBuf Body() const = 0;
+ virtual TStringBuf Cgi() const = 0;
+ void SendError(TResponseError err, const TString& details = TString()) override final {
+ SendError(err, THttpErrorDetails{.Details = details});
+ }
+
+ virtual void SendError(TResponseError err, const THttpErrorDetails& details) = 0;
+ };
+
+ namespace NHttp {
+ enum class EResolverType {
+ ETCP = 0,
+ EUNIXSOCKET = 1
+ };
+
+ struct TFdLimits {
+ public:
+ TFdLimits()
+ : Soft_(10000)
+ , Hard_(15000)
+ {
+ }
+
+ TFdLimits(const TFdLimits& other) {
+ Soft_.store(other.Soft(), std::memory_order_release);
+ Hard_.store(other.Hard(), std::memory_order_release);
+ }
+
+ inline size_t Delta() const noexcept {
+ return ExceedLimit(Hard_.load(std::memory_order_acquire), Soft_.load(std::memory_order_acquire));
+ }
+
+ inline static size_t ExceedLimit(size_t val, size_t limit) noexcept {
+ return val > limit ? val - limit : 0;
+ }
+
+ void SetSoft(size_t value) noexcept {
+ Soft_.store(value, std::memory_order_release);
+ }
+
+ void SetHard(size_t value) noexcept {
+ Hard_.store(value, std::memory_order_release);
+ }
+
+ size_t Soft() const noexcept {
+ return Soft_.load(std::memory_order_acquire);
+ }
+
+ size_t Hard() const noexcept {
+ return Hard_.load(std::memory_order_acquire);
+ }
+
+ private:
+ std::atomic<size_t> Soft_;
+ std::atomic<size_t> Hard_;
+ };
+
+ template <class T>
+ class TLockFreeSequence {
+ public:
+ inline TLockFreeSequence() {
+ memset((void*)T_, 0, sizeof(T_));
+ }
+
+ inline ~TLockFreeSequence() {
+ for (size_t i = 0; i < Y_ARRAY_SIZE(T_); ++i) {
+ delete[] T_[i];
+ }
+ }
+
+ inline T& Get(size_t n) {
+ const size_t i = GetValueBitCount(n + 1) - 1;
+
+ return GetList(i)[n + 1 - (((size_t)1) << i)];
+ }
+
+ private:
+ inline T* GetList(size_t n) {
+ T* volatile* t = T_ + n;
+
+ T* result;
+ while (!(result = AtomicGet(*t))) {
+ TArrayHolder<T> nt(new T[((size_t)1) << n]);
+
+ if (AtomicCas(t, nt.Get(), nullptr)) {
+ return nt.Release();
+ }
+ }
+
+ return result;
+ }
+
+ private:
+ T* volatile T_[sizeof(size_t) * 8];
+ };
+
+ class TRequestData: public TNonCopyable {
+ public:
+ using TPtr = TAutoPtr<TRequestData>;
+ using TParts = TVector<IOutputStream::TPart>;
+
+ inline TRequestData(size_t memSize)
+ : Mem(memSize)
+ {
+ }
+
+ inline void SendTo(IOutputStream& io) const {
+ io.Write(Parts_.data(), Parts_.size());
+ }
+
+ inline void AddPart(const void* buf, size_t len) noexcept {
+ Parts_.push_back(IOutputStream::TPart(buf, len));
+ }
+
+ const TParts& Parts() const noexcept {
+ return Parts_;
+ }
+
+ TVector<char> Mem;
+
+ private:
+ TParts Parts_;
+ };
+
+ struct TRequestSettings {
+ bool NoDelay = true;
+ EResolverType ResolverType = EResolverType::ETCP;
+ bool UseAsyncSendRequest = false;
+
+ TRequestSettings& SetNoDelay(bool noDelay) {
+ NoDelay = noDelay;
+ return *this;
+ }
+
+ TRequestSettings& SetResolverType(EResolverType resolverType) {
+ ResolverType = resolverType;
+ return *this;
+ }
+
+ TRequestSettings& SetUseAsyncSendRequest(bool useAsyncSendRequest) {
+ UseAsyncSendRequest = useAsyncSendRequest;
+ return *this;
+ }
+ };
+
+ struct TRequestGet {
+ static TRequestData::TPtr Build(const TMessage& msg, const TParsedLocation& loc) {
+ TRequestData::TPtr req(new TRequestData(50 + loc.Service.size() + msg.Data.size() + loc.Host.size()));
+ TMemoryOutput out(req->Mem.data(), req->Mem.size());
+
+ out << TStringBuf("GET /") << loc.Service;
+
+ if (!!msg.Data) {
+ out << '?' << msg.Data;
+ }
+
+ out << TStringBuf(" HTTP/1.1\r\nHost: ") << loc.Host;
+
+ if (!!loc.Port) {
+ out << TStringBuf(":") << loc.Port;
+ }
+
+ out << TStringBuf("\r\n\r\n");
+
+ req->AddPart(req->Mem.data(), out.Buf() - req->Mem.data());
+ return req;
+ }
+
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("http");
+ }
+
+ static TRequestSettings RequestSettings() {
+ return TRequestSettings();
+ }
+ };
+
+ struct TRequestPost {
+ static TRequestData::TPtr Build(const TMessage& msg, const TParsedLocation& loc) {
+ TRequestData::TPtr req(new TRequestData(100 + loc.Service.size() + loc.Host.size()));
+ TMemoryOutput out(req->Mem.data(), req->Mem.size());
+
+ out << TStringBuf("POST /") << loc.Service
+ << TStringBuf(" HTTP/1.1\r\nHost: ") << loc.Host;
+
+ if (!!loc.Port) {
+ out << TStringBuf(":") << loc.Port;
+ }
+
+ out << TStringBuf("\r\nContent-Length: ") << msg.Data.size() << TStringBuf("\r\n\r\n");
+
+ req->AddPart(req->Mem.data(), out.Buf() - req->Mem.data());
+ req->AddPart(msg.Data.data(), msg.Data.size());
+ return req;
+ }
+
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("post");
+ }
+
+ static TRequestSettings RequestSettings() {
+ return TRequestSettings();
+ }
+ };
+
+ struct TRequestFull {
+ static TRequestData::TPtr Build(const TMessage& msg, const TParsedLocation&) {
+ TRequestData::TPtr req(new TRequestData(0));
+ req->AddPart(msg.Data.data(), msg.Data.size());
+ return req;
+ }
+
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("full");
+ }
+
+ static TRequestSettings RequestSettings() {
+ return TRequestSettings();
+ }
+ };
+
+ enum class ERequestType {
+ Any = 0 /* "ANY" */,
+ Post /* "POST" */,
+ Get /* "GET" */,
+ Put /* "PUT" */,
+ Delete /* "DELETE" */,
+ Patch /* "PATCH" */,
+ };
+
+ enum class ERequestFlag {
+ None = 0,
+ /** use absoulte uri for proxy requests in the first request line
+ * POST http://ya.ru HTTP/1.1
+ * @see https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2
+ */
+ AbsoluteUri = 1,
+ };
+
+ Y_DECLARE_FLAGS(ERequestFlags, ERequestFlag)
+ Y_DECLARE_OPERATORS_FOR_FLAGS(ERequestFlags)
+
+ static constexpr ERequestType DefaultRequestType = ERequestType::Any;
+
+ extern const TStringBuf DefaultContentType;
+
+ /// @brief `MakeFullRequest` transmutes http/post/http2/post2 message to full/full2 with
+ /// additional HTTP headers and/or content data.
+ ///
+ /// If reqType is `Any`, then request type is POST, unless content is empty and schema
+ /// prefix is http/https/http2, in that case request type is GET.
+ ///
+ /// @msg[in] Will get URL from `msg.Data`.
+ bool MakeFullRequest(TMessage& msg, TStringBuf headers, TStringBuf content, TStringBuf contentType = DefaultContentType, ERequestType reqType = DefaultRequestType, ERequestFlags flags = ERequestFlag::None);
+
+ /// @see `MakeFullrequest`.
+ ///
+ /// @urlParts[in] Will construct url from `urlParts`, `msg.Data` is not used.
+ bool MakeFullRequest(TMessage& msg, TConstArrayRef<TString> urlParts, TStringBuf headers, TStringBuf content, TStringBuf contentType = DefaultContentType, ERequestType reqType = DefaultRequestType, ERequestFlags flags = ERequestFlag::None);
+
+ /// Same as `MakeFullRequest` but it will add ERequestFlag::AbsoluteUri to the @a flags
+ /// and replace msg.Addr with @a proxyAddr
+ ///
+ /// @see `MakeFullrequest`.
+ bool MakeFullProxyRequest(TMessage& msg, TStringBuf proxyAddr, TStringBuf headers, TStringBuf content, TStringBuf contentType = DefaultContentType, ERequestType reqType = DefaultRequestType, ERequestFlags flags = ERequestFlag::None);
+
+ size_t GetUrlPartsLength(TConstArrayRef<TString> urlParts);
+ //part1&part2&...
+ void JoinUrlParts(TConstArrayRef<TString> urlParts, IOutputStream& out);
+ //'?' + JoinUrlParts
+ void WriteUrlParts(TConstArrayRef<TString> urlParts, IOutputStream& out);
+
+ bool IsHttpScheme(TStringBuf scheme);
+ }
+}
diff --git a/library/cpp/neh/http_headers.cpp b/library/cpp/neh/http_headers.cpp
new file mode 100644
index 0000000000..79dc823e7e
--- /dev/null
+++ b/library/cpp/neh/http_headers.cpp
@@ -0,0 +1 @@
+#include "http_headers.h"
diff --git a/library/cpp/neh/http_headers.h b/library/cpp/neh/http_headers.h
new file mode 100644
index 0000000000..70cf3a9fbe
--- /dev/null
+++ b/library/cpp/neh/http_headers.h
@@ -0,0 +1,55 @@
+#pragma once
+
+#include <util/generic/strbuf.h>
+#include <util/stream/output.h>
+#include <util/string/ascii.h>
+
+namespace NNeh {
+ namespace NHttp {
+ template <typename Port>
+ void WriteHostHeader(IOutputStream& out, TStringBuf host, Port port) {
+ out << TStringBuf("Host: ") << host;
+ if (port) {
+ out << TStringBuf(":") << port;
+ }
+ out << TStringBuf("\r\n");
+ }
+
+ class THeaderSplitter {
+ public:
+ THeaderSplitter(TStringBuf headers)
+ : Headers_(headers)
+ {
+ }
+
+ bool Next(TStringBuf& header) {
+ while (Headers_.ReadLine(header)) {
+ if (!header.Empty()) {
+ return true;
+ }
+ }
+ return false;
+ }
+ private:
+ TStringBuf Headers_;
+ };
+
+ inline bool HasHostHeader(TStringBuf headers) {
+ THeaderSplitter splitter(headers);
+ TStringBuf header;
+ while (splitter.Next(header)) {
+ if (AsciiHasPrefixIgnoreCase(header, "Host:")) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ template <typename Port>
+ void WriteHostHeaderIfNot(IOutputStream& out, TStringBuf host, Port port, TStringBuf headers) {
+ if (!NNeh::NHttp::HasHostHeader(headers)) {
+ NNeh::NHttp::WriteHostHeader(out, host, port);
+ }
+ }
+ }
+}
diff --git a/library/cpp/neh/https.cpp b/library/cpp/neh/https.cpp
new file mode 100644
index 0000000000..d0e150e778
--- /dev/null
+++ b/library/cpp/neh/https.cpp
@@ -0,0 +1,1936 @@
+#include "https.h"
+
+#include "details.h"
+#include "factory.h"
+#include "http_common.h"
+#include "jobqueue.h"
+#include "location.h"
+#include "multi.h"
+#include "pipequeue.h"
+#include "utils.h"
+
+#include <contrib/libs/openssl/include/openssl/ssl.h>
+#include <contrib/libs/openssl/include/openssl/err.h>
+#include <contrib/libs/openssl/include/openssl/bio.h>
+#include <contrib/libs/openssl/include/openssl/x509v3.h>
+
+#include <library/cpp/openssl/init/init.h>
+#include <library/cpp/openssl/method/io.h>
+#include <library/cpp/coroutine/listener/listen.h>
+#include <library/cpp/dns/cache.h>
+#include <library/cpp/http/misc/parsed_request.h>
+#include <library/cpp/http/misc/httpcodes.h>
+#include <library/cpp/http/io/stream.h>
+
+#include <util/generic/cast.h>
+#include <util/generic/list.h>
+#include <util/generic/utility.h>
+#include <util/network/socket.h>
+#include <util/stream/str.h>
+#include <util/stream/zlib.h>
+#include <util/string/builder.h>
+#include <util/string/cast.h>
+#include <util/system/condvar.h>
+#include <util/system/error.h>
+#include <util/system/types.h>
+#include <util/thread/factory.h>
+
+#include <atomic>
+
+#if defined(_unix_)
+#include <sys/ioctl.h>
+#endif
+
+#if defined(_linux_)
+#undef SIOCGSTAMP
+#undef SIOCGSTAMPNS
+#include <linux/sockios.h>
+#define FIONWRITE SIOCOUTQ
+#endif
+
+using namespace NDns;
+using namespace NAddr;
+
+namespace NNeh {
+ TString THttpsOptions::CAFile;
+ TString THttpsOptions::CAPath;
+ TString THttpsOptions::ClientCertificate;
+ TString THttpsOptions::ClientPrivateKey;
+ TString THttpsOptions::ClientPrivateKeyPassword;
+ bool THttpsOptions::EnableSslServerDebug = false;
+ bool THttpsOptions::EnableSslClientDebug = false;
+ bool THttpsOptions::CheckCertificateHostname = false;
+ THttpsOptions::TVerifyCallback THttpsOptions::ClientVerifyCallback = nullptr;
+ THttpsOptions::TPasswordCallback THttpsOptions::KeyPasswdCallback = nullptr;
+ bool THttpsOptions::RedirectionNotError = false;
+
+ bool THttpsOptions::Set(TStringBuf name, TStringBuf value) {
+#define YNDX_NEH_HTTPS_TRY_SET(optName) \
+ if (name == TStringBuf(#optName)) { \
+ optName = FromString<decltype(optName)>(value); \
+ return true; \
+ }
+
+ YNDX_NEH_HTTPS_TRY_SET(CAFile);
+ YNDX_NEH_HTTPS_TRY_SET(CAPath);
+ YNDX_NEH_HTTPS_TRY_SET(ClientCertificate);
+ YNDX_NEH_HTTPS_TRY_SET(ClientPrivateKey);
+ YNDX_NEH_HTTPS_TRY_SET(ClientPrivateKeyPassword);
+ YNDX_NEH_HTTPS_TRY_SET(EnableSslServerDebug);
+ YNDX_NEH_HTTPS_TRY_SET(EnableSslClientDebug);
+ YNDX_NEH_HTTPS_TRY_SET(CheckCertificateHostname);
+ YNDX_NEH_HTTPS_TRY_SET(RedirectionNotError);
+
+#undef YNDX_NEH_HTTPS_TRY_SET
+
+ return false;
+ }
+}
+
+namespace NNeh {
+ namespace NHttps {
+ namespace {
+ // force ssl_write/ssl_read functions to return this value via BIO_method_read/write that means request is canceled
+ constexpr int SSL_RVAL_TIMEOUT = -42;
+
+ struct TInputConnections {
+ TInputConnections()
+ : Counter(0)
+ , MaxUnusedConnKeepaliveTimeout(120)
+ , MinUnusedConnKeepaliveTimeout(10)
+ {
+ }
+
+ inline size_t ExceedSoftLimit() const noexcept {
+ return NHttp::TFdLimits::ExceedLimit(Counter.Val(), Limits.Soft());
+ }
+
+ inline size_t ExceedHardLimit() const noexcept {
+ return NHttp::TFdLimits::ExceedLimit(Counter.Val(), Limits.Hard());
+ }
+
+ inline size_t DeltaLimit() const noexcept {
+ return Limits.Delta();
+ }
+
+ unsigned UnusedConnKeepaliveTimeout() const {
+ if (size_t e = ExceedSoftLimit()) {
+ size_t d = DeltaLimit();
+ size_t leftAvailableFd = NHttp::TFdLimits::ExceedLimit(d, e);
+ unsigned r = static_cast<unsigned>(MaxUnusedConnKeepaliveTimeout.load(std::memory_order_acquire) * leftAvailableFd / (d + 1));
+ return Max(r, (unsigned)MinUnusedConnKeepaliveTimeout.load(std::memory_order_acquire));
+ }
+ return MaxUnusedConnKeepaliveTimeout.load(std::memory_order_acquire);
+ }
+
+ void SetFdLimits(size_t soft, size_t hard) {
+ Limits.SetSoft(soft);
+ Limits.SetHard(hard);
+ }
+
+ NHttp::TFdLimits Limits;
+ TAtomicCounter Counter;
+ std::atomic<unsigned> MaxUnusedConnKeepaliveTimeout; //in seconds
+ std::atomic<unsigned> MinUnusedConnKeepaliveTimeout; //in seconds
+ };
+
+ TInputConnections* InputConnections() {
+ return Singleton<TInputConnections>();
+ }
+
+ struct TSharedSocket: public TSocketHolder, public TAtomicRefCount<TSharedSocket> {
+ inline TSharedSocket(TSocketHolder& s)
+ : TSocketHolder(s.Release())
+ {
+ InputConnections()->Counter.Inc();
+ }
+
+ ~TSharedSocket() {
+ InputConnections()->Counter.Dec();
+ }
+ };
+
+ using TSocketRef = TIntrusivePtr<TSharedSocket>;
+
+ struct TX509Deleter {
+ static void Destroy(X509* cert) {
+ X509_free(cert);
+ }
+ };
+ using TX509Holder = THolder<X509, TX509Deleter>;
+
+ struct TSslSessionDeleter {
+ static void Destroy(SSL_SESSION* sess) {
+ SSL_SESSION_free(sess);
+ }
+ };
+ using TSslSessionHolder = THolder<SSL_SESSION, TSslSessionDeleter>;
+
+ struct TSslDeleter {
+ static void Destroy(SSL* ssl) {
+ SSL_free(ssl);
+ }
+ };
+ using TSslHolder = THolder<SSL, TSslDeleter>;
+
+ // read from bio and write via operator<<() to dst
+ template <typename T>
+ class TBIOInput : public NOpenSSL::TAbstractIO {
+ public:
+ TBIOInput(T& dst)
+ : Dst_(dst)
+ {
+ }
+
+ int Write(const char* data, size_t dlen, size_t* written) override {
+ Dst_ << TStringBuf(data, dlen);
+ *written = dlen;
+ return 1;
+ }
+
+ int Read(char* data, size_t dlen, size_t* readbytes) override {
+ Y_UNUSED(data);
+ Y_UNUSED(dlen);
+ Y_UNUSED(readbytes);
+ return -1;
+ }
+
+ int Puts(const char* buf) override {
+ Y_UNUSED(buf);
+ return -1;
+ }
+
+ int Gets(char* buf, int len) override {
+ Y_UNUSED(buf);
+ Y_UNUSED(len);
+ return -1;
+ }
+
+ void Flush() override {
+ }
+
+ private:
+ T& Dst_;
+ };
+ }
+
+ class TSslException: public yexception {
+ public:
+ TSslException() = default;
+
+ TSslException(TStringBuf f) {
+ *this << f << Endl;
+ InitErr();
+ }
+
+ TSslException(TStringBuf f, const SSL* ssl, int ret) {
+ *this << f << TStringBuf(" error type: ");
+ const int etype = SSL_get_error(ssl, ret);
+ switch (etype) {
+ case SSL_ERROR_ZERO_RETURN:
+ *this << TStringBuf("SSL_ERROR_ZERO_RETURN");
+ break;
+ case SSL_ERROR_WANT_READ:
+ *this << TStringBuf("SSL_ERROR_WANT_READ");
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ *this << TStringBuf("SSL_ERROR_WANT_WRITE");
+ break;
+ case SSL_ERROR_WANT_CONNECT:
+ *this << TStringBuf("SSL_ERROR_WANT_CONNECT");
+ break;
+ case SSL_ERROR_WANT_ACCEPT:
+ *this << TStringBuf("SSL_ERROR_WANT_ACCEPT");
+ break;
+ case SSL_ERROR_WANT_X509_LOOKUP:
+ *this << TStringBuf("SSL_ERROR_WANT_X509_LOOKUP");
+ break;
+ case SSL_ERROR_SYSCALL:
+ *this << TStringBuf("SSL_ERROR_SYSCALL ret: ") << ret << TStringBuf(", errno: ") << errno;
+ break;
+ case SSL_ERROR_SSL:
+ *this << TStringBuf("SSL_ERROR_SSL");
+ break;
+ }
+ *this << ' ';
+ InitErr();
+ }
+
+ private:
+ void InitErr() {
+ TBIOInput<TSslException> bio(*this);
+ ERR_print_errors(bio);
+ }
+ };
+
+ namespace {
+ enum EMatchResult {
+ MATCH_FOUND,
+ NO_MATCH,
+ NO_EXTENSION,
+ ERROR
+ };
+ bool EqualNoCase(TStringBuf a, TStringBuf b) {
+ return (a.size() == b.size()) && ToString(a).to_lower() == ToString(b).to_lower();
+ }
+ bool MatchDomainName(TStringBuf tmpl, TStringBuf name) {
+ // match wildcards only in the left-most part
+ // do not support (optional according to RFC) partial wildcards (ww*.yandex.ru)
+ // see RFC-6125
+ TStringBuf tmplRest = tmpl;
+ TStringBuf tmplFirst = tmplRest.NextTok('.');
+ if (tmplFirst == "*") {
+ tmpl = tmplRest;
+ name.NextTok('.');
+ }
+ return EqualNoCase(tmpl, name);
+ }
+
+ EMatchResult MatchCertAltNames(X509* cert, TStringBuf hostname) {
+ EMatchResult result = NO_MATCH;
+ STACK_OF(GENERAL_NAME)* names = (STACK_OF(GENERAL_NAME)*)X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, NULL);
+ if (!names) {
+ return NO_EXTENSION;
+ }
+
+ int namesCt = sk_GENERAL_NAME_num(names);
+ for (int i = 0; i < namesCt; ++i) {
+ const GENERAL_NAME* name = sk_GENERAL_NAME_value(names, i);
+
+ if (name->type == GEN_DNS) {
+ TStringBuf dnsName((const char*)ASN1_STRING_get0_data(name->d.dNSName), ASN1_STRING_length(name->d.dNSName));
+ if (MatchDomainName(dnsName, hostname)) {
+ result = MATCH_FOUND;
+ break;
+ }
+ }
+ }
+ sk_GENERAL_NAME_pop_free(names, GENERAL_NAME_free);
+ return result;
+ }
+
+ EMatchResult MatchCertCommonName(X509* cert, TStringBuf hostname) {
+ int commonNameLoc = X509_NAME_get_index_by_NID(X509_get_subject_name(cert), NID_commonName, -1);
+ if (commonNameLoc < 0) {
+ return ERROR;
+ }
+
+ X509_NAME_ENTRY* commonNameEntry = X509_NAME_get_entry(X509_get_subject_name(cert), commonNameLoc);
+ if (!commonNameEntry) {
+ return ERROR;
+ }
+
+ ASN1_STRING* commonNameAsn1 = X509_NAME_ENTRY_get_data(commonNameEntry);
+ if (!commonNameAsn1) {
+ return ERROR;
+ }
+
+ TStringBuf commonName((const char*)ASN1_STRING_get0_data(commonNameAsn1), ASN1_STRING_length(commonNameAsn1));
+
+ return MatchDomainName(commonName, hostname)
+ ? MATCH_FOUND
+ : NO_MATCH;
+ }
+
+ bool CheckCertHostname(X509* cert, TStringBuf hostname) {
+ switch (MatchCertAltNames(cert, hostname)) {
+ case MATCH_FOUND:
+ return true;
+ break;
+ case NO_EXTENSION:
+ return MatchCertCommonName(cert, hostname) == MATCH_FOUND;
+ break;
+ default:
+ return false;
+ }
+ }
+
+ void ParseUserInfo(const TParsedLocation& loc, TString& cert, TString& pvtKey) {
+ if (!loc.UserInfo) {
+ return;
+ }
+
+ TStringBuf kws = loc.UserInfo;
+ while (kws) {
+ TStringBuf name = kws.NextTok('=');
+ TStringBuf value = kws.NextTok(';');
+ if (TStringBuf("cert") == name) {
+ cert = value;
+ } else if (TStringBuf("key") == name) {
+ pvtKey = value;
+ }
+ }
+ }
+
+ struct TSSLInit {
+ inline TSSLInit() {
+ InitOpenSSL();
+ }
+ } SSL_INIT;
+ }
+
+ static inline void PrepareSocket(SOCKET s) {
+ SetNoDelay(s, true);
+ }
+
+ class TConnCache;
+ static TConnCache* SocketCache();
+
+ class TConnCache: public IThreadFactory::IThreadAble {
+ public:
+ typedef TAutoLockFreeQueue<TSocketHolder> TConnList;
+ typedef TAutoPtr<TSocketHolder> TSocketRef;
+
+ struct TConnection {
+ inline TConnection(TSocketRef& s, bool reUsed, const TResolvedHost* host) noexcept
+ : Socket(s)
+ , ReUsed(reUsed)
+ , Host(host)
+ {
+ SocketCache()->ActiveSockets.Inc();
+ }
+
+ inline ~TConnection() {
+ if (!!Socket) {
+ SocketCache()->ActiveSockets.Dec();
+ }
+ }
+
+ SOCKET Fd() {
+ return *Socket;
+ }
+
+ protected:
+ friend class TConnCache;
+ TSocketRef Socket;
+
+ public:
+ const bool ReUsed;
+ const TResolvedHost* Host;
+ };
+
+ TConnCache()
+ : InPurging_(0)
+ , MaxConnId_(0)
+ , Shutdown_(false)
+ {
+ T_ = SystemThreadFactory()->Run(this);
+ }
+
+ ~TConnCache() override {
+ {
+ TGuard<TMutex> g(PurgeMutex_);
+
+ Shutdown_ = true;
+ CondPurge_.Signal();
+ }
+
+ T_->Join();
+ }
+
+ //used for forwarding filling cache
+ class TConnector: public IJob {
+ public:
+ //create fresh connection
+ TConnector(const TResolvedHost* host)
+ : Host_(host)
+ {
+ }
+
+ //continue connecting exist socket
+ TConnector(const TResolvedHost* host, TSocketRef& s)
+ : Host_(host)
+ , S_(s)
+ {
+ }
+
+ void DoRun(TCont* c) override {
+ THolder<TConnector> This(this);
+
+ try {
+ if (!S_) {
+ TSocketRef res(new TSocketHolder());
+
+ for (TNetworkAddress::TIterator it = Host_->Addr.Begin(); it != Host_->Addr.End(); ++it) {
+ int ret = NCoro::ConnectD(c, *res, *it, TDuration::MilliSeconds(300).ToDeadLine());
+
+ if (!ret) {
+ TConnection tc(res, false, Host_);
+ SocketCache()->Release(tc);
+ return;
+ }
+
+ if (ret == ECANCELED) {
+ return;
+ }
+ }
+ } else {
+ if (!NCoro::PollT(c, *S_, CONT_POLL_WRITE, TDuration::MilliSeconds(300))) {
+ TConnection tc(S_, false, Host_);
+ SocketCache()->Release(tc);
+ }
+ }
+ } catch (...) {
+ }
+ }
+
+ private:
+ const TResolvedHost* Host_;
+ TSocketRef S_;
+ };
+
+ TConnection* Connect(TCont* c, const TString& msgAddr, const TResolvedHost* addr, TErrorRef* error) {
+ if (ExceedHardLimit()) {
+ if (error) {
+ *error = new TError("neh::https output connections limit reached", TError::TType::UnknownType);
+ }
+ return nullptr;
+ }
+
+ TSocketRef res;
+ TConnList& connList = ConnList(addr);
+
+ while (connList.Dequeue(&res)) {
+ CachedSockets.Dec();
+
+ if (IsNotSocketClosedByOtherSide(*res)) {
+ if (connList.Size() == 0) {
+ //available connections exhausted - try create yet one (reserve)
+ TAutoPtr<IJob> job(new TConnector(addr));
+
+ if (c) {
+ try {
+ c->Executor()->Create(*job, "https-con");
+ Y_UNUSED(job.Release());
+ } catch (...) {
+ }
+ } else {
+ JobQueue()->Schedule(job);
+ }
+ }
+ return new TConnection(res, true, addr);
+ }
+ }
+
+ if (!c) {
+ if (error) {
+ *error = new TError("directo connection failed");
+ }
+ return nullptr;
+ }
+
+ try {
+ //run reserve/concurrent connecting
+ TAutoPtr<IJob> job(new TConnector(addr));
+
+ c->Executor()->Create(*job, "https-con");
+ Y_UNUSED(job.Release());
+ } catch (...) {
+ }
+
+ TNetworkAddress::TIterator ait = addr->Addr.Begin();
+ res.Reset(new TSocketHolder(NCoro::Socket(*ait)));
+ const TInstant now(TInstant::Now());
+ const TInstant deadline(now + TDuration::Seconds(10));
+ TDuration delay = TDuration::MilliSeconds(8);
+ TInstant checkpoint = Min(deadline, now + delay);
+ int ret = NCoro::ConnectD(c, *res, ait->ai_addr, ait->ai_addrlen, checkpoint);
+
+ if (ret) {
+ do {
+ if ((ret == ETIMEDOUT || ret == EINTR) && checkpoint < deadline) {
+ delay += delay;
+ checkpoint = Min(deadline, now + delay);
+
+ TSocketRef res2;
+
+ if (connList.Dequeue(&res2)) {
+ CachedSockets.Dec();
+
+ if (IsNotSocketClosedByOtherSide(*res2)) {
+ try {
+ TAutoPtr<IJob> job(new TConnector(addr, res));
+
+ c->Executor()->Create(*job, "https-con");
+ Y_UNUSED(job.Release());
+ } catch (...) {
+ }
+
+ res = res2;
+
+ break;
+ }
+ }
+ } else {
+ if (error) {
+ *error = new TError(TStringBuilder() << TStringBuf("can not connect to ") << msgAddr);
+ }
+ return nullptr;
+ }
+ } while (ret = NCoro::PollD(c, *res, CONT_POLL_WRITE, checkpoint));
+ }
+
+ PrepareSocket(*res);
+
+ return new TConnection(res, false, addr);
+ }
+
+ inline void Release(TConnection& conn) {
+ if (!ExceedHardLimit()) {
+ size_t maxConnId = MaxConnId_.load(std::memory_order_acquire);
+
+ while (maxConnId < conn.Host->Id) {
+ MaxConnId_.compare_exchange_strong(
+ maxConnId,
+ conn.Host->Id,
+ std::memory_order_seq_cst,
+ std::memory_order_seq_cst);
+ maxConnId = MaxConnId_.load(std::memory_order_acquire);
+ }
+
+ CachedSockets.Inc();
+ ActiveSockets.Dec();
+
+ ConnList(conn.Host).Enqueue(conn.Socket);
+ }
+
+ if (CachedSockets.Val() && ExceedSoftLimit()) {
+ SuggestPurgeCache();
+ }
+ }
+
+ void SetFdLimits(size_t soft, size_t hard) {
+ Limits.SetSoft(soft);
+ Limits.SetHard(hard);
+ }
+
+ private:
+ void SuggestPurgeCache() {
+ if (AtomicTryLock(&InPurging_)) {
+ //evaluate the usefulness of purging the cache
+ //если в кеше мало соединений (< MaxConnId_/16 или 64), не чистим кеш
+ if ((size_t)CachedSockets.Val() > (Min((size_t)MaxConnId_.load(std::memory_order_acquire), (size_t)1024U) >> 4)) {
+ //по мере приближения к hardlimit нужда в чистке cache приближается к 100%
+ size_t closenessToHardLimit256 = ((ActiveSockets.Val() + 1) << 8) / (Limits.Delta() + 1);
+ //чем больше соединений в кеше, а не в работе, тем менее нужен кеш (можно его почистить)
+ size_t cacheUselessness256 = ((CachedSockets.Val() + 1) << 8) / (ActiveSockets.Val() + 1);
+
+ //итого, - пороги срабатывания:
+ //при достижении soft-limit, если соединения в кеше, а не в работе
+ //на полпути от soft-limit к hard-limit, если в кеше больше половины соединений
+ //при приближении к hardlimit пытаться почистить кеш почти постоянно
+ if ((closenessToHardLimit256 + cacheUselessness256) >= 256U) {
+ TGuard<TMutex> g(PurgeMutex_);
+
+ CondPurge_.Signal();
+ return; //memo: thread MUST unlock InPurging_ (see DoExecute())
+ }
+ }
+ AtomicUnlock(&InPurging_);
+ }
+ }
+
+ void DoExecute() override {
+ while (true) {
+ {
+ TGuard<TMutex> g(PurgeMutex_);
+
+ if (Shutdown_)
+ return;
+
+ CondPurge_.WaitI(PurgeMutex_);
+ }
+
+ PurgeCache();
+
+ AtomicUnlock(&InPurging_);
+ }
+ }
+
+ inline void OnPurgeSocket(ui64& processed) {
+ CachedSockets.Dec();
+ if ((processed++ & 0x3f) == 0) {
+ //suspend execution every 64 processed socket (clean rate ~= 6400 sockets/sec)
+ Sleep(TDuration::MilliSeconds(10));
+ }
+ }
+
+ void PurgeCache() noexcept {
+ //try remove at least ExceedSoftLimit() oldest connections from cache
+ //вычисляем долю кеша, которую нужно почистить (в 256 долях) (но не менее 1/32 кеша)
+ size_t frac256 = Min(size_t(Max(size_t(256U / 32U), (ExceedSoftLimit() << 8) / (CachedSockets.Val() + 1))), (size_t)256U);
+ TSocketRef tmp;
+
+ ui64 processed = 0;
+ for (size_t i = 0; i < MaxConnId_.load(std::memory_order_acquire) && !Shutdown_; i++) {
+ TConnList& tc = Lst_.Get(i);
+ if (size_t qsize = tc.Size()) {
+ //в каждой очереди чистим вычисленную долю
+ size_t purgeCounter = ((qsize * frac256) >> 8);
+
+ if (!purgeCounter && qsize) {
+ if (qsize <= 2) {
+ TSocketRef res;
+ if (tc.Dequeue(&res)) {
+ if (IsNotSocketClosedByOtherSide(*res)) {
+ tc.Enqueue(res);
+ } else {
+ OnPurgeSocket(processed);
+ }
+ }
+ } else {
+ purgeCounter = 1;
+ }
+ }
+ while (purgeCounter-- && tc.Dequeue(&tmp)) {
+ OnPurgeSocket(processed);
+ }
+ }
+ }
+ }
+
+ inline TConnList& ConnList(const TResolvedHost* addr) {
+ return Lst_.Get(addr->Id);
+ }
+
+ inline size_t TotalSockets() const noexcept {
+ return ActiveSockets.Val() + CachedSockets.Val();
+ }
+
+ inline size_t ExceedSoftLimit() const noexcept {
+ return NHttp::TFdLimits::ExceedLimit(TotalSockets(), Limits.Soft());
+ }
+
+ inline size_t ExceedHardLimit() const noexcept {
+ return NHttp::TFdLimits::ExceedLimit(TotalSockets(), Limits.Hard());
+ }
+
+ NHttp::TFdLimits Limits;
+ TAtomicCounter ActiveSockets;
+ TAtomicCounter CachedSockets;
+
+ NHttp::TLockFreeSequence<TConnList> Lst_;
+
+ TAtomic InPurging_;
+ std::atomic<size_t> MaxConnId_;
+
+ TAutoPtr<IThreadFactory::IThread> T_;
+ TCondVar CondPurge_;
+ TMutex PurgeMutex_;
+ TAtomicBool Shutdown_;
+ };
+
+ class TSslCtx: public TThrRefBase {
+ protected:
+ TSslCtx()
+ : SslCtx_(nullptr)
+ {
+ }
+
+ public:
+ ~TSslCtx() override {
+ SSL_CTX_free(SslCtx_);
+ }
+
+ operator SSL_CTX*() {
+ return SslCtx_;
+ }
+
+ protected:
+ SSL_CTX* SslCtx_;
+ };
+ using TSslCtxPtr = TIntrusivePtr<TSslCtx>;
+
+ class TSslCtxServer: public TSslCtx {
+ struct TPasswordCallbackUserData {
+ TParsedLocation Location;
+ TString CertFileName;
+ TString KeyFileName;
+ };
+ class TUserDataHolder {
+ public:
+ TUserDataHolder(SSL_CTX* ctx, const TParsedLocation& location, const TString& certFileName, const TString& keyFileName)
+ : SslCtx_(ctx)
+ , Data_{location, certFileName, keyFileName}
+ {
+ SSL_CTX_set_default_passwd_cb_userdata(SslCtx_, &Data_);
+ }
+ ~TUserDataHolder() {
+ SSL_CTX_set_default_passwd_cb_userdata(SslCtx_, nullptr);
+ }
+ private:
+ SSL_CTX* SslCtx_;
+ TPasswordCallbackUserData Data_;
+ };
+ public:
+ TSslCtxServer(const TParsedLocation& loc) {
+ const SSL_METHOD* method = SSLv23_server_method();
+ if (Y_UNLIKELY(!method)) {
+ ythrow TSslException(TStringBuf("SSLv23_server_method"));
+ }
+
+ SslCtx_ = SSL_CTX_new(method);
+ if (Y_UNLIKELY(!SslCtx_)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_new(server)"));
+ }
+
+ TString cert, key;
+ ParseUserInfo(loc, cert, key);
+
+ TUserDataHolder holder(SslCtx_, loc, cert, key);
+
+ SSL_CTX_set_default_passwd_cb(SslCtx_, [](char* buf, int size, int rwflag, void* userData) -> int {
+ Y_UNUSED(rwflag);
+ Y_UNUSED(userData);
+
+ if (THttpsOptions::KeyPasswdCallback == nullptr || userData == nullptr) {
+ return 0;
+ }
+
+ auto data = static_cast<TPasswordCallbackUserData*>(userData);
+ const auto& passwd = THttpsOptions::KeyPasswdCallback(data->Location, data->CertFileName, data->KeyFileName);
+
+ if (size < static_cast<int>(passwd.size())) {
+ return -1;
+ }
+
+ return passwd.copy(buf, size, 0);
+ });
+
+ if (!cert || !key) {
+ ythrow TSslException() << TStringBuf("no certificate or private key is specified for server");
+ }
+
+ if (1 != SSL_CTX_use_certificate_chain_file(SslCtx_, cert.data())) {
+ ythrow TSslException(TStringBuf("SSL_CTX_use_certificate_chain_file (server)"));
+ }
+
+ if (1 != SSL_CTX_use_PrivateKey_file(SslCtx_, key.data(), SSL_FILETYPE_PEM)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_use_PrivateKey_file (server)"));
+ }
+
+ if (1 != SSL_CTX_check_private_key(SslCtx_)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_check_private_key (server)"));
+ }
+ }
+ };
+
+ class TSslCtxClient: public TSslCtx {
+ public:
+ TSslCtxClient() {
+ const SSL_METHOD* method = SSLv23_client_method();
+ if (Y_UNLIKELY(!method)) {
+ ythrow TSslException(TStringBuf("SSLv23_client_method"));
+ }
+
+ SslCtx_ = SSL_CTX_new(method);
+ if (Y_UNLIKELY(!SslCtx_)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_new(client)"));
+ }
+
+ const TString& caFile = THttpsOptions::CAFile;
+ const TString& caPath = THttpsOptions::CAPath;
+ if (caFile || caPath) {
+ if (!SSL_CTX_load_verify_locations(SslCtx_, caFile ? caFile.data() : nullptr, caPath ? caPath.data() : nullptr)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_load_verify_locations(client)"));
+ }
+ }
+
+ SSL_CTX_set_options(SslCtx_, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION);
+ if (THttpsOptions::ClientVerifyCallback) {
+ SSL_CTX_set_verify(SslCtx_, SSL_VERIFY_PEER, THttpsOptions::ClientVerifyCallback);
+ } else {
+ SSL_CTX_set_verify(SslCtx_, SSL_VERIFY_NONE, nullptr);
+ }
+
+ const TString& clientCertificate = THttpsOptions::ClientCertificate;
+ const TString& clientPrivateKey = THttpsOptions::ClientPrivateKey;
+ if (clientCertificate && clientPrivateKey) {
+ SSL_CTX_set_default_passwd_cb(SslCtx_, [](char* buf, int size, int rwflag, void* userData) -> int {
+ Y_UNUSED(rwflag);
+ Y_UNUSED(userData);
+
+ const TString& clientPrivateKeyPwd = THttpsOptions::ClientPrivateKeyPassword;
+ if (!clientPrivateKeyPwd) {
+ return 0;
+ }
+ if (size < static_cast<int>(clientPrivateKeyPwd.size())) {
+ return -1;
+ }
+
+ return clientPrivateKeyPwd.copy(buf, size, 0);
+ });
+ if (1 != SSL_CTX_use_certificate_chain_file(SslCtx_, clientCertificate.c_str())) {
+ ythrow TSslException(TStringBuf("SSL_CTX_use_certificate_chain_file (client)"));
+ }
+ if (1 != SSL_CTX_use_PrivateKey_file(SslCtx_, clientPrivateKey.c_str(), SSL_FILETYPE_PEM)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_use_PrivateKey_file (client)"));
+ }
+ if (1 != SSL_CTX_check_private_key(SslCtx_)) {
+ ythrow TSslException(TStringBuf("SSL_CTX_check_private_key (client)"));
+ }
+ } else if (clientCertificate || clientPrivateKey) {
+ ythrow TSslException() << TStringBuf("both certificate and private key must be specified for client");
+ }
+ }
+
+ static TSslCtxClient& Instance() {
+ return *Singleton<TSslCtxClient>();
+ }
+ };
+
+ class TContBIO : public NOpenSSL::TAbstractIO {
+ public:
+ TContBIO(SOCKET s, const TAtomicBool* canceled = nullptr)
+ : Timeout_(TDuration::MicroSeconds(10000))
+ , S_(s)
+ , Canceled_(canceled)
+ , Cont_(nullptr)
+ {
+ }
+
+ SOCKET Socket() {
+ return S_;
+ }
+
+ int PollT(int what, const TDuration& timeout) {
+ return NCoro::PollT(Cont_, Socket(), what, timeout);
+ }
+
+ void WaitUntilWritten() {
+#if defined(FIONWRITE)
+ if (Y_LIKELY(Cont_)) {
+ int err;
+ int nbytes = Max<int>();
+ TDuration tout = TDuration::MilliSeconds(10);
+
+ while (((err = ioctl(S_, FIONWRITE, &nbytes)) == 0) && nbytes) {
+ err = NCoro::PollT(Cont_, S_, CONT_POLL_READ, tout);
+
+ if (!err) {
+ //wait complete, cause have some data
+ break;
+ }
+
+ if (err != ETIMEDOUT) {
+ ythrow TSystemError(err) << TStringBuf("request failed");
+ }
+
+ tout = tout * 2;
+ }
+
+ if (err) {
+ ythrow TSystemError() << TStringBuf("ioctl() failed");
+ }
+ } else {
+ ythrow TSslException() << TStringBuf("No cont available");
+ }
+#endif
+ }
+
+ void AcquireCont(TCont* c) {
+ Cont_ = c;
+ }
+ void ReleaseCont() {
+ Cont_ = nullptr;
+ }
+
+ int Write(const char* data, size_t dlen, size_t* written) override {
+ if (Y_UNLIKELY(!Cont_)) {
+ return -1;
+ }
+
+ while (true) {
+ auto done = NCoro::WriteI(Cont_, S_, data, dlen);
+ if (done.Status() != EAGAIN) {
+ *written = done.Checked();
+ return 1;
+ }
+ }
+ }
+
+ int Read(char* data, size_t dlen, size_t* readbytes) override {
+ if (Y_UNLIKELY(!Cont_)) {
+ return -1;
+ }
+
+ if (!Canceled_) {
+ while (true) {
+ auto done = NCoro::ReadI(Cont_, S_, data, dlen);
+ if (EAGAIN != done.Status()) {
+ *readbytes = done.Processed();
+ return 1;
+ }
+ }
+ }
+
+ while (true) {
+ if (*Canceled_) {
+ return SSL_RVAL_TIMEOUT;
+ }
+
+ TContIOStatus ioStat(NCoro::ReadT(Cont_, S_, data, dlen, Timeout_));
+ if (ioStat.Status() == ETIMEDOUT) {
+ //increase to 1.5 times every iteration (to 1sec floor)
+ Timeout_ = TDuration::MicroSeconds(Min<ui64>(1000000, Timeout_.MicroSeconds() + (Timeout_.MicroSeconds() >> 1)));
+ continue;
+ }
+
+ *readbytes = ioStat.Processed();
+ return 1;
+ }
+ }
+
+ int Puts(const char* buf) override {
+ Y_UNUSED(buf);
+ return -1;
+ }
+
+ int Gets(char* buf, int size) override {
+ Y_UNUSED(buf);
+ Y_UNUSED(size);
+ return -1;
+ }
+
+ void Flush() override {
+ }
+
+ private:
+ TDuration Timeout_;
+ SOCKET S_;
+ const TAtomicBool* Canceled_;
+ TCont* Cont_;
+ };
+
+ class TSslIOStream: public IInputStream, public IOutputStream {
+ protected:
+ TSslIOStream(TSslCtx& sslCtx, TAutoPtr<TContBIO> connection)
+ : Connection_(connection)
+ , SslCtx_(sslCtx)
+ , Ssl_(nullptr)
+ {
+ }
+
+ virtual void Handshake() = 0;
+
+ public:
+ void WaitUntilWritten() {
+ if (Connection_) {
+ Connection_->WaitUntilWritten();
+ }
+ }
+
+ int PollReadT(const TDuration& timeout) {
+ if (!Connection_) {
+ return -1;
+ }
+
+ while (true) {
+ const int rpoll = Connection_->PollT(CONT_POLL_READ, timeout);
+ if (!Ssl_ || rpoll) {
+ return rpoll;
+ }
+
+ char c = 0;
+ const int rpeek = SSL_peek(Ssl_.Get(), &c, sizeof(c));
+ if (rpeek < 0) {
+ return -1;
+ } else if (rpeek > 0) {
+ return 0;
+ } else {
+ if ((SSL_get_shutdown(Ssl_.Get()) & SSL_RECEIVED_SHUTDOWN) != 0) {
+ Shutdown(); // wait until shutdown is finished
+ return EIO;
+ }
+ }
+ }
+ }
+
+ void Shutdown() {
+ if (Ssl_ && Connection_) {
+ for (size_t i = 0; i < 2; ++i) {
+ bool rval = SSL_shutdown(Ssl_.Get());
+ if (0 == rval) {
+ continue;
+ } else if (1 == rval) {
+ break;
+ }
+ }
+ }
+ }
+
+ inline void AcquireCont(TCont* c) {
+ if (Y_UNLIKELY(!Connection_)) {
+ ythrow TSslException() << TStringBuf("no connection provided");
+ }
+
+ Connection_->AcquireCont(c);
+ }
+
+ inline void ReleaseCont() {
+ if (Connection_) {
+ Connection_->ReleaseCont();
+ }
+ }
+
+ TContIOStatus WriteVectorI(const TList<IOutputStream::TPart>& vec) {
+ for (const auto& p : vec) {
+ Write(p.buf, p.len);
+ }
+ return TContIOStatus::Success(vec.size());
+ }
+
+ SOCKET Socket() {
+ if (Y_UNLIKELY(!Connection_)) {
+ ythrow TSslException() << TStringBuf("no connection provided");
+ }
+
+ return Connection_->Socket();
+ }
+
+ private:
+ void DoWrite(const void* buf, size_t len) override {
+ if (Y_UNLIKELY(!Connection_)) {
+ ythrow TSslException() << TStringBuf("DoWrite() no connection provided");
+ }
+
+ const int rval = SSL_write(Ssl_.Get(), buf, len);
+ if (rval <= 0) {
+ ythrow TSslException(TStringBuf("SSL_write"), Ssl_.Get(), rval);
+ }
+ }
+
+ size_t DoRead(void* buf, size_t len) override {
+ if (Y_UNLIKELY(!Connection_)) {
+ ythrow TSslException() << TStringBuf("DoRead() no connection provided");
+ }
+
+ const int rval = SSL_read(Ssl_.Get(), buf, len);
+ if (rval < 0) {
+ if (SSL_RVAL_TIMEOUT == rval) {
+ ythrow TSystemError(ECANCELED) << TStringBuf(" http request canceled");
+ }
+ ythrow TSslException(TStringBuf("SSL_read"), Ssl_.Get(), rval);
+ } else if (0 == rval) {
+ if ((SSL_get_shutdown(Ssl_.Get()) & SSL_RECEIVED_SHUTDOWN) != 0) {
+ return rval;
+ } else {
+ const int err = SSL_get_error(Ssl_.Get(), rval);
+ if (SSL_ERROR_ZERO_RETURN != err) {
+ ythrow TSslException(TStringBuf("SSL_read"), Ssl_.Get(), rval);
+ }
+ }
+ }
+
+ return static_cast<size_t>(rval);
+ }
+
+ protected:
+ // just for ssl debug
+ static void InfoCB(const SSL* s, int where, int ret) {
+ TStringBuf str;
+ const int w = where & ~SSL_ST_MASK;
+ if (w & SSL_ST_CONNECT) {
+ str = TStringBuf("SSL_connect");
+ } else if (w & SSL_ST_ACCEPT) {
+ str = TStringBuf("SSL_accept");
+ } else {
+ str = TStringBuf("undefined");
+ }
+
+ if (where & SSL_CB_LOOP) {
+ Cerr << str << ':' << SSL_state_string_long(s) << Endl;
+ } else if (where & SSL_CB_ALERT) {
+ Cerr << TStringBuf("SSL3 alert ") << ((where & SSL_CB_READ) ? TStringBuf("read") : TStringBuf("write")) << ' ' << SSL_alert_type_string_long(ret) << ':' << SSL_alert_desc_string_long(ret) << Endl;
+ } else if (where & SSL_CB_EXIT) {
+ if (ret == 0) {
+ Cerr << str << TStringBuf(":failed in ") << SSL_state_string_long(s) << Endl;
+ } else if (ret < 0) {
+ Cerr << str << TStringBuf(":error in ") << SSL_state_string_long(s) << Endl;
+ }
+ }
+ }
+
+ protected:
+ THolder<TContBIO> Connection_;
+ TSslCtx& SslCtx_;
+ TSslHolder Ssl_;
+ };
+
+ class TContBIOWatcher {
+ public:
+ TContBIOWatcher(TSslIOStream& io, TCont* c) noexcept
+ : IO_(io)
+ {
+ IO_.AcquireCont(c);
+ }
+
+ ~TContBIOWatcher() noexcept {
+ IO_.ReleaseCont();
+ }
+
+ private:
+ TSslIOStream& IO_;
+ };
+
+ class TSslClientIOStream: public TSslIOStream {
+ public:
+ TSslClientIOStream(TSslCtxClient& sslCtx, const TParsedLocation& loc, SOCKET s, const TAtomicBool* canceled)
+ : TSslIOStream(sslCtx, new TContBIO(s, canceled))
+ , Location_(loc)
+ {
+ }
+
+ void Handshake() override {
+ Ssl_.Reset(SSL_new(SslCtx_));
+ if (THttpsOptions::EnableSslClientDebug) {
+ SSL_set_info_callback(Ssl_.Get(), InfoCB);
+ }
+
+ BIO_up_ref(*Connection_); // SSL_set_bio consumes only one reference if rbio and wbio are the same
+ SSL_set_bio(Ssl_.Get(), *Connection_, *Connection_);
+
+ const TString hostname(Location_.Host);
+ const int rev = SSL_set_tlsext_host_name(Ssl_.Get(), hostname.data());
+ if (Y_UNLIKELY(1 != rev)) {
+ ythrow TSslException(TStringBuf("SSL_set_tlsext_host_name(client)"), Ssl_.Get(), rev);
+ }
+
+ TString cert, pvtKey;
+ ParseUserInfo(Location_, cert, pvtKey);
+
+ if (cert && (1 != SSL_use_certificate_file(Ssl_.Get(), cert.data(), SSL_FILETYPE_PEM))) {
+ ythrow TSslException(TStringBuf("SSL_use_certificate_file(client)"));
+ }
+
+ if (pvtKey) {
+ if (1 != SSL_use_PrivateKey_file(Ssl_.Get(), pvtKey.data(), SSL_FILETYPE_PEM)) {
+ ythrow TSslException(TStringBuf("SSL_use_PrivateKey_file(client)"));
+ }
+
+ if (1 != SSL_check_private_key(Ssl_.Get())) {
+ ythrow TSslException(TStringBuf("SSL_check_private_key(client)"));
+ }
+ }
+
+ SSL_set_connect_state(Ssl_.Get());
+
+ // TODO restore session if reconnect
+ const int rval = SSL_do_handshake(Ssl_.Get());
+ if (1 != rval) {
+ if (rval == SSL_RVAL_TIMEOUT) {
+ ythrow TSystemError(ECANCELED) << TStringBuf("canceled");
+ } else {
+ ythrow TSslException(TStringBuf("BIO_do_handshake(client)"), Ssl_.Get(), rval);
+ }
+ }
+
+ if (THttpsOptions::CheckCertificateHostname) {
+ TX509Holder peerCert(SSL_get_peer_certificate(Ssl_.Get()));
+ if (!peerCert) {
+ ythrow TSslException(TStringBuf("SSL_get_peer_certificate(client)"));
+ }
+
+ if (!CheckCertHostname(peerCert.Get(), Location_.Host)) {
+ ythrow TSslException(TStringBuf("CheckCertHostname(client)"));
+ }
+ }
+ }
+
+ private:
+ const TParsedLocation Location_;
+ //TSslSessionHolder Session_;
+ };
+
+ static TConnCache* SocketCache() {
+ return Singleton<TConnCache>();
+ }
+
+ //some templates magic
+ template <class T>
+ static inline TAutoPtr<T> AutoPtr(T* t) noexcept {
+ return t;
+ }
+
+ static inline TString ReadAll(THttpInput& in) {
+ TString ret;
+ ui64 clin;
+
+ if (in.GetContentLength(clin)) {
+ const size_t cl = SafeIntegerCast<size_t>(clin);
+
+ ret.ReserveAndResize(cl);
+ size_t sz = in.Load(ret.begin(), cl);
+ if (sz != cl) {
+ throw yexception() << TStringBuf("not full content: ") << sz << TStringBuf(" bytes from ") << cl;
+ }
+ } else if (in.HasContent()) {
+ TVector<char> buff(9500); //common jumbo frame size
+
+ while (size_t len = in.Read(buff.data(), buff.size())) {
+ ret.AppendNoAlias(buff.data(), len);
+ }
+ }
+
+ return ret;
+ }
+
+ template <class TRequestType>
+ class THttpsRequest: public IJob {
+ public:
+ inline THttpsRequest(TSimpleHandleRef hndl, TMessage msg)
+ : Hndl_(hndl)
+ , Msg_(std::move(msg))
+ , Loc_(Msg_.Addr)
+ , Addr_(CachedThrResolve(TResolveInfo(Loc_.Host, Loc_.GetPort())))
+ {
+ }
+
+ void DoRun(TCont* c) override {
+ THolder<THttpsRequest> This(this);
+
+ if (c->Cancelled()) {
+ Hndl_->NotifyError(new TError("canceled", TError::TType::Cancelled));
+ return;
+ }
+
+ TErrorRef error;
+ THolder<TConnCache::TConnection> s(SocketCache()->Connect(c, Msg_.Addr, Addr_, &error));
+ if (!s) {
+ Hndl_->NotifyError(error);
+ return;
+ }
+
+ TSslClientIOStream io(TSslCtxClient::Instance(), Loc_, s->Fd(), Hndl_->CanceledPtr());
+ TContBIOWatcher w(io, c);
+ TString received;
+ THttpHeaders headers;
+ TString firstLine;
+
+ try {
+ io.Handshake();
+ RequestData().SendTo(io);
+ Req_.Destroy();
+ error = ProcessRecv(io, &received, &headers, &firstLine);
+ } catch (const TSystemError& e) {
+ if (c->Cancelled() || e.Status() == ECANCELED) {
+ error = new TError("canceled", TError::TType::Cancelled);
+ } else {
+ error = new TError(CurrentExceptionMessage());
+ }
+ } catch (...) {
+ if (c->Cancelled()) {
+ error = new TError("canceled", TError::TType::Cancelled);
+ } else {
+ error = new TError(CurrentExceptionMessage());
+ }
+ }
+
+ if (error) {
+ Hndl_->NotifyError(error, received, firstLine, headers);
+ } else {
+ io.Shutdown();
+ SocketCache()->Release(*s);
+ Hndl_->NotifyResponse(received, firstLine, headers);
+ }
+ }
+
+ TErrorRef ProcessRecv(TSslClientIOStream& io, TString* data, THttpHeaders* headers, TString* firstLine) {
+ io.WaitUntilWritten();
+
+ Hndl_->SetSendComplete();
+
+ THttpInput in(&io);
+ *data = ReadAll(in);
+ *firstLine = in.FirstLine();
+ *headers = in.Headers();
+
+ i32 code = ParseHttpRetCode(in.FirstLine());
+ if (code < 200 || code > (!THttpsOptions::RedirectionNotError ? 299 : 399)) {
+ return new TError(TStringBuilder() << TStringBuf("request failed(") << in.FirstLine() << ')', TError::TType::ProtocolSpecific, code);
+ }
+
+ return nullptr;
+ }
+
+ const NHttp::TRequestData& RequestData() {
+ if (!Req_) {
+ Req_ = TRequestType::Build(Msg_, Loc_);
+ }
+ return *Req_;
+ }
+
+ private:
+ TSimpleHandleRef Hndl_;
+ const TMessage Msg_;
+ const TParsedLocation Loc_;
+ const TResolvedHost* Addr_;
+ NHttp::TRequestData::TPtr Req_;
+ };
+
+ class TServer: public IRequester, public TContListener::ICallBack {
+ class TSslServerIOStream: public TSslIOStream, public TThrRefBase {
+ public:
+ TSslServerIOStream(TSslCtxServer& sslCtx, TSocketRef s)
+ : TSslIOStream(sslCtx, new TContBIO(*s))
+ , S_(s)
+ {
+ }
+
+ void Close(bool shutdown) {
+ if (shutdown) {
+ Shutdown();
+ }
+ S_->Close();
+ }
+
+ void Handshake() override {
+ if (!Ssl_) {
+ Ssl_.Reset(SSL_new(SslCtx_));
+ if (THttpsOptions::EnableSslServerDebug) {
+ SSL_set_info_callback(Ssl_.Get(), InfoCB);
+ }
+
+ BIO_up_ref(*Connection_); // SSL_set_bio consumes only one reference if rbio and wbio are the same
+ SSL_set_bio(Ssl_.Get(), *Connection_, *Connection_);
+
+ const int rc = SSL_accept(Ssl_.Get());
+ if (1 != rc) {
+ ythrow TSslException(TStringBuf("SSL_accept"), Ssl_.Get(), rc);
+ }
+ }
+
+ if (!SSL_is_init_finished(Ssl_.Get())) {
+ const int rc = SSL_do_handshake(Ssl_.Get());
+ if (rc != 1) {
+ ythrow TSslException(TStringBuf("SSL_do_handshake"), Ssl_.Get(), rc);
+ }
+ }
+ }
+
+ private:
+ TSocketRef S_;
+ };
+
+ class TJobsQueue: public TAutoOneConsumerPipeQueue<IJob>, public TThrRefBase {
+ };
+
+ typedef TIntrusivePtr<TJobsQueue> TJobsQueueRef;
+
+ class TWrite: public IJob, public TData {
+ private:
+ template <class T>
+ static void WriteHeader(IOutputStream& os, TStringBuf name, T value) {
+ os << name << TStringBuf(": ") << value << TStringBuf("\r\n");
+ }
+
+ static void WriteHttpCode(IOutputStream& os, TMaybe<IRequest::TResponseError> error) {
+ if (!error.Defined()) {
+ os << HttpCodeStrEx(HttpCodes::HTTP_OK);
+ return;
+ }
+
+ switch (*error) {
+ case IRequest::TResponseError::BadRequest:
+ os << HttpCodeStrEx(HttpCodes::HTTP_BAD_REQUEST);
+ break;
+ case IRequest::TResponseError::Forbidden:
+ os << HttpCodeStrEx(HttpCodes::HTTP_FORBIDDEN);
+ break;
+ case IRequest::TResponseError::NotExistService:
+ os << HttpCodeStrEx(HttpCodes::HTTP_NOT_FOUND);
+ break;
+ case IRequest::TResponseError::TooManyRequests:
+ os << HttpCodeStrEx(HttpCodes::HTTP_TOO_MANY_REQUESTS);
+ break;
+ case IRequest::TResponseError::InternalError:
+ os << HttpCodeStrEx(HttpCodes::HTTP_INTERNAL_SERVER_ERROR);
+ break;
+ case IRequest::TResponseError::NotImplemented:
+ os << HttpCodeStrEx(HttpCodes::HTTP_NOT_IMPLEMENTED);
+ break;
+ case IRequest::TResponseError::BadGateway:
+ os << HttpCodeStrEx(HttpCodes::HTTP_BAD_GATEWAY);
+ break;
+ case IRequest::TResponseError::ServiceUnavailable:
+ os << HttpCodeStrEx(HttpCodes::HTTP_SERVICE_UNAVAILABLE);
+ break;
+ case IRequest::TResponseError::BandwidthLimitExceeded:
+ os << HttpCodeStrEx(HttpCodes::HTTP_BANDWIDTH_LIMIT_EXCEEDED);
+ break;
+ case IRequest::TResponseError::MaxResponseError:
+ ythrow yexception() << TStringBuf("unknow type of error");
+ }
+ }
+
+ public:
+ inline TWrite(TData& data, const TString& compressionScheme, TIntrusivePtr<TSslServerIOStream> io, TServer* server, const TString& headers, int httpCode)
+ : CompressionScheme_(compressionScheme)
+ , IO_(io)
+ , Server_(server)
+ , Error_(TMaybe<IRequest::TResponseError>())
+ , Headers_(headers)
+ , HttpCode_(httpCode)
+ {
+ swap(data);
+ }
+
+ inline TWrite(TData& data, const TString& compressionScheme, TIntrusivePtr<TSslServerIOStream> io, TServer* server, IRequest::TResponseError error, const TString& headers)
+ : CompressionScheme_(compressionScheme)
+ , IO_(io)
+ , Server_(server)
+ , Error_(error)
+ , Headers_(headers)
+ , HttpCode_(0)
+ {
+ swap(data);
+ }
+
+ void DoRun(TCont* c) override {
+ THolder<TWrite> This(this);
+
+ try {
+ TContBIOWatcher w(*IO_, c);
+
+ PrepareSocket(IO_->Socket());
+
+ char buf[128];
+ TMemoryOutput mo(buf, sizeof(buf));
+
+ mo << TStringBuf("HTTP/1.1 ");
+ if (HttpCode_) {
+ mo << HttpCodeStrEx(HttpCode_);
+ } else {
+ WriteHttpCode(mo, Error_);
+ }
+ mo << TStringBuf("\r\n");
+
+ if (!CompressionScheme_.empty()) {
+ WriteHeader(mo, TStringBuf("Content-Encoding"), TStringBuf(CompressionScheme_));
+ }
+ WriteHeader(mo, TStringBuf("Connection"), TStringBuf("Keep-Alive"));
+ WriteHeader(mo, TStringBuf("Content-Length"), size());
+
+ mo << Headers_;
+
+ mo << TStringBuf("\r\n");
+
+ IO_->Write(buf, mo.Buf() - buf);
+ if (size()) {
+ IO_->Write(data(), size());
+ }
+
+ Server_->Enqueue(new TRead(IO_, Server_));
+ } catch (...) {
+ }
+ }
+
+ private:
+ const TString CompressionScheme_;
+ TIntrusivePtr<TSslServerIOStream> IO_;
+ TServer* Server_;
+ TMaybe<IRequest::TResponseError> Error_;
+ TString Headers_;
+ int HttpCode_;
+ };
+
+ class TRequest: public IHttpRequest {
+ public:
+ inline TRequest(THttpInput& in, TIntrusivePtr<TSslServerIOStream> io, TServer* server)
+ : IO_(io)
+ , Tmp_(in.FirstLine())
+ , CompressionScheme_(in.BestCompressionScheme())
+ , RemoteHost_(PrintHostByRfc(*GetPeerAddr(IO_->Socket())))
+ , Headers_(in.Headers())
+ , H_(Tmp_)
+ , Server_(server)
+ {
+ }
+
+ ~TRequest() override {
+ if (!!IO_) {
+ try {
+ Server_->Enqueue(new TFail(IO_, Server_));
+ } catch (...) {
+ }
+ }
+ }
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("https");
+ }
+
+ TString RemoteHost() const override {
+ return RemoteHost_;
+ }
+
+ const THttpHeaders& Headers() const override {
+ return Headers_;
+ }
+
+ TStringBuf Method() const override {
+ return H_.Method;
+ }
+
+ TStringBuf Cgi() const override {
+ return H_.Cgi;
+ }
+
+ TStringBuf Service() const override {
+ return TStringBuf(H_.Path).Skip(1);
+ }
+
+ TStringBuf RequestId() const override {
+ return TStringBuf();
+ }
+
+ bool Canceled() const override {
+ if (!IO_) {
+ return false;
+ }
+ return !IsNotSocketClosedByOtherSide(IO_->Socket());
+ }
+
+ void SendReply(TData& data) override {
+ SendReply(data, TString(), HttpCodes::HTTP_OK);
+ }
+
+ void SendReply(TData& data, const TString& headers, int httpCode) override {
+ const bool compressed = Compress(data);
+ Server_->Enqueue(new TWrite(data, compressed ? CompressionScheme_ : TString(), IO_, Server_, headers, httpCode));
+ Y_UNUSED(IO_.Release());
+ }
+
+ void SendError(TResponseError error, const THttpErrorDetails& details) override {
+ TData data;
+ Server_->Enqueue(new TWrite(data, TString(), IO_, Server_, error, details.Headers));
+ Y_UNUSED(IO_.Release());
+ }
+
+ private:
+ bool Compress(TData& data) const {
+ if (CompressionScheme_ == TStringBuf("gzip")) {
+ try {
+ TData gzipped(data.size());
+ TMemoryOutput out(gzipped.data(), gzipped.size());
+ TZLibCompress c(&out, ZLib::GZip);
+ c.Write(data.data(), data.size());
+ c.Finish();
+ gzipped.resize(out.Buf() - gzipped.data());
+ data.swap(gzipped);
+ return true;
+ } catch (yexception&) {
+ // gzipped data occupies more space than original data
+ }
+ }
+ return false;
+ }
+
+ private:
+ TIntrusivePtr<TSslServerIOStream> IO_;
+ const TString Tmp_;
+ const TString CompressionScheme_;
+ const TString RemoteHost_;
+ const THttpHeaders Headers_;
+
+ protected:
+ TParsedHttpFull H_;
+ TServer* Server_;
+ };
+
+ class TGetRequest: public TRequest {
+ public:
+ inline TGetRequest(THttpInput& in, TIntrusivePtr<TSslServerIOStream> io, TServer* server)
+ : TRequest(in, io, server)
+ {
+ }
+
+ TStringBuf Data() const override {
+ return H_.Cgi;
+ }
+
+ TStringBuf Body() const override {
+ return TStringBuf();
+ }
+ };
+
+ class TPostRequest: public TRequest {
+ public:
+ inline TPostRequest(THttpInput& in, TIntrusivePtr<TSslServerIOStream> io, TServer* server)
+ : TRequest(in, io, server)
+ , Data_(ReadAll(in))
+ {
+ }
+
+ TStringBuf Data() const override {
+ return Data_;
+ }
+
+ TStringBuf Body() const override {
+ return Data_;
+ }
+
+ private:
+ TString Data_;
+ };
+
+ class TFail: public IJob {
+ public:
+ inline TFail(TIntrusivePtr<TSslServerIOStream> io, TServer* server)
+ : IO_(io)
+ , Server_(server)
+ {
+ }
+
+ void DoRun(TCont* c) override {
+ THolder<TFail> This(this);
+ constexpr TStringBuf answer = "HTTP/1.1 503 Service unavailable\r\n"
+ "Content-Length: 0\r\n\r\n"sv;
+
+ try {
+ TContBIOWatcher w(*IO_, c);
+ IO_->Write(answer);
+ Server_->Enqueue(new TRead(IO_, Server_));
+ } catch (...) {
+ }
+ }
+
+ private:
+ TIntrusivePtr<TSslServerIOStream> IO_;
+ TServer* Server_;
+ };
+
+ class TRead: public IJob {
+ public:
+ TRead(TIntrusivePtr<TSslServerIOStream> io, TServer* server, bool selfRemove = false)
+ : IO_(io)
+ , Server_(server)
+ , SelfRemove(selfRemove)
+ {
+ }
+
+ inline void operator()(TCont* c) {
+ try {
+ TContBIOWatcher w(*IO_, c);
+
+ if (IO_->PollReadT(TDuration::Seconds(InputConnections()->UnusedConnKeepaliveTimeout()))) {
+ IO_->Close(true);
+ return;
+ }
+
+ IO_->Handshake();
+ THttpInput in(IO_.Get());
+
+ const char sym = *in.FirstLine().data();
+
+ if (sym == 'p' || sym == 'P') {
+ Server_->OnRequest(new TPostRequest(in, IO_, Server_));
+ } else {
+ Server_->OnRequest(new TGetRequest(in, IO_, Server_));
+ }
+ } catch (...) {
+ IO_->Close(false);
+ }
+
+ if (SelfRemove) {
+ delete this;
+ }
+ }
+
+ private:
+ void DoRun(TCont* c) override {
+ THolder<TRead> This(this);
+ (*this)(c);
+ }
+
+ private:
+ TIntrusivePtr<TSslServerIOStream> IO_;
+ TServer* Server_;
+ bool SelfRemove = false;
+ };
+
+ public:
+ inline TServer(IOnRequest* cb, const TParsedLocation& loc)
+ : CB_(cb)
+ , E_(RealStackSize(16000))
+ , L_(new TContListener(this, &E_, TContListener::TOptions().SetDeferAccept(true)))
+ , JQ_(new TJobsQueue())
+ , SslCtx_(loc)
+ {
+ L_->Bind(TNetworkAddress(loc.GetPort()));
+ E_.Create<TServer, &TServer::RunDispatcher>(this, "dispatcher");
+ Thrs_.push_back(Spawn<TServer, &TServer::Run>(this));
+ }
+
+ ~TServer() override {
+ JQ_->Enqueue(nullptr);
+
+ for (size_t i = 0; i < Thrs_.size(); ++i) {
+ Thrs_[i]->Join();
+ }
+ }
+
+ void Run() {
+ //SetHighestThreadPriority();
+ L_->Listen();
+ E_.Execute();
+ }
+
+ inline void OnRequest(const IRequestRef& req) {
+ CB_->OnRequest(req);
+ }
+
+ TJobsQueueRef& JobQueue() noexcept {
+ return JQ_;
+ }
+
+ void Enqueue(IJob* j) {
+ JQ_->EnqueueSafe(TAutoPtr<IJob>(j));
+ }
+
+ void RunDispatcher(TCont* c) {
+ for (;;) {
+ TAutoPtr<IJob> job(JQ_->Dequeue(c));
+
+ if (!job) {
+ break;
+ }
+
+ try {
+ c->Executor()->Create(*job, "https-job");
+ Y_UNUSED(job.Release());
+ } catch (...) {
+ }
+ }
+
+ JQ_->Enqueue(nullptr);
+ c->Executor()->Abort();
+ }
+
+ void OnAcceptFull(const TAcceptFull& a) override {
+ try {
+ TSocketRef s(new TSharedSocket(*a.S));
+
+ if (InputConnections()->ExceedHardLimit()) {
+ s->Close();
+ return;
+ }
+
+ THolder<TRead> read(new TRead(new TSslServerIOStream(SslCtx_, s), this, /* selfRemove */ true));
+ E_.Create(*read, "https-response");
+ Y_UNUSED(read.Release());
+ E_.Running()->Yield();
+ } catch (...) {
+ }
+ }
+
+ void OnError() override {
+ try {
+ throw;
+ } catch (const TSystemError& e) {
+ //crutch for prevent 100% busyloop (simple suspend listener/accepter)
+ if (e.Status() == EMFILE) {
+ E_.Running()->SleepT(TDuration::MilliSeconds(500));
+ }
+ }
+ }
+
+ private:
+ IOnRequest* CB_;
+ TContExecutor E_;
+ THolder<TContListener> L_;
+ TVector<TThreadRef> Thrs_;
+ TJobsQueueRef JQ_;
+ TSslCtxServer SslCtx_;
+ };
+
+ template <class T>
+ class THttpsProtocol: public IProtocol {
+ public:
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ return new TServer(cb, loc);
+ }
+
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ TSimpleHandleRef ret(new TSimpleHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+ try {
+ TAutoPtr<THttpsRequest<T>> req(new THttpsRequest<T>(ret, msg));
+ JobQueue()->Schedule(req);
+ return ret.Get();
+ } catch (...) {
+ ret->ResetOnRecv();
+ throw;
+ }
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return T::Name();
+ }
+
+ bool SetOption(TStringBuf name, TStringBuf value) override {
+ return THttpsOptions::Set(name, value);
+ }
+ };
+
+ struct TRequestGet: public NHttp::TRequestGet {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("https");
+ }
+ };
+
+ struct TRequestFull: public NHttp::TRequestFull {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("fulls");
+ }
+ };
+
+ struct TRequestPost: public NHttp::TRequestPost {
+ static inline TStringBuf Name() noexcept {
+ return TStringBuf("posts");
+ }
+ };
+
+ }
+}
+
+namespace NNeh {
+ IProtocol* SSLGetProtocol() {
+ return Singleton<NHttps::THttpsProtocol<NNeh::NHttps::TRequestGet>>();
+ }
+
+ IProtocol* SSLPostProtocol() {
+ return Singleton<NHttps::THttpsProtocol<NNeh::NHttps::TRequestPost>>();
+ }
+
+ IProtocol* SSLFullProtocol() {
+ return Singleton<NHttps::THttpsProtocol<NNeh::NHttps::TRequestFull>>();
+ }
+
+ void SetHttpOutputConnectionsLimits(size_t softLimit, size_t hardLimit) {
+ Y_VERIFY(
+ hardLimit > softLimit,
+ "invalid output fd limits; hardLimit=%" PRISZT ", softLimit=%" PRISZT,
+ hardLimit, softLimit);
+
+ NHttps::SocketCache()->SetFdLimits(softLimit, hardLimit);
+ }
+
+ void SetHttpInputConnectionsLimits(size_t softLimit, size_t hardLimit) {
+ Y_VERIFY(
+ hardLimit > softLimit,
+ "invalid output fd limits; hardLimit=%" PRISZT ", softLimit=%" PRISZT,
+ hardLimit, softLimit);
+
+ NHttps::InputConnections()->SetFdLimits(softLimit, hardLimit);
+ }
+
+ void SetHttpInputConnectionsTimeouts(unsigned minSec, unsigned maxSec) {
+ Y_VERIFY(
+ maxSec > minSec,
+ "invalid input fd limits timeouts; maxSec=%u, minSec=%u",
+ maxSec, minSec);
+
+ NHttps::InputConnections()->MinUnusedConnKeepaliveTimeout.store(minSec, std::memory_order_release);
+ NHttps::InputConnections()->MaxUnusedConnKeepaliveTimeout.store(maxSec, std::memory_order_release);
+ }
+}
diff --git a/library/cpp/neh/https.h b/library/cpp/neh/https.h
new file mode 100644
index 0000000000..6dbb5370d7
--- /dev/null
+++ b/library/cpp/neh/https.h
@@ -0,0 +1,47 @@
+#pragma once
+
+#include <contrib/libs/openssl/include/openssl/ossl_typ.h>
+
+#include <util/generic/string.h>
+#include <util/generic/strbuf.h>
+
+#include <functional>
+
+namespace NNeh {
+ class IProtocol;
+ struct TParsedLocation;
+
+ IProtocol* SSLGetProtocol();
+ IProtocol* SSLPostProtocol();
+ IProtocol* SSLFullProtocol();
+
+ /// if exceed soft limit, reduce quantity unused connections in cache
+ void SetHttpOutputConnectionsLimits(size_t softLimit, size_t hardLimit);
+
+ /// if exceed soft limit, reduce keepalive time for unused connections
+ void SetHttpInputConnectionsLimits(size_t softLimit, size_t hardLimit);
+
+ /// unused input sockets keepalive timeouts
+ /// real(used) timeout:
+ /// - max, if not reached soft limit
+ /// - min, if reached hard limit
+ /// - approx. linear changed[max..min], while conn. count in range [soft..hard]
+ void SetHttpInputConnectionsTimeouts(unsigned minSeconds, unsigned maxSeconds);
+
+ struct THttpsOptions {
+ using TVerifyCallback = int (*)(int, X509_STORE_CTX*);
+ using TPasswordCallback = std::function<TString (const TParsedLocation&, const TString&, const TString&)>;
+ static TString CAFile;
+ static TString CAPath;
+ static TString ClientCertificate;
+ static TString ClientPrivateKey;
+ static TString ClientPrivateKeyPassword;
+ static bool CheckCertificateHostname;
+ static bool EnableSslServerDebug;
+ static bool EnableSslClientDebug;
+ static TVerifyCallback ClientVerifyCallback;
+ static TPasswordCallback KeyPasswdCallback;
+ static bool RedirectionNotError;
+ static bool Set(TStringBuf name, TStringBuf value);
+ };
+}
diff --git a/library/cpp/neh/inproc.cpp b/library/cpp/neh/inproc.cpp
new file mode 100644
index 0000000000..b124f38d17
--- /dev/null
+++ b/library/cpp/neh/inproc.cpp
@@ -0,0 +1,212 @@
+#include "inproc.h"
+
+#include "details.h"
+#include "neh.h"
+#include "location.h"
+#include "utils.h"
+#include "factory.h"
+
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/generic/singleton.h>
+#include <util/stream/output.h>
+#include <util/string/cast.h>
+
+using namespace NNeh;
+
+namespace {
+ const TString canceled = "canceled";
+
+ struct TInprocHandle: public TNotifyHandle {
+ inline TInprocHandle(const TMessage& msg, IOnRecv* r, TStatCollector* sc) noexcept
+ : TNotifyHandle(r, msg, sc)
+ , Canceled_(false)
+ , NotifyCnt_(0)
+ {
+ }
+
+ bool MessageSendedCompletely() const noexcept override {
+ return true;
+ }
+
+ void Cancel() noexcept override {
+ THandle::Cancel(); //inform stat collector
+ Canceled_ = true;
+ try {
+ if (MarkReplied()) {
+ NotifyError(new TError(canceled, TError::Cancelled));
+ }
+ } catch (...) {
+ Cdbg << "inproc canc. " << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ inline void SendReply(const TString& resp) {
+ if (MarkReplied()) {
+ NotifyResponse(resp);
+ }
+ }
+
+ inline void SendError(const TString& details) {
+ if (MarkReplied()) {
+ NotifyError(new TError{details, TError::ProtocolSpecific, 1});
+ }
+ }
+
+ void Disable() {
+ F_ = nullptr;
+ MarkReplied();
+ }
+
+ inline bool Canceled() const noexcept {
+ return Canceled_;
+ }
+
+ private:
+ //return true when mark first reply
+ inline bool MarkReplied() {
+ return AtomicAdd(NotifyCnt_, 1) == 1;
+ }
+
+ private:
+ TAtomicBool Canceled_;
+ TAtomic NotifyCnt_;
+ };
+
+ typedef TIntrusivePtr<TInprocHandle> TInprocHandleRef;
+
+ class TInprocLocation: public TParsedLocation {
+ public:
+ TInprocLocation(const TStringBuf& addr)
+ : TParsedLocation(addr)
+ {
+ Service.Split('?', InprocService, InprocId);
+ }
+
+ TStringBuf InprocService;
+ TStringBuf InprocId;
+ };
+
+ class TRequest: public IRequest {
+ public:
+ TRequest(const TInprocHandleRef& hndl)
+ : Location(hndl->Message().Addr)
+ , Handle_(hndl)
+ {
+ }
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("inproc");
+ }
+
+ TString RemoteHost() const override {
+ return TString();
+ }
+
+ TStringBuf Service() const override {
+ return Location.InprocService;
+ }
+
+ TStringBuf Data() const override {
+ return Handle_->Message().Data;
+ }
+
+ TStringBuf RequestId() const override {
+ return Location.InprocId;
+ }
+
+ bool Canceled() const override {
+ return Handle_->Canceled();
+ }
+
+ void SendReply(TData& data) override {
+ Handle_->SendReply(TString(data.data(), data.size()));
+ }
+
+ void SendError(TResponseError, const TString& details) override {
+ Handle_->SendError(details);
+ }
+
+ const TMessage Request;
+ const TInprocLocation Location;
+
+ private:
+ TInprocHandleRef Handle_;
+ };
+
+ class TInprocRequester: public IRequester {
+ public:
+ TInprocRequester(IOnRequest*& rqcb)
+ : RegisteredCallback_(rqcb)
+ {
+ }
+
+ ~TInprocRequester() override {
+ RegisteredCallback_ = nullptr;
+ }
+
+ private:
+ IOnRequest*& RegisteredCallback_;
+ };
+
+ class TInprocRequesterStg: public IProtocol {
+ public:
+ inline TInprocRequesterStg() {
+ V_.resize(1 + (size_t)Max<ui16>());
+ }
+
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ IOnRequest*& rqcb = Find(loc);
+
+ if (!rqcb) {
+ rqcb = cb;
+ } else if (rqcb != cb) {
+ ythrow yexception() << "shit happen - already registered";
+ }
+
+ return new TInprocRequester(rqcb);
+ }
+
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ TInprocHandleRef hndl(new TInprocHandle(msg, fallback, !ss ? nullptr : new TStatCollector(ss)));
+ try {
+ TAutoPtr<TRequest> req(new TRequest(hndl));
+
+ if (IOnRequest* cb = Find(req->Location)) {
+ cb->OnRequest(req.Release());
+ } else {
+ throw yexception() << TStringBuf("not found inproc location");
+ }
+ } catch (...) {
+ hndl->Disable();
+ throw;
+ }
+
+ return THandleRef(hndl.Get());
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return TStringBuf("inproc");
+ }
+
+ private:
+ static inline ui16 Id(const TParsedLocation& loc) {
+ return loc.GetPort();
+ }
+
+ inline IOnRequest*& Find(const TParsedLocation& loc) {
+ return Find(Id(loc));
+ }
+
+ inline IOnRequest*& Find(ui16 id) {
+ return V_[id];
+ }
+
+ private:
+ TVector<IOnRequest*> V_;
+ };
+}
+
+IProtocol* NNeh::InProcProtocol() {
+ return Singleton<TInprocRequesterStg>();
+}
diff --git a/library/cpp/neh/inproc.h b/library/cpp/neh/inproc.h
new file mode 100644
index 0000000000..546ef915ab
--- /dev/null
+++ b/library/cpp/neh/inproc.h
@@ -0,0 +1,7 @@
+#pragma once
+
+namespace NNeh {
+ class IProtocol;
+
+ IProtocol* InProcProtocol();
+}
diff --git a/library/cpp/neh/jobqueue.cpp b/library/cpp/neh/jobqueue.cpp
new file mode 100644
index 0000000000..8d02fd9bbf
--- /dev/null
+++ b/library/cpp/neh/jobqueue.cpp
@@ -0,0 +1,79 @@
+#include "utils.h"
+#include "lfqueue.h"
+#include "jobqueue.h"
+#include "pipequeue.h"
+
+#include <util/thread/factory.h>
+#include <util/generic/singleton.h>
+#include <util/system/thread.h>
+
+using namespace NNeh;
+
+namespace {
+ class TExecThread: public IThreadFactory::IThreadAble, public IJob {
+ public:
+ TExecThread()
+ : T_(SystemThreadFactory()->Run(this))
+ {
+ }
+
+ ~TExecThread() override {
+ Enqueue(this);
+ T_->Join();
+ }
+
+ inline void Enqueue(IJob* job) {
+ Q_.Enqueue(job);
+ }
+
+ private:
+ void DoRun(TCont* c) override {
+ c->Executor()->Abort();
+ }
+
+ void DoExecute() override {
+ SetHighestThreadPriority();
+
+ TContExecutor e(RealStackSize(20000));
+
+ e.Execute<TExecThread, &TExecThread::Dispatcher>(this);
+ }
+
+ inline void Dispatcher(TCont* c) {
+ IJob* job;
+
+ while ((job = Q_.Dequeue(c))) {
+ try {
+ c->Executor()->Create(*job, "job");
+ } catch (...) {
+ (*job)(c);
+ }
+ }
+ }
+
+ typedef TAutoPtr<IThreadFactory::IThread> IThreadRef;
+ TOneConsumerPipeQueue<IJob> Q_;
+ IThreadRef T_;
+ };
+
+ class TJobScatter: public IJobQueue {
+ public:
+ inline TJobScatter() {
+ for (size_t i = 0; i < 2; ++i) {
+ E_.push_back(new TExecThread());
+ }
+ }
+
+ void ScheduleImpl(IJob* job) override {
+ E_[TThread::CurrentThreadId() % E_.size()]->Enqueue(job);
+ }
+
+ private:
+ typedef TAutoPtr<TExecThread> TExecThreadRef;
+ TVector<TExecThreadRef> E_;
+ };
+}
+
+IJobQueue* NNeh::JobQueue() {
+ return Singleton<TJobScatter>();
+}
diff --git a/library/cpp/neh/jobqueue.h b/library/cpp/neh/jobqueue.h
new file mode 100644
index 0000000000..ec86458cf0
--- /dev/null
+++ b/library/cpp/neh/jobqueue.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include <library/cpp/coroutine/engine/impl.h>
+
+#include <util/generic/yexception.h>
+#include <util/stream/output.h>
+
+namespace NNeh {
+ class IJob {
+ public:
+ inline void operator()(TCont* c) noexcept {
+ try {
+ DoRun(c);
+ } catch (...) {
+ Cdbg << "run " << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ virtual ~IJob() {
+ }
+
+ private:
+ virtual void DoRun(TCont* c) = 0;
+ };
+
+ class IJobQueue {
+ public:
+ template <class T>
+ inline void Schedule(T req) {
+ ScheduleImpl(req.Get());
+ Y_UNUSED(req.Release());
+ }
+
+ virtual void ScheduleImpl(IJob* job) = 0;
+
+ virtual ~IJobQueue() {
+ }
+ };
+
+ IJobQueue* JobQueue();
+}
diff --git a/library/cpp/neh/lfqueue.h b/library/cpp/neh/lfqueue.h
new file mode 100644
index 0000000000..c957047a99
--- /dev/null
+++ b/library/cpp/neh/lfqueue.h
@@ -0,0 +1,53 @@
+#pragma once
+
+#include <util/thread/lfqueue.h>
+#include <util/generic/ptr.h>
+
+namespace NNeh {
+ template <class T>
+ class TAutoLockFreeQueue {
+ struct TCounter : TAtomicCounter {
+ inline void IncCount(const T* const&) {
+ Inc();
+ }
+
+ inline void DecCount(const T* const&) {
+ Dec();
+ }
+ };
+
+ public:
+ typedef TAutoPtr<T> TRef;
+
+ inline ~TAutoLockFreeQueue() {
+ TRef tmp;
+
+ while (Dequeue(&tmp)) {
+ }
+ }
+
+ inline bool Dequeue(TRef* t) {
+ T* res = nullptr;
+
+ if (Q_.Dequeue(&res)) {
+ t->Reset(res);
+
+ return true;
+ }
+
+ return false;
+ }
+
+ inline void Enqueue(TRef& t) {
+ Q_.Enqueue(t.Get());
+ Y_UNUSED(t.Release());
+ }
+
+ inline size_t Size() {
+ return Q_.GetCounter().Val();
+ }
+
+ private:
+ TLockFreeQueue<T*, TCounter> Q_;
+ };
+}
diff --git a/library/cpp/neh/location.cpp b/library/cpp/neh/location.cpp
new file mode 100644
index 0000000000..36c0846673
--- /dev/null
+++ b/library/cpp/neh/location.cpp
@@ -0,0 +1,50 @@
+#include "location.h"
+
+#include <util/string/cast.h>
+
+using namespace NNeh;
+
+TParsedLocation::TParsedLocation(TStringBuf path) {
+ path.Split(':', Scheme, path);
+ path.Skip(2);
+
+ const size_t pos = path.find_first_of(TStringBuf("?@"));
+
+ if (TStringBuf::npos != pos && '@' == path[pos]) {
+ path.SplitAt(pos, UserInfo, path);
+ path.Skip(1);
+ }
+
+ auto checkRange = [](size_t b, size_t e){
+ return b != TStringBuf::npos && e != TStringBuf::npos && b < e;
+ };
+
+ size_t oBracket = path.find_first_of('[');
+ size_t cBracket = path.find_first_of(']');
+ size_t endEndPointPos = path.find_first_of('/');
+ if (checkRange(oBracket, cBracket)) {
+ endEndPointPos = path.find_first_of('/', cBracket);
+ }
+ EndPoint = path.SubStr(0, endEndPointPos);
+ Host = EndPoint;
+
+ size_t lastColon = EndPoint.find_last_of(':');
+ if (checkRange(cBracket, lastColon)
+ || (cBracket == TStringBuf::npos && lastColon != TStringBuf::npos))
+ {
+ Host = EndPoint.SubStr(0, lastColon);
+ Port = EndPoint.SubStr(lastColon + 1, EndPoint.size() - lastColon + 1);
+ }
+
+ if (endEndPointPos != TStringBuf::npos) {
+ Service = path.SubStr(endEndPointPos + 1, path.size() - endEndPointPos + 1);
+ }
+}
+
+ui16 TParsedLocation::GetPort() const {
+ if (!Port) {
+ return TStringBuf("https") == Scheme || TStringBuf("fulls") == Scheme || TStringBuf("posts") == Scheme ? 443 : 80;
+ }
+
+ return FromString<ui16>(Port);
+}
diff --git a/library/cpp/neh/location.h b/library/cpp/neh/location.h
new file mode 100644
index 0000000000..8a9154785f
--- /dev/null
+++ b/library/cpp/neh/location.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include <util/generic/strbuf.h>
+
+namespace NNeh {
+ struct TParsedLocation {
+ TParsedLocation(TStringBuf path);
+
+ ui16 GetPort() const;
+
+ TStringBuf Scheme;
+ TStringBuf UserInfo;
+ TStringBuf EndPoint;
+ TStringBuf Host;
+ TStringBuf Port;
+ TStringBuf Service;
+ };
+}
diff --git a/library/cpp/neh/multi.cpp b/library/cpp/neh/multi.cpp
new file mode 100644
index 0000000000..4929d4efd9
--- /dev/null
+++ b/library/cpp/neh/multi.cpp
@@ -0,0 +1,35 @@
+#include "rpc.h"
+#include "multi.h"
+#include "location.h"
+#include "factory.h"
+
+#include <util/string/cast.h>
+#include <util/generic/hash.h>
+
+using namespace NNeh;
+
+namespace {
+ namespace NMulti {
+ class TRequester: public IRequester {
+ public:
+ inline TRequester(const TListenAddrs& addrs, IOnRequest* cb) {
+ for (const auto& addr : addrs) {
+ TParsedLocation loc(addr);
+ IRequesterRef& req = R_[ToString(loc.Scheme) + ToString(loc.GetPort())];
+
+ if (!req) {
+ req = ProtocolFactory()->Protocol(loc.Scheme)->CreateRequester(cb, loc);
+ }
+ }
+ }
+
+ private:
+ typedef THashMap<TString, IRequesterRef> TRequesters;
+ TRequesters R_;
+ };
+ }
+}
+
+IRequesterRef NNeh::MultiRequester(const TListenAddrs& addrs, IOnRequest* cb) {
+ return new NMulti::TRequester(addrs, cb);
+}
diff --git a/library/cpp/neh/multi.h b/library/cpp/neh/multi.h
new file mode 100644
index 0000000000..ab99dd78f9
--- /dev/null
+++ b/library/cpp/neh/multi.h
@@ -0,0 +1,12 @@
+#pragma once
+
+#include "rpc.h"
+
+#include <util/generic/vector.h>
+#include <util/generic/string.h>
+
+namespace NNeh {
+ typedef TVector<TString> TListenAddrs;
+
+ IRequesterRef MultiRequester(const TListenAddrs& addrs, IOnRequest* rq);
+}
diff --git a/library/cpp/neh/multiclient.cpp b/library/cpp/neh/multiclient.cpp
new file mode 100644
index 0000000000..cb7672755e
--- /dev/null
+++ b/library/cpp/neh/multiclient.cpp
@@ -0,0 +1,378 @@
+#include "multiclient.h"
+#include "utils.h"
+
+#include <library/cpp/containers/intrusive_rb_tree/rb_tree.h>
+
+#include <atomic>
+
+namespace {
+ using namespace NNeh;
+
+ struct TCompareDeadline {
+ template <class T>
+ static inline bool Compare(const T& l, const T& r) noexcept {
+ return l.Deadline() < r.Deadline() || (l.Deadline() == r.Deadline() && &l < &r);
+ }
+ };
+
+ class TMultiClient: public IMultiClient, public TThrRefBase {
+ class TRequestSupervisor: public TRbTreeItem<TRequestSupervisor, TCompareDeadline>, public IOnRecv, public TThrRefBase, public TNonCopyable {
+ private:
+ TRequestSupervisor() {
+ } //disable
+
+ public:
+ inline TRequestSupervisor(const TRequest& request, TMultiClient* mc) noexcept
+ : MC_(mc)
+ , Request_(request)
+ , Maked_(0)
+ , FinishOnMakeRequest_(0)
+ , Handled_(0)
+ , Dequeued_(false)
+ {
+ }
+
+ inline TInstant Deadline() const noexcept {
+ return Request_.Deadline;
+ }
+
+ //not thread safe (can be called at some time from TMultiClient::Request() and TRequestSupervisor::OnNotify())
+ void OnMakeRequest(THandleRef h) noexcept {
+ //request can be mark as maked only once, so only one/first call set handle
+ if (AtomicCas(&Maked_, 1, 0)) {
+ H_.Swap(h);
+ //[paranoid mode on] make sure handle be initiated before return
+ AtomicSet(FinishOnMakeRequest_, 1);
+ } else {
+ while (!AtomicGet(FinishOnMakeRequest_)) {
+ SpinLockPause();
+ }
+ //[paranoid mode off]
+ }
+ }
+
+ void FillEvent(TEvent& ev) noexcept {
+ ev.Hndl = H_;
+ FillEventUserData(ev);
+ }
+
+ void FillEventUserData(TEvent& ev) noexcept {
+ ev.UserData = Request_.UserData;
+ }
+
+ void ResetRequest() noexcept { //destroy keepaliving cross-ref TRequestSupervisor<->THandle
+ H_.Drop();
+ }
+
+ //method OnProcessRequest() & OnProcessResponse() executed from Wait() context (thread)
+ void OnEndProcessRequest() {
+ Dequeued_ = true;
+ if (Y_UNLIKELY(IsHandled())) {
+ ResetRequest(); //race - response already handled before processing request from queue
+ } else {
+ MC_->RegisterRequest(this);
+ }
+ }
+
+ void OnEndProcessResponse() {
+ if (Y_LIKELY(Dequeued_)) {
+ UnLink();
+ ResetRequest();
+ } //else request yet not dequeued/registered, so we not need unlink request
+ //(when we later dequeue request OnEndProcessRequest()...IsHandled() return true and we reset request)
+ }
+
+ //IOnRecv interface
+ void OnNotify(THandle& h) override {
+ if (Y_LIKELY(MarkAsHandled())) {
+ THandleRef hr(&h);
+ OnMakeRequest(hr); //fix race with receiving response before return control from NNeh::Request()
+ MC_->ScheduleResponse(this, hr);
+ }
+ }
+
+ void OnRecv(THandle&) noexcept override {
+ UnRef();
+ }
+
+ void OnEnd() noexcept override {
+ UnRef();
+ }
+ //
+
+ //request can be handled only once, so only one/first call MarkAsHandled() return true
+ bool MarkAsHandled() noexcept {
+ return AtomicCas(&Handled_, 1, 0);
+ }
+
+ bool IsHandled() const noexcept {
+ return AtomicGet(Handled_);
+ }
+
+ private:
+ TIntrusivePtr<TMultiClient> MC_;
+ TRequest Request_;
+ THandleRef H_;
+ TAtomic Maked_;
+ TAtomic FinishOnMakeRequest_;
+ TAtomic Handled_;
+ bool Dequeued_;
+ };
+
+ typedef TRbTree<TRequestSupervisor, TCompareDeadline> TRequestsSupervisors;
+ typedef TIntrusivePtr<TRequestSupervisor> TRequestSupervisorRef;
+
+ public:
+ TMultiClient()
+ : Interrupt_(false)
+ , NearDeadline_(TInstant::Max().GetValue())
+ , E_(::TSystemEvent::rAuto)
+ , Shutdown_(false)
+ {
+ }
+
+ struct TResetRequest {
+ inline void operator()(TRequestSupervisor& rs) const noexcept {
+ rs.ResetRequest();
+ }
+ };
+
+ void Shutdown() {
+ //reset THandleRef's for all exist supervisors and jobs queue (+prevent creating new)
+ //- so we break crossref-chain, which prevent destroy this object THande->TRequestSupervisor->TMultiClient)
+ Shutdown_ = true;
+ RS_.ForEachNoOrder(TResetRequest());
+ RS_.Clear();
+ CleanQueue();
+ }
+
+ private:
+ class IJob {
+ public:
+ virtual ~IJob() {
+ }
+ virtual bool Process(TEvent&) = 0;
+ virtual void Cancel() = 0;
+ };
+ typedef TAutoPtr<IJob> TJobPtr;
+
+ class TNewRequest: public IJob {
+ public:
+ TNewRequest(TRequestSupervisorRef& rs)
+ : RS_(rs)
+ {
+ }
+
+ private:
+ bool Process(TEvent&) override {
+ RS_->OnEndProcessRequest();
+ return false;
+ }
+
+ void Cancel() override {
+ RS_->ResetRequest();
+ }
+
+ TRequestSupervisorRef RS_;
+ };
+
+ class TNewResponse: public IJob {
+ public:
+ TNewResponse(TRequestSupervisor* rs, THandleRef& h) noexcept
+ : RS_(rs)
+ , H_(h)
+ {
+ }
+
+ private:
+ bool Process(TEvent& ev) override {
+ ev.Type = TEvent::Response;
+ ev.Hndl = H_;
+ RS_->FillEventUserData(ev);
+ RS_->OnEndProcessResponse();
+ return true;
+ }
+
+ void Cancel() override {
+ RS_->ResetRequest();
+ }
+
+ TRequestSupervisorRef RS_;
+ THandleRef H_;
+ };
+
+ public:
+ THandleRef Request(const TRequest& request) override {
+ TIntrusivePtr<TRequestSupervisor> rs(new TRequestSupervisor(request, this));
+ THandleRef h;
+ try {
+ rs->Ref();
+ h = NNeh::Request(request.Msg, rs.Get());
+ //accurately handle race when processing new request event
+ //(we already can receive response (call OnNotify) before we schedule info about new request here)
+ } catch (...) {
+ rs->UnRef();
+ throw;
+ }
+ rs->OnMakeRequest(h);
+ ScheduleRequest(rs, h, request.Deadline);
+ return h;
+ }
+
+ bool Wait(TEvent& ev, const TInstant deadline_ = TInstant::Max()) override {
+ while (!Interrupt_) {
+ TInstant deadline = deadline_;
+ const TInstant now = TInstant::Now();
+ if (deadline != TInstant::Max() && now >= deadline) {
+ break;
+ }
+
+ { //process jobs queue (requests/responses info)
+ TAutoPtr<IJob> j;
+ while (JQ_.Dequeue(&j)) {
+ if (j->Process(ev)) {
+ return true;
+ }
+ }
+ }
+
+ if (!RS_.Empty()) {
+ TRequestSupervisor* nearRS = &*RS_.Begin();
+ if (nearRS->Deadline() <= now) {
+ if (!nearRS->MarkAsHandled()) {
+ //race with notify, - now in queue must exist response job for this request
+ continue;
+ }
+ ev.Type = TEvent::Timeout;
+ nearRS->FillEvent(ev);
+ nearRS->ResetRequest();
+ nearRS->UnLink();
+ return true;
+ }
+ deadline = Min(nearRS->Deadline(), deadline);
+ }
+
+ if (SetNearDeadline(deadline)) {
+ continue; //update deadline to more far time, so need re-check queue for avoiding race
+ }
+
+ E_.WaitD(deadline);
+ }
+ Interrupt_ = false;
+ return false;
+ }
+
+ void Interrupt() override {
+ Interrupt_ = true;
+ Signal();
+ }
+
+ size_t QueueSize() override {
+ return JQ_.Size();
+ }
+
+ private:
+ void Signal() {
+ //TODO:try optimize - hack with skipping signaling if not have waiters (reduce mutex usage)
+ E_.Signal();
+ }
+
+ void ScheduleRequest(TIntrusivePtr<TRequestSupervisor>& rs, const THandleRef& h, const TInstant& deadline) {
+ TJobPtr j(new TNewRequest(rs));
+ JQ_.Enqueue(j);
+ if (!h->Signalled) {
+ if (deadline.GetValue() < GetNearDeadline_()) {
+ Signal();
+ }
+ }
+ }
+
+ void ScheduleResponse(TRequestSupervisor* rs, THandleRef& h) {
+ TJobPtr j(new TNewResponse(rs, h));
+ JQ_.Enqueue(j);
+ if (Y_UNLIKELY(Shutdown_)) {
+ CleanQueue();
+ } else {
+ Signal();
+ }
+ }
+
+ //return true, if deadline re-installed to more late time
+ bool SetNearDeadline(const TInstant& deadline) {
+ bool deadlineMovedFurther = deadline.GetValue() > GetNearDeadline_();
+ SetNearDeadline_(deadline.GetValue());
+ return deadlineMovedFurther;
+ }
+
+ //used only from Wait()
+ void RegisterRequest(TRequestSupervisor* rs) {
+ if (rs->Deadline() != TInstant::Max()) {
+ RS_.Insert(rs);
+ } else {
+ rs->ResetRequest(); //prevent blocking destruction 'endless' requests
+ }
+ }
+
+ void CleanQueue() {
+ TAutoPtr<IJob> j;
+ while (JQ_.Dequeue(&j)) {
+ j->Cancel();
+ }
+ }
+
+ private:
+ void SetNearDeadline_(const TInstant::TValue& v) noexcept {
+ TGuard<TAdaptiveLock> g(NDLock_);
+ NearDeadline_.store(v, std::memory_order_release);
+ }
+
+ TInstant::TValue GetNearDeadline_() const noexcept {
+ TGuard<TAdaptiveLock> g(NDLock_);
+ return NearDeadline_.load(std::memory_order_acquire);
+ }
+
+ NNeh::TAutoLockFreeQueue<IJob> JQ_;
+ TAtomicBool Interrupt_;
+ TRequestsSupervisors RS_;
+ TAdaptiveLock NDLock_;
+ std::atomic<TInstant::TValue> NearDeadline_;
+ ::TSystemEvent E_;
+ TAtomicBool Shutdown_;
+ };
+
+ class TMultiClientAutoShutdown: public IMultiClient {
+ public:
+ TMultiClientAutoShutdown()
+ : MC_(new TMultiClient())
+ {
+ }
+
+ ~TMultiClientAutoShutdown() override {
+ MC_->Shutdown();
+ }
+
+ size_t QueueSize() override {
+ return MC_->QueueSize();
+ }
+
+ private:
+ THandleRef Request(const TRequest& req) override {
+ return MC_->Request(req);
+ }
+
+ bool Wait(TEvent& ev, TInstant deadline = TInstant::Max()) override {
+ return MC_->Wait(ev, deadline);
+ }
+
+ void Interrupt() override {
+ return MC_->Interrupt();
+ }
+
+ private:
+ TIntrusivePtr<TMultiClient> MC_;
+ };
+}
+
+TMultiClientPtr NNeh::CreateMultiClient() {
+ return new TMultiClientAutoShutdown();
+}
diff --git a/library/cpp/neh/multiclient.h b/library/cpp/neh/multiclient.h
new file mode 100644
index 0000000000..e12b73dcd9
--- /dev/null
+++ b/library/cpp/neh/multiclient.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include "neh.h"
+
+namespace NNeh {
+ /// thread-safe dispacher for processing multiple neh requests
+ /// (method Wait() MUST be called from single thread, methods Request and Interrupt are thread-safe)
+ class IMultiClient {
+ public:
+ virtual ~IMultiClient() {
+ }
+
+ struct TRequest {
+ TRequest()
+ : Deadline(TInstant::Max())
+ , UserData(nullptr)
+ {
+ }
+
+ TRequest(const TMessage& msg, TInstant deadline = TInstant::Max(), void* userData = nullptr)
+ : Msg(msg)
+ , Deadline(deadline)
+ , UserData(userData)
+ {
+ }
+
+ TMessage Msg;
+ TInstant Deadline;
+ void* UserData;
+ };
+
+ /// WARNING:
+ /// Wait(event) called from another thread can return Event
+ /// for this request before this call return control
+ virtual THandleRef Request(const TRequest& req) = 0;
+
+ virtual size_t QueueSize() = 0;
+
+ struct TEvent {
+ enum TType {
+ Timeout,
+ Response,
+ SizeEventType
+ };
+
+ TEvent()
+ : Type(SizeEventType)
+ , UserData(nullptr)
+ {
+ }
+
+ TEvent(TType t, void* userData)
+ : Type(t)
+ , UserData(userData)
+ {
+ }
+
+ TType Type;
+ THandleRef Hndl;
+ void* UserData;
+ };
+
+ /// return false if interrupted
+ virtual bool Wait(TEvent&, TInstant = TInstant::Max()) = 0;
+ /// interrupt guaranteed breaking execution Wait(), but few interrupts can be handled as one
+ virtual void Interrupt() = 0;
+ };
+
+ typedef TAutoPtr<IMultiClient> TMultiClientPtr;
+
+ TMultiClientPtr CreateMultiClient();
+}
diff --git a/library/cpp/neh/neh.cpp b/library/cpp/neh/neh.cpp
new file mode 100644
index 0000000000..2a3eef5023
--- /dev/null
+++ b/library/cpp/neh/neh.cpp
@@ -0,0 +1,146 @@
+#include "neh.h"
+
+#include "details.h"
+#include "factory.h"
+
+#include <util/generic/list.h>
+#include <util/generic/hash_set.h>
+#include <util/digest/numeric.h>
+#include <util/string/cast.h>
+
+using namespace NNeh;
+
+namespace {
+ class TMultiRequester: public IMultiRequester {
+ struct TOps {
+ template <class T>
+ inline bool operator()(const T& l, const T& r) const noexcept {
+ return l.Get() == r.Get();
+ }
+
+ template <class T>
+ inline size_t operator()(const T& t) const noexcept {
+ return NumericHash(t.Get());
+ }
+ };
+
+ struct TOnComplete {
+ TMultiRequester* Parent;
+ bool Signalled;
+
+ inline TOnComplete(TMultiRequester* parent)
+ : Parent(parent)
+ , Signalled(false)
+ {
+ }
+
+ inline void operator()(TWaitHandle* wh) {
+ THandleRef req(static_cast<THandle*>(wh));
+
+ Signalled = true;
+ Parent->OnComplete(req);
+ }
+ };
+
+ public:
+ void Add(const THandleRef& req) override {
+ Reqs_.insert(req);
+ }
+
+ void Del(const THandleRef& req) override {
+ Reqs_.erase(req);
+ }
+
+ bool Wait(THandleRef& req, TInstant deadLine) override {
+ while (Complete_.empty()) {
+ if (Reqs_.empty()) {
+ return false;
+ }
+
+ TOnComplete cb(this);
+
+ WaitForMultipleObj(Reqs_.begin(), Reqs_.end(), deadLine, cb);
+
+ if (!cb.Signalled) {
+ return false;
+ }
+ }
+
+ req = *Complete_.begin();
+ Complete_.pop_front();
+
+ return true;
+ }
+
+ bool IsEmpty() const override {
+ return Reqs_.empty() && Complete_.empty();
+ }
+
+ inline void OnComplete(const THandleRef& req) {
+ Complete_.push_back(req);
+ Reqs_.erase(req);
+ }
+
+ private:
+ typedef THashSet<THandleRef, TOps, TOps> TReqs;
+ typedef TList<THandleRef> TComplete;
+ TReqs Reqs_;
+ TComplete Complete_;
+ };
+
+ inline IProtocol* ProtocolForMessage(const TMessage& msg) {
+ return ProtocolFactory()->Protocol(TStringBuf(msg.Addr).Before(':'));
+ }
+}
+
+NNeh::TMessage NNeh::TMessage::FromString(const TStringBuf req) {
+ TStringBuf addr;
+ TStringBuf data;
+
+ req.Split('?', addr, data);
+ return TMessage(ToString(addr), ToString(data));
+}
+
+namespace {
+ const TString svcFail = "service status: failed";
+}
+
+THandleRef NNeh::Request(const TMessage& msg, IOnRecv* fallback, bool useAsyncSendRequest) {
+ TServiceStatRef ss;
+
+ if (TServiceStat::Disabled()) {
+ return ProtocolForMessage(msg)->ScheduleAsyncRequest(msg, fallback, ss, useAsyncSendRequest);
+ }
+
+ ss = GetServiceStat(msg.Addr);
+ TServiceStat::EStatus es = ss->GetStatus();
+
+ if (es == TServiceStat::Ok) {
+ return ProtocolForMessage(msg)->ScheduleAsyncRequest(msg, fallback, ss, useAsyncSendRequest);
+ }
+
+ if (es == TServiceStat::ReTry) {
+ //send empty data request for validating service (update TServiceStat info)
+ TMessage validator;
+
+ validator.Addr = msg.Addr;
+
+ ProtocolForMessage(msg)->ScheduleAsyncRequest(validator, nullptr, ss, useAsyncSendRequest);
+ }
+
+ TNotifyHandleRef h(new TNotifyHandle(fallback, msg));
+ h->NotifyError(new TError(svcFail));
+ return h.Get();
+}
+
+THandleRef NNeh::Request(const TString& req, IOnRecv* fallback) {
+ return Request(TMessage::FromString(req), fallback);
+}
+
+IMultiRequesterRef NNeh::CreateRequester() {
+ return new TMultiRequester();
+}
+
+bool NNeh::SetProtocolOption(TStringBuf protoOption, TStringBuf value) {
+ return ProtocolFactory()->Protocol(protoOption.Before('/'))->SetOption(protoOption.After('/'), value);
+}
diff --git a/library/cpp/neh/neh.h b/library/cpp/neh/neh.h
new file mode 100644
index 0000000000..e0211a7dff
--- /dev/null
+++ b/library/cpp/neh/neh.h
@@ -0,0 +1,320 @@
+#pragma once
+
+#include "wfmo.h"
+#include "stat.h"
+
+#include <library/cpp/http/io/headers.h>
+
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/datetime/base.h>
+
+#include <utility>
+
+namespace NNeh {
+ struct TMessage {
+ TMessage() = default;
+
+ inline TMessage(TString addr, TString data)
+ : Addr(std::move(addr))
+ , Data(std::move(data))
+ {
+ }
+
+ static TMessage FromString(TStringBuf request);
+
+ TString Addr;
+ TString Data;
+ };
+
+ using TMessageRef = TAutoPtr<TMessage>;
+
+ struct TError {
+ public:
+ enum TType {
+ UnknownType,
+ Cancelled,
+ ProtocolSpecific
+ };
+
+ TError(TString text, TType type = UnknownType, i32 code = 0, i32 systemCode = 0)
+ : Type(std::move(type))
+ , Code(code)
+ , Text(text)
+ , SystemCode(systemCode)
+ {
+ }
+
+ TType Type = UnknownType;
+ i32 Code = 0; // protocol specific code (example(http): 404)
+ TString Text;
+ i32 SystemCode = 0; // system error code
+ };
+
+ using TErrorRef = TAutoPtr<TError>;
+
+ struct TResponse;
+ using TResponseRef = TAutoPtr<TResponse>;
+
+ struct TResponse {
+ inline TResponse(TMessage req,
+ TString data,
+ const TDuration duration)
+ : TResponse(std::move(req), std::move(data), duration, {} /* firstLine */, {} /* headers */, {} /* error */)
+ {
+ }
+
+ inline TResponse(TMessage req,
+ TString data,
+ const TDuration duration,
+ TString firstLine,
+ THttpHeaders headers)
+ : TResponse(std::move(req), std::move(data), duration, std::move(firstLine), std::move(headers), {} /* error */)
+ {
+ }
+
+ inline TResponse(TMessage req,
+ TString data,
+ const TDuration duration,
+ TString firstLine,
+ THttpHeaders headers,
+ TErrorRef error)
+ : Request(std::move(req))
+ , Data(std::move(data))
+ , Duration(duration)
+ , FirstLine(std::move(firstLine))
+ , Headers(std::move(headers))
+ , Error_(std::move(error))
+ {
+ }
+
+ inline static TResponseRef FromErrorText(TMessage msg, TString error, const TDuration duration) {
+ return new TResponse(std::move(msg), {} /* data */, duration, {} /* firstLine */, {} /* headers */, new TError(std::move(error)));
+ }
+
+ inline static TResponseRef FromError(TMessage msg, TErrorRef error, const TDuration duration) {
+ return new TResponse(std::move(msg), {} /* data */, duration, {} /* firstLine */, {} /* headers */, error);
+ }
+
+ inline static TResponseRef FromError(TMessage msg, TErrorRef error, const TDuration duration,
+ TString data, TString firstLine, THttpHeaders headers)
+ {
+ return new TResponse(std::move(msg), std::move(data), duration, std::move(firstLine), std::move(headers), error);
+ }
+
+ inline static TResponseRef FromError(
+ TMessage msg,
+ TErrorRef error,
+ TString data,
+ const TDuration duration,
+ TString firstLine,
+ THttpHeaders headers)
+ {
+ return new TResponse(std::move(msg), std::move(data), duration, std::move(firstLine), std::move(headers), error);
+ }
+
+ inline bool IsError() const {
+ return Error_.Get();
+ }
+
+ inline TError::TType GetErrorType() const {
+ return Error_.Get() ? Error_->Type : TError::UnknownType;
+ }
+
+ inline i32 GetErrorCode() const {
+ return Error_.Get() ? Error_->Code : 0;
+ }
+
+ inline i32 GetSystemErrorCode() const {
+ return Error_.Get() ? Error_->SystemCode : 0;
+ }
+
+ inline TString GetErrorText() const {
+ return Error_.Get() ? Error_->Text : TString();
+ }
+
+ const TMessage Request;
+ const TString Data;
+ const TDuration Duration;
+ const TString FirstLine;
+ THttpHeaders Headers;
+
+ private:
+ THolder<TError> Error_;
+ };
+
+ class THandle;
+
+ class IOnRecv {
+ public:
+ virtual ~IOnRecv() = default;
+
+ virtual void OnNotify(THandle&) {
+ } //callback on receive response
+ virtual void OnEnd() {
+ } //response was extracted by Wait() method, - OnRecv() will not be called
+ virtual void OnRecv(THandle& resp) = 0; //callback on destroy handler
+ };
+
+ class THandle: public TThrRefBase, public TWaitHandle {
+ public:
+ inline THandle(IOnRecv* f, TStatCollector* s = nullptr) noexcept
+ : F_(f)
+ , Stat_(s)
+ {
+ }
+
+ ~THandle() override {
+ if (F_) {
+ try {
+ F_->OnRecv(*this);
+ } catch (...) {
+ }
+ }
+ }
+
+ virtual bool MessageSendedCompletely() const noexcept {
+ //TODO
+ return true;
+ }
+
+ virtual void Cancel() noexcept {
+ //TODO
+ if (!!Stat_)
+ Stat_->OnCancel();
+ }
+
+ inline const TResponse* Response() const noexcept {
+ return R_.Get();
+ }
+
+ //method MUST be called only after success Wait() for this handle or from callback IOnRecv::OnRecv()
+ //else exist chance for memory leak (race between Get()/Notify())
+ inline TResponseRef Get() noexcept {
+ return R_;
+ }
+
+ inline bool Wait(TResponseRef& msg, const TInstant deadLine) {
+ if (WaitForOne(*this, deadLine)) {
+ if (F_) {
+ F_->OnEnd();
+ F_ = nullptr;
+ }
+ msg = Get();
+
+ return true;
+ }
+
+ return false;
+ }
+
+ inline bool Wait(TResponseRef& msg, const TDuration timeOut) {
+ return Wait(msg, timeOut.ToDeadLine());
+ }
+
+ inline bool Wait(TResponseRef& msg) {
+ return Wait(msg, TInstant::Max());
+ }
+
+ inline TResponseRef Wait(const TInstant deadLine) {
+ TResponseRef ret;
+
+ Wait(ret, deadLine);
+
+ return ret;
+ }
+
+ inline TResponseRef Wait(const TDuration timeOut) {
+ return Wait(timeOut.ToDeadLine());
+ }
+
+ inline TResponseRef Wait() {
+ return Wait(TInstant::Max());
+ }
+
+ protected:
+ inline void Notify(TResponseRef resp) {
+ if (!!Stat_) {
+ if (!resp || resp->IsError()) {
+ Stat_->OnFail();
+ } else {
+ Stat_->OnSuccess();
+ }
+ }
+ R_.Swap(resp);
+ if (F_) {
+ try {
+ F_->OnNotify(*this);
+ } catch (...) {
+ }
+ }
+ Signal();
+ }
+
+ IOnRecv* F_;
+
+ private:
+ TResponseRef R_;
+ THolder<TStatCollector> Stat_;
+ };
+
+ using THandleRef = TIntrusivePtr<THandle>;
+
+ THandleRef Request(const TMessage& msg, IOnRecv* fallback, bool useAsyncSendRequest = false);
+
+ inline THandleRef Request(const TMessage& msg) {
+ return Request(msg, nullptr);
+ }
+
+ THandleRef Request(const TString& req, IOnRecv* fallback);
+
+ inline THandleRef Request(const TString& req) {
+ return Request(req, nullptr);
+ }
+
+ class IMultiRequester {
+ public:
+ virtual ~IMultiRequester() = default;
+
+ virtual void Add(const THandleRef& req) = 0;
+ virtual void Del(const THandleRef& req) = 0;
+ virtual bool Wait(THandleRef& req, TInstant deadLine) = 0;
+ virtual bool IsEmpty() const = 0;
+
+ inline void Schedule(const TString& req) {
+ Add(Request(req));
+ }
+
+ inline bool Wait(THandleRef& req, TDuration timeOut) {
+ return Wait(req, timeOut.ToDeadLine());
+ }
+
+ inline bool Wait(THandleRef& req) {
+ return Wait(req, TInstant::Max());
+ }
+
+ inline bool Wait(TResponseRef& resp, TInstant deadLine) {
+ THandleRef req;
+
+ while (Wait(req, deadLine)) {
+ resp = req->Get();
+
+ if (!!resp) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ inline bool Wait(TResponseRef& resp) {
+ return Wait(resp, TInstant::Max());
+ }
+ };
+
+ using IMultiRequesterRef = TAutoPtr<IMultiRequester>;
+
+ IMultiRequesterRef CreateRequester();
+
+ bool SetProtocolOption(TStringBuf protoOption, TStringBuf value);
+}
diff --git a/library/cpp/neh/netliba.cpp b/library/cpp/neh/netliba.cpp
new file mode 100644
index 0000000000..f69906f3ba
--- /dev/null
+++ b/library/cpp/neh/netliba.cpp
@@ -0,0 +1,508 @@
+#include "details.h"
+#include "factory.h"
+#include "http_common.h"
+#include "location.h"
+#include "multi.h"
+#include "netliba.h"
+#include "netliba_udp_http.h"
+#include "lfqueue.h"
+#include "utils.h"
+
+#include <library/cpp/dns/cache.h>
+
+#include <util/generic/hash.h>
+#include <util/generic/singleton.h>
+#include <util/generic/vector.h>
+#include <util/generic/yexception.h>
+#include <util/string/cast.h>
+#include <util/system/yassert.h>
+
+#include <atomic>
+
+using namespace NDns;
+using namespace NNeh;
+namespace NNeh {
+ size_t TNetLibaOptions::ClientThreads = 4;
+ TDuration TNetLibaOptions::AckTailEffect = TDuration::Seconds(30);
+
+ bool TNetLibaOptions::Set(TStringBuf name, TStringBuf value) {
+#define NETLIBA_TRY_SET(optType, optName) \
+ if (name == TStringBuf(#optName)) { \
+ optName = FromString<optType>(value); \
+ }
+
+ NETLIBA_TRY_SET(size_t, ClientThreads)
+ else NETLIBA_TRY_SET(TDuration, AckTailEffect) else {
+ return false;
+ }
+ return true;
+ }
+}
+
+namespace {
+ namespace NNetLiba {
+ using namespace NNetliba;
+ using namespace NNehNetliba;
+
+ typedef NNehNetliba::IRequester INetLibaRequester;
+ typedef TAutoPtr<TUdpHttpRequest> TUdpHttpRequestPtr;
+ typedef TAutoPtr<TUdpHttpResponse> TUdpHttpResponsePtr;
+
+ static inline const addrinfo* FindIPBase(const TNetworkAddress* addr, int family) {
+ for (TNetworkAddress::TIterator it = addr->Begin(); it != addr->End(); ++it) {
+ if (it->ai_family == family) {
+ return &*it;
+ }
+ }
+
+ return nullptr;
+ }
+
+ static inline const sockaddr_in6& FindIP(const TNetworkAddress* addr) {
+ //prefer ipv6
+ const addrinfo* ret = FindIPBase(addr, AF_INET6);
+
+ if (!ret) {
+ ret = FindIPBase(addr, AF_INET);
+ }
+
+ if (!ret) {
+ ythrow yexception() << "ip not supported by " << *addr;
+ }
+
+ return *(const sockaddr_in6*)(ret->ai_addr);
+ }
+
+ class TLastAckTimes {
+ struct TTimeVal {
+ TTimeVal()
+ : Val(0)
+ {
+ }
+
+ std::atomic<TInstant::TValue> Val;
+ };
+
+ public:
+ TInstant::TValue Get(size_t idAddr) {
+ return Tm_.Get(idAddr).Val.load(std::memory_order_acquire);
+ }
+
+ void Set(size_t idAddr) {
+ Tm_.Get(idAddr).Val.store(TInstant::Now().GetValue(), std::memory_order_release);
+ }
+
+ static TLastAckTimes& Common() {
+ return *Singleton<TLastAckTimes>();
+ }
+
+ private:
+ NNeh::NHttp::TLockFreeSequence<TTimeVal> Tm_;
+ };
+
+ class TRequest: public TSimpleHandle {
+ public:
+ inline TRequest(TIntrusivePtr<INetLibaRequester>& r, size_t idAddr, const TMessage& msg, IOnRecv* cb, TStatCollector* s)
+ : TSimpleHandle(cb, msg, s)
+ , R_(r)
+ , IdAddr_(idAddr)
+ , Notified_(false)
+ {
+ CreateGuid(&Guid_);
+ }
+
+ void Cancel() noexcept override {
+ TSimpleHandle::Cancel();
+ R_->CancelRequest(Guid_);
+ }
+
+ inline const TString& Addr() const noexcept {
+ return Message().Addr;
+ }
+
+ inline const TGUID& Guid() const noexcept {
+ return Guid_;
+ }
+
+ //return false if already notifie
+ inline bool SetNotified() noexcept {
+ bool ret = Notified_;
+ Notified_ = true;
+ return !ret;
+ }
+
+ void OnSend() {
+ if (TNetLibaOptions::AckTailEffect.GetValue() && TLastAckTimes::Common().Get(IdAddr_) + TNetLibaOptions::AckTailEffect.GetValue() > TInstant::Now().GetValue()) {
+ //fake(predicted) completing detection
+ SetSendComplete();
+ }
+ }
+
+ void OnRequestAck() {
+ if (TNetLibaOptions::AckTailEffect.GetValue()) {
+ TLastAckTimes::Common().Set(IdAddr_);
+ }
+ SetSendComplete();
+ }
+
+ private:
+ TIntrusivePtr<INetLibaRequester> R_;
+ size_t IdAddr_;
+ TGUID Guid_;
+ bool Notified_;
+ };
+
+ typedef TIntrusivePtr<TRequest> TRequestRef;
+
+ class TNetLibaBus {
+ class TEventsHandler: public IEventsCollector {
+ typedef THashMap<TGUID, TRequestRef, TGUIDHash> TInFly;
+
+ public:
+ inline void OnSend(TRequestRef& req) {
+ Q_.Enqueue(req);
+ req->OnSend();
+ }
+
+ private:
+ void UpdateInFly() {
+ TRequestRef req;
+
+ while (Q_.Dequeue(&req)) {
+ if (!req) {
+ return;
+ }
+
+ InFly_[req->Guid()] = req;
+ }
+ }
+
+ void AddRequest(TUdpHttpRequest* req) override {
+ //ignore received requests in client
+ delete req;
+ }
+
+ void AddResponse(TUdpHttpResponse* resp) override {
+ TUdpHttpResponsePtr ptr(resp);
+
+ UpdateInFly();
+ TInFly::iterator it = InFly_.find(resp->ReqId);
+
+ Y_VERIFY(it != InFly_.end(), "incorrect incoming message");
+
+ TRequestRef& req = it->second;
+
+ if (req->SetNotified()) {
+ if (resp->Ok == TUdpHttpResponse::OK) {
+ req->NotifyResponse(TString(resp->Data.data(), resp->Data.size()));
+ } else {
+ if (resp->Ok == TUdpHttpResponse::CANCELED) {
+ req->NotifyError(new TError(resp->Error, TError::Cancelled));
+ } else {
+ req->NotifyError(new TError(resp->Error));
+ }
+ }
+ }
+
+ InFly_.erase(it);
+ }
+
+ void AddCancel(const TGUID& guid) override {
+ UpdateInFly();
+ TInFly::iterator it = InFly_.find(guid);
+
+ if (it != InFly_.end() && it->second->SetNotified()) {
+ it->second->NotifyError("Canceled (before ack)");
+ }
+ }
+
+ void AddRequestAck(const TGUID& guid) override {
+ UpdateInFly();
+ TInFly::iterator it = InFly_.find(guid);
+
+ Y_VERIFY(it != InFly_.end(), "incorrect complete notification");
+
+ it->second->OnRequestAck();
+ }
+
+ private:
+ TLockFreeQueue<TRequestRef> Q_;
+ TInFly InFly_;
+ };
+
+ struct TClientThread {
+ TClientThread(int physicalCpu)
+ : EH_(new TEventsHandler())
+ , R_(CreateHttpUdpRequester(0, IEventsCollectorRef(EH_.Get()), physicalCpu))
+ {
+ R_->EnableReportRequestAck();
+ }
+
+ ~TClientThread() {
+ R_->StopNoWait();
+ }
+
+ TIntrusivePtr<TEventsHandler> EH_;
+ TIntrusivePtr<INetLibaRequester> R_;
+ };
+
+ public:
+ TNetLibaBus() {
+ for (size_t i = 0; i < TNetLibaOptions::ClientThreads; ++i) {
+ Clnt_.push_back(new TClientThread(i));
+ }
+ }
+
+ inline THandleRef Schedule(const TMessage& msg, IOnRecv* cb, TServiceStatRef& ss) {
+ TParsedLocation loc(msg.Addr);
+ TUdpAddress addr;
+
+ const TResolvedHost* resHost = CachedResolve(TResolveInfo(loc.Host, loc.GetPort()));
+ GetUdpAddress(&addr, FindIP(&resHost->Addr));
+
+ TClientThread& clnt = *Clnt_[resHost->Id % Clnt_.size()];
+ TIntrusivePtr<INetLibaRequester> rr = clnt.R_;
+ TRequestRef req(new TRequest(rr, resHost->Id, msg, cb, !ss ? nullptr : new TStatCollector(ss)));
+
+ clnt.EH_->OnSend(req);
+ rr->SendRequest(addr, ToString(loc.Service), msg.Data, req->Guid());
+
+ return THandleRef(req.Get());
+ }
+
+ private:
+ TVector<TAutoPtr<TClientThread>> Clnt_;
+ };
+
+ //server
+ class TRequester: public TThrRefBase {
+ struct TSrvRequestState: public TAtomicRefCount<TSrvRequestState> {
+ TSrvRequestState()
+ : Canceled(false)
+ {
+ }
+
+ TAtomicBool Canceled;
+ };
+
+ class TRequest: public IRequest {
+ public:
+ inline TRequest(TAutoPtr<TUdpHttpRequest> req, TIntrusivePtr<TSrvRequestState> state, TRequester* parent)
+ : R_(req)
+ , S_(state)
+ , P_(parent)
+ {
+ }
+
+ ~TRequest() override {
+ if (!!P_) {
+ P_->RequestProcessed(this);
+ }
+ }
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("netliba");
+ }
+
+ TString RemoteHost() const override {
+ if (!H_) {
+ TUdpAddress tmp(R_->PeerAddress);
+ tmp.Scope = 0; //discard scope from serialized addr
+
+ TString addr = GetAddressAsString(tmp);
+
+ TStringBuf host, port;
+
+ TStringBuf(addr).RSplit(':', host, port);
+ H_ = host;
+ }
+ return H_;
+ }
+
+ TStringBuf Service() const override {
+ return TStringBuf(R_->Url.c_str(), R_->Url.length());
+ }
+
+ TStringBuf Data() const override {
+ return TStringBuf((const char*)R_->Data.data(), R_->Data.size());
+ }
+
+ TStringBuf RequestId() const override {
+ const TGUID& g = R_->ReqId;
+
+ return TStringBuf((const char*)g.dw, sizeof(g.dw));
+ }
+
+ bool Canceled() const override {
+ return S_->Canceled;
+ }
+
+ void SendReply(TData& data) override {
+ TIntrusivePtr<TRequester> p;
+ p.Swap(P_);
+ if (!!p) {
+ if (!Canceled()) {
+ p->R_->SendResponse(R_->ReqId, &data);
+ }
+ p->RequestProcessed(this);
+ }
+ }
+
+ void SendError(TResponseError, const TString&) override {
+ // TODO
+ }
+
+ inline const TGUID& RequestGuid() const noexcept {
+ return R_->ReqId;
+ }
+
+ private:
+ TAutoPtr<TUdpHttpRequest> R_;
+ mutable TString H_;
+ TIntrusivePtr<TSrvRequestState> S_;
+ TIntrusivePtr<TRequester> P_;
+ };
+
+ class TEventsHandler: public IEventsCollector {
+ public:
+ TEventsHandler(TRequester* parent)
+ {
+ P_.store(parent, std::memory_order_release);
+ }
+
+ void RequestProcessed(const TRequest* r) {
+ FinishedReqs_.Enqueue(r->RequestGuid());
+ }
+
+ //thread safe method for disable proxy callbacks to parent (OnRequest(...))
+ void SyncStop() {
+ P_.store(nullptr, std::memory_order_release);
+ while (!RequesterPtrPotector_.TryAcquire()) {
+ Sleep(TDuration::MicroSeconds(100));
+ }
+ RequesterPtrPotector_.Release();
+ }
+
+ private:
+ typedef THashMap<TGUID, TIntrusivePtr<TSrvRequestState>, TGUIDHash> TStatesInProcessRequests;
+
+ void AddRequest(TUdpHttpRequest* req) override {
+ TUdpHttpRequestPtr ptr(req);
+
+ TSrvRequestState* state = new TSrvRequestState();
+
+ InProcess_[req->ReqId] = state;
+ try {
+ TGuard<TSpinLock> m(RequesterPtrPotector_);
+ if (TRequester* p = P_.load(std::memory_order_acquire)) {
+ p->OnRequest(ptr, state); //move req. owning to parent
+ }
+ } catch (...) {
+ Cdbg << "ignore exc.: " << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ void AddResponse(TUdpHttpResponse*) override {
+ Y_FAIL("unexpected response in neh netliba server");
+ }
+
+ void AddCancel(const TGUID& guid) override {
+ UpdateInProcess();
+ TStatesInProcessRequests::iterator ustate = InProcess_.find(guid);
+ if (ustate != InProcess_.end())
+ ustate->second->Canceled = true;
+ }
+
+ void AddRequestAck(const TGUID&) override {
+ Y_FAIL("unexpected acc in neh netliba server");
+ }
+
+ void UpdateInProcess() {
+ TGUID guid;
+
+ while (FinishedReqs_.Dequeue(&guid)) {
+ InProcess_.erase(guid);
+ }
+ }
+
+ private:
+ TLockFreeStack<TGUID> FinishedReqs_; //processed requests (responded or destroyed)
+ TStatesInProcessRequests InProcess_;
+ TSpinLock RequesterPtrPotector_;
+ std::atomic<TRequester*> P_;
+ };
+
+ public:
+ inline TRequester(IOnRequest* cb, ui16 port)
+ : CB_(cb)
+ , EH_(new TEventsHandler(this))
+ , R_(CreateHttpUdpRequester(port, EH_.Get()))
+ {
+ R_->EnableReportRequestCancel();
+ }
+
+ ~TRequester() override {
+ Shutdown();
+ }
+
+ void Shutdown() noexcept {
+ if (!Shutdown_) {
+ Shutdown_ = true;
+ R_->StopNoWait();
+ EH_->SyncStop();
+ }
+ }
+
+ void OnRequest(TUdpHttpRequestPtr req, TSrvRequestState* state) {
+ CB_->OnRequest(new TRequest(req, state, this));
+ }
+
+ void RequestProcessed(const TRequest* r) {
+ EH_->RequestProcessed(r);
+ }
+
+ private:
+ IOnRequest* CB_;
+ TIntrusivePtr<TEventsHandler> EH_;
+ TIntrusivePtr<INetLibaRequester> R_;
+ bool Shutdown_ = false;
+ };
+
+ typedef TIntrusivePtr<TRequester> TRequesterRef;
+
+ class TRequesterAutoShutdown: public NNeh::IRequester {
+ public:
+ TRequesterAutoShutdown(const TRequesterRef& r)
+ : R_(r)
+ {
+ }
+
+ ~TRequesterAutoShutdown() override {
+ R_->Shutdown();
+ }
+
+ private:
+ TRequesterRef R_;
+ };
+
+ class TProtocol: public IProtocol {
+ public:
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ return Singleton<TNetLibaBus>()->Schedule(msg, fallback, ss);
+ }
+
+ NNeh::IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ TRequesterRef r(new TRequester(cb, loc.GetPort()));
+ return new TRequesterAutoShutdown(r);
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return TStringBuf("netliba");
+ }
+ };
+ }
+}
+
+IProtocol* NNeh::NetLibaProtocol() {
+ return Singleton<NNetLiba::TProtocol>();
+}
diff --git a/library/cpp/neh/netliba.h b/library/cpp/neh/netliba.h
new file mode 100644
index 0000000000..9635d29753
--- /dev/null
+++ b/library/cpp/neh/netliba.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include <util/datetime/base.h>
+
+namespace NNeh {
+ //global options
+ struct TNetLibaOptions {
+ static size_t ClientThreads;
+
+ //period for quick send complete confirmation for next request to some address after receiving request ack
+ static TDuration AckTailEffect;
+
+ //set option, - return false, if option name not recognized
+ static bool Set(TStringBuf name, TStringBuf value);
+ };
+
+ class IProtocol;
+
+ IProtocol* NetLibaProtocol();
+}
diff --git a/library/cpp/neh/netliba_udp_http.cpp b/library/cpp/neh/netliba_udp_http.cpp
new file mode 100644
index 0000000000..a4df426f02
--- /dev/null
+++ b/library/cpp/neh/netliba_udp_http.cpp
@@ -0,0 +1,808 @@
+#include "netliba_udp_http.h"
+#include "utils.h"
+
+#include <library/cpp/netliba/v6/cpu_affinity.h>
+#include <library/cpp/netliba/v6/stdafx.h>
+#include <library/cpp/netliba/v6/udp_client_server.h>
+#include <library/cpp/netliba/v6/udp_socket.h>
+
+#include <library/cpp/netliba/v6/block_chain.h> // depend on another headers
+
+#include <util/system/hp_timer.h>
+#include <util/system/shmat.h>
+#include <util/system/spinlock.h>
+#include <util/system/thread.h>
+#include <util/system/types.h>
+#include <util/system/yassert.h>
+#include <util/thread/lfqueue.h>
+
+#include <atomic>
+
+#if !defined(_win_)
+#include <signal.h>
+#include <pthread.h>
+#endif
+
+using namespace NNetliba;
+
+namespace {
+ const float HTTP_TIMEOUT = 15.0f;
+ const size_t MIN_SHARED_MEM_PACKET = 1000;
+ const size_t MAX_PACKET_SIZE = 0x70000000;
+
+ NNeh::TAtomicBool PanicAttack;
+ std::atomic<NHPTimer::STime> LastHeartbeat;
+ std::atomic<double> HeartbeatTimeout;
+
+ bool IsLocal(const TUdpAddress& addr) {
+ return addr.IsIPv4() ? IsLocalIPv4(addr.GetIPv4()) : IsLocalIPv6(addr.Network, addr.Interface);
+ }
+
+ void StopAllNetLibaThreads() {
+ PanicAttack = true; // AAAA!!!!
+ }
+
+ void ReadShm(TSharedMemory* shm, TVector<char>* data) {
+ Y_ASSERT(shm);
+ int dataSize = shm->GetSize();
+ data->yresize(dataSize);
+ memcpy(&(*data)[0], shm->GetPtr(), dataSize);
+ }
+
+ void ReadShm(TSharedMemory* shm, TString* data) {
+ Y_ASSERT(shm);
+ size_t dataSize = shm->GetSize();
+ data->ReserveAndResize(dataSize);
+ memcpy(data->begin(), shm->GetPtr(), dataSize);
+ }
+
+ template <class T>
+ void EraseList(TLockFreeQueue<T*>* data) {
+ T* ptr = 0;
+ while (data->Dequeue(&ptr)) {
+ delete ptr;
+ }
+ }
+
+ enum EHttpPacket {
+ PKT_REQUEST,
+ PKT_PING,
+ PKT_PING_RESPONSE,
+ PKT_RESPONSE,
+ PKT_LOCAL_REQUEST,
+ PKT_LOCAL_RESPONSE,
+ PKT_CANCEL,
+ };
+}
+
+namespace NNehNetliba {
+ TUdpHttpMessage::TUdpHttpMessage(const TGUID& reqId, const TUdpAddress& peerAddr)
+ : ReqId(reqId)
+ , PeerAddress(peerAddr)
+ {
+ }
+
+ TUdpHttpRequest::TUdpHttpRequest(TAutoPtr<TRequest>& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr)
+ : TUdpHttpMessage(reqId, peerAddr)
+ {
+ TBlockChainIterator reqData(dataHolder->Data->GetChain());
+ char pktType;
+ reqData.Read(&pktType, 1);
+ ReadArr(&reqData, &Url);
+ if (pktType == PKT_REQUEST) {
+ ReadYArr(&reqData, &Data);
+ } else if (pktType == PKT_LOCAL_REQUEST) {
+ ReadShm(dataHolder->Data->GetSharedData(), &Data);
+ } else {
+ Y_ASSERT(0);
+ }
+
+ if (reqData.HasFailed()) {
+ Y_ASSERT(0 && "wrong format, memory corruption suspected");
+ Url = "";
+ Data.clear();
+ }
+ }
+
+ TUdpHttpResponse::TUdpHttpResponse(TAutoPtr<TRequest>& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr, EResult result, const char* error)
+ : TUdpHttpMessage(reqId, peerAddr)
+ , Ok(result)
+ {
+ if (result == TUdpHttpResponse::FAILED) {
+ Error = error ? error : "request failed";
+ } else if (result == TUdpHttpResponse::CANCELED) {
+ Error = error ? error : "request cancelled";
+ } else {
+ TBlockChainIterator reqData(dataHolder->Data->GetChain());
+ if (Y_UNLIKELY(reqData.HasFailed())) {
+ Y_ASSERT(0 && "wrong format, memory corruption suspected");
+ Ok = TUdpHttpResponse::FAILED;
+ Data.clear();
+ Error = "wrong response format";
+ } else {
+ char pktType;
+ reqData.Read(&pktType, 1);
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ Y_ASSERT(ReqId == guid);
+ if (pktType == PKT_RESPONSE) {
+ ReadArr<TString>(&reqData, &Data);
+ } else if (pktType == PKT_LOCAL_RESPONSE) {
+ ReadShm(dataHolder->Data->GetSharedData(), &Data);
+ } else {
+ Y_ASSERT(0);
+ }
+ }
+ }
+ }
+
+ class TUdpHttp: public IRequester {
+ enum EDir {
+ DIR_OUT,
+ DIR_IN
+ };
+
+ struct TInRequestState {
+ enum EState {
+ S_WAITING,
+ S_RESPONSE_SENDING,
+ S_CANCELED,
+ };
+
+ TInRequestState()
+ : State(S_WAITING)
+ {
+ }
+
+ TInRequestState(const TUdpAddress& address)
+ : State(S_WAITING)
+ , Address(address)
+ {
+ }
+
+ EState State;
+ TUdpAddress Address;
+ };
+
+ struct TOutRequestState {
+ enum EState {
+ S_SENDING,
+ S_WAITING,
+ S_WAITING_PING_SENDING,
+ S_WAITING_PING_SENT,
+ S_CANCEL_AFTER_SENDING
+ };
+
+ TOutRequestState()
+ : State(S_SENDING)
+ , TimePassed(0)
+ , PingTransferId(-1)
+ {
+ }
+
+ EState State;
+ TUdpAddress Address;
+ double TimePassed;
+ int PingTransferId;
+ IEventsCollectorRef EventsCollector;
+ };
+
+ struct TTransferPurpose {
+ EDir Dir;
+ TGUID Guid;
+
+ TTransferPurpose()
+ : Dir(DIR_OUT)
+ {
+ }
+
+ TTransferPurpose(EDir dir, TGUID guid)
+ : Dir(dir)
+ , Guid(guid)
+ {
+ }
+ };
+
+ struct TSendRequest {
+ TSendRequest() = default;
+
+ TSendRequest(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket>* data, const TGUID& reqGuid, const IEventsCollectorRef& eventsCollector)
+ : Addr(addr)
+ , Data(*data)
+ , ReqGuid(reqGuid)
+ , EventsCollector(eventsCollector)
+ , Crc32(CalcChecksum(Data->GetChain()))
+ {
+ }
+
+ TUdpAddress Addr;
+ TAutoPtr<TRopeDataPacket> Data;
+ TGUID ReqGuid;
+ IEventsCollectorRef EventsCollector;
+ ui32 Crc32;
+ };
+
+ struct TSendResponse {
+ TSendResponse() = default;
+
+ TSendResponse(const TGUID& reqGuid, EPacketPriority prior, TVector<char>* data)
+ : ReqGuid(reqGuid)
+ , DataCrc32(0)
+ , Priority(prior)
+ {
+ if (data && !data->empty()) {
+ data->swap(Data);
+ DataCrc32 = TIncrementalChecksumCalcer::CalcBlockSum(&Data[0], Data.ysize());
+ }
+ }
+
+ TVector<char> Data;
+ TGUID ReqGuid;
+ ui32 DataCrc32;
+ EPacketPriority Priority;
+ };
+
+ typedef THashMap<TGUID, TOutRequestState, TGUIDHash> TOutRequestHash;
+ typedef THashMap<TGUID, TInRequestState, TGUIDHash> TInRequestHash;
+
+ public:
+ TUdpHttp(const IEventsCollectorRef& eventsCollector)
+ : MyThread_(ExecServerThread, (void*)this)
+ , AbortTransactions_(false)
+ , Port_(0)
+ , EventCollector_(eventsCollector)
+ , ReportRequestCancel_(false)
+ , ReporRequestAck_(false)
+ , PhysicalCpu_(-1)
+ {
+ }
+
+ ~TUdpHttp() override {
+ if (MyThread_.Running()) {
+ AtomicSet(KeepRunning_, 0);
+ MyThread_.Join();
+ }
+ }
+
+ bool Start(int port, int physicalCpu) {
+ Y_ASSERT(Host_.Get() == nullptr);
+ Port_ = port;
+ PhysicalCpu_ = physicalCpu;
+ MyThread_.Start();
+ HasStarted_.Wait();
+ return Host_.Get() != nullptr;
+ }
+
+ void EnableReportRequestCancel() override {
+ ReportRequestCancel_ = true;
+ }
+
+ void EnableReportRequestAck() override {
+ ReporRequestAck_ = true;
+ }
+
+ void SendRequest(const TUdpAddress& addr, const TString& url, const TString& data, const TGUID& reqId) override {
+ Y_VERIFY(
+ data.size() < MAX_PACKET_SIZE,
+ "data size is too large; data.size()=%" PRISZT ", MAX_PACKET_SIZE=%" PRISZT,
+ data.size(), MAX_PACKET_SIZE);
+
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ if (data.size() > MIN_SHARED_MEM_PACKET && IsLocal(addr)) {
+ TIntrusivePtr<TSharedMemory> shm = new TSharedMemory;
+ if (shm->Create(data.size())) {
+ ms->Write((char)PKT_LOCAL_REQUEST);
+ ms->WriteStroka(url);
+ memcpy(shm->GetPtr(), data.begin(), data.size());
+ ms->AttachSharedData(shm);
+ }
+ }
+ if (ms->GetSharedData() == nullptr) {
+ ms->Write((char)PKT_REQUEST);
+ ms->WriteStroka(url);
+ struct TStrokaStorage: public TThrRefBase, public TString {
+ TStrokaStorage(const TString& s)
+ : TString(s)
+ {
+ }
+ };
+ TStrokaStorage* ss = new TStrokaStorage(data);
+ ms->Write((int)ss->size());
+ ms->AddBlock(ss, ss->begin(), ss->size());
+ }
+
+ SendReqList_.Enqueue(new TSendRequest(addr, &ms, reqId, EventCollector_));
+ Host_->CancelWait();
+ }
+
+ void CancelRequest(const TGUID& reqId) override {
+ CancelReqList_.Enqueue(reqId);
+ Host_->CancelWait();
+ }
+
+ void SendResponse(const TGUID& reqId, TVector<char>* data) override {
+ if (data && data->size() > MAX_PACKET_SIZE) {
+ Y_FAIL(
+ "data size is too large; data->size()=%" PRISZT ", MAX_PACKET_SIZE=%" PRISZT,
+ data->size(), MAX_PACKET_SIZE);
+ }
+ SendRespList_.Enqueue(new TSendResponse(reqId, PP_NORMAL, data));
+ Host_->CancelWait();
+ }
+
+ void StopNoWait() override {
+ AbortTransactions_ = true;
+ AtomicSet(KeepRunning_, 0);
+ // calcel all outgoing requests
+ TGuard<TSpinLock> lock(Spn_);
+ while (!OutRequests_.empty()) {
+ // cancel without informing peer that we are cancelling the request
+ FinishRequest(OutRequests_.begin(), TUdpHttpResponse::CANCELED, nullptr, "request canceled: inside TUdpHttp::StopNoWait()");
+ }
+ }
+
+ private:
+ void FinishRequest(TOutRequestHash::iterator i, TUdpHttpResponse::EResult ok, TRequestPtr data, const char* error = nullptr) {
+ TOutRequestState& s = i->second;
+ s.EventsCollector->AddResponse(new TUdpHttpResponse(data, i->first, s.Address, ok, error));
+ OutRequests_.erase(i);
+ }
+
+ int SendWithHighPriority(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data) {
+ ui32 crc32 = CalcChecksum(data->GetChain());
+ return Host_->Send(addr, data.Release(), crc32, nullptr, PP_HIGH);
+ }
+
+ void ProcessIncomingPackets() {
+ TVector<TGUID> failedRequests;
+ for (;;) {
+ TAutoPtr<TRequest> req = Host_->GetRequest();
+ if (req.Get() == nullptr)
+ break;
+
+ TBlockChainIterator reqData(req->Data->GetChain());
+ char pktType;
+ reqData.Read(&pktType, 1);
+ switch (pktType) {
+ case PKT_REQUEST:
+ case PKT_LOCAL_REQUEST: {
+ TGUID reqId = req->Guid;
+ TInRequestHash::iterator z = InRequests_.find(reqId);
+ if (z != InRequests_.end()) {
+ // oops, this request already exists!
+ // might happen if request can be stored in single packet
+ // and this packet had source IP broken during transmission and managed to pass crc checks
+ // since we already reported wrong source address for this request to the user
+ // the best thing we can do is to stop the program to avoid further complications
+ // but we just report the accident to stderr
+ fprintf(stderr, "Jackpot, same request %s received twice from %s and earlier from %s\n",
+ GetGuidAsString(reqId).c_str(), GetAddressAsString(z->second.Address).c_str(),
+ GetAddressAsString(req->Address).c_str());
+ } else {
+ InRequests_[reqId] = TInRequestState(req->Address);
+ EventCollector_->AddRequest(new TUdpHttpRequest(req, reqId, req->Address));
+ }
+ } break;
+ case PKT_PING: {
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ bool ok = InRequests_.find(guid) != InRequests_.end();
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write((char)PKT_PING_RESPONSE);
+ ms->Write(guid);
+ ms->Write(ok);
+ SendWithHighPriority(req->Address, ms.Release());
+ } break;
+ case PKT_PING_RESPONSE: {
+ TGUID guid;
+ bool ok;
+ reqData.Read(&guid, sizeof(guid));
+ reqData.Read(&ok, sizeof(ok));
+ TOutRequestHash::iterator i = OutRequests_.find(guid);
+ if (i == OutRequests_.end())
+ ; //Y_ASSERT(0); // actually possible with some packet orders
+ else {
+ if (!ok) {
+ // can not delete request at this point
+ // since we can receive failed ping and response at the same moment
+ // consider sequence: client sends ping, server sends response
+ // and replies false to ping as reply is sent
+ // we can not receive failed ping_response earlier then response itself
+ // but we can receive them simultaneously
+ failedRequests.push_back(guid);
+ } else {
+ TOutRequestState& s = i->second;
+ switch (s.State) {
+ case TOutRequestState::S_WAITING_PING_SENDING: {
+ Y_ASSERT(s.PingTransferId >= 0);
+ TTransferHash::iterator k = TransferHash_.find(s.PingTransferId);
+ if (k != TransferHash_.end())
+ TransferHash_.erase(k);
+ else
+ Y_ASSERT(0);
+ s.PingTransferId = -1;
+ s.TimePassed = 0;
+ s.State = TOutRequestState::S_WAITING;
+ } break;
+ case TOutRequestState::S_WAITING_PING_SENT:
+ s.TimePassed = 0;
+ s.State = TOutRequestState::S_WAITING;
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ }
+ }
+ } break;
+ case PKT_RESPONSE:
+ case PKT_LOCAL_RESPONSE: {
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ TOutRequestHash::iterator i = OutRequests_.find(guid);
+ if (i == OutRequests_.end()) {
+ ; //Y_ASSERT(0); // does happen
+ } else {
+ FinishRequest(i, TUdpHttpResponse::OK, req);
+ }
+ } break;
+ case PKT_CANCEL: {
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ TInRequestHash::iterator i = InRequests_.find(guid);
+ if (i == InRequests_.end()) {
+ ; //Y_ASSERT(0); // may happen
+ } else {
+ TInRequestState& s = i->second;
+ if (s.State != TInRequestState::S_CANCELED && ReportRequestCancel_)
+ EventCollector_->AddCancel(guid);
+ s.State = TInRequestState::S_CANCELED;
+ }
+ } break;
+ default:
+ Y_ASSERT(0);
+ }
+ }
+ // cleanup failed requests
+ for (size_t k = 0; k < failedRequests.size(); ++k) {
+ const TGUID& guid = failedRequests[k];
+ TOutRequestHash::iterator i = OutRequests_.find(guid);
+ if (i != OutRequests_.end())
+ FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "failed udp ping");
+ }
+ }
+
+ void AnalyzeSendResults() {
+ TSendResult res;
+ while (Host_->GetSendResult(&res)) {
+ TTransferHash::iterator k = TransferHash_.find(res.TransferId);
+ if (k != TransferHash_.end()) {
+ const TTransferPurpose& tp = k->second;
+ switch (tp.Dir) {
+ case DIR_OUT: {
+ TOutRequestHash::iterator i = OutRequests_.find(tp.Guid);
+ if (i != OutRequests_.end()) {
+ const TGUID& reqId = i->first;
+ TOutRequestState& s = i->second;
+ switch (s.State) {
+ case TOutRequestState::S_SENDING:
+ if (!res.Success) {
+ FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_SENDING");
+ } else {
+ if (ReporRequestAck_ && !!s.EventsCollector) {
+ s.EventsCollector->AddRequestAck(reqId);
+ }
+ s.State = TOutRequestState::S_WAITING;
+ s.TimePassed = 0;
+ }
+ break;
+ case TOutRequestState::S_CANCEL_AFTER_SENDING:
+ DoSendCancel(s.Address, reqId);
+ FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request failed: state S_CANCEL_AFTER_SENDING");
+ break;
+ case TOutRequestState::S_WAITING:
+ case TOutRequestState::S_WAITING_PING_SENT:
+ Y_ASSERT(0);
+ break;
+ case TOutRequestState::S_WAITING_PING_SENDING:
+ Y_ASSERT(s.PingTransferId >= 0 && s.PingTransferId == res.TransferId);
+ if (!res.Success) {
+ FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_WAITING_PING_SENDING");
+ } else {
+ s.PingTransferId = -1;
+ s.State = TOutRequestState::S_WAITING_PING_SENT;
+ s.TimePassed = 0;
+ }
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ }
+ } break;
+ case DIR_IN: {
+ TInRequestHash::iterator i = InRequests_.find(tp.Guid);
+ if (i != InRequests_.end()) {
+ Y_ASSERT(i->second.State == TInRequestState::S_RESPONSE_SENDING || i->second.State == TInRequestState::S_CANCELED);
+ InRequests_.erase(i);
+ }
+ } break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ TransferHash_.erase(k);
+ }
+ }
+ }
+
+ void SendPingsIfNeeded() {
+ NHPTimer::STime tChk = PingsSendT_;
+ float deltaT = (float)NHPTimer::GetTimePassed(&tChk);
+ if (deltaT < 0.05) {
+ return;
+ }
+ PingsSendT_ = tChk;
+ deltaT = ClampVal(deltaT, 0.0f, HTTP_TIMEOUT / 3);
+
+ {
+ for (TOutRequestHash::iterator i = OutRequests_.begin(); i != OutRequests_.end();) {
+ TOutRequestHash::iterator curIt = i++;
+ TOutRequestState& s = curIt->second;
+ const TGUID& guid = curIt->first;
+ switch (s.State) {
+ case TOutRequestState::S_WAITING:
+ s.TimePassed += deltaT;
+ if (s.TimePassed > HTTP_TIMEOUT) {
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write((char)PKT_PING);
+ ms->Write(guid);
+ int transId = SendWithHighPriority(s.Address, ms.Release());
+ TransferHash_[transId] = TTransferPurpose(DIR_OUT, guid);
+ s.State = TOutRequestState::S_WAITING_PING_SENDING;
+ s.PingTransferId = transId;
+ }
+ break;
+ case TOutRequestState::S_WAITING_PING_SENT:
+ s.TimePassed += deltaT;
+ if (s.TimePassed > HTTP_TIMEOUT) {
+ FinishRequest(curIt, TUdpHttpResponse::FAILED, nullptr, "request failed: http timeout in state S_WAITING_PING_SENT");
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ }
+
+ void Step() {
+ {
+ TGuard<TSpinLock> lock(Spn_);
+ DoSends();
+ }
+ Host_->Step();
+ {
+ TGuard<TSpinLock> lock(Spn_);
+ DoSends();
+ ProcessIncomingPackets();
+ AnalyzeSendResults();
+ SendPingsIfNeeded();
+ }
+ }
+
+ void Wait() {
+ Host_->Wait(0.1f);
+ }
+
+ void DoSendCancel(const TUdpAddress& addr, const TGUID& req) {
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write((char)PKT_CANCEL);
+ ms->Write(req);
+ SendWithHighPriority(addr, ms);
+ }
+
+ void DoSends() {
+ {
+ // cancelling requests
+ TGUID reqGuid;
+ while (CancelReqList_.Dequeue(&reqGuid)) {
+ TOutRequestHash::iterator i = OutRequests_.find(reqGuid);
+ if (i == OutRequests_.end()) {
+ AnticipateCancels_.insert(reqGuid);
+ continue; // cancelling non existing request is ok
+ }
+ TOutRequestState& s = i->second;
+ if (s.State == TOutRequestState::S_SENDING) {
+ // we are in trouble - have not sent request and we already have to cancel it, wait send
+ s.State = TOutRequestState::S_CANCEL_AFTER_SENDING;
+ s.EventsCollector->AddCancel(i->first);
+ } else {
+ DoSendCancel(s.Address, reqGuid);
+ FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request canceled: notify requested side");
+ }
+ }
+ }
+ {
+ // sending replies
+ for (TSendResponse* rd = nullptr; SendRespList_.Dequeue(&rd); delete rd) {
+ TInRequestHash::iterator i = InRequests_.find(rd->ReqGuid);
+ if (i == InRequests_.end()) {
+ Y_ASSERT(0);
+ continue;
+ }
+ TInRequestState& s = i->second;
+ if (s.State == TInRequestState::S_CANCELED) {
+ // need not send response for the canceled request
+ InRequests_.erase(i);
+ continue;
+ }
+
+ Y_ASSERT(s.State == TInRequestState::S_WAITING);
+ s.State = TInRequestState::S_RESPONSE_SENDING;
+
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ui32 crc32 = 0;
+ int dataSize = rd->Data.ysize();
+ if (rd->Data.size() > MIN_SHARED_MEM_PACKET && IsLocal(s.Address)) {
+ TIntrusivePtr<TSharedMemory> shm = new TSharedMemory;
+ if (shm->Create(dataSize)) {
+ ms->Write((char)PKT_LOCAL_RESPONSE);
+ ms->Write(rd->ReqGuid);
+ memcpy(shm->GetPtr(), &rd->Data[0], dataSize);
+ TVector<char> empty;
+ rd->Data.swap(empty);
+ ms->AttachSharedData(shm);
+ crc32 = CalcChecksum(ms->GetChain());
+ }
+ }
+ if (ms->GetSharedData() == nullptr) {
+ ms->Write((char)PKT_RESPONSE);
+ ms->Write(rd->ReqGuid);
+
+ // to offload crc calcs from inner thread, crc of data[] is calced outside and passed in DataCrc32
+ // this means that we are calculating crc when shared memory is used
+ // it is hard to avoid since in SendResponse() we don't know if shared mem will be used
+ // (peer address is not available there)
+ TIncrementalChecksumCalcer csCalcer;
+ AddChain(&csCalcer, ms->GetChain());
+ // here we are replicating the way WriteDestructive serializes data
+ csCalcer.AddBlock(&dataSize, sizeof(dataSize));
+ csCalcer.AddBlockSum(rd->DataCrc32, dataSize);
+ crc32 = csCalcer.CalcChecksum();
+
+ ms->WriteDestructive(&rd->Data);
+ //ui32 chkCrc = CalcChecksum(ms->GetChain()); // can not use since its slow for large responses
+ //Y_ASSERT(chkCrc == crc32);
+ }
+
+ int transId = Host_->Send(s.Address, ms.Release(), crc32, nullptr, rd->Priority);
+ TransferHash_[transId] = TTransferPurpose(DIR_IN, rd->ReqGuid);
+ }
+ }
+ {
+ // sending requests
+ for (TSendRequest* rd = nullptr; SendReqList_.Dequeue(&rd); delete rd) {
+ Y_ASSERT(OutRequests_.find(rd->ReqGuid) == OutRequests_.end());
+
+ {
+ TOutRequestState& s = OutRequests_[rd->ReqGuid];
+ s.State = TOutRequestState::S_SENDING;
+ s.Address = rd->Addr;
+ s.EventsCollector = rd->EventsCollector;
+ }
+
+ if (AnticipateCancels_.find(rd->ReqGuid) != AnticipateCancels_.end()) {
+ FinishRequest(OutRequests_.find(rd->ReqGuid), TUdpHttpResponse::CANCELED, nullptr, "Canceled (before transmit)");
+ } else {
+ TGUID pktGuid = rd->ReqGuid; // request packet id should match request id
+ int transId = Host_->Send(rd->Addr, rd->Data.Release(), rd->Crc32, &pktGuid, PP_NORMAL);
+ TransferHash_[transId] = TTransferPurpose(DIR_OUT, rd->ReqGuid);
+ }
+ }
+ }
+ if (!AnticipateCancels_.empty()) {
+ AnticipateCancels_.clear();
+ }
+ }
+
+ void FinishOutstandingTransactions() {
+ // wait all pending requests, all new requests are canceled
+ while ((!OutRequests_.empty() || !InRequests_.empty() || !SendRespList_.IsEmpty() || !SendReqList_.IsEmpty()) && !PanicAttack) {
+ Step();
+ sleep(0);
+ }
+ }
+
+ static void* ExecServerThread(void* param) {
+ TUdpHttp* pThis = (TUdpHttp*)param;
+ if (pThis->GetPhysicalCpu() >= 0) {
+ BindToSocket(pThis->GetPhysicalCpu());
+ }
+ SetHighestThreadPriority();
+
+ TIntrusivePtr<NNetlibaSocket::ISocket> socket = NNetlibaSocket::CreateSocket();
+ socket->Open(pThis->Port_);
+ if (socket->IsValid()) {
+ pThis->Port_ = socket->GetPort();
+ pThis->Host_ = CreateUdpHost(socket);
+ } else {
+ pThis->Host_ = nullptr;
+ }
+
+ pThis->HasStarted_.Signal();
+ if (!pThis->Host_)
+ return nullptr;
+
+ NHPTimer::GetTime(&pThis->PingsSendT_);
+ while (AtomicGet(pThis->KeepRunning_) && !PanicAttack) {
+ if (HeartbeatTimeout.load(std::memory_order_acquire) > 0) {
+ NHPTimer::STime chk = LastHeartbeat.load(std::memory_order_acquire);
+ if (NHPTimer::GetTimePassed(&chk) > HeartbeatTimeout.load(std::memory_order_acquire)) {
+ StopAllNetLibaThreads();
+#ifndef _win_
+ killpg(0, SIGKILL);
+#endif
+ abort();
+ break;
+ }
+ }
+ pThis->Step();
+ pThis->Wait();
+ }
+ if (!pThis->AbortTransactions_ && !PanicAttack) {
+ pThis->FinishOutstandingTransactions();
+ }
+ pThis->Host_ = nullptr;
+ return nullptr;
+ }
+
+ int GetPhysicalCpu() const noexcept {
+ return PhysicalCpu_;
+ }
+
+ private:
+ TThread MyThread_;
+ TAtomic KeepRunning_ = 1;
+ bool AbortTransactions_;
+ TSpinLock Spn_;
+ TSystemEvent HasStarted_;
+
+ NHPTimer::STime PingsSendT_;
+
+ TIntrusivePtr<IUdpHost> Host_;
+ int Port_;
+ TOutRequestHash OutRequests_;
+ TInRequestHash InRequests_;
+
+ typedef THashMap<int, TTransferPurpose> TTransferHash;
+ TTransferHash TransferHash_;
+
+ // hold it here to not construct on every DoSends()
+ typedef THashSet<TGUID, TGUIDHash> TAnticipateCancels;
+ TAnticipateCancels AnticipateCancels_;
+
+ TLockFreeQueue<TSendRequest*> SendReqList_;
+ TLockFreeQueue<TSendResponse*> SendRespList_;
+ TLockFreeQueue<TGUID> CancelReqList_;
+
+ TIntrusivePtr<IEventsCollector> EventCollector_;
+
+ bool ReportRequestCancel_;
+ bool ReporRequestAck_;
+ int PhysicalCpu_;
+ };
+
+ IRequesterRef CreateHttpUdpRequester(int port, const IEventsCollectorRef& ec, int physicalCpu) {
+ TUdpHttp* udpHttp = new TUdpHttp(ec);
+ IRequesterRef res(udpHttp);
+ if (!udpHttp->Start(port, physicalCpu)) {
+ if (port) {
+ ythrow yexception() << "netliba can't bind port=" << port;
+ } else {
+ ythrow yexception() << "netliba can't bind random port";
+ }
+ }
+ return res;
+ }
+}
diff --git a/library/cpp/neh/netliba_udp_http.h b/library/cpp/neh/netliba_udp_http.h
new file mode 100644
index 0000000000..324366c10c
--- /dev/null
+++ b/library/cpp/neh/netliba_udp_http.h
@@ -0,0 +1,79 @@
+#pragma once
+
+#include <library/cpp/netliba/v6/net_queue_stat.h>
+#include <library/cpp/netliba/v6/udp_address.h>
+#include <library/cpp/netliba/v6/udp_debug.h>
+
+#include <util/generic/guid.h>
+#include <util/generic/ptr.h>
+#include <util/network/init.h>
+#include <util/system/event.h>
+
+namespace NNetliba {
+ struct TRequest;
+}
+
+namespace NNehNetliba {
+ using namespace NNetliba;
+
+ typedef TAutoPtr<TRequest> TRequestPtr;
+
+ class TUdpHttpMessage {
+ public:
+ TUdpHttpMessage(const TGUID& reqId, const TUdpAddress& peerAddr);
+
+ TGUID ReqId;
+ TUdpAddress PeerAddress;
+ };
+
+ class TUdpHttpRequest: public TUdpHttpMessage {
+ public:
+ TUdpHttpRequest(TRequestPtr& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr);
+
+ TString Url;
+ TVector<char> Data;
+ };
+
+ class TUdpHttpResponse: public TUdpHttpMessage {
+ public:
+ enum EResult {
+ FAILED = 0,
+ OK = 1,
+ CANCELED = 2
+ };
+
+ TUdpHttpResponse(TRequestPtr& dataHolder, const TGUID& reqId, const TUdpAddress& peerAddr, EResult result, const char* error);
+
+ EResult Ok;
+ TString Data;
+ TString Error;
+ };
+
+ class IRequester: public TThrRefBase {
+ public:
+ virtual void EnableReportRequestCancel() = 0;
+ virtual void EnableReportRequestAck() = 0;
+
+ // vector<char> *data - vector will be cleared upon call
+ virtual void SendRequest(const TUdpAddress&, const TString& url, const TString& data, const TGUID&) = 0;
+ virtual void CancelRequest(const TGUID&) = 0;
+ virtual void SendResponse(const TGUID&, TVector<char>* data) = 0;
+
+ virtual void StopNoWait() = 0;
+ };
+
+ class IEventsCollector: public TThrRefBase {
+ public:
+ // move ownership request/response object to event collector
+ virtual void AddRequest(TUdpHttpRequest*) = 0;
+ virtual void AddResponse(TUdpHttpResponse*) = 0;
+ virtual void AddCancel(const TGUID&) = 0;
+ virtual void AddRequestAck(const TGUID&) = 0;
+ };
+
+ typedef TIntrusivePtr<IEventsCollector> IEventsCollectorRef;
+ typedef TIntrusivePtr<IRequester> IRequesterRef;
+
+ // throw exception, if can't bind port
+ IRequesterRef CreateHttpUdpRequester(int port, const IEventsCollectorRef&, int physicalCpu = -1);
+}
diff --git a/library/cpp/neh/pipequeue.cpp b/library/cpp/neh/pipequeue.cpp
new file mode 100644
index 0000000000..551e584228
--- /dev/null
+++ b/library/cpp/neh/pipequeue.cpp
@@ -0,0 +1 @@
+#include "pipequeue.h"
diff --git a/library/cpp/neh/pipequeue.h b/library/cpp/neh/pipequeue.h
new file mode 100644
index 0000000000..bed8d44bd2
--- /dev/null
+++ b/library/cpp/neh/pipequeue.h
@@ -0,0 +1,207 @@
+#pragma once
+
+#include "lfqueue.h"
+
+#include <library/cpp/coroutine/engine/impl.h>
+#include <library/cpp/coroutine/engine/network.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <util/system/pipe.h>
+
+#ifdef _linux_
+#include <sys/eventfd.h>
+#endif
+
+#if defined(_bionic_) && !defined(EFD_SEMAPHORE)
+#define EFD_SEMAPHORE 1
+#endif
+
+namespace NNeh {
+#ifdef _linux_
+ class TSemaphoreEventFd {
+ public:
+ inline TSemaphoreEventFd() {
+ F_ = eventfd(0, EFD_NONBLOCK | EFD_SEMAPHORE);
+ if (F_ < 0) {
+ ythrow TFileError() << "failed to create a eventfd";
+ }
+ }
+
+ inline ~TSemaphoreEventFd() {
+ close(F_);
+ }
+
+ inline size_t Acquire(TCont* c) {
+ ui64 ev;
+ return NCoro::ReadI(c, F_, &ev, sizeof ev).Processed();
+ }
+
+ inline void Release() {
+ const static ui64 ev(1);
+ (void)write(F_, &ev, sizeof ev);
+ }
+
+ private:
+ int F_;
+ };
+#endif
+
+ class TSemaphorePipe {
+ public:
+ inline TSemaphorePipe() {
+ TPipeHandle::Pipe(S_[0], S_[1]);
+
+ SetNonBlock(S_[0]);
+ SetNonBlock(S_[1]);
+ }
+
+ inline size_t Acquire(TCont* c) {
+ char ch;
+ return NCoro::ReadI(c, S_[0], &ch, 1).Processed();
+ }
+
+ inline size_t Acquire(TCont* c, char* buff, size_t buflen) {
+ return NCoro::ReadI(c, S_[0], buff, buflen).Processed();
+ }
+
+ inline void Release() {
+ char ch = 13;
+ S_[1].Write(&ch, 1);
+ }
+
+ private:
+ TPipeHandle S_[2];
+ };
+
+ class TPipeQueueBase {
+ public:
+ inline void Enqueue(void* job) {
+ Q_.Enqueue(job);
+ S_.Release();
+ }
+
+ inline void* Dequeue(TCont* c, char* ch, size_t buflen) {
+ void* ret = nullptr;
+
+ while (!Q_.Dequeue(&ret) && S_.Acquire(c, ch, buflen)) {
+ }
+
+ return ret;
+ }
+
+ inline void* Dequeue() noexcept {
+ void* ret = nullptr;
+
+ Q_.Dequeue(&ret);
+
+ return ret;
+ }
+
+ private:
+ TLockFreeQueue<void*> Q_;
+ TSemaphorePipe S_;
+ };
+
+ template <class T, size_t buflen = 1>
+ class TPipeQueue {
+ public:
+ template <class TPtr>
+ inline void EnqueueSafe(TPtr req) {
+ Enqueue(req.Get());
+ req.Release();
+ }
+
+ inline void Enqueue(T* req) {
+ Q_.Enqueue(req);
+ }
+
+ template <class TPtr>
+ inline void DequeueSafe(TCont* c, TPtr& ret) {
+ ret.Reset(Dequeue(c));
+ }
+
+ inline T* Dequeue(TCont* c) {
+ char ch[buflen];
+
+ return (T*)Q_.Dequeue(c, ch, sizeof(ch));
+ }
+
+ protected:
+ TPipeQueueBase Q_;
+ };
+
+ //optimized for avoiding unnecessary usage semaphore + use eventfd on linux
+ template <class T>
+ struct TOneConsumerPipeQueue {
+ inline TOneConsumerPipeQueue()
+ : Signaled_(0)
+ , SkipWait_(0)
+ {
+ }
+
+ inline void Enqueue(T* job) {
+ Q_.Enqueue(job);
+
+ AtomicSet(SkipWait_, 1);
+ if (AtomicCas(&Signaled_, 1, 0)) {
+ S_.Release();
+ }
+ }
+
+ inline T* Dequeue(TCont* c) {
+ T* ret = nullptr;
+
+ while (!Q_.Dequeue(&ret)) {
+ AtomicSet(Signaled_, 0);
+ if (!AtomicCas(&SkipWait_, 0, 1)) {
+ if (!S_.Acquire(c)) {
+ break;
+ }
+ }
+ AtomicSet(Signaled_, 1);
+ }
+
+ return ret;
+ }
+
+ template <class TPtr>
+ inline void EnqueueSafe(TPtr req) {
+ Enqueue(req.Get());
+ Y_UNUSED(req.Release());
+ }
+
+ template <class TPtr>
+ inline void DequeueSafe(TCont* c, TPtr& ret) {
+ ret.Reset(Dequeue(c));
+ }
+
+ protected:
+ TLockFreeQueue<T*> Q_;
+#ifdef _linux_
+ TSemaphoreEventFd S_;
+#else
+ TSemaphorePipe S_;
+#endif
+ TAtomic Signaled_;
+ TAtomic SkipWait_;
+ };
+
+ template <class T, size_t buflen = 1>
+ struct TAutoPipeQueue: public TPipeQueue<T, buflen> {
+ ~TAutoPipeQueue() {
+ while (T* t = (T*)TPipeQueue<T, buflen>::Q_.Dequeue()) {
+ delete t;
+ }
+ }
+ };
+
+ template <class T>
+ struct TAutoOneConsumerPipeQueue: public TOneConsumerPipeQueue<T> {
+ ~TAutoOneConsumerPipeQueue() {
+ T* ret = nullptr;
+
+ while (TOneConsumerPipeQueue<T>::Q_.Dequeue(&ret)) {
+ delete ret;
+ }
+ }
+ };
+}
diff --git a/library/cpp/neh/rpc.cpp b/library/cpp/neh/rpc.cpp
new file mode 100644
index 0000000000..13813bea1f
--- /dev/null
+++ b/library/cpp/neh/rpc.cpp
@@ -0,0 +1,322 @@
+#include "rpc.h"
+#include "rq.h"
+#include "multi.h"
+#include "location.h"
+
+#include <library/cpp/threading/thread_local/thread_local.h>
+
+#include <util/generic/hash.h>
+#include <util/thread/factory.h>
+#include <util/system/spinlock.h>
+
+using namespace NNeh;
+
+namespace {
+ typedef std::pair<TString, IServiceRef> TServiceDescr;
+ typedef TVector<TServiceDescr> TServicesBase;
+
+ class TServices: public TServicesBase, public TThrRefBase, public IOnRequest {
+ typedef THashMap<TStringBuf, IServiceRef> TSrvs;
+
+ struct TVersionedServiceMap {
+ TSrvs Srvs;
+ i64 Version = 0;
+ };
+
+
+ struct TFunc: public IThreadFactory::IThreadAble {
+ inline TFunc(TServices* parent)
+ : Parent(parent)
+ {
+ }
+
+ void DoExecute() override {
+ TThread::SetCurrentThreadName("NehTFunc");
+ TVersionedServiceMap mp;
+ while (true) {
+ IRequestRef req = Parent->RQ_->Next();
+
+ if (!req) {
+ break;
+ }
+
+ Parent->ServeRequest(mp, req);
+ }
+
+ Parent->RQ_->Schedule(nullptr);
+ }
+
+ TServices* Parent;
+ };
+
+ public:
+ inline TServices()
+ : RQ_(CreateRequestQueue())
+ {
+ }
+
+ inline TServices(TCheck check)
+ : RQ_(CreateRequestQueue())
+ , C_(check)
+ {
+ }
+
+ inline ~TServices() override {
+ LF_.Destroy();
+ }
+
+ inline void Add(const TString& service, IServiceRef srv) {
+ TGuard<TSpinLock> guard(L_);
+
+ push_back(std::make_pair(service, srv));
+ AtomicIncrement(SelfVersion_);
+ }
+
+ inline void Listen() {
+ Y_ENSURE(!HasLoop_ || !*HasLoop_);
+ HasLoop_ = false;
+ RR_ = MultiRequester(ListenAddrs(), this);
+ }
+
+ inline void Loop(size_t threads) {
+ Y_ENSURE(!HasLoop_ || *HasLoop_);
+ HasLoop_ = true;
+
+ TIntrusivePtr<TServices> self(this);
+ IRequesterRef rr = MultiRequester(ListenAddrs(), this);
+ TFunc func(this);
+
+ typedef TAutoPtr<IThreadFactory::IThread> IThreadRef;
+ TVector<IThreadRef> thrs;
+
+ for (size_t i = 1; i < threads; ++i) {
+ thrs.push_back(SystemThreadFactory()->Run(&func));
+ }
+
+ func.Execute();
+
+ for (size_t i = 0; i < thrs.size(); ++i) {
+ thrs[i]->Join();
+ }
+ RQ_->Clear();
+ }
+
+ inline void ForkLoop(size_t threads) {
+ Y_ENSURE(!HasLoop_ || *HasLoop_);
+ HasLoop_ = true;
+ //here we can have trouble with binding port(s), so expect exceptions
+ IRequesterRef rr = MultiRequester(ListenAddrs(), this);
+ LF_.Reset(new TLoopFunc(this, threads, rr));
+ }
+
+ inline void Stop() {
+ RQ_->Schedule(nullptr);
+ }
+
+ inline void SyncStopFork() {
+ Stop();
+ if (LF_) {
+ LF_->SyncStop();
+ }
+ RQ_->Clear();
+ LF_.Destroy();
+ }
+
+ void OnRequest(IRequestRef req) override {
+ if (C_) {
+ if (auto error = C_(req)) {
+ req->SendError(*error);
+ return;
+ }
+ }
+ if (!*HasLoop_) {
+ ServeRequest(LocalMap_.GetRef(), req);
+ } else {
+ RQ_->Schedule(req);
+ }
+ }
+
+ private:
+ class TLoopFunc: public TFunc {
+ public:
+ TLoopFunc(TServices* parent, size_t threads, IRequesterRef& rr)
+ : TFunc(parent)
+ , RR_(rr)
+ {
+ T_.reserve(threads);
+
+ try {
+ for (size_t i = 0; i < threads; ++i) {
+ T_.push_back(SystemThreadFactory()->Run(this));
+ }
+ } catch (...) {
+ //paranoid mode on
+ SyncStop();
+ throw;
+ }
+ }
+
+ ~TLoopFunc() override {
+ try {
+ SyncStop();
+ } catch (...) {
+ Cdbg << TStringBuf("neh rpc ~loop_func: ") << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ void SyncStop() {
+ if (!T_) {
+ return;
+ }
+
+ Parent->Stop();
+
+ for (size_t i = 0; i < T_.size(); ++i) {
+ T_[i]->Join();
+ }
+ T_.clear();
+ }
+
+ private:
+ typedef TAutoPtr<IThreadFactory::IThread> IThreadRef;
+ TVector<IThreadRef> T_;
+ IRequesterRef RR_;
+ };
+
+ inline void ServeRequest(TVersionedServiceMap& mp, IRequestRef req) {
+ if (!req) {
+ return;
+ }
+
+ const TStringBuf name = req->Service();
+ TSrvs::const_iterator it = mp.Srvs.find(name);
+
+ if (Y_UNLIKELY(it == mp.Srvs.end())) {
+ if (UpdateServices(mp.Srvs, mp.Version)) {
+ it = mp.Srvs.find(name);
+ }
+ }
+
+ if (Y_UNLIKELY(it == mp.Srvs.end())) {
+ it = mp.Srvs.find(TStringBuf("*"));
+ }
+
+ if (Y_UNLIKELY(it == mp.Srvs.end())) {
+ req->SendError(IRequest::NotExistService);
+ } else {
+ try {
+ it->second->ServeRequest(req);
+ } catch (...) {
+ Cdbg << CurrentExceptionMessage() << Endl;
+ }
+ }
+ }
+
+ inline bool UpdateServices(TSrvs& srvs, i64& version) const {
+ if (AtomicGet(SelfVersion_) == version) {
+ return false;
+ }
+
+ srvs.clear();
+
+ TGuard<TSpinLock> guard(L_);
+
+ for (const auto& it : *this) {
+ srvs[TParsedLocation(it.first).Service] = it.second;
+ }
+ version = AtomicGet(SelfVersion_);
+
+ return true;
+ }
+
+ inline TListenAddrs ListenAddrs() const {
+ TListenAddrs addrs;
+
+ {
+ TGuard<TSpinLock> guard(L_);
+
+ for (const auto& it : *this) {
+ addrs.push_back(it.first);
+ }
+ }
+
+ return addrs;
+ }
+
+ TSpinLock L_;
+ IRequestQueueRef RQ_;
+ THolder<TLoopFunc> LF_;
+ TAtomic SelfVersion_ = 1;
+ TCheck C_;
+
+ NThreading::TThreadLocalValue<TVersionedServiceMap> LocalMap_;
+
+ IRequesterRef RR_;
+ TMaybe<bool> HasLoop_;
+ };
+
+ class TServicesFace: public IServices {
+ public:
+ inline TServicesFace()
+ : S_(new TServices())
+ {
+ }
+
+ inline TServicesFace(TCheck check)
+ : S_(new TServices(check))
+ {
+ }
+
+ void DoAdd(const TString& service, IServiceRef srv) override {
+ S_->Add(service, srv);
+ }
+
+ void Loop(size_t threads) override {
+ S_->Loop(threads);
+ }
+
+ void ForkLoop(size_t threads) override {
+ S_->ForkLoop(threads);
+ }
+
+ void SyncStopFork() override {
+ S_->SyncStopFork();
+ }
+
+ void Stop() override {
+ S_->Stop();
+ }
+
+ void Listen() override {
+ S_->Listen();
+ }
+
+ private:
+ TIntrusivePtr<TServices> S_;
+ };
+}
+
+IServiceRef NNeh::Wrap(const TServiceFunction& func) {
+ struct TWrapper: public IService {
+ inline TWrapper(const TServiceFunction& f)
+ : F(f)
+ {
+ }
+
+ void ServeRequest(const IRequestRef& request) override {
+ F(request);
+ }
+
+ TServiceFunction F;
+ };
+
+ return new TWrapper(func);
+}
+
+IServicesRef NNeh::CreateLoop() {
+ return new TServicesFace();
+}
+
+IServicesRef NNeh::CreateLoop(TCheck check) {
+ return new TServicesFace(check);
+}
diff --git a/library/cpp/neh/rpc.h b/library/cpp/neh/rpc.h
new file mode 100644
index 0000000000..482ff7ce53
--- /dev/null
+++ b/library/cpp/neh/rpc.h
@@ -0,0 +1,155 @@
+#pragma once
+
+#include <util/generic/vector.h>
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/generic/strbuf.h>
+#include <util/generic/maybe.h>
+#include <util/stream/output.h>
+#include <util/datetime/base.h>
+#include <functional>
+
+namespace NNeh {
+ using TData = TVector<char>;
+
+ class TDataSaver: public TData, public IOutputStream {
+ public:
+ TDataSaver() = default;
+ ~TDataSaver() override = default;
+ TDataSaver(TDataSaver&&) noexcept = default;
+ TDataSaver& operator=(TDataSaver&&) noexcept = default;
+
+ void DoWrite(const void* buf, size_t len) override {
+ insert(end(), (const char*)buf, (const char*)buf + len);
+ }
+ };
+
+ class IRequest {
+ public:
+ IRequest()
+ : ArrivalTime_(TInstant::Now())
+ {
+ }
+
+ virtual ~IRequest() = default;
+
+ virtual TStringBuf Scheme() const = 0;
+ virtual TString RemoteHost() const = 0; //IP-literal / IPv4address / reg-name()
+ virtual TStringBuf Service() const = 0;
+ virtual TStringBuf Data() const = 0;
+ virtual TStringBuf RequestId() const = 0;
+ virtual bool Canceled() const = 0;
+ virtual void SendReply(TData& data) = 0;
+ enum TResponseError {
+ BadRequest, // bad request data - http_code 400
+ Forbidden, // forbidden request - http_code 403
+ NotExistService, // not found request handler - http_code 404
+ TooManyRequests, // too many requests for the handler - http_code 429
+ InternalError, // s...amthing happen - http_code 500
+ NotImplemented, // not implemented - http_code 501
+ BadGateway, // remote backend not available - http_code 502
+ ServiceUnavailable, // overload - http_code 503
+ BandwidthLimitExceeded, // 5xx version of 429
+ MaxResponseError // count error types
+ };
+ virtual void SendError(TResponseError err, const TString& details = TString()) = 0;
+ virtual TInstant ArrivalTime() const {
+ return ArrivalTime_;
+ }
+
+ private:
+ TInstant ArrivalTime_;
+ };
+
+ using IRequestRef = TAutoPtr<IRequest>;
+
+ struct IOnRequest {
+ virtual void OnRequest(IRequestRef req) = 0;
+ };
+
+ class TRequestOut: public TDataSaver {
+ public:
+ inline TRequestOut(IRequest* req)
+ : Req_(req)
+ {
+ }
+
+ ~TRequestOut() override {
+ try {
+ Finish();
+ } catch (...) {
+ }
+ }
+
+ void DoFinish() override {
+ if (Req_) {
+ Req_->SendReply(*this);
+ Req_ = nullptr;
+ }
+ }
+
+ private:
+ IRequest* Req_;
+ };
+
+ struct IRequester {
+ virtual ~IRequester() = default;
+ };
+
+ using IRequesterRef = TAtomicSharedPtr<IRequester>;
+
+ struct IService: public TThrRefBase {
+ virtual void ServeRequest(const IRequestRef& request) = 0;
+ };
+
+ using IServiceRef = TIntrusivePtr<IService>;
+ using TServiceFunction = std::function<void(const IRequestRef&)>;
+
+ IServiceRef Wrap(const TServiceFunction& func);
+
+ class IServices {
+ public:
+ virtual ~IServices() = default;
+
+ /// use current thread and run #threads-1 in addition
+ virtual void Loop(size_t threads) = 0;
+ /// run #threads and return control
+ virtual void ForkLoop(size_t threads) = 0;
+ /// send stopping request and wait stopping all services
+ virtual void SyncStopFork() = 0;
+ /// send stopping request and return control (async call)
+ virtual void Stop() = 0;
+ /// just listen, don't start any threads
+ virtual void Listen() = 0;
+
+ inline IServices& Add(const TString& service, IServiceRef srv) {
+ DoAdd(service, srv);
+
+ return *this;
+ }
+
+ inline IServices& Add(const TString& service, const TServiceFunction& func) {
+ return Add(service, Wrap(func));
+ }
+
+ template <class T>
+ inline IServices& Add(const TString& service, T& t) {
+ return this->Add(service, std::bind(&T::ServeRequest, std::ref(t), std::placeholders::_1));
+ }
+
+ template <class T, void (T::*M)(const IRequestRef&)>
+ inline IServices& Add(const TString& service, T& t) {
+ return this->Add(service, std::bind(M, std::ref(t), std::placeholders::_1));
+ }
+
+ private:
+ virtual void DoAdd(const TString& service, IServiceRef srv) = 0;
+ };
+
+ using IServicesRef = TAutoPtr<IServices>;
+ using TCheck = std::function<TMaybe<IRequest::TResponseError>(const IRequestRef&)>;
+
+ IServicesRef CreateLoop();
+ // if request fails check it will be cancelled
+ IServicesRef CreateLoop(TCheck check);
+}
diff --git a/library/cpp/neh/rq.cpp b/library/cpp/neh/rq.cpp
new file mode 100644
index 0000000000..b3ae8f470d
--- /dev/null
+++ b/library/cpp/neh/rq.cpp
@@ -0,0 +1,312 @@
+#include "rq.h"
+#include "lfqueue.h"
+
+#include <library/cpp/threading/atomic/bool.h>
+
+#include <util/system/tls.h>
+#include <util/system/pipe.h>
+#include <util/system/event.h>
+#include <util/system/mutex.h>
+#include <util/system/condvar.h>
+#include <util/system/guard.h>
+#include <util/network/socket.h>
+#include <util/generic/deque.h>
+
+using namespace NNeh;
+
+namespace {
+ class TBaseLockFreeRequestQueue: public IRequestQueue {
+ public:
+ void Clear() override {
+ IRequestRef req;
+ while (Q_.Dequeue(&req)) {
+ }
+ }
+
+ protected:
+ NNeh::TAutoLockFreeQueue<IRequest> Q_;
+ };
+
+ class TFdRequestQueue: public TBaseLockFreeRequestQueue {
+ public:
+ inline TFdRequestQueue() {
+ TPipeHandle::Pipe(R_, W_);
+ SetNonBlock(W_);
+ }
+
+ void Schedule(IRequestRef req) override {
+ Q_.Enqueue(req);
+ char ch = 42;
+ W_.Write(&ch, 1);
+ }
+
+ IRequestRef Next() override {
+ IRequestRef ret;
+
+#if 0
+ for (size_t i = 0; i < 20; ++i) {
+ if (Q_.Dequeue(&ret)) {
+ return ret;
+ }
+
+ //asm volatile ("pause;");
+ }
+#endif
+
+ while (!Q_.Dequeue(&ret)) {
+ char ch;
+
+ R_.Read(&ch, 1);
+ }
+
+ return ret;
+ }
+
+ private:
+ TPipeHandle R_;
+ TPipeHandle W_;
+ };
+
+ struct TNehFdEvent {
+ inline TNehFdEvent() {
+ TPipeHandle::Pipe(R, W);
+ SetNonBlock(W);
+ }
+
+ inline void Signal() noexcept {
+ char ch = 21;
+ W.Write(&ch, 1);
+ }
+
+ inline void Wait() noexcept {
+ char buf[128];
+ R.Read(buf, sizeof(buf));
+ }
+
+ TPipeHandle R;
+ TPipeHandle W;
+ };
+
+ template <class TEvent>
+ class TEventRequestQueue: public TBaseLockFreeRequestQueue {
+ public:
+ void Schedule(IRequestRef req) override {
+ Q_.Enqueue(req);
+ E_.Signal();
+ }
+
+ IRequestRef Next() override {
+ IRequestRef ret;
+
+ while (!Q_.Dequeue(&ret)) {
+ E_.Wait();
+ }
+
+ E_.Signal();
+
+ return ret;
+ }
+
+ private:
+ TEvent E_;
+ };
+
+ template <class TEvent>
+ class TLazyEventRequestQueue: public TBaseLockFreeRequestQueue {
+ public:
+ void Schedule(IRequestRef req) override {
+ Q_.Enqueue(req);
+ if (C_.Val()) {
+ E_.Signal();
+ }
+ }
+
+ IRequestRef Next() override {
+ IRequestRef ret;
+
+ C_.Inc();
+ while (!Q_.Dequeue(&ret)) {
+ E_.Wait();
+ }
+ C_.Dec();
+
+ if (Q_.Size() && C_.Val()) {
+ E_.Signal();
+ }
+
+ return ret;
+ }
+
+ private:
+ TEvent E_;
+ TAtomicCounter C_;
+ };
+
+ class TCondVarRequestQueue: public IRequestQueue {
+ public:
+ void Clear() override {
+ TGuard<TMutex> g(M_);
+ Q_.clear();
+ }
+
+ void Schedule(IRequestRef req) override {
+ {
+ TGuard<TMutex> g(M_);
+
+ Q_.push_back(req);
+ }
+
+ C_.Signal();
+ }
+
+ IRequestRef Next() override {
+ TGuard<TMutex> g(M_);
+
+ while (Q_.empty()) {
+ C_.Wait(M_);
+ }
+
+ IRequestRef ret = *Q_.begin();
+ Q_.pop_front();
+
+ return ret;
+ }
+
+ private:
+ TDeque<IRequestRef> Q_;
+ TMutex M_;
+ TCondVar C_;
+ };
+
+ class TBusyRequestQueue: public TBaseLockFreeRequestQueue {
+ public:
+ void Schedule(IRequestRef req) override {
+ Q_.Enqueue(req);
+ }
+
+ IRequestRef Next() override {
+ IRequestRef ret;
+
+ while (!Q_.Dequeue(&ret)) {
+ }
+
+ return ret;
+ }
+ };
+
+ class TSleepRequestQueue: public TBaseLockFreeRequestQueue {
+ public:
+ void Schedule(IRequestRef req) override {
+ Q_.Enqueue(req);
+ }
+
+ IRequestRef Next() override {
+ IRequestRef ret;
+
+ while (!Q_.Dequeue(&ret)) {
+ usleep(1);
+ }
+
+ return ret;
+ }
+ };
+
+ struct TStupidEvent {
+ inline TStupidEvent()
+ : InWait(false)
+ {
+ }
+
+ inline bool Signal() noexcept {
+ const bool ret = InWait;
+ Ev.Signal();
+
+ return ret;
+ }
+
+ inline void Wait() noexcept {
+ InWait = true;
+ Ev.Wait();
+ InWait = false;
+ }
+
+ TAutoEvent Ev;
+ NAtomic::TBool InWait;
+ };
+
+ template <class TEvent>
+ class TLFRequestQueue: public TBaseLockFreeRequestQueue {
+ struct TLocalQueue: public TEvent {
+ };
+
+ public:
+ void Schedule(IRequestRef req) override {
+ Q_.Enqueue(req);
+
+ for (TLocalQueue* lq = 0; FQ_.Dequeue(&lq) && !lq->Signal();) {
+ }
+ }
+
+ IRequestRef Next() override {
+ while (true) {
+ IRequestRef ret;
+
+ if (Q_.Dequeue(&ret)) {
+ return ret;
+ }
+
+ TLocalQueue* lq = LocalQueue();
+
+ FQ_.Enqueue(lq);
+
+ if (Q_.Dequeue(&ret)) {
+ TLocalQueue* besttry;
+
+ if (FQ_.Dequeue(&besttry)) {
+ if (besttry == lq) {
+ //huraay, get rid of spurious wakeup
+ } else {
+ FQ_.Enqueue(besttry);
+ }
+ }
+
+ return ret;
+ }
+
+ lq->Wait();
+ }
+ }
+
+ private:
+ static inline TLocalQueue* LocalQueue() noexcept {
+ Y_POD_STATIC_THREAD(TLocalQueue*)
+ lq((TLocalQueue*)0);
+
+ if (!lq) {
+ Y_STATIC_THREAD(TLocalQueue)
+ slq;
+
+ lq = &(TLocalQueue&)slq;
+ }
+
+ return lq;
+ }
+
+ private:
+ TLockFreeStack<TLocalQueue*> FQ_;
+ };
+}
+
+IRequestQueueRef NNeh::CreateRequestQueue() {
+//return new TCondVarRequestQueue();
+//return new TSleepRequestQueue();
+//return new TBusyRequestQueue();
+//return new TLFRequestQueue<TStupidEvent>();
+#if defined(_freebsd_)
+ return new TFdRequestQueue();
+#endif
+ //return new TFdRequestQueue();
+ return new TLazyEventRequestQueue<TAutoEvent>();
+ //return new TEventRequestQueue<TAutoEvent>();
+ //return new TEventRequestQueue<TNehFdEvent>();
+}
diff --git a/library/cpp/neh/rq.h b/library/cpp/neh/rq.h
new file mode 100644
index 0000000000..39db6c7124
--- /dev/null
+++ b/library/cpp/neh/rq.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include "rpc.h"
+
+namespace NNeh {
+ class IRequestQueue {
+ public:
+ virtual ~IRequestQueue() {
+ }
+
+ virtual void Clear() = 0;
+ virtual void Schedule(IRequestRef req) = 0;
+ virtual IRequestRef Next() = 0;
+ };
+
+ typedef TAutoPtr<IRequestQueue> IRequestQueueRef;
+ IRequestQueueRef CreateRequestQueue();
+}
diff --git a/library/cpp/neh/smart_ptr.cpp b/library/cpp/neh/smart_ptr.cpp
new file mode 100644
index 0000000000..62a540c871
--- /dev/null
+++ b/library/cpp/neh/smart_ptr.cpp
@@ -0,0 +1 @@
+#include "smart_ptr.h"
diff --git a/library/cpp/neh/smart_ptr.h b/library/cpp/neh/smart_ptr.h
new file mode 100644
index 0000000000..1ec4653304
--- /dev/null
+++ b/library/cpp/neh/smart_ptr.h
@@ -0,0 +1,332 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+
+namespace NNeh {
+ //limited emulation shared_ptr/weak_ptr from boost lib.
+ //the main value means the weak_ptr functionality, else recommended use types from util/generic/ptr.h
+
+ //smart pointer counter shared between shared and weak ptrs.
+ class TSPCounted: public TThrRefBase {
+ public:
+ inline TSPCounted() noexcept
+ : C_(0)
+ {
+ }
+
+ inline void Inc() noexcept {
+ AtomicIncrement(C_);
+ }
+
+ //return false if C_ already 0, else increment and return true
+ inline bool TryInc() noexcept {
+ for (;;) {
+ intptr_t curVal(AtomicGet(C_));
+
+ if (!curVal) {
+ return false;
+ }
+
+ intptr_t newVal(curVal + 1);
+
+ if (AtomicCas(&C_, newVal, curVal)) {
+ return true;
+ }
+ }
+ }
+
+ inline intptr_t Dec() noexcept {
+ return AtomicDecrement(C_);
+ }
+
+ inline intptr_t Value() const noexcept {
+ return AtomicGet(C_);
+ }
+
+ private:
+ TAtomic C_;
+ };
+
+ typedef TIntrusivePtr<TSPCounted> TSPCountedRef;
+
+ class TWeakCount;
+
+ class TSPCount {
+ public:
+ TSPCount(TSPCounted* c = nullptr) noexcept
+ : C_(c)
+ {
+ }
+
+ inline void Swap(TSPCount& r) noexcept {
+ DoSwap(C_, r.C_);
+ }
+
+ inline size_t UseCount() const noexcept {
+ if (!C_) {
+ return 0;
+ }
+ return C_->Value();
+ }
+
+ inline bool operator!() const noexcept {
+ return !C_;
+ }
+
+ inline TSPCounted* GetCounted() const noexcept {
+ return C_.Get();
+ }
+
+ inline void Reset() noexcept {
+ if (!!C_) {
+ C_.Drop();
+ }
+ }
+
+ protected:
+ TIntrusivePtr<TSPCounted> C_;
+ };
+
+ class TSharedCount: public TSPCount {
+ public:
+ inline TSharedCount() noexcept {
+ }
+
+ /// @throws std::bad_alloc
+ inline explicit TSharedCount(const TSharedCount& r)
+ : TSPCount(r.C_.Get())
+ {
+ if (!!C_) {
+ (C_->Inc());
+ }
+ }
+
+ //'c' must exist and has already increased ref
+ inline explicit TSharedCount(TSPCounted* c) noexcept
+ : TSPCount(c)
+ {
+ }
+
+ public:
+ /// @throws std::bad_alloc
+ inline void Inc() {
+ if (!C_) {
+ TSPCountedRef(new TSPCounted()).Swap(C_);
+ }
+ C_->Inc();
+ }
+
+ inline bool TryInc() noexcept {
+ if (!C_) {
+ return false;
+ }
+ return C_->TryInc();
+ }
+
+ inline intptr_t Dec() noexcept {
+ if (!C_) {
+ Y_ASSERT(0);
+ return 0;
+ }
+ return C_->Dec();
+ }
+
+ void Drop() noexcept {
+ C_.Drop();
+ }
+
+ protected:
+ template <class Y>
+ friend class TSharedPtrB;
+
+ // 'c' MUST BE already incremented
+ void Assign(TSPCounted* c) noexcept {
+ TSPCountedRef(c).Swap(C_);
+ }
+
+ private:
+ TSharedCount& operator=(const TSharedCount&); //disable
+ };
+
+ class TWeakCount: public TSPCount {
+ public:
+ inline TWeakCount() noexcept {
+ }
+
+ inline explicit TWeakCount(const TWeakCount& r) noexcept
+ : TSPCount(r.GetCounted())
+ {
+ }
+
+ inline explicit TWeakCount(const TSharedCount& r) noexcept
+ : TSPCount(r.GetCounted())
+ {
+ }
+
+ private:
+ TWeakCount& operator=(const TWeakCount&); //disable
+ };
+
+ template <class T>
+ class TWeakPtrB;
+
+ template <class T>
+ class TSharedPtrB {
+ public:
+ inline TSharedPtrB() noexcept
+ : T_(nullptr)
+ {
+ }
+
+ /// @throws std::bad_alloc
+ inline TSharedPtrB(T* t)
+ : T_(nullptr)
+ {
+ if (t) {
+ THolder<T> h(t);
+ C_.Inc();
+ T_ = h.Release();
+ }
+ }
+
+ inline TSharedPtrB(const TSharedPtrB<T>& r) noexcept
+ : T_(r.T_)
+ , C_(r.C_)
+ {
+ Y_ASSERT((!!T_ && !!C_.UseCount()) || (!T_ && !C_.UseCount()));
+ }
+
+ inline TSharedPtrB(const TWeakPtrB<T>& r) noexcept
+ : T_(r.T_)
+ {
+ if (T_) {
+ TSPCounted* spc = r.C_.GetCounted();
+
+ if (spc && spc->TryInc()) {
+ C_.Assign(spc);
+ } else { //obsolete ptr
+ T_ = nullptr;
+ }
+ }
+ }
+
+ inline ~TSharedPtrB() {
+ Reset();
+ }
+
+ TSharedPtrB& operator=(const TSharedPtrB<T>& r) noexcept {
+ TSharedPtrB<T>(r).Swap(*this);
+ return *this;
+ }
+
+ TSharedPtrB& operator=(const TWeakPtrB<T>& r) noexcept {
+ TSharedPtrB<T>(r).Swap(*this);
+ return *this;
+ }
+
+ void Swap(TSharedPtrB<T>& r) noexcept {
+ DoSwap(T_, r.T_);
+ DoSwap(C_, r.C_);
+ Y_ASSERT((!!T_ && !!UseCount()) || (!T_ && !UseCount()));
+ }
+
+ inline bool operator!() const noexcept {
+ return !T_;
+ }
+
+ inline T* Get() noexcept {
+ return T_;
+ }
+
+ inline T* operator->() noexcept {
+ return T_;
+ }
+
+ inline T* operator->() const noexcept {
+ return T_;
+ }
+
+ inline T& operator*() noexcept {
+ return *T_;
+ }
+
+ inline T& operator*() const noexcept {
+ return *T_;
+ }
+
+ inline void Reset() noexcept {
+ if (T_) {
+ if (C_.Dec() == 0) {
+ delete T_;
+ }
+ T_ = nullptr;
+ C_.Drop();
+ }
+ }
+
+ inline size_t UseCount() const noexcept {
+ return C_.UseCount();
+ }
+
+ protected:
+ template <class Y>
+ friend class TWeakPtrB;
+
+ T* T_;
+ TSharedCount C_;
+ };
+
+ template <class T>
+ class TWeakPtrB {
+ public:
+ inline TWeakPtrB() noexcept
+ : T_(nullptr)
+ {
+ }
+
+ inline TWeakPtrB(const TWeakPtrB<T>& r) noexcept
+ : T_(r.T_)
+ , C_(r.C_)
+ {
+ }
+
+ inline TWeakPtrB(const TSharedPtrB<T>& r) noexcept
+ : T_(r.T_)
+ , C_(r.C_)
+ {
+ }
+
+ TWeakPtrB& operator=(const TWeakPtrB<T>& r) noexcept {
+ TWeakPtrB(r).Swap(*this);
+ return *this;
+ }
+
+ TWeakPtrB& operator=(const TSharedPtrB<T>& r) noexcept {
+ TWeakPtrB(r).Swap(*this);
+ return *this;
+ }
+
+ inline void Swap(TWeakPtrB<T>& r) noexcept {
+ DoSwap(T_, r.T_);
+ DoSwap(C_, r.C_);
+ }
+
+ inline void Reset() noexcept {
+ T_ = 0;
+ C_.Reset();
+ }
+
+ inline size_t UseCount() const noexcept {
+ return C_.UseCount();
+ }
+
+ protected:
+ template <class Y>
+ friend class TSharedPtrB;
+
+ T* T_;
+ TWeakCount C_;
+ };
+
+}
diff --git a/library/cpp/neh/stat.cpp b/library/cpp/neh/stat.cpp
new file mode 100644
index 0000000000..ef6422fb52
--- /dev/null
+++ b/library/cpp/neh/stat.cpp
@@ -0,0 +1,114 @@
+#include "stat.h"
+
+#include <util/generic/hash.h>
+#include <util/generic/singleton.h>
+#include <util/system/spinlock.h>
+#include <util/system/tls.h>
+
+using namespace NNeh;
+
+volatile TAtomic NNeh::TServiceStat::MaxContinuousErrors_ = 0; //by default disabled
+volatile TAtomic NNeh::TServiceStat::ReSendValidatorPeriod_ = 100;
+
+NNeh::TServiceStat::EStatus NNeh::TServiceStat::GetStatus() {
+ if (!AtomicGet(MaxContinuousErrors_) || AtomicGet(LastContinuousErrors_) < AtomicGet(MaxContinuousErrors_)) {
+ return Ok;
+ }
+
+ if (RequestsInProcess_.Val() != 0)
+ return Fail;
+
+ if (AtomicIncrement(SendValidatorCounter_) != AtomicGet(ReSendValidatorPeriod_)) {
+ return Fail;
+ }
+
+ //time for refresh service status (send validation request)
+ AtomicSet(SendValidatorCounter_, 0);
+
+ return ReTry;
+}
+
+void NNeh::TServiceStat::DbgOut(IOutputStream& out) const {
+ out << "----------------------------------------------------" << '\n';;
+ out << "RequestsInProcess: " << RequestsInProcess_.Val() << '\n';
+ out << "LastContinuousErrors: " << AtomicGet(LastContinuousErrors_) << '\n';
+ out << "SendValidatorCounter: " << AtomicGet(SendValidatorCounter_) << '\n';
+ out << "ReSendValidatorPeriod: " << AtomicGet(ReSendValidatorPeriod_) << Endl;
+}
+
+void NNeh::TServiceStat::OnBegin() {
+ RequestsInProcess_.Inc();
+}
+
+void NNeh::TServiceStat::OnSuccess() {
+ RequestsInProcess_.Dec();
+ AtomicSet(LastContinuousErrors_, 0);
+}
+
+void NNeh::TServiceStat::OnCancel() {
+ RequestsInProcess_.Dec();
+}
+
+void NNeh::TServiceStat::OnFail() {
+ RequestsInProcess_.Dec();
+ if (AtomicIncrement(LastContinuousErrors_) == AtomicGet(MaxContinuousErrors_)) {
+ AtomicSet(SendValidatorCounter_, 0);
+ }
+}
+
+namespace {
+ class TGlobalServicesStat {
+ public:
+ inline TServiceStatRef ServiceStat(const TStringBuf addr) noexcept {
+ const auto guard = Guard(Lock_);
+
+ TServiceStatRef& ss = SS_[addr];
+
+ if (!ss) {
+ TServiceStatRef tmp(new TServiceStat());
+
+ ss.Swap(tmp);
+ }
+ return ss;
+ }
+
+ protected:
+ TAdaptiveLock Lock_;
+ THashMap<TString, TServiceStatRef> SS_;
+ };
+
+ class TServicesStat {
+ public:
+ inline TServiceStatRef ServiceStat(const TStringBuf addr) noexcept {
+ TServiceStatRef& ss = SS_[addr];
+
+ if (!ss) {
+ TServiceStatRef tmp(Singleton<TGlobalServicesStat>()->ServiceStat(addr));
+
+ ss.Swap(tmp);
+ }
+ return ss;
+ }
+
+ protected:
+ THashMap<TString, TServiceStatRef> SS_;
+ };
+
+ inline TServicesStat* ThrServiceStat() {
+ Y_POD_STATIC_THREAD(TServicesStat*)
+ ss;
+
+ if (!ss) {
+ Y_STATIC_THREAD(TServicesStat)
+ tss;
+
+ ss = &(TServicesStat&)tss;
+ }
+
+ return ss;
+ }
+}
+
+TServiceStatRef NNeh::GetServiceStat(const TStringBuf addr) {
+ return ThrServiceStat()->ServiceStat(addr);
+}
diff --git a/library/cpp/neh/stat.h b/library/cpp/neh/stat.h
new file mode 100644
index 0000000000..803e8d2974
--- /dev/null
+++ b/library/cpp/neh/stat.h
@@ -0,0 +1,96 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+#include <util/stream/output.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <library/cpp/deprecated/atomic/atomic_ops.h>
+
+namespace NNeh {
+ class TStatCollector;
+
+ /// NEH service workability statistics collector.
+ ///
+ /// Disabled by default, use `TServiceStat::ConfigureValidator` to set `maxContinuousErrors`
+ /// different from zero.
+ class TServiceStat: public TThrRefBase {
+ public:
+ static void ConfigureValidator(unsigned maxContinuousErrors, unsigned reSendValidatorPeriod) noexcept {
+ AtomicSet(MaxContinuousErrors_, maxContinuousErrors);
+ AtomicSet(ReSendValidatorPeriod_, reSendValidatorPeriod);
+ }
+ static bool Disabled() noexcept {
+ return !AtomicGet(MaxContinuousErrors_);
+ }
+
+ enum EStatus {
+ Ok,
+ Fail,
+ ReTry //time for sending request-validator to service
+ };
+
+ EStatus GetStatus();
+
+ void DbgOut(IOutputStream&) const;
+
+ protected:
+ friend class TStatCollector;
+
+ virtual void OnBegin();
+ virtual void OnSuccess();
+ virtual void OnCancel();
+ virtual void OnFail();
+
+ static TAtomic MaxContinuousErrors_;
+ static TAtomic ReSendValidatorPeriod_;
+ TAtomicCounter RequestsInProcess_;
+ TAtomic LastContinuousErrors_ = 0;
+ TAtomic SendValidatorCounter_ = 0;
+ };
+
+ using TServiceStatRef = TIntrusivePtr<TServiceStat>;
+
+ //thread safe (race protected) service stat updater
+ class TStatCollector {
+ public:
+ TStatCollector(TServiceStatRef& ss)
+ : SS_(ss)
+ {
+ ss->OnBegin();
+ }
+
+ ~TStatCollector() {
+ if (CanInformSS()) {
+ SS_->OnFail();
+ }
+ }
+
+ void OnCancel() noexcept {
+ if (CanInformSS()) {
+ SS_->OnCancel();
+ }
+ }
+
+ void OnFail() noexcept {
+ if (CanInformSS()) {
+ SS_->OnFail();
+ }
+ }
+
+ void OnSuccess() noexcept {
+ if (CanInformSS()) {
+ SS_->OnSuccess();
+ }
+ }
+
+ private:
+ inline bool CanInformSS() noexcept {
+ return AtomicGet(CanInformSS_) && AtomicCas(&CanInformSS_, 0, 1);
+ }
+
+ TServiceStatRef SS_;
+ TAtomic CanInformSS_ = 1;
+ };
+
+ TServiceStatRef GetServiceStat(TStringBuf addr);
+
+}
diff --git a/library/cpp/neh/tcp.cpp b/library/cpp/neh/tcp.cpp
new file mode 100644
index 0000000000..80f464dac2
--- /dev/null
+++ b/library/cpp/neh/tcp.cpp
@@ -0,0 +1,676 @@
+#include "tcp.h"
+
+#include "details.h"
+#include "factory.h"
+#include "location.h"
+#include "pipequeue.h"
+#include "utils.h"
+
+#include <library/cpp/coroutine/listener/listen.h>
+#include <library/cpp/coroutine/engine/events.h>
+#include <library/cpp/coroutine/engine/sockpool.h>
+#include <library/cpp/dns/cache.h>
+
+#include <util/ysaveload.h>
+#include <util/generic/buffer.h>
+#include <util/generic/guid.h>
+#include <util/generic/hash.h>
+#include <util/generic/intrlist.h>
+#include <util/generic/ptr.h>
+#include <util/generic/vector.h>
+#include <util/system/yassert.h>
+#include <util/system/unaligned_mem.h>
+#include <util/stream/buffered.h>
+#include <util/stream/mem.h>
+
+using namespace NDns;
+using namespace NNeh;
+
+using TNehMessage = TMessage;
+
+template <>
+struct TSerializer<TGUID> {
+ static inline void Save(IOutputStream* out, const TGUID& g) {
+ out->Write(&g.dw, sizeof(g.dw));
+ }
+
+ static inline void Load(IInputStream* in, TGUID& g) {
+ in->Load(&g.dw, sizeof(g.dw));
+ }
+};
+
+namespace {
+ namespace NNehTCP {
+ typedef IOutputStream::TPart TPart;
+
+ static inline ui64 LocalGuid(const TGUID& g) {
+ return ReadUnaligned<ui64>(g.dw);
+ }
+
+ static inline TString LoadStroka(IInputStream& input, size_t len) {
+ TString tmp;
+
+ tmp.ReserveAndResize(len);
+ input.Load(tmp.begin(), tmp.size());
+
+ return tmp;
+ }
+
+ struct TParts: public TVector<TPart> {
+ template <class T>
+ inline void Push(const T& t) {
+ Push(TPart(t));
+ }
+
+ inline void Push(const TPart& part) {
+ if (part.len) {
+ push_back(part);
+ }
+ }
+
+ inline void Clear() noexcept {
+ clear();
+ }
+ };
+
+ template <class T>
+ struct TMessageQueue {
+ inline TMessageQueue(TContExecutor* e)
+ : Ev(e)
+ {
+ }
+
+ template <class TPtr>
+ inline void Enqueue(TPtr p) noexcept {
+ L.PushBack(p.Release());
+ Ev.Signal();
+ }
+
+ template <class TPtr>
+ inline bool Dequeue(TPtr& p) noexcept {
+ do {
+ if (TryDequeue(p)) {
+ return true;
+ }
+ } while (Ev.WaitI() != ECANCELED);
+
+ return false;
+ }
+
+ template <class TPtr>
+ inline bool TryDequeue(TPtr& p) noexcept {
+ if (L.Empty()) {
+ return false;
+ }
+
+ p.Reset(L.PopFront());
+
+ return true;
+ }
+
+ inline TContExecutor* Executor() const noexcept {
+ return Ev.Executor();
+ }
+
+ TIntrusiveListWithAutoDelete<T, TDelete> L;
+ TContSimpleEvent Ev;
+ };
+
+ template <class Q, class C>
+ inline bool Dequeue(Q& q, C& c, size_t len) {
+ typename C::value_type t;
+ size_t slen = 0;
+
+ if (q.Dequeue(t)) {
+ slen += t->Length();
+ c.push_back(t);
+
+ while (slen < len && q.TryDequeue(t)) {
+ slen += t->Length();
+ c.push_back(t);
+ }
+
+ return true;
+ }
+
+ return false;
+ }
+
+ struct TServer: public IRequester, public TContListener::ICallBack {
+ struct TLink;
+ typedef TIntrusivePtr<TLink> TLinkRef;
+
+ struct TResponce: public TIntrusiveListItem<TResponce> {
+ inline TResponce(const TLinkRef& link, TData& data, TStringBuf reqid)
+ : Link(link)
+ {
+ Data.swap(data);
+
+ TMemoryOutput out(Buf, sizeof(Buf));
+
+ ::Save(&out, (ui32)(reqid.size() + Data.size()));
+ out.Write(reqid.data(), reqid.size());
+
+ Y_ASSERT(reqid.size() == 16);
+
+ Len = out.Buf() - Buf;
+ }
+
+ inline void Serialize(TParts& parts) {
+ parts.Push(TStringBuf(Buf, Len));
+ parts.Push(TStringBuf(Data.data(), Data.size()));
+ }
+
+ inline size_t Length() const noexcept {
+ return Len + Data.size();
+ }
+
+ TLinkRef Link;
+ TData Data;
+ char Buf[32];
+ size_t Len;
+ };
+
+ typedef TAutoPtr<TResponce> TResponcePtr;
+
+ struct TRequest: public IRequest {
+ inline TRequest(const TLinkRef& link, IInputStream& in, size_t len)
+ : Link(link)
+ {
+ Buf.Proceed(len);
+ in.Load(Buf.Data(), Buf.Size());
+ if ((ServiceBegin() - Buf.Data()) + ServiceLen() > Buf.Size()) {
+ throw yexception() << "invalid request (service len)";
+ }
+ }
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("tcp");
+ }
+
+ TString RemoteHost() const override {
+ return Link->RemoteHost;
+ }
+
+ TStringBuf Service() const override {
+ return TStringBuf(ServiceBegin(), ServiceLen());
+ }
+
+ TStringBuf Data() const override {
+ return TStringBuf(Service().end(), Buf.End());
+ }
+
+ TStringBuf RequestId() const override {
+ return TStringBuf(Buf.Data(), 16);
+ }
+
+ bool Canceled() const override {
+ //TODO
+ return false;
+ }
+
+ void SendReply(TData& data) override {
+ Link->P->Schedule(new TResponce(Link, data, RequestId()));
+ }
+
+ void SendError(TResponseError, const TString&) override {
+ // TODO
+ }
+
+ size_t ServiceLen() const noexcept {
+ const char* ptr = RequestId().end();
+ return *(ui32*)ptr;
+ }
+
+ const char* ServiceBegin() const noexcept {
+ return RequestId().end() + sizeof(ui32);
+ }
+
+ TBuffer Buf;
+ TLinkRef Link;
+ };
+
+ struct TLink: public TAtomicRefCount<TLink> {
+ inline TLink(TServer* parent, const TAcceptFull& a)
+ : P(parent)
+ , MQ(Executor())
+ {
+ S.Swap(*a.S);
+ SetNoDelay(S, true);
+
+ RemoteHost = PrintHostByRfc(*GetPeerAddr(S));
+
+ TLinkRef self(this);
+
+ Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv");
+ Executor()->Create<TLink, &TLink::SendCycle>(this, "send");
+
+ Executor()->Running()->Yield();
+ }
+
+ inline void Enqueue(TResponcePtr res) {
+ MQ.Enqueue(res);
+ }
+
+ inline TContExecutor* Executor() const noexcept {
+ return P->E.Get();
+ }
+
+ void SendCycle(TCont* c) {
+ TLinkRef self(this);
+
+ try {
+ DoSendCycle(c);
+ } catch (...) {
+ Cdbg << "neh/tcp/1: " << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ inline void DoSendCycle(TCont* c) {
+ TVector<TResponcePtr> responses;
+ TParts parts;
+
+ while (Dequeue(MQ, responses, 7000)) {
+ for (size_t i = 0; i < responses.size(); ++i) {
+ responses[i]->Serialize(parts);
+ }
+
+ {
+ TContIOVector iovec(parts.data(), parts.size());
+ NCoro::WriteVectorI(c, S, &iovec);
+ }
+
+ parts.Clear();
+ responses.clear();
+ }
+ }
+
+ void RecvCycle(TCont* c) {
+ TLinkRef self(this);
+
+ try {
+ DoRecvCycle(c);
+ } catch (...) {
+ if (!c->Cancelled()) {
+ Cdbg << "neh/tcp/2: " << CurrentExceptionMessage() << Endl;
+ }
+ }
+ }
+
+ inline void DoRecvCycle(TCont* c) {
+ TContIO io(S, c);
+ TBufferedInput input(&io, 8192 * 4);
+
+ while (true) {
+ ui32 len;
+
+ try {
+ ::Load(&input, len);
+ } catch (TLoadEOF&) {
+ return;
+ }
+
+ P->CB->OnRequest(new TRequest(this, input, len));
+ }
+ }
+
+ TServer* P;
+ TMessageQueue<TResponce> MQ;
+ TSocketHolder S;
+ TString RemoteHost;
+ };
+
+ inline TServer(IOnRequest* cb, ui16 port)
+ : CB(cb)
+ , Addr(port)
+ {
+ Thrs.push_back(Spawn<TServer, &TServer::Run>(this));
+ }
+
+ ~TServer() override {
+ Schedule(nullptr);
+
+ for (size_t i = 0; i < Thrs.size(); ++i) {
+ Thrs[i]->Join();
+ }
+ }
+
+ void Run() {
+ E = MakeHolder<TContExecutor>(RealStackSize(32000));
+ THolder<TContListener> L(new TContListener(this, E.Get(), TContListener::TOptions().SetDeferAccept(true)));
+ //SetHighestThreadPriority();
+ L->Bind(Addr);
+ E->Create<TServer, &TServer::RunDispatcher>(this, "dispatcher");
+ L->Listen();
+ E->Execute();
+ }
+
+ void OnAcceptFull(const TAcceptFull& a) override {
+ //I love such code
+ new TLink(this, a);
+ }
+
+ void OnError() override {
+ Cerr << CurrentExceptionMessage() << Endl;
+ }
+
+ inline void Schedule(TResponcePtr res) {
+ PQ.EnqueueSafe(res);
+ }
+
+ void RunDispatcher(TCont* c) {
+ while (true) {
+ TResponcePtr res;
+
+ PQ.DequeueSafe(c, res);
+
+ if (!res) {
+ break;
+ }
+
+ TLinkRef link = res->Link;
+
+ link->Enqueue(res);
+ }
+
+ c->Executor()->Abort();
+ }
+ THolder<TContExecutor> E;
+ IOnRequest* CB;
+ TNetworkAddress Addr;
+ TOneConsumerPipeQueue<TResponce> PQ;
+ TVector<TThreadRef> Thrs;
+ };
+
+ struct TClient {
+ struct TRequest: public TIntrusiveListItem<TRequest> {
+ inline TRequest(const TSimpleHandleRef& hndl, const TNehMessage& msg)
+ : Hndl(hndl)
+ , Msg(msg)
+ , Loc(Msg.Addr)
+ , RI(CachedThrResolve(TResolveInfo(Loc.Host, Loc.GetPort())))
+ {
+ CreateGuid(&Guid);
+ }
+
+ inline void Serialize(TParts& parts) {
+ TMemoryOutput out(Buf, sizeof(Buf));
+
+ ::Save(&out, (ui32)MsgLen());
+ ::Save(&out, Guid);
+ ::Save(&out, (ui32) Loc.Service.size());
+
+ if (Loc.Service.size() > out.Avail()) {
+ parts.Push(TStringBuf(Buf, out.Buf()));
+ parts.Push(Loc.Service);
+ } else {
+ out.Write(Loc.Service.data(), Loc.Service.size());
+ parts.Push(TStringBuf(Buf, out.Buf()));
+ }
+
+ parts.Push(Msg.Data);
+ }
+
+ inline size_t Length() const noexcept {
+ return sizeof(ui32) + MsgLen();
+ }
+
+ inline size_t MsgLen() const noexcept {
+ return sizeof(Guid.dw) + sizeof(ui32) + Loc.Service.size() + Msg.Data.size();
+ }
+
+ void OnError(const TString& errText) {
+ Hndl->NotifyError(errText);
+ }
+
+ TSimpleHandleRef Hndl;
+ TNehMessage Msg;
+ TGUID Guid;
+ const TParsedLocation Loc;
+ const TResolvedHost* RI;
+ char Buf[128];
+ };
+
+ typedef TAutoPtr<TRequest> TRequestPtr;
+
+ struct TChannel {
+ struct TLink: public TIntrusiveListItem<TLink>, public TSimpleRefCount<TLink> {
+ inline TLink(TChannel* parent)
+ : P(parent)
+ {
+ Executor()->Create<TLink, &TLink::SendCycle>(this, "send");
+ }
+
+ void SendCycle(TCont* c) {
+ TIntrusivePtr<TLink> self(this);
+
+ try {
+ DoSendCycle(c);
+ OnError("shutdown");
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+
+ Unlink();
+ }
+
+ inline void DoSendCycle(TCont* c) {
+ if (int ret = NCoro::ConnectI(c, S, P->RI->Addr)) {
+ ythrow TSystemError(ret) << "can't connect";
+ }
+ SetNoDelay(S, true);
+ Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv");
+
+ TVector<TRequestPtr> reqs;
+ TParts parts;
+
+ while (Dequeue(P->Q, reqs, 7000)) {
+ for (size_t i = 0; i < reqs.size(); ++i) {
+ TRequestPtr& req = reqs[i];
+
+ req->Serialize(parts);
+ InFly[LocalGuid(req->Guid)] = req;
+ }
+
+ {
+ TContIOVector vec(parts.data(), parts.size());
+ NCoro::WriteVectorI(c, S, &vec);
+ }
+
+ reqs.clear();
+ parts.Clear();
+ }
+ }
+
+ void RecvCycle(TCont* c) {
+ TIntrusivePtr<TLink> self(this);
+
+ try {
+ DoRecvCycle(c);
+ OnError("service close connection");
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+ }
+
+ inline void DoRecvCycle(TCont* c) {
+ TContIO io(S, c);
+ TBufferedInput input(&io, 8192 * 4);
+
+ while (true) {
+ ui32 len;
+ TGUID g;
+
+ try {
+ ::Load(&input, len);
+ } catch (TLoadEOF&) {
+ return;
+ }
+ ::Load(&input, g);
+ const TString data(LoadStroka(input, len - sizeof(g.dw)));
+
+ TInFly::iterator it = InFly.find(LocalGuid(g));
+
+ if (it == InFly.end()) {
+ continue;
+ }
+
+ TRequestPtr req = it->second;
+
+ InFly.erase(it);
+ req->Hndl->NotifyResponse(data);
+ }
+ }
+
+ inline TContExecutor* Executor() const noexcept {
+ return P->Q.Executor();
+ }
+
+ void OnError(const TString& errText) {
+ for (auto& it : InFly) {
+ it.second->OnError(errText);
+ }
+ InFly.clear();
+
+ TRequestPtr req;
+ while (P->Q.TryDequeue(req)) {
+ req->OnError(errText);
+ }
+ }
+
+ TChannel* P;
+ TSocketHolder S;
+ typedef THashMap<ui64, TRequestPtr> TInFly;
+ TInFly InFly;
+ };
+
+ inline TChannel(TContExecutor* e, const TResolvedHost* ri)
+ : Q(e)
+ , RI(ri)
+ {
+ }
+
+ inline void Enqueue(TRequestPtr req) {
+ Q.Enqueue(req);
+
+ if (Links.Empty()) {
+ for (size_t i = 0; i < 1; ++i) {
+ SpawnLink();
+ }
+ }
+ }
+
+ inline void SpawnLink() {
+ Links.PushBack(new TLink(this));
+ }
+
+ TMessageQueue<TRequest> Q;
+ TIntrusiveList<TLink> Links;
+ const TResolvedHost* RI;
+ };
+
+ typedef TAutoPtr<TChannel> TChannelPtr;
+
+ inline TClient() {
+ Thr = Spawn<TClient, &TClient::RunExecutor>(this);
+ }
+
+ inline ~TClient() {
+ Reqs.Enqueue(nullptr);
+ Thr->Join();
+ }
+
+ inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) {
+ TSimpleHandleRef ret(new TSimpleHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+
+ Reqs.Enqueue(new TRequest(ret, msg));
+
+ return ret.Get();
+ }
+
+ void RunExecutor() {
+ //SetHighestThreadPriority();
+ TContExecutor e(RealStackSize(32000));
+
+ e.Create<TClient, &TClient::RunDispatcher>(this, "dispatcher");
+ e.Execute();
+ }
+
+ void RunDispatcher(TCont* c) {
+ TRequestPtr req;
+
+ while (true) {
+ Reqs.DequeueSafe(c, req);
+
+ if (!req) {
+ break;
+ }
+
+ TChannelPtr& ch = Channels.Get(req->RI->Id);
+
+ if (!ch) {
+ ch.Reset(new TChannel(c->Executor(), req->RI));
+ }
+
+ ch->Enqueue(req);
+ }
+
+ c->Executor()->Abort();
+ }
+
+ TThreadRef Thr;
+ TOneConsumerPipeQueue<TRequest> Reqs;
+ TSocketMap<TChannelPtr> Channels;
+ };
+
+ struct TMultiClient {
+ inline TMultiClient()
+ : Next(0)
+ {
+ for (size_t i = 0; i < 2; ++i) {
+ Clients.push_back(new TClient());
+ }
+ }
+
+ inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) {
+ return Clients[AtomicIncrement(Next) % Clients.size()]->Schedule(msg, fallback, ss);
+ }
+
+ TVector<TAutoPtr<TClient>> Clients;
+ TAtomic Next;
+ };
+
+#if 0
+ static inline TMultiClient* Client() {
+ return Singleton<NNehTCP::TMultiClient>();
+ }
+#else
+ static inline TClient* Client() {
+ return Singleton<NNehTCP::TClient>();
+ }
+#endif
+
+ class TTcpProtocol: public IProtocol {
+ public:
+ inline TTcpProtocol() {
+ InitNetworkSubSystem();
+ }
+
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ return new TServer(cb, loc.GetPort());
+ }
+
+ THandleRef ScheduleRequest(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ return Client()->Schedule(msg, fallback, ss);
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return TStringBuf("tcp");
+ }
+ };
+ }
+}
+
+IProtocol* NNeh::TcpProtocol() {
+ return Singleton<NNehTCP::TTcpProtocol>();
+}
diff --git a/library/cpp/neh/tcp.h b/library/cpp/neh/tcp.h
new file mode 100644
index 0000000000..e0d25d0bac
--- /dev/null
+++ b/library/cpp/neh/tcp.h
@@ -0,0 +1,7 @@
+#pragma once
+
+namespace NNeh {
+ class IProtocol;
+
+ IProtocol* TcpProtocol();
+}
diff --git a/library/cpp/neh/tcp2.cpp b/library/cpp/neh/tcp2.cpp
new file mode 100644
index 0000000000..3dad055af1
--- /dev/null
+++ b/library/cpp/neh/tcp2.cpp
@@ -0,0 +1,1656 @@
+#include "tcp2.h"
+
+#include "details.h"
+#include "factory.h"
+#include "http_common.h"
+#include "neh.h"
+#include "utils.h"
+
+#include <library/cpp/dns/cache.h>
+#include <library/cpp/neh/asio/executor.h>
+#include <library/cpp/threading/atomic/bool.h>
+
+#include <util/generic/buffer.h>
+#include <util/generic/hash.h>
+#include <util/generic/singleton.h>
+#include <util/network/endpoint.h>
+#include <util/network/init.h>
+#include <util/network/iovec.h>
+#include <util/network/socket.h>
+#include <util/string/cast.h>
+
+#include <atomic>
+
+//#define DEBUG_TCP2 1
+#ifdef DEBUG_TCP2
+TSpinLock OUT_LOCK;
+#define DBGOUT(args) \
+ { \
+ TGuard<TSpinLock> m(OUT_LOCK); \
+ Cout << TInstant::Now().GetValue() << " " << args << Endl; \
+ }
+#else
+#define DBGOUT(args)
+#endif
+
+using namespace std::placeholders;
+
+namespace NNeh {
+ TDuration TTcp2Options::ConnectTimeout = TDuration::MilliSeconds(300);
+ size_t TTcp2Options::InputBufferSize = 16000;
+ size_t TTcp2Options::AsioClientThreads = 4;
+ size_t TTcp2Options::AsioServerThreads = 4;
+ int TTcp2Options::Backlog = 100;
+ bool TTcp2Options::ClientUseDirectWrite = true;
+ bool TTcp2Options::ServerUseDirectWrite = true;
+ TDuration TTcp2Options::ServerInputDeadline = TDuration::Seconds(3600);
+ TDuration TTcp2Options::ServerOutputDeadline = TDuration::Seconds(10);
+
+ bool TTcp2Options::Set(TStringBuf name, TStringBuf value) {
+#define TCP2_TRY_SET(optType, optName) \
+ if (name == TStringBuf(#optName)) { \
+ optName = FromString<optType>(value); \
+ }
+
+ TCP2_TRY_SET(TDuration, ConnectTimeout)
+ else TCP2_TRY_SET(size_t, InputBufferSize) else TCP2_TRY_SET(size_t, AsioClientThreads) else TCP2_TRY_SET(size_t, AsioServerThreads) else TCP2_TRY_SET(int, Backlog) else TCP2_TRY_SET(bool, ClientUseDirectWrite) else TCP2_TRY_SET(bool, ServerUseDirectWrite) else TCP2_TRY_SET(TDuration, ServerInputDeadline) else TCP2_TRY_SET(TDuration, ServerOutputDeadline) else {
+ return false;
+ }
+ return true;
+ }
+}
+
+namespace {
+ namespace NNehTcp2 {
+ using namespace NAsio;
+ using namespace NDns;
+ using namespace NNeh;
+
+ const TString canceled = "canceled";
+ const TString emptyReply = "empty reply";
+
+ inline void PrepareSocket(SOCKET s) {
+ SetNoDelay(s, true);
+ }
+
+ typedef ui64 TRequestId;
+
+#pragma pack(push, 1) //disable align struct members (structs mapped to data transmitted other network)
+ struct TBaseHeader {
+ enum TMessageType {
+ Request = 1,
+ Response = 2,
+ Cancel = 3,
+ MaxMessageType
+ };
+
+ TBaseHeader(TRequestId id, ui32 headerLength, ui8 version, ui8 mType)
+ : Id(id)
+ , HeaderLength(headerLength)
+ , Version(version)
+ , Type(mType)
+ {
+ }
+
+ TRequestId Id; //message id, - monotonic inc. sequence (skip nil value)
+ ui32 HeaderLength;
+ ui8 Version; //current version: 1
+ ui8 Type; //<- TMessageType (+ in future possible ForceResponse,etc)
+ };
+
+ struct TRequestHeader: public TBaseHeader {
+ TRequestHeader(TRequestId reqId, size_t servicePathLength, size_t dataSize)
+ : TBaseHeader(reqId, sizeof(TRequestHeader) + servicePathLength, 1, (ui8)Request)
+ , ContentLength(dataSize)
+ {
+ }
+
+ ui32 ContentLength;
+ };
+
+ struct TResponseHeader: public TBaseHeader {
+ enum TErrorCode {
+ Success = 0,
+ EmptyReply = 1 //not found such service or service not sent response
+ ,
+ MaxErrorCode
+ };
+
+ TResponseHeader(TRequestId reqId, TErrorCode code, size_t dataSize)
+ : TBaseHeader(reqId, sizeof(TResponseHeader), 1, (ui8)Response)
+ , ErrorCode((ui16)code)
+ , ContentLength(dataSize)
+ {
+ }
+
+ TString ErrorDescription() const {
+ if (ErrorCode == (ui16)EmptyReply) {
+ return emptyReply;
+ }
+
+ TStringStream ss;
+ ss << TStringBuf("tcp2 err_code=") << ErrorCode;
+ return ss.Str();
+ }
+
+ ui16 ErrorCode;
+ ui32 ContentLength;
+ };
+
+ struct TCancelHeader: public TBaseHeader {
+ TCancelHeader(TRequestId reqId)
+ : TBaseHeader(reqId, sizeof(TCancelHeader), 1, (ui8)Cancel)
+ {
+ }
+ };
+#pragma pack(pop)
+
+ static const size_t maxHeaderSize = sizeof(TResponseHeader);
+
+ //buffer for read input data, - header + message data
+ struct TTcp2Message {
+ TTcp2Message()
+ : Loader_(&TTcp2Message::LoadBaseHeader)
+ , RequireBytesForComplete_(sizeof(TBaseHeader))
+ , Header_(sizeof(TBaseHeader))
+ {
+ }
+
+ void Clear() {
+ Loader_ = &TTcp2Message::LoadBaseHeader;
+ RequireBytesForComplete_ = sizeof(TBaseHeader);
+ Header_.Clear();
+ Content_.clear();
+ }
+
+ TBuffer& Header() noexcept {
+ return Header_;
+ }
+
+ const TString& Content() const noexcept {
+ return Content_;
+ }
+
+ bool IsComplete() const noexcept {
+ return RequireBytesForComplete_ == 0;
+ }
+
+ size_t LoadFrom(const char* buf, size_t len) {
+ return (this->*Loader_)(buf, len);
+ }
+
+ const TBaseHeader& BaseHeader() const {
+ return *reinterpret_cast<const TBaseHeader*>(Header_.Data());
+ }
+
+ const TRequestHeader& RequestHeader() const {
+ return *reinterpret_cast<const TRequestHeader*>(Header_.Data());
+ }
+
+ const TResponseHeader& ResponseHeader() const {
+ return *reinterpret_cast<const TResponseHeader*>(Header_.Data());
+ }
+
+ private:
+ size_t LoadBaseHeader(const char* buf, size_t len) {
+ size_t useBytes = Min<size_t>(sizeof(TBaseHeader) - Header_.Size(), len);
+ Header_.Append(buf, useBytes);
+ if (Y_UNLIKELY(sizeof(TBaseHeader) > Header_.Size())) {
+ //base header yet not complete
+ return useBytes;
+ }
+ {
+ const TBaseHeader& hdr = BaseHeader();
+ if (BaseHeader().HeaderLength > 32000) { //some heuristic header size limit
+ throw yexception() << TStringBuf("to large neh/tcp2 header size: ") << BaseHeader().HeaderLength;
+ }
+ //header completed
+ Header_.Reserve(hdr.HeaderLength);
+ }
+ const TBaseHeader& hdr = BaseHeader(); //reallocation can move Header_ data to another place, so use fresh 'hdr'
+ if (Y_UNLIKELY(hdr.Version != 1)) {
+ throw yexception() << TStringBuf("unsupported protocol version: ") << static_cast<unsigned>(hdr.Version);
+ }
+ RequireBytesForComplete_ = hdr.HeaderLength - sizeof(TBaseHeader);
+ return useBytes + LoadHeader(buf + useBytes, len - useBytes);
+ }
+
+ size_t LoadHeader(const char* buf, size_t len) {
+ size_t useBytes = Min<size_t>(RequireBytesForComplete_, len);
+ Header_.Append(buf, useBytes);
+ RequireBytesForComplete_ -= useBytes;
+ if (RequireBytesForComplete_) {
+ //continue load header
+ Loader_ = &TTcp2Message::LoadHeader;
+ return useBytes;
+ }
+
+ const TBaseHeader& hdr = *reinterpret_cast<const TBaseHeader*>(Header_.Data());
+
+ if (hdr.Type == TBaseHeader::Request) {
+ if (Header_.Size() < sizeof(TRequestHeader)) {
+ throw yexception() << TStringBuf("invalid request header size");
+ }
+ InitContentLoading(RequestHeader().ContentLength);
+ } else if (hdr.Type == TBaseHeader::Response) {
+ if (Header_.Size() < sizeof(TResponseHeader)) {
+ throw yexception() << TStringBuf("invalid response header size");
+ }
+ InitContentLoading(ResponseHeader().ContentLength);
+ } else if (hdr.Type == TBaseHeader::Cancel) {
+ if (Header_.Size() < sizeof(TCancelHeader)) {
+ throw yexception() << TStringBuf("invalid cancel header size");
+ }
+ return useBytes;
+ } else {
+ throw yexception() << TStringBuf("unsupported request type: ") << static_cast<unsigned>(hdr.Type);
+ }
+ return useBytes + (this->*Loader_)(buf + useBytes, len - useBytes);
+ }
+
+ void InitContentLoading(size_t contentLength) {
+ RequireBytesForComplete_ = contentLength;
+ Content_.ReserveAndResize(contentLength);
+ Loader_ = &TTcp2Message::LoadContent;
+ }
+
+ size_t LoadContent(const char* buf, size_t len) {
+ size_t curContentSize = Content_.size() - RequireBytesForComplete_;
+ size_t useBytes = Min<size_t>(RequireBytesForComplete_, len);
+ memcpy(Content_.begin() + curContentSize, buf, useBytes);
+ RequireBytesForComplete_ -= useBytes;
+ return useBytes;
+ }
+
+ private:
+ typedef size_t (TTcp2Message::*TLoader)(const char*, size_t);
+
+ TLoader Loader_; //current loader (stages - base-header/header/content)
+ size_t RequireBytesForComplete_;
+ TBuffer Header_;
+ TString Content_;
+ };
+
+ //base storage for output data
+ class TMultiBuffers {
+ public:
+ TMultiBuffers()
+ : IOVec_(nullptr, 0)
+ , DataSize_(0)
+ , PoolBytes_(0)
+ {
+ }
+
+ void Clear() noexcept {
+ Parts_.clear();
+ DataSize_ = 0;
+ PoolBytes_ = 0;
+ }
+
+ bool HasFreeSpace() const noexcept {
+ return DataSize_ < 64000 && (PoolBytes_ < (MemPoolSize_ - maxHeaderSize));
+ }
+
+ bool HasData() const noexcept {
+ return Parts_.size();
+ }
+
+ TContIOVector* GetIOvec() noexcept {
+ return &IOVec_;
+ }
+
+ protected:
+ void AddPart(const void* buf, size_t len) {
+ Parts_.push_back(IOutputStream::TPart(buf, len));
+ DataSize_ += len;
+ }
+
+ //used for allocate header (MUST be POD type)
+ template <typename T>
+ inline T* Allocate() noexcept {
+ size_t poolBytes = PoolBytes_;
+ PoolBytes_ += sizeof(T);
+ return (T*)(MemPool_ + poolBytes);
+ }
+
+ //used for allocate header (MUST be POD type) + some tail
+ template <typename T>
+ inline T* AllocatePlus(size_t tailSize) noexcept {
+ Y_ASSERT(tailSize <= MemPoolReserve_);
+ size_t poolBytes = PoolBytes_;
+ PoolBytes_ += sizeof(T) + tailSize;
+ return (T*)(MemPool_ + poolBytes);
+ }
+
+ protected:
+ TContIOVector IOVec_;
+ TVector<IOutputStream::TPart> Parts_;
+ static const size_t MemPoolSize_ = maxHeaderSize * 100;
+ static const size_t MemPoolReserve_ = 32;
+ size_t DataSize_;
+ size_t PoolBytes_;
+ char MemPool_[MemPoolSize_ + MemPoolReserve_];
+ };
+
+ //protector for limit usage tcp connection output (and used data) only from one thread at same time
+ class TOutputLock {
+ public:
+ TOutputLock() noexcept
+ : Lock_(0)
+ {
+ }
+
+ bool TryAquire() noexcept {
+ do {
+ if (AtomicTryLock(&Lock_)) {
+ return true;
+ }
+ } while (!AtomicGet(Lock_)); //without magic loop atomic lock some unreliable
+ return false;
+ }
+
+ void Release() noexcept {
+ AtomicUnlock(&Lock_);
+ }
+
+ bool IsFree() const noexcept {
+ return !AtomicGet(Lock_);
+ }
+
+ private:
+ TAtomic Lock_;
+ };
+
+ class TClient {
+ class TRequest;
+ class TConnection;
+ typedef TIntrusivePtr<TRequest> TRequestRef;
+ typedef TIntrusivePtr<TConnection> TConnectionRef;
+
+ class TRequest: public TThrRefBase, public TNonCopyable {
+ public:
+ class THandle: public TSimpleHandle {
+ public:
+ THandle(IOnRecv* f, const TMessage& msg, TStatCollector* s) noexcept
+ : TSimpleHandle(f, msg, s)
+ {
+ }
+
+ bool MessageSendedCompletely() const noexcept override {
+ if (TSimpleHandle::MessageSendedCompletely()) {
+ return true;
+ }
+
+ TRequestRef req = GetRequest();
+ if (!!req && req->RequestSendedCompletely()) {
+ const_cast<THandle*>(this)->SetSendComplete();
+ }
+
+ return TSimpleHandle::MessageSendedCompletely();
+ }
+
+ void Cancel() noexcept override {
+ if (TSimpleHandle::Canceled()) {
+ return;
+ }
+
+ TRequestRef req = GetRequest();
+ if (!!req) {
+ req->Cancel();
+ TSimpleHandle::Cancel();
+ }
+ }
+
+ void NotifyResponse(const TString& resp) {
+ TNotifyHandle::NotifyResponse(resp);
+
+ ReleaseRequest();
+ }
+
+ void NotifyError(const TString& error) {
+ TNotifyHandle::NotifyError(error);
+
+ ReleaseRequest();
+ }
+
+ void NotifyError(TErrorRef error) {
+ TNotifyHandle::NotifyError(error);
+
+ ReleaseRequest();
+ }
+
+ //not thread safe!
+ void SetRequest(const TRequestRef& r) noexcept {
+ Req_ = r;
+ }
+
+ void ReleaseRequest() noexcept {
+ TRequestRef tmp;
+ TGuard<TSpinLock> g(SP_);
+ tmp.Swap(Req_);
+ }
+
+ private:
+ TRequestRef GetRequest() const noexcept {
+ TGuard<TSpinLock> g(SP_);
+ return Req_;
+ }
+
+ mutable TSpinLock SP_;
+ TRequestRef Req_;
+ };
+
+ typedef TIntrusivePtr<THandle> THandleRef;
+
+ static void Run(THandleRef& h, const TMessage& msg, TClient& clnt) {
+ TRequestRef req(new TRequest(h, msg, clnt));
+ h->SetRequest(req);
+ req->Run(req);
+ }
+
+ ~TRequest() override {
+ DBGOUT("TClient::~TRequest()");
+ }
+
+ private:
+ TRequest(THandleRef& h, TMessage msg, TClient& clnt)
+ : Hndl_(h)
+ , Clnt_(clnt)
+ , Msg_(std::move(msg))
+ , Loc_(Msg_.Addr)
+ , Addr_(CachedResolve(TResolveInfo(Loc_.Host, Loc_.GetPort())))
+ , Canceled_(false)
+ , Id_(0)
+ {
+ DBGOUT("TClient::TRequest()");
+ }
+
+ void Run(TRequestRef& req) {
+ TDestination& dest = Clnt_.Dest_.Get(Addr_->Id);
+ dest.Run(req);
+ }
+
+ public:
+ void OnResponse(TTcp2Message& msg) {
+ DBGOUT("TRequest::OnResponse: " << msg.ResponseHeader().Id);
+ THandleRef h = ReleaseHandler();
+ if (!h) {
+ return;
+ }
+
+ const TResponseHeader& respHdr = msg.ResponseHeader();
+ if (Y_LIKELY(!respHdr.ErrorCode)) {
+ h->NotifyResponse(msg.Content());
+ } else {
+ h->NotifyError(new TError(respHdr.ErrorDescription(), TError::ProtocolSpecific, respHdr.ErrorCode));
+ }
+ ReleaseConn();
+ }
+
+ void OnError(const TString& err, const i32 systemCode = 0) {
+ DBGOUT("TRequest::OnError: " << Id_.load(std::memory_order_acquire));
+ THandleRef h = ReleaseHandler();
+ if (!h) {
+ return;
+ }
+
+ h->NotifyError(new TError(err, TError::UnknownType, 0, systemCode));
+ ReleaseConn();
+ }
+
+ void SetConnection(TConnection* conn) noexcept {
+ auto g = Guard(AL_);
+ Conn_ = conn;
+ }
+
+ bool Canceled() const noexcept {
+ return Canceled_;
+ }
+
+ const TResolvedHost* Addr() const noexcept {
+ return Addr_;
+ }
+
+ TStringBuf Service() const noexcept {
+ return Loc_.Service;
+ }
+
+ const TString& Data() const noexcept {
+ return Msg_.Data;
+ }
+
+ TClient& Client() noexcept {
+ return Clnt_;
+ }
+
+ bool RequestSendedCompletely() const noexcept {
+ if (Id_.load(std::memory_order_acquire) == 0) {
+ return false;
+ }
+
+ TConnectionRef conn = GetConn();
+ if (!conn) {
+ return false;
+ }
+
+ TRequestId lastSendedReqId = conn->LastSendedRequestId();
+ if (lastSendedReqId >= Id_.load(std::memory_order_acquire)) {
+ return true;
+ } else if (Y_UNLIKELY((Id_.load(std::memory_order_acquire) - lastSendedReqId) > (Max<TRequestId>() - Max<ui32>()))) {
+ //overflow req-id value
+ return true;
+ }
+ return false;
+ }
+
+ void Cancel() noexcept {
+ Canceled_ = true;
+ THandleRef h = ReleaseHandler();
+ if (!h) {
+ return;
+ }
+
+ TConnectionRef conn = ReleaseConn();
+ if (!!conn && Id_.load(std::memory_order_acquire)) {
+ conn->Cancel(Id_.load(std::memory_order_acquire));
+ }
+ h->NotifyError(new TError(canceled, TError::Cancelled));
+ }
+
+ void SetReqId(TRequestId reqId) noexcept {
+ auto guard = Guard(IdLock_);
+ Id_.store(reqId, std::memory_order_release);
+ }
+
+ TRequestId ReqId() const noexcept {
+ return Id_.load(std::memory_order_acquire);
+ }
+
+ private:
+ inline THandleRef ReleaseHandler() noexcept {
+ THandleRef h;
+ {
+ auto g = Guard(AL_);
+ h.Swap(Hndl_);
+ }
+ return h;
+ }
+
+ inline TConnectionRef GetConn() const noexcept {
+ auto g = Guard(AL_);
+ return Conn_;
+ }
+
+ inline TConnectionRef ReleaseConn() noexcept {
+ TConnectionRef c;
+ {
+ auto g = Guard(AL_);
+ c.Swap(Conn_);
+ }
+ return c;
+ }
+
+ mutable TAdaptiveLock AL_; //guaranted calling notify() only once (prevent race between asio thread and current)
+ THandleRef Hndl_;
+ TClient& Clnt_;
+ const TMessage Msg_;
+ const TParsedLocation Loc_;
+ const TResolvedHost* Addr_;
+ TConnectionRef Conn_;
+ NAtomic::TBool Canceled_;
+ TSpinLock IdLock_;
+ std::atomic<TRequestId> Id_;
+ };
+
+ class TConnection: public TThrRefBase {
+ enum TState {
+ Init,
+ Connecting,
+ Connected,
+ Closed,
+ MaxState
+ };
+ typedef THashMap<TRequestId, TRequestRef> TReqsInFly;
+
+ public:
+ class TOutputBuffers: public TMultiBuffers {
+ public:
+ void AddRequest(const TRequestRef& req) {
+ Requests_.push_back(req);
+ if (req->Service().size() > MemPoolReserve_) {
+ TRequestHeader* hdr = new (Allocate<TRequestHeader>()) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size());
+ AddPart(hdr, sizeof(TRequestHeader));
+ AddPart(req->Service().data(), req->Service().size());
+ } else {
+ TRequestHeader* hdr = new (AllocatePlus<TRequestHeader>(req->Service().size())) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size());
+ AddPart(hdr, sizeof(TRequestHeader) + req->Service().size());
+ memmove(++hdr, req->Service().data(), req->Service().size());
+ }
+ AddPart(req->Data().data(), req->Data().size());
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void AddCancelRequest(TRequestId reqId) {
+ TCancelHeader* hdr = new (Allocate<TCancelHeader>()) TCancelHeader(reqId);
+ AddPart(hdr, sizeof(TCancelHeader));
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void Clear() {
+ TMultiBuffers::Clear();
+ Requests_.clear();
+ }
+
+ private:
+ TVector<TRequestRef> Requests_;
+ };
+
+ TConnection(TIOService& srv)
+ : AS_(srv)
+ , State_(Init)
+ , BuffSize_(TTcp2Options::InputBufferSize)
+ , Buff_(new char[BuffSize_])
+ , NeedCheckReqsQueue_(0)
+ , NeedCheckCancelsQueue_(0)
+ , GenReqId_(0)
+ , LastSendedReqId_(0)
+ {
+ }
+
+ ~TConnection() override {
+ try {
+ DBGOUT("TClient::~TConnection()");
+ OnError("~");
+ } catch (...) {
+ Cdbg << "tcp2::~cln_conn: " << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ //called from client thread
+ bool Run(TRequestRef& req) {
+ if (Y_UNLIKELY(AtomicGet(State_) == Closed)) {
+ return false;
+ }
+
+ req->Ref();
+ try {
+ Reqs_.Enqueue(req.Get());
+ } catch (...) {
+ req->UnRef();
+ throw;
+ }
+
+ AtomicSet(NeedCheckReqsQueue_, 1);
+ req->SetConnection(this);
+ TAtomicBase state = AtomicGet(State_);
+ if (Y_LIKELY(state == Connected)) {
+ ProcessOutputReqsQueue();
+ return true;
+ }
+
+ if (state == Init) {
+ if (AtomicCas(&State_, Connecting, Init)) {
+ try {
+ TEndpoint addr(new NAddr::TAddrInfo(&*req->Addr()->Addr.Begin()));
+ AS_.AsyncConnect(addr, std::bind(&TConnection::OnConnect, TConnectionRef(this), _1, _2), TTcp2Options::ConnectTimeout);
+ } catch (...) {
+ AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCallback, TConnectionRef(this), CurrentExceptionMessage()));
+ }
+ return true;
+ }
+ }
+ state = AtomicGet(State_);
+ if (state == Connected) {
+ ProcessOutputReqsQueue();
+ } else if (state == Closed) {
+ SafeOnError();
+ }
+ return true;
+ }
+
+ //called from client thread
+ void Cancel(TRequestId id) {
+ Cancels_.Enqueue(id);
+ AtomicSet(NeedCheckCancelsQueue_, 1);
+ if (Y_LIKELY(AtomicGet(State_) == Connected)) {
+ ProcessOutputCancelsQueue();
+ }
+ }
+
+ void ProcessOutputReqsQueue() {
+ if (OutputLock_.TryAquire()) {
+ SendMessages(false);
+ }
+ }
+
+ void ProcessOutputCancelsQueue() {
+ if (OutputLock_.TryAquire()) {
+ AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true));
+ return;
+ }
+ }
+
+ //must be called only from asio thread
+ void ProcessReqsInFlyQueue() {
+ if (AtomicGet(State_) == Closed) {
+ return;
+ }
+
+ TRequest* reqPtr;
+
+ while (ReqsInFlyQueue_.Dequeue(&reqPtr)) {
+ TRequestRef reqTmp(reqPtr);
+ reqPtr->UnRef();
+ ReqsInFly_[reqPtr->ReqId()].Swap(reqTmp);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnConnect(const TErrorCode& ec, IHandlingContext&) {
+ DBGOUT("TConnect::OnConnect: " << ec.Value());
+ if (Y_UNLIKELY(ec)) {
+ if (ec.Value() == EIO) {
+ //try get more detail error info
+ char buf[1];
+ TErrorCode errConnect;
+ AS_.ReadSome(buf, 1, errConnect);
+ OnErrorCode(errConnect.Value() ? errConnect : ec);
+ } else {
+ OnErrorCode(ec);
+ }
+ } else {
+ try {
+ PrepareSocket(AS_.Native());
+ AtomicSet(State_, Connected);
+ AS_.AsyncPollRead(std::bind(&TConnection::OnCanRead, TConnectionRef(this), _1, _2));
+ if (OutputLock_.TryAquire()) {
+ SendMessages(true);
+ return;
+ }
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+ }
+ }
+
+ //must be called only after succes aquiring output
+ void SendMessages(bool asioThread) {
+ //DBGOUT("SendMessages");
+ if (Y_UNLIKELY(AtomicGet(State_) == Closed)) {
+ if (asioThread) {
+ OnError(Error_);
+ } else {
+ SafeOnError();
+ }
+ return;
+ }
+
+ do {
+ if (asioThread) {
+ AtomicSet(NeedCheckCancelsQueue_, 0);
+ TRequestId reqId;
+
+ ProcessReqsInFlyQueue();
+ while (Cancels_.Dequeue(&reqId)) {
+ TReqsInFly::iterator it = ReqsInFly_.find(reqId);
+ if (it == ReqsInFly_.end()) {
+ continue;
+ }
+
+ ReqsInFly_.erase(it);
+ OutputBuffers_.AddCancelRequest(reqId);
+ if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) {
+ if (!FlushOutputBuffers(asioThread, 0)) {
+ return;
+ }
+ }
+ }
+ } else if (AtomicGet(NeedCheckCancelsQueue_)) {
+ AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true));
+ return;
+ }
+
+ TRequestId lastReqId = 0;
+ {
+ AtomicSet(NeedCheckReqsQueue_, 0);
+ TRequest* reqPtr;
+
+ while (Reqs_.Dequeue(&reqPtr)) {
+ TRequestRef reqTmp(reqPtr);
+ reqPtr->UnRef();
+ reqPtr->SetReqId(GenerateReqId());
+ if (reqPtr->Canceled()) {
+ continue;
+ }
+ lastReqId = reqPtr->ReqId();
+ if (asioThread) {
+ TRequestRef& req = ReqsInFly_[(TRequestId)reqPtr->ReqId()];
+ req.Swap(reqTmp);
+ OutputBuffers_.AddRequest(req);
+ } else { //can access to ReqsInFly_ only from asio thread, so enqueue req to update ReqsInFly_ queue
+ try {
+ reqTmp->Ref();
+ ReqsInFlyQueue_.Enqueue(reqPtr);
+ } catch (...) {
+ reqTmp->UnRef();
+ throw;
+ }
+ OutputBuffers_.AddRequest(reqTmp);
+ }
+ if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) {
+ if (!FlushOutputBuffers(asioThread, lastReqId)) {
+ return;
+ }
+ }
+ }
+ }
+
+ if (OutputBuffers_.HasData()) {
+ if (!FlushOutputBuffers(asioThread, lastReqId)) {
+ return;
+ }
+ }
+
+ OutputLock_.Release();
+
+ if (!AtomicGet(NeedCheckReqsQueue_) && !AtomicGet(NeedCheckCancelsQueue_)) {
+ DBGOUT("TClient::SendMessages(exit2)");
+ return;
+ }
+ } while (OutputLock_.TryAquire());
+ DBGOUT("TClient::SendMessages(exit1)");
+ }
+
+ TRequestId GenerateReqId() noexcept {
+ TRequestId reqId;
+ {
+ auto guard = Guard(GenReqIdLock_);
+ reqId = ++GenReqId_;
+ }
+ return Y_LIKELY(reqId) ? reqId : GenerateReqId();
+ }
+
+ //called non thread-safe (from outside thread)
+ bool FlushOutputBuffers(bool asioThread, TRequestId reqId) {
+ if (asioThread || TTcp2Options::ClientUseDirectWrite) {
+ TContIOVector& vec = *OutputBuffers_.GetIOvec();
+ TErrorCode err;
+ vec.Proceed(AS_.WriteSome(vec, err));
+
+ if (Y_UNLIKELY(err)) {
+ if (asioThread) {
+ OnErrorCode(err);
+ } else {
+ AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCode, TConnectionRef(this), err));
+ }
+ return false;
+ }
+
+ if (vec.Complete()) {
+ LastSendedReqId_.store(reqId, std::memory_order_release);
+ DBGOUT("Client::FlushOutputBuffers(" << reqId << ")");
+ OutputBuffers_.Clear();
+ return true;
+ }
+ }
+
+ DBGOUT("Client::AsyncWrite(" << reqId << ")");
+ AS_.AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), reqId, _1, _2, _3), TTcp2Options::ServerOutputDeadline);
+ return false;
+ }
+
+ //must be called only from asio thread
+ void OnSend(TRequestId reqId, const TErrorCode& ec, size_t amount, IHandlingContext&) {
+ Y_UNUSED(amount);
+ if (Y_UNLIKELY(ec)) {
+ OnErrorCode(ec);
+ } else {
+ if (Y_LIKELY(reqId)) {
+ DBGOUT("Client::OnSend(" << reqId << ")");
+ LastSendedReqId_.store(reqId, std::memory_order_release);
+ }
+ //output already aquired, used asio thread
+ OutputBuffers_.Clear();
+ SendMessages(true);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) {
+ //DBGOUT("OnCanRead(" << ec.Value() << ")");
+ if (Y_UNLIKELY(ec)) {
+ OnErrorCode(ec);
+ } else {
+ TErrorCode ec2;
+ OnReadSome(ec2, AS_.ReadSome(Buff_.Get(), BuffSize_, ec2), ctx);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) {
+ //DBGOUT("OnReadSome(" << ec.Value() << ", " << amount << ")");
+ if (Y_UNLIKELY(ec)) {
+ OnErrorCode(ec);
+
+ return;
+ }
+
+ while (1) {
+ if (Y_UNLIKELY(!amount)) {
+ OnError("tcp conn. closed");
+
+ return;
+ }
+
+ try {
+ const char* buff = Buff_.Get();
+ size_t leftBytes = amount;
+ do {
+ size_t useBytes = Msg_.LoadFrom(buff, leftBytes);
+ leftBytes -= useBytes;
+ buff += useBytes;
+ if (Msg_.IsComplete()) {
+ //DBGOUT("OnReceiveMessage(" << Msg_.BaseHeader().Id << "): " << leftBytes);
+ OnReceiveMessage();
+ Msg_.Clear();
+ }
+ } while (leftBytes);
+
+ if (amount == BuffSize_) {
+ //try decrease system calls, - re-run ReadSome if has full filled buffer
+ TErrorCode ecR;
+ amount = AS_.ReadSome(Buff_.Get(), BuffSize_, ecR);
+ if (!ecR) {
+ continue; //process next input data
+ }
+ if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) {
+ ctx.ContinueUseHandler();
+ } else {
+ OnErrorCode(ec);
+ }
+ } else {
+ ctx.ContinueUseHandler();
+ }
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+
+ return;
+ }
+ }
+
+ //must be called only from asio thread
+ void OnErrorCode(TErrorCode ec) {
+ OnError(ec.Text(), ec.Value());
+ }
+
+ //must be called only from asio thread
+ void OnErrorCallback(TString err) {
+ OnError(err);
+ }
+
+ //must be called only from asio thread
+ void OnError(const TString& err, const i32 systemCode = 0) {
+ if (AtomicGet(State_) != Closed) {
+ Error_ = err;
+ SystemCode_ = systemCode;
+ AtomicSet(State_, Closed);
+ AS_.AsyncCancel();
+ }
+ SafeOnError();
+ for (auto& it : ReqsInFly_) {
+ it.second->OnError(err);
+ }
+ ReqsInFly_.clear();
+ }
+
+ void SafeOnError() {
+ TRequest* reqPtr;
+
+ while (Reqs_.Dequeue(&reqPtr)) {
+ TRequestRef req(reqPtr);
+ reqPtr->UnRef();
+ //DBGOUT("err queue(" << AS_.Native() << "):" << size_t(reqPtr));
+ req->OnError(Error_, SystemCode_);
+ }
+
+ while (ReqsInFlyQueue_.Dequeue(&reqPtr)) {
+ TRequestRef req(reqPtr);
+ reqPtr->UnRef();
+ //DBGOUT("err fly queue(" << AS_.Native() << "):" << size_t(reqPtr));
+ req->OnError(Error_, SystemCode_);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnReceiveMessage() {
+ //DBGOUT("OnReceiveMessage");
+ const TBaseHeader& hdr = Msg_.BaseHeader();
+
+ if (hdr.Type == TBaseHeader::Response) {
+ ProcessReqsInFlyQueue();
+ TReqsInFly::iterator it = ReqsInFly_.find(hdr.Id);
+ if (it == ReqsInFly_.end()) {
+ DBGOUT("ignore response: " << hdr.Id);
+ return;
+ }
+
+ it->second->OnResponse(Msg_);
+ ReqsInFly_.erase(it);
+ } else {
+ throw yexception() << TStringBuf("unsupported message type: ") << hdr.Type;
+ }
+ }
+
+ TRequestId LastSendedRequestId() const noexcept {
+ return LastSendedReqId_.load(std::memory_order_acquire);
+ }
+
+ private:
+ NAsio::TTcpSocket AS_;
+ TAtomic State_; //state machine status (TState)
+ TString Error_;
+ i32 SystemCode_ = 0;
+
+ //input
+ size_t BuffSize_;
+ TArrayHolder<char> Buff_;
+ TTcp2Message Msg_;
+
+ //output
+ TOutputLock OutputLock_;
+ TAtomic NeedCheckReqsQueue_;
+ TLockFreeQueue<TRequest*> Reqs_;
+ TAtomic NeedCheckCancelsQueue_;
+ TLockFreeQueue<TRequestId> Cancels_;
+ TAdaptiveLock GenReqIdLock_;
+ std::atomic<TRequestId> GenReqId_;
+ std::atomic<TRequestId> LastSendedReqId_;
+ TLockFreeQueue<TRequest*> ReqsInFlyQueue_;
+ TReqsInFly ReqsInFly_;
+ TOutputBuffers OutputBuffers_;
+ };
+
+ class TDestination {
+ public:
+ void Run(TRequestRef& req) {
+ while (1) {
+ TConnectionRef conn = GetConnection();
+ if (!!conn && conn->Run(req)) {
+ return;
+ }
+
+ DBGOUT("TDestination CreateConnection");
+ CreateConnection(conn, req->Client().ExecutorsPool().GetExecutor().GetIOService());
+ }
+ }
+
+ private:
+ TConnectionRef GetConnection() {
+ TGuard<TSpinLock> g(L_);
+ return Conn_;
+ }
+
+ void CreateConnection(TConnectionRef& oldConn, TIOService& srv) {
+ TConnectionRef conn(new TConnection(srv));
+ TGuard<TSpinLock> g(L_);
+ if (Conn_ == oldConn) {
+ Conn_.Swap(conn);
+ }
+ }
+
+ TSpinLock L_;
+ TConnectionRef Conn_;
+ };
+
+ //////////// TClient /////////
+
+ public:
+ TClient()
+ : EP_(TTcp2Options::AsioClientThreads)
+ {
+ }
+
+ ~TClient() {
+ EP_.SyncShutdown();
+ }
+
+ THandleRef Schedule(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) {
+ //find exist connection or create new
+ TRequest::THandleRef hndl(new TRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+ try {
+ TRequest::Run(hndl, msg, *this);
+ } catch (...) {
+ hndl->ResetOnRecv();
+ hndl->ReleaseRequest();
+ throw;
+ }
+ return hndl.Get();
+ }
+
+ TExecutorsPool& ExecutorsPool() {
+ return EP_;
+ }
+
+ private:
+ NNeh::NHttp::TLockFreeSequence<TDestination> Dest_;
+ TExecutorsPool EP_;
+ };
+
+ ////////// server side ////////////////////////////////////////////////////////////////////////////////////////////
+
+ class TServer: public IRequester {
+ typedef TAutoPtr<TTcpAcceptor> TTcpAcceptorPtr;
+ typedef TAtomicSharedPtr<TTcpSocket> TTcpSocketRef;
+ class TConnection;
+ typedef TIntrusivePtr<TConnection> TConnectionRef;
+
+ struct TRequest: public IRequest {
+ struct TState: public TThrRefBase {
+ TState()
+ : Canceled(false)
+ {
+ }
+
+ TAtomicBool Canceled;
+ };
+ typedef TIntrusivePtr<TState> TStateRef;
+
+ TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content);
+ ~TRequest() override;
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("tcp2");
+ }
+
+ TString RemoteHost() const override;
+
+ TStringBuf Service() const override {
+ return TStringBuf(Buf.Data() + sizeof(TRequestHeader), Buf.End());
+ }
+
+ TStringBuf Data() const override {
+ return TStringBuf(Content_);
+ }
+
+ TStringBuf RequestId() const override {
+ return TStringBuf();
+ }
+
+ bool Canceled() const override {
+ return State->Canceled;
+ }
+
+ void SendReply(TData& data) override;
+
+ void SendError(TResponseError, const TString&) override {
+ // TODO
+ }
+
+ const TRequestHeader& RequestHeader() const noexcept {
+ return *reinterpret_cast<const TRequestHeader*>(Buf.Data());
+ }
+
+ private:
+ TConnectionRef Conn;
+ TBuffer Buf; //service-name + message-data
+ TString Content_;
+ TAtomic Replied_;
+
+ public:
+ TIntrusivePtr<TState> State;
+ };
+
+ class TConnection: public TThrRefBase {
+ private:
+ TConnection(TServer& srv, const TTcpSocketRef& sock)
+ : Srv_(srv)
+ , AS_(sock)
+ , Canceled_(false)
+ , RemoteHost_(NNeh::PrintHostByRfc(*AS_->RemoteEndpoint().Addr()))
+ , BuffSize_(TTcp2Options::InputBufferSize)
+ , Buff_(new char[BuffSize_])
+ , NeedCheckOutputQueue_(0)
+ {
+ DBGOUT("TServer::TConnection()");
+ }
+
+ public:
+ class TOutputBuffers: public TMultiBuffers {
+ public:
+ void AddResponse(TRequestId reqId, TData& data) {
+ TResponseHeader* hdr = new (Allocate<TResponseHeader>()) TResponseHeader(reqId, TResponseHeader::Success, data.size());
+ ResponseData_.push_back(TAutoPtr<TData>(new TData()));
+ TData& movedData = *ResponseData_.back();
+ movedData.swap(data);
+ AddPart(hdr, sizeof(TResponseHeader));
+ AddPart(movedData.data(), movedData.size());
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void AddError(TRequestId reqId, TResponseHeader::TErrorCode errCode) {
+ TResponseHeader* hdr = new (Allocate<TResponseHeader>()) TResponseHeader(reqId, errCode, 0);
+ AddPart(hdr, sizeof(TResponseHeader));
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void Clear() {
+ TMultiBuffers::Clear();
+ ResponseData_.clear();
+ }
+
+ private:
+ TVector<TAutoPtr<TData>> ResponseData_;
+ };
+
+ static void Create(TServer& srv, const TTcpSocketRef& sock) {
+ TConnectionRef conn(new TConnection(srv, sock));
+ conn->AS_->AsyncPollRead(std::bind(&TConnection::OnCanRead, conn, _1, _2), TTcp2Options::ServerInputDeadline);
+ }
+
+ ~TConnection() override {
+ DBGOUT("~TServer::TConnection(" << (!AS_ ? -666 : AS_->Native()) << ")");
+ }
+
+ private:
+ void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) {
+ if (ec) {
+ OnError();
+ } else {
+ TErrorCode ec2;
+ OnReadSome(ec2, AS_->ReadSome(Buff_.Get(), BuffSize_, ec2), ctx);
+ }
+ }
+
+ void OnError() {
+ DBGOUT("Srv OnError(" << (!AS_ ? -666 : AS_->Native()) << ")"
+ << " c=" << (size_t)this);
+ Canceled_ = true;
+ AS_->AsyncCancel();
+ }
+
+ void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) {
+ while (1) {
+ if (ec || !amount) {
+ OnError();
+ return;
+ }
+
+ try {
+ const char* buff = Buff_.Get();
+ size_t leftBytes = amount;
+ do {
+ size_t useBytes = Msg_.LoadFrom(buff, leftBytes);
+ leftBytes -= useBytes;
+ buff += useBytes;
+ if (Msg_.IsComplete()) {
+ OnReceiveMessage();
+ }
+ } while (leftBytes);
+
+ if (amount == BuffSize_) {
+ //try decrease system calls, - re-run ReadSome if has full filled buffer
+ TErrorCode ecR;
+ amount = AS_->ReadSome(Buff_.Get(), BuffSize_, ecR);
+ if (!ecR) {
+ continue;
+ }
+ if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) {
+ ctx.ContinueUseHandler();
+ } else {
+ OnError();
+ }
+ } else {
+ ctx.ContinueUseHandler();
+ }
+ } catch (...) {
+ DBGOUT("exc. " << CurrentExceptionMessage());
+ OnError();
+ }
+ return;
+ }
+ }
+
+ void OnReceiveMessage() {
+ DBGOUT("OnReceiveMessage()");
+ const TBaseHeader& hdr = Msg_.BaseHeader();
+
+ if (hdr.Type == TBaseHeader::Request) {
+ TRequest* reqPtr = new TRequest(TConnectionRef(this), Msg_.Header(), Msg_.Content());
+ IRequestRef req(reqPtr);
+ ReqsState_[reqPtr->RequestHeader().Id] = reqPtr->State;
+ OnRequest(req);
+ } else if (hdr.Type == TBaseHeader::Cancel) {
+ OnCancelRequest(hdr.Id);
+ } else {
+ throw yexception() << "unsupported message type: " << (ui32)hdr.Type;
+ }
+ Msg_.Clear();
+ {
+ TRequestId reqId;
+ while (FinReqs_.Dequeue(&reqId)) {
+ ReqsState_.erase(reqId);
+ }
+ }
+ }
+
+ void OnRequest(IRequestRef& r) {
+ DBGOUT("OnRequest()");
+ Srv_.OnRequest(r);
+ }
+
+ void OnCancelRequest(TRequestId reqId) {
+ THashMap<TRequestId, TRequest::TStateRef>::iterator it = ReqsState_.find(reqId);
+ if (it == ReqsState_.end()) {
+ return;
+ }
+
+ it->second->Canceled = true;
+ }
+
+ public:
+ class TOutputData {
+ public:
+ TOutputData(TRequestId reqId)
+ : ReqId(reqId)
+ {
+ }
+
+ virtual ~TOutputData() {
+ }
+
+ virtual void MoveTo(TOutputBuffers& bufs) = 0;
+
+ TRequestId ReqId;
+ };
+
+ class TResponseData: public TOutputData {
+ public:
+ TResponseData(TRequestId reqId, TData& data)
+ : TOutputData(reqId)
+ {
+ Data.swap(data);
+ }
+
+ void MoveTo(TOutputBuffers& bufs) override {
+ bufs.AddResponse(ReqId, Data);
+ }
+
+ TData Data;
+ };
+
+ class TResponseErrorData: public TOutputData {
+ public:
+ TResponseErrorData(TRequestId reqId, TResponseHeader::TErrorCode errorCode)
+ : TOutputData(reqId)
+ , ErrorCode(errorCode)
+ {
+ }
+
+ void MoveTo(TOutputBuffers& bufs) override {
+ bufs.AddError(ReqId, ErrorCode);
+ }
+
+ TResponseHeader::TErrorCode ErrorCode;
+ };
+
+ //called non thread-safe (from client thread)
+ void SendResponse(TRequestId reqId, TData& data) {
+ DBGOUT("SendResponse: " << reqId << " " << (size_t)~data << " c=" << (size_t)this);
+ TAutoPtr<TOutputData> od(new TResponseData(reqId, data));
+ OutputData_.Enqueue(od);
+ ProcessOutputQueue();
+ }
+
+ //called non thread-safe (from outside thread)
+ void SendError(TRequestId reqId, TResponseHeader::TErrorCode err) {
+ DBGOUT("SendResponseError: " << reqId << " c=" << (size_t)this);
+ TAutoPtr<TOutputData> od(new TResponseErrorData(reqId, err));
+ OutputData_.Enqueue(od);
+ ProcessOutputQueue();
+ }
+
+ void ProcessOutputQueue() {
+ AtomicSet(NeedCheckOutputQueue_, 1);
+ if (OutputLock_.TryAquire()) {
+ SendMessages(false);
+ return;
+ }
+ DBGOUT("ProcessOutputQueue: !AquireOutputOwnership: " << (int)OutputLock_.IsFree());
+ }
+
+ //must be called only after success aquiring output
+ void SendMessages(bool asioThread) {
+ DBGOUT("TServer::SendMessages(enter)");
+ try {
+ do {
+ AtomicUnlock(&NeedCheckOutputQueue_);
+ TAutoPtr<TOutputData> d;
+ while (OutputData_.Dequeue(&d)) {
+ d->MoveTo(OutputBuffers_);
+ if (!OutputBuffers_.HasFreeSpace()) {
+ if (!FlushOutputBuffers(asioThread)) {
+ return;
+ }
+ }
+ }
+
+ if (OutputBuffers_.HasData()) {
+ if (!FlushOutputBuffers(asioThread)) {
+ return;
+ }
+ }
+
+ OutputLock_.Release();
+
+ if (!AtomicGet(NeedCheckOutputQueue_)) {
+ DBGOUT("Server::SendMessages(exit2): " << (int)OutputLock_.IsFree());
+ return;
+ }
+ } while (OutputLock_.TryAquire());
+ DBGOUT("Server::SendMessages(exit1)");
+ } catch (...) {
+ OnError();
+ }
+ }
+
+ bool FlushOutputBuffers(bool asioThread) {
+ DBGOUT("FlushOutputBuffers: cnt=" << OutputBuffers_.GetIOvec()->Count() << " c=" << (size_t)this);
+ //TODO:reseach direct write efficiency
+ if (asioThread || TTcp2Options::ServerUseDirectWrite) {
+ TContIOVector& vec = *OutputBuffers_.GetIOvec();
+
+ vec.Proceed(AS_->WriteSome(vec));
+
+ if (vec.Complete()) {
+ OutputBuffers_.Clear();
+ //DBGOUT("WriteResponse: " << " c=" << (size_t)this);
+ return true;
+ }
+ }
+
+ //socket buffer filled - use async write for sending left data
+ DBGOUT("AsyncWriteResponse: "
+ << " [" << OutputBuffers_.GetIOvec()->Bytes() << "]"
+ << " c=" << (size_t)this);
+ AS_->AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), _1, _2, _3), TTcp2Options::ServerOutputDeadline);
+ return false;
+ }
+
+ void OnFinishRequest(TRequestId reqId) {
+ if (Y_LIKELY(!Canceled_)) {
+ FinReqs_.Enqueue(reqId);
+ }
+ }
+
+ private:
+ void OnSend(const TErrorCode& ec, size_t amount, IHandlingContext&) {
+ Y_UNUSED(amount);
+ DBGOUT("TServer::OnSend(" << ec.Value() << ", " << amount << ")");
+ if (ec) {
+ OnError();
+ } else {
+ OutputBuffers_.Clear();
+ SendMessages(true);
+ }
+ }
+
+ public:
+ bool IsCanceled() const noexcept {
+ return Canceled_;
+ }
+
+ const TString& RemoteHost() const noexcept {
+ return RemoteHost_;
+ }
+
+ private:
+ TServer& Srv_;
+ TTcpSocketRef AS_;
+ NAtomic::TBool Canceled_;
+ TString RemoteHost_;
+
+ //input
+ size_t BuffSize_;
+ TArrayHolder<char> Buff_;
+ TTcp2Message Msg_;
+ THashMap<TRequestId, TRequest::TStateRef> ReqsState_;
+ TLockFreeQueue<TRequestId> FinReqs_;
+
+ //output
+ TOutputLock OutputLock_; //protect socket/buffers from simultaneous access from few threads
+ TAtomic NeedCheckOutputQueue_;
+ NNeh::TAutoLockFreeQueue<TOutputData> OutputData_;
+ TOutputBuffers OutputBuffers_;
+ };
+
+ //////////// TServer /////////
+ public:
+ TServer(IOnRequest* cb, ui16 port)
+ : EP_(TTcp2Options::AsioServerThreads)
+ , CB_(cb)
+ {
+ TNetworkAddress addr(port);
+
+ for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) {
+ TEndpoint ep(new NAddr::TAddrInfo(&*it));
+ TTcpAcceptorPtr a(new TTcpAcceptor(EA_.GetIOService()));
+ //DBGOUT("bind:" << ep.IpToString() << ":" << ep.Port());
+ a->Bind(ep);
+ a->Listen(TTcp2Options::Backlog);
+ StartAccept(a.Get());
+ A_.push_back(a);
+ }
+ }
+
+ ~TServer() override {
+ EA_.SyncShutdown(); //cancel accepting connections
+ A_.clear(); //stop listening
+ EP_.SyncShutdown(); //close all exist connections
+ }
+
+ void StartAccept(TTcpAcceptor* a) {
+ const auto s = MakeAtomicShared<TTcpSocket>(EP_.Size() ? EP_.GetExecutor().GetIOService() : EA_.GetIOService());
+ a->AsyncAccept(*s, std::bind(&TServer::OnAccept, this, a, s, _1, _2));
+ }
+
+ void OnAccept(TTcpAcceptor* a, TTcpSocketRef s, const TErrorCode& ec, IHandlingContext&) {
+ if (Y_UNLIKELY(ec)) {
+ if (ec.Value() == ECANCELED) {
+ return;
+ } else if (ec.Value() == EMFILE || ec.Value() == ENFILE || ec.Value() == ENOMEM || ec.Value() == ENOBUFS) {
+ //reach some os limit, suspend accepting for preventing busyloop (100% cpu usage)
+ TSimpleSharedPtr<TDeadlineTimer> dt(new TDeadlineTimer(a->GetIOService()));
+ dt->AsyncWaitExpireAt(TDuration::Seconds(30), std::bind(&TServer::OnTimeoutSuspendAccept, this, a, dt, _1, _2));
+ } else {
+ Cdbg << "acc: " << ec.Text() << Endl;
+ }
+ } else {
+ SetNonBlock(s->Native());
+ PrepareSocket(s->Native());
+ TConnection::Create(*this, s);
+ }
+ StartAccept(a); //continue accepting
+ }
+
+ void OnTimeoutSuspendAccept(TTcpAcceptor* a, TSimpleSharedPtr<TDeadlineTimer>, const TErrorCode& ec, IHandlingContext&) {
+ if (!ec) {
+ DBGOUT("resume acceptor");
+ StartAccept(a);
+ }
+ }
+
+ void OnRequest(IRequestRef& r) {
+ try {
+ CB_->OnRequest(r);
+ } catch (...) {
+ Cdbg << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ private:
+ TVector<TTcpAcceptorPtr> A_;
+ TIOServiceExecutor EA_; //thread, where accepted incoming tcp connections
+ TExecutorsPool EP_; //threads, for process write/read data to/from tcp connections (if empty, use EA_ for r/w)
+ IOnRequest* CB_;
+ };
+
+ TServer::TRequest::TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content)
+ : Conn(conn)
+ , Content_(content)
+ , Replied_(0)
+ , State(new TState())
+ {
+ DBGOUT("TServer::TRequest()");
+ Buf.Swap(buf);
+ }
+
+ TServer::TRequest::~TRequest() {
+ DBGOUT("TServer::~TRequest()");
+ if (!AtomicGet(Replied_)) {
+ Conn->SendError(RequestHeader().Id, TResponseHeader::EmptyReply);
+ }
+ Conn->OnFinishRequest(RequestHeader().Id);
+ }
+
+ TString TServer::TRequest::RemoteHost() const {
+ return Conn->RemoteHost();
+ }
+
+ void TServer::TRequest::SendReply(TData& data) {
+ do {
+ if (AtomicCas(&Replied_, 1, 0)) {
+ Conn->SendResponse(RequestHeader().Id, data);
+ return;
+ }
+ } while (AtomicGet(Replied_) == 0);
+ }
+
+ class TProtocol: public IProtocol {
+ public:
+ inline TProtocol() {
+ InitNetworkSubSystem();
+ }
+
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ return new TServer(cb, loc.GetPort());
+ }
+
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ return Singleton<TClient>()->Schedule(msg, fallback, ss);
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return TStringBuf("tcp2");
+ }
+
+ bool SetOption(TStringBuf name, TStringBuf value) override {
+ return TTcp2Options::Set(name, value);
+ }
+ };
+ }
+}
+
+NNeh::IProtocol* NNeh::Tcp2Protocol() {
+ return Singleton<NNehTcp2::TProtocol>();
+}
diff --git a/library/cpp/neh/tcp2.h b/library/cpp/neh/tcp2.h
new file mode 100644
index 0000000000..bd6d8c25bd
--- /dev/null
+++ b/library/cpp/neh/tcp2.h
@@ -0,0 +1,44 @@
+#pragma once
+
+#include <util/datetime/base.h>
+#include <util/system/defaults.h>
+
+namespace NNeh {
+ //global options
+ struct TTcp2Options {
+ //connect timeout
+ static TDuration ConnectTimeout;
+
+ //input buffer size
+ static size_t InputBufferSize;
+
+ //asio client threads
+ static size_t AsioClientThreads;
+
+ //asio server threads, - if == 0, use acceptor thread for read/parse incoming requests
+ //esle use one thread for accepting + AsioServerThreads for process established tcp connections
+ static size_t AsioServerThreads;
+
+ //listen socket queue limit
+ static int Backlog;
+
+ //try call non block write to socket from client thread (for decrease latency)
+ static bool ClientUseDirectWrite;
+
+ //try call non block write to socket from client thread (for decrease latency)
+ static bool ServerUseDirectWrite;
+
+ //expecting receiving request data right after connect or inside receiving request data
+ static TDuration ServerInputDeadline;
+
+ //timelimit for sending response data
+ static TDuration ServerOutputDeadline;
+
+ //set option, - return false, if option name not recognized
+ static bool Set(TStringBuf name, TStringBuf value);
+ };
+
+ class IProtocol;
+
+ IProtocol* Tcp2Protocol();
+}
diff --git a/library/cpp/neh/udp.cpp b/library/cpp/neh/udp.cpp
new file mode 100644
index 0000000000..13250a2493
--- /dev/null
+++ b/library/cpp/neh/udp.cpp
@@ -0,0 +1,691 @@
+#include "udp.h"
+#include "details.h"
+#include "neh.h"
+#include "location.h"
+#include "utils.h"
+#include "factory.h"
+
+#include <library/cpp/dns/cache.h>
+
+#include <util/network/socket.h>
+#include <util/network/address.h>
+#include <util/generic/deque.h>
+#include <util/generic/hash.h>
+#include <util/generic/string.h>
+#include <util/generic/buffer.h>
+#include <util/generic/singleton.h>
+#include <util/digest/murmur.h>
+#include <util/random/random.h>
+#include <util/ysaveload.h>
+#include <util/system/thread.h>
+#include <util/system/pipe.h>
+#include <util/system/error.h>
+#include <util/stream/mem.h>
+#include <util/stream/buffer.h>
+#include <util/string/cast.h>
+
+using namespace NNeh;
+using namespace NDns;
+using namespace NAddr;
+
+namespace {
+ namespace NUdp {
+ enum EPacketType {
+ PT_REQUEST = 1,
+ PT_RESPONSE = 2,
+ PT_STOP = 3,
+ PT_TIMEOUT = 4
+ };
+
+ struct TUdpHandle: public TNotifyHandle {
+ inline TUdpHandle(IOnRecv* r, const TMessage& msg, TStatCollector* sc) noexcept
+ : TNotifyHandle(r, msg, sc)
+ {
+ }
+
+ void Cancel() noexcept override {
+ THandle::Cancel(); //inform stat collector
+ }
+
+ bool MessageSendedCompletely() const noexcept override {
+ //TODO
+ return true;
+ }
+ };
+
+ static inline IRemoteAddrPtr GetSendAddr(SOCKET s) {
+ IRemoteAddrPtr local = GetSockAddr(s);
+ const sockaddr* addr = local->Addr();
+
+ switch (addr->sa_family) {
+ case AF_INET: {
+ const TIpAddress a = *(const sockaddr_in*)addr;
+
+ return MakeHolder<TIPv4Addr>(TIpAddress(InetToHost(INADDR_LOOPBACK), a.Port()));
+ }
+
+ case AF_INET6: {
+ sockaddr_in6 a = *(const sockaddr_in6*)addr;
+
+ a.sin6_addr = in6addr_loopback;
+
+ return MakeHolder<TIPv6Addr>(a);
+ }
+ }
+
+ ythrow yexception() << "unsupported";
+ }
+
+ typedef ui32 TCheckSum;
+
+ static inline TString GenerateGuid() {
+ const ui64 res[2] = {
+ RandomNumber<ui64>(), RandomNumber<ui64>()};
+
+ return TString((const char*)res, sizeof(res));
+ }
+
+ static inline TCheckSum Sum(const TStringBuf& s) noexcept {
+ return HostToInet(MurmurHash<TCheckSum>(s.data(), s.size()));
+ }
+
+ struct TPacket;
+
+ template <class T>
+ static inline void Serialize(TPacket& p, const T& t);
+
+ struct TPacket {
+ inline TPacket(IRemoteAddrPtr addr)
+ : Addr(std::move(addr))
+ {
+ }
+
+ template <class T>
+ inline TPacket(const T& t, IRemoteAddrPtr addr)
+ : Addr(std::move(addr))
+ {
+ NUdp::Serialize(*this, t);
+ }
+
+ inline TPacket(TSocketHolder& s, TBuffer& tmp) {
+ TAutoPtr<TOpaqueAddr> addr(new TOpaqueAddr());
+
+ retry_on_intr : {
+ const int rv = recvfrom(s, tmp.Data(), tmp.size(), MSG_WAITALL, addr->MutableAddr(), addr->LenPtr());
+
+ if (rv < 0) {
+ int err = LastSystemError();
+ if (err == EAGAIN || err == EWOULDBLOCK) {
+ Data.Resize(sizeof(TCheckSum) + 1);
+ *(Data.data() + sizeof(TCheckSum)) = static_cast<char>(PT_TIMEOUT);
+ } else if (err == EINTR) {
+ goto retry_on_intr;
+ } else {
+ ythrow TSystemError() << "recv failed";
+ }
+ } else {
+ Data.Append(tmp.Data(), (size_t)rv);
+ Addr.Reset(addr.Release());
+ CheckSign();
+ }
+ }
+ }
+
+ inline void SendTo(TSocketHolder& s) {
+ Sign();
+
+ if (sendto(s, Data.data(), Data.size(), 0, Addr->Addr(), Addr->Len()) < 0) {
+ Cdbg << LastSystemErrorText() << Endl;
+ }
+ }
+
+ IRemoteAddrPtr Addr;
+ TBuffer Data;
+
+ inline void Sign() {
+ const TCheckSum sum = CalcSign();
+
+ memcpy(Data.Data(), &sum, sizeof(sum));
+ }
+
+ inline char Type() const noexcept {
+ return *(Data.data() + sizeof(TCheckSum));
+ }
+
+ inline void CheckSign() const {
+ if (Data.size() < 16) {
+ ythrow yexception() << "small packet";
+ }
+
+ if (StoredSign() != CalcSign()) {
+ ythrow yexception() << "bad checksum";
+ }
+ }
+
+ inline TCheckSum StoredSign() const noexcept {
+ TCheckSum sum;
+
+ memcpy(&sum, Data.Data(), sizeof(sum));
+
+ return sum;
+ }
+
+ inline TCheckSum CalcSign() const noexcept {
+ return Sum(Body());
+ }
+
+ inline TStringBuf Body() const noexcept {
+ return TStringBuf(Data.data() + sizeof(TCheckSum), Data.End());
+ }
+ };
+
+ typedef TAutoPtr<TPacket> TPacketRef;
+
+ class TPacketInput: public TMemoryInput {
+ public:
+ inline TPacketInput(const TPacket& p)
+ : TMemoryInput(p.Body().data(), p.Body().size())
+ {
+ }
+ };
+
+ class TPacketOutput: public TBufferOutput {
+ public:
+ inline TPacketOutput(TPacket& p)
+ : TBufferOutput(p.Data)
+ {
+ p.Data.Proceed(sizeof(TCheckSum));
+ }
+ };
+
+ template <class T>
+ static inline void Serialize(TPacketOutput* out, const T& t) {
+ Save(out, t.Type());
+ t.Serialize(out);
+ }
+
+ template <class T>
+ static inline void Serialize(TPacket& p, const T& t) {
+ TPacketOutput out(p);
+
+ NUdp::Serialize(&out, t);
+ }
+
+ namespace NPrivate {
+ template <class T>
+ static inline void Deserialize(TPacketInput* in, T& t) {
+ char type;
+ Load(in, type);
+
+ if (type != t.Type()) {
+ ythrow yexception() << "unsupported packet";
+ }
+
+ t.Deserialize(in);
+ }
+
+ template <class T>
+ static inline void Deserialize(const TPacket& p, T& t) {
+ TPacketInput in(p);
+
+ Deserialize(&in, t);
+ }
+ }
+
+ struct TRequestPacket {
+ TString Guid;
+ TString Service;
+ TString Data;
+
+ inline TRequestPacket(const TPacket& p) {
+ NPrivate::Deserialize(p, *this);
+ }
+
+ inline TRequestPacket(const TString& srv, const TString& data)
+ : Guid(GenerateGuid())
+ , Service(srv)
+ , Data(data)
+ {
+ }
+
+ inline char Type() const noexcept {
+ return static_cast<char>(PT_REQUEST);
+ }
+
+ inline void Serialize(TPacketOutput* out) const {
+ Save(out, Guid);
+ Save(out, Service);
+ Save(out, Data);
+ }
+
+ inline void Deserialize(TPacketInput* in) {
+ Load(in, Guid);
+ Load(in, Service);
+ Load(in, Data);
+ }
+ };
+
+ template <class TStore>
+ struct TResponsePacket {
+ TString Guid;
+ TStore Data;
+
+ inline TResponsePacket(const TString& guid, TStore& data)
+ : Guid(guid)
+ {
+ Data.swap(data);
+ }
+
+ inline TResponsePacket(const TPacket& p) {
+ NPrivate::Deserialize(p, *this);
+ }
+
+ inline char Type() const noexcept {
+ return static_cast<char>(PT_RESPONSE);
+ }
+
+ inline void Serialize(TPacketOutput* out) const {
+ Save(out, Guid);
+ Save(out, Data);
+ }
+
+ inline void Deserialize(TPacketInput* in) {
+ Load(in, Guid);
+ Load(in, Data);
+ }
+ };
+
+ struct TStopPacket {
+ inline char Type() const noexcept {
+ return static_cast<char>(PT_STOP);
+ }
+
+ inline void Serialize(TPacketOutput* out) const {
+ Save(out, TString("stop packet"));
+ }
+ };
+
+ struct TBindError: public TSystemError {
+ };
+
+ struct TSocketDescr {
+ inline TSocketDescr(TSocketHolder& s, int family)
+ : S(s.Release())
+ , Family(family)
+ {
+ }
+
+ TSocketHolder S;
+ int Family;
+ };
+
+ typedef TAutoPtr<TSocketDescr> TSocketRef;
+ typedef TVector<TSocketRef> TSockets;
+
+ static inline void CreateSocket(TSocketHolder& s, const IRemoteAddr& addr) {
+ TSocketHolder res(socket(addr.Addr()->sa_family, SOCK_DGRAM, IPPROTO_UDP));
+
+ if (!res) {
+ ythrow TSystemError() << "can not create socket";
+ }
+
+ FixIPv6ListenSocket(res);
+
+ if (bind(res, addr.Addr(), addr.Len()) != 0) {
+ ythrow TBindError() << "can not bind " << PrintHostAndPort(addr);
+ }
+
+ res.Swap(s);
+ }
+
+ static inline void CreateSockets(TSockets& s, ui16 port) {
+ TNetworkAddress addr(port);
+
+ for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) {
+ TSocketHolder res;
+
+ CreateSocket(res, TAddrInfo(&*it));
+
+ s.push_back(new TSocketDescr(res, it->ai_family));
+ }
+ }
+
+ static inline void CreateSocketsOnRandomPort(TSockets& s) {
+ while (true) {
+ try {
+ TSockets tmp;
+
+ CreateSockets(tmp, 5000 + (RandomNumber<ui16>() % 1000));
+ tmp.swap(s);
+
+ return;
+ } catch (const TBindError&) {
+ }
+ }
+ }
+
+ typedef ui64 TTimeStamp;
+
+ static inline TTimeStamp TimeStamp() noexcept {
+ return GetCycleCount() >> 31;
+ }
+
+ struct TRequestDescr: public TIntrusiveListItem<TRequestDescr> {
+ inline TRequestDescr(const TString& guid, const TNotifyHandleRef& hndl, const TMessage& msg)
+ : Guid(guid)
+ , Hndl(hndl)
+ , Msg(msg)
+ , TS(TimeStamp())
+ {
+ }
+
+ TString Guid;
+ TNotifyHandleRef Hndl;
+ TMessage Msg;
+ TTimeStamp TS;
+ };
+
+ typedef TAutoPtr<TRequestDescr> TRequestDescrRef;
+
+ class TProto {
+ class TRequest: public IRequest, public TRequestPacket {
+ public:
+ inline TRequest(TPacket& p, TProto* parent)
+ : TRequestPacket(p)
+ , Addr_(std::move(p.Addr))
+ , H_(PrintHostByRfc(*Addr_))
+ , P_(parent)
+ {
+ }
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("udp");
+ }
+
+ TString RemoteHost() const override {
+ return H_;
+ }
+
+ TStringBuf Service() const override {
+ return ((TRequestPacket&)(*this)).Service;
+ }
+
+ TStringBuf Data() const override {
+ return ((TRequestPacket&)(*this)).Data;
+ }
+
+ TStringBuf RequestId() const override {
+ return ((TRequestPacket&)(*this)).Guid;
+ }
+
+ bool Canceled() const override {
+ //TODO ?
+ return false;
+ }
+
+ void SendReply(TData& data) override {
+ P_->Schedule(new TPacket(TResponsePacket<TData>(Guid, data), std::move(Addr_)));
+ }
+
+ void SendError(TResponseError, const TString&) override {
+ // TODO
+ }
+
+ private:
+ IRemoteAddrPtr Addr_;
+ TString H_;
+ TProto* P_;
+ };
+
+ public:
+ inline TProto(IOnRequest* cb, TSocketHolder& s)
+ : CB_(cb)
+ , ToSendEv_(TSystemEvent::rAuto)
+ , S_(s.Release())
+ {
+ SetSocketTimeout(S_, 10);
+ Thrs_.push_back(Spawn<TProto, &TProto::ExecuteRecv>(this));
+ Thrs_.push_back(Spawn<TProto, &TProto::ExecuteSend>(this));
+ }
+
+ inline ~TProto() {
+ Schedule(new TPacket(TStopPacket(), GetSendAddr(S_)));
+
+ for (size_t i = 0; i < Thrs_.size(); ++i) {
+ Thrs_[i]->Join();
+ }
+ }
+
+ inline TPacketRef Recv() {
+ TBuffer tmp;
+
+ tmp.Resize(128 * 1024);
+
+ while (true) {
+ try {
+ return new TPacket(S_, tmp);
+ } catch (...) {
+ Cdbg << CurrentExceptionMessage() << Endl;
+
+ continue;
+ }
+ }
+ }
+
+ typedef THashMap<TString, TRequestDescrRef> TInFlyBase;
+
+ struct TInFly: public TInFlyBase, public TIntrusiveList<TRequestDescr> {
+ typedef TInFlyBase::iterator TIter;
+ typedef TInFlyBase::const_iterator TContsIter;
+
+ inline void Insert(TRequestDescrRef& d) {
+ PushBack(d.Get());
+ (*this)[d->Guid] = d;
+ }
+
+ inline void EraseStale() noexcept {
+ const TTimeStamp now = TimeStamp();
+
+ for (TIterator it = Begin(); (it != End()) && (it->TS < now) && ((now - it->TS) > 120);) {
+ it->Hndl->NotifyError("request timeout");
+ TString safe_key = (it++)->Guid;
+ erase(safe_key);
+ }
+ }
+ };
+
+ inline void ExecuteRecv() {
+ SetHighestThreadPriority();
+
+ TInFly infly;
+
+ while (true) {
+ TPacketRef p = Recv();
+
+ switch (static_cast<EPacketType>(p->Type())) {
+ case PT_REQUEST:
+ if (CB_) {
+ CB_->OnRequest(new TRequest(*p, this));
+ } else {
+ //skip request in case of client
+ }
+
+ break;
+
+ case PT_RESPONSE: {
+ CancelStaleRequests(infly);
+
+ TResponsePacket<TString> rp(*p);
+
+ TInFly::TIter it = static_cast<TInFlyBase&>(infly).find(rp.Guid);
+
+ if (it == static_cast<TInFlyBase&>(infly).end()) {
+ break;
+ }
+
+ const TRequestDescrRef& d = it->second;
+ d->Hndl->NotifyResponse(rp.Data);
+
+ infly.erase(it);
+
+ break;
+ }
+
+ case PT_STOP:
+ Schedule(nullptr);
+
+ return;
+
+ case PT_TIMEOUT:
+ CancelStaleRequests(infly);
+
+ break;
+ }
+ }
+ }
+
+ inline void ExecuteSend() {
+ SetHighestThreadPriority();
+
+ while (true) {
+ TPacketRef p;
+
+ while (!ToSend_.Dequeue(&p)) {
+ ToSendEv_.Wait();
+ }
+
+ //shutdown
+ if (!p) {
+ return;
+ }
+
+ p->SendTo(S_);
+ }
+ }
+
+ inline void Schedule(TPacketRef p) {
+ ToSend_.Enqueue(p);
+ ToSendEv_.Signal();
+ }
+
+ inline void Schedule(TRequestDescrRef dsc, TPacketRef p) {
+ ScheduledReqs_.Enqueue(dsc);
+ Schedule(p);
+ }
+
+ protected:
+ void CancelStaleRequests(TInFly& infly) {
+ TRequestDescrRef d;
+
+ while (ScheduledReqs_.Dequeue(&d)) {
+ infly.Insert(d);
+ }
+
+ infly.EraseStale();
+ }
+
+ IOnRequest* CB_;
+ NNeh::TAutoLockFreeQueue<TPacket> ToSend_;
+ NNeh::TAutoLockFreeQueue<TRequestDescr> ScheduledReqs_;
+ TSystemEvent ToSendEv_;
+ TSocketHolder S_;
+ TVector<TThreadRef> Thrs_;
+ };
+
+ class TProtos {
+ public:
+ inline TProtos() {
+ TSockets s;
+
+ CreateSocketsOnRandomPort(s);
+ Init(nullptr, s);
+ }
+
+ inline TProtos(IOnRequest* cb, ui16 port) {
+ TSockets s;
+
+ CreateSockets(s, port);
+ Init(cb, s);
+ }
+
+ static inline TProtos* Instance() {
+ return Singleton<TProtos>();
+ }
+
+ inline void Schedule(const TMessage& msg, const TNotifyHandleRef& hndl) {
+ TParsedLocation loc(msg.Addr);
+ const TNetworkAddress* addr = &CachedThrResolve(TResolveInfo(loc.Host, loc.GetPort()))->Addr;
+
+ for (TNetworkAddress::TIterator ai = addr->Begin(); ai != addr->End(); ++ai) {
+ TProto* proto = Find(ai->ai_family);
+
+ if (proto) {
+ TRequestPacket rp(ToString(loc.Service), msg.Data);
+ TRequestDescrRef rd(new TRequestDescr(rp.Guid, hndl, msg));
+ IRemoteAddrPtr raddr(new TAddrInfo(&*ai));
+ TPacketRef p(new TPacket(rp, std::move(raddr)));
+
+ proto->Schedule(rd, p);
+
+ return;
+ }
+ }
+
+ ythrow yexception() << "unsupported protocol family";
+ }
+
+ private:
+ inline void Init(IOnRequest* cb, TSockets& s) {
+ for (auto& it : s) {
+ P_[it->Family] = new TProto(cb, it->S);
+ }
+ }
+
+ inline TProto* Find(int family) const {
+ TProtoStorage::const_iterator it = P_.find(family);
+
+ if (it == P_.end()) {
+ return nullptr;
+ }
+
+ return it->second.Get();
+ }
+
+ private:
+ typedef TAutoPtr<TProto> TProtoRef;
+ typedef THashMap<int, TProtoRef> TProtoStorage;
+ TProtoStorage P_;
+ };
+
+ class TRequester: public IRequester, public TProtos {
+ public:
+ inline TRequester(IOnRequest* cb, ui16 port)
+ : TProtos(cb, port)
+ {
+ }
+ };
+
+ class TProtocol: public IProtocol {
+ public:
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ return new TRequester(cb, loc.GetPort());
+ }
+
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ TNotifyHandleRef ret(new TUdpHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+
+ TProtos::Instance()->Schedule(msg, ret);
+
+ return ret.Get();
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return TStringBuf("udp");
+ }
+ };
+ }
+}
+
+IProtocol* NNeh::UdpProtocol() {
+ return Singleton<NUdp::TProtocol>();
+}
diff --git a/library/cpp/neh/udp.h b/library/cpp/neh/udp.h
new file mode 100644
index 0000000000..491df042f2
--- /dev/null
+++ b/library/cpp/neh/udp.h
@@ -0,0 +1,7 @@
+#pragma once
+
+namespace NNeh {
+ class IProtocol;
+
+ IProtocol* UdpProtocol();
+}
diff --git a/library/cpp/neh/utils.cpp b/library/cpp/neh/utils.cpp
new file mode 100644
index 0000000000..2f8671c581
--- /dev/null
+++ b/library/cpp/neh/utils.cpp
@@ -0,0 +1,47 @@
+#include "utils.h"
+
+#include <util/generic/utility.h>
+#include <util/stream/output.h>
+#include <util/stream/str.h>
+#include <util/system/error.h>
+
+#if defined(_unix_)
+#include <pthread.h>
+#endif
+
+#if defined(_win_)
+#include <windows.h>
+#endif
+
+using namespace NNeh;
+
+size_t NNeh::RealStackSize(size_t len) noexcept {
+#if defined(NDEBUG) && !defined(_san_enabled_)
+ return len;
+#else
+ return Max<size_t>(len, 64000);
+#endif
+}
+
+TString NNeh::PrintHostByRfc(const NAddr::IRemoteAddr& addr) {
+ TStringStream ss;
+
+ if (addr.Addr()->sa_family == AF_INET) {
+ NAddr::PrintHost(ss, addr);
+ } else if (addr.Addr()->sa_family == AF_INET6) {
+ ss << '[';
+ NAddr::PrintHost(ss, addr);
+ ss << ']';
+ }
+ return ss.Str();
+}
+
+NAddr::IRemoteAddrPtr NNeh::GetPeerAddr(SOCKET s) {
+ TAutoPtr<NAddr::TOpaqueAddr> addr(new NAddr::TOpaqueAddr());
+
+ if (getpeername(s, addr->MutableAddr(), addr->LenPtr()) < 0) {
+ ythrow TSystemError() << "getpeername() failed";
+ }
+
+ return NAddr::IRemoteAddrPtr(addr.Release());
+}
diff --git a/library/cpp/neh/utils.h b/library/cpp/neh/utils.h
new file mode 100644
index 0000000000..ff1f63c2df
--- /dev/null
+++ b/library/cpp/neh/utils.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <library/cpp/threading/atomic/bool.h>
+
+#include <util/network/address.h>
+#include <util/system/thread.h>
+#include <util/generic/cast.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+
+namespace NNeh {
+ typedef TAutoPtr<TThread> TThreadRef;
+
+ template <typename T, void (T::*M)()>
+ static void* HelperMemberFunc(void* arg) {
+ T* obj = reinterpret_cast<T*>(arg);
+ (obj->*M)();
+ return nullptr;
+ }
+
+ template <typename T, void (T::*M)()>
+ static TThreadRef Spawn(T* t) {
+ TThreadRef thr(new TThread(HelperMemberFunc<T, M>, t));
+
+ thr->Start();
+
+ return thr;
+ }
+
+ size_t RealStackSize(size_t len) noexcept;
+
+ //from rfc3986:
+ //host = IP-literal / IPv4address / reg-name
+ //IP-literal = "[" ( IPv6address / IPvFuture ) "]"
+ TString PrintHostByRfc(const NAddr::IRemoteAddr& addr);
+
+ NAddr::IRemoteAddrPtr GetPeerAddr(SOCKET s);
+
+ using TAtomicBool = NAtomic::TBool;
+}
diff --git a/library/cpp/neh/wfmo.h b/library/cpp/neh/wfmo.h
new file mode 100644
index 0000000000..11f32dda22
--- /dev/null
+++ b/library/cpp/neh/wfmo.h
@@ -0,0 +1,140 @@
+#pragma once
+
+#include "lfqueue.h"
+
+#include <library/cpp/threading/atomic/bool.h>
+
+#include <util/generic/vector.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <library/cpp/deprecated/atomic/atomic_ops.h>
+#include <util/system/event.h>
+#include <util/system/spinlock.h>
+
+namespace NNeh {
+ template <class T>
+ class TBlockedQueue: public TLockFreeQueue<T>, public TSystemEvent {
+ public:
+ inline TBlockedQueue() noexcept
+ : TSystemEvent(TSystemEvent::rAuto)
+ {
+ }
+
+ inline void Notify(T t) noexcept {
+ this->Enqueue(t);
+ Signal();
+ }
+ };
+
+ class TWaitQueue {
+ public:
+ struct TWaitHandle {
+ inline TWaitHandle() noexcept
+ : Signalled(false)
+ , Parent(nullptr)
+ {
+ }
+
+ inline void Signal() noexcept {
+ TGuard<TSpinLock> lock(M_);
+
+ Signalled = true;
+
+ if (Parent) {
+ Parent->Notify(this);
+ }
+ }
+
+ inline void Register(TWaitQueue* parent) noexcept {
+ TGuard<TSpinLock> lock(M_);
+
+ Parent = parent;
+
+ if (Signalled) {
+ if (Parent) {
+ Parent->Notify(this);
+ }
+ }
+ }
+
+ NAtomic::TBool Signalled;
+ TWaitQueue* Parent;
+ TSpinLock M_;
+ };
+
+ inline ~TWaitQueue() {
+ for (size_t i = 0; i < H_.size(); ++i) {
+ H_[i]->Register(nullptr);
+ }
+ }
+
+ inline void Register(TWaitHandle& ev) {
+ H_.push_back(&ev);
+ ev.Register(this);
+ }
+
+ template <class T>
+ inline void Register(const T& ev) {
+ Register(static_cast<TWaitHandle&>(*ev));
+ }
+
+ inline bool Wait(const TInstant& deadLine) noexcept {
+ return Q_.WaitD(deadLine);
+ }
+
+ inline void Notify(TWaitHandle* wq) noexcept {
+ Q_.Notify(wq);
+ }
+
+ inline bool Dequeue(TWaitHandle** wq) noexcept {
+ return Q_.Dequeue(wq);
+ }
+
+ private:
+ TBlockedQueue<TWaitHandle*> Q_;
+ TVector<TWaitHandle*> H_;
+ };
+
+ typedef TWaitQueue::TWaitHandle TWaitHandle;
+
+ template <class It, class T>
+ static inline void WaitForMultipleObj(It b, It e, const TInstant& deadLine, T& func) {
+ TWaitQueue hndl;
+
+ while (b != e) {
+ hndl.Register(*b++);
+ }
+
+ do {
+ TWaitHandle* ret = nullptr;
+
+ if (hndl.Dequeue(&ret)) {
+ do {
+ func(ret);
+ } while (hndl.Dequeue(&ret));
+
+ return;
+ }
+ } while (hndl.Wait(deadLine));
+ }
+
+ struct TSignalled {
+ inline TSignalled()
+ : Signalled(false)
+ {
+ }
+
+ inline void operator()(const TWaitHandle*) noexcept {
+ Signalled = true;
+ }
+
+ bool Signalled;
+ };
+
+ static inline bool WaitForOne(TWaitHandle& wh, const TInstant& deadLine) {
+ TSignalled func;
+
+ WaitForMultipleObj(&wh, &wh + 1, deadLine, func);
+
+ return func.Signalled;
+ }
+}