diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/http/server/http.cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/http/server/http.cpp')
-rw-r--r-- | library/cpp/http/server/http.cpp | 843 |
1 files changed, 843 insertions, 0 deletions
diff --git a/library/cpp/http/server/http.cpp b/library/cpp/http/server/http.cpp new file mode 100644 index 0000000000..128583bdd7 --- /dev/null +++ b/library/cpp/http/server/http.cpp @@ -0,0 +1,843 @@ +#include "http.h" +#include "http_ex.h" + +#include <library/cpp/threading/equeue/equeue.h> + +#include <util/generic/buffer.h> +#include <util/generic/cast.h> +#include <util/generic/intrlist.h> +#include <util/generic/yexception.h> +#include <util/network/address.h> +#include <util/network/socket.h> +#include <util/network/poller.h> +#include <util/system/atomic.h> +#include <util/system/compat.h> // stricmp, strnicmp, strlwr, strupr, stpcpy +#include <util/system/defaults.h> +#include <util/system/event.h> +#include <util/system/mutex.h> +#include <util/system/pipe.h> +#include <util/system/thread.h> +#include <util/thread/factory.h> + +#include <cerrno> +#include <cstring> +#include <ctime> + +#include <sys/stat.h> +#include <sys/types.h> + +using namespace NAddr; + +namespace { + class IPollAble { + public: + inline IPollAble() noexcept { + } + + virtual ~IPollAble() { + } + + virtual void OnPollEvent(TInstant now) = 0; + }; + + struct TShouldStop { + }; + + struct TWakeupPollAble: public IPollAble { + void OnPollEvent(TInstant) override { + throw TShouldStop(); + } + }; +} + +class TClientConnection: public IPollAble, public TIntrusiveListItem<TClientConnection> { +public: + TClientConnection(const TSocket& s, THttpServer::TImpl* serv, NAddr::IRemoteAddrRef listenerSockAddrRef); + ~TClientConnection() override; + + void OnPollEvent(TInstant now) override; + + inline void Activate(TInstant now) noexcept; + inline void DeActivate(); + inline void Reject(); + +public: + TSocket Socket_; + NAddr::IRemoteAddrRef ListenerSockAddrRef_; + THttpServer::TImpl* HttpServ_ = nullptr; + bool Reject_ = false; + TInstant LastUsed; + TInstant AcceptMoment; + size_t ReceivedRequests = 0; +}; + +class THttpServer::TImpl { +public: + class TConnections { + public: + inline TConnections(TSocketPoller* poller, const THttpServerOptions& options) + : Poller_(poller) + , Options(options) + { + } + + inline ~TConnections() { + } + + inline void Add(TClientConnection* c) noexcept { + TGuard<TMutex> g(Mutex_); + + Conns_.PushBack(c); + Poller_->WaitRead(c->Socket_, (void*)static_cast<const IPollAble*>(c)); + } + + inline void Erase(TClientConnection* c, TInstant now) noexcept { + TGuard<TMutex> g(Mutex_); + EraseUnsafe(c); + if (Options.ExpirationTimeout > TDuration::Zero()) { + TryRemovingUnsafe(now - Options.ExpirationTimeout); + } + } + + inline void Clear() noexcept { + TGuard<TMutex> g(Mutex_); + + Conns_.Clear(); + } + + inline bool RemoveOld(TInstant border) noexcept { + TGuard<TMutex> g(Mutex_); + return TryRemovingUnsafe(border); + } + + bool TryRemovingUnsafe(TInstant border) noexcept { + if (Conns_.Empty()) { + return false; + } + TClientConnection* c = &*(Conns_.Begin()); + if (c->LastUsed > border) { + return false; + } + EraseUnsafe(c); + delete c; + return true; + } + + void EraseUnsafe(TClientConnection* c) noexcept { + Poller_->Unwait(c->Socket_); + c->Unlink(); + } + + public: + TMutex Mutex_; + TIntrusiveListWithAutoDelete<TClientConnection, TDelete> Conns_; + TSocketPoller* Poller_ = nullptr; + const THttpServerOptions& Options; + }; + + static void* ListenSocketFunction(void* param) { + try { + ((TImpl*)param)->ListenSocket(); + } catch (...) { + + } + + return nullptr; + } + + TAutoPtr<TClientRequest> CreateRequest(TAutoPtr<TClientConnection> c) { + THolder<TClientRequest> obj(Cb_->CreateClient()); + + obj->Conn_.Reset(c.Release()); + + return obj; + } + + void AddRequestFromSocket(const TSocket& s, TInstant now, NAddr::IRemoteAddrRef listenerSockAddrRef) { + if (MaxRequestsReached()) { + Cb_->OnMaxConn(); + bool wasRemoved = Connections->RemoveOld(TInstant::Max()); + if (!wasRemoved && Options_.RejectExcessConnections) { + (new TClientConnection(s, this, listenerSockAddrRef))->Reject(); + return; + } + } + + auto connection = new TClientConnection(s, this, listenerSockAddrRef); + connection->LastUsed = now; + connection->DeActivate(); + } + + void SaveErrorCode() { + ErrorCode = WSAGetLastError(); + } + + int GetErrorCode() const { + return ErrorCode; + } + + const char* GetError() const { + return LastSystemErrorText(ErrorCode); + } + + bool Start() { + Poller.Reset(new TSocketPoller()); + Connections.Reset(new TConnections(Poller.Get(), Options_)); + + // Start the listener thread + ListenerRunningOK = false; + + // throws on error + TPipeHandle::Pipe(ListenWakeupReadFd, ListenWakeupWriteFd); + + SetNonBlock(ListenWakeupWriteFd, true); + SetNonBlock(ListenWakeupReadFd, true); + + Poller->WaitRead(ListenWakeupReadFd, &WakeupPollAble); + + ListenStartEvent.Reset(); + try { + ListenThread.Reset(new TThread(ListenSocketFunction, this)); + ListenThread->Start(); + } catch (const yexception&) { + SaveErrorCode(); + return false; + } + + // Wait until the thread has completely started and return the success indicator + ListenStartEvent.Wait(); + + return ListenerRunningOK; + } + + void Wait() { + Cb_->OnWait(); + TGuard<TMutex> g(StopMutex); + if (ListenThread) { + ListenThread->Join(); + ListenThread.Reset(nullptr); + } + } + + void Stop() { + Shutdown(); + + TGuard<TMutex> g(StopMutex); + if (ListenThread) { + ListenThread->Join(); + ListenThread.Reset(nullptr); + } + + while (ConnectionCount) { + usleep(10000); + Connections->Clear(); + } + + Connections.Destroy(); + Poller.Destroy(); + } + + void Shutdown() { + ListenWakeupWriteFd.Write("", 1); + // ignore result + } + + void AddRequest(TAutoPtr<TClientRequest> req, bool fail) { + struct TFailRequest: public THttpClientRequestEx { + inline TFailRequest(TAutoPtr<TClientRequest> parent) { + Conn_.Reset(parent->Conn_.Release()); + HttpConn_.Reset(parent->HttpConn_.Release()); + } + + bool Reply(void*) override { + if (!ProcessHeaders()) { + return true; + } + + ProcessFailRequest(0); + return true; + } + }; + + if (!fail && Requests->Add(req.Get())) { + Y_UNUSED(req.Release()); + } else { + req = new TFailRequest(req); + + if (FailRequests->Add(req.Get())) { + Y_UNUSED(req.Release()); + } else { + Cb_->OnFailRequest(-1); + } + } + } + + size_t GetRequestQueueSize() const { + return Requests->Size(); + } + + size_t GetFailQueueSize() const { + return FailRequests->Size(); + } + + const IThreadPool& GetRequestQueue() const { + return *Requests; + } + + const IThreadPool& GetFailQueue() const { + return *FailRequests; + } + + class TListenSocket: public IPollAble, public TIntrusiveListItem<TListenSocket> { + public: + inline TListenSocket(const TSocket& s, TImpl* parent) + : S_(s) + , Server_(parent) + , SockAddrRef_(GetSockAddr(S_)) + { + } + + ~TListenSocket() override { + } + + void OnPollEvent(TInstant) override { + SOCKET s = ::accept(S_, nullptr, nullptr); + + if (s == INVALID_SOCKET) { + ythrow yexception() << "accept: " << LastSystemErrorText(); + } + + Server_->AddRequestFromSocket(s, TInstant::Now(), SockAddrRef_); + } + + SOCKET GetSocket() const noexcept { + return S_; + } + + private: + TSocket S_; + TImpl* Server_ = nullptr; + NAddr::IRemoteAddrRef SockAddrRef_; + }; + + void ListenSocket() { + TThread::SetCurrentThreadName(Options_.ListenThreadName.c_str()); + + ErrorCode = 0; + TIntrusiveListWithAutoDelete<TListenSocket, TDelete> Reqs; + + std::function<void(TSocket)> callback = [&](TSocket socket) { + THolder<TListenSocket> ls(new TListenSocket(socket, this)); + Poller->WaitRead(socket, static_cast<IPollAble*>(ls.Get())); + Reqs.PushBack(ls.Release()); + }; + bool addressesBound = TryToBindAddresses(Options_, &callback); + if (!addressesBound) { + SaveErrorCode(); + ListenStartEvent.Signal(); + + return; + } + + Requests->Start(Options_.nThreads, Options_.MaxQueueSize); + FailRequests->Start(Options_.nFThreads, Options_.MaxFQueueSize); + Cb_->OnListenStart(); + ListenerRunningOK = true; + ListenStartEvent.Signal(); + + TVector<void*> events; + events.resize(1); + + TInstant now = TInstant::Now(); + for (;;) { + try { + const TInstant deadline = Options_.PollTimeout == TDuration::Zero() ? TInstant::Max() : now + Options_.PollTimeout; + const size_t ret = Poller->WaitD(events.data(), events.size(), deadline); + + now = TInstant::Now(); + for (size_t i = 0; i < ret; ++i) { + ((IPollAble*)events[i])->OnPollEvent(now); + } + + if (ret == 0 && Options_.ExpirationTimeout > TDuration::Zero()) { + Connections->RemoveOld(now - Options_.ExpirationTimeout); + } + + // When MaxConnections is limited or ExpirationTimeout is set, OnPollEvent can call + // RemoveOld and destroy other IPollAble* objects in the + // poller. Thus in this case we can safely process only one + // event from the poller at a time. + if (!Options_.MaxConnections && Options_.ExpirationTimeout == TDuration::Zero()) { + if (ret >= events.size()) { + events.resize(ret * 2); + } + } + } catch (const TShouldStop&) { + break; + } catch (...) { + Cb_->OnException(); + } + } + + while (!Reqs.Empty()) { + THolder<TListenSocket> ls(Reqs.PopFront()); + + Poller->Unwait(ls->GetSocket()); + } + + Requests->Stop(); + FailRequests->Stop(); + Cb_->OnListenStop(); + } + + void RestartRequestThreads(ui32 nTh, ui32 maxQS) { + Requests->Stop(); + Options_.nThreads = nTh; + Options_.MaxQueueSize = maxQS; + Requests->Start(Options_.nThreads, Options_.MaxQueueSize); + } + + TImpl(THttpServer* parent, ICallBack* cb, TMtpQueueRef mainWorkers, TMtpQueueRef failWorkers, const TOptions& options_) + : Requests(mainWorkers) + , FailRequests(failWorkers) + , Options_(options_) + , Cb_(cb) + , Parent_(parent) + { + } + + TImpl(THttpServer* parent, ICallBack* cb, const TOptions& options, IThreadFactory* factory) + : TImpl( + parent, + cb, + MakeThreadPool<TSimpleThreadPool>(factory, options.UseElasticQueues, cb, options.RequestsThreadName), + MakeThreadPool<TThreadPool>(factory, options.UseElasticQueues, nullptr, options.FailRequestsThreadName), + options) { + } + + ~TImpl() { + try { + Stop(); + } catch (...) { + } + } + + inline const TOptions& Options() const noexcept { + return Options_; + } + + inline void DecreaseConnections() noexcept { + AtomicDecrement(ConnectionCount); + } + + inline void IncreaseConnections() noexcept { + AtomicIncrement(ConnectionCount); + } + + inline i64 GetClientCount() const { + return AtomicGet(ConnectionCount); + } + + inline bool MaxRequestsReached() const { + return Options_.MaxConnections && ((size_t)GetClientCount() >= Options_.MaxConnections); + } + + THolder<TThread> ListenThread; + TPipeHandle ListenWakeupReadFd; + TPipeHandle ListenWakeupWriteFd; + TSystemEvent ListenStartEvent; + TMtpQueueRef Requests; + TMtpQueueRef FailRequests; + TAtomic ConnectionCount = 0; + THolder<TSocketPoller> Poller; + THolder<TConnections> Connections; + bool ListenerRunningOK = false; + int ErrorCode = 0; + TOptions Options_; + ICallBack* Cb_ = nullptr; + THttpServer* Parent_ = nullptr; + TWakeupPollAble WakeupPollAble; + TMutex StopMutex; + +private: + template <class TThreadPool_> + static THolder<IThreadPool> MakeThreadPool(IThreadFactory* factory, bool elastic, ICallBack* callback = nullptr, const TString& threadName = {}) { + if (!factory) { + factory = SystemThreadFactory(); + } + + THolder<IThreadPool> pool; + const auto params = IThreadPool::TParams().SetFactory(factory).SetThreadName(threadName); + if (callback) { + pool = MakeHolder<TThreadPoolBinder<TThreadPool_, THttpServer::ICallBack>>(callback, params); + } else { + pool = MakeHolder<TThreadPool_>(params); + } + + if (elastic) { + pool = MakeHolder<TElasticQueue>(std::move(pool)); + } + + return pool; + } +}; + +THttpServer::THttpServer(ICallBack* cb, const TOptions& options, IThreadFactory* pool) + : Impl_(new TImpl(this, cb, options, pool)) +{ +} + +THttpServer::THttpServer(ICallBack* cb, TMtpQueueRef mainWorkers, TMtpQueueRef failWorkers, const TOptions& options) + : Impl_(new TImpl(this, cb, mainWorkers, failWorkers, options)) +{ +} + +THttpServer::~THttpServer() { +} + +i64 THttpServer::GetClientCount() const { + return Impl_->GetClientCount(); +} + +bool THttpServer::Start() { + return Impl_->Start(); +} + +void THttpServer::Stop() { + Impl_->Stop(); +} + +void THttpServer::Shutdown() { + Impl_->Shutdown(); +} + +void THttpServer::Wait() { + Impl_->Wait(); +} + +int THttpServer::GetErrorCode() { + return Impl_->GetErrorCode(); +} + +const char* THttpServer::GetError() { + return Impl_->GetError(); +} + +void THttpServer::RestartRequestThreads(ui32 n, ui32 queue) { + Impl_->RestartRequestThreads(n, queue); +} + +const THttpServer::TOptions& THttpServer::Options() const noexcept { + return Impl_->Options(); +} + +size_t THttpServer::GetRequestQueueSize() const { + return Impl_->GetRequestQueueSize(); +} + +size_t THttpServer::GetFailQueueSize() const { + return Impl_->GetFailQueueSize(); +} + +const IThreadPool& THttpServer::GetRequestQueue() const { + return Impl_->GetRequestQueue(); +} + +const IThreadPool& THttpServer::GetFailQueue() const { + return Impl_->GetFailQueue(); +} + +bool THttpServer::MaxRequestsReached() const { + return Impl_->MaxRequestsReached(); +} + +TClientConnection::TClientConnection(const TSocket& s, THttpServer::TImpl* serv, NAddr::IRemoteAddrRef listenerSockAddrRef) + : Socket_(s) + , ListenerSockAddrRef_(listenerSockAddrRef) + , HttpServ_(serv) +{ + SetNoDelay(Socket_, true); + + const TDuration& clientTimeout = HttpServ_->Options().ClientTimeout; + if (clientTimeout != TDuration::Zero()) { + SetSocketTimeout(Socket_, (long)clientTimeout.Seconds(), clientTimeout.MilliSecondsOfSecond()); + } + + HttpServ_->IncreaseConnections(); +} + +TClientConnection::~TClientConnection() { + HttpServ_->DecreaseConnections(); +} + +void TClientConnection::OnPollEvent(TInstant now) { + THolder<TClientConnection> this_(this); + Activate(now); + + { + char tmp[1]; + + if (::recv(Socket_, tmp, 1, MSG_PEEK) < 1) { + /* + * We can received a FIN so our socket was moved to + * TCP_CLOSE_WAIT state. Check it before adding work + * for this socket. + */ + + return; + } + } + + THolder<TClientRequest> obj(HttpServ_->CreateRequest(this_)); + AcceptMoment = now; + + HttpServ_->AddRequest(obj, Reject_); +} + +void TClientConnection::Activate(TInstant now) noexcept { + HttpServ_->Connections->Erase(this, now); + LastUsed = now; + ++ReceivedRequests; +} + +void TClientConnection::DeActivate() { + HttpServ_->Connections->Add(this); +} + +void TClientConnection::Reject() { + Reject_ = true; + + HttpServ_->Connections->Add(this); +} + +TClientRequest::TClientRequest() { +} + +TClientRequest::~TClientRequest() { +} + +bool TClientRequest::Reply(void* /*ThreadSpecificResource*/) { + if (strnicmp(RequestString.data(), "GET ", 4)) { + Output() << "HTTP/1.0 501 Not Implemented\r\n\r\n"; + } else { + Output() << "HTTP/1.0 200 OK\r\n" + "Content-Type: text/html\r\n" + "\r\n" + "Hello World!\r\n"; + } + + return true; +} + +bool TClientRequest::IsLocal() const { + return HasLocalAddress(Socket()); +} + +bool TClientRequest::CheckLoopback() { + bool isLocal = false; + + try { + isLocal = IsLocal(); + } catch (const yexception& e) { + Output() << "HTTP/1.0 500 Oops\r\n\r\n" + << e.what() << "\r\n"; + return false; + } + + if (!isLocal) { + Output() << "HTTP/1.0 403 Permission denied\r\n" + "Content-Type: text/html; charset=windows-1251\r\n" + "Connection: close\r\n" + "\r\n" + "<html><head><title>Permission denied</title></head>" + "<body><h1>Permission denied</h1>" + "<p>This request must be sent from the localhost.</p>" + "</body></html>\r\n"; + + return false; + } + + return true; +} + +void TClientRequest::ReleaseConnection() { + if (Conn_ && HttpConn_ && HttpServ()->Options().KeepAliveEnabled && HttpConn_->CanBeKeepAlive() && (!HttpServ()->Options().RejectExcessConnections || !HttpServ()->MaxRequestsReached())) { + Output().Finish(); + Conn_->DeActivate(); + Y_UNUSED(Conn_.Release()); + } +} + +void TClientRequest::ResetConnection() { + if (HttpConn_) { + // send RST packet to client + HttpConn_->Reset(); + HttpConn_.Destroy(); + } +} + +void TClientRequest::Process(void* ThreadSpecificResource) { + THolder<TClientRequest> this_(this); + + auto* serverImpl = Conn_->HttpServ_; + + try { + if (!HttpConn_) { + const size_t outputBufferSize = HttpServ()->Options().OutputBufferSize; + if (outputBufferSize) { + HttpConn_.Reset(new THttpServerConn(Socket(), outputBufferSize)); + } else { + HttpConn_.Reset(new THttpServerConn(Socket())); + } + + auto maxRequestsPerConnection = HttpServ()->Options().MaxRequestsPerConnection; + HttpConn_->Output()->EnableKeepAlive(HttpServ()->Options().KeepAliveEnabled && (!maxRequestsPerConnection || Conn_->ReceivedRequests < maxRequestsPerConnection)); + HttpConn_->Output()->EnableCompression(HttpServ()->Options().CompressionEnabled); + } + + if (ParsedHeaders.empty()) { + RequestString = Input().FirstLine(); + + const THttpHeaders& h = Input().Headers(); + ParsedHeaders.reserve(h.Count()); + for (THttpHeaders::TConstIterator it = h.Begin(); it != h.End(); ++it) { + ParsedHeaders.emplace_back(it->Name(), it->Value()); + } + } + + if (Reply(ThreadSpecificResource)) { + ReleaseConnection(); + + /* + * *this will be destroyed... + */ + + return; + } + } catch (...) { + serverImpl->Cb_->OnException(); + + throw; + } + + Y_UNUSED(this_.Release()); +} + +void TClientRequest::ProcessFailRequest(int failstate) { + Output() << "HTTP/1.1 503 Service Unavailable\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: 21\r\n" + "\r\n" + "Service Unavailable\r\n"; + + TString url; + + if (!strnicmp(RequestString.data(), "GET ", 4)) { + // Trying to extract url... + const char* str = RequestString.data(); + + // Skipping spaces before url... + size_t start = 3; + while (str[start] == ' ') { + ++start; + } + + if (str[start]) { + // Traversing url... + size_t idx = start; + + while (str[idx] != ' ' && str[idx]) { + ++idx; + } + + url = RequestString.substr(start, idx - start); + } + } + + const THttpServer::ICallBack::TFailLogData d = { + failstate, + url, + }; + + // Handling failure... + Conn_->HttpServ_->Cb_->OnFailRequestEx(d); + Output().Flush(); +} + +THttpServer* TClientRequest::HttpServ() const noexcept { + return Conn_->HttpServ_->Parent_; +} + +const TSocket& TClientRequest::Socket() const noexcept { + return Conn_->Socket_; +} + +NAddr::IRemoteAddrRef TClientRequest::GetListenerSockAddrRef() const noexcept { + return Conn_->ListenerSockAddrRef_; +} + +TInstant TClientRequest::AcceptMoment() const noexcept { + return Conn_->AcceptMoment; +} + +/* + * TRequestReplier + */ +TRequestReplier::TRequestReplier() { +} + +TRequestReplier::~TRequestReplier() { +} + +bool TRequestReplier::Reply(void* threadSpecificResource) { + const TReplyParams params = { + threadSpecificResource, Input(), Output()}; + + return DoReply(params); +} + +bool TryToBindAddresses(const THttpServerOptions& options, const std::function<void(TSocket)>* callbackOnBoundAddress) { + THttpServerOptions::TBindAddresses addrs; + try { + options.BindAddresses(addrs); + } catch (const std::exception&) { + return false; + } + + for (const auto& na : addrs) { + for (TNetworkAddress::TIterator ai = na.Begin(); ai != na.End(); ++ai) { + NAddr::TAddrInfo addr(&*ai); + + TSocket socket(::socket(addr.Addr()->sa_family, SOCK_STREAM, 0)); + + if (socket == INVALID_SOCKET) { + return false; + } + + FixIPv6ListenSocket(socket); + + if (options.ReuseAddress) { + int yes = 1; + ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char*)&yes, sizeof(yes)); + } + + if (options.ReusePort) { + SetReusePort(socket, true); + } + + if (::bind(socket, addr.Addr(), addr.Len()) == SOCKET_ERROR) { + return false; + } + + if (::listen(socket, options.ListenBacklog) == SOCKET_ERROR) { + return false; + } + + if (callbackOnBoundAddress != nullptr) { + (*callbackOnBoundAddress)(socket); + } + } + } + + return true; +} |