#pragma once

#include "ref_counted.h"

#include <util/generic/hash.h>
#include <util/generic/utility.h>

#include <utility>
#include <type_traits>

namespace NYT {

////////////////////////////////////////////////////////////////////////////////

template <class T>
class TIntrusivePtr
{
public:
    typedef T TUnderlying;

    constexpr TIntrusivePtr() noexcept
    { }

    constexpr TIntrusivePtr(std::nullptr_t) noexcept
    { }

    //! Constructor from an unqualified reference.
    /*!
     * Note that this constructor could be racy due to unsynchronized operations
     * on the object and on the counter.
     *
     * Note that it notoriously hard to make this constructor explicit
     * given the current amount of code written.
     */
    TIntrusivePtr(T* obj, bool addReference = true) noexcept
        : T_(obj)
    {
        if (T_ && addReference) {
            Ref(T_);
        }
    }

    //! Copy constructor.
    TIntrusivePtr(const TIntrusivePtr& other) noexcept
        : T_(other.Get())
    {
        if (T_) {
            Ref(T_);
        }
    }

    //! Copy constructor with an upcast.
    template <class U, class = typename std::enable_if_t<std::is_convertible_v<U*, T*>>>
    TIntrusivePtr(const TIntrusivePtr<U>& other) noexcept
        : T_(other.Get())
    {
        static_assert(
            std::is_base_of_v<TRefCountedBase, T>,
            "Cast allowed only for types derived from TRefCountedBase");
        if (T_) {
            Ref(T_);
        }
    }

    //! Move constructor.
    TIntrusivePtr(TIntrusivePtr&& other) noexcept
        : T_(other.Get())
    {
        other.T_ = nullptr;
    }

    //! Move constructor with an upcast.
    template <class U, class = typename std::enable_if_t<std::is_convertible_v<U*, T*>>>
    TIntrusivePtr(TIntrusivePtr<U>&& other) noexcept
        : T_(other.Get())
    {
        static_assert(
            std::is_base_of_v<TRefCountedBase, T>,
            "Cast allowed only for types derived from TRefCountedBase");
        other.T_ = nullptr;
    }

    //! Destructor.
    ~TIntrusivePtr()
    {
        if (T_) {
            Unref(T_);
        }
    }

    //! Copy assignment operator.
    TIntrusivePtr& operator=(const TIntrusivePtr& other) noexcept
    {
        TIntrusivePtr(other).Swap(*this);
        return *this;
    }

    //! Copy assignment operator with an upcast.
    template <class U>
    TIntrusivePtr& operator=(const TIntrusivePtr<U>& other) noexcept
    {
        static_assert(
            std::is_convertible_v<U*, T*>,
            "U* must be convertible to T*");
        static_assert(
            std::is_base_of_v<TRefCountedBase, T>,
            "Cast allowed only for types derived from TRefCountedBase");
        TIntrusivePtr(other).Swap(*this);
        return *this;
    }

    //! Move assignment operator.
    TIntrusivePtr& operator=(TIntrusivePtr&& other) noexcept
    {
        TIntrusivePtr(std::move(other)).Swap(*this);
        return *this;
    }

    //! Move assignment operator with an upcast.
    template <class U>
    TIntrusivePtr& operator=(TIntrusivePtr<U>&& other) noexcept
    {
        static_assert(
            std::is_convertible_v<U*, T*>,
            "U* must be convertible to T*");
        static_assert(
            std::is_base_of_v<TRefCountedBase, T>,
            "Cast allowed only for types derived from TRefCountedBase");
        TIntrusivePtr(std::move(other)).Swap(*this);
        return *this;
    }

    //! Drop the pointer.
    void Reset() // noexcept
    {
        TIntrusivePtr().Swap(*this);
    }

    //! Replace the pointer with a specified one.
    void Reset(T* p) // noexcept
    {
        TIntrusivePtr(p).Swap(*this);
    }

    //! Returns the pointer.
    T* Get() const noexcept
    {
        return T_;
    }

    //! Returns the pointer and releases the ownership.
    T* Release() noexcept
    {
        auto* p = T_;
        T_ = nullptr;
        return p;
    }

    T& operator*() const noexcept
    {
        YT_ASSERT(T_);
        return *T_;
    }

    T* operator->() const noexcept
    {
        YT_ASSERT(T_);
        return T_;
    }

    explicit operator bool() const noexcept
    {
        return T_ != nullptr;
    }

    //! Swap the pointer with the other one.
    void Swap(TIntrusivePtr& r) noexcept
    {
        DoSwap(T_, r.T_);
    }

private:
    template <class U>
    friend class TIntrusivePtr;

    T* T_ = nullptr;
};

////////////////////////////////////////////////////////////////////////////////

//! Creates a strong pointer wrapper for a given raw pointer.
//! Compared to |TIntrusivePtr<T>::ctor|, type inference enables omitting |T|.
template <class T>
TIntrusivePtr<T> MakeStrong(T* p)
{
    return TIntrusivePtr<T>(p);
}

//! Tries to obtain an intrusive pointer for an object that may had
//! already lost all of its references and, thus, is about to be deleted.
/*!
 * You may call this method at any time provided that you have a valid
 * raw pointer to an object. The call either returns an intrusive pointer
 * for the object (thus ensuring that the object won't be destroyed until
 * you're holding this pointer) or NULL indicating that the last reference
 * had already been lost and the object is on its way to heavens.
 * All these steps happen atomically.
 *
 * Under all circumstances it is caller's responsibility the make sure that
 * the object is not destroyed during the call to #DangerousGetPtr.
 * Typically this is achieved by keeping a (lock-protected) collection of
 * raw pointers, taking a lock in object's destructor, and unregistering
 * its raw pointer from the collection there.
 */

template <class T>
Y_FORCE_INLINE TIntrusivePtr<T> DangerousGetPtr(T* object)
{
    return object->TryRef()
        ? TIntrusivePtr<T>(object, false)
        : TIntrusivePtr<T>();
}

////////////////////////////////////////////////////////////////////////////////

template <class T, class U>
TIntrusivePtr<T> StaticPointerCast(const TIntrusivePtr<U>& ptr)
{
    return {static_cast<T*>(ptr.Get())};
}

template <class T, class U>
TIntrusivePtr<T> StaticPointerCast(TIntrusivePtr<U>&& ptr)
{
    return {static_cast<T*>(ptr.Release()), false};
}

template <class T, class U>
TIntrusivePtr<T> ConstPointerCast(const TIntrusivePtr<U>& ptr)
{
    return {const_cast<T*>(ptr.Get())};
}

template <class T, class U>
TIntrusivePtr<T> ConstPointerCast(TIntrusivePtr<U>&& ptr)
{
    return {const_cast<T*>(ptr.Release()), false};
}

template <class T, class U>
TIntrusivePtr<T> DynamicPointerCast(const TIntrusivePtr<U>& ptr)
{
    return {dynamic_cast<T*>(ptr.Get())};
}

////////////////////////////////////////////////////////////////////////////////

template <class T>
bool operator<(const TIntrusivePtr<T>& lhs, const TIntrusivePtr<T>& rhs)
{
    return lhs.Get() < rhs.Get();
}

template <class T>
bool operator>(const TIntrusivePtr<T>& lhs, const TIntrusivePtr<T>& rhs)
{
    return lhs.Get() > rhs.Get();
}

template <class T, class U>
bool operator==(const TIntrusivePtr<T>& lhs, const TIntrusivePtr<U>& rhs)
{
    static_assert(
        std::is_convertible_v<U*, T*>,
        "U* must be convertible to T*");
    return lhs.Get() == rhs.Get();
}

template <class T, class U>
bool operator!=(const TIntrusivePtr<T>& lhs, const TIntrusivePtr<U>& rhs)
{
    static_assert(
        std::is_convertible_v<U*, T*>,
        "U* must be convertible to T*");
    return lhs.Get() != rhs.Get();
}

template <class T, class U>
bool operator==(const TIntrusivePtr<T>& lhs, U* rhs)
{
    return lhs.Get() == rhs;
}

template <class T, class U>
bool operator!=(const TIntrusivePtr<T>& lhs, U* rhs)
{
    static_assert(
        std::is_convertible_v<U*, T*>,
        "U* must be convertible to T*");
    return lhs.Get() != rhs;
}

template <class T, class U>
bool operator==(T* lhs, const TIntrusivePtr<U>& rhs)
{
    return lhs == rhs.Get();
}

template <class T, class U>
bool operator!=(T* lhs, const TIntrusivePtr<U>& rhs)
{
    static_assert(
        std::is_convertible_v<U*, T*>,
        "U* must be convertible to T*");
    return lhs != rhs.Get();
}

template <class T>
bool operator==(std::nullptr_t, const TIntrusivePtr<T>& rhs)
{
    return nullptr == rhs.Get();
}

template <class T>
bool operator!=(std::nullptr_t, const TIntrusivePtr<T>& rhs)
{
    return nullptr != rhs.Get();
}

template <class T>
bool operator==(const TIntrusivePtr<T>& lhs, std::nullptr_t)
{
    return nullptr == lhs.Get();
}

template <class T>
bool operator!=(const TIntrusivePtr<T>& lhs, std::nullptr_t)
{
    return nullptr != lhs.Get();
}

////////////////////////////////////////////////////////////////////////////////

} //namespace NYT

//! A hasher for TIntrusivePtr.
template <class T>
struct THash<NYT::TIntrusivePtr<T>>
{
    Y_FORCE_INLINE size_t operator () (const NYT::TIntrusivePtr<T>& ptr) const
    {
        return THash<T*>()(ptr.Get());
    }
};