#include "event_loop.h"
#include "network.h"
#include "thread_extra.h"
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/generic/hash.h>
#include <util/network/pair.h>
#include <util/network/poller.h>
#include <util/system/event.h>
#include <util/system/mutex.h>
#include <util/system/thread.h>
#include <util/system/yassert.h>
#include <util/thread/lfqueue.h>
#include <errno.h>
using namespace NEventLoop;
namespace {
enum ERunningState {
EVENT_LOOP_CREATED,
EVENT_LOOP_RUNNING,
EVENT_LOOP_STOPPED,
};
enum EOperation {
OP_READ = 1,
OP_WRITE = 2,
OP_READ_WRITE = OP_READ | OP_WRITE,
};
}
class TChannel::TImpl {
public:
TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr, void* cookie);
~TImpl();
void EnableRead();
void DisableRead();
void EnableWrite();
void DisableWrite();
void Unregister();
SOCKET GetSocket() const;
TSocket GetSocketPtr() const;
void Update(int pollerFlags, bool enable);
void CallHandler();
TEventLoop::TImpl* EventLoop;
TSocket Socket;
TEventHandlerPtr EventHandler;
void* Cookie;
TMutex Mutex;
int CurrentFlags;
bool Close;
};
class TEventLoop::TImpl {
public:
TImpl(const char* name);
void Run();
void Wakeup();
void Stop();
TChannelPtr Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie);
void Unregister(SOCKET socket);
typedef THashMap<SOCKET, TChannelPtr> TData;
void AddToPoller(SOCKET socket, void* cookie, int flags);
TMutex Mutex;
const char* Name;
TAtomic RunningState;
TAtomic StopSignal;
TSystemEvent StoppedEvent;
TData Data;
TLockFreeQueue<SOCKET> SocketsToRemove;
TSocketPoller Poller;
TSocketHolder WakeupReadSocket;
TSocketHolder WakeupWriteSocket;
};
TChannel::~TChannel() {
}
void TChannel::EnableRead() {
Impl->EnableRead();
}
void TChannel::DisableRead() {
Impl->DisableRead();
}
void TChannel::EnableWrite() {
Impl->EnableWrite();
}
void TChannel::DisableWrite() {
Impl->DisableWrite();
}
void TChannel::Unregister() {
Impl->Unregister();
}
SOCKET TChannel::GetSocket() const {
return Impl->GetSocket();
}
TSocket TChannel::GetSocketPtr() const {
return Impl->GetSocketPtr();
}
TChannel::TChannel(TImpl* impl)
: Impl(impl)
{
}
TEventLoop::TEventLoop(const char* name)
: Impl(new TImpl(name))
{
}
TEventLoop::~TEventLoop() {
}
void TEventLoop::Run() {
Impl->Run();
}
void TEventLoop::Stop() {
Impl->Stop();
}
bool TEventLoop::IsRunning() {
return AtomicGet(Impl->RunningState) == EVENT_LOOP_RUNNING;
}
TChannelPtr TEventLoop::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) {
return Impl->Register(socket, eventHandler, cookie);
}
TChannel::TImpl::TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr eventHandler, void* cookie)
: EventLoop(eventLoop)
, Socket(socket)
, EventHandler(eventHandler)
, Cookie(cookie)
, CurrentFlags(0)
, Close(false)
{
}
TChannel::TImpl::~TImpl() {
Y_ASSERT(Close);
}
void TChannel::TImpl::EnableRead() {
Update(OP_READ, true);
}
void TChannel::TImpl::DisableRead() {
Update(OP_READ, false);
}
void TChannel::TImpl::EnableWrite() {
Update(OP_WRITE, true);
}
void TChannel::TImpl::DisableWrite() {
Update(OP_WRITE, false);
}
void TChannel::TImpl::Unregister() {
TGuard<TMutex> guard(Mutex);
if (Close) {
return;
}
Close = true;
if (CurrentFlags != 0) {
EventLoop->Poller.Unwait(Socket);
CurrentFlags = 0;
}
EventHandler.Drop();
EventLoop->SocketsToRemove.Enqueue(Socket);
EventLoop->Wakeup();
}
void TChannel::TImpl::Update(int flags, bool enable) {
TGuard<TMutex> guard(Mutex);
if (Close) {
return;
}
int newFlags = enable ? (CurrentFlags | flags) : (CurrentFlags & ~flags);
if (CurrentFlags == newFlags) {
return;
}
if (!newFlags) {
EventLoop->Poller.Unwait(Socket);
} else {
void* cookie = reinterpret_cast<void*>(this);
EventLoop->AddToPoller(Socket, cookie, newFlags);
}
CurrentFlags = newFlags;
}
SOCKET TChannel::TImpl::GetSocket() const {
return Socket;
}
TSocket TChannel::TImpl::GetSocketPtr() const {
return Socket;
}
void TChannel::TImpl::CallHandler() {
TEventHandlerPtr handler;
{
TGuard<TMutex> guard(Mutex);
// other thread may have re-added socket to epoll
// so even if CurrentFlags is 0, epoll may fire again
// so please use non-blocking operations
CurrentFlags = 0;
if (Close) {
return;
}
handler = EventHandler;
}
if (!!handler) {
handler->HandleEvent(Socket, Cookie);
}
}
TEventLoop::TImpl::TImpl(const char* name)
: Name(name)
, RunningState(EVENT_LOOP_CREATED)
, StopSignal(0)
{
SOCKET wakeupSockets[2];
if (SocketPair(wakeupSockets) < 0) {
Y_ABORT("failed to create socket pair for wakeup sockets: %s", LastSystemErrorText());
}
TSocketHolder wakeupReadSocket(wakeupSockets[0]);
TSocketHolder wakeupWriteSocket(wakeupSockets[1]);
WakeupReadSocket.Swap(wakeupReadSocket);
WakeupWriteSocket.Swap(wakeupWriteSocket);
SetNonBlock(WakeupWriteSocket, true);
SetNonBlock(WakeupReadSocket, true);
Poller.WaitRead(WakeupReadSocket,
reinterpret_cast<void*>(this));
}
void TEventLoop::TImpl::Run() {
bool res = AtomicCas(&RunningState, EVENT_LOOP_RUNNING, EVENT_LOOP_CREATED);
Y_ABORT_UNLESS(res, "Invalid mbus event loop state");
if (!!Name) {
SetCurrentThreadName(Name);
}
while (AtomicGet(StopSignal) == 0) {
void* cookies[1024];
const size_t count = Poller.WaitI(cookies, Y_ARRAY_SIZE(cookies));
void** end = cookies + count;
for (void** c = cookies; c != end; ++c) {
TChannel::TImpl* s = reinterpret_cast<TChannel::TImpl*>(*c);
if (*c == this) {
char buf[0x1000];
if (NBus::NPrivate::SocketRecv(WakeupReadSocket, buf) < 0) {
Y_ABORT("failed to recv from wakeup socket: %s", LastSystemErrorText());
}
continue;
}
s->CallHandler();
}
SOCKET socket = -1;
while (SocketsToRemove.Dequeue(&socket)) {
TGuard<TMutex> guard(Mutex);
Y_ABORT_UNLESS(Data.erase(socket) == 1, "must be removed once");
}
}
{
TGuard<TMutex> guard(Mutex);
for (auto& it : Data) {
it.second->Unregister();
}
// release file descriptors
Data.clear();
}
res = AtomicCas(&RunningState, EVENT_LOOP_STOPPED, EVENT_LOOP_RUNNING);
Y_ABORT_UNLESS(res);
StoppedEvent.Signal();
}
void TEventLoop::TImpl::Stop() {
AtomicSet(StopSignal, 1);
if (AtomicGet(RunningState) == EVENT_LOOP_RUNNING) {
Wakeup();
StoppedEvent.WaitI();
}
}
TChannelPtr TEventLoop::TImpl::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) {
Y_ABORT_UNLESS(socket != INVALID_SOCKET, "must be a valid socket");
TChannelPtr channel = new TChannel(new TChannel::TImpl(this, socket, eventHandler, cookie));
TGuard<TMutex> guard(Mutex);
Y_ABORT_UNLESS(Data.insert(std::make_pair(socket, channel)).second, "must not be already inserted");
return channel;
}
void TEventLoop::TImpl::Wakeup() {
if (NBus::NPrivate::SocketSend(WakeupWriteSocket, TArrayRef<const char>("", 1)) < 0) {
if (LastSystemError() != EAGAIN) {
Y_ABORT("failed to send to wakeup socket: %s", LastSystemErrorText());
}
}
}
void TEventLoop::TImpl::AddToPoller(SOCKET socket, void* cookie, int flags) {
if (flags == OP_READ) {
Poller.WaitReadOneShot(socket, cookie);
} else if (flags == OP_WRITE) {
Poller.WaitWriteOneShot(socket, cookie);
} else if (flags == OP_READ_WRITE) {
Poller.WaitReadWriteOneShot(socket, cookie);
} else {
Y_ABORT("Wrong flags: %d", int(flags));
}
}