#include "event_loop.h"
 
#include "network.h"
#include "thread_extra.h" 
 
#include <util/generic/hash.h>
#include <util/network/pair.h>
#include <util/network/poller.h>
#include <util/system/event.h>
#include <util/system/mutex.h>
#include <util/system/thread.h>
#include <util/system/yassert.h>
#include <util/thread/lfqueue.h> 

#include <errno.h>
 
using namespace NEventLoop;

namespace {
    enum ERunningState { 
        EVENT_LOOP_CREATED, 
        EVENT_LOOP_RUNNING, 
        EVENT_LOOP_STOPPED, 
    }; 
 
    enum EOperation {
        OP_READ = 1,
        OP_WRITE = 2,
        OP_READ_WRITE = OP_READ | OP_WRITE,
    };
}

class TChannel::TImpl {
public:
    TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr, void* cookie); 
    ~TImpl(); 

    void EnableRead();
    void DisableRead();
    void EnableWrite();
    void DisableWrite();

    void Unregister();

    SOCKET GetSocket() const;
    TSocket GetSocketPtr() const; 

    void Update(int pollerFlags, bool enable); 
    void CallHandler(); 
 
    TEventLoop::TImpl* EventLoop;
    TSocket Socket; 
    TEventHandlerPtr EventHandler; 
    void* Cookie; 
 
    TMutex Mutex; 
 
    int CurrentFlags; 
    bool Close; 
};

class TEventLoop::TImpl {
public:
    TImpl(const char* name); 

    void Run();
    void Wakeup();
    void Stop();

    TChannelPtr Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie); 
    void Unregister(SOCKET socket);

    typedef THashMap<SOCKET, TChannelPtr> TData;

    void AddToPoller(SOCKET socket, void* cookie, int flags); 

    TMutex Mutex; 

    const char* Name; 
 
    TAtomic RunningState;
    TAtomic StopSignal;
    TSystemEvent StoppedEvent;
    TData Data;

    TLockFreeQueue<SOCKET> SocketsToRemove; 
 
    TSocketPoller Poller;
    TSocketHolder WakeupReadSocket;
    TSocketHolder WakeupWriteSocket;
};

TChannel::~TChannel() {
}

void TChannel::EnableRead() {
    Impl->EnableRead();
}

void TChannel::DisableRead() {
    Impl->DisableRead();
}

void TChannel::EnableWrite() {
    Impl->EnableWrite();
}

void TChannel::DisableWrite() {
    Impl->DisableWrite();
}

void TChannel::Unregister() {
    Impl->Unregister();
}

SOCKET TChannel::GetSocket() const {
    return Impl->GetSocket();
}

TSocket TChannel::GetSocketPtr() const { 
    return Impl->GetSocketPtr(); 
} 
 
TChannel::TChannel(TImpl* impl) 
    : Impl(impl) 
{
}

TEventLoop::TEventLoop(const char* name) 
    : Impl(new TImpl(name)) 
{
}

TEventLoop::~TEventLoop() {
}

void TEventLoop::Run() {
    Impl->Run();
}

void TEventLoop::Stop() {
    Impl->Stop();
}

bool TEventLoop::IsRunning() { 
    return AtomicGet(Impl->RunningState) == EVENT_LOOP_RUNNING;
} 
 
TChannelPtr TEventLoop::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) { 
    return Impl->Register(socket, eventHandler, cookie); 
}

TChannel::TImpl::TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr eventHandler, void* cookie) 
    : EventLoop(eventLoop)
    , Socket(socket)
    , EventHandler(eventHandler) 
    , Cookie(cookie) 
    , CurrentFlags(0) 
    , Close(false) 
{
}

TChannel::TImpl::~TImpl() { 
    Y_ASSERT(Close);
} 
 
void TChannel::TImpl::EnableRead() {
    Update(OP_READ, true); 
}

void TChannel::TImpl::DisableRead() {
    Update(OP_READ, false); 
}

void TChannel::TImpl::EnableWrite() {
    Update(OP_WRITE, true); 
}

void TChannel::TImpl::DisableWrite() {
    Update(OP_WRITE, false); 
}

void TChannel::TImpl::Unregister() {
    TGuard<TMutex> guard(Mutex); 
 
    if (Close) { 
        return; 
    } 
 
    Close = true; 
    if (CurrentFlags != 0) { 
        EventLoop->Poller.Unwait(Socket); 
        CurrentFlags = 0; 
    } 
    EventHandler.Drop(); 
 
    EventLoop->SocketsToRemove.Enqueue(Socket); 
    EventLoop->Wakeup(); 
}

void TChannel::TImpl::Update(int flags, bool enable) { 
    TGuard<TMutex> guard(Mutex); 
 
    if (Close) { 
        return; 
    } 
 
    int newFlags = enable ? (CurrentFlags | flags) : (CurrentFlags & ~flags); 
 
    if (CurrentFlags == newFlags) { 
        return; 
    } 
 
    if (!newFlags) { 
        EventLoop->Poller.Unwait(Socket); 
    } else { 
        void* cookie = reinterpret_cast<void*>(this); 
        EventLoop->AddToPoller(Socket, cookie, newFlags); 
    } 
 
    CurrentFlags = newFlags; 
} 
 
