#include "rpc.h"
#include "rq.h"
#include "multi.h"
#include "location.h"
#include <library/cpp/threading/thread_local/thread_local.h>
#include <util/generic/hash.h>
#include <util/thread/factory.h>
#include <util/system/spinlock.h>
using namespace NNeh;
namespace {
typedef std::pair<TString, IServiceRef> TServiceDescr;
typedef TVector<TServiceDescr> TServicesBase;
class TServices: public TServicesBase, public TThrRefBase, public IOnRequest {
typedef THashMap<TStringBuf, IServiceRef> TSrvs;
struct TVersionedServiceMap {
TSrvs Srvs;
i64 Version = 0;
};
struct TFunc: public IThreadFactory::IThreadAble {
inline TFunc(TServices* parent)
: Parent(parent)
{
}
void DoExecute() override {
TThread::SetCurrentThreadName("NehTFunc");
TVersionedServiceMap mp;
while (true) {
IRequestRef req = Parent->RQ_->Next();
if (!req) {
break;
}
Parent->ServeRequest(mp, req);
}
Parent->RQ_->Schedule(nullptr);
}
TServices* Parent;
};
public:
inline TServices()
: RQ_(CreateRequestQueue())
{
}
inline TServices(TCheck check)
: RQ_(CreateRequestQueue())
, C_(check)
{
}
inline ~TServices() override {
LF_.Destroy();
}
inline void Add(const TString& service, IServiceRef srv) {
TGuard<TSpinLock> guard(L_);
push_back(std::make_pair(service, srv));
AtomicIncrement(SelfVersion_);
}
inline void Listen() {
Y_ENSURE(!HasLoop_ || !*HasLoop_);
HasLoop_ = false;
RR_ = MultiRequester(ListenAddrs(), this);
}
inline void Loop(size_t threads) {
Y_ENSURE(!HasLoop_ || *HasLoop_);
HasLoop_ = true;
TIntrusivePtr<TServices> self(this);
IRequesterRef rr = MultiRequester(ListenAddrs(), this);
TFunc func(this);
typedef TAutoPtr<IThreadFactory::IThread> IThreadRef;
TVector<IThreadRef> thrs;
for (size_t i = 1; i < threads; ++i) {
thrs.push_back(SystemThreadFactory()->Run(&func));
}
func.Execute();
for (size_t i = 0; i < thrs.size(); ++i) {
thrs[i]->Join();
}
RQ_->Clear();
}
inline void ForkLoop(size_t threads) {
Y_ENSURE(!HasLoop_ || *HasLoop_);
HasLoop_ = true;
//here we can have trouble with binding port(s), so expect exceptions
IRequesterRef rr = MultiRequester(ListenAddrs(), this);
LF_.Reset(new TLoopFunc(this, threads, rr));
}
inline void Stop() {
RQ_->Schedule(nullptr);
}
inline void SyncStopFork() {
Stop();
if (LF_) {
LF_->SyncStop();
}
RQ_->Clear();
LF_.Destroy();
}
void OnRequest(IRequestRef req) override {
if (C_) {
if (auto error = C_(req)) {
req->SendError(*error);
return;
}
}
if (!*HasLoop_) {
ServeRequest(LocalMap_.GetRef(), req);
} else {
RQ_->Schedule(req);
}
}
private:
class TLoopFunc: public TFunc {
public:
TLoopFunc(TServices* parent, size_t threads, IRequesterRef& rr)
: TFunc(parent)
, RR_(rr)
{
T_.reserve(threads);
try {
for (size_t i = 0; i < threads; ++i) {
T_.push_back(SystemThreadFactory()->Run(this));
}
} catch (...) {
//paranoid mode on
SyncStop();
throw;
}
}
~TLoopFunc() override {
try {
SyncStop();
} catch (...) {
Cdbg << TStringBuf("neh rpc ~loop_func: ") << CurrentExceptionMessage() << Endl;
}
}
void SyncStop() {
if (!T_) {
return;
}
Parent->Stop();
for (size_t i = 0; i < T_.size(); ++i) {
T_[i]->Join();
}
T_.clear();
}
private:
typedef TAutoPtr<IThreadFactory::IThread> IThreadRef;
TVector<IThreadRef> T_;
IRequesterRef RR_;
};
inline void ServeRequest(TVersionedServiceMap& mp, IRequestRef req) {
if (!req) {
return;
}
const TStringBuf name = req->Service();
TSrvs::const_iterator it = mp.Srvs.find(name);
if (Y_UNLIKELY(it == mp.Srvs.end())) {
if (UpdateServices(mp.Srvs, mp.Version)) {
it = mp.Srvs.find(name);
}
}
if (Y_UNLIKELY(it == mp.Srvs.end())) {
it = mp.Srvs.find(TStringBuf("*"));
}
if (Y_UNLIKELY(it == mp.Srvs.end())) {
req->SendError(IRequest::NotExistService);
} else {
try {
it->second->ServeRequest(req);
} catch (...) {
Cdbg << CurrentExceptionMessage() << Endl;
}
}
}
inline bool UpdateServices(TSrvs& srvs, i64& version) const {
if (AtomicGet(SelfVersion_) == version) {
return false;
}
srvs.clear();
TGuard<TSpinLock> guard(L_);
for (const auto& it : *this) {
srvs[TParsedLocation(it.first).Service] = it.second;
}
version = AtomicGet(SelfVersion_);
return true;
}
inline TListenAddrs ListenAddrs() const {
TListenAddrs addrs;
{
TGuard<TSpinLock> guard(L_);
for (const auto& it : *this) {
addrs.push_back(it.first);
}
}
return addrs;
}
TSpinLock L_;
IRequestQueueRef RQ_;
THolder<TLoopFunc> LF_;
TAtomic SelfVersion_ = 1;
TCheck C_;
NThreading::TThreadLocalValue<TVersionedServiceMap> LocalMap_;
IRequesterRef RR_;
TMaybe<bool> HasLoop_;
};
class TServicesFace: public IServices {
public:
inline TServicesFace()
: S_(new TServices())
{
}
inline TServicesFace(TCheck check)
: S_(new TServices(check))
{
}
void DoAdd(const TString& service, IServiceRef srv) override {
S_->Add(service, srv);
}
void Loop(size_t threads) override {
S_->Loop(threads);
}
void ForkLoop(size_t threads) override {
S_->ForkLoop(threads);
}
void SyncStopFork() override {
S_->SyncStopFork();
}
void Stop() override {
S_->Stop();
}
void Listen() override {
S_->Listen();
}
private:
TIntrusivePtr<TServices> S_;
};
}
IServiceRef NNeh::Wrap(const TServiceFunction& func) {
struct TWrapper: public IService {
inline TWrapper(const TServiceFunction& f)
: F(f)
{
}
void ServeRequest(const IRequestRef& request) override {
F(request);
}
TServiceFunction F;
};
return new TWrapper(func);
}
IServicesRef NNeh::CreateLoop() {
return new TServicesFace();
}
IServicesRef NNeh::CreateLoop(TCheck check) {
return new TServicesFace(check);
}