#pragma once

#include "net_test.h"
#include "net_queue_stat.h"

#include <util/system/spinlock.h>

namespace NNetliba {
    const float MIN_PACKET_RTT_SKO = 0.001f; // avoid drops due to small hiccups

    const float CONG_CTRL_INITIAL_RTT = 0.24f; //0.01f; // taking into account Las Vegas 10ms estimate is too optimistic

    const float CONG_CTRL_WINDOW_GROW = 0.005f;
    const float CONG_CTRL_WINDOW_SHRINK = 0.9f;
    const float CONG_CTRL_WINDOW_SHRINK_RTT = 0.95f;
    const float CONG_CTRL_RTT_MIX_RATE = 0.9f;
    const int CONG_CTRL_RTT_SEQ_COUNT = 8;
    const float CONG_CTRL_MIN_WINDOW = 0.01f;
    const float CONG_CTRL_LARGE_TIME_WINDOW = 10000.0f;
    const float CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD = 0.4f; // in seconds
    const float CONG_CTRL_MINIMAL_SEND_INTERVAL = 1;
    const float CONG_CTRL_MIN_FAIL_INTERVAL = 0.001f;
    const float CONG_CTRL_ALLOWED_BURST_SIZE = 3;
    const float CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION = 0.002f;

    const float LAME_MTU_TIMEOUT = 0.3f;
    const float LAME_MTU_INTERVAL = 0.05f;

    const float START_CHECK_PORT_DELAY = 0.5;
    const float FINISH_CHECK_PORT_DELAY = 10;
    const int N_PORT_TEST_COUNT_LIMIT = 256; // or 512

    // if enabled all acks are sent with different TOS, so they end up in different queue
    // this allows us to limit window based on minimal RTT observed and 1G link assumption
    extern bool UseTOSforAcks;

    class TPingTracker {
        float AvrgRTT, AvrgRTT2; // RTT statistics
        float RTTCount;

    public:
        TPingTracker();
        float GetRTT() const {
            return AvrgRTT;
        }
        float GetRTTSKO() const {
            float sko = sqrt(fabs(Sqr(AvrgRTT) - AvrgRTT2));
            float minSKO = Max(MIN_PACKET_RTT_SKO, AvrgRTT * 0.05f);
            return Max(minSKO, sko);
        }
        float GetTimeout() const {
            return GetRTT() + GetRTTSKO() * 3;
        }
        void RegisterRTT(float rtt);
        void IncreaseRTT();
    };

    ui32 NetAckRnd();

    class TLameMTUDiscovery: public TThrRefBase {
        enum EState {
            NEED_PING,
            WAIT,
        };

        float TimePassed, TimeSinceLastPing;
        EState State;

    public:
        TLameMTUDiscovery()
            : TimePassed(0)
            , TimeSinceLastPing(0)
            , State(NEED_PING)
        {
        }
        bool CanSend() {
            return State == NEED_PING;
        }
        void PingSent() {
            State = WAIT;
            TimeSinceLastPing = 0;
        }
        bool IsTimedOut() const {
            return TimePassed > LAME_MTU_TIMEOUT;
        }
        void Step(float deltaT) {
            TimePassed += deltaT;
            TimeSinceLastPing += deltaT;
            if (TimeSinceLastPing > LAME_MTU_INTERVAL)
                State = NEED_PING;
        }
    };

    struct TPeerQueueStats: public IPeerQueueStats {
        int Count;

        TPeerQueueStats()
            : Count(0)
        {
        }
        int GetPacketCount() override {
            return Count;
        }
    };

    // pretend we have multiple channels in parallel
    // not exact approximation since N channels should have N distinct windows
    extern float CONG_CTRL_CHANNEL_INFLATE;

    class TCongestionControl: public TThrRefBase {
        float Window, PacketsInFly, FailRate;
        float MinRTT, MaxWindow;
        bool FullSpeed, DoCountTime;
        TPingTracker PingTracker;
        double TimeSinceLastRecv;
        TAdaptiveLock PortTesterLock;
        TIntrusivePtr<TPortUnreachableTester> PortTester;
        int ActiveTransferCount;
        float AvrgRTT;
        int HighRTTCounter;
        float WindowFraction, FractionRecalc;
        float TimeWindow;
        double TimeSinceLastFail;
        float VirtualPackets;
        int MTU;
        TIntrusivePtr<TLameMTUDiscovery> MTUDiscovery;
        TIntrusivePtr<TPeerQueueStats> QueueStats;

        void CalcMaxWindow() {
            if (MTU == 0)
                return;
            MaxWindow = 125000000 / MTU * Max(0.001f, MinRTT);
        }

    public:
        static float StartWindowSize, MaxPacketRate;

    public:
        TCongestionControl()
            : Window(StartWindowSize * CONG_CTRL_CHANNEL_INFLATE)
            , PacketsInFly(0)
            , FailRate(0)
            , MinRTT(10)
            , MaxWindow(10000)
            , FullSpeed(false)
            , DoCountTime(false)
            , TimeSinceLastRecv(0)
            , ActiveTransferCount(0)
            , AvrgRTT(0)
            , HighRTTCounter(0)
            , WindowFraction(0)
            , FractionRecalc(0)
            , TimeWindow(CONG_CTRL_LARGE_TIME_WINDOW)
            , TimeSinceLastFail(0)
            , MTU(0)
        {
            VirtualPackets = Max(Window - CONG_CTRL_ALLOWED_BURST_SIZE, 0.f);
        }
        bool CanSend() {
            bool res = VirtualPackets + PacketsInFly + WindowFraction <= Window;
            FullSpeed |= !res;
            res &= TimeWindow > 0;
            return res;
        }
        void LaunchPacket() {
            PacketsInFly += 1.0f;
            TimeWindow -= 1.0f;
        }
        void RegisterRTT(float RTT) {
            if (RTT < 0)
                return;
            RTT = ClampVal(RTT, 0.0001f, 1.0f);
            if (RTT < MinRTT && MTU != 0) {
                MinRTT = RTT;
                CalcMaxWindow();
            }
            MinRTT = Min(MinRTT, RTT);

            PingTracker.RegisterRTT(RTT);
            if (AvrgRTT == 0)
                AvrgRTT = RTT;
            if (RTT > AvrgRTT) {
                ++HighRTTCounter;
                if (HighRTTCounter >= CONG_CTRL_RTT_SEQ_COUNT) {
                    //printf("Too many high RTT in a row\n");
                    if (FullSpeed) {
                        float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK_RTT) / CONG_CTRL_CHANNEL_INFLATE);
                        Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract);
                        VirtualPackets = Max(0.f, VirtualPackets - windowSubtract);
                        //printf("reducing window by RTT , new window %g\n", Window);
                    }
                    // reduce no more then twice per RTT
                    HighRTTCounter = Min(0, CONG_CTRL_RTT_SEQ_COUNT - (int)(Window * 0.5));
                }
            } else {
                HighRTTCounter = Min(0, HighRTTCounter);
            }

            float rttMixRate = CONG_CTRL_RTT_MIX_RATE;
            AvrgRTT = AvrgRTT * rttMixRate + RTT * (1 - rttMixRate);
        }
        void Success() {
            PacketsInFly -= 1;
            Y_ASSERT(PacketsInFly >= 0);
            // FullSpeed should be correct at this point
            // we assume that after UpdateAlive() we send all packets first then we listen for acks and call Success()
            // FullSpeed is set in CanSend() during send if we are using full window
            // do not increaese window while send rate is limited by virtual packets (ie start of transfer)
            if (FullSpeed && VirtualPackets == 0) {
                // there are 2 requirements for window growth
                // 1) growth should be proportional to window size to ensure constant FailRate
                // 2) growth should be constant to ensure fairness among different flows
                // so lets make it square root :)
                Window += sqrt(Window / CONG_CTRL_CHANNEL_INFLATE) * CONG_CTRL_WINDOW_GROW;
                if (UseTOSforAcks) {
                    Window = Min(Window, MaxWindow);
                }
            }
            FailRate *= 0.99f;
        }
        void FailureOnSend() {
            //printf("Failure on send\n");
            PacketsInFly -= 1;
            Y_ASSERT(PacketsInFly >= 0);
            // not a congestion event, do not modify Window
            // do not set FullSpeed since we are not using full Window
        }
        void Failure() {
            //printf("Congestion failure\n");
            PacketsInFly -= 1;
            Y_ASSERT(PacketsInFly >= 0);
            // account limited number of fails per segment
            if (TimeSinceLastFail > CONG_CTRL_MIN_FAIL_INTERVAL) {
                TimeSinceLastFail = 0;
                if (Window <= CONG_CTRL_MIN_WINDOW) {
                    // ping dead hosts less frequently
                    if (PingTracker.GetRTT() / CONG_CTRL_MIN_WINDOW < CONG_CTRL_MINIMAL_SEND_INTERVAL)
                        PingTracker.IncreaseRTT();
                    Window = CONG_CTRL_MIN_WINDOW;
                    VirtualPackets = 0;
                } else {
                    float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK) / CONG_CTRL_CHANNEL_INFLATE);
                    Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract);
                    VirtualPackets = Max(0.f, VirtualPackets - windowSubtract);
                }
            }
            FailRate = FailRate * 0.99f + 0.01f;
        }
        bool HasPacketsInFly() const {
            return PacketsInFly > 0;
        }
        float GetTimeout() const {
            return PingTracker.GetTimeout();
        }
        float GetWindow() const {
            return Window;
        }
        float GetRTT() const {
            return PingTracker.GetRTT();
        }
        float GetFailRate() const {
            return FailRate;
        }
        float GetTimeSinceLastRecv() const {
            return TimeSinceLastRecv;
        }
        int GetTransferCount() const {
            return ActiveTransferCount;
        }
        float GetMaxWindow() const {
            return UseTOSforAcks ? MaxWindow : -1;
        }
        void MarkAlive() {
            TimeSinceLastRecv = 0;

            with_lock (PortTesterLock) {
                PortTester = nullptr;
            }

        }
        void HasTriedToSend() {
            DoCountTime = true;
        }
        bool IsAlive() const {
            return TimeSinceLastRecv < 1e6f;
        }
        void Kill() {
            TimeSinceLastRecv = 1e6f;
        }
        bool UpdateAlive(const TUdpAddress& toAddress, float deltaT, float timeout, float* resMaxWaitTime) {
            if (!FullSpeed) {
                // create virtual packets during idle to avoid burst on transmit start
                if (AvrgRTT > CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION) {
                    VirtualPackets = Max(VirtualPackets, Window - PacketsInFly - CONG_CTRL_ALLOWED_BURST_SIZE);
                }
            } else {
                if (VirtualPackets > 0) {
                    if (Window <= CONG_CTRL_ALLOWED_BURST_SIZE) {
                        VirtualPackets = 0;
                    }
                    float xRTT = AvrgRTT == 0 ? CONG_CTRL_INITIAL_RTT : AvrgRTT;
                    float virtualPktsPerSecond = Window / xRTT;
                    VirtualPackets = Max(0.f, VirtualPackets - deltaT * virtualPktsPerSecond);
                    *resMaxWaitTime = Min(*resMaxWaitTime, 0.001f); // need to update virtual packets counter regularly
                }
            }
            float currentRTT = GetRTT();
            FractionRecalc += deltaT;
            if (FractionRecalc > currentRTT) {
                int cycleCount = (int)(FractionRecalc / currentRTT);
                FractionRecalc -= currentRTT * cycleCount;
                WindowFraction = (NetAckRnd() & 1023) * (1 / 1023.0f) / cycleCount;
            }

            if (MaxPacketRate > 0 && AvrgRTT > 0) {
                float maxTimeWindow = CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD * MaxPacketRate;
                TimeWindow = Min(maxTimeWindow, TimeWindow + MaxPacketRate * deltaT);
            } else
                TimeWindow = CONG_CTRL_LARGE_TIME_WINDOW;

            // guarantee minimal send rate
            if (currentRTT > CONG_CTRL_MINIMAL_SEND_INTERVAL * Window) {
                Window = Max(CONG_CTRL_MIN_WINDOW, currentRTT / CONG_CTRL_MINIMAL_SEND_INTERVAL);
                VirtualPackets = 0;
            }

            TimeSinceLastFail += deltaT;

            //static int n;
            //if ((++n & 127) == 0)
            //    printf("window = %g, fly = %g, VirtualPkts = %g, deltaT = %g, FailRate = %g FullSpeed = %d AvrgRTT = %g\n",
            //        Window, PacketsInFly, VirtualPackets, deltaT * 1000, FailRate, (int)FullSpeed, AvrgRTT * 1000);

            if (PacketsInFly > 0 || FullSpeed || DoCountTime) {
                // считаем время только когда есть пакеты в полете
                TimeSinceLastRecv += deltaT;
                if (TimeSinceLastRecv > START_CHECK_PORT_DELAY) {
                    if (TimeSinceLastRecv < FINISH_CHECK_PORT_DELAY) {
                        TIntrusivePtr<TPortUnreachableTester> portTester;
                        with_lock (PortTesterLock) {
                            portTester = PortTester;
                        }

                        if (!portTester && AtomicGet(ActivePortTestersCount) < N_PORT_TEST_COUNT_LIMIT) {
                            portTester = new TPortUnreachableTester();
                            with_lock (PortTesterLock) {
                                PortTester = portTester;
                            }

                            if (portTester->IsValid()) {
                                portTester->Connect(toAddress);
                            } else {
                                with_lock (PortTesterLock) {
                                    PortTester = nullptr;
                                }
                            }
                        }
                        if (portTester && !portTester->Test(deltaT)) {
                            Kill();
                            return false;
                        }
                    } else {
                        with_lock (PortTesterLock) {
                            PortTester = nullptr;
                        }
                    }
                }
                if (TimeSinceLastRecv > timeout) {
                    Kill();
                    return false;
                }
            }

            FullSpeed = false;
            DoCountTime = false;

            if (MTUDiscovery.Get())
                MTUDiscovery->Step(deltaT);

            return true;
        }
        bool IsKnownMTU() const {
            return MTU != 0;
        }
        int GetMTU() const {
            return MTU;
        }
        TLameMTUDiscovery* GetMTUDiscovery() {
            if (MTUDiscovery.Get() == nullptr)
                MTUDiscovery = new TLameMTUDiscovery;
            return MTUDiscovery.Get();
        }
        void SetMTU(int sz) {
            MTU = sz;
            MTUDiscovery = nullptr;
            CalcMaxWindow();
        }
        void AttachQueueStats(TIntrusivePtr<TPeerQueueStats> s) {
            if (s.Get()) {
                s->Count = ActiveTransferCount;
            }
            Y_ASSERT(QueueStats.Get() == nullptr);
            QueueStats = s;
        }
        friend class TCongestionControlPtr;
    };

    class TCongestionControlPtr {
        TIntrusivePtr<TCongestionControl> Ptr;

        void Inc() {
            if (Ptr.Get()) {
                ++Ptr->ActiveTransferCount;
                if (Ptr->QueueStats.Get()) {
                    Ptr->QueueStats->Count = Ptr->ActiveTransferCount;
                }
            }
        }
        void Dec() {
            if (Ptr.Get()) {
                --Ptr->ActiveTransferCount;
                if (Ptr->QueueStats.Get()) {
                    Ptr->QueueStats->Count = Ptr->ActiveTransferCount;
                }
            }
        }

    public:
        TCongestionControlPtr() {
        }
        ~TCongestionControlPtr() {
            Dec();
        }
        TCongestionControlPtr(TCongestionControl* p)
            : Ptr(p)
        {
            Inc();
        }
        TCongestionControlPtr& operator=(const TCongestionControlPtr& a) {
            Dec();
            Ptr = a.Ptr;
            Inc();
            return *this;
        }
        TCongestionControlPtr& operator=(TCongestionControl* a) {
            Dec();
            Ptr = a;
            Inc();
            return *this;
        }
        operator TCongestionControl*() const {
            return Ptr.Get();
        }
        TCongestionControl* operator->() const {
            return Ptr.Get();
        }
        TIntrusivePtr<TCongestionControl> Get() const {
            return Ptr;
        }
    };

    class TAckTracker {
        struct TFlyingPacket {
            float T;
            int PktId;
            TFlyingPacket()
                : T(0)
                , PktId(-1)
            {
            }
            TFlyingPacket(float t, int pktId)
                : T(t)
                , PktId(pktId)
            {
            }
        };
        int PacketCount, CurrentPacket;
        typedef THashMap<int, float> TPacketHash;
        TPacketHash PacketsInFly, DroppedPackets;
        TVector<int> ResendQueue;
        TCongestionControlPtr Congestion;
        TVector<bool> AckReceived;
        float TimeToNextPacketTimeout;

        int SelectPacket();

    public:
        TAckTracker()
            : PacketCount(0)
            , CurrentPacket(0)
            , TimeToNextPacketTimeout(1000)
        {
        }
        ~TAckTracker();
        void AttachCongestionControl(TCongestionControl* p) {
            Congestion = p;
        }
        TIntrusivePtr<TCongestionControl> GetCongestionControl() const {
            return Congestion.Get();
        }
        void SetPacketCount(int n) {
            Y_ASSERT(PacketCount == 0);
            PacketCount = n;
            AckReceived.resize(n, false);
        }
        void Resend();
        bool IsInitialized() {
            return PacketCount != 0;
        }
        int GetPacketToSend(float deltaT);
        void AddToResend(int pkt); // called when failed to send packet
        void Ack(int pkt, float deltaT, bool updateRTT);
        void AckAll();
        void MarkAlive() {
            Congestion->MarkAlive();
        }
        bool IsAlive() const {
            return Congestion->IsAlive();
        }
        void Step(float deltaT);
        bool CanSend() const {
            return Congestion->CanSend();
        }
        float GetTimeToNextPacketTimeout() const {
            return TimeToNextPacketTimeout;
        }
    };
}