SOCKET TChannel::TImpl::GetSocket() const {
    return Socket;
}

TSocket TChannel::TImpl::GetSocketPtr() const { 
    return Socket; 
} 
 
void TChannel::TImpl::CallHandler() { 
    TEventHandlerPtr handler; 
 
    { 
        TGuard<TMutex> guard(Mutex); 
 
        // other thread may have re-added socket to epoll 
        // so even if CurrentFlags is 0, epoll may fire again 
        // so please use non-blocking operations 
        CurrentFlags = 0; 
 
        if (Close) { 
            return; 
        } 
 
        handler = EventHandler; 
    } 
 
    if (!!handler) { 
        handler->HandleEvent(Socket, Cookie); 
    } 
} 
 
TEventLoop::TImpl::TImpl(const char* name) 
    : Name(name) 
    , RunningState(EVENT_LOOP_CREATED) 
    , StopSignal(0)
{
    SOCKET wakeupSockets[2];

    if (SocketPair(wakeupSockets) < 0) { 
        Y_FAIL("failed to create socket pair for wakeup sockets: %s", LastSystemErrorText());
    } 

    TSocketHolder wakeupReadSocket(wakeupSockets[0]);
    TSocketHolder wakeupWriteSocket(wakeupSockets[1]);

    WakeupReadSocket.Swap(wakeupReadSocket);
    WakeupWriteSocket.Swap(wakeupWriteSocket);

    SetNonBlock(WakeupWriteSocket, true);
    SetNonBlock(WakeupReadSocket, true);

    Poller.WaitRead(WakeupReadSocket,
                    reinterpret_cast<void*>(this));
}

void TEventLoop::TImpl::Run() {
    bool res = AtomicCas(&RunningState, EVENT_LOOP_RUNNING, EVENT_LOOP_CREATED);
    Y_VERIFY(res, "Invalid mbus event loop state");
 
    if (!!Name) { 
        SetCurrentThreadName(Name);
    } 
 
    while (AtomicGet(StopSignal) == 0) {
        void* cookies[1024];
        const size_t count = Poller.WaitI(cookies, Y_ARRAY_SIZE(cookies));

        void** end = cookies + count;
        for (void** c = cookies; c != end; ++c) {
            TChannel::TImpl* s = reinterpret_cast<TChannel::TImpl*>(*c); 

            if (*c == this) { 
                char buf[0x1000];
                if (NBus::NPrivate::SocketRecv(WakeupReadSocket, buf) < 0) { 
                    Y_FAIL("failed to recv from wakeup socket: %s", LastSystemErrorText());
                } 
                continue;
            }

            s->CallHandler(); 
        }
 
        SOCKET socket = -1;
        while (SocketsToRemove.Dequeue(&socket)) {
            TGuard<TMutex> guard(Mutex); 
            Y_VERIFY(Data.erase(socket) == 1, "must be removed once");
        } 
    }
 
    { 
        TGuard<TMutex> guard(Mutex); 
        for (auto& it : Data) {
            it.second->Unregister();
        } 
 
        // release file descriptors 
        Data.clear(); 
    } 
 
    res = AtomicCas(&RunningState, EVENT_LOOP_STOPPED, EVENT_LOOP_RUNNING);
 
    Y_VERIFY(res);

    StoppedEvent.Signal(); 
}

void TEventLoop::TImpl::Stop() {
    AtomicSet(StopSignal, 1);
 
    if (AtomicGet(RunningState) == EVENT_LOOP_RUNNING) {
        Wakeup();
 
        StoppedEvent.WaitI();
    }
}

TChannelPtr TEventLoop::TImpl::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) { 
    Y_VERIFY(socket != INVALID_SOCKET, "must be a valid socket");
 
    TChannelPtr channel = new TChannel(new TChannel::TImpl(this, socket, eventHandler, cookie)); 

    TGuard<TMutex> guard(Mutex); 

    Y_VERIFY(Data.insert(std::make_pair(socket, channel)).second, "must not be already inserted");

    return channel; 
}

void TEventLoop::TImpl::Wakeup() {
    if (NBus::NPrivate::SocketSend(WakeupWriteSocket, TArrayRef<const char>("", 1)) < 0) {
        if (LastSystemError() != EAGAIN) { 
            Y_FAIL("failed to send to wakeup socket: %s", LastSystemErrorText());
        } 
    }
}

void TEventLoop::TImpl::AddToPoller(SOCKET socket, void* cookie, int flags) {
    if (flags == OP_READ) {
        Poller.WaitReadOneShot(socket, cookie); 
    } else if (flags == OP_WRITE) {
        Poller.WaitWriteOneShot(socket, cookie); 
    } else if (flags == OP_READ_WRITE) {
        Poller.WaitReadWriteOneShot(socket, cookie); 
    } else {
        Y_FAIL("Wrong flags: %d", int(flags));
    }
}