#pragma once

#include <util/generic/map.h>
#include <util/system/yassert.h>

#include <type_traits>

template <class T>
class TDisjointIntervalTree {
private:
    static_assert(std::is_integral<T>::value, "expect std::is_integral<T>::value");

    using TTree = TMap<T, T>; // [key, value)
    using TIterator = typename TTree::iterator;
    using TConstIterator = typename TTree::const_iterator;
    using TReverseIterator = typename TTree::reverse_iterator;
    using TThis = TDisjointIntervalTree<T>;

    TTree Tree;
    size_t NumElements;

public:
    TDisjointIntervalTree()
        : NumElements()
    {
    }

    void Insert(const T t) {
        InsertInterval(t, t + 1);
    }

    // we assume that none of elements from [begin, end) belong to tree.
    void InsertInterval(const T begin, const T end) {
        InsertIntervalImpl(begin, end);
        NumElements += (size_t)(end - begin);
    }

    bool Has(const T t) const {
        return const_cast<TThis*>(this)->FindContaining(t) != Tree.end();
    }

    bool Intersects(const T begin, const T end) {
        if (Empty()) {
            return false;
        }

        TIterator l = Tree.lower_bound(begin);
        if (l != Tree.end()) {
            if (l->first < end) {
                return true;
            } else if (l != Tree.begin()) {
                --l;
                return l->second > begin;
            } else {
                return false;
            }
        } else {
            auto last = Tree.rbegin();
            return begin < last->second;
        }
    }

    TConstIterator FindContaining(const T t) const {
        return const_cast<TThis*>(this)->FindContaining(t);
    }

    // Erase element. Returns true when element has been deleted, otherwise false.
    bool Erase(const T t) {
        TIterator n = FindContaining(t);
        if (n == Tree.end()) {
            return false;
        }

        --NumElements;

        T& begin = const_cast<T&>(n->first);
        T& end = const_cast<T&>(n->second);

        // Optimization hack.
        if (t == begin) {
            if (++begin == end) { // OK to change key since intervals do not intersect.
                Tree.erase(n);
                return true;
            }

        } else if (t == end - 1) {
            --end;

        } else {
            const T e = end;
            end = t;
            InsertIntervalImpl(t + 1, e);
        }

        Y_ASSERT(begin < end);
        return true;
    }

    // Erase interval. Returns number of elements removed from set.
    size_t EraseInterval(const T begin, const T end) {
        Y_ASSERT(begin < end);

        if (Empty()) {
            return 0;
        }

        size_t elementsRemoved = 0;

        TIterator completelyRemoveBegin = Tree.lower_bound(begin);
        if ((completelyRemoveBegin != Tree.end() && completelyRemoveBegin->first > begin && completelyRemoveBegin != Tree.begin())
            || completelyRemoveBegin == Tree.end()) {
            // Look at the interval. It could contain [begin, end).
            TIterator containingBegin = completelyRemoveBegin;
            --containingBegin;
            if (containingBegin->first < begin && begin < containingBegin->second) { // Contains begin.
                if (containingBegin->second > end) { // Contains end.
                    const T prevEnd = containingBegin->second;
                    Y_ASSERT(containingBegin->second - begin <= NumElements);

                    Y_ASSERT(containingBegin->second - containingBegin->first > end - begin);
                    containingBegin->second = begin;
                    InsertIntervalImpl(end, prevEnd);

                    elementsRemoved = end - begin;
                    NumElements -= elementsRemoved;
                    return elementsRemoved;
                } else {
                    elementsRemoved += containingBegin->second - begin;
                    containingBegin->second = begin;
                }
            }
        }

        TIterator completelyRemoveEnd = completelyRemoveBegin != Tree.end() ? Tree.lower_bound(end) : Tree.end();
        if (completelyRemoveEnd != Tree.end() && completelyRemoveEnd != Tree.begin() && completelyRemoveEnd->first != end) {
            TIterator containingEnd = completelyRemoveEnd;
            --containingEnd;
            if (containingEnd->second > end) {
                T& leftBorder = const_cast<T&>(containingEnd->first);

                Y_ASSERT(leftBorder < end);

                --completelyRemoveEnd; // Don't remove the whole interval.

                // Optimization hack.
                elementsRemoved += end - leftBorder;
                leftBorder = end; // OK to change key since intervals do not intersect.
            }
        }

        for (TIterator i = completelyRemoveBegin; i != completelyRemoveEnd; ++i) {
            elementsRemoved += i->second - i->first;
        }

        Tree.erase(completelyRemoveBegin, completelyRemoveEnd);

        Y_ASSERT(elementsRemoved <= NumElements);
        NumElements -= elementsRemoved;

        return elementsRemoved;
    }

    void Swap(TDisjointIntervalTree& rhv) {
        Tree.swap(rhv.Tree);
        std::swap(NumElements, rhv.NumElements);
    }

    void Clear() {
        Tree.clear();
        NumElements = 0;
    }

    bool Empty() const {
        return Tree.empty();
    }

    size_t GetNumElements() const {
        return NumElements;
    }

    size_t GetNumIntervals() const {
        return Tree.size();
    }

    T Min() const {
        Y_ASSERT(!Empty());
        return Tree.begin()->first;
    }

    T Max() const {
        Y_ASSERT(!Empty());
        return Tree.rbegin()->second;
    }

    TConstIterator begin() const {
        return Tree.begin();
    }

    TConstIterator end() const {
        return Tree.end();
    }

private:
    void InsertIntervalImpl(const T begin, const T end) {
        Y_ASSERT(begin < end);
        Y_ASSERT(!Intersects(begin, end));

        TIterator l = Tree.lower_bound(begin);
        TIterator p = Tree.end();
        if (l != Tree.begin()) {
            p = l;
            --p;
        }

#ifndef NDEBUG
        TIterator u = Tree.upper_bound(begin);
        Y_VERIFY_DEBUG(u == Tree.end() || u->first >= end, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, u->first, u->second);
        Y_VERIFY_DEBUG(l == Tree.end() || l == u, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, l->first, l->second);
        Y_VERIFY_DEBUG(p == Tree.end() || p->second <= begin, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, p->first, p->second);
#endif

        // try to extend interval
        if (p != Tree.end() && p->second == begin) {
            p->second = end;
            //Try to merge 2 intervals - p and next one if possible
            auto next = p;
            // Next is not Tree.end() here.
            ++next;
            if (next != Tree.end() && next->first == end) {
                p->second = next->second;
                Tree.erase(next);
            }
        // Maybe new interval extends right interval
        } else if (l != Tree.end() && end == l->first) {
            T& leftBorder = const_cast<T&>(l->first);
            // Optimization hack.
            leftBorder = begin; // OK to change key since intervals do not intersect.
        } else {
            Tree.insert(std::make_pair(begin, end));
        }
    }

    TIterator FindContaining(const T t) {
        TIterator l = Tree.lower_bound(t);
        if (l != Tree.end()) {
            if (l->first == t) {
                return l;
            }
            Y_ASSERT(l->first > t);

            if (l == Tree.begin()) {
                return Tree.end();
            }

            --l;
            Y_ASSERT(l->first != t);

            if (l->first < t && t < l->second) {
                return l;
            }

        } else if (!Tree.empty()) { // l is larger than Begin of any interval, but maybe it belongs to last interval?
            TReverseIterator last = Tree.rbegin();
            Y_ASSERT(last->first != t);

            if (last->first < t && t < last->second) {
                return (++last).base();
            }
        }
        return Tree.end();
    }
};