#include "is_equal.h"
#include "traits.h"

#include <google/protobuf/descriptor.h>

#include <util/generic/yexception.h>
#include <util/string/cast.h>
#include <util/string/vector.h>

namespace NProtoBuf {
    template <bool useDefault>
    static bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath);

    namespace {
        template <FieldDescriptor::CppType CppType, bool useDefault>
        struct TCompareValue {
            typedef typename TCppTypeTraits<CppType>::T T;
            static inline bool IsEqual(T value1, T value2, TVector<TString>*) {
                return value1 == value2;
            }
        };

        template <bool useDefault>
        struct TCompareValue<FieldDescriptor::CPPTYPE_MESSAGE, useDefault> {
            static inline bool IsEqual(const Message* value1, const Message* value2, TVector<TString>* differentPath) {
                return NProtoBuf::IsEqualImpl<useDefault>(*value1, *value2, differentPath);
            }
        };

        template <FieldDescriptor::CppType CppType, bool useDefault>
        class TCompareField {
            typedef TCppTypeTraits<CppType> TTraits;
            typedef TCompareValue<CppType, useDefault> TCompare;

        public:
            static inline bool IsEqual(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) {
                if (field.is_repeated())
                    return IsEqualRepeated(m1, m2, &field, differentPath);
                else
                    return IsEqualSingle(m1, m2, &field, differentPath);
            }

        private:
            static bool IsEqualSingle(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) {
                bool has1 = m1.GetReflection()->HasField(m1, field);
                bool has2 = m2.GetReflection()->HasField(m2, field);

                if (has1 != has2) {
                    if (!useDefault || field->is_required()) {
                        return false;
                    }
                } else if (!has1)
                    return true;

                return TCompare::IsEqual(TTraits::Get(m1, field),
                                         TTraits::Get(m2, field),
                                         differentPath);
            }

            static bool IsEqualRepeated(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) {
                int fieldSize = m1.GetReflection()->FieldSize(m1, field);
                if (fieldSize != m2.GetReflection()->FieldSize(m2, field))
                    return false;
                for (int i = 0; i < fieldSize; ++i)
                    if (!IsEqualRepeatedValue(m1, m2, field, i, differentPath)) {
                        if (!!differentPath) {
                            differentPath->push_back(ToString(i));
                        }
                        return false;
                    }
                return true;
            }

            static inline bool IsEqualRepeatedValue(const Message& m1, const Message& m2, const FieldDescriptor* field, int index, TVector<TString>* differentPath) {
                return TCompare::IsEqual(TTraits::GetRepeated(m1, field, index),
                                         TTraits::GetRepeated(m2, field, index),
                                         differentPath);
            }
        };

        template <bool useDefault>
        bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) {
#define CASE_CPPTYPE(cpptype)                                                                                          \
    case FieldDescriptor::CPPTYPE_##cpptype: {                                                                         \
        bool r = TCompareField<FieldDescriptor::CPPTYPE_##cpptype, useDefault>::IsEqual(m1, m2, field, differentPath); \
        if (!r && !!differentPath) {                                                                                   \
            differentPath->push_back(field.name());                                                                    \
        }                                                                                                              \
        return r;                                                                                                      \
    }

            switch (field.cpp_type()) {
                CASE_CPPTYPE(INT32)
                CASE_CPPTYPE(INT64)
                CASE_CPPTYPE(UINT32)
                CASE_CPPTYPE(UINT64)
                CASE_CPPTYPE(DOUBLE)
                CASE_CPPTYPE(FLOAT)
                CASE_CPPTYPE(BOOL)
                CASE_CPPTYPE(ENUM)
                CASE_CPPTYPE(STRING)
                CASE_CPPTYPE(MESSAGE)
                default:
                    ythrow yexception() << "Unsupported cpp-type field comparison";
            }

#undef CASE_CPPTYPE
        }
    }

    template <bool useDefault>
    bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath) {
        const Descriptor* descr = m1.GetDescriptor();
        if (descr != m2.GetDescriptor()) {
            return false;
        }
        for (int i = 0; i < descr->field_count(); ++i)
            if (!IsEqualField<useDefault>(m1, m2, *descr->field(i), differentPath)) {
                return false;
            }
        return true;
    }

    bool IsEqual(const Message& m1, const Message& m2) {
        return IsEqualImpl<false>(m1, m2, nullptr);
    }

    bool IsEqual(const Message& m1, const Message& m2, TString* differentPath) {
        TVector<TString> differentPathVector;
        TVector<TString>* differentPathVectorPtr = !!differentPath ? &differentPathVector : nullptr;
        bool r = IsEqualImpl<false>(m1, m2, differentPathVectorPtr);
        if (!r && differentPath) {
            *differentPath = JoinStrings(differentPathVector.rbegin(), differentPathVector.rend(), "/");
        }
        return r;
    }

    bool IsEqualDefault(const Message& m1, const Message& m2) {
        return IsEqualImpl<true>(m1, m2, nullptr);
    }

    template <bool useDefault>
    static bool IsEqualFieldImpl(
        const Message& m1,
        const Message& m2,
        const FieldDescriptor& field,
        TVector<TString>* differentPath) {
        const Descriptor* descr = m1.GetDescriptor();
        if (descr != m2.GetDescriptor()) {
            return false;
        }
        return IsEqualField<useDefault>(m1, m2, field, differentPath);
    }

    bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field) {
        return IsEqualFieldImpl<false>(m1, m2, field, nullptr);
    }

    bool IsEqualFieldDefault(const Message& m1, const Message& m2, const FieldDescriptor& field) {
        return IsEqualFieldImpl<true>(m1, m2, field, nullptr);
    }

}