#include "is_equal.h" #include "traits.h" #include <google/protobuf/descriptor.h> #include <util/generic/yexception.h> #include <util/string/cast.h> #include <util/string/vector.h> namespace NProtoBuf { template <bool useDefault> static bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath); namespace { template <FieldDescriptor::CppType CppType, bool useDefault> struct TCompareValue { typedef typename TCppTypeTraits<CppType>::T T; static inline bool IsEqual(T value1, T value2, TVector<TString>*) { return value1 == value2; } }; template <bool useDefault> struct TCompareValue<FieldDescriptor::CPPTYPE_MESSAGE, useDefault> { static inline bool IsEqual(const Message* value1, const Message* value2, TVector<TString>* differentPath) { return NProtoBuf::IsEqualImpl<useDefault>(*value1, *value2, differentPath); } }; template <FieldDescriptor::CppType CppType, bool useDefault> class TCompareField { typedef TCppTypeTraits<CppType> TTraits; typedef TCompareValue<CppType, useDefault> TCompare; public: static inline bool IsEqual(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) { if (field.is_repeated()) return IsEqualRepeated(m1, m2, &field, differentPath); else return IsEqualSingle(m1, m2, &field, differentPath); } private: static bool IsEqualSingle(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) { bool has1 = m1.GetReflection()->HasField(m1, field); bool has2 = m2.GetReflection()->HasField(m2, field); if (has1 != has2) { if (!useDefault || field->is_required()) { return false; } } else if (!has1) return true; return TCompare::IsEqual(TTraits::Get(m1, field), TTraits::Get(m2, field), differentPath); } static bool IsEqualRepeated(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) { int fieldSize = m1.GetReflection()->FieldSize(m1, field); if (fieldSize != m2.GetReflection()->FieldSize(m2, field)) return false; for (int i = 0; i < fieldSize; ++i) if (!IsEqualRepeatedValue(m1, m2, field, i, differentPath)) { if (!!differentPath) { differentPath->push_back(ToString(i)); } return false; } return true; } static inline bool IsEqualRepeatedValue(const Message& m1, const Message& m2, const FieldDescriptor* field, int index, TVector<TString>* differentPath) { return TCompare::IsEqual(TTraits::GetRepeated(m1, field, index), TTraits::GetRepeated(m2, field, index), differentPath); } }; template <bool useDefault> bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) { #define CASE_CPPTYPE(cpptype) \ case FieldDescriptor::CPPTYPE_##cpptype: { \ bool r = TCompareField<FieldDescriptor::CPPTYPE_##cpptype, useDefault>::IsEqual(m1, m2, field, differentPath); \ if (!r && !!differentPath) { \ differentPath->push_back(field.name()); \ } \ return r; \ } switch (field.cpp_type()) { CASE_CPPTYPE(INT32) CASE_CPPTYPE(INT64) CASE_CPPTYPE(UINT32) CASE_CPPTYPE(UINT64) CASE_CPPTYPE(DOUBLE) CASE_CPPTYPE(FLOAT) CASE_CPPTYPE(BOOL) CASE_CPPTYPE(ENUM) CASE_CPPTYPE(STRING) CASE_CPPTYPE(MESSAGE) default: ythrow yexception() << "Unsupported cpp-type field comparison"; } #undef CASE_CPPTYPE } } template <bool useDefault> bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath) { const Descriptor* descr = m1.GetDescriptor(); if (descr != m2.GetDescriptor()) { return false; } for (int i = 0; i < descr->field_count(); ++i) if (!IsEqualField<useDefault>(m1, m2, *descr->field(i), differentPath)) { return false; } return true; } bool IsEqual(const Message& m1, const Message& m2) { return IsEqualImpl<false>(m1, m2, nullptr); } bool IsEqual(const Message& m1, const Message& m2, TString* differentPath) { TVector<TString> differentPathVector; TVector<TString>* differentPathVectorPtr = !!differentPath ? &differentPathVector : nullptr; bool r = IsEqualImpl<false>(m1, m2, differentPathVectorPtr); if (!r && differentPath) { *differentPath = JoinStrings(differentPathVector.rbegin(), differentPathVector.rend(), "/"); } return r; } bool IsEqualDefault(const Message& m1, const Message& m2) { return IsEqualImpl<true>(m1, m2, nullptr); } template <bool useDefault> static bool IsEqualFieldImpl( const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) { const Descriptor* descr = m1.GetDescriptor(); if (descr != m2.GetDescriptor()) { return false; } return IsEqualField<useDefault>(m1, m2, field, differentPath); } bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field) { return IsEqualFieldImpl<false>(m1, m2, field, nullptr); } bool IsEqualFieldDefault(const Message& m1, const Message& m2, const FieldDescriptor& field) { return IsEqualFieldImpl<true>(m1, m2, field, nullptr); } }