diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/protobuf | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/protobuf')
78 files changed, 8856 insertions, 0 deletions
diff --git a/library/cpp/protobuf/interop/cast.cpp b/library/cpp/protobuf/interop/cast.cpp new file mode 100644 index 0000000000..c4cd59b417 --- /dev/null +++ b/library/cpp/protobuf/interop/cast.cpp @@ -0,0 +1,23 @@ +#include <library/cpp/protobuf/interop/cast.h> + +#include <google/protobuf/duration.pb.h> +#include <google/protobuf/timestamp.pb.h> +#include <google/protobuf/util/time_util.h> + +namespace NProtoInterop { + google::protobuf::Duration CastToProto(TDuration duration) { + return google::protobuf::util::TimeUtil::MicrosecondsToDuration(duration.MicroSeconds()); + } + + google::protobuf::Timestamp CastToProto(TInstant instant) { + return google::protobuf::util::TimeUtil::MicrosecondsToTimestamp(instant.MicroSeconds()); + } + + TDuration CastFromProto(const google::protobuf::Duration& duration) { + return TDuration::MicroSeconds(google::protobuf::util::TimeUtil::DurationToMicroseconds(duration)); + } + + TInstant CastFromProto(const google::protobuf::Timestamp& timestamp) { + return TInstant::MicroSeconds(google::protobuf::util::TimeUtil::TimestampToMicroseconds(timestamp)); + } +} diff --git a/library/cpp/protobuf/interop/cast.h b/library/cpp/protobuf/interop/cast.h new file mode 100644 index 0000000000..b1c295236e --- /dev/null +++ b/library/cpp/protobuf/interop/cast.h @@ -0,0 +1,15 @@ +#pragma once + +#include <util/datetime/base.h> + +namespace google::protobuf { + class Duration; + class Timestamp; +} + +namespace NProtoInterop { + google::protobuf::Duration CastToProto(TDuration duration); + google::protobuf::Timestamp CastToProto(TInstant instant); + TDuration CastFromProto(const google::protobuf::Duration& message); + TInstant CastFromProto(const google::protobuf::Timestamp& message); +} diff --git a/library/cpp/protobuf/interop/ut/cast_ut.cpp b/library/cpp/protobuf/interop/ut/cast_ut.cpp new file mode 100644 index 0000000000..6ef055b651 --- /dev/null +++ b/library/cpp/protobuf/interop/ut/cast_ut.cpp @@ -0,0 +1,52 @@ +#include <library/cpp/protobuf/interop/cast.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <google/protobuf/duration.pb.h> +#include <google/protobuf/timestamp.pb.h> + +static constexpr ui64 MicroSecondsInSecond = 1000 * 1000; +static constexpr ui64 NanoSecondsInMicroSecond = 1000; + +Y_UNIT_TEST_SUITE(TCastTest) { + Y_UNIT_TEST(TimestampFromProto) { + const ui64 now = TInstant::Now().MicroSeconds(); + + google::protobuf::Timestamp timestamp; + timestamp.set_seconds(now / MicroSecondsInSecond); + timestamp.set_nanos((now % MicroSecondsInSecond) * NanoSecondsInMicroSecond); + + const TInstant instant = NProtoInterop::CastFromProto(timestamp); + UNIT_ASSERT_EQUAL(instant.MicroSeconds(), now); + } + + Y_UNIT_TEST(DurationFromProto) { + const ui64 now = TInstant::Now().MicroSeconds(); + + google::protobuf::Duration message; + message.set_seconds(now / MicroSecondsInSecond); + message.set_nanos((now % MicroSecondsInSecond) * NanoSecondsInMicroSecond); + + const TDuration duration = NProtoInterop::CastFromProto(message); + UNIT_ASSERT_EQUAL(duration.MicroSeconds(), now); + } + + Y_UNIT_TEST(TimestampToProto) { + const TInstant instant = TInstant::Now(); + + google::protobuf::Timestamp timestamp = NProtoInterop::CastToProto(instant); + const ui64 microSeconds = timestamp.seconds() * MicroSecondsInSecond + + timestamp.nanos() / NanoSecondsInMicroSecond; + + UNIT_ASSERT_EQUAL(instant.MicroSeconds(), microSeconds); + } + + Y_UNIT_TEST(DurationToProto) { + const TDuration duration = TDuration::Seconds(TInstant::Now().Seconds() / 2); + + google::protobuf::Duration message = NProtoInterop::CastToProto(duration); + const ui64 microSeconds = message.seconds() * MicroSecondsInSecond + + message.nanos() / NanoSecondsInMicroSecond; + + UNIT_ASSERT_EQUAL(duration.MicroSeconds(), microSeconds); + } +} diff --git a/library/cpp/protobuf/interop/ut/ya.make b/library/cpp/protobuf/interop/ut/ya.make new file mode 100644 index 0000000000..b9c634cb6b --- /dev/null +++ b/library/cpp/protobuf/interop/ut/ya.make @@ -0,0 +1,15 @@ +UNITTEST_FOR(library/cpp/protobuf/interop) + +OWNER( + paxakor +) + +SRCS( + cast_ut.cpp +) + +PEERDIR( + library/cpp/protobuf/interop +) + +END() diff --git a/library/cpp/protobuf/interop/ya.make b/library/cpp/protobuf/interop/ya.make new file mode 100644 index 0000000000..618b553459 --- /dev/null +++ b/library/cpp/protobuf/interop/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER( + paxakor +) + +SRCS( + cast.cpp +) + +PEERDIR( + contrib/libs/protobuf +) + +END() diff --git a/library/cpp/protobuf/json/README b/library/cpp/protobuf/json/README new file mode 100644 index 0000000000..a0d1092ee2 --- /dev/null +++ b/library/cpp/protobuf/json/README @@ -0,0 +1 @@ +Protobuf to/from JSON converter. diff --git a/library/cpp/protobuf/json/config.h b/library/cpp/protobuf/json/config.h new file mode 100644 index 0000000000..dc84fb4d5d --- /dev/null +++ b/library/cpp/protobuf/json/config.h @@ -0,0 +1,164 @@ +#pragma once + +#include "string_transform.h" +#include "name_generator.h" + +#include <util/generic/vector.h> +#include <util/generic/yexception.h> + +#include <functional> + +namespace NProtobufJson { + struct TProto2JsonConfig { + using TSelf = TProto2JsonConfig; + + bool FormatOutput = false; + + enum MissingKeyMode { + // Skip missing keys + MissingKeySkip = 0, + // Fill missing keys with json null value. + MissingKeyNull, + // Use default value in any case. + // If default value is not explicitly defined, use default type value: + // i.e. 0 for integers, "" for strings + // For repeated keys, means [] + MissingKeyDefault, + // Use default value if it is explicitly specified for optional fields. + // Skip if no explicitly defined default value for optional fields. + // Throw exception if required field is empty. + // For repeated keys, same as MissingKeySkip + MissingKeyExplicitDefaultThrowRequired + }; + MissingKeyMode MissingSingleKeyMode = MissingKeySkip; + MissingKeyMode MissingRepeatedKeyMode = MissingKeySkip; + + /// Add null value for missing fields (false by default). + bool AddMissingFields = false; + + enum EnumValueMode { + EnumNumber = 0, // default + EnumName, + EnumFullName, + EnumNameLowerCase, + EnumFullNameLowerCase, + }; + EnumValueMode EnumMode = EnumNumber; + + enum FldNameMode { + FieldNameOriginalCase = 0, // default + FieldNameLowerCase, + FieldNameUpperCase, + FieldNameCamelCase, + FieldNameSnakeCase, // ABC -> a_b_c, UserID -> user_i_d + FieldNameSnakeCaseDense // ABC -> abc, UserID -> user_id + }; + FldNameMode FieldNameMode = FieldNameOriginalCase; + + enum ExtFldNameMode { + ExtFldNameFull = 0, // default, field.full_name() + ExtFldNameShort // field.name() + }; + ExtFldNameMode ExtensionFieldNameMode = ExtFldNameFull; + + /// Use 'json_name' protobuf option for field name, mutually exclusive + /// with FieldNameMode. + bool UseJsonName = false; + + /// Transforms will be applied only to string values (== protobuf fields of string / bytes type). + /// yajl_encode_string will be used if no transforms are specified. + TVector<TStringTransformPtr> StringTransforms; + + /// Print map as object, otherwise print it as array of key/value objects + bool MapAsObject = false; + + /// Stringify long integers which are not exactly representable by float or double values + enum EStringifyLongNumbersMode { + StringifyLongNumbersNever = 0, // default + StringifyLongNumbersForFloat, + StringifyLongNumbersForDouble, + }; + EStringifyLongNumbersMode StringifyLongNumbers = StringifyLongNumbersNever; + + /// Custom field names generator. + TNameGenerator NameGenerator = {}; + + /// Custom enum values generator. + TEnumValueGenerator EnumValueGenerator = {}; + + bool WriteNanAsString = false; + + TSelf& SetFormatOutput(bool format) { + FormatOutput = format; + return *this; + } + + TSelf& SetMissingSingleKeyMode(MissingKeyMode mode) { + MissingSingleKeyMode = mode; + return *this; + } + + TSelf& SetMissingRepeatedKeyMode(MissingKeyMode mode) { + MissingRepeatedKeyMode = mode; + return *this; + } + + TSelf& SetAddMissingFields(bool add) { + AddMissingFields = add; + return *this; + } + + TSelf& SetEnumMode(EnumValueMode mode) { + EnumMode = mode; + return *this; + } + + TSelf& SetFieldNameMode(FldNameMode mode) { + Y_ENSURE(mode == FieldNameOriginalCase || !UseJsonName, "FieldNameMode and UseJsonName are mutually exclusive"); + FieldNameMode = mode; + return *this; + } + + TSelf& SetUseJsonName(bool jsonName) { + Y_ENSURE(!jsonName || FieldNameMode == FieldNameOriginalCase, "FieldNameMode and UseJsonName are mutually exclusive"); + UseJsonName = jsonName; + return *this; + } + + TSelf& SetExtensionFieldNameMode(ExtFldNameMode mode) { + ExtensionFieldNameMode = mode; + return *this; + } + + TSelf& AddStringTransform(TStringTransformPtr transform) { + StringTransforms.push_back(transform); + return *this; + } + + TSelf& SetMapAsObject(bool value) { + MapAsObject = value; + return *this; + } + + TSelf& SetStringifyLongNumbers(EStringifyLongNumbersMode stringify) { + StringifyLongNumbers = stringify; + return *this; + } + + TSelf& SetNameGenerator(TNameGenerator callback) { + NameGenerator = callback; + return *this; + } + + TSelf& SetEnumValueGenerator(TEnumValueGenerator callback) { + EnumValueGenerator = callback; + return *this; + } + + TSelf& SetWriteNanAsString(bool value) { + WriteNanAsString = value; + return *this; + } + }; + +} diff --git a/library/cpp/protobuf/json/field_option.h b/library/cpp/protobuf/json/field_option.h new file mode 100644 index 0000000000..c8a8bfbff5 --- /dev/null +++ b/library/cpp/protobuf/json/field_option.h @@ -0,0 +1,40 @@ +#pragma once + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/message.h> + +namespace NProtobufJson { + // Functor that defines whether given field has some option set to true + // + // Example: + // message T { + // optional stroka some_field = 1 [(some_option) = true]; + // } + // + template <typename TFieldOptionExtensionId> + class TFieldOptionFunctor { + public: + TFieldOptionFunctor(const TFieldOptionExtensionId& option, bool positive = true) + : Option(option) + , Positive(positive) + { + } + + bool operator()(const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor* field) const { + const NProtoBuf::FieldOptions& opt = field->options(); + const bool val = opt.GetExtension(Option); + return Positive ? val : !val; + } + + private: + const TFieldOptionExtensionId& Option; + bool Positive; + }; + + template <typename TFieldOptionExtensionId> + TFieldOptionFunctor<TFieldOptionExtensionId> MakeFieldOptionFunctor(const TFieldOptionExtensionId& option, bool positive = true) { + return TFieldOptionFunctor<TFieldOptionExtensionId>(option, positive); + } + +} diff --git a/library/cpp/protobuf/json/filter.h b/library/cpp/protobuf/json/filter.h new file mode 100644 index 0000000000..9a3ddb54fe --- /dev/null +++ b/library/cpp/protobuf/json/filter.h @@ -0,0 +1,48 @@ +#pragma once + +#include "config.h" +#include "proto2json_printer.h" +#include "json_output_create.h" + +#include <util/generic/yexception.h> +#include <util/generic/utility.h> + +#include <functional> + +namespace NProtobufJson { + template <typename TBasePrinter = TProto2JsonPrinter> // TBasePrinter is assumed to be a TProto2JsonPrinter descendant + class TFilteringPrinter: public TBasePrinter { + public: + using TFieldPredicate = std::function<bool(const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor*)>; + + template <typename... TArgs> + TFilteringPrinter(TFieldPredicate isPrinted, TArgs&&... args) + : TBasePrinter(std::forward<TArgs>(args)...) + , IsPrinted(std::move(isPrinted)) + { + } + + virtual void PrintField(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key) override { + if (key || IsPrinted(proto, &field)) + TBasePrinter::PrintField(proto, field, json, key); + } + + private: + TFieldPredicate IsPrinted; + }; + + inline void PrintWithFilter(const NProtoBuf::Message& msg, TFilteringPrinter<>::TFieldPredicate filter, IJsonOutput& output, const TProto2JsonConfig& config = TProto2JsonConfig()) { + TFilteringPrinter<> printer(std::move(filter), config); + printer.Print(msg, output); + } + + inline TString PrintWithFilter(const NProtoBuf::Message& msg, TFilteringPrinter<>::TFieldPredicate filter, const TProto2JsonConfig& config = TProto2JsonConfig()) { + TString ret; + PrintWithFilter(msg, std::move(filter), *CreateJsonMapOutput(ret, config), config); + return ret; + } + +} diff --git a/library/cpp/protobuf/json/inline.h b/library/cpp/protobuf/json/inline.h new file mode 100644 index 0000000000..e2d7bb6ef0 --- /dev/null +++ b/library/cpp/protobuf/json/inline.h @@ -0,0 +1,115 @@ +#pragma once + +// A printer from protobuf to json string, with ability to inline some string fields of given protobuf message +// into output as ready json without additional escaping. These fields should be marked using special field option. +// An example of usage: +// 1) Define a field option in your .proto to identify fields which should be inlined, e.g. +// +// import "google/protobuf/descriptor.proto"; +// extend google.protobuf.FieldOptions { +// optional bool this_is_json = 58253; // do not forget assign some more or less unique tag +// } +// +// 2) Mark some fields of your protobuf message with this option, e.g.: +// +// message TMyObject { +// optional string A = 1 [(this_is_json) = true]; +// } +// +// 3) In the C++ code you prepare somehow an object of TMyObject type +// +// TMyObject o; +// o.Set("{\"inner\":\"value\"}"); +// +// 4) And then serialize it to json string with inlining, e.g.: +// +// Cout << NProtobufJson::PrintInlined(o, MakeFieldOptionFunctor(this_is_json)) << Endl; +// +// 5) Alternatively you can specify a some more abstract functor for defining raw json fields +// +// which will print following json to stdout: +// {"A":{"inner":"value"}} +// instead of +// {"A":"{\"inner\":\"value\"}"} +// which would be printed with normal Proto2Json printer. +// +// See ut/inline_ut.cpp for additional examples of usage. + +#include "config.h" +#include "proto2json_printer.h" +#include "json_output_create.h" + +#include <library/cpp/protobuf/util/simple_reflection.h> + +#include <util/generic/maybe.h> +#include <util/generic/yexception.h> +#include <util/generic/utility.h> + +#include <functional> + +namespace NProtobufJson { + template <typename TBasePrinter = TProto2JsonPrinter> // TBasePrinter is assumed to be a TProto2JsonPrinter descendant + class TInliningPrinter: public TBasePrinter { + public: + using TFieldPredicate = std::function<bool(const NProtoBuf::Message&, + const NProtoBuf::FieldDescriptor*)>; + + template <typename... TArgs> + TInliningPrinter(TFieldPredicate isInlined, TArgs&&... args) + : TBasePrinter(std::forward<TArgs>(args)...) + , IsInlined(std::move(isInlined)) + { + } + + virtual void PrintField(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key) override { + const NProtoBuf::TConstField f(proto, &field); + if (!key && IsInlined(proto, &field) && ShouldPrint(f)) { + key = this->MakeKey(field); + json.WriteKey(key); + if (!field.is_repeated()) { + json.WriteRawJson(f.Get<TString>()); + } else { + json.BeginList(); + for (size_t i = 0, sz = f.Size(); i < sz; ++i) + json.WriteRawJson(f.Get<TString>(i)); + json.EndList(); + } + + } else { + TBasePrinter::PrintField(proto, field, json, key); + } + } + + private: + bool ShouldPrint(const NProtoBuf::TConstField& f) const { + if (!f.IsString()) + ythrow yexception() << "TInliningPrinter: json field " + << f.Field()->name() << " should be a string"; + + if (f.HasValue()) + return true; + + // we may want write default value for given field in case of its absence + const auto& cfg = this->GetConfig(); + return (f.Field()->is_repeated() ? cfg.MissingRepeatedKeyMode : cfg.MissingSingleKeyMode) == TProto2JsonConfig::MissingKeyDefault; + } + + private: + TFieldPredicate IsInlined; + }; + + inline void PrintInlined(const NProtoBuf::Message& msg, TInliningPrinter<>::TFieldPredicate isInlined, IJsonOutput& output, const TProto2JsonConfig& config = TProto2JsonConfig()) { + TInliningPrinter<> printer(std::move(isInlined), config); + printer.Print(msg, output); + } + + inline TString PrintInlined(const NProtoBuf::Message& msg, TInliningPrinter<>::TFieldPredicate isInlined, const TProto2JsonConfig& config = TProto2JsonConfig()) { + TString ret; + PrintInlined(msg, std::move(isInlined), *CreateJsonMapOutput(ret, config), config); + return ret; + } + +} diff --git a/library/cpp/protobuf/json/json2proto.cpp b/library/cpp/protobuf/json/json2proto.cpp new file mode 100644 index 0000000000..640c10f5a5 --- /dev/null +++ b/library/cpp/protobuf/json/json2proto.cpp @@ -0,0 +1,428 @@ +#include "json2proto.h" +#include "util.h" + +#include <library/cpp/json/json_value.h> + +#include <google/protobuf/message.h> +#include <google/protobuf/descriptor.h> + +#include <util/generic/hash.h> +#include <util/generic/maybe.h> +#include <util/string/ascii.h> +#include <util/string/cast.h> + +#define JSON_TO_FIELD(EProtoCppType, name, json, JsonCheckType, ProtoSet, JsonGet) \ + case FieldDescriptor::EProtoCppType: { \ + if (config.CastRobust) { \ + reflection->ProtoSet(&proto, &field, json.JsonGet##Robust()); \ + break; \ + } \ + if (!json.JsonCheckType()) { \ + if (config.CastFromString && json.IsString()) { \ + if (config.DoNotCastEmptyStrings && json.GetString().empty()) { \ + /* Empty string is same as "no value" for scalar types.*/ \ + break; \ + } \ + reflection->ProtoSet(&proto, &field, FromString(json.GetString())); \ + break; \ + } \ + ythrow yexception() << "Invalid type of JSON field " << name << ": " \ + << #JsonCheckType << "() failed while " \ + << #EProtoCppType << " is expected."; \ + } \ + reflection->ProtoSet(&proto, &field, json.JsonGet()); \ + break; \ + } + +static TString GetFieldName(const google::protobuf::FieldDescriptor& field, + const NProtobufJson::TJson2ProtoConfig& config) { + if (config.NameGenerator) { + return config.NameGenerator(field); + } + + if (config.UseJsonName) { + Y_ASSERT(!field.json_name().empty()); + TString name = field.json_name(); + if (!field.has_json_name() && !name.empty()) { + // FIXME: https://st.yandex-team.ru/CONTRIB-139 + name[0] = AsciiToLower(name[0]); + } + return name; + } + + TString name = field.name(); + switch (config.FieldNameMode) { + case NProtobufJson::TJson2ProtoConfig::FieldNameOriginalCase: + break; + case NProtobufJson::TJson2ProtoConfig::FieldNameLowerCase: + name.to_lower(); + break; + case NProtobufJson::TJson2ProtoConfig::FieldNameUpperCase: + name.to_upper(); + break; + case NProtobufJson::TJson2ProtoConfig::FieldNameCamelCase: + if (!name.empty()) { + name[0] = AsciiToLower(name[0]); + } + break; + case NProtobufJson::TJson2ProtoConfig::FieldNameSnakeCase: + NProtobufJson::ToSnakeCase(&name); + break; + case NProtobufJson::TJson2ProtoConfig::FieldNameSnakeCaseDense: + NProtobufJson::ToSnakeCaseDense(&name); + break; + default: + Y_VERIFY_DEBUG(false, "Unknown FieldNameMode."); + } + return name; +} + +static void +JsonString2Field(const NJson::TJsonValue& json, + google::protobuf::Message& proto, + const google::protobuf::FieldDescriptor& field, + const NProtobufJson::TJson2ProtoConfig& config) { + using namespace google::protobuf; + + const Reflection* reflection = proto.GetReflection(); + Y_ASSERT(!!reflection); + + if (!json.IsString() && !config.CastRobust) { + ythrow yexception() << "Invalid type of JSON field '" << field.name() << "': " + << "IsString() failed while " + << "CPPTYPE_STRING is expected."; + } + TString value = json.GetStringRobust(); + for (size_t i = 0, endI = config.StringTransforms.size(); i < endI; ++i) { + Y_ASSERT(!!config.StringTransforms[i]); + if (!!config.StringTransforms[i]) { + if (field.type() == google::protobuf::FieldDescriptor::TYPE_BYTES) { + config.StringTransforms[i]->TransformBytes(value); + } else { + config.StringTransforms[i]->Transform(value); + } + } + } + + if (field.is_repeated()) + reflection->AddString(&proto, &field, value); + else + reflection->SetString(&proto, &field, value); +} + +static const NProtoBuf::EnumValueDescriptor* +FindEnumValue(const NProtoBuf::EnumDescriptor* enumField, + TStringBuf target, bool (*equals)(TStringBuf, TStringBuf)) { + for (int i = 0; i < enumField->value_count(); i++) { + auto* valueDescriptor = enumField->value(i); + if (equals(valueDescriptor->name(), target)) { + return valueDescriptor; + } + } + return nullptr; +} + +static void +JsonEnum2Field(const NJson::TJsonValue& json, + google::protobuf::Message& proto, + const google::protobuf::FieldDescriptor& field, + const NProtobufJson::TJson2ProtoConfig& config) { + using namespace google::protobuf; + + const Reflection* reflection = proto.GetReflection(); + Y_ASSERT(!!reflection); + + const EnumDescriptor* enumField = field.enum_type(); + Y_ASSERT(!!enumField); + + /// @todo configure name/numerical value + const EnumValueDescriptor* enumFieldValue = nullptr; + + if (json.IsInteger()) { + const auto value = json.GetInteger(); + enumFieldValue = enumField->FindValueByNumber(value); + if (!enumFieldValue) { + ythrow yexception() << "Invalid integer value of JSON enum field: " << value << "."; + } + } else if (json.IsString()) { + const auto& value = json.GetString(); + if (config.EnumValueMode == NProtobufJson::TJson2ProtoConfig::EnumCaseInsensetive) { + enumFieldValue = FindEnumValue(enumField, value, AsciiEqualsIgnoreCase); + } else if (config.EnumValueMode == NProtobufJson::TJson2ProtoConfig::EnumSnakeCaseInsensitive) { + enumFieldValue = FindEnumValue(enumField, value, NProtobufJson::EqualsIgnoringCaseAndUnderscores); + } else { + enumFieldValue = enumField->FindValueByName(value); + } + if (!enumFieldValue) { + ythrow yexception() << "Invalid string value of JSON enum field: " << TStringBuf(value).Head(100) << "."; + } + } else { + ythrow yexception() << "Invalid type of JSON enum field: not an integer/string."; + } + + if (field.is_repeated()) { + reflection->AddEnum(&proto, &field, enumFieldValue); + } else { + reflection->SetEnum(&proto, &field, enumFieldValue); + } +} + +static void +Json2SingleField(const NJson::TJsonValue& json, + google::protobuf::Message& proto, + const google::protobuf::FieldDescriptor& field, + const NProtobufJson::TJson2ProtoConfig& config, + bool isMapValue = false) { + using namespace google::protobuf; + + const Reflection* reflection = proto.GetReflection(); + Y_ASSERT(!!reflection); + + TString name; + if (!isMapValue) { + name = GetFieldName(field, config); + if (!json.Has(name) || json[name].GetType() == NJson::JSON_UNDEFINED || json[name].GetType() == NJson::JSON_NULL) { + if (field.is_required() && !field.has_default_value() && !reflection->HasField(proto, &field) && config.CheckRequiredFields) { + ythrow yexception() << "JSON has no field for required field " + << name << "."; + } + + return; + } + } + + const NJson::TJsonValue& fieldJson = name ? json[name] : json; + + switch (field.cpp_type()) { + JSON_TO_FIELD(CPPTYPE_INT32, field.name(), fieldJson, IsInteger, SetInt32, GetInteger); + JSON_TO_FIELD(CPPTYPE_INT64, field.name(), fieldJson, IsInteger, SetInt64, GetInteger); + JSON_TO_FIELD(CPPTYPE_UINT32, field.name(), fieldJson, IsInteger, SetUInt32, GetInteger); + JSON_TO_FIELD(CPPTYPE_UINT64, field.name(), fieldJson, IsUInteger, SetUInt64, GetUInteger); + JSON_TO_FIELD(CPPTYPE_DOUBLE, field.name(), fieldJson, IsDouble, SetDouble, GetDouble); + JSON_TO_FIELD(CPPTYPE_FLOAT, field.name(), fieldJson, IsDouble, SetFloat, GetDouble); + JSON_TO_FIELD(CPPTYPE_BOOL, field.name(), fieldJson, IsBoolean, SetBool, GetBoolean); + + case FieldDescriptor::CPPTYPE_STRING: { + JsonString2Field(fieldJson, proto, field, config); + break; + } + + case FieldDescriptor::CPPTYPE_ENUM: { + JsonEnum2Field(fieldJson, proto, field, config); + break; + } + + case FieldDescriptor::CPPTYPE_MESSAGE: { + Message* innerProto = reflection->MutableMessage(&proto, &field); + Y_ASSERT(!!innerProto); + NProtobufJson::MergeJson2Proto(fieldJson, *innerProto, config); + + break; + } + + default: + ythrow yexception() << "Unknown protobuf field type: " + << static_cast<int>(field.cpp_type()) << "."; + } +} + +static void +SetKey(NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field, + const TString& key) { + using namespace google::protobuf; + using namespace NProtobufJson; + + const Reflection* reflection = proto.GetReflection(); + TString result; + switch (field.cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + reflection->SetInt32(&proto, &field, FromString<int32>(key)); + break; + case FieldDescriptor::CPPTYPE_INT64: + reflection->SetInt64(&proto, &field, FromString<int64>(key)); + break; + case FieldDescriptor::CPPTYPE_UINT32: + reflection->SetUInt32(&proto, &field, FromString<uint32>(key)); + break; + case FieldDescriptor::CPPTYPE_UINT64: + reflection->SetUInt64(&proto, &field, FromString<uint64>(key)); + break; + case FieldDescriptor::CPPTYPE_BOOL: + reflection->SetBool(&proto, &field, FromString<bool>(key)); + break; + case FieldDescriptor::CPPTYPE_STRING: + reflection->SetString(&proto, &field, key); + break; + default: + ythrow yexception() << "Unsupported key type."; + } +} + +static void +Json2RepeatedFieldValue(const NJson::TJsonValue& jsonValue, + google::protobuf::Message& proto, + const google::protobuf::FieldDescriptor& field, + const NProtobufJson::TJson2ProtoConfig& config, + const google::protobuf::Reflection* reflection, + const TMaybe<TString>& key = {}) { + using namespace google::protobuf; + + switch (field.cpp_type()) { + JSON_TO_FIELD(CPPTYPE_INT32, field.name(), jsonValue, IsInteger, AddInt32, GetInteger); + JSON_TO_FIELD(CPPTYPE_INT64, field.name(), jsonValue, IsInteger, AddInt64, GetInteger); + JSON_TO_FIELD(CPPTYPE_UINT32, field.name(), jsonValue, IsInteger, AddUInt32, GetInteger); + JSON_TO_FIELD(CPPTYPE_UINT64, field.name(), jsonValue, IsUInteger, AddUInt64, GetUInteger); + JSON_TO_FIELD(CPPTYPE_DOUBLE, field.name(), jsonValue, IsDouble, AddDouble, GetDouble); + JSON_TO_FIELD(CPPTYPE_FLOAT, field.name(), jsonValue, IsDouble, AddFloat, GetDouble); + JSON_TO_FIELD(CPPTYPE_BOOL, field.name(), jsonValue, IsBoolean, AddBool, GetBoolean); + + case FieldDescriptor::CPPTYPE_STRING: { + JsonString2Field(jsonValue, proto, field, config); + break; + } + + case FieldDescriptor::CPPTYPE_ENUM: { + JsonEnum2Field(jsonValue, proto, field, config); + break; + } + + case FieldDescriptor::CPPTYPE_MESSAGE: { + Message* innerProto = reflection->AddMessage(&proto, &field); + Y_ASSERT(!!innerProto); + if (key.Defined()) { + const FieldDescriptor* keyField = innerProto->GetDescriptor()->FindFieldByName("key"); + Y_ENSURE(keyField, "Map entry key field not found: " << field.name()); + SetKey(*innerProto, *keyField, *key); + + const FieldDescriptor* valueField = innerProto->GetDescriptor()->FindFieldByName("value"); + Y_ENSURE(valueField, "Map entry value field not found."); + Json2SingleField(jsonValue, *innerProto, *valueField, config, /*isMapValue=*/true); + } else { + NProtobufJson::MergeJson2Proto(jsonValue, *innerProto, config); + } + + break; + } + + default: + ythrow yexception() << "Unknown protobuf field type: " + << static_cast<int>(field.cpp_type()) << "."; + } +} + +static void +Json2RepeatedField(const NJson::TJsonValue& json, + google::protobuf::Message& proto, + const google::protobuf::FieldDescriptor& field, + const NProtobufJson::TJson2ProtoConfig& config) { + using namespace google::protobuf; + + TString name = GetFieldName(field, config); + if (!json.Has(name)) + return; + + const NJson::TJsonValue& fieldJson = json[name]; + if (fieldJson.GetType() == NJson::JSON_UNDEFINED || fieldJson.GetType() == NJson::JSON_NULL) + return; + + bool isMap = fieldJson.GetType() == NJson::JSON_MAP; + if (isMap) { + if (!config.MapAsObject) { + ythrow yexception() << "Map as object representation is not allowed, field: " << field.name(); + } else if (!field.is_map() && !fieldJson.GetMap().empty()) { + ythrow yexception() << "Field " << field.name() << " is not a map."; + } + } + + if (fieldJson.GetType() != NJson::JSON_ARRAY && !config.MapAsObject && !config.VectorizeScalars && !config.ValueVectorizer) { + ythrow yexception() << "JSON field doesn't represent an array for " + << name + << "(actual type is " + << static_cast<int>(fieldJson.GetType()) << ")."; + } + + const Reflection* reflection = proto.GetReflection(); + Y_ASSERT(!!reflection); + + if (isMap) { + const THashMap<TString, NJson::TJsonValue> jsonMap = fieldJson.GetMap(); + for (const auto& x : jsonMap) { + const TString& key = x.first; + const NJson::TJsonValue& jsonValue = x.second; + Json2RepeatedFieldValue(jsonValue, proto, field, config, reflection, key); + } + } else { + if (config.ReplaceRepeatedFields) { + reflection->ClearField(&proto, &field); + } + if (fieldJson.GetType() == NJson::JSON_ARRAY) { + const NJson::TJsonValue::TArray& jsonArray = fieldJson.GetArray(); + for (const NJson::TJsonValue& jsonValue : jsonArray) { + Json2RepeatedFieldValue(jsonValue, proto, field, config, reflection); + } + } else if (config.ValueVectorizer) { + for (const NJson::TJsonValue& jsonValue : config.ValueVectorizer(fieldJson)) { + Json2RepeatedFieldValue(jsonValue, proto, field, config, reflection); + } + } else if (config.VectorizeScalars) { + Json2RepeatedFieldValue(fieldJson, proto, field, config, reflection); + } + } +} + +namespace NProtobufJson { + void MergeJson2Proto(const NJson::TJsonValue& json, google::protobuf::Message& proto, const TJson2ProtoConfig& config) { + if (json.IsNull()) { + return; + } + + Y_ENSURE(json.IsMap(), "expected json map"); + + const google::protobuf::Descriptor* descriptor = proto.GetDescriptor(); + Y_ASSERT(!!descriptor); + + for (int f = 0, endF = descriptor->field_count(); f < endF; ++f) { + const google::protobuf::FieldDescriptor* field = descriptor->field(f); + Y_ASSERT(!!field); + + if (field->is_repeated()) { + Json2RepeatedField(json, proto, *field, config); + } else { + Json2SingleField(json, proto, *field, config); + } + } + + if (!config.AllowUnknownFields) { + THashMap<TString, bool> knownFields; + for (int f = 0, endF = descriptor->field_count(); f < endF; ++f) { + const google::protobuf::FieldDescriptor* field = descriptor->field(f); + knownFields[GetFieldName(*field, config)] = 1; + } + for (const auto& f : json.GetMap()) { + Y_ENSURE(knownFields.contains(f.first), "unknown field " << f.first); + } + } + } + + void MergeJson2Proto(const TStringBuf& json, google::protobuf::Message& proto, const TJson2ProtoConfig& config) { + NJson::TJsonReaderConfig jsonCfg; + jsonCfg.DontValidateUtf8 = true; + jsonCfg.AllowComments = config.AllowComments; + + NJson::TJsonValue jsonValue; + ReadJsonTree(json, &jsonCfg, &jsonValue, /* throwOnError = */ true); + + MergeJson2Proto(jsonValue, proto, config); + } + + void Json2Proto(const NJson::TJsonValue& json, google::protobuf::Message& proto, const TJson2ProtoConfig& config) { + proto.Clear(); + MergeJson2Proto(json, proto, config); + } + + void Json2Proto(const TStringBuf& json, google::protobuf::Message& proto, const TJson2ProtoConfig& config) { + proto.Clear(); + MergeJson2Proto(json, proto, config); + } +} diff --git a/library/cpp/protobuf/json/json2proto.h b/library/cpp/protobuf/json/json2proto.h new file mode 100644 index 0000000000..4c33498dfa --- /dev/null +++ b/library/cpp/protobuf/json/json2proto.h @@ -0,0 +1,222 @@ +#pragma once + +#include "string_transform.h" +#include "name_generator.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/json/json_value.h> + +#include <util/stream/input.h> +#include <util/stream/str.h> +#include <util/stream/mem.h> + +namespace google { + namespace protobuf { + class Message; + } +} + +namespace NProtobufJson { + struct TJson2ProtoConfig { + using TSelf = TJson2ProtoConfig; + using TValueVectorizer = std::function<NJson::TJsonValue::TArray(const NJson::TJsonValue& jsonValue)>; + + enum FldNameMode { + FieldNameOriginalCase = 0, // default + FieldNameLowerCase, + FieldNameUpperCase, + FieldNameCamelCase, + FieldNameSnakeCase, // ABC -> a_b_c, UserID -> user_i_d + FieldNameSnakeCaseDense // ABC -> abc, UserID -> user_id + }; + + enum EnumValueMode { + EnumCaseSensetive = 0, // default + EnumCaseInsensetive, + EnumSnakeCaseInsensitive + }; + + TSelf& SetFieldNameMode(FldNameMode mode) { + Y_ENSURE(mode == FieldNameOriginalCase || !UseJsonName, "FieldNameMode and UseJsonName are mutually exclusive"); + FieldNameMode = mode; + return *this; + } + + TSelf& SetUseJsonName(bool jsonName) { + Y_ENSURE(!jsonName || FieldNameMode == FieldNameOriginalCase, "FieldNameMode and UseJsonName are mutually exclusive"); + UseJsonName = jsonName; + return *this; + } + + TSelf& AddStringTransform(TStringTransformPtr transform) { + StringTransforms.push_back(transform); + return *this; + } + + TSelf& SetCastFromString(bool cast) { + CastFromString = cast; + return *this; + } + + TSelf& SetDoNotCastEmptyStrings(bool cast) { + DoNotCastEmptyStrings = cast; + return *this; + } + + TSelf& SetCastRobust(bool cast) { + CastRobust = cast; + return *this; + } + + TSelf& SetMapAsObject(bool mapAsObject) { + MapAsObject = mapAsObject; + return *this; + } + + TSelf& SetReplaceRepeatedFields(bool replaceRepeatedFields) { + ReplaceRepeatedFields = replaceRepeatedFields; + return *this; + } + + TSelf& SetNameGenerator(TNameGenerator callback) { + NameGenerator = callback; + return *this; + } + + TSelf& SetEnumValueMode(EnumValueMode enumValueMode) { + EnumValueMode = enumValueMode; + return *this; + } + + TSelf& SetVectorizeScalars(bool vectorizeScalars) { + VectorizeScalars = vectorizeScalars; + return *this; + } + + TSelf& SetAllowComments(bool value) { + AllowComments = value; + return *this; + } + + TSelf& SetAllowUnknownFields(bool value) { + AllowUnknownFields = value; + return *this; + } + + FldNameMode FieldNameMode = FieldNameOriginalCase; + bool AllowUnknownFields = true; + + /// Use 'json_name' protobuf option for field name, mutually exclusive + /// with FieldNameMode. + bool UseJsonName = false; + + /// Transforms will be applied only to string values (== protobuf fields of string / bytes type). + TVector<TStringTransformPtr> StringTransforms; + + /// Cast string json values to protobuf field type + bool CastFromString = false; + /// Skip empty strings, instead casting from string into scalar types. + /// I.e. empty string like default value for scalar types. + bool DoNotCastEmptyStrings = false; + /// Cast all json values to protobuf field types + bool CastRobust = false; + + /// Consider map to be an object, otherwise consider it to be an array of key/value objects + bool MapAsObject = false; + + /// Throw exception if there is no required fields in json object. + bool CheckRequiredFields = true; + + /// Replace repeated fields content during merging + bool ReplaceRepeatedFields = false; + + /// Custom field names generator. + TNameGenerator NameGenerator = {}; + + /// Enum value parsing mode. + EnumValueMode EnumValueMode = EnumCaseSensetive; + + /// Append scalars to repeated fields + bool VectorizeScalars = false; + + /// Custom spliter non array value to repeated fields. + TValueVectorizer ValueVectorizer; + + /// Allow js-style comments (both // and /**/) + bool AllowComments = false; + }; + + /// @throw yexception + void MergeJson2Proto(const NJson::TJsonValue& json, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()); + + /// @throw yexception + void MergeJson2Proto(const TStringBuf& json, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()); + + /// @throw yexception + inline void MergeJson2Proto(const TString& json, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + MergeJson2Proto(TStringBuf(json), proto, config); + } + + /// @throw yexception + void Json2Proto(const NJson::TJsonValue& json, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()); + + /// @throw yexception + void Json2Proto(const TStringBuf& json, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()); + + /// @throw yexception + inline void Json2Proto(const TString& json, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + Json2Proto(TStringBuf(json), proto, config); + } + + /// @throw yexception + inline void Json2Proto(IInputStream& in, google::protobuf::Message& proto, + const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + Json2Proto(TStringBuf(in.ReadAll()), proto, config); + } + + /// @throw yexception + template <typename T> + T Json2Proto(IInputStream& in, const NJson::TJsonReaderConfig& readerConfig, + const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + NJson::TJsonValue jsonValue; + NJson::ReadJsonTree(&in, &readerConfig, &jsonValue, true); + T protoValue; + Json2Proto(jsonValue, protoValue, config); + return protoValue; + } + + /// @throw yexception + template <typename T> + T Json2Proto(IInputStream& in, const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + NJson::TJsonReaderConfig readerConfig; + readerConfig.DontValidateUtf8 = true; + return Json2Proto<T>(in, readerConfig, config); + } + + /// @throw yexception + template <typename T> + T Json2Proto(const TString& value, const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + TStringInput in(value); + return Json2Proto<T>(in, config); + } + + /// @throw yexception + template <typename T> + T Json2Proto(const TStringBuf& value, const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + TMemoryInput in(value); + return Json2Proto<T>(in, config); + } + + /// @throw yexception + template <typename T> + T Json2Proto(const char* ptr, const TJson2ProtoConfig& config = TJson2ProtoConfig()) { + return Json2Proto<T>(TStringBuf(ptr), config); + } + +} diff --git a/library/cpp/protobuf/json/json_output.h b/library/cpp/protobuf/json/json_output.h new file mode 100644 index 0000000000..df143af57a --- /dev/null +++ b/library/cpp/protobuf/json/json_output.h @@ -0,0 +1,79 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> + +namespace NProtobufJson { + class IJsonOutput { + public: + template <typename T> + IJsonOutput& Write(const T& t) { + DoWrite(t); + return *this; + } + IJsonOutput& WriteNull() { + DoWriteNull(); + return *this; + } + + IJsonOutput& BeginList() { + DoBeginList(); + return *this; + } + IJsonOutput& EndList() { + DoEndList(); + return *this; + } + + IJsonOutput& BeginObject() { + DoBeginObject(); + return *this; + } + IJsonOutput& WriteKey(const TStringBuf& key) { + DoWriteKey(key); + return *this; + } + IJsonOutput& EndObject() { + DoEndObject(); + return *this; + } + + IJsonOutput& WriteRawJson(const TStringBuf& str) { + DoWriteRawJson(str); + return *this; + } + + virtual ~IJsonOutput() { + } + + protected: + virtual void DoWrite(const TStringBuf& s) = 0; + virtual void DoWrite(const TString& s) = 0; + virtual void DoWrite(int i) = 0; + void DoWrite(long i) { + DoWrite(static_cast<long long>(i)); + } + virtual void DoWrite(long long i) = 0; + virtual void DoWrite(unsigned int i) = 0; + void DoWrite(unsigned long i) { + DoWrite(static_cast<unsigned long long>(i)); + } + virtual void DoWrite(unsigned long long i) = 0; + virtual void DoWrite(float f) = 0; + virtual void DoWrite(double f) = 0; + virtual void DoWrite(bool b) = 0; + virtual void DoWriteNull() = 0; + + virtual void DoBeginList() = 0; + virtual void DoEndList() = 0; + + virtual void DoBeginObject() = 0; + virtual void DoWriteKey(const TStringBuf& key) = 0; + virtual void DoEndObject() = 0; + + virtual void DoWriteRawJson(const TStringBuf& str) = 0; + }; + + using TJsonMapOutputPtr = THolder<IJsonOutput>; + +} diff --git a/library/cpp/protobuf/json/json_output_create.cpp b/library/cpp/protobuf/json/json_output_create.cpp new file mode 100644 index 0000000000..378e4ea65a --- /dev/null +++ b/library/cpp/protobuf/json/json_output_create.cpp @@ -0,0 +1,32 @@ +#include "json_output_create.h" + +#include "config.h" +#include "json_writer_output.h" +#include "json_value_output.h" + +namespace NProtobufJson { + TJsonMapOutputPtr CreateJsonMapOutput(IOutputStream& out, const NJson::TJsonWriterConfig& config) { + return MakeHolder<TJsonWriterOutput>(&out, config); + } + + TJsonMapOutputPtr CreateJsonMapOutput(NJson::TJsonWriter& writer) { + return MakeHolder<TBaseJsonWriterOutput>(writer); + } + + TJsonMapOutputPtr CreateJsonMapOutput(TString& str, const TProto2JsonConfig& config) { + return MakeHolder<TJsonStringWriterOutput>(&str, config); + } + + TJsonMapOutputPtr CreateJsonMapOutput(TStringStream& out, const TProto2JsonConfig& config) { + return MakeHolder<TJsonWriterOutput>(&out, config); + } + + TJsonMapOutputPtr CreateJsonMapOutput(IOutputStream& out, const TProto2JsonConfig& config) { + return MakeHolder<TJsonWriterOutput>(&out, config); + } + + TJsonMapOutputPtr CreateJsonMapOutput(NJson::TJsonValue& json) { + return MakeHolder<TJsonValueOutput>(json); + } + +} diff --git a/library/cpp/protobuf/json/json_output_create.h b/library/cpp/protobuf/json/json_output_create.h new file mode 100644 index 0000000000..ad3889f5e9 --- /dev/null +++ b/library/cpp/protobuf/json/json_output_create.h @@ -0,0 +1,22 @@ +#pragma once + +#include "config.h" +#include "json_output.h" + +namespace NJson { + class TJsonValue; + class TJsonWriter; + struct TJsonWriterConfig; +} + +class IOutputStream; +class TStringStream; + +namespace NProtobufJson { + TJsonMapOutputPtr CreateJsonMapOutput(IOutputStream& out, const NJson::TJsonWriterConfig& config); + TJsonMapOutputPtr CreateJsonMapOutput(NJson::TJsonWriter& writer); + TJsonMapOutputPtr CreateJsonMapOutput(IOutputStream& out, const TProto2JsonConfig& config = TProto2JsonConfig()); + TJsonMapOutputPtr CreateJsonMapOutput(TString& str, const TProto2JsonConfig& config = TProto2JsonConfig()); + TJsonMapOutputPtr CreateJsonMapOutput(NJson::TJsonValue& json); + +} diff --git a/library/cpp/protobuf/json/json_value_output.cpp b/library/cpp/protobuf/json/json_value_output.cpp new file mode 100644 index 0000000000..d845cc1c74 --- /dev/null +++ b/library/cpp/protobuf/json/json_value_output.cpp @@ -0,0 +1,106 @@ +#include "json_value_output.h" + +#include <library/cpp/json/json_reader.h> + +namespace NProtobufJson { + template <typename T> + void TJsonValueOutput::WriteImpl(const T& t) { + Y_ASSERT(Context.top().Type == TContext::JSON_ARRAY || Context.top().Type == TContext::JSON_AFTER_KEY); + + if (Context.top().Type == TContext::JSON_AFTER_KEY) { + Context.top().Value = t; + Context.pop(); + } else { + Context.top().Value.AppendValue(t); + } + } + + void TJsonValueOutput::DoWrite(const TStringBuf& s) { + WriteImpl(s); + } + + void TJsonValueOutput::DoWrite(const TString& s) { + WriteImpl(s); + } + + void TJsonValueOutput::DoWrite(int i) { + WriteImpl(i); + } + + void TJsonValueOutput::DoWrite(unsigned int i) { + WriteImpl(i); + } + + void TJsonValueOutput::DoWrite(long long i) { + WriteImpl(i); + } + + void TJsonValueOutput::DoWrite(unsigned long long i) { + WriteImpl(i); + } + + void TJsonValueOutput::DoWrite(float f) { + WriteImpl(f); + } + + void TJsonValueOutput::DoWrite(double f) { + WriteImpl(f); + } + + void TJsonValueOutput::DoWrite(bool b) { + WriteImpl(b); + } + + void TJsonValueOutput::DoWriteNull() { + WriteImpl(NJson::JSON_NULL); + } + + void TJsonValueOutput::DoBeginList() { + Y_ASSERT(Context.top().Type == TContext::JSON_ARRAY || Context.top().Type == TContext::JSON_AFTER_KEY); + + if (Context.top().Type == TContext::JSON_AFTER_KEY) { + Context.top().Type = TContext::JSON_ARRAY; + Context.top().Value.SetType(NJson::JSON_ARRAY); + } else { + Context.emplace(TContext::JSON_ARRAY, Context.top().Value.AppendValue(NJson::JSON_ARRAY)); + } + } + + void TJsonValueOutput::DoEndList() { + Y_ASSERT(Context.top().Type == TContext::JSON_ARRAY); + Context.pop(); + } + + void TJsonValueOutput::DoBeginObject() { + Y_ASSERT(Context.top().Type == TContext::JSON_ARRAY || Context.top().Type == TContext::JSON_AFTER_KEY); + + if (Context.top().Type == TContext::JSON_AFTER_KEY) { + Context.top().Type = TContext::JSON_MAP; + Context.top().Value.SetType(NJson::JSON_MAP); + } else { + Context.emplace(TContext::JSON_MAP, Context.top().Value.AppendValue(NJson::JSON_MAP)); + } + } + + void TJsonValueOutput::DoWriteKey(const TStringBuf& key) { + Y_ASSERT(Context.top().Type == TContext::JSON_MAP); + Context.emplace(TContext::JSON_AFTER_KEY, Context.top().Value[key]); + } + + void TJsonValueOutput::DoEndObject() { + Y_ASSERT(Context.top().Type == TContext::JSON_MAP); + Context.pop(); + } + + void TJsonValueOutput::DoWriteRawJson(const TStringBuf& str) { + Y_ASSERT(Context.top().Type == TContext::JSON_ARRAY || Context.top().Type == TContext::JSON_AFTER_KEY); + + if (Context.top().Type == TContext::JSON_AFTER_KEY) { + NJson::ReadJsonTree(str, &Context.top().Value); + Context.pop(); + } else { + NJson::ReadJsonTree(str, &Context.top().Value.AppendValue(NJson::JSON_UNDEFINED)); + } + } + +} diff --git a/library/cpp/protobuf/json/json_value_output.h b/library/cpp/protobuf/json/json_value_output.h new file mode 100644 index 0000000000..3fc6ff2ab0 --- /dev/null +++ b/library/cpp/protobuf/json/json_value_output.h @@ -0,0 +1,63 @@ +#pragma once + +#include "json_output.h" + +#include <library/cpp/json/writer/json_value.h> + +#include <util/generic/stack.h> + +namespace NProtobufJson { + class TJsonValueOutput: public IJsonOutput { + public: + TJsonValueOutput(NJson::TJsonValue& value) + : Root(value) + { + Context.emplace(TContext::JSON_AFTER_KEY, Root); + } + + void DoWrite(const TStringBuf& s) override; + void DoWrite(const TString& s) override; + void DoWrite(int i) override; + void DoWrite(unsigned int i) override; + void DoWrite(long long i) override; + void DoWrite(unsigned long long i) override; + void DoWrite(float f) override; + void DoWrite(double f) override; + void DoWrite(bool b) override; + void DoWriteNull() override; + + void DoBeginList() override; + void DoEndList() override; + + void DoBeginObject() override; + void DoWriteKey(const TStringBuf& key) override; + void DoEndObject() override; + + void DoWriteRawJson(const TStringBuf& str) override; + + private: + template <typename T> + void WriteImpl(const T& t); + + struct TContext { + enum EType { + JSON_MAP, + JSON_ARRAY, + JSON_AFTER_KEY, + }; + + TContext(EType type, NJson::TJsonValue& value) + : Type(type) + , Value(value) + { + } + + EType Type; + NJson::TJsonValue& Value; + }; + + NJson::TJsonValue& Root; + TStack<TContext, TVector<TContext>> Context; + }; + +} diff --git a/library/cpp/protobuf/json/json_writer_output.cpp b/library/cpp/protobuf/json/json_writer_output.cpp new file mode 100644 index 0000000000..288f645bab --- /dev/null +++ b/library/cpp/protobuf/json/json_writer_output.cpp @@ -0,0 +1,22 @@ +#include "json_writer_output.h" + +namespace NProtobufJson { + NJson::TJsonWriterConfig TJsonWriterOutput::CreateJsonWriterConfig(const TProto2JsonConfig& config) { + NJson::TJsonWriterConfig jsonConfig; + jsonConfig.FormatOutput = config.FormatOutput; + jsonConfig.SortKeys = false; + jsonConfig.ValidateUtf8 = false; + jsonConfig.DontEscapeStrings = false; + jsonConfig.WriteNanAsString = config.WriteNanAsString; + + for (size_t i = 0; i < config.StringTransforms.size(); ++i) { + Y_ASSERT(config.StringTransforms[i]); + if (config.StringTransforms[i]->GetType() == IStringTransform::EscapeTransform) { + jsonConfig.DontEscapeStrings = true; + break; + } + } + return jsonConfig; + } + +} diff --git a/library/cpp/protobuf/json/json_writer_output.h b/library/cpp/protobuf/json/json_writer_output.h new file mode 100644 index 0000000000..3d8a2daa56 --- /dev/null +++ b/library/cpp/protobuf/json/json_writer_output.h @@ -0,0 +1,103 @@ +#pragma once + +#include "json_output.h" +#include "config.h" + +#include <library/cpp/json/json_writer.h> + +#include <util/string/builder.h> +#include <util/generic/store_policy.h> + +namespace NProtobufJson { + class TBaseJsonWriterOutput: public IJsonOutput { + public: + TBaseJsonWriterOutput(NJson::TJsonWriter& writer) + : Writer(writer) + { + } + + private: + void DoWrite(int i) override { + Writer.Write(i); + } + void DoWrite(unsigned int i) override { + Writer.Write(i); + } + void DoWrite(long long i) override { + Writer.Write(i); + } + void DoWrite(unsigned long long i) override { + Writer.Write(i); + } + void DoWrite(float f) override { + Writer.Write(f); + } + void DoWrite(double f) override { + Writer.Write(f); + } + void DoWrite(bool b) override { + Writer.Write(b); + } + void DoWriteNull() override { + Writer.WriteNull(); + } + void DoWrite(const TStringBuf& s) override { + Writer.Write(s); + } + void DoWrite(const TString& s) override { + Writer.Write(s); + } + + void DoBeginList() override { + Writer.OpenArray(); + } + void DoEndList() override { + Writer.CloseArray(); + } + + void DoBeginObject() override { + Writer.OpenMap(); + } + void DoWriteKey(const TStringBuf& key) override { + Writer.Write(key); + } + void DoEndObject() override { + Writer.CloseMap(); + } + + void DoWriteRawJson(const TStringBuf& str) override { + Writer.UnsafeWrite(str); + } + + NJson::TJsonWriter& Writer; + }; + + class TJsonWriterOutput: public TEmbedPolicy<NJson::TJsonWriter>, public TBaseJsonWriterOutput { + public: + TJsonWriterOutput(IOutputStream* outputStream, const NJson::TJsonWriterConfig& cfg) + : TEmbedPolicy<NJson::TJsonWriter>(outputStream, cfg) + , TBaseJsonWriterOutput(*Ptr()) + { + } + + TJsonWriterOutput(IOutputStream* outputStream, const TProto2JsonConfig& cfg) + : TEmbedPolicy<NJson::TJsonWriter>(outputStream, CreateJsonWriterConfig(cfg)) + , TBaseJsonWriterOutput(*Ptr()) + { + } + + private: + static NJson::TJsonWriterConfig CreateJsonWriterConfig(const TProto2JsonConfig& cfg); + }; + + class TJsonStringWriterOutput: public TEmbedPolicy<TStringOutput>, public TJsonWriterOutput { + public: + template <typename TConfig> + TJsonStringWriterOutput(TString* str, const TConfig& cfg) + : TEmbedPolicy<TStringOutput>(*str) + , TJsonWriterOutput(TEmbedPolicy<TStringOutput>::Ptr(), cfg) + { + } + }; + +} diff --git a/library/cpp/protobuf/json/name_generator.cpp b/library/cpp/protobuf/json/name_generator.cpp new file mode 100644 index 0000000000..c1fb421175 --- /dev/null +++ b/library/cpp/protobuf/json/name_generator.cpp @@ -0,0 +1 @@ +#include "name_generator.h" diff --git a/library/cpp/protobuf/json/name_generator.h b/library/cpp/protobuf/json/name_generator.h new file mode 100644 index 0000000000..2b5361bee2 --- /dev/null +++ b/library/cpp/protobuf/json/name_generator.h @@ -0,0 +1,18 @@ +#pragma once + +#include <util/generic/string.h> + +#include <functional> + +namespace google { + namespace protobuf { + class FieldDescriptor; + class EnumValueDescriptor; + } +} + +namespace NProtobufJson { + using TNameGenerator = std::function<TString(const google::protobuf::FieldDescriptor&)>; + using TEnumValueGenerator = std::function<TString(const google::protobuf::EnumValueDescriptor&)>; + +} diff --git a/library/cpp/protobuf/json/proto2json.cpp b/library/cpp/protobuf/json/proto2json.cpp new file mode 100644 index 0000000000..3d76a91686 --- /dev/null +++ b/library/cpp/protobuf/json/proto2json.cpp @@ -0,0 +1,56 @@ +#include "proto2json.h" + +#include "json_output_create.h" +#include "proto2json_printer.h" + +#include <library/cpp/json/json_reader.h> +#include <library/cpp/json/json_value.h> +#include <library/cpp/json/json_writer.h> + +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> +#include <util/stream/output.h> +#include <util/stream/str.h> +#include <util/system/yassert.h> + +namespace NProtobufJson { + void Proto2Json(const NProtoBuf::Message& proto, IJsonOutput& jsonOutput, + const TProto2JsonConfig& config, bool closeMap) { + TProto2JsonPrinter printer(config); + printer.Print(proto, jsonOutput, closeMap); + } + + void Proto2Json(const NProtoBuf::Message& proto, NJson::TJsonValue& json, + const TProto2JsonConfig& config) { + Proto2Json(proto, *CreateJsonMapOutput(json), config); + } + + void Proto2Json(const NProtoBuf::Message& proto, NJson::TJsonWriter& writer, + const TProto2JsonConfig& config) { + Proto2Json(proto, *CreateJsonMapOutput(writer), config); + writer.Flush(); + } + + void Proto2Json(const NProtoBuf::Message& proto, IOutputStream& out, + const TProto2JsonConfig& config) { + Proto2Json(proto, *CreateJsonMapOutput(out, config), config); + } + + void Proto2Json(const NProtoBuf::Message& proto, TStringStream& out, + const TProto2JsonConfig& config) { + Proto2Json(proto, *CreateJsonMapOutput(out, config), config); + } + + void Proto2Json(const NProtoBuf::Message& proto, TString& str, + const TProto2JsonConfig& config) { + Proto2Json(proto, *CreateJsonMapOutput(str, config), config); + } + + TString Proto2Json(const ::NProtoBuf::Message& proto, + const TProto2JsonConfig& config) { + TString res; + Proto2Json(proto, res, config); + return res; + } + +} diff --git a/library/cpp/protobuf/json/proto2json.h b/library/cpp/protobuf/json/proto2json.h new file mode 100644 index 0000000000..89a1781a40 --- /dev/null +++ b/library/cpp/protobuf/json/proto2json.h @@ -0,0 +1,78 @@ +#pragma once + +#include "config.h" +#include "json_output.h" + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/message.h> + +#include <util/generic/fwd.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/stream/str.h> + +#include <functional> + +namespace NJson { + class TJsonValue; + class TJsonWriter; +} + +class IOutputStream; +class TStringStream; + +namespace NProtobufJson { + void Proto2Json(const NProtoBuf::Message& proto, IJsonOutput& jsonOutput, + const TProto2JsonConfig& config = TProto2JsonConfig(), bool closeMap = true); + + void Proto2Json(const NProtoBuf::Message& proto, NJson::TJsonWriter& writer, + const TProto2JsonConfig& config = TProto2JsonConfig()); + + /// @throw yexception + void Proto2Json(const NProtoBuf::Message& proto, NJson::TJsonValue& json, + const TProto2JsonConfig& config = TProto2JsonConfig()); + + /// @throw yexception + void Proto2Json(const NProtoBuf::Message& proto, IOutputStream& out, + const TProto2JsonConfig& config); + // Generated code shortcut + template <class T> + inline void Proto2Json(const T& proto, IOutputStream& out) { + out << proto.AsJSON(); + } + + // TStringStream deserves a special overload as its operator TString() would cause ambiguity + /// @throw yexception + void Proto2Json(const NProtoBuf::Message& proto, TStringStream& out, + const TProto2JsonConfig& config); + // Generated code shortcut + template <class T> + inline void Proto2Json(const T& proto, TStringStream& out) { + out << proto.AsJSON(); + } + + /// @throw yexception + void Proto2Json(const NProtoBuf::Message& proto, TString& str, + const TProto2JsonConfig& config); + // Generated code shortcut + template <class T> + inline void Proto2Json(const T& proto, TString& str) { + str.clear(); + TStringOutput out(str); + out << proto.AsJSON(); + } + + /// @throw yexception + TString Proto2Json(const NProtoBuf::Message& proto, + const TProto2JsonConfig& config); + // Returns incorrect result if proto contains another NProtoBuf::Message + // Generated code shortcut + template <class T> + inline TString Proto2Json(const T& proto) { + TString result; + Proto2Json(proto, result); + return result; + } + +} diff --git a/library/cpp/protobuf/json/proto2json_printer.cpp b/library/cpp/protobuf/json/proto2json_printer.cpp new file mode 100644 index 0000000000..6123eab0f2 --- /dev/null +++ b/library/cpp/protobuf/json/proto2json_printer.cpp @@ -0,0 +1,517 @@ +#include "proto2json_printer.h" +#include "config.h" +#include "util.h" + +#include <util/generic/yexception.h> +#include <util/string/ascii.h> +#include <util/string/cast.h> + +namespace NProtobufJson { + using namespace NProtoBuf; + + class TJsonKeyBuilder { + public: + TJsonKeyBuilder(const FieldDescriptor& field, const TProto2JsonConfig& config, TString& tmpBuf) + : NewKeyStr(tmpBuf) + { + if (config.NameGenerator) { + NewKeyStr = config.NameGenerator(field); + NewKeyBuf = NewKeyStr; + return; + } + + if (config.UseJsonName) { + Y_ASSERT(!field.json_name().empty()); + NewKeyStr = field.json_name(); + if (!field.has_json_name() && !NewKeyStr.empty()) { + // FIXME: https://st.yandex-team.ru/CONTRIB-139 + NewKeyStr[0] = AsciiToLower(NewKeyStr[0]); + } + NewKeyBuf = NewKeyStr; + return; + } + + switch (config.FieldNameMode) { + case TProto2JsonConfig::FieldNameOriginalCase: { + NewKeyBuf = field.name(); + break; + } + + case TProto2JsonConfig::FieldNameLowerCase: { + NewKeyStr = field.name(); + NewKeyStr.to_lower(); + NewKeyBuf = NewKeyStr; + break; + } + + case TProto2JsonConfig::FieldNameUpperCase: { + NewKeyStr = field.name(); + NewKeyStr.to_upper(); + NewKeyBuf = NewKeyStr; + break; + } + + case TProto2JsonConfig::FieldNameCamelCase: { + NewKeyStr = field.name(); + if (!NewKeyStr.empty()) { + NewKeyStr[0] = AsciiToLower(NewKeyStr[0]); + } + NewKeyBuf = NewKeyStr; + break; + } + + case TProto2JsonConfig::FieldNameSnakeCase: { + NewKeyStr = field.name(); + ToSnakeCase(&NewKeyStr); + NewKeyBuf = NewKeyStr; + break; + } + + case TProto2JsonConfig::FieldNameSnakeCaseDense: { + NewKeyStr = field.name(); + ToSnakeCaseDense(&NewKeyStr); + NewKeyBuf = NewKeyStr; + break; + } + + default: + Y_VERIFY_DEBUG(false, "Unknown FieldNameMode."); + } + } + + const TStringBuf& GetKey() const { + return NewKeyBuf; + } + + private: + TStringBuf NewKeyBuf; + TString& NewKeyStr; + }; + + TProto2JsonPrinter::TProto2JsonPrinter(const TProto2JsonConfig& cfg) + : Config(cfg) + { + } + + TProto2JsonPrinter::~TProto2JsonPrinter() { + } + + TStringBuf TProto2JsonPrinter::MakeKey(const FieldDescriptor& field) { + return TJsonKeyBuilder(field, GetConfig(), TmpBuf).GetKey(); + } + + template <bool InMapContext, typename T> + std::enable_if_t<InMapContext, void> WriteWithMaybeEmptyKey(IJsonOutput& json, const TStringBuf& key, const T& value) { + json.WriteKey(key).Write(value); + } + + template <bool InMapContext, typename T> + std::enable_if_t<!InMapContext, void> WriteWithMaybeEmptyKey(IJsonOutput& array, const TStringBuf& key, const T& value) { + Y_ASSERT(!key); + array.Write(value); + } + + template <bool InMapContext> + void TProto2JsonPrinter::PrintStringValue(const FieldDescriptor& field, + const TStringBuf& key, const TString& value, + IJsonOutput& json) { + if (!GetConfig().StringTransforms.empty()) { + TString tmpBuf = value; + for (const TStringTransformPtr& stringTransform : GetConfig().StringTransforms) { + Y_ASSERT(stringTransform); + if (stringTransform) { + if (field.type() == FieldDescriptor::TYPE_BYTES) + stringTransform->TransformBytes(tmpBuf); + else + stringTransform->Transform(tmpBuf); + } + } + WriteWithMaybeEmptyKey<InMapContext>(json, key, tmpBuf); + } else { + WriteWithMaybeEmptyKey<InMapContext>(json, key, value); + } + } + + template <bool InMapContext> + void TProto2JsonPrinter::PrintEnumValue(const TStringBuf& key, + const EnumValueDescriptor* value, + IJsonOutput& json) { + if (Config.EnumValueGenerator) { + WriteWithMaybeEmptyKey<InMapContext>(json, key, Config.EnumValueGenerator(*value)); + return; + } + + switch (GetConfig().EnumMode) { + case TProto2JsonConfig::EnumNumber: { + WriteWithMaybeEmptyKey<InMapContext>(json, key, value->number()); + break; + } + + case TProto2JsonConfig::EnumName: { + WriteWithMaybeEmptyKey<InMapContext>(json, key, value->name()); + break; + } + + case TProto2JsonConfig::EnumFullName: { + WriteWithMaybeEmptyKey<InMapContext>(json, key, value->full_name()); + break; + } + + case TProto2JsonConfig::EnumNameLowerCase: { + TString newName = value->name(); + newName.to_lower(); + WriteWithMaybeEmptyKey<InMapContext>(json, key, newName); + break; + } + + case TProto2JsonConfig::EnumFullNameLowerCase: { + TString newName = value->full_name(); + newName.to_lower(); + WriteWithMaybeEmptyKey<InMapContext>(json, key, newName); + break; + } + + default: + Y_VERIFY_DEBUG(false, "Unknown EnumMode."); + } + } + + void TProto2JsonPrinter::PrintSingleField(const Message& proto, + const FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key) { + Y_VERIFY(!field.is_repeated(), "field is repeated."); + + if (!key) { + key = MakeKey(field); + } + +#define FIELD_TO_JSON(EProtoCppType, ProtoGet) \ + case FieldDescriptor::EProtoCppType: { \ + json.WriteKey(key).Write(reflection->ProtoGet(proto, &field)); \ + break; \ + } + +#define INT_FIELD_TO_JSON(EProtoCppType, ProtoGet) \ + case FieldDescriptor::EProtoCppType: { \ + const auto value = reflection->ProtoGet(proto, &field); \ + if (NeedStringifyNumber(value)) { \ + json.WriteKey(key).Write(ToString(value)); \ + } else { \ + json.WriteKey(key).Write(value); \ + } \ + break; \ + } + + const Reflection* reflection = proto.GetReflection(); + + bool shouldPrintField = reflection->HasField(proto, &field); + if (!shouldPrintField && GetConfig().MissingSingleKeyMode == TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired) { + if (field.has_default_value()) { + shouldPrintField = true; + } else if (field.is_required()) { + ythrow yexception() << "Empty required protobuf field: " + << field.full_name() << "."; + } + } + shouldPrintField = shouldPrintField || GetConfig().MissingSingleKeyMode == TProto2JsonConfig::MissingKeyDefault; + + if (shouldPrintField) { + switch (field.cpp_type()) { + INT_FIELD_TO_JSON(CPPTYPE_INT32, GetInt32); + INT_FIELD_TO_JSON(CPPTYPE_INT64, GetInt64); + INT_FIELD_TO_JSON(CPPTYPE_UINT32, GetUInt32); + INT_FIELD_TO_JSON(CPPTYPE_UINT64, GetUInt64); + FIELD_TO_JSON(CPPTYPE_DOUBLE, GetDouble); + FIELD_TO_JSON(CPPTYPE_FLOAT, GetFloat); + FIELD_TO_JSON(CPPTYPE_BOOL, GetBool); + + case FieldDescriptor::CPPTYPE_MESSAGE: { + json.WriteKey(key); + Print(reflection->GetMessage(proto, &field), json); + break; + } + + case FieldDescriptor::CPPTYPE_ENUM: { + PrintEnumValue<true>(key, reflection->GetEnum(proto, &field), json); + break; + } + + case FieldDescriptor::CPPTYPE_STRING: { + TString scratch; + const TString& value = reflection->GetStringReference(proto, &field, &scratch); + PrintStringValue<true>(field, key, value, json); + break; + } + + default: + ythrow yexception() << "Unknown protobuf field type: " + << static_cast<int>(field.cpp_type()) << "."; + } + } else { + switch (GetConfig().MissingSingleKeyMode) { + case TProto2JsonConfig::MissingKeyNull: { + json.WriteKey(key).WriteNull(); + break; + } + + case TProto2JsonConfig::MissingKeySkip: + case TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired: + default: + break; + } + } +#undef FIELD_TO_JSON + } + + void TProto2JsonPrinter::PrintRepeatedField(const Message& proto, + const FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key) { + Y_VERIFY(field.is_repeated(), "field isn't repeated."); + + const bool isMap = field.is_map() && GetConfig().MapAsObject; + if (!key) { + key = MakeKey(field); + } + +#define REPEATED_FIELD_TO_JSON(EProtoCppType, ProtoGet) \ + case FieldDescriptor::EProtoCppType: { \ + for (size_t i = 0, endI = reflection->FieldSize(proto, &field); i < endI; ++i) \ + json.Write(reflection->ProtoGet(proto, &field, i)); \ + break; \ + } + + const Reflection* reflection = proto.GetReflection(); + + if (reflection->FieldSize(proto, &field) > 0) { + json.WriteKey(key); + if (isMap) { + json.BeginObject(); + } else { + json.BeginList(); + } + + switch (field.cpp_type()) { + REPEATED_FIELD_TO_JSON(CPPTYPE_INT32, GetRepeatedInt32); + REPEATED_FIELD_TO_JSON(CPPTYPE_INT64, GetRepeatedInt64); + REPEATED_FIELD_TO_JSON(CPPTYPE_UINT32, GetRepeatedUInt32); + REPEATED_FIELD_TO_JSON(CPPTYPE_UINT64, GetRepeatedUInt64); + REPEATED_FIELD_TO_JSON(CPPTYPE_DOUBLE, GetRepeatedDouble); + REPEATED_FIELD_TO_JSON(CPPTYPE_FLOAT, GetRepeatedFloat); + REPEATED_FIELD_TO_JSON(CPPTYPE_BOOL, GetRepeatedBool); + + case FieldDescriptor::CPPTYPE_MESSAGE: { + if (isMap) { + for (size_t i = 0, endI = reflection->FieldSize(proto, &field); i < endI; ++i) { + PrintKeyValue(reflection->GetRepeatedMessage(proto, &field, i), json); + } + } else { + for (size_t i = 0, endI = reflection->FieldSize(proto, &field); i < endI; ++i) { + Print(reflection->GetRepeatedMessage(proto, &field, i), json); + } + } + break; + } + + case FieldDescriptor::CPPTYPE_ENUM: { + for (int i = 0, endI = reflection->FieldSize(proto, &field); i < endI; ++i) + PrintEnumValue<false>(TStringBuf(), reflection->GetRepeatedEnum(proto, &field, i), json); + break; + } + + case FieldDescriptor::CPPTYPE_STRING: { + TString scratch; + for (int i = 0, endI = reflection->FieldSize(proto, &field); i < endI; ++i) { + const TString& value = + reflection->GetRepeatedStringReference(proto, &field, i, &scratch); + PrintStringValue<false>(field, TStringBuf(), value, json); + } + break; + } + + default: + ythrow yexception() << "Unknown protobuf field type: " + << static_cast<int>(field.cpp_type()) << "."; + } + + if (isMap) { + json.EndObject(); + } else { + json.EndList(); + } + } else { + switch (GetConfig().MissingRepeatedKeyMode) { + case TProto2JsonConfig::MissingKeyNull: { + json.WriteKey(key).WriteNull(); + break; + } + + case TProto2JsonConfig::MissingKeyDefault: { + json.WriteKey(key); + if (isMap) { + json.BeginObject().EndObject(); + } else { + json.BeginList().EndList(); + } + break; + } + + case TProto2JsonConfig::MissingKeySkip: + case TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired: + default: + break; + } + } + +#undef REPEATED_FIELD_TO_JSON + } + + void TProto2JsonPrinter::PrintKeyValue(const NProtoBuf::Message& proto, + IJsonOutput& json) { + const FieldDescriptor* keyField = proto.GetDescriptor()->FindFieldByName("key"); + Y_VERIFY(keyField, "Map entry key field not found."); + TString key = MakeKey(proto, *keyField); + const FieldDescriptor* valueField = proto.GetDescriptor()->FindFieldByName("value"); + Y_VERIFY(valueField, "Map entry value field not found."); + PrintField(proto, *valueField, json, key); + } + + TString TProto2JsonPrinter::MakeKey(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field) { + const Reflection* reflection = proto.GetReflection(); + TString result; + switch (field.cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + result = ToString(reflection->GetInt32(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_INT64: + result = ToString(reflection->GetInt64(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_UINT32: + result = ToString(reflection->GetUInt32(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_UINT64: + result = ToString(reflection->GetUInt64(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_DOUBLE: + result = ToString(reflection->GetDouble(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_FLOAT: + result = ToString(reflection->GetFloat(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_BOOL: + result = ToString(reflection->GetBool(proto, &field)); + break; + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* value = reflection->GetEnum(proto, &field); + switch (GetConfig().EnumMode) { + case TProto2JsonConfig::EnumNumber: + result = ToString(value->number()); + break; + case TProto2JsonConfig::EnumName: + result = value->name(); + break; + case TProto2JsonConfig::EnumFullName: + result = value->full_name(); + break; + case TProto2JsonConfig::EnumNameLowerCase: + result = value->name(); + result.to_lower(); + break; + case TProto2JsonConfig::EnumFullNameLowerCase: + result = value->full_name(); + result.to_lower(); + break; + default: + ythrow yexception() << "Unsupported enum mode."; + } + break; + } + case FieldDescriptor::CPPTYPE_STRING: + result = reflection->GetString(proto, &field); + break; + default: + ythrow yexception() << "Unsupported key type."; + } + + return result; + } + + void TProto2JsonPrinter::PrintField(const Message& proto, + const FieldDescriptor& field, + IJsonOutput& json, + const TStringBuf key) { + + + if (field.is_repeated()) + PrintRepeatedField(proto, field, json, key); + else + PrintSingleField(proto, field, json, key); + } + + void TProto2JsonPrinter::Print(const Message& proto, IJsonOutput& json, bool closeMap) { + const Descriptor* descriptor = proto.GetDescriptor(); + Y_ASSERT(descriptor); + + json.BeginObject(); + + // Iterate over all non-extension fields + for (int f = 0, endF = descriptor->field_count(); f < endF; ++f) { + const FieldDescriptor* field = descriptor->field(f); + Y_ASSERT(field); + PrintField(proto, *field, json); + } + + // Check extensions via ListFields + std::vector<const FieldDescriptor*> fields; + auto* ref = proto.GetReflection(); + ref->ListFields(proto, &fields); + + for (const FieldDescriptor* field : fields) { + Y_ASSERT(field); + if (field->is_extension()) { + switch (GetConfig().ExtensionFieldNameMode) { + case TProto2JsonConfig::ExtFldNameFull: + PrintField(proto, *field, json, field->full_name()); + break; + case TProto2JsonConfig::ExtFldNameShort: + PrintField(proto, *field, json); + break; + } + } + } + + if (closeMap) { + json.EndObject(); + } + } + + template <class T, class U> + std::enable_if_t<!std::is_unsigned<T>::value, bool> ValueInRange(T value, U range) { + return value >= -range && value <= range; + } + + template <class T, class U> + std::enable_if_t<std::is_unsigned<T>::value, bool> ValueInRange(T value, U range) { + return value <= (std::make_unsigned_t<U>)(range); + } + + template <class T> + bool TProto2JsonPrinter::NeedStringifyNumber(T value) const { + constexpr long SAFE_INTEGER_RANGE_FLOAT = 16777216; + constexpr long long SAFE_INTEGER_RANGE_DOUBLE = 9007199254740992; + + switch (GetConfig().StringifyLongNumbers) { + case TProto2JsonConfig::StringifyLongNumbersNever: + return false; + case TProto2JsonConfig::StringifyLongNumbersForFloat: + return !ValueInRange(value, SAFE_INTEGER_RANGE_FLOAT); + case TProto2JsonConfig::StringifyLongNumbersForDouble: + return !ValueInRange(value, SAFE_INTEGER_RANGE_DOUBLE); + } + + return false; + } + +} diff --git a/library/cpp/protobuf/json/proto2json_printer.h b/library/cpp/protobuf/json/proto2json_printer.h new file mode 100644 index 0000000000..9dc5aa86c6 --- /dev/null +++ b/library/cpp/protobuf/json/proto2json_printer.h @@ -0,0 +1,68 @@ +#pragma once + +#include "json_output.h" + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/message.h> + +#include <util/generic/strbuf.h> +#include <util/generic/string.h> + +namespace NProtobufJson { + struct TProto2JsonConfig; + + class TProto2JsonPrinter { + public: + TProto2JsonPrinter(const TProto2JsonConfig& config); + virtual ~TProto2JsonPrinter(); + + virtual void Print(const NProtoBuf::Message& proto, IJsonOutput& json, bool closeMap = true); + + virtual const TProto2JsonConfig& GetConfig() const { + return Config; + } + + protected: + virtual TStringBuf MakeKey(const NProtoBuf::FieldDescriptor& field); + + virtual void PrintField(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key = {}); + + void PrintRepeatedField(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key = {}); + + void PrintSingleField(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field, + IJsonOutput& json, + TStringBuf key = {}); + + void PrintKeyValue(const NProtoBuf::Message& proto, + IJsonOutput& json); + + TString MakeKey(const NProtoBuf::Message& proto, + const NProtoBuf::FieldDescriptor& field); + + template <bool InMapContext> + void PrintEnumValue(const TStringBuf& key, + const NProtoBuf::EnumValueDescriptor* value, + IJsonOutput& json); + + template <bool InMapContext> + void PrintStringValue(const NProtoBuf::FieldDescriptor& field, + const TStringBuf& key, const TString& value, + IJsonOutput& json); + + template <class T> + bool NeedStringifyNumber(T value) const; + + protected: + const TProto2JsonConfig& Config; + TString TmpBuf; + }; + +} diff --git a/library/cpp/protobuf/json/string_transform.cpp b/library/cpp/protobuf/json/string_transform.cpp new file mode 100644 index 0000000000..7c42daa677 --- /dev/null +++ b/library/cpp/protobuf/json/string_transform.cpp @@ -0,0 +1,64 @@ +#include "string_transform.h" + +#include <google/protobuf/stubs/strutil.h> + +#include <library/cpp/string_utils/base64/base64.h> + +namespace NProtobufJson { + void TCEscapeTransform::Transform(TString& str) const { + str = google::protobuf::CEscape(str); + } + + void TSafeUtf8CEscapeTransform::Transform(TString& str) const { + str = google::protobuf::strings::Utf8SafeCEscape(str); + } + + void TDoubleEscapeTransform::Transform(TString& str) const { + TString escaped = google::protobuf::CEscape(str); + str = ""; + for (char* it = escaped.begin(); *it; ++it) { + if (*it == '\\' || *it == '\"') + str += "\\"; + str += *it; + } + } + + void TDoubleUnescapeTransform::Transform(TString& str) const { + str = google::protobuf::UnescapeCEscapeString(Unescape(str)); + } + + TString TDoubleUnescapeTransform::Unescape(const TString& str) const { + if (str.empty()) { + return str; + } + + TString result; + result.reserve(str.size()); + + char prev = str[0]; + bool doneOutput = true; + for (const char* it = str.c_str() + 1; *it; ++it) { + if (doneOutput && prev == '\\' && (*it == '\\' || *it == '\"')) { + doneOutput = false; + } else { + result += prev; + doneOutput = true; + } + prev = *it; + } + + if ((doneOutput && prev != '\\') || !doneOutput) { + result += prev; + } + + return result; + } + + void TBase64EncodeBytesTransform::TransformBytes(TString &str) const { + str = Base64Encode(str); + } + + void TBase64DecodeBytesTransform::TransformBytes(TString &str) const { + str = Base64Decode(str); + } +} diff --git a/library/cpp/protobuf/json/string_transform.h b/library/cpp/protobuf/json/string_transform.h new file mode 100644 index 0000000000..e4b296bc01 --- /dev/null +++ b/library/cpp/protobuf/json/string_transform.h @@ -0,0 +1,111 @@ +#pragma once + +#include <library/cpp/string_utils/relaxed_escaper/relaxed_escaper.h> +#include <util/generic/ptr.h> +#include <util/generic/refcount.h> + +namespace NProtobufJson { + class IStringTransform: public TSimpleRefCount<IStringTransform> { + public: + virtual ~IStringTransform() { + } + + /// Some transforms have special meaning. + /// For example, escape transforms cause generic JSON escaping to be turned off. + enum Type { + EscapeTransform = 0x1, + }; + + virtual int GetType() const = 0; + + /// This method is called for each string field in proto + virtual void Transform(TString& str) const = 0; + + /// This method is called for each bytes field in proto + virtual void TransformBytes(TString& str) const { + // Default behaviour is to apply string transform + return Transform(str); + } + }; + + using TStringTransformPtr = TIntrusivePtr<IStringTransform>; + + template <bool quote, bool tounicode> + class TEscapeJTransform: public IStringTransform { + public: + int GetType() const override { + return EscapeTransform; + } + + void Transform(TString& str) const override { + TString newStr; + NEscJ::EscapeJ<quote, tounicode>(str, newStr); + str = newStr; + } + }; + + class TCEscapeTransform: public IStringTransform { + public: + int GetType() const override { + return EscapeTransform; + } + + void Transform(TString& str) const override; + }; + + class TSafeUtf8CEscapeTransform: public IStringTransform { + public: + int GetType() const override { + return EscapeTransform; + } + + void Transform(TString& str) const override; + }; + + class TDoubleEscapeTransform: public IStringTransform { + public: + int GetType() const override { + return EscapeTransform; + } + + void Transform(TString& str) const override; + }; + + class TDoubleUnescapeTransform: public NProtobufJson::IStringTransform { + public: + int GetType() const override { + return NProtobufJson::IStringTransform::EscapeTransform; + } + + void Transform(TString& str) const override; + + private: + TString Unescape(const TString& str) const; + }; + + class TBase64EncodeBytesTransform: public NProtobufJson::IStringTransform { + public: + int GetType() const override { + return 0; + } + + void Transform(TString&) const override { + // Do not transform strings + } + + void TransformBytes(TString &str) const override; + }; + + class TBase64DecodeBytesTransform: public NProtobufJson::IStringTransform { + public: + int GetType() const override { + return 0; + } + + void Transform(TString&) const override { + // Do not transform strings + } + + void TransformBytes(TString &str) const override; + }; +} diff --git a/library/cpp/protobuf/json/ut/fields.incl b/library/cpp/protobuf/json/ut/fields.incl new file mode 100644 index 0000000000..4b22985836 --- /dev/null +++ b/library/cpp/protobuf/json/ut/fields.incl @@ -0,0 +1,23 @@ +// Intentionally no #pragma once + +// (Field name == JSON key, Value) +DEFINE_FIELD(I32, Min<i32>()) +DEFINE_FIELD(I64, Min<i64>()) +DEFINE_FIELD(UI32, Max<ui32>()) +DEFINE_FIELD(UI64, Max<ui64>()) +DEFINE_FIELD(SI32, Min<i32>()) +DEFINE_FIELD(SI64, Min<i64>()) +DEFINE_FIELD(FI32, Max<ui32>()) +DEFINE_FIELD(FI64, Max<ui64>()) +DEFINE_FIELD(SFI32, Min<i32>()) +DEFINE_FIELD(SFI64, Min<i64>()) +DEFINE_FIELD(Bool, true) +DEFINE_FIELD(String, "Lorem ipsum") +DEFINE_FIELD(Bytes, "מחשב") +DEFINE_FIELD(Enum, E_1) +DEFINE_FIELD(Float, 1.123f) +DEFINE_FIELD(Double, 1.123456789012) +DEFINE_FIELD(OneString, "Lorem ipsum dolor") +DEFINE_FIELD(OneTwoString, "Lorem ipsum dolor sit") +DEFINE_FIELD(ABC, "abc") +DEFINE_FIELD(UserID, "some_id")
\ No newline at end of file diff --git a/library/cpp/protobuf/json/ut/filter_ut.cpp b/library/cpp/protobuf/json/ut/filter_ut.cpp new file mode 100644 index 0000000000..95c227666f --- /dev/null +++ b/library/cpp/protobuf/json/ut/filter_ut.cpp @@ -0,0 +1,93 @@ +#include <library/cpp/protobuf/json/ut/filter_ut.pb.h> + +#include <library/cpp/protobuf/json/filter.h> +#include <library/cpp/protobuf/json/field_option.h> +#include <library/cpp/protobuf/json/proto2json.h> +#include <library/cpp/testing/unittest/registar.h> + +using namespace NProtobufJson; + +static NProtobufJsonUt::TFilterTest GetTestMsg() { + NProtobufJsonUt::TFilterTest msg; + msg.SetOptFiltered("1"); + msg.SetNotFiltered("23"); + msg.AddRepFiltered(45); + msg.AddRepFiltered(67); + msg.MutableInner()->AddNumber(100); + msg.MutableInner()->AddNumber(200); + msg.MutableInner()->SetInnerFiltered(235); + return msg; +} + +Y_UNIT_TEST_SUITE(TProto2JsonFilterTest){ + Y_UNIT_TEST(TestFilterPrinter){ + NProtobufJsonUt::TFilterTest msg = GetTestMsg(); +{ + TString expected = R"({"OptFiltered":"1","NotFiltered":"23","RepFiltered":[45,67],)" + R"("Inner":{"Number":[100,200],"InnerFiltered":235}})"; + TString my = Proto2Json(msg); + UNIT_ASSERT_STRINGS_EQUAL(my, expected); +} + +{ + TString expected = R"({"NotFiltered":"23",)" + R"("Inner":{"Number":[100,200]}})"; + TString my = PrintWithFilter(msg, MakeFieldOptionFunctor(NProtobufJsonUt::filter_test, false)); + UNIT_ASSERT_STRINGS_EQUAL(my, expected); +} + +{ + TString expected = R"({"OptFiltered":"1","RepFiltered":[45,67]})"; + TString my = PrintWithFilter(msg, MakeFieldOptionFunctor(NProtobufJsonUt::filter_test)); + UNIT_ASSERT_STRINGS_EQUAL(my, expected); +} + +{ + TString expected = R"({"OptFiltered":"1","NotFiltered":"23",)" + R"("Inner":{"Number":[100,200]}})"; + TString my; + PrintWithFilter(msg, MakeFieldOptionFunctor(NProtobufJsonUt::export_test), *CreateJsonMapOutput(my)); + UNIT_ASSERT_STRINGS_EQUAL(my, expected); +} + +{ + TString expected = R"({"NotFiltered":"23",)" + R"("Inner":{"Number":[100,200]}})"; + auto functor = [](const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor* field) { + return field->name() == "NotFiltered" || field->name() == "Number" || field->name() == "Inner"; + }; + TString my = PrintWithFilter(msg, functor); + UNIT_ASSERT_STRINGS_EQUAL(my, expected); +} +} + +Y_UNIT_TEST(NoUnnecessaryCopyFunctor) { + size_t CopyCount = 0; + struct TFunctorMock { + TFunctorMock(size_t* copyCount) + : CopyCount(copyCount) + { + UNIT_ASSERT(*CopyCount <= 1); + } + + TFunctorMock(const TFunctorMock& f) + : CopyCount(f.CopyCount) + { + ++*CopyCount; + } + + TFunctorMock(TFunctorMock&& f) = default; + + bool operator()(const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor*) const { + return false; + } + + size_t* CopyCount; + }; + + TProto2JsonConfig cfg; + TFilteringPrinter<> printer(TFunctorMock(&CopyCount), cfg); + UNIT_ASSERT(CopyCount <= 1); +} +} +; diff --git a/library/cpp/protobuf/json/ut/filter_ut.proto b/library/cpp/protobuf/json/ut/filter_ut.proto new file mode 100644 index 0000000000..29d630ade4 --- /dev/null +++ b/library/cpp/protobuf/json/ut/filter_ut.proto @@ -0,0 +1,20 @@ +import "google/protobuf/descriptor.proto"; + +package NProtobufJsonUt; + +extend google.protobuf.FieldOptions { + optional bool filter_test = 58255; + optional bool export_test = 58256; +} + +message TFilterTest { + optional string OptFiltered = 1 [(filter_test) = true, (export_test) = true]; + optional string NotFiltered = 2 [(export_test) = true]; + repeated uint64 RepFiltered = 3 [(filter_test) = true]; + + message TInner { + repeated uint32 Number = 1 [(export_test) = true]; + optional int32 InnerFiltered = 2 [(filter_test) = true]; + } + optional TInner Inner = 4 [(export_test) = true]; +} diff --git a/library/cpp/protobuf/json/ut/inline_ut.cpp b/library/cpp/protobuf/json/ut/inline_ut.cpp new file mode 100644 index 0000000000..c29ad32e7d --- /dev/null +++ b/library/cpp/protobuf/json/ut/inline_ut.cpp @@ -0,0 +1,122 @@ +#include <library/cpp/protobuf/json/ut/inline_ut.pb.h> + +#include <library/cpp/protobuf/json/inline.h> +#include <library/cpp/protobuf/json/field_option.h> +#include <library/cpp/protobuf/json/proto2json.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/string.h> + +using namespace NProtobufJson; + +static NProtobufJsonUt::TInlineTest GetTestMsg() { + NProtobufJsonUt::TInlineTest msg; + msg.SetOptJson(R"({"a":1,"b":"000"})"); + msg.SetNotJson("12{}34"); + msg.AddRepJson("{}"); + msg.AddRepJson("[1,2]"); + msg.MutableInner()->AddNumber(100); + msg.MutableInner()->AddNumber(200); + msg.MutableInner()->SetInnerJson(R"({"xxx":[]})"); + return msg; +} + +Y_UNIT_TEST_SUITE(TProto2JsonInlineTest){ + Y_UNIT_TEST(TestNormalPrint){ + NProtobufJsonUt::TInlineTest msg = GetTestMsg(); +// normal print should output these fields as just string values +TString expRaw = R"({"OptJson":"{\"a\":1,\"b\":\"000\"}","NotJson":"12{}34","RepJson":["{}","[1,2]"],)" + R"("Inner":{"Number":[100,200],"InnerJson":"{\"xxx\":[]}"}})"; +TString myRaw; +Proto2Json(msg, myRaw); +UNIT_ASSERT_STRINGS_EQUAL(myRaw, expRaw); + +myRaw = PrintInlined(msg, [](const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor*) { return false; }); +UNIT_ASSERT_STRINGS_EQUAL(myRaw, expRaw); // result is the same +} + +Y_UNIT_TEST(TestInliningPrinter) { + NProtobufJsonUt::TInlineTest msg = GetTestMsg(); + // inlined print should output these fields as inlined json sub-objects + TString expInlined = R"({"OptJson":{"a":1,"b":"000"},"NotJson":"12{}34","RepJson":[{},[1,2]],)" + R"("Inner":{"Number":[100,200],"InnerJson":{"xxx":[]}}})"; + + { + TString myInlined = PrintInlined(msg, MakeFieldOptionFunctor(NProtobufJsonUt::inline_test)); + UNIT_ASSERT_STRINGS_EQUAL(myInlined, expInlined); + } + { + auto functor = [](const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor* field) { + return field->name() == "OptJson" || field->name() == "RepJson" || field->name() == "InnerJson"; + }; + TString myInlined = PrintInlined(msg, functor); + UNIT_ASSERT_STRINGS_EQUAL(myInlined, expInlined); + } +} + +Y_UNIT_TEST(TestNoValues) { + // no values - no printing + NProtobufJsonUt::TInlineTest msg; + msg.MutableInner()->AddNumber(100); + msg.MutableInner()->AddNumber(200); + + TString expInlined = R"({"Inner":{"Number":[100,200]}})"; + + TString myInlined = PrintInlined(msg, MakeFieldOptionFunctor(NProtobufJsonUt::inline_test)); + UNIT_ASSERT_STRINGS_EQUAL(myInlined, expInlined); +} + +Y_UNIT_TEST(TestMissingKeyModeNull) { + NProtobufJsonUt::TInlineTest msg; + msg.MutableInner()->AddNumber(100); + msg.MutableInner()->AddNumber(200); + + TString expInlined = R"({"OptJson":null,"NotJson":null,"RepJson":null,"Inner":{"Number":[100,200],"InnerJson":null}})"; + + TProto2JsonConfig cfg; + cfg.SetMissingSingleKeyMode(TProto2JsonConfig::MissingKeyNull).SetMissingRepeatedKeyMode(TProto2JsonConfig::MissingKeyNull); + TString myInlined = PrintInlined(msg, MakeFieldOptionFunctor(NProtobufJsonUt::inline_test), cfg); + UNIT_ASSERT_STRINGS_EQUAL(myInlined, expInlined); +} + +Y_UNIT_TEST(TestMissingKeyModeDefault) { + NProtobufJsonUt::TInlineTestDefaultValues msg; + + TString expInlined = R"({"OptJson":{"default":1},"Number":0,"RepJson":[],"Inner":{"OptJson":{"default":2}}})"; + + TProto2JsonConfig cfg; + cfg.SetMissingSingleKeyMode(TProto2JsonConfig::MissingKeyDefault).SetMissingRepeatedKeyMode(TProto2JsonConfig::MissingKeyDefault); + TString myInlined = PrintInlined(msg, MakeFieldOptionFunctor(NProtobufJsonUt::inline_test), cfg); + UNIT_ASSERT_STRINGS_EQUAL(myInlined, expInlined); +} + +Y_UNIT_TEST(NoUnnecessaryCopyFunctor) { + size_t CopyCount = 0; + struct TFunctorMock { + TFunctorMock(size_t* copyCount) + : CopyCount(copyCount) + { + UNIT_ASSERT(*CopyCount <= 1); + } + + TFunctorMock(const TFunctorMock& f) + : CopyCount(f.CopyCount) + { + ++*CopyCount; + } + + TFunctorMock(TFunctorMock&& f) = default; + + bool operator()(const NProtoBuf::Message&, const NProtoBuf::FieldDescriptor*) const { + return false; + } + + size_t* CopyCount; + }; + + TProto2JsonConfig cfg; + TInliningPrinter<> printer(TFunctorMock(&CopyCount), cfg); + UNIT_ASSERT(CopyCount <= 1); +} +} +; diff --git a/library/cpp/protobuf/json/ut/inline_ut.proto b/library/cpp/protobuf/json/ut/inline_ut.proto new file mode 100644 index 0000000000..76bd10232d --- /dev/null +++ b/library/cpp/protobuf/json/ut/inline_ut.proto @@ -0,0 +1,29 @@ +import "google/protobuf/descriptor.proto"; + +package NProtobufJsonUt; + +extend google.protobuf.FieldOptions { + optional bool inline_test = 58253; +} + +message TInlineTest { + optional string OptJson = 1 [(inline_test) = true]; + optional string NotJson = 2; + repeated string RepJson = 3 [(inline_test) = true]; + + message TInner { + repeated uint32 Number = 1; + optional string InnerJson = 2 [(inline_test) = true]; + } + optional TInner Inner = 4; +} + +message TInlineTestDefaultValues { + optional string OptJson = 1 [(inline_test) = true, default = "{\"default\":1}"]; + optional uint32 Number = 2; + repeated string RepJson = 3 [(inline_test) = true]; + message TInner { + optional string OptJson = 1 [(inline_test) = true, default = "{\"default\":2}"]; + } + optional TInner Inner = 4; +} diff --git a/library/cpp/protobuf/json/ut/json.h b/library/cpp/protobuf/json/ut/json.h new file mode 100644 index 0000000000..c1f108e6e4 --- /dev/null +++ b/library/cpp/protobuf/json/ut/json.h @@ -0,0 +1,69 @@ +#pragma once + +#include <library/cpp/protobuf/json/ut/test.pb.h> + +#include <library/cpp/json/json_value.h> + +#include <cstdarg> + +#include <util/generic/hash_set.h> +#include <util/generic/string.h> + +#include <util/system/defaults.h> + +namespace NProtobufJsonTest { + inline NJson::TJsonValue + CreateFlatJson(const THashSet<TString>& skippedKeys = THashSet<TString>()) { + NJson::TJsonValue json; + +#define DEFINE_FIELD(name, value) \ + if (skippedKeys.find(#name) == skippedKeys.end()) \ + json.InsertValue(#name, value); +#include "fields.incl" +#undef DEFINE_FIELD + + return json; + } + + inline NJson::TJsonValue + CreateRepeatedFlatJson(const THashSet<TString>& skippedKeys = THashSet<TString>()) { + NJson::TJsonValue json; + +#define DEFINE_REPEATED_FIELD(name, type, ...) \ + if (skippedKeys.find(#name) == skippedKeys.end()) { \ + type values[] = {__VA_ARGS__}; \ + NJson::TJsonValue array(NJson::JSON_ARRAY); \ + for (size_t i = 0, end = Y_ARRAY_SIZE(values); i < end; ++i) { \ + array.AppendValue(values[i]); \ + } \ + json.InsertValue(#name, array); \ + } +#include "repeated_fields.incl" +#undef DEFINE_REPEATED_FIELD + + return json; + } + + inline NJson::TJsonValue + CreateCompositeJson(const THashSet<TString>& skippedKeys = THashSet<TString>()) { + const NJson::TJsonValue& part = CreateFlatJson(skippedKeys); + NJson::TJsonValue json; + json.InsertValue("Part", part); + + return json; + } + +#define UNIT_ASSERT_JSONS_EQUAL(lhs, rhs) \ + if (lhs != rhs) { \ + UNIT_ASSERT_STRINGS_EQUAL(lhs.GetStringRobust(), rhs.GetStringRobust()); \ + } + +#define UNIT_ASSERT_JSON_STRINGS_EQUAL(lhs, rhs) \ + if (lhs != rhs) { \ + NJson::TJsonValue _lhs_json, _rhs_json; \ + UNIT_ASSERT(NJson::ReadJsonTree(lhs, &_lhs_json)); \ + UNIT_ASSERT(NJson::ReadJsonTree(rhs, &_rhs_json)); \ + UNIT_ASSERT_JSONS_EQUAL(_lhs_json, _rhs_json); \ + } + +} diff --git a/library/cpp/protobuf/json/ut/json2proto_ut.cpp b/library/cpp/protobuf/json/ut/json2proto_ut.cpp new file mode 100644 index 0000000000..0dfe57bc7a --- /dev/null +++ b/library/cpp/protobuf/json/ut/json2proto_ut.cpp @@ -0,0 +1,1147 @@ +#include "json.h" +#include "proto.h" +#include "proto2json.h" + +#include <library/cpp/protobuf/json/ut/test.pb.h> + +#include <library/cpp/json/json_value.h> +#include <library/cpp/json/json_reader.h> +#include <library/cpp/json/json_writer.h> + +#include <library/cpp/protobuf/json/json2proto.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/hash_set.h> +#include <util/generic/string.h> +#include <util/generic/ylimits.h> +#include <util/stream/str.h> +#include <util/string/cast.h> +#include <util/system/defaults.h> +#include <util/system/yassert.h> + +using namespace NProtobufJson; +using namespace NProtobufJsonTest; + +namespace google { + namespace protobuf { + namespace internal { + void MapTestForceDeterministic() { + google::protobuf::io::CodedOutputStream::SetDefaultSerializationDeterministic(); + } + } + } // namespace protobuf +} + +namespace { + class TInit { + public: + TInit() { + ::google::protobuf::internal::MapTestForceDeterministic(); + } + } Init; + + template <typename T> + TString ConvertToString(T value) { + return ToString(value); + } + + // default ToString<double>() implementation loses precision + TString ConvertToString(double value) { + return FloatToString(value); + } + + TString JsonValueToString(const NJson::TJsonValue& json) { + NJsonWriter::TBuf buf(NJsonWriter::HEM_UNSAFE); + return buf.WriteJsonValue(&json).Str(); + } + + void TestComplexMapAsObject(std::function<void(TComplexMapType&)>&& init, const TString& json, const TJson2ProtoConfig& config = TJson2ProtoConfig().SetMapAsObject(true)) { + TComplexMapType modelProto; + + init(modelProto); + + TString modelStr(json); + + TComplexMapType proto; + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TComplexMapType>(modelStr, config)); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } +} + +Y_UNIT_TEST_SUITE(TJson2ProtoTest) { + Y_UNIT_TEST(TestFlatOptional){ + {const NJson::TJsonValue& json = CreateFlatJson(); + TFlatOptional proto; + Json2Proto(json, proto); + TFlatOptional modelProto; + FillFlatProto(&modelProto); + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} + + // Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + const NJson::TJsonValue& json = CreateFlatJson(skippedField); \ + TFlatOptional proto; \ + Json2Proto(json, proto); \ + TFlatOptional modelProto; \ + FillFlatProto(&modelProto, skippedField); \ + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestFlatOptional + +Y_UNIT_TEST(TestFlatRequired){ + {const NJson::TJsonValue& json = CreateFlatJson(); +TFlatRequired proto; +Json2Proto(json, proto); +TFlatRequired modelProto; +FillFlatProto(&modelProto); +UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} + +// Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + const NJson::TJsonValue& json = CreateFlatJson(skippedField); \ + TFlatRequired proto; \ + UNIT_ASSERT_EXCEPTION(Json2Proto(json, proto), yexception); \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestFlatRequired + +Y_UNIT_TEST(TestNameGenerator) { + TJson2ProtoConfig cfg; + cfg.SetNameGenerator([](const NProtoBuf::FieldDescriptor&) { return "42"; }); + + TNameGeneratorType proto; + Json2Proto(TStringBuf(R"({"42":42})"), proto, cfg); + + TNameGeneratorType expected; + expected.SetField(42); + + UNIT_ASSERT_PROTOS_EQUAL(expected, proto); +} + +Y_UNIT_TEST(TestFlatNoCheckRequired) { + { + const NJson::TJsonValue& json = CreateFlatJson(); + TFlatRequired proto; + Json2Proto(json, proto); + TFlatRequired modelProto; + FillFlatProto(&modelProto); + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } + + TJson2ProtoConfig cfg; + cfg.CheckRequiredFields = false; + + // Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + const NJson::TJsonValue& json = CreateFlatJson(skippedField); \ + TFlatRequired proto; \ + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto, cfg)); \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestFlatNoCheckRequired + +Y_UNIT_TEST(TestFlatRepeated){ + {const NJson::TJsonValue& json = CreateRepeatedFlatJson(); +TFlatRepeated proto; +Json2Proto(json, proto); +TFlatRepeated modelProto; +FillRepeatedProto(&modelProto); +UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} + +// Try to skip each field +#define DEFINE_REPEATED_FIELD(name, ...) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + const NJson::TJsonValue& json = CreateRepeatedFlatJson(skippedField); \ + TFlatRepeated proto; \ + Json2Proto(json, proto); \ + TFlatRepeated modelProto; \ + FillRepeatedProto(&modelProto, skippedField); \ + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); \ + } +#include <library/cpp/protobuf/json/ut/repeated_fields.incl> +#undef DEFINE_REPEATED_FIELD +} // TestFlatRepeated + +Y_UNIT_TEST(TestCompositeOptional){ + {const NJson::TJsonValue& json = CreateCompositeJson(); +TCompositeOptional proto; +Json2Proto(json, proto); +TCompositeOptional modelProto; +FillCompositeProto(&modelProto); +UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} + +// Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + const NJson::TJsonValue& json = CreateCompositeJson(skippedField); \ + TCompositeOptional proto; \ + Json2Proto(json, proto); \ + TCompositeOptional modelProto; \ + FillCompositeProto(&modelProto, skippedField); \ + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestCompositeOptional + +Y_UNIT_TEST(TestCompositeOptionalStringBuf){ + {NJson::TJsonValue json = CreateCompositeJson(); +json["Part"]["Double"] = 42.5; +TCompositeOptional proto; +Json2Proto(JsonValueToString(json), proto); +TCompositeOptional modelProto; +FillCompositeProto(&modelProto); +modelProto.MutablePart()->SetDouble(42.5); +UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} + +// Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + NJson::TJsonValue json = CreateCompositeJson(skippedField); \ + if (json["Part"].Has("Double")) { \ + json["Part"]["Double"] = 42.5; \ + } \ + TCompositeOptional proto; \ + Json2Proto(JsonValueToString(json), proto); \ + TCompositeOptional modelProto; \ + FillCompositeProto(&modelProto, skippedField); \ + if (modelProto.GetPart().HasDouble()) { \ + modelProto.MutablePart()->SetDouble(42.5); \ + } \ + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestCompositeOptionalStringBuf + +Y_UNIT_TEST(TestCompositeRequired) { + { + const NJson::TJsonValue& json = CreateCompositeJson(); + TCompositeRequired proto; + Json2Proto(json, proto); + TCompositeRequired modelProto; + FillCompositeProto(&modelProto); + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } + + { + NJson::TJsonValue json; + TCompositeRequired proto; + UNIT_ASSERT_EXCEPTION(Json2Proto(json, proto), yexception); + } +} // TestCompositeRequired + +Y_UNIT_TEST(TestCompositeRepeated) { + { + NJson::TJsonValue json; + NJson::TJsonValue array; + array.AppendValue(CreateFlatJson()); + json.InsertValue("Part", array); + + TCompositeRepeated proto; + Json2Proto(json, proto); + + TFlatOptional partModelProto; + FillFlatProto(&partModelProto); + TCompositeRepeated modelProto; + modelProto.AddPart()->CopyFrom(partModelProto); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } + + { + // Array of messages with each field skipped + TCompositeRepeated modelProto; + NJson::TJsonValue array; + +#define DEFINE_REPEATED_FIELD(name, ...) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TFlatOptional partModelProto; \ + FillFlatProto(&partModelProto, skippedField); \ + modelProto.AddPart()->CopyFrom(partModelProto); \ + array.AppendValue(CreateFlatJson(skippedField)); \ + } +#include <library/cpp/protobuf/json/ut/repeated_fields.incl> +#undef DEFINE_REPEATED_FIELD + + NJson::TJsonValue json; + json.InsertValue("Part", array); + + TCompositeRepeated proto; + Json2Proto(json, proto); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } +} // TestCompositeRepeated + +Y_UNIT_TEST(TestInvalidEnum) { + { + NJson::TJsonValue json; + json.InsertValue("Enum", "E_100"); + TFlatOptional proto; + UNIT_ASSERT_EXCEPTION(Json2Proto(json, proto), yexception); + } + + { + NJson::TJsonValue json; + json.InsertValue("Enum", 100); + TFlatOptional proto; + UNIT_ASSERT_EXCEPTION(Json2Proto(json, proto), yexception); + } +} + +Y_UNIT_TEST(TestFieldNameMode) { + // Original case 1 + { + TString modelStr(R"_({"String":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value"); + } + + // Original case 2 + { + TString modelStr(R"_({"String":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameOriginalCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value"); + } + + // Lowercase + { + TString modelStr(R"_({"string":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameLowerCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value"); + } + + // Uppercase + { + TString modelStr(R"_({"STRING":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameUpperCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value"); + } + + // Camelcase + { + TString modelStr(R"_({"string":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameCamelCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value"); + } + { + TString modelStr(R"_({"oneString":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameCamelCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetOneString() == "value"); + } + { + TString modelStr(R"_({"oneTwoString":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameCamelCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetOneTwoString() == "value"); + } + + // snake_case + { + TString modelStr(R"_({"string":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value"); + } + { + TString modelStr(R"_({"one_string":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetOneString() == "value"); + } + { + TString modelStr(R"_({"one_two_string":"value"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetOneTwoString() == "value"); + } + + // Original case, repeated + { + TString modelStr(R"_({"I32":[1,2]})_"); + + TFlatRepeated proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameOriginalCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatRepeated>(modelStr, config)); + UNIT_ASSERT(proto.I32Size() == 2); + UNIT_ASSERT(proto.GetI32(0) == 1); + UNIT_ASSERT(proto.GetI32(1) == 2); + } + + // Lower case, repeated + { + TString modelStr(R"_({"i32":[1,2]})_"); + + TFlatRepeated proto; + TJson2ProtoConfig config; + config.FieldNameMode = TJson2ProtoConfig::FieldNameLowerCase; + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatRepeated>(modelStr, config)); + UNIT_ASSERT(proto.I32Size() == 2); + UNIT_ASSERT(proto.GetI32(0) == 1); + UNIT_ASSERT(proto.GetI32(1) == 2); + } + + // UseJsonName + { + // FIXME(CONTRIB-139): since protobuf 3.1, Def_upper json name is + // "DefUpper", but until kernel/ugc/schema and yweb/yasap/pdb are + // updated, library/cpp/protobuf/json preserves compatibility with + // protobuf 3.0 by lowercasing default names, making it "defUpper". + TString modelStr(R"_({"My-Upper":1,"my-lower":2,"defUpper":3,"defLower":4})_"); + + TWithJsonName proto; + TJson2ProtoConfig config; + config.SetUseJsonName(true); + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TWithJsonName>(modelStr, config)); + UNIT_ASSERT_EQUAL(proto.Getmy_upper(), 1); + UNIT_ASSERT_EQUAL(proto.GetMy_lower(), 2); + UNIT_ASSERT_EQUAL(proto.GetDef_upper(), 3); + UNIT_ASSERT_EQUAL(proto.Getdef_lower(), 4); + } + + // FieldNameMode with UseJsonName + { + TJson2ProtoConfig config; + config.SetFieldNameMode(TJson2ProtoConfig::FieldNameLowerCase); + UNIT_ASSERT_EXCEPTION_CONTAINS( + config.SetUseJsonName(true), yexception, "mutually exclusive"); + } + { + TJson2ProtoConfig config; + config.SetUseJsonName(true); + UNIT_ASSERT_EXCEPTION_CONTAINS( + config.SetFieldNameMode(TJson2ProtoConfig::FieldNameLowerCase), yexception, "mutually exclusive"); + } +} // TestFieldNameMode + +class TStringTransform: public IStringTransform { +public: + int GetType() const override { + return 0; + } + void Transform(TString& str) const override { + str = "transformed_any"; + } +}; + +class TBytesTransform: public IStringTransform { +public: + int GetType() const override { + return 0; + } + void Transform(TString&) const override { + } + void TransformBytes(TString& str) const override { + str = "transformed_bytes"; + } +}; + +Y_UNIT_TEST(TestInvalidJson) { + NJson::TJsonValue val{"bad value"}; + TFlatOptional proto; + UNIT_ASSERT_EXCEPTION(Json2Proto(val, proto), yexception); +} + +Y_UNIT_TEST(TestInvalidRepeatedFieldWithMapAsObject) { + TCompositeRepeated proto; + TJson2ProtoConfig config; + config.MapAsObject = true; + UNIT_ASSERT_EXCEPTION(Json2Proto(TStringBuf(R"({"Part":{"Boo":{}}})"), proto, config), yexception); +} + +Y_UNIT_TEST(TestStringTransforms) { + // Check that strings and bytes are transformed + { + TString modelStr(R"_({"String":"value_str", "Bytes": "value_bytes"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.AddStringTransform(new TStringTransform); + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "transformed_any"); + UNIT_ASSERT(proto.GetBytes() == "transformed_any"); + } + + // Check that bytes are transformed, strings are left intact + { + TString modelStr(R"_({"String":"value_str", "Bytes": "value_bytes"})_"); + + TFlatOptional proto; + TJson2ProtoConfig config; + config.AddStringTransform(new TBytesTransform); + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetString() == "value_str"); + UNIT_ASSERT(proto.GetBytes() == "transformed_bytes"); + } + + // Check that repeated bytes are transformed, repeated strings are left intact + { + TString modelStr(R"_({"String":["value_str", "str2"], "Bytes": ["value_bytes", "bytes2"]})_"); + + TFlatRepeated proto; + TJson2ProtoConfig config; + config.AddStringTransform(new TBytesTransform); + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TFlatRepeated>(modelStr, config)); + UNIT_ASSERT(proto.StringSize() == 2); + UNIT_ASSERT(proto.GetString(0) == "value_str"); + UNIT_ASSERT(proto.GetString(1) == "str2"); + UNIT_ASSERT(proto.BytesSize() == 2); + UNIT_ASSERT(proto.GetBytes(0) == "transformed_bytes"); + UNIT_ASSERT(proto.GetBytes(1) == "transformed_bytes"); + } + + // Check that bytes are transformed, strings are left intact in composed messages + { + TString modelStr(R"_({"Part": {"String":"value_str", "Bytes": "value_bytes"}})_"); + + TCompositeOptional proto; + TJson2ProtoConfig config; + config.AddStringTransform(new TBytesTransform); + + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TCompositeOptional>(modelStr, config)); + UNIT_ASSERT(proto.GetPart().GetString() == "value_str"); + UNIT_ASSERT(proto.GetPart().GetBytes() == "transformed_bytes"); + } +} // TestStringTransforms + +Y_UNIT_TEST(TestCastFromString) { + // single fields + { + NJson::TJsonValue json; +#define DEFINE_FIELD(name, value) \ + json.InsertValue(#name, ConvertToString(value)); +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD + + TFlatOptional proto; + UNIT_ASSERT_EXCEPTION_CONTAINS(Json2Proto(json, proto), yexception, "Invalid type"); + + TJson2ProtoConfig config; + config.SetCastFromString(true); + Json2Proto(json, proto, config); + + TFlatOptional modelProto; + FillFlatProto(&modelProto); + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } + + // repeated fields + { + NJson::TJsonValue json; +#define DEFINE_REPEATED_FIELD(name, type, ...) \ + { \ + type values[] = {__VA_ARGS__}; \ + NJson::TJsonValue array(NJson::JSON_ARRAY); \ + for (size_t i = 0, end = Y_ARRAY_SIZE(values); i < end; ++i) { \ + array.AppendValue(ConvertToString(values[i])); \ + } \ + json.InsertValue(#name, array); \ + } +#include <library/cpp/protobuf/json/ut/repeated_fields.incl> +#undef DEFINE_REPEATED_FIELD + + TFlatRepeated proto; + UNIT_ASSERT_EXCEPTION_CONTAINS(Json2Proto(json, proto), yexception, "Invalid type"); + + TJson2ProtoConfig config; + config.SetCastFromString(true); + Json2Proto(json, proto, config); + + TFlatRepeated modelProto; + FillRepeatedProto(&modelProto); + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); + } +} // TestCastFromString + +Y_UNIT_TEST(TestMap) { + TMapType modelProto; + + auto& items = *modelProto.MutableItems(); + items["key1"] = "value1"; + items["key2"] = "value2"; + items["key3"] = "value3"; + + TString modelStr(R"_({"Items":[{"key":"key3","value":"value3"},{"key":"key2","value":"value2"},{"key":"key1","value":"value1"}]})_"); + + TJson2ProtoConfig config; + TMapType proto; + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TMapType>(modelStr, config)); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMap + +Y_UNIT_TEST(TestCastRobust) { + NJson::TJsonValue json; + json["I32"] = "5"; + json["Bool"] = 1; + json["String"] = 6; + json["Double"] = 8; + TFlatOptional proto; + UNIT_ASSERT_EXCEPTION_CONTAINS(Json2Proto(json, proto), yexception, "Invalid type"); + + TJson2ProtoConfig config; + config.SetCastRobust(true); + Json2Proto(json, proto, config); + + TFlatOptional expected; + expected.SetI32(5); + expected.SetBool(true); + expected.SetString("6"); + expected.SetDouble(8); + UNIT_ASSERT_PROTOS_EQUAL(proto, expected); +} + +Y_UNIT_TEST(TestVectorizeScalars) { + NJson::TJsonValue json; +#define DEFINE_FIELD(name, value) \ + json.InsertValue(#name, value); +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD + + TFlatRepeated proto; + TJson2ProtoConfig config; + config.SetVectorizeScalars(true); + Json2Proto(json, proto, config); + +#define DEFINE_FIELD(name, value) \ + UNIT_ASSERT_VALUES_EQUAL(proto.Get ## name(0), value); +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} + +Y_UNIT_TEST(TestValueVectorizer) { + { + // No ValueVectorizer + NJson::TJsonValue json; + json["RepeatedString"] = "123"; + TJson2ProtoConfig config; + TSingleRepeatedString expected; + UNIT_ASSERT_EXCEPTION(Json2Proto(json, expected, config), yexception); + } + { + // ValueVectorizer replace original value by array + NJson::TJsonValue json; + json["RepeatedString"] = "123"; + TJson2ProtoConfig config; + + TSingleRepeatedString expected; + expected.AddRepeatedString("4"); + expected.AddRepeatedString("5"); + expected.AddRepeatedString("6"); + + config.ValueVectorizer = [](const NJson::TJsonValue& val) -> NJson::TJsonValue::TArray { + Y_UNUSED(val); + return {NJson::TJsonValue("4"), NJson::TJsonValue("5"), NJson::TJsonValue("6")}; + }; + TSingleRepeatedString actual; + Json2Proto(json, actual, config); + UNIT_ASSERT_PROTOS_EQUAL(expected, actual); + } + { + // ValueVectorizer replace original value by array and cast + NJson::TJsonValue json; + json["RepeatedInt"] = 123; + TJson2ProtoConfig config; + + TSingleRepeatedInt expected; + expected.AddRepeatedInt(4); + expected.AddRepeatedInt(5); + expected.AddRepeatedInt(6); + + config.ValueVectorizer = [](const NJson::TJsonValue& val) -> NJson::TJsonValue::TArray { + Y_UNUSED(val); + return {NJson::TJsonValue("4"), NJson::TJsonValue(5), NJson::TJsonValue("6")}; + }; + config.CastFromString = true; + + TSingleRepeatedInt actual; + Json2Proto(json, actual, config); + UNIT_ASSERT_PROTOS_EQUAL(expected, actual); + } +} + +Y_UNIT_TEST(TestMapAsObject) { + TMapType modelProto; + + auto& items = *modelProto.MutableItems(); + items["key1"] = "value1"; + items["key2"] = "value2"; + items["key3"] = "value3"; + + TString modelStr(R"_({"Items":{"key1":"value1","key2":"value2","key3":"value3"}})_"); + + TJson2ProtoConfig config; + config.MapAsObject = true; + TMapType proto; + UNIT_ASSERT_NO_EXCEPTION(proto = Json2Proto<TMapType>(modelStr, config)); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMapAsObject + +Y_UNIT_TEST(TestComplexMapAsObject_I32) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableI32(); + items[1] = 1; + items[-2] = -2; + items[3] = 3; + }, + R"_({"I32":{"1":1,"-2":-2,"3":3}})_"); +} // TestComplexMapAsObject_I32 + +Y_UNIT_TEST(TestComplexMapAsObject_I64) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableI64(); + items[2147483649L] = 2147483649L; + items[-2147483650L] = -2147483650L; + items[2147483651L] = 2147483651L; + }, + R"_({"I64":{"2147483649":2147483649,"-2147483650":-2147483650,"2147483651":2147483651}})_"); +} // TestComplexMapAsObject_I64 + +Y_UNIT_TEST(TestComplexMapAsObject_UI32) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableUI32(); + items[1073741825U] = 1073741825U; + items[1073741826U] = 1073741826U; + items[1073741827U] = 1073741827U; + }, + R"_({"UI32":{"1073741825":1073741825,"1073741826":1073741826,"1073741827":1073741827}})_"); +} // TestComplexMapAsObject_UI32 + +Y_UNIT_TEST(TestComplexMapAsObject_UI64) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableUI64(); + items[9223372036854775809UL] = 9223372036854775809UL; + items[9223372036854775810UL] = 9223372036854775810UL; + items[9223372036854775811UL] = 9223372036854775811UL; + }, + R"_({"UI64":{"9223372036854775809":9223372036854775809,"9223372036854775810":9223372036854775810,"9223372036854775811":9223372036854775811}})_"); +} // TestComplexMapAsObject_UI64 + +Y_UNIT_TEST(TestComplexMapAsObject_SI32) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableSI32(); + items[1] = 1; + items[-2] = -2; + items[3] = 3; + }, + R"_({"SI32":{"1":1,"-2":-2,"3":3}})_"); +} // TestComplexMapAsObject_SI32 + +Y_UNIT_TEST(TestComplexMapAsObject_SI64) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableSI64(); + items[2147483649L] = 2147483649L; + items[-2147483650L] = -2147483650L; + items[2147483651L] = 2147483651L; + }, + R"_({"SI64":{"2147483649":2147483649,"-2147483650":-2147483650,"2147483651":2147483651}})_"); +} // TestComplexMapAsObject_SI64 + +Y_UNIT_TEST(TestComplexMapAsObject_FI32) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableFI32(); + items[1073741825U] = 1073741825U; + items[1073741826U] = 1073741826U; + items[1073741827U] = 1073741827U; + }, + R"_({"FI32":{"1073741825":1073741825,"1073741826":1073741826,"1073741827":1073741827}})_"); +} // TestComplexMapAsObject_FI32 + +Y_UNIT_TEST(TestComplexMapAsObject_FI64) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableFI64(); + items[9223372036854775809UL] = 9223372036854775809UL; + items[9223372036854775810UL] = 9223372036854775810UL; + items[9223372036854775811UL] = 9223372036854775811UL; + }, + R"_({"FI64":{"9223372036854775809":9223372036854775809,"9223372036854775810":9223372036854775810,"9223372036854775811":9223372036854775811}})_"); +} // TestComplexMapAsObject_FI64 + +Y_UNIT_TEST(TestComplexMapAsObject_SFI32) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableSFI32(); + items[1] = 1; + items[-2] = -2; + items[3] = 3; + }, + R"_({"SFI32":{"1":1,"-2":-2,"3":3}})_"); +} // TestComplexMapAsObject_SFI32 + +Y_UNIT_TEST(TestComplexMapAsObject_SFI64) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableSFI64(); + items[2147483649L] = 2147483649L; + items[-2147483650L] = -2147483650L; + items[2147483651L] = 2147483651L; + }, + R"_({"SFI64":{"2147483649":2147483649,"-2147483650":-2147483650,"2147483651":2147483651}})_"); +} // TestComplexMapAsObject_SFI64 + +Y_UNIT_TEST(TestComplexMapAsObject_Bool) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableBool(); + items[true] = true; + items[false] = false; + }, + R"_({"Bool":{"true":true,"false":false}})_"); +} // TestComplexMapAsObject_Bool + +Y_UNIT_TEST(TestComplexMapAsObject_String) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableString(); + items["key1"] = "value1"; + items["key2"] = "value2"; + items["key3"] = "value3"; + items[""] = "value4"; + }, + R"_({"String":{"key1":"value1","key2":"value2","key3":"value3","":"value4"}})_"); +} // TestComplexMapAsObject_String + +Y_UNIT_TEST(TestComplexMapAsObject_Enum) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableEnum(); + items["key1"] = EEnum::E_1; + items["key2"] = EEnum::E_2; + items["key3"] = EEnum::E_3; + }, + R"_({"Enum":{"key1":1,"key2":2,"key3":3}})_"); +} // TestComplexMapAsObject_Enum + +Y_UNIT_TEST(TestComplexMapAsObject_EnumString) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableEnum(); + items["key1"] = EEnum::E_1; + items["key2"] = EEnum::E_2; + items["key3"] = EEnum::E_3; + }, + R"_({"Enum":{"key1":"E_1","key2":"E_2","key3":"E_3"}})_"); +} // TestComplexMapAsObject_EnumString + +Y_UNIT_TEST(TestComplexMapAsObject_EnumStringCaseInsensetive) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableEnum(); + items["key1"] = EEnum::E_1; + items["key2"] = EEnum::E_2; + items["key3"] = EEnum::E_3; + }, + R"_({"Enum":{"key1":"e_1","key2":"E_2","key3":"e_3"}})_", + TJson2ProtoConfig() + .SetMapAsObject(true) + .SetEnumValueMode(NProtobufJson::TJson2ProtoConfig::EnumCaseInsensetive) + ); +} // TestComplexMapAsObject_EnumStringCaseInsensetive + +Y_UNIT_TEST(TestComplexMapAsObject_EnumStringSnakeCaseInsensitive) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableEnum(); + items["key1"] = EEnum::E_1; + items["key2"] = EEnum::E_2; + items["key3"] = EEnum::E_3; + }, + R"_({"Enum":{"key1":"e1","key2":"_E_2_","key3":"e_3"}})_", + TJson2ProtoConfig() + .SetMapAsObject(true) + .SetEnumValueMode(NProtobufJson::TJson2ProtoConfig::EnumSnakeCaseInsensitive) + ); +} // TestComplexMapAsObject_EnumStringCaseInsensetive + +Y_UNIT_TEST(TestComplexMapAsObject_Float) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableFloat(); + items["key1"] = 0.1f; + items["key2"] = 0.2f; + items["key3"] = 0.3f; + }, + R"_({"Float":{"key1":0.1,"key2":0.2,"key3":0.3}})_"); +} // TestComplexMapAsObject_Float + +Y_UNIT_TEST(TestComplexMapAsObject_Double) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + auto& items = *proto.MutableDouble(); + items["key1"] = 0.1L; + items["key2"] = 0.2L; + items["key3"] = 0.3L; + }, + R"_({"Double":{"key1":0.1,"key2":0.2,"key3":0.3}})_"); +} // TestComplexMapAsObject_Double + +Y_UNIT_TEST(TestComplexMapAsObject_Nested) { + TestComplexMapAsObject( + [](TComplexMapType& proto) { + TComplexMapType inner; + auto& innerItems = *inner.MutableString(); + innerItems["key"] = "value"; + auto& items = *proto.MutableNested(); + items["key1"] = inner; + items["key2"] = inner; + items["key3"] = inner; + }, + R"_({"Nested":{"key1":{"String":{"key":"value"}},"key2":{"String":{"key":"value"}},"key3":{"String":{"key":"value"}}}})_"); +} // TestComplexMapAsObject_Nested + +Y_UNIT_TEST(TestMapAsObjectConfigNotSet) { + TString modelStr(R"_({"Items":{"key":"value"}})_"); + + TJson2ProtoConfig config; + UNIT_ASSERT_EXCEPTION_CONTAINS( + Json2Proto<TMapType>(modelStr, config), yexception, + "Map as object representation is not allowed"); +} // TestMapAsObjectNotSet + +Y_UNIT_TEST(TestMergeFlatOptional) { + const NJson::TJsonValue& json = CreateFlatJson(); + + NJson::TJsonValue patch; + patch["I32"] = 5; + patch["Bool"] = false; + patch["String"] = "abacaba"; + patch["Double"] = 0.123; + + TFlatOptional proto; + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto)); + UNIT_ASSERT_NO_EXCEPTION(MergeJson2Proto(patch, proto)); + + TFlatRequired modelProto; + FillFlatProto(&modelProto); + modelProto.SetI32(5); + modelProto.SetBool(false); + modelProto.SetString("abacaba"); + modelProto.SetDouble(0.123); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMergeFlatOptional + +Y_UNIT_TEST(TestMergeFlatRequired) { + const NJson::TJsonValue& json = CreateFlatJson(); + + NJson::TJsonValue patch; + patch["I32"] = 5; + patch["Bool"] = false; + patch["String"] = "abacaba"; + patch["Double"] = 0.123; + + TFlatRequired proto; + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto)); + UNIT_ASSERT_NO_EXCEPTION(MergeJson2Proto(patch, proto)); + + TFlatRequired modelProto; + FillFlatProto(&modelProto); + modelProto.SetI32(5); + modelProto.SetBool(false); + modelProto.SetString("abacaba"); + modelProto.SetDouble(0.123); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMergeFlatRequired + +Y_UNIT_TEST(TestMergeComposite) { + const NJson::TJsonValue& json = CreateCompositeJson(); + + NJson::TJsonValue patch; + patch["Part"]["I32"] = 5; + patch["Part"]["Bool"] = false; + patch["Part"]["String"] = "abacaba"; + patch["Part"]["Double"] = 0.123; + + TCompositeOptional proto; + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto)); + UNIT_ASSERT_NO_EXCEPTION(MergeJson2Proto(patch, proto)); + + TCompositeOptional modelProto; + FillCompositeProto(&modelProto); + modelProto.MutablePart()->SetI32(5); + modelProto.MutablePart()->SetBool(false); + modelProto.MutablePart()->SetString("abacaba"); + modelProto.MutablePart()->SetDouble(0.123); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMergeComposite + +Y_UNIT_TEST(TestMergeRepeatedReplace) { + const NJson::TJsonValue& json = CreateRepeatedFlatJson(); + + NJson::TJsonValue patch; + patch["I32"].AppendValue(5); + patch["I32"].AppendValue(6); + patch["String"].AppendValue("abacaba"); + + TFlatRepeated proto; + TJson2ProtoConfig config; + config.ReplaceRepeatedFields = true; + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto)); + UNIT_ASSERT_NO_EXCEPTION(MergeJson2Proto(patch, proto, config)); + + TFlatRepeated modelProto; + FillRepeatedProto(&modelProto); + modelProto.ClearI32(); + modelProto.AddI32(5); + modelProto.AddI32(6); + modelProto.ClearString(); + modelProto.AddString("abacaba"); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMergeRepeatedReplace + +Y_UNIT_TEST(TestMergeRepeatedAppend) { + const NJson::TJsonValue& json = CreateRepeatedFlatJson(); + + NJson::TJsonValue patch; + patch["I32"].AppendValue(5); + patch["I32"].AppendValue(6); + patch["String"].AppendValue("abacaba"); + + TFlatRepeated proto; + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto)); + UNIT_ASSERT_NO_EXCEPTION(MergeJson2Proto(patch, proto)); + + TFlatRepeated modelProto; + FillRepeatedProto(&modelProto); + modelProto.AddI32(5); + modelProto.AddI32(6); + modelProto.AddString("abacaba"); + + UNIT_ASSERT_PROTOS_EQUAL(proto, modelProto); +} // TestMergeRepeatedAppend + +Y_UNIT_TEST(TestEmptyStringForCastFromString) { + NJson::TJsonValue json; + json["I32"] = ""; + json["Bool"] = ""; + json["OneString"] = ""; + + TJson2ProtoConfig config; + config.SetCastFromString(true); + config.SetDoNotCastEmptyStrings(true); + TFlatOptional proto; + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto, config)); + UNIT_ASSERT(!proto.HasBool()); + UNIT_ASSERT(!proto.HasI32()); + UNIT_ASSERT(proto.HasOneString()); + UNIT_ASSERT_EQUAL("", proto.GetOneString()); +} // TestEmptyStringForCastFromString + +Y_UNIT_TEST(TestAllowComments) { + constexpr TStringBuf json = R"( +{ + "I32": 4, // comment1 +/* + comment2 + {} + qwer +*/ + "I64": 3423 +} + +)"; + + TJson2ProtoConfig config; + TFlatOptional proto; + UNIT_ASSERT_EXCEPTION_CONTAINS(Json2Proto(json, proto, config), yexception, "Error: Missing a name for object member"); + + config.SetAllowComments(true); + UNIT_ASSERT_NO_EXCEPTION(Json2Proto(json, proto, config)); + UNIT_ASSERT_VALUES_EQUAL(proto.GetI32(), 4); + UNIT_ASSERT_VALUES_EQUAL(proto.GetI64(), 3423); +} // TestAllowComments + +} // TJson2ProtoTest diff --git a/library/cpp/protobuf/json/ut/proto.h b/library/cpp/protobuf/json/ut/proto.h new file mode 100644 index 0000000000..8183bfc8e1 --- /dev/null +++ b/library/cpp/protobuf/json/ut/proto.h @@ -0,0 +1,62 @@ +#pragma once + +#include <util/generic/hash_set.h> +#include <util/generic/string.h> + +#include <util/system/defaults.h> + +namespace NProtobufJsonTest { + template <typename TProto> + inline void + FillFlatProto(TProto* proto, + const THashSet<TString>& skippedFields = THashSet<TString>()) { +#define DEFINE_FIELD(name, value) \ + if (skippedFields.find(#name) == skippedFields.end()) \ + proto->Set##name(value); +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD + } + + template <typename TRepeatedField, typename TValue> + inline void + AddValue(TRepeatedField* field, TValue value) { + field->Add(value); + } + + inline void + AddValue(google::protobuf::RepeatedPtrField<TString>* field, const TString& value) { + *(field->Add()) = value; + } + + inline void + FillRepeatedProto(TFlatRepeated* proto, + const THashSet<TString>& skippedFields = THashSet<TString>()) { +#define DEFINE_REPEATED_FIELD(name, type, ...) \ + if (skippedFields.find(#name) == skippedFields.end()) { \ + type values[] = {__VA_ARGS__}; \ + for (size_t i = 0, end = Y_ARRAY_SIZE(values); i < end; ++i) { \ + AddValue(proto->Mutable##name(), values[i]); \ + } \ + } +#include <library/cpp/protobuf/json/ut/repeated_fields.incl> +#undef DEFINE_REPEATED_FIELD + } + + template <typename TProto> + inline void + FillCompositeProto(TProto* proto, const THashSet<TString>& skippedFields = THashSet<TString>()) { + FillFlatProto(proto->MutablePart(), skippedFields); + } + +#define UNIT_ASSERT_PROTOS_EQUAL(lhs, rhs) \ + do { \ + if (lhs.SerializeAsString() != rhs.SerializeAsString()) { \ + Cerr << ">>>>>>>>>> lhs != rhs:" << Endl; \ + Cerr << lhs.DebugString() << Endl; \ + Cerr << rhs.DebugString() << Endl; \ + UNIT_ASSERT_STRINGS_EQUAL(lhs.DebugString(), rhs.DebugString()); \ + UNIT_ASSERT_STRINGS_EQUAL(lhs.SerializeAsString(), rhs.SerializeAsString()); \ + } \ + } while (false); + +} diff --git a/library/cpp/protobuf/json/ut/proto2json_ut.cpp b/library/cpp/protobuf/json/ut/proto2json_ut.cpp new file mode 100644 index 0000000000..07e52d7f2f --- /dev/null +++ b/library/cpp/protobuf/json/ut/proto2json_ut.cpp @@ -0,0 +1,1022 @@ +#include "json.h" +#include "proto.h" + +#include <library/cpp/protobuf/json/ut/test.pb.h> + +#include <library/cpp/json/json_value.h> +#include <library/cpp/json/json_reader.h> +#include <library/cpp/json/json_writer.h> + +#include <library/cpp/protobuf/json/proto2json.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/hash_set.h> +#include <util/generic/string.h> +#include <util/generic/ylimits.h> + +#include <util/stream/str.h> + +#include <util/system/defaults.h> +#include <util/system/yassert.h> + +#include <limits> + +using namespace NProtobufJson; +using namespace NProtobufJsonTest; + +Y_UNIT_TEST_SUITE(TProto2JsonFlatTest) { + Y_UNIT_TEST(TestFlatDefault) { + using namespace ::google::protobuf; + TFlatDefault proto; + NJson::TJsonValue json; + TProto2JsonConfig cfg; + cfg.SetMissingSingleKeyMode(TProto2JsonConfig::MissingKeyDefault); + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, cfg)); +#define DEFINE_FIELD(name, value) \ + { \ + auto descr = proto.GetMetadata().descriptor->FindFieldByName(#name); \ + UNIT_ASSERT(descr); \ + UNIT_ASSERT(json.Has(#name)); \ + switch (descr->cpp_type()) { \ + case FieldDescriptor::CPPTYPE_INT32: \ + UNIT_ASSERT(descr->default_value_int32() == json[#name].GetIntegerRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_INT64: \ + UNIT_ASSERT(descr->default_value_int64() == json[#name].GetIntegerRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_UINT32: \ + UNIT_ASSERT(descr->default_value_uint32() == json[#name].GetUIntegerRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_UINT64: \ + UNIT_ASSERT(descr->default_value_uint32() == json[#name].GetUIntegerRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_DOUBLE: \ + UNIT_ASSERT(descr->default_value_double() == json[#name].GetDoubleRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_FLOAT: \ + UNIT_ASSERT(descr->default_value_float() == json[#name].GetDoubleRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_BOOL: \ + UNIT_ASSERT(descr->default_value_bool() == json[#name].GetBooleanRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_ENUM: \ + UNIT_ASSERT(descr->default_value_enum()->number() == json[#name].GetIntegerRobust()); \ + break; \ + case FieldDescriptor::CPPTYPE_STRING: \ + UNIT_ASSERT(descr->default_value_string() == json[#name].GetStringRobust()); \ + break; \ + default: \ + break; \ + } \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD + } + + Y_UNIT_TEST(TestNameGenerator) { + TNameGeneratorType proto; + proto.SetField(42); + + TProto2JsonConfig cfg; + cfg.SetNameGenerator([](const NProtoBuf::FieldDescriptor&) { return "42"; }); + + TStringStream str; + Proto2Json(proto, str, cfg); + + UNIT_ASSERT_STRINGS_EQUAL(R"({"42":42})", str.Str()); + } + + Y_UNIT_TEST(TestEnumValueGenerator) { + TEnumValueGeneratorType proto; + proto.SetEnum(TEnumValueGeneratorType::ENUM_42); + + TProto2JsonConfig cfg; + cfg.SetEnumValueGenerator([](const NProtoBuf::EnumValueDescriptor&) { return "42"; }); + + TStringStream str; + Proto2Json(proto, str, cfg); + + UNIT_ASSERT_STRINGS_EQUAL(R"({"Enum":"42"})", str.Str()); + } + + Y_UNIT_TEST(TestFlatOptional){ + {TFlatOptional proto; + FillFlatProto(&proto); + const NJson::TJsonValue& modelJson = CreateFlatJson(); + { + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } // streamed +} + + // Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TFlatOptional proto; \ + FillFlatProto(&proto, skippedField); \ + const NJson::TJsonValue& modelJson = CreateFlatJson(skippedField); \ + NJson::TJsonValue json; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + { \ + TStringStream jsonStream; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); \ + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + } \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestFlatOptional + +Y_UNIT_TEST(TestFlatRequired){ + {TFlatRequired proto; +FillFlatProto(&proto); +const NJson::TJsonValue& modelJson = CreateFlatJson(); +{ + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); +} + +{ + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); +} // streamed +} + +// Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TFlatRequired proto; \ + FillFlatProto(&proto, skippedField); \ + const NJson::TJsonValue& modelJson = CreateFlatJson(skippedField); \ + NJson::TJsonValue json; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + { \ + TStringStream jsonStream; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); \ + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + } \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestFlatRequired + +Y_UNIT_TEST(TestFlatRepeated) { + { + TFlatRepeated proto; + FillRepeatedProto(&proto); + const NJson::TJsonValue& modelJson = CreateRepeatedFlatJson(); + { + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } // streamed + } + + TProto2JsonConfig config; + config.SetMissingRepeatedKeyMode(TProto2JsonConfig::MissingKeySkip); + + // Try to skip each field +#define DEFINE_REPEATED_FIELD(name, ...) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TFlatRepeated proto; \ + FillRepeatedProto(&proto, skippedField); \ + const NJson::TJsonValue& modelJson = CreateRepeatedFlatJson(skippedField); \ + NJson::TJsonValue json; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + { \ + TStringStream jsonStream; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream, config)); \ + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + } \ + } +#include <library/cpp/protobuf/json/ut/repeated_fields.incl> +#undef DEFINE_REPEATED_FIELD +} // TestFlatRepeated + +Y_UNIT_TEST(TestCompositeOptional){ + {TCompositeOptional proto; +FillCompositeProto(&proto); +const NJson::TJsonValue& modelJson = CreateCompositeJson(); +{ + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); +} + +{ + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); +} // streamed +} + +// Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TCompositeOptional proto; \ + FillCompositeProto(&proto, skippedField); \ + const NJson::TJsonValue& modelJson = CreateCompositeJson(skippedField); \ + NJson::TJsonValue json; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + { \ + TStringStream jsonStream; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); \ + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + } \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestCompositeOptional + +Y_UNIT_TEST(TestCompositeRequired){ + {TCompositeRequired proto; +FillCompositeProto(&proto); +const NJson::TJsonValue& modelJson = CreateCompositeJson(); +{ + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); +} + +{ + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); +} // streamed +} + +// Try to skip each field +#define DEFINE_FIELD(name, value) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TCompositeRequired proto; \ + FillCompositeProto(&proto, skippedField); \ + const NJson::TJsonValue& modelJson = CreateCompositeJson(skippedField); \ + NJson::TJsonValue json; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + { \ + TStringStream jsonStream; \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); \ + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); \ + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); \ + } \ + } +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD +} // TestCompositeRequired + +Y_UNIT_TEST(TestCompositeRepeated) { + { + TFlatOptional partProto; + FillFlatProto(&partProto); + TCompositeRepeated proto; + proto.AddPart()->CopyFrom(partProto); + + NJson::TJsonValue modelJson; + NJson::TJsonValue modelArray; + modelArray.AppendValue(CreateFlatJson()); + modelJson.InsertValue("Part", modelArray); + { + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } // streamed + } + + { + // Array of messages with each field skipped + TCompositeRepeated proto; + NJson::TJsonValue modelArray; + +#define DEFINE_REPEATED_FIELD(name, ...) \ + { \ + THashSet<TString> skippedField; \ + skippedField.insert(#name); \ + TFlatOptional partProto; \ + FillFlatProto(&partProto, skippedField); \ + proto.AddPart()->CopyFrom(partProto); \ + modelArray.AppendValue(CreateFlatJson(skippedField)); \ + } +#include <library/cpp/protobuf/json/ut/repeated_fields.incl> +#undef DEFINE_REPEATED_FIELD + + NJson::TJsonValue modelJson; + modelJson.InsertValue("Part", modelArray); + + { + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TStringStream jsonStream; + NJson::TJsonValue json; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStream)); + UNIT_ASSERT(ReadJsonTree(&jsonStream, &json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } // streamed + } +} // TestCompositeRepeated + +Y_UNIT_TEST(TestEnumConfig) { + { + TFlatOptional proto; + proto.SetEnum(E_1); + NJson::TJsonValue modelJson; + modelJson.InsertValue("Enum", 1); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.EnumMode = TProto2JsonConfig::EnumNumber; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TFlatOptional proto; + proto.SetEnum(E_1); + NJson::TJsonValue modelJson; + modelJson.InsertValue("Enum", "E_1"); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.EnumMode = TProto2JsonConfig::EnumName; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TFlatOptional proto; + proto.SetEnum(E_1); + NJson::TJsonValue modelJson; + modelJson.InsertValue("Enum", "NProtobufJsonTest.E_1"); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.EnumMode = TProto2JsonConfig::EnumFullName; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TFlatOptional proto; + proto.SetEnum(E_1); + NJson::TJsonValue modelJson; + modelJson.InsertValue("Enum", "e_1"); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.EnumMode = TProto2JsonConfig::EnumNameLowerCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + TFlatOptional proto; + proto.SetEnum(E_1); + NJson::TJsonValue modelJson; + modelJson.InsertValue("Enum", "nprotobufjsontest.e_1"); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.EnumMode = TProto2JsonConfig::EnumFullNameLowerCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } +} // TestEnumConfig + +Y_UNIT_TEST(TestMissingSingleKeyConfig) { + { + TFlatOptional proto; + NJson::TJsonValue modelJson(NJson::JSON_MAP); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingSingleKeyMode = TProto2JsonConfig::MissingKeySkip; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + NJson::TJsonValue modelJson; +#define DEFINE_FIELD(name, value) \ + modelJson.InsertValue(#name, NJson::TJsonValue(NJson::JSON_NULL)); +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD + + TFlatOptional proto; + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingSingleKeyMode = TProto2JsonConfig::MissingKeyNull; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + { + // Test MissingKeyExplicitDefaultThrowRequired for non explicit default values. + TFlatOptional proto; + NJson::TJsonValue modelJson(NJson::JSON_MAP); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingSingleKeyMode = TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + { + // Test MissingKeyExplicitDefaultThrowRequired for explicit default values. + NJson::TJsonValue modelJson; + modelJson["String"] = "value"; + + TSingleDefaultString proto; + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingSingleKeyMode = TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + { + // Test MissingKeyExplicitDefaultThrowRequired for empty required values. + TFlatRequired proto; + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingSingleKeyMode = TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired; + UNIT_ASSERT_EXCEPTION_CONTAINS(Proto2Json(proto, json, config), yexception, "Empty required protobuf field"); + } + { + // Test MissingKeyExplicitDefaultThrowRequired for required value. + TSingleRequiredString proto; + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingSingleKeyMode = TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired; + + UNIT_ASSERT_EXCEPTION_CONTAINS(Proto2Json(proto, json, config), yexception, "Empty required protobuf field"); + + NJson::TJsonValue modelJson; + modelJson["String"] = "value"; + proto.SetString("value"); + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } +} // TestMissingSingleKeyConfig + +Y_UNIT_TEST(TestMissingRepeatedKeyNoConfig) { + { + TFlatRepeated proto; + NJson::TJsonValue modelJson(NJson::JSON_MAP); + NJson::TJsonValue json; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } +} // TestMissingRepeatedKeyNoConfig + +Y_UNIT_TEST(TestMissingRepeatedKeyConfig) { + { + TFlatRepeated proto; + NJson::TJsonValue modelJson(NJson::JSON_MAP); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingRepeatedKeyMode = TProto2JsonConfig::MissingKeySkip; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + + { + NJson::TJsonValue modelJson; +#define DEFINE_FIELD(name, value) \ + modelJson.InsertValue(#name, NJson::TJsonValue(NJson::JSON_NULL)); +#include <library/cpp/protobuf/json/ut/fields.incl> +#undef DEFINE_FIELD + + TFlatRepeated proto; + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingRepeatedKeyMode = TProto2JsonConfig::MissingKeyNull; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } + { + TFlatRepeated proto; + NJson::TJsonValue modelJson(NJson::JSON_MAP); + NJson::TJsonValue json; + TProto2JsonConfig config; + config.MissingRepeatedKeyMode = TProto2JsonConfig::MissingKeyExplicitDefaultThrowRequired; + + // SHould be same as MissingKeySkip + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, json, config)); + UNIT_ASSERT_JSONS_EQUAL(json, modelJson); + } +} // TestMissingRepeatedKeyConfig + +Y_UNIT_TEST(TestEscaping) { + // No escape + { + TString modelStr(R"_({"String":"value\""})_"); + + TFlatOptional proto; + proto.SetString(R"_(value")_"); + TStringStream jsonStr; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // TEscapeJTransform + { + TString modelStr(R"_({"String":"value\""})_"); + + TFlatOptional proto; + proto.SetString(R"_(value")_"); + TProto2JsonConfig config; + config.StringTransforms.push_back(new TEscapeJTransform<false, true>()); + TStringStream jsonStr; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(modelStr, jsonStr.Str()); + } + + // TCEscapeTransform + { + TString modelStr(R"_({"String":"value\""})_"); + + TFlatOptional proto; + proto.SetString(R"_(value")_"); + TProto2JsonConfig config; + config.StringTransforms.push_back(new TCEscapeTransform()); + TStringStream jsonStr; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // TSafeUtf8CEscapeTransform + { + TString modelStr(R"_({"String":"value\""})_"); + + TFlatOptional proto; + proto.SetString(R"_(value")_"); + TProto2JsonConfig config; + config.StringTransforms.push_back(new TSafeUtf8CEscapeTransform()); + TStringStream jsonStr; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } +} // TestEscaping + +class TBytesTransform: public IStringTransform { +public: + int GetType() const override { + return 0; + } + void Transform(TString&) const override { + } + void TransformBytes(TString& str) const override { + str = "bytes"; + } +}; + +Y_UNIT_TEST(TestBytesTransform) { + // Test that string field is not changed + { + TString modelStr(R"_({"String":"value"})_"); + + TFlatOptional proto; + proto.SetString(R"_(value)_"); + TProto2JsonConfig config; + config.StringTransforms.push_back(new TBytesTransform()); + TStringStream jsonStr; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Test that bytes field is changed + { + TString modelStr(R"_({"Bytes":"bytes"})_"); + + TFlatOptional proto; + proto.SetBytes(R"_(value)_"); + TProto2JsonConfig config; + config.StringTransforms.push_back(new TBytesTransform()); + TStringStream jsonStr; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } +} + +Y_UNIT_TEST(TestFieldNameMode) { + // Original case 1 + { + TString modelStr(R"_({"String":"value"})_"); + + TFlatOptional proto; + proto.SetString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Original case 2 + { + TString modelStr(R"_({"String":"value"})_"); + + TFlatOptional proto; + proto.SetString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameOriginalCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Lowercase + { + TString modelStr(R"_({"string":"value"})_"); + + TFlatOptional proto; + proto.SetString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameLowerCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Uppercase + { + TString modelStr(R"_({"STRING":"value"})_"); + + TFlatOptional proto; + proto.SetString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameUpperCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Camelcase + { + TString modelStr(R"_({"string":"value"})_"); + + TFlatOptional proto; + proto.SetString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameCamelCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + { + TString modelStr(R"_({"oneString":"value"})_"); + + TFlatOptional proto; + proto.SetOneString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameCamelCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + { + TString modelStr(R"_({"oneTwoString":"value"})_"); + + TFlatOptional proto; + proto.SetOneTwoString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameCamelCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // snake_case + { + TString modelStr(R"_({"string":"value"})_"); + + TFlatOptional proto; + proto.SetString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + { + TString modelStr(R"_({"one_string":"value"})_"); + + TFlatOptional proto; + proto.SetOneString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + { + TString modelStr(R"_({"one_two_string":"value"})_"); + + TFlatOptional proto; + proto.SetOneTwoString("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + { + TString modelStr(R"_({"a_b_c":"value","user_i_d":"value"})_"); + + TFlatOptional proto; + proto.SetABC("value"); + proto.SetUserID("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameSnakeCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // snake_case_dense + { + TString modelStr(R"_({"abc":"value","user_id":"value"})_"); + + TFlatOptional proto; + proto.SetABC("value"); + proto.SetUserID("value"); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameSnakeCaseDense; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Original case, repeated + { + TString modelStr(R"_({"I32":[1,2]})_"); + + TFlatRepeated proto; + proto.AddI32(1); + proto.AddI32(2); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameOriginalCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // Lower case, repeated + { + TString modelStr(R"_({"i32":[1,2]})_"); + + TFlatRepeated proto; + proto.AddI32(1); + proto.AddI32(2); + TStringStream jsonStr; + TProto2JsonConfig config; + config.FieldNameMode = TProto2JsonConfig::FieldNameLowerCase; + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // UseJsonName + { + // FIXME(CONTRIB-139): see the comment about UseJsonName in json2proto_ut.cpp: + // Def_upper json name should be "DefUpper". + TString modelStr(R"_({"My-Upper":1,"my-lower":2,"defUpper":3,"defLower":4})_"); + + TWithJsonName proto; + proto.Setmy_upper(1); + proto.SetMy_lower(2); + proto.SetDef_upper(3); + proto.Setdef_lower(4); + TStringStream jsonStr; + TProto2JsonConfig config; + config.SetUseJsonName(true); + + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + UNIT_ASSERT_STRINGS_EQUAL(jsonStr.Str(), modelStr); + } + + // FieldNameMode with UseJsonName + { + TProto2JsonConfig config; + config.SetFieldNameMode(TProto2JsonConfig::FieldNameLowerCase); + UNIT_ASSERT_EXCEPTION_CONTAINS( + config.SetUseJsonName(true), yexception, "mutually exclusive"); + } + { + TProto2JsonConfig config; + config.SetUseJsonName(true); + UNIT_ASSERT_EXCEPTION_CONTAINS( + config.SetFieldNameMode(TProto2JsonConfig::FieldNameLowerCase), yexception, "mutually exclusive"); + } + + /// TODO: test missing keys +} // TestFieldNameMode + +Y_UNIT_TEST(TestNan) { + TFlatOptional proto; + proto.SetDouble(std::numeric_limits<double>::quiet_NaN()); + + UNIT_ASSERT_EXCEPTION(Proto2Json(proto, TProto2JsonConfig()), yexception); +} // TestNan + +Y_UNIT_TEST(TestInf) { + TFlatOptional proto; + proto.SetFloat(std::numeric_limits<float>::infinity()); + + UNIT_ASSERT_EXCEPTION(Proto2Json(proto, TProto2JsonConfig()), yexception); +} // TestInf + +Y_UNIT_TEST(TestMap) { + TMapType proto; + + auto& items = *proto.MutableItems(); + items["key1"] = "value1"; + items["key2"] = "value2"; + items["key3"] = "value3"; + + TString modelStr(R"_({"Items":[{"key":"key3","value":"value3"},{"key":"key2","value":"value2"},{"key":"key1","value":"value1"}]})_"); + + TStringStream jsonStr; + TProto2JsonConfig config; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + + NJson::TJsonValue jsonValue, modelValue; + NJson::TJsonValue::TArray jsonItems, modelItems; + UNIT_ASSERT(NJson::ReadJsonTree(jsonStr.Str(), &jsonValue)); + UNIT_ASSERT(NJson::ReadJsonTree(modelStr, &modelValue)); + UNIT_ASSERT(jsonValue.Has("Items")); + jsonValue["Items"].GetArray(&jsonItems); + modelValue["Items"].GetArray(&modelItems); + auto itemKey = [](const NJson::TJsonValue& v) { + return v["key"].GetString(); + }; + SortBy(jsonItems, itemKey); + SortBy(modelItems, itemKey); + UNIT_ASSERT_EQUAL(jsonItems, modelItems); +} // TestMap + +Y_UNIT_TEST(TestMapAsObject) { + TMapType proto; + + auto& items = *proto.MutableItems(); + items["key1"] = "value1"; + items["key2"] = "value2"; + items["key3"] = "value3"; + + TString modelStr(R"_({"Items":{"key3":"value3","key2":"value2","key1":"value1"}})_"); + + TStringStream jsonStr; + TProto2JsonConfig config; + config.MapAsObject = true; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); + + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); +} // TestMapAsObject + +Y_UNIT_TEST(TestMapWTF) { + TMapType proto; + + auto& items = *proto.MutableItems(); + items["key1"] = "value1"; + items["key2"] = "value2"; + items["key3"] = "value3"; + + TString modelStr(R"_({"Items":{"key3":"value3","key2":"value2","key1":"value1"}})_"); + + TStringStream jsonStr; + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr)); + + UNIT_ASSERT_JSON_STRINGS_EQUAL(jsonStr.Str(), modelStr); +} // TestMapWTF + +Y_UNIT_TEST(TestStringifyLongNumbers) { +#define TEST_SINGLE(flag, value, expectString) \ + do { \ + TFlatOptional proto; \ + proto.SetSI64(value); \ + \ + TStringStream jsonStr; \ + TProto2JsonConfig config; \ + config.SetStringifyLongNumbers(flag); \ + UNIT_ASSERT_NO_EXCEPTION(Proto2Json(proto, jsonStr, config)); \ + if (expectString) { \ + UNIT_ASSERT_EQUAL(jsonStr.Str(), "{\"SI64\":\"" #value "\"}"); \ + } else { \ + UNIT_ASSERT_EQUAL(jsonStr.Str(), "{\"SI64\":" #value "}"); \ + } \ + } while (false) + + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersNever, 1, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersNever, 1000000000, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersNever, 10000000000000000, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersNever, -1, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersNever, -1000000000, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersNever, -10000000000000000, false); + + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForDouble, 1, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForDouble, 1000000000, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForDouble, 10000000000000000, true); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForDouble, -1, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForDouble, -1000000000, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForDouble, -10000000000000000, true); + + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForFloat, 1, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForFloat, 1000000000, true); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForFloat, 10000000000000000, true); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForFloat, -1, false); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForFloat, -1000000000, true); + TEST_SINGLE(TProto2JsonConfig::StringifyLongNumbersForFloat, -10000000000000000, true); + +#undef TEST_SINGLE +} // TestStringifyLongNumbers + +Y_UNIT_TEST(TestExtension) { + TExtensionField proto; + proto.SetExtension(bar, 1); + + Y_ASSERT(proto.HasExtension(bar)); + UNIT_ASSERT_EQUAL(Proto2Json(proto, TProto2JsonConfig()), "{\"NProtobufJsonTest.bar\":1}"); + + + TProto2JsonConfig cfg; + cfg.SetExtensionFieldNameMode(TProto2JsonConfig::ExtFldNameShort); + UNIT_ASSERT_EQUAL(Proto2Json(proto, cfg), "{\"bar\":1}"); +} // TestExtension + +} // TProto2JsonTest diff --git a/library/cpp/protobuf/json/ut/repeated_fields.incl b/library/cpp/protobuf/json/ut/repeated_fields.incl new file mode 100644 index 0000000000..e9548917d8 --- /dev/null +++ b/library/cpp/protobuf/json/ut/repeated_fields.incl @@ -0,0 +1,21 @@ +// Intentionally no #pragma once + +// (Field name == JSON key, Type, Values...) +DEFINE_REPEATED_FIELD(I32, i32, Min<i32>(), -1, 0, 1, Max<i32>()) +DEFINE_REPEATED_FIELD(I64, i64, Min<i64>(), -1ll, 0ll, 1ll, Max<i64>()) +DEFINE_REPEATED_FIELD(UI32, ui32, 0ul, 1ul, Max<ui32>()) +DEFINE_REPEATED_FIELD(UI64, ui64, 0ull, 1ull, Max<ui64>()) +DEFINE_REPEATED_FIELD(SI32, i32, Min<i32>(), -1, 0, 1, Max<i32>()) +DEFINE_REPEATED_FIELD(SI64, i64, Min<i64>(), -1ll, 0ll, 1ll, Max<i64>()) +DEFINE_REPEATED_FIELD(FI32, ui32, 0, 1, Max<ui32>()) +DEFINE_REPEATED_FIELD(FI64, ui64, 0ull, 1ull, Max<ui64>()) +DEFINE_REPEATED_FIELD(SFI32, i32, Min<i32>(), -1, 0, 1, Max<i32>()) +DEFINE_REPEATED_FIELD(SFI64, i64, Min<i64>(), -1ll, 0ll, 1ll, Max<i64>()) +DEFINE_REPEATED_FIELD(Bool, bool, false, true) +DEFINE_REPEATED_FIELD(String, TString, "", "Lorem ipsum", "123123") +DEFINE_REPEATED_FIELD(Bytes, TString, "", "מחשב", "\x1") +DEFINE_REPEATED_FIELD(Enum, EEnum, E_1, E_2, E_3) +DEFINE_REPEATED_FIELD(Float, float, 0.0f, 1.0f, 1.123f) +DEFINE_REPEATED_FIELD(Double, double, 0.0, 1.0, 1.123456789012) +DEFINE_REPEATED_FIELD(OneString, TString, "", "Lorem ipsum dolor", "1231231") +DEFINE_REPEATED_FIELD(OneTwoString, TString, "", "Lorem ipsum dolor sit", "12312312") diff --git a/library/cpp/protobuf/json/ut/string_transform_ut.cpp b/library/cpp/protobuf/json/ut/string_transform_ut.cpp new file mode 100644 index 0000000000..a31dabcb0f --- /dev/null +++ b/library/cpp/protobuf/json/ut/string_transform_ut.cpp @@ -0,0 +1,106 @@ +#include "json.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <library/cpp/protobuf/json/proto2json.h> + +Y_UNIT_TEST_SUITE(TDoubleEscapeTransform) { + Y_UNIT_TEST(TestEmptyString) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleEscapeTransform(); + TString s; + s = ""; + transform.Transform(s); + UNIT_ASSERT_EQUAL(s, ""); + } + + Y_UNIT_TEST(TestAlphabeticString) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleEscapeTransform(); + TString s; + s = "abacaba"; + transform.Transform(s); + UNIT_ASSERT_EQUAL(s, "abacaba"); + } + + Y_UNIT_TEST(TestRussianSymbols) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleEscapeTransform(); + TString s; + s = "тест"; + transform.Transform(s); + UNIT_ASSERT_EQUAL(s, "\\\\321\\\\202\\\\320\\\\265\\\\321\\\\201\\\\321\\\\202"); + } + + Y_UNIT_TEST(TestEscapeSpecialSymbols) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleEscapeTransform(); + TString s; + s = "aba\\ca\"ba"; + transform.Transform(s); + Cerr << "###" << s << Endl; + UNIT_ASSERT_EQUAL(s, "aba\\\\\\\\ca\\\\\\\"ba"); + } +} + +Y_UNIT_TEST_SUITE(TDoubleUnescapeTransform) { + Y_UNIT_TEST(TestEmptyString) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleUnescapeTransform(); + TString s; + s = ""; + transform.Transform(s); + UNIT_ASSERT_EQUAL("", s); + } + + Y_UNIT_TEST(TestAlphabeticString) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleUnescapeTransform(); + TString s; + s = "abacaba"; + transform.Transform(s); + Cerr << "###" << s << Endl; + UNIT_ASSERT_EQUAL("abacaba", s); + } + + Y_UNIT_TEST(TestRussianSymbols) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleUnescapeTransform(); + TString s; + s = "\\\\321\\\\202\\\\320\\\\265\\\\321\\\\201\\\\321\\\\202"; + transform.Transform(s); + UNIT_ASSERT_EQUAL("тест", s); + } + + Y_UNIT_TEST(TestEscapeSpecialSymbols) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleUnescapeTransform(); + TString s; + s = "aba\\\\\\\\ca\\\\\\\"ba"; + transform.Transform(s); + UNIT_ASSERT_EQUAL("aba\\ca\"ba", s); + } + + Y_UNIT_TEST(TestEscapeSpecialSymbolsDifficultCases) { + const NProtobufJson::IStringTransform& transform = NProtobufJson::TDoubleUnescapeTransform(); + TString s; + s = "\\\\\\\\\\\\\\\\"; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\\\\", s); + + s = "\\\\\\\\\\\\\\\""; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\\\"", s); + + s = "\\\\\\\"\\\\\\\\"; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\"\\", s); + + s = "\\\\\\\"\\\\\\\""; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\"\"", s); + + s = "\\\\\\\\\\\\\\\\\\\\\\\\"; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\\\\\\", s); + + s = "\\\\\\\\\\\\\\\\\\\\\\\\abacaba\\\\"; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\\\\\\abacaba", s); + + s = "\\\\\\\\\\\\\\\\\\\\\\\\abacaba\\\""; + transform.Transform(s); + UNIT_ASSERT_EQUAL("\\\\\\abacaba\"", s); + } +} diff --git a/library/cpp/protobuf/json/ut/test.proto b/library/cpp/protobuf/json/ut/test.proto new file mode 100644 index 0000000000..0fa996fd41 --- /dev/null +++ b/library/cpp/protobuf/json/ut/test.proto @@ -0,0 +1,203 @@ +package NProtobufJsonTest; + +enum EEnum { + E_0 = 0; + E_1 = 1; + E_2 = 2; + E_3 = 3; +}; + +message TFlatOptional { + optional int32 I32 = 1; + optional int64 I64 = 2; + optional uint32 UI32 = 3; + optional uint64 UI64 = 4; + optional sint32 SI32 = 5; + optional sint64 SI64 = 6; + optional fixed32 FI32 = 7; + optional fixed64 FI64 = 8; + optional sfixed32 SFI32 = 9; + optional sfixed64 SFI64 = 10; + + optional bool Bool = 11; + + optional string String = 12; + optional bytes Bytes = 13; + + optional EEnum Enum = 14; + + optional float Float = 15; + optional double Double = 16; + + optional string OneString = 17; + optional string OneTwoString = 18; + optional string ABC = 19; + optional string UserID = 20; +}; + +message TFlatRequired { + required int32 I32 = 1; + required int64 I64 = 2; + required uint32 UI32 = 3; + required uint64 UI64 = 4; + required sint32 SI32 = 5; + required sint64 SI64 = 6; + required fixed32 FI32 = 7; + required fixed64 FI64 = 8; + required sfixed32 SFI32 = 9; + required sfixed64 SFI64 = 10; + + required bool Bool = 11; + + required string String = 12; + required bytes Bytes = 13; + + required EEnum Enum = 14; + + required float Float = 15; + required double Double = 16; + + required string OneString = 17; + required string OneTwoString = 18; + required string ABC = 19; + required string UserID = 20; +}; + +message TFlatRepeated { + repeated int32 I32 = 1; + repeated int64 I64 = 2; + repeated uint32 UI32 = 3; + repeated uint64 UI64 = 4; + repeated sint32 SI32 = 5; + repeated sint64 SI64 = 6; + repeated fixed32 FI32 = 7; + repeated fixed64 FI64 = 8; + repeated sfixed32 SFI32 = 9; + repeated sfixed64 SFI64 = 10; + + repeated bool Bool = 11; + + repeated string String = 12; + repeated bytes Bytes = 13; + + repeated EEnum Enum = 14; + + repeated float Float = 15; + repeated double Double = 16; + + repeated string OneString = 17; + repeated string OneTwoString = 18; + repeated string ABC = 19; + repeated string UserID = 20; +}; + +message TFlatDefault { + optional int32 I32 = 1 [default = 132]; + optional int64 I64 = 2 [default = 164]; + optional uint32 UI32 = 3 [default = 232]; + optional uint64 UI64 = 4 [default = 264]; + optional sint32 SI32 = 5 [default = 332]; + optional sint64 SI64 = 6 [default = 364]; + optional fixed32 FI32 = 7 [default = 432]; + optional fixed64 FI64 = 8 [default = 464]; + optional sfixed32 SFI32 = 9 [default = 532]; + optional sfixed64 SFI64 = 10 [default = 564]; + + optional bool Bool = 11 [default = true]; + + optional string String = 12 [default = "string"]; + optional bytes Bytes = 13 [default = "bytes"]; + + optional EEnum Enum = 14 [default = E_2]; + + optional float Float = 15 [default = 0.123]; + optional double Double = 16 [default = 0.456]; + + optional string OneString = 17 [default = "string"]; + optional string OneTwoString = 18 [default = "string"]; + optional string ABC = 19 [default = "abc"]; + optional string UserID = 20 [default = "some_id"]; +}; + +message TCompositeOptional { + optional TFlatOptional Part = 1; +}; + +message TCompositeRequired { + required TFlatRequired Part = 1; +}; + +message TCompositeRepeated { + repeated TFlatOptional Part = 1; +}; + +message TMapType { + map<string, string> Items = 1; +}; + +message TNameGeneratorType { + optional int32 Field = 1; +}; + +message TEnumValueGeneratorType { + enum EEnum { + ENUM_42 = 1; + }; + + optional EEnum Enum = 1; +}; + +message TComplexMapType { + map<int32, int32> I32 = 1; + map<int64, int64> I64 = 2; + map<uint32, uint32> UI32 = 3; + map<uint64, uint64> UI64 = 4; + map<sint32, sint32> SI32 = 5; + map<sint64, sint64> SI64 = 6; + map<fixed32, fixed32> FI32 = 7; + map<fixed64, fixed64> FI64 = 8; + map<sfixed32, sfixed32> SFI32 = 9; + map<sfixed64, sfixed64> SFI64 = 10; + + map<bool, bool> Bool = 11; + + map<string, string> String = 12; + + map<string, EEnum> Enum = 13; + + map<string, float> Float = 14; + map<string, double> Double = 15; + + map<string, TComplexMapType> Nested = 16; +}; + +message TWithJsonName { + optional int32 my_upper = 1 [json_name = "My-Upper"]; + optional int32 My_lower = 2 [json_name = "my-lower"]; + optional int32 Def_upper = 3; // json_name = "DefUpper" + optional int32 def_lower = 4; // json_name = "defLower" +} + +message TSingleRequiredString { + required string String = 1; +} + +message TSingleDefaultString { + optional string String = 1 [default = "value"]; +} + +message TSingleRepeatedString { + repeated string RepeatedString = 1; +} + +message TSingleRepeatedInt { + repeated int32 RepeatedInt = 1; +} + +message TExtensionField { + extensions 100 to 199; +} + +extend TExtensionField { + optional int32 bar = 123; +}
\ No newline at end of file diff --git a/library/cpp/protobuf/json/ut/util_ut.cpp b/library/cpp/protobuf/json/ut/util_ut.cpp new file mode 100644 index 0000000000..05101dca28 --- /dev/null +++ b/library/cpp/protobuf/json/ut/util_ut.cpp @@ -0,0 +1,42 @@ +#include <library/cpp/protobuf/json/util.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NProtobufJson; + +Y_UNIT_TEST_SUITE(TEqualsTest) { + Y_UNIT_TEST(TestEmpty) { + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("", "")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("", "_")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("f", "")); + } + + Y_UNIT_TEST(TestTrivial) { + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("f", "f")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("f", "o")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("fo", "f")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("f", "fo")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("bar", "baz")); + } + + Y_UNIT_TEST(TestUnderscores) { + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("foo_bar", "foobar")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("foo_bar_", "foobar")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("foo_bar_z", "foobar")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("foo__bar__", "foobar")); + UNIT_ASSERT(!EqualsIgnoringCaseAndUnderscores("foo__bar__z", "foobar")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("_foo_bar", "foobar")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("_foo_bar_", "foobar")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("_foo_bar_", "foo___bar")); + } + + Y_UNIT_TEST(TestCase) { + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("foo_bar", "FOO_BAR")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("foobar", "fooBar")); + } + + Y_UNIT_TEST(TestCaseAndUnderscores) { + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("fooBar", "FOO_BAR")); + UNIT_ASSERT(EqualsIgnoringCaseAndUnderscores("FOO_BAR_BAZ", "fooBar_BAZ")); + } +} diff --git a/library/cpp/protobuf/json/ut/ya.make b/library/cpp/protobuf/json/ut/ya.make new file mode 100644 index 0000000000..b60a6d3c17 --- /dev/null +++ b/library/cpp/protobuf/json/ut/ya.make @@ -0,0 +1,23 @@ +UNITTEST_FOR(library/cpp/protobuf/json) + +OWNER(avitella) + +SRCS( + filter_ut.cpp + json2proto_ut.cpp + proto2json_ut.cpp + inline_ut.proto + inline_ut.cpp + string_transform_ut.cpp + filter_ut.proto + test.proto + util_ut.cpp +) + +GENERATE_ENUM_SERIALIZATION(test.pb.h) + +PEERDIR( + library/cpp/protobuf/json +) + +END() diff --git a/library/cpp/protobuf/json/util.cpp b/library/cpp/protobuf/json/util.cpp new file mode 100644 index 0000000000..53a065eee2 --- /dev/null +++ b/library/cpp/protobuf/json/util.cpp @@ -0,0 +1,76 @@ +#include "util.h" + +#include <util/string/ascii.h> + +namespace { + void ToSnakeCaseImpl(TString* const name, std::function<bool(const char)> requiresUnderscore) { + bool requiresChanges = false; + size_t size = name->size(); + for (size_t i = 0; i < name->size(); i++) { + if (IsAsciiUpper(name->at(i))) { + requiresChanges = true; + if (i > 0 && requiresUnderscore(name->at(i - 1))) { + size++; + } + } + } + + if (!requiresChanges) { + return; + } + + if (size != name->size()) { + TString result; + result.reserve(size); + for (size_t i = 0; i < name->size(); i++) { + const char c = name->at(i); + if (IsAsciiUpper(c)) { + if (i > 0 && requiresUnderscore(name->at(i - 1))) { + result += '_'; + } + result += AsciiToLower(c); + } else { + result += c; + } + } + *name = std::move(result); + } else { + name->to_lower(); + } + } +} + +namespace NProtobufJson { + void ToSnakeCase(TString* const name) { + ToSnakeCaseImpl(name, [](const char prev) { return prev != '_'; }); + } + + void ToSnakeCaseDense(TString* const name) { + ToSnakeCaseImpl(name, [](const char prev) { return prev != '_' && !IsAsciiUpper(prev); }); + } + + bool EqualsIgnoringCaseAndUnderscores(TStringBuf s1, TStringBuf s2) { + size_t i1 = 0, i2 = 0; + + while (i1 < s1.size() && i2 < s2.size()) { + if (s1[i1] == '_') { + ++i1; + } else if (s2[i2] == '_') { + ++i2; + } else if (AsciiToUpper(s1[i1]) != AsciiToUpper(s2[i2])) { + return false; + } else { + ++i1, ++i2; + } + } + + while (i1 < s1.size() && s1[i1] == '_') { + ++i1; + } + while (i2 < s2.size() && s2[i2] == '_') { + ++i2; + } + + return (i1 == s1.size() && i2 == s2.size()); + } +} diff --git a/library/cpp/protobuf/json/util.h b/library/cpp/protobuf/json/util.h new file mode 100644 index 0000000000..d93342d3f8 --- /dev/null +++ b/library/cpp/protobuf/json/util.h @@ -0,0 +1,14 @@ +#pragma once + +#include <util/generic/string.h> + +namespace NProtobufJson { + void ToSnakeCase(TString* const name); + + void ToSnakeCaseDense(TString* const name); + + /** + * "FOO_BAR" ~ "foo_bar" ~ "fooBar" + */ + bool EqualsIgnoringCaseAndUnderscores(TStringBuf s1, TStringBuf s2); +} diff --git a/library/cpp/protobuf/json/ya.make b/library/cpp/protobuf/json/ya.make new file mode 100644 index 0000000000..2f2c75cfdb --- /dev/null +++ b/library/cpp/protobuf/json/ya.make @@ -0,0 +1,25 @@ +LIBRARY() + +OWNER(avitella) + +SRCS( + json2proto.cpp + json_output_create.cpp + json_value_output.cpp + json_writer_output.cpp + name_generator.cpp + proto2json.cpp + proto2json_printer.cpp + string_transform.cpp + util.h + util.cpp +) + +PEERDIR( + contrib/libs/protobuf + library/cpp/json + library/cpp/protobuf/util + library/cpp/string_utils/relaxed_escaper +) + +END() diff --git a/library/cpp/protobuf/util/cast.h b/library/cpp/protobuf/util/cast.h new file mode 100644 index 0000000000..83749dfcee --- /dev/null +++ b/library/cpp/protobuf/util/cast.h @@ -0,0 +1,156 @@ +#pragma once + +#include "traits.h" + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> + +#include <util/generic/cast.h> + +namespace NProtoBuf { + // C++ compatible conversions of FieldDescriptor::CppType's + + using ECppType = FieldDescriptor::CppType; + + namespace NCast { + template <ECppType src, ECppType dst> + struct TIsCompatibleCppType { + enum { + Result = src == dst || + (TIsNumericCppType<src>::Result && TIsNumericCppType<dst>::Result) + }; + }; + + template <ECppType src, ECppType dst> + struct TIsEnumToNumericCppType { + enum { + Result = (src == FieldDescriptor::CPPTYPE_ENUM && TIsNumericCppType<dst>::Result) + }; + }; + + template <ECppType src, ECppType dst, bool compatible> // compatible == true + struct TCompatCastBase { + static const bool IsCompatible = true; + + typedef typename TCppTypeTraits<src>::T TSrc; + typedef typename TCppTypeTraits<dst>::T TDst; + + static inline TDst Cast(TSrc value) { + return value; + } + }; + + template <ECppType src, ECppType dst> // compatible == false + struct TCompatCastBase<src, dst, false> { + static const bool IsCompatible = false; + + typedef typename TCppTypeTraits<src>::T TSrc; + typedef typename TCppTypeTraits<dst>::T TDst; + + static inline TDst Cast(TSrc) { + ythrow TBadCastException() << "Incompatible FieldDescriptor::CppType conversion: #" + << (size_t)src << " to #" << (size_t)dst; + } + }; + + template <ECppType src, ECppType dst, bool isEnumToNum> // enum -> numeric + struct TCompatCastImpl { + static const bool IsCompatible = true; + + typedef typename TCppTypeTraits<dst>::T TDst; + + static inline TDst Cast(const EnumValueDescriptor* value) { + Y_ASSERT(value != nullptr); + return value->number(); + } + }; + + template <ECppType src, ECppType dst> + struct TCompatCastImpl<src, dst, false>: public TCompatCastBase<src, dst, TIsCompatibleCppType<src, dst>::Result> { + using TCompatCastBase<src, dst, TIsCompatibleCppType<src, dst>::Result>::IsCompatible; + }; + + template <ECppType src, ECppType dst> + struct TCompatCast: public TCompatCastImpl<src, dst, TIsEnumToNumericCppType<src, dst>::Result> { + typedef TCompatCastImpl<src, dst, TIsEnumToNumericCppType<src, dst>::Result> TBase; + + typedef typename TCppTypeTraits<src>::T TSrc; + typedef typename TCppTypeTraits<dst>::T TDst; + + using TBase::Cast; + using TBase::IsCompatible; + + inline bool Try(TSrc value, TDst& res) { + if (IsCompatible) { + res = Cast(value); + return true; + } + return false; + } + }; + + } + + template <ECppType src, ECppType dst> + inline typename TCppTypeTraits<dst>::T CompatCast(typename TCppTypeTraits<src>::T value) { + return NCast::TCompatCast<src, dst>::Cast(value); + } + + template <ECppType src, ECppType dst> + inline bool TryCompatCast(typename TCppTypeTraits<src>::T value, typename TCppTypeTraits<dst>::T& res) { + return NCast::TCompatCast<src, dst>::Try(value, res); + } + + // Message static/dynamic checked casts + + template <typename TpMessage> + inline const TpMessage* TryCast(const Message* msg) { + if (!msg || TpMessage::descriptor() != msg->GetDescriptor()) + return NULL; + return CheckedCast<const TpMessage*>(msg); + } + + template <typename TpMessage> + inline const TpMessage* TryCast(const Message* msg, const TpMessage*& ret) { + ret = TryCast<TpMessage>(msg); + return ret; + } + + template <typename TpMessage> + inline TpMessage* TryCast(Message* msg) { + if (!msg || TpMessage::descriptor() != msg->GetDescriptor()) + return nullptr; + return CheckedCast<TpMessage*>(msg); + } + + template <typename TpMessage> + inline TpMessage* TryCast(Message* msg, TpMessage*& ret) { + ret = TryCast<TpMessage>(msg); + return ret; + } + + // specialize for Message itself + + template <> + inline const Message* TryCast<Message>(const Message* msg) { + return msg; + } + + template <> + inline Message* TryCast<Message>(Message* msg) { + return msg; + } + + // Binary serialization compatible conversion + inline bool TryBinaryCast(const Message* from, Message* to, TString* buffer = nullptr) { + TString tmpbuf; + if (!buffer) + buffer = &tmpbuf; + + if (!from->SerializeToString(buffer)) + return false; + + return to->ParseFromString(*buffer); + } + +} diff --git a/library/cpp/protobuf/util/is_equal.cpp b/library/cpp/protobuf/util/is_equal.cpp new file mode 100644 index 0000000000..227408006e --- /dev/null +++ b/library/cpp/protobuf/util/is_equal.cpp @@ -0,0 +1,163 @@ +#include "is_equal.h" +#include "traits.h" + +#include <google/protobuf/descriptor.h> + +#include <util/generic/yexception.h> +#include <util/string/cast.h> +#include <util/string/vector.h> + +namespace NProtoBuf { + template <bool useDefault> + static bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath); + + namespace { + template <FieldDescriptor::CppType CppType, bool useDefault> + struct TCompareValue { + typedef typename TCppTypeTraits<CppType>::T T; + static inline bool IsEqual(T value1, T value2, TVector<TString>*) { + return value1 == value2; + } + }; + + template <bool useDefault> + struct TCompareValue<FieldDescriptor::CPPTYPE_MESSAGE, useDefault> { + static inline bool IsEqual(const Message* value1, const Message* value2, TVector<TString>* differentPath) { + return NProtoBuf::IsEqualImpl<useDefault>(*value1, *value2, differentPath); + } + }; + + template <FieldDescriptor::CppType CppType, bool useDefault> + class TCompareField { + typedef TCppTypeTraits<CppType> TTraits; + typedef TCompareValue<CppType, useDefault> TCompare; + + public: + static inline bool IsEqual(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) { + if (field.is_repeated()) + return IsEqualRepeated(m1, m2, &field, differentPath); + else + return IsEqualSingle(m1, m2, &field, differentPath); + } + + private: + static bool IsEqualSingle(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) { + bool has1 = m1.GetReflection()->HasField(m1, field); + bool has2 = m2.GetReflection()->HasField(m2, field); + + if (has1 != has2) { + if (!useDefault || field->is_required()) { + return false; + } + } else if (!has1) + return true; + + return TCompare::IsEqual(TTraits::Get(m1, field), + TTraits::Get(m2, field), + differentPath); + } + + static bool IsEqualRepeated(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) { + int fieldSize = m1.GetReflection()->FieldSize(m1, field); + if (fieldSize != m2.GetReflection()->FieldSize(m2, field)) + return false; + for (int i = 0; i < fieldSize; ++i) + if (!IsEqualRepeatedValue(m1, m2, field, i, differentPath)) { + if (!!differentPath) { + differentPath->push_back(ToString(i)); + } + return false; + } + return true; + } + + static inline bool IsEqualRepeatedValue(const Message& m1, const Message& m2, const FieldDescriptor* field, int index, TVector<TString>* differentPath) { + return TCompare::IsEqual(TTraits::GetRepeated(m1, field, index), + TTraits::GetRepeated(m2, field, index), + differentPath); + } + }; + + template <bool useDefault> + bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) { +#define CASE_CPPTYPE(cpptype) \ + case FieldDescriptor::CPPTYPE_##cpptype: { \ + bool r = TCompareField<FieldDescriptor::CPPTYPE_##cpptype, useDefault>::IsEqual(m1, m2, field, differentPath); \ + if (!r && !!differentPath) { \ + differentPath->push_back(field.name()); \ + } \ + return r; \ + } + + switch (field.cpp_type()) { + CASE_CPPTYPE(INT32) + CASE_CPPTYPE(INT64) + CASE_CPPTYPE(UINT32) + CASE_CPPTYPE(UINT64) + CASE_CPPTYPE(DOUBLE) + CASE_CPPTYPE(FLOAT) + CASE_CPPTYPE(BOOL) + CASE_CPPTYPE(ENUM) + CASE_CPPTYPE(STRING) + CASE_CPPTYPE(MESSAGE) + default: + ythrow yexception() << "Unsupported cpp-type field comparison"; + } + +#undef CASE_CPPTYPE + } + } + + template <bool useDefault> + bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath) { + const Descriptor* descr = m1.GetDescriptor(); + if (descr != m2.GetDescriptor()) { + return false; + } + for (int i = 0; i < descr->field_count(); ++i) + if (!IsEqualField<useDefault>(m1, m2, *descr->field(i), differentPath)) { + return false; + } + return true; + } + + bool IsEqual(const Message& m1, const Message& m2) { + return IsEqualImpl<false>(m1, m2, nullptr); + } + + bool IsEqual(const Message& m1, const Message& m2, TString* differentPath) { + TVector<TString> differentPathVector; + TVector<TString>* differentPathVectorPtr = !!differentPath ? &differentPathVector : nullptr; + bool r = IsEqualImpl<false>(m1, m2, differentPathVectorPtr); + if (!r && differentPath) { + *differentPath = JoinStrings(differentPathVector.rbegin(), differentPathVector.rend(), "/"); + } + return r; + } + + bool IsEqualDefault(const Message& m1, const Message& m2) { + return IsEqualImpl<true>(m1, m2, nullptr); + } + + template <bool useDefault> + static bool IsEqualFieldImpl( + const Message& m1, + const Message& m2, + const FieldDescriptor& field, + TVector<TString>* differentPath) { + const Descriptor* descr = m1.GetDescriptor(); + if (descr != m2.GetDescriptor()) { + return false; + } + return IsEqualField<useDefault>(m1, m2, field, differentPath); + } + + bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field) { + return IsEqualFieldImpl<false>(m1, m2, field, nullptr); + } + + bool IsEqualFieldDefault(const Message& m1, const Message& m2, const FieldDescriptor& field) { + return IsEqualFieldImpl<true>(m1, m2, field, nullptr); + } + +} diff --git a/library/cpp/protobuf/util/is_equal.h b/library/cpp/protobuf/util/is_equal.h new file mode 100644 index 0000000000..13c0aae63d --- /dev/null +++ b/library/cpp/protobuf/util/is_equal.h @@ -0,0 +1,33 @@ +#pragma once + +#include <util/generic/fwd.h> + +namespace google { + namespace protobuf { + class Message; + class FieldDescriptor; + } +} + +namespace NProtoBuf { + using ::google::protobuf::FieldDescriptor; + using ::google::protobuf::Message; +} + +namespace NProtoBuf { + // Reflection-based equality check for arbitrary protobuf messages + + // Strict comparison: optional field without value is NOT equal to + // a field with explicitly set default value. + bool IsEqual(const Message& m1, const Message& m2); + bool IsEqual(const Message& m1, const Message& m2, TString* differentPath); + + bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field); + + // Non-strict version: optional field without explicit value is compared + // using its default value. + bool IsEqualDefault(const Message& m1, const Message& m2); + + bool IsEqualFieldDefault(const Message& m1, const Message& m2, const FieldDescriptor& field); + +} diff --git a/library/cpp/protobuf/util/is_equal_ut.cpp b/library/cpp/protobuf/util/is_equal_ut.cpp new file mode 100644 index 0000000000..3ca4c90dd5 --- /dev/null +++ b/library/cpp/protobuf/util/is_equal_ut.cpp @@ -0,0 +1,88 @@ +#include "is_equal.h" +#include <library/cpp/protobuf/util/ut/sample_for_is_equal.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <google/protobuf/descriptor.h> + +Y_UNIT_TEST_SUITE(ProtobufIsEqual) { + const ::google::protobuf::Descriptor* Descr = TSampleForIsEqual::descriptor(); + const ::google::protobuf::FieldDescriptor* NameDescr = Descr->field(0); + const ::google::protobuf::FieldDescriptor* InnerDescr = Descr->field(1); + + Y_UNIT_TEST(CheckDescriptors) { + UNIT_ASSERT(Descr); + UNIT_ASSERT(NameDescr); + UNIT_ASSERT_VALUES_EQUAL(NameDescr->name(), "Name"); + UNIT_ASSERT_VALUES_EQUAL(InnerDescr->name(), "Inner"); + } + + Y_UNIT_TEST(IsEqual1) { + TSampleForIsEqual a; + TSampleForIsEqual b; + + a.SetName("aaa"); + b.SetName("bbb"); + + TString path; + + bool equal = NProtoBuf::IsEqual(a, b, &path); + UNIT_ASSERT(!equal); + UNIT_ASSERT_VALUES_EQUAL("Name", path); + + UNIT_ASSERT(!NProtoBuf::IsEqualField(a, b, *NameDescr)); + } + + Y_UNIT_TEST(IsEqual2) { + TSampleForIsEqual a; + TSampleForIsEqual b; + + a.MutableInner()->SetBrbrbr("aaa"); + b.MutableInner()->SetBrbrbr("bbb"); + + TString path; + + bool equal = NProtoBuf::IsEqual(a, b, &path); + UNIT_ASSERT(!equal); + UNIT_ASSERT_VALUES_EQUAL("Inner/Brbrbr", path); + + bool equalField = NProtoBuf::IsEqualField(a, b, *InnerDescr); + UNIT_ASSERT(!equalField); + } + + Y_UNIT_TEST(IsEqual3) { + TSampleForIsEqual a; + TSampleForIsEqual b; + + a.SetName("aaa"); + a.MutableInner()->SetBrbrbr("bbb"); + + b.SetName("aaa"); + b.MutableInner()->SetBrbrbr("bbb"); + + TString path; + + UNIT_ASSERT(NProtoBuf::IsEqual(a, b)); + UNIT_ASSERT(NProtoBuf::IsEqualField(a, b, *NameDescr)); + UNIT_ASSERT(NProtoBuf::IsEqualField(a, b, *InnerDescr)); + + b.MutableInner()->SetBrbrbr("ccc"); + UNIT_ASSERT(!NProtoBuf::IsEqual(a, b)); + UNIT_ASSERT(!NProtoBuf::IsEqualField(a, b, *InnerDescr)); + + b.SetName("ccc"); + UNIT_ASSERT(!NProtoBuf::IsEqualField(a, b, *NameDescr)); + } + + Y_UNIT_TEST(IsEqualDefault) { + TSampleForIsEqual a; + TSampleForIsEqual b; + + a.SetName(""); + UNIT_ASSERT(NProtoBuf::IsEqualDefault(a, b)); + UNIT_ASSERT(!NProtoBuf::IsEqual(a, b)); + + UNIT_ASSERT(!NProtoBuf::IsEqualField(a, b, *NameDescr)); + UNIT_ASSERT(NProtoBuf::IsEqualFieldDefault(a, b, *NameDescr)); + } +} diff --git a/library/cpp/protobuf/util/iterators.h b/library/cpp/protobuf/util/iterators.h new file mode 100644 index 0000000000..6d53ac71b1 --- /dev/null +++ b/library/cpp/protobuf/util/iterators.h @@ -0,0 +1,53 @@ +#pragma once + +#include <google/protobuf/descriptor.h> + +namespace NProtoBuf { + class TFieldsIterator { + public: + explicit TFieldsIterator(const NProtoBuf::Descriptor* descriptor, int position = 0) + : Descriptor(descriptor) + , Position(position) + { } + + TFieldsIterator& operator++() { + ++Position; + return *this; + } + + TFieldsIterator& operator++(int) { + auto& ret = *this; + ++*this; + return ret; + } + + const NProtoBuf::FieldDescriptor* operator*() const { + return Descriptor->field(Position); + } + + bool operator== (const TFieldsIterator& other) const { + return Position == other.Position && Descriptor == other.Descriptor; + } + + bool operator!= (const TFieldsIterator& other) const { + return !(*this == other); + } + + private: + const NProtoBuf::Descriptor* Descriptor = nullptr; + int Position = 0; + }; +} + +// Namespaces required by `range-based for` ADL: +namespace google { + namespace protobuf { + NProtoBuf::TFieldsIterator begin(const NProtoBuf::Descriptor& descriptor) { + return NProtoBuf::TFieldsIterator(&descriptor); + } + + NProtoBuf::TFieldsIterator end(const NProtoBuf::Descriptor& descriptor) { + return NProtoBuf::TFieldsIterator(&descriptor, descriptor.field_count()); + } + } +} diff --git a/library/cpp/protobuf/util/iterators_ut.cpp b/library/cpp/protobuf/util/iterators_ut.cpp new file mode 100644 index 0000000000..9ebcff2963 --- /dev/null +++ b/library/cpp/protobuf/util/iterators_ut.cpp @@ -0,0 +1,52 @@ +#include "iterators.h" +#include "simple_reflection.h" +#include <library/cpp/protobuf/util/ut/common_ut.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/algorithm.h> + +using NProtoBuf::TFieldsIterator; +using NProtoBuf::TConstField; + +Y_UNIT_TEST_SUITE(Iterators) { + Y_UNIT_TEST(Count) { + const NProtobufUtilUt::TWalkTest proto; + const NProtoBuf::Descriptor* d = proto.GetDescriptor(); + TFieldsIterator dbegin(d), dend(d, d->field_count()); + size_t steps = 0; + + UNIT_ASSERT_EQUAL(dbegin, begin(*d)); + UNIT_ASSERT_EQUAL(dend, end(*d)); + + for (; dbegin != dend; ++dbegin) + ++steps; + UNIT_ASSERT_VALUES_EQUAL(steps, d->field_count()); + } + + Y_UNIT_TEST(RangeFor) { + size_t steps = 0, values = 0; + NProtobufUtilUt::TWalkTest proto; + proto.SetOptStr("yandex"); + for (const auto& field : *proto.GetDescriptor()) { + values += TConstField(proto, field).HasValue(); + ++steps; + } + UNIT_ASSERT_VALUES_EQUAL(steps, proto.GetDescriptor()->field_count()); + UNIT_ASSERT_VALUES_EQUAL(values, 1); + } + + Y_UNIT_TEST(AnyOf) { + NProtobufUtilUt::TWalkTest proto; + const NProtoBuf::Descriptor* d = proto.GetDescriptor(); + TFieldsIterator begin(d), end(d, d->field_count()); + UNIT_ASSERT(!AnyOf(begin, end, [&proto](const NProtoBuf::FieldDescriptor* f){ + return TConstField(proto, f).HasValue(); + })); + + proto.SetOptStr("yandex"); + UNIT_ASSERT(AnyOf(begin, end, [&proto](const NProtoBuf::FieldDescriptor* f){ + return TConstField(proto, f).HasValue(); + })); + } +} diff --git a/library/cpp/protobuf/util/merge.cpp b/library/cpp/protobuf/util/merge.cpp new file mode 100644 index 0000000000..dc2b9cc806 --- /dev/null +++ b/library/cpp/protobuf/util/merge.cpp @@ -0,0 +1,46 @@ +#include "merge.h" +#include "simple_reflection.h" + +#include <google/protobuf/message.h> + +#include <library/cpp/protobuf/util/proto/merge.pb.h> + +namespace NProtoBuf { + void RewriteMerge(const Message& src, Message& dst) { + const Descriptor* d = src.GetDescriptor(); + Y_ASSERT(d == dst.GetDescriptor()); + + for (int i = 0; i < d->field_count(); ++i) { + if (TConstField(src, d->field(i)).Has()) + TMutableField(dst, d->field(i)).Clear(); + } + + dst.MergeFrom(src); + } + + static void ClearNonMergeable(const Message& src, Message& dst) { + const Descriptor* d = src.GetDescriptor(); + if (d->options().GetExtension(DontMerge)) { + dst.Clear(); + return; + } + + for (int i = 0; i < d->field_count(); ++i) { + const FieldDescriptor* fd = d->field(i); + TConstField srcField(src, fd); + if (srcField.Has()) { + TMutableField dstField(dst, fd); + if (fd->options().GetExtension(DontMergeField)) + dstField.Clear(); + else if (!fd->is_repeated() && dstField.IsMessage() && dstField.Has()) + ClearNonMergeable(*srcField.Get<const Message*>(), *dstField.MutableMessage()); + } + } + } + + void CustomMerge(const Message& src, Message& dst) { + ClearNonMergeable(src, dst); + dst.MergeFrom(src); + } + +} diff --git a/library/cpp/protobuf/util/merge.h b/library/cpp/protobuf/util/merge.h new file mode 100644 index 0000000000..924975f141 --- /dev/null +++ b/library/cpp/protobuf/util/merge.h @@ -0,0 +1,22 @@ +#pragma once + +namespace google { + namespace protobuf { + class Message; + } +} + +namespace NProtoBuf { + using Message = ::google::protobuf::Message; +} + +namespace NProtoBuf { + // Similiar to Message::MergeFrom, overwrites existing repeated fields + // and embedded messages completely instead of recursive merging. + void RewriteMerge(const Message& src, Message& dst); + + // Does standard MergeFrom() by default, except messages/fields marked with DontMerge or DontMergeField option. + // Such fields are merged using RewriteMerge() (i.e. destination is cleared before merging anything from source) + void CustomMerge(const Message& src, Message& dst); + +} diff --git a/library/cpp/protobuf/util/merge_ut.cpp b/library/cpp/protobuf/util/merge_ut.cpp new file mode 100644 index 0000000000..22217db183 --- /dev/null +++ b/library/cpp/protobuf/util/merge_ut.cpp @@ -0,0 +1,83 @@ +#include "merge.h" +#include <library/cpp/protobuf/util/ut/common_ut.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NProtoBuf; + +Y_UNIT_TEST_SUITE(ProtobufMerge) { + static void InitProto(NProtobufUtilUt::TMergeTest & p, bool isSrc) { + size_t start = isSrc ? 0 : 100; + + p.AddMergeInt(start + 1); + p.AddMergeInt(start + 2); + + p.AddNoMergeInt(start + 3); + p.AddNoMergeInt(start + 4); + + NProtobufUtilUt::TMergeTestMerge* m = p.MutableMergeSub(); + m->SetA(start + 5); + m->AddB(start + 6); + m->AddB(start + 7); + m->AddC(start + 14); + + if (!isSrc) { + // only for dst + NProtobufUtilUt::TMergeTestMerge* mm1 = p.AddNoMergeRepSub(); + mm1->SetA(start + 8); + mm1->AddB(start + 9); + mm1->AddB(start + 10); + } + + NProtobufUtilUt::TMergeTestNoMerge* mm3 = p.MutableNoMergeOptSub(); + mm3->SetA(start + 11); + mm3->AddB(start + 12); + mm3->AddB(start + 13); + } + + Y_UNIT_TEST(CustomMerge) { + NProtobufUtilUt::TMergeTest src, dst; + InitProto(src, true); + InitProto(dst, false); + + // Cerr << "\nsrc: " << src.ShortDebugString() << Endl; + // Cerr << "dst: " << dst.ShortDebugString() << Endl; + NProtoBuf::CustomMerge(src, dst); + // Cerr << "dst2:" << dst.ShortDebugString() << Endl; + + // repeated uint32 MergeInt = 1; + UNIT_ASSERT_EQUAL(dst.MergeIntSize(), 4); + UNIT_ASSERT_EQUAL(dst.GetMergeInt(0), 101); + UNIT_ASSERT_EQUAL(dst.GetMergeInt(1), 102); + UNIT_ASSERT_EQUAL(dst.GetMergeInt(2), 1); + UNIT_ASSERT_EQUAL(dst.GetMergeInt(3), 2); + + // repeated uint32 NoMergeInt = 2 [(DontMergeField)=true]; + UNIT_ASSERT_EQUAL(dst.NoMergeIntSize(), 2); + UNIT_ASSERT_EQUAL(dst.GetNoMergeInt(0), 3); + UNIT_ASSERT_EQUAL(dst.GetNoMergeInt(1), 4); + + // optional TMergeTestMerge MergeSub = 3; + UNIT_ASSERT_EQUAL(dst.GetMergeSub().GetA(), 5); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().BSize(), 4); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().GetB(0), 106); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().GetB(1), 107); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().GetB(2), 6); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().GetB(3), 7); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().CSize(), 1); + UNIT_ASSERT_EQUAL(dst.GetMergeSub().GetC(0), 14); + + // repeated TMergeTestMerge NoMergeRepSub = 4 [(DontMergeField)=true]; + UNIT_ASSERT_EQUAL(dst.NoMergeRepSubSize(), 1); + UNIT_ASSERT_EQUAL(dst.GetNoMergeRepSub(0).GetA(), 108); + UNIT_ASSERT_EQUAL(dst.GetNoMergeRepSub(0).BSize(), 2); + UNIT_ASSERT_EQUAL(dst.GetNoMergeRepSub(0).GetB(0), 109); + UNIT_ASSERT_EQUAL(dst.GetNoMergeRepSub(0).GetB(1), 110); + + // optional TMergeTestNoMerge NoMergeOptSub = 5; + UNIT_ASSERT_EQUAL(dst.GetNoMergeOptSub().GetA(), 11); + UNIT_ASSERT_EQUAL(dst.GetNoMergeOptSub().BSize(), 2); + UNIT_ASSERT_EQUAL(dst.GetNoMergeOptSub().GetB(0), 12); + UNIT_ASSERT_EQUAL(dst.GetNoMergeOptSub().GetB(1), 13); + } +} diff --git a/library/cpp/protobuf/util/path.cpp b/library/cpp/protobuf/util/path.cpp new file mode 100644 index 0000000000..efa2a42c8a --- /dev/null +++ b/library/cpp/protobuf/util/path.cpp @@ -0,0 +1,61 @@ +#include "path.h" + +#include <util/generic/yexception.h> + +namespace NProtoBuf { + TFieldPath::TFieldPath() { + } + + TFieldPath::TFieldPath(const Descriptor* msgType, const TStringBuf& path) { + Init(msgType, path); + } + + TFieldPath::TFieldPath(const TVector<const FieldDescriptor*>& path) + : Path(path) + { + } + + bool TFieldPath::InitUnsafe(const Descriptor* msgType, TStringBuf path) { + Path.clear(); + while (path) { + TStringBuf next; + while (!next && path) + next = path.NextTok('/'); + if (!next) + return true; + + if (!msgType) // need field but no message type + return false; + + TString nextStr(next); + const FieldDescriptor* field = msgType->FindFieldByName(nextStr); + if (!field) { + // Try to find extension field by FindAllExtensions() + const DescriptorPool* pool = msgType->file()->pool(); + Y_ASSERT(pool); // never NULL by protobuf docs + TVector<const FieldDescriptor*> extensions; + pool->FindAllExtensions(msgType, &extensions); // find all extensions of this extendee + for (const FieldDescriptor* ext : extensions) { + if (ext->full_name() == nextStr || ext->name() == nextStr) { + if (field) + return false; // ambiguity + field = ext; + } + } + } + + if (!field) + return false; + + Path.push_back(field); + msgType = field->type() == FieldDescriptor::TYPE_MESSAGE ? field->message_type() : nullptr; + } + return true; + } + + void TFieldPath::Init(const Descriptor* msgType, const TStringBuf& path) { + if (!InitUnsafe(msgType, path)) + ythrow yexception() << "Failed to resolve path \"" << path << "\" relative to " << msgType->full_name(); + } + +} diff --git a/library/cpp/protobuf/util/path.h b/library/cpp/protobuf/util/path.h new file mode 100644 index 0000000000..487f643a2d --- /dev/null +++ b/library/cpp/protobuf/util/path.h @@ -0,0 +1,52 @@ +#pragma once + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> + +#include <util/generic/vector.h> + +namespace NProtoBuf { + class TFieldPath { + public: + TFieldPath(); + TFieldPath(const Descriptor* msgType, const TStringBuf& path); // throws exception if path doesn't exist + TFieldPath(const TVector<const FieldDescriptor*>& path); + TFieldPath(const TFieldPath&) = default; + TFieldPath& operator=(const TFieldPath&) = default; + + bool InitUnsafe(const Descriptor* msgType, const TStringBuf path); // noexcept + void Init(const Descriptor* msgType, const TStringBuf& path); // throws + + const TVector<const FieldDescriptor*>& Fields() const { + return Path; + } + + void AddField(const FieldDescriptor* field) { + Path.push_back(field); + } + + const Descriptor* ParentType() const { + return Empty() ? nullptr : Path.front()->containing_type(); + } + + const FieldDescriptor* FieldDescr() const { + return Empty() ? nullptr : Path.back(); + } + + bool Empty() const { + return Path.empty(); + } + + explicit operator bool() const { + return !Empty(); + } + + bool operator!() const { + return Empty(); + } + + private: + TVector<const FieldDescriptor*> Path; + }; + +} diff --git a/library/cpp/protobuf/util/pb_io.cpp b/library/cpp/protobuf/util/pb_io.cpp new file mode 100644 index 0000000000..6270ee0624 --- /dev/null +++ b/library/cpp/protobuf/util/pb_io.cpp @@ -0,0 +1,221 @@ +#include "pb_io.h" + +#include <library/cpp/binsaver/bin_saver.h> +#include <library/cpp/string_utils/base64/base64.h> + +#include <google/protobuf/message.h> +#include <google/protobuf/messagext.h> +#include <google/protobuf/text_format.h> + +#include <util/generic/string.h> +#include <util/stream/file.h> +#include <util/stream/str.h> +#include <util/string/cast.h> + +namespace NProtoBuf { + + class TEnumIdValuePrinter : public google::protobuf::TextFormat::FastFieldValuePrinter { + public: + void PrintEnum(int32 val, const TString& /*name*/, google::protobuf::TextFormat::BaseTextGenerator* generator) const override { + generator->PrintString(ToString(val)); + } + }; + + void ParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven) { + if (!m.ParseFromString(allowUneven ? Base64DecodeUneven(dataBase64) : Base64StrictDecode(dataBase64))) { + ythrow yexception() << "can't parse " << m.GetTypeName() << " from base64-encoded string"; + } + } + + bool TryParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven) { + try { + ParseFromBase64String(dataBase64, m, allowUneven); + return true; + } catch (const std::exception&) { + return false; + } + } + + void SerializeToBase64String(const Message& m, TString& dataBase64) { + TString rawData; + if (!m.SerializeToString(&rawData)) { + ythrow yexception() << "can't serialize " << m.GetTypeName(); + } + + Base64EncodeUrl(rawData, dataBase64); + } + + TString SerializeToBase64String(const Message& m) { + TString s; + SerializeToBase64String(m, s); + return s; + } + + bool TrySerializeToBase64String(const Message& m, TString& dataBase64) { + try { + SerializeToBase64String(m, dataBase64); + return true; + } catch (const std::exception&) { + return false; + } + } + + const TString ShortUtf8DebugString(const Message& message) { + TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.SetUseUtf8StringEscaping(true); + TString result; + printer.PrintToString(message, &result); + return result; + } + + bool MergePartialFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage) { + google::protobuf::io::CodedInputStream input(reinterpret_cast<const ui8*>(serializedProtoMessage.data()), serializedProtoMessage.size()); + bool ok = m.MergePartialFromCodedStream(&input); + ok = ok && input.ConsumedEntireMessage(); + return ok; + } + + bool MergeFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage) { + return MergePartialFromString(m, serializedProtoMessage) && m.IsInitialized(); + } +} + +int operator&(NProtoBuf::Message& m, IBinSaver& f) { + TStringStream ss; + if (f.IsReading()) { + f.Add(0, &ss.Str()); + m.ParseFromArcadiaStream(&ss); + } else { + m.SerializeToArcadiaStream(&ss); + f.Add(0, &ss.Str()); + } + return 0; +} + +void SerializeToTextFormat(const NProtoBuf::Message& m, IOutputStream& out) { + NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out); + + if (!NProtoBuf::TextFormat::Print(m, &adaptor)) { + ythrow yexception() << "SerializeToTextFormat failed on Print"; + } +} + +void SerializeToTextFormat(const NProtoBuf::Message& m, const TString& fileName) { + /* TUnbufferedFileOutput is unbuffered, but TCopyingOutputStreamAdaptor adds + * a buffer on top of it. */ + TUnbufferedFileOutput stream(fileName); + SerializeToTextFormat(m, stream); +} + +void SerializeToTextFormatWithEnumId(const NProtoBuf::Message& m, IOutputStream& out) { + google::protobuf::TextFormat::Printer printer; + printer.SetDefaultFieldValuePrinter(new NProtoBuf::TEnumIdValuePrinter()); + NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out); + + if (!printer.Print(m, &adaptor)) { + ythrow yexception() << "SerializeToTextFormatWithEnumId failed on Print"; + } +} + +void SerializeToTextFormatPretty(const NProtoBuf::Message& m, IOutputStream& out) { + google::protobuf::TextFormat::Printer printer; + printer.SetUseUtf8StringEscaping(true); + printer.SetUseShortRepeatedPrimitives(true); + + NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out); + + if (!printer.Print(m, &adaptor)) { + ythrow yexception() << "SerializeToTextFormatPretty failed on Print"; + } +} + +static void ConfigureParser(const EParseFromTextFormatOptions options, + NProtoBuf::TextFormat::Parser& p) { + if (options & EParseFromTextFormatOption::AllowUnknownField) { + p.AllowUnknownField(true); + } +} + +void ParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + NProtoBuf::io::TCopyingInputStreamAdaptor adaptor(&in); + NProtoBuf::TextFormat::Parser p; + ConfigureParser(options, p); + + if (!p.Parse(&adaptor, &m)) { + // remove everything that may have been read + m.Clear(); + ythrow yexception() << "ParseFromTextFormat failed on Parse for " << m.GetTypeName(); + } +} + +void ParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + /* TUnbufferedFileInput is unbuffered, but TCopyingInputStreamAdaptor adds + * a buffer on top of it. */ + TUnbufferedFileInput stream(fileName); + ParseFromTextFormat(stream, m, options); +} + +bool TryParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + try { + ParseFromTextFormat(fileName, m, options); + } catch (std::exception&) { + return false; + } + + return true; +} + +bool TryParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + try { + ParseFromTextFormat(in, m, options); + } catch (std::exception&) { + return false; + } + + return true; +} + +void MergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + NProtoBuf::io::TCopyingInputStreamAdaptor adaptor(&in); + NProtoBuf::TextFormat::Parser p; + ConfigureParser(options, p); + if (!p.Merge(&adaptor, &m)) { + ythrow yexception() << "MergeFromTextFormat failed on Merge for " << m.GetTypeName(); + } +} + +void MergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + /* TUnbufferedFileInput is unbuffered, but TCopyingInputStreamAdaptor adds + * a buffer on top of it. */ + TUnbufferedFileInput stream(fileName); + MergeFromTextFormat(stream, m, options); +} + +bool TryMergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + try { + MergeFromTextFormat(fileName, m, options); + } catch (std::exception&) { + return false; + } + + return true; +} + +bool TryMergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options) { + try { + MergeFromTextFormat(in, m, options); + } catch (std::exception&) { + return false; + } + + return true; +} diff --git a/library/cpp/protobuf/util/pb_io.h b/library/cpp/protobuf/util/pb_io.h new file mode 100644 index 0000000000..493c84cb5f --- /dev/null +++ b/library/cpp/protobuf/util/pb_io.h @@ -0,0 +1,138 @@ +#pragma once + +#include <util/generic/fwd.h> +#include <util/generic/flags.h> + +struct IBinSaver; + +namespace google { + namespace protobuf { + class Message; + } +} + +namespace NProtoBuf { + using Message = ::google::protobuf::Message; +} + +class IInputStream; +class IOutputStream; + +namespace NProtoBuf { + /* Parse base64 URL encoded serialized message from string. + */ + void ParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven = false); + bool TryParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven = false); + template <typename T> + static T ParseFromBase64String(const TStringBuf& dataBase64, bool allowUneven = false) { + T m; + ParseFromBase64String(dataBase64, m, allowUneven); + return m; + } + + /* Serialize message into string and apply base64 URL encoding. + */ + TString SerializeToBase64String(const Message& m); + void SerializeToBase64String(const Message& m, TString& dataBase64); + bool TrySerializeToBase64String(const Message& m, TString& dataBase64); + + const TString ShortUtf8DebugString(const Message& message); + + bool MergePartialFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage); + bool MergeFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage); +} + +int operator&(NProtoBuf::Message& m, IBinSaver& f); + +// Write a textual representation of the given message to the given file. +void SerializeToTextFormat(const NProtoBuf::Message& m, const TString& fileName); +void SerializeToTextFormat(const NProtoBuf::Message& m, IOutputStream& out); + +// Write a textual representation of the given message to the given output stream +// with flags UseShortRepeatedPrimitives and UseUtf8StringEscaping set to true. +void SerializeToTextFormatPretty(const NProtoBuf::Message& m, IOutputStream& out); + +// Write a textual representation of the given message to the given output stream +// use enum id instead of enum name for all enum fields. +void SerializeToTextFormatWithEnumId(const NProtoBuf::Message& m, IOutputStream& out); + +enum class EParseFromTextFormatOption : ui64 { + // Unknown fields will be ignored by the parser + AllowUnknownField = 1 +}; + +Y_DECLARE_FLAGS(EParseFromTextFormatOptions, EParseFromTextFormatOption); + +// Parse a text-format protocol message from the given file into message object. +void ParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); +// NOTE: will read `in` till the end. +void ParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); + +/* @return `true` if parsing was successfull and `false` otherwise. + * + * @see `ParseFromTextFormat` + */ +bool TryParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); +// NOTE: will read `in` till the end. +bool TryParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); + +// @see `ParseFromTextFormat` +template <typename T> +static T ParseFromTextFormat(const TString& fileName, + const EParseFromTextFormatOptions options = {}) { + T message; + ParseFromTextFormat(fileName, message, options); + return message; +} + +// @see `ParseFromTextFormat` +// NOTE: will read `in` till the end. +template <typename T> +static T ParseFromTextFormat(IInputStream& in, + const EParseFromTextFormatOptions options = {}) { + T message; + ParseFromTextFormat(in, message, options); + return message; +} + +// Merge a text-format protocol message from the given file into message object. +// +// NOTE: Even when parsing failed and exception was thrown `m` may be different from its original +// value. User must implement transactional logic around `MergeFromTextFormat` by himself. +void MergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); +// NOTE: will read `in` till the end. +void MergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); +/* @return `true` if parsing was successfull and `false` otherwise. + * + * @see `MergeFromTextFormat` + */ +bool TryMergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); +// NOTE: will read `in` till the end. +bool TryMergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m, + const EParseFromTextFormatOptions options = {}); + +// @see `MergeFromTextFormat` +template <typename T> +static T MergeFromTextFormat(const TString& fileName, + const EParseFromTextFormatOptions options = {}) { + T message; + MergeFromTextFormat(fileName, message, options); + return message; +} + +// @see `MergeFromTextFormat` +// NOTE: will read `in` till the end. +template <typename T> +static T MergeFromTextFormat(IInputStream& in, + const EParseFromTextFormatOptions options = {}) { + T message; + MergeFromTextFormat(in, message, options); + return message; +} diff --git a/library/cpp/protobuf/util/pb_io_ut.cpp b/library/cpp/protobuf/util/pb_io_ut.cpp new file mode 100644 index 0000000000..875d6dc602 --- /dev/null +++ b/library/cpp/protobuf/util/pb_io_ut.cpp @@ -0,0 +1,418 @@ +#include "pb_io.h" + +#include "is_equal.h" + +#include <library/cpp/protobuf/util/ut/common_ut.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/folder/path.h> +#include <util/folder/tempdir.h> +#include <util/stream/file.h> +#include <util/stream/str.h> + +static NProtobufUtilUt::TTextTest GetCorrectMessage() { + NProtobufUtilUt::TTextTest m; + m.SetFoo(42); + return m; +} + +static NProtobufUtilUt::TTextEnumTest GetCorrectEnumMessage() { + NProtobufUtilUt::TTextEnumTest m; + m.SetSlot(NProtobufUtilUt::TTextEnumTest::EET_SLOT_1); + return m; +} + +static const TString CORRECT_MESSAGE = + R"(Foo: 42 +)"; +static const TString CORRECT_ENUM_NAME_MESSAGE = + R"(Slot: EET_SLOT_1 +)"; +static const TString CORRECT_ENUM_ID_MESSAGE = + R"(Slot: 1 +)"; + +static const TString INCORRECT_MESSAGE = + R"(Bar: 1 +)"; +static const TString INCORRECT_ENUM_NAME_MESSAGE = + R"(Slot: EET_SLOT_3 +)"; +static const TString INCORRECT_ENUM_ID_MESSAGE = + R"(Slot: 3 +)"; + +static const TString CORRECT_BASE64_MESSAGE = "CCo,"; + +static const TString CORRECT_UNEVEN_BASE64_MESSAGE = "CCo"; + +static const TString INCORRECT_BASE64_MESSAGE = "CC"; + +Y_UNIT_TEST_SUITE(TTestProtoBufIO) { + Y_UNIT_TEST(TestBase64) { + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(NProtoBuf::TryParseFromBase64String(CORRECT_BASE64_MESSAGE, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(!NProtoBuf::TryParseFromBase64String(INCORRECT_BASE64_MESSAGE, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(NProtoBuf::TryParseFromBase64String(CORRECT_UNEVEN_BASE64_MESSAGE , message, true)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(!NProtoBuf::TryParseFromBase64String(CORRECT_UNEVEN_BASE64_MESSAGE , message, false)); + } + { + UNIT_ASSERT_VALUES_EQUAL(CORRECT_BASE64_MESSAGE, NProtoBuf::SerializeToBase64String(GetCorrectMessage())); + } + { + const auto m = NProtoBuf::ParseFromBase64String<NProtobufUtilUt::TTextTest>(CORRECT_BASE64_MESSAGE); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + } + + Y_UNIT_TEST(TestParseFromTextFormat) { + TTempDir tempDir; + const TFsPath correctFileName = TFsPath{tempDir()} / "correct.pb.txt"; + const TFsPath incorrectFileName = TFsPath{tempDir()} / "incorrect.pb.txt"; + + TFileOutput{correctFileName}.Write(CORRECT_MESSAGE); + TFileOutput{incorrectFileName}.Write(INCORRECT_MESSAGE); + + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(TryParseFromTextFormat(correctFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(!TryParseFromTextFormat(incorrectFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{CORRECT_MESSAGE}; + UNIT_ASSERT(TryParseFromTextFormat(in, message)); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{INCORRECT_MESSAGE}; + UNIT_ASSERT(!TryParseFromTextFormat(in, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_NO_EXCEPTION(TryParseFromTextFormat(incorrectFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(!TryParseFromTextFormat("this_file_doesnt_exists", message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_NO_EXCEPTION(TryParseFromTextFormat("this_file_doesnt_exists", message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat("this_file_doesnt_exists", message), TFileError); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_NO_EXCEPTION(ParseFromTextFormat(correctFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat(incorrectFileName, message), yexception); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{CORRECT_MESSAGE}; + UNIT_ASSERT_NO_EXCEPTION(ParseFromTextFormat(in, message)); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{INCORRECT_MESSAGE}; + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat(in, message), yexception); + } + { + NProtobufUtilUt::TTextTest m; + const auto f = [&correctFileName](NProtobufUtilUt::TTextTest& mm) { + mm = ParseFromTextFormat<NProtobufUtilUt::TTextTest>(correctFileName); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + { + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat<NProtobufUtilUt::TTextTest>(incorrectFileName), yexception); + } + { + NProtobufUtilUt::TTextTest m; + TStringInput in{CORRECT_MESSAGE}; + const auto f = [&in](NProtobufUtilUt::TTextTest& mm) { + mm = ParseFromTextFormat<NProtobufUtilUt::TTextTest>(in); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + { + TStringInput in{INCORRECT_MESSAGE}; + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat<NProtobufUtilUt::TTextTest>(in), yexception); + } + { + const TFsPath correctFileName2 = TFsPath{tempDir()} / "serialized.pb.txt"; + const auto original = GetCorrectMessage(); + UNIT_ASSERT_NO_EXCEPTION(SerializeToTextFormat(original, correctFileName2)); + const auto serializedStr = TUnbufferedFileInput{correctFileName2}.ReadAll(); + UNIT_ASSERT_VALUES_EQUAL(serializedStr, CORRECT_MESSAGE); + } + { + const auto original = GetCorrectMessage(); + TStringStream out; + UNIT_ASSERT_NO_EXCEPTION(SerializeToTextFormat(original, out)); + UNIT_ASSERT_VALUES_EQUAL(out.Str(), CORRECT_MESSAGE); + } + { + NProtobufUtilUt::TTextTest m; + const auto f = [&correctFileName](NProtobufUtilUt::TTextTest& mm) { + mm = ParseFromTextFormat<NProtobufUtilUt::TTextTest>( + correctFileName, + EParseFromTextFormatOption::AllowUnknownField); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + { + const NProtobufUtilUt::TTextTest empty; + NProtobufUtilUt::TTextTest m; + const auto f = [&incorrectFileName](NProtobufUtilUt::TTextTest& mm) { + mm = ParseFromTextFormat<NProtobufUtilUt::TTextTest>( + incorrectFileName, + EParseFromTextFormatOption::AllowUnknownField); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(empty, m)); + } + } + + Y_UNIT_TEST(TestSerializeToTextFormatWithEnumId) { + TTempDir tempDir; + const TFsPath correctNameFileName = TFsPath{tempDir()} / "correct_name.pb.txt"; + const TFsPath incorrectNameFileName = TFsPath{tempDir()} / "incorrect_name.pb.txt"; + const TFsPath correctIdFileName = TFsPath{tempDir()} / "correct_id.pb.txt"; + const TFsPath incorrectIdFileName = TFsPath{tempDir()} / "incorrect_id.pb.txt"; + + TFileOutput{correctNameFileName}.Write(CORRECT_ENUM_NAME_MESSAGE); + TFileOutput{incorrectNameFileName}.Write(INCORRECT_ENUM_NAME_MESSAGE); + TFileOutput{correctIdFileName}.Write(CORRECT_ENUM_ID_MESSAGE); + TFileOutput{incorrectIdFileName}.Write(INCORRECT_ENUM_ID_MESSAGE); + + { + NProtobufUtilUt::TTextEnumTest message; + for (auto correct_message: {CORRECT_ENUM_ID_MESSAGE, CORRECT_ENUM_NAME_MESSAGE}) { + TStringInput in{correct_message}; + UNIT_ASSERT_NO_EXCEPTION(ParseFromTextFormat(in, message)); + } + } + { + NProtobufUtilUt::TTextEnumTest message; + for (auto incorrect_message: {INCORRECT_ENUM_ID_MESSAGE, INCORRECT_ENUM_NAME_MESSAGE}) { + TStringInput in{incorrect_message}; + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat(in, message), yexception); + } + } + { + const auto f = [](NProtobufUtilUt::TTextEnumTest& mm, const TString fileName) { + mm = ParseFromTextFormat<NProtobufUtilUt::TTextEnumTest>(fileName); + }; + for (auto fileName: {correctIdFileName, correctNameFileName}) { + NProtobufUtilUt::TTextEnumTest m; + UNIT_ASSERT_NO_EXCEPTION(f(m, fileName)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectEnumMessage(), m)); + } + } + { + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat<NProtobufUtilUt::TTextEnumTest>(incorrectIdFileName), yexception); + UNIT_ASSERT_EXCEPTION(ParseFromTextFormat<NProtobufUtilUt::TTextEnumTest>(incorrectNameFileName), yexception); + } + { + const auto original = GetCorrectEnumMessage(); + TStringStream out; + UNIT_ASSERT_NO_EXCEPTION(SerializeToTextFormat(original, out)); + UNIT_ASSERT_VALUES_EQUAL(out.Str(), CORRECT_ENUM_NAME_MESSAGE); + } + { + const auto original = GetCorrectEnumMessage(); + TStringStream out; + UNIT_ASSERT_NO_EXCEPTION(SerializeToTextFormatWithEnumId(original, out)); + UNIT_ASSERT_VALUES_EQUAL(out.Str(), CORRECT_ENUM_ID_MESSAGE); + } + } + + Y_UNIT_TEST(TestMergeFromTextFormat) { + // + // Tests cases below are identical to `Parse` tests + // + TTempDir tempDir; + const TFsPath correctFileName = TFsPath{tempDir()} / "correct.pb.txt"; + const TFsPath incorrectFileName = TFsPath{tempDir()} / "incorrect.pb.txt"; + + TFileOutput{correctFileName}.Write(CORRECT_MESSAGE); + TFileOutput{incorrectFileName}.Write(INCORRECT_MESSAGE); + + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(TryMergeFromTextFormat(correctFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(!TryMergeFromTextFormat(incorrectFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{CORRECT_MESSAGE}; + UNIT_ASSERT(TryMergeFromTextFormat(in, message)); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{INCORRECT_MESSAGE}; + UNIT_ASSERT(!TryMergeFromTextFormat(in, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_NO_EXCEPTION(TryMergeFromTextFormat(incorrectFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT(!TryMergeFromTextFormat("this_file_doesnt_exists", message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_NO_EXCEPTION(TryMergeFromTextFormat("this_file_doesnt_exists", message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_EXCEPTION(MergeFromTextFormat("this_file_doesnt_exists", message), TFileError); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_NO_EXCEPTION(MergeFromTextFormat(correctFileName, message)); + } + { + NProtobufUtilUt::TTextTest message; + UNIT_ASSERT_EXCEPTION(MergeFromTextFormat(incorrectFileName, message), yexception); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{CORRECT_MESSAGE}; + UNIT_ASSERT_NO_EXCEPTION(MergeFromTextFormat(in, message)); + } + { + NProtobufUtilUt::TTextTest message; + TStringInput in{INCORRECT_MESSAGE}; + UNIT_ASSERT_EXCEPTION(MergeFromTextFormat(in, message), yexception); + } + { + NProtobufUtilUt::TTextTest m; + const auto f = [&correctFileName](NProtobufUtilUt::TTextTest& mm) { + mm = MergeFromTextFormat<NProtobufUtilUt::TTextTest>(correctFileName); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + { + UNIT_ASSERT_EXCEPTION(MergeFromTextFormat<NProtobufUtilUt::TTextTest>(incorrectFileName), yexception); + } + { + NProtobufUtilUt::TTextTest m; + TStringInput in{CORRECT_MESSAGE}; + const auto f = [&in](NProtobufUtilUt::TTextTest& mm) { + mm = MergeFromTextFormat<NProtobufUtilUt::TTextTest>(in); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + { + TStringInput in{INCORRECT_MESSAGE}; + UNIT_ASSERT_EXCEPTION(MergeFromTextFormat<NProtobufUtilUt::TTextTest>(in), yexception); + } + { + const TFsPath correctFileName2 = TFsPath{tempDir()} / "serialized.pb.txt"; + const auto original = GetCorrectMessage(); + UNIT_ASSERT_NO_EXCEPTION(SerializeToTextFormat(original, correctFileName2)); + const auto serializedStr = TUnbufferedFileInput{correctFileName2}.ReadAll(); + UNIT_ASSERT_VALUES_EQUAL(serializedStr, CORRECT_MESSAGE); + } + { + const auto original = GetCorrectMessage(); + TStringStream out; + UNIT_ASSERT_NO_EXCEPTION(SerializeToTextFormat(original, out)); + UNIT_ASSERT_VALUES_EQUAL(out.Str(), CORRECT_MESSAGE); + } + { + NProtobufUtilUt::TTextTest m; + const auto f = [&correctFileName](NProtobufUtilUt::TTextTest& mm) { + mm = MergeFromTextFormat<NProtobufUtilUt::TTextTest>( + correctFileName, + EParseFromTextFormatOption::AllowUnknownField); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(GetCorrectMessage(), m)); + } + { + const NProtobufUtilUt::TTextTest empty; + NProtobufUtilUt::TTextTest m; + const auto f = [&incorrectFileName](NProtobufUtilUt::TTextTest& mm) { + mm = MergeFromTextFormat<NProtobufUtilUt::TTextTest>( + incorrectFileName, + EParseFromTextFormatOption::AllowUnknownField); + }; + UNIT_ASSERT_NO_EXCEPTION(f(m)); + UNIT_ASSERT(NProtoBuf::IsEqual(empty, m)); + } + + // + // Test cases for `Merge` + // + { + NProtobufUtilUt::TTextTest message; + message.SetFoo(100500); + TStringInput in{CORRECT_MESSAGE}; + UNIT_ASSERT(TryMergeFromTextFormat(in, message)); + UNIT_ASSERT(NProtoBuf::IsEqual(message, GetCorrectMessage())); + } + } + + Y_UNIT_TEST(TestMergeFromString) { + NProtobufUtilUt::TMergeTest message; + NProtobufUtilUt::TMergeTest messageFirstHalf; + NProtobufUtilUt::TMergeTest messageSecondHalf; + + for (ui32 v = ~0; v != 0; v >>= 1) { + message.AddMergeInt(v); + (v > 0xffff ? messageFirstHalf : messageSecondHalf).AddMergeInt(v); + } + + const TString full = message.SerializeAsString(); + + { + NProtobufUtilUt::TMergeTest m1; + UNIT_ASSERT(NProtoBuf::MergeFromString(m1, full)); + UNIT_ASSERT(NProtoBuf::IsEqual(message, m1)); + } + { + NProtobufUtilUt::TMergeTest m2; + TStringBuf s0 = TStringBuf(full).SubStr(0, 3); + TStringBuf s1 = TStringBuf(full).SubStr(3); + // объединение результатов двух MergePartialFromString не эквивалентно вызову MergePartialFromString от объединения строк + UNIT_ASSERT(!(NProtoBuf::MergePartialFromString(m2, s0) && NProtoBuf::MergePartialFromString(m2, s1))); + } + { + NProtobufUtilUt::TMergeTest m3; + UNIT_ASSERT(NProtoBuf::MergePartialFromString(m3, messageFirstHalf.SerializeAsString())); + UNIT_ASSERT(NProtoBuf::MergeFromString(m3, messageSecondHalf.SerializeAsString())); + UNIT_ASSERT(NProtoBuf::IsEqual(message, m3)); + } + } +} diff --git a/library/cpp/protobuf/util/pb_utils.h b/library/cpp/protobuf/util/pb_utils.h new file mode 100644 index 0000000000..9e9a110b48 --- /dev/null +++ b/library/cpp/protobuf/util/pb_utils.h @@ -0,0 +1,11 @@ +#pragma once + +#define UPDATE_PB_FIELD_MAX(PBMESS, FIELD, VAL) \ + if ((VAL) > (PBMESS).Get##FIELD()) { \ + (PBMESS).Set##FIELD(VAL); \ + } + +#define UPDATE_OPT_PB_FIELD_MAX(PBMESS, FIELD, VAL) \ + if (!(PBMESS).Has##FIELD() || ((VAL) > (PBMESS).Get##FIELD())) { \ + (PBMESS).Set##FIELD(VAL); \ + } diff --git a/library/cpp/protobuf/util/proto/merge.proto b/library/cpp/protobuf/util/proto/merge.proto new file mode 100644 index 0000000000..a937041c07 --- /dev/null +++ b/library/cpp/protobuf/util/proto/merge.proto @@ -0,0 +1,11 @@ +import "google/protobuf/descriptor.proto"; + +// These meta-options are used for selecting proper merging method, see merge.h + +extend google.protobuf.MessageOptions { + optional bool DontMerge = 54287; +} + +extend google.protobuf.FieldOptions { + optional bool DontMergeField = 54288; +} diff --git a/library/cpp/protobuf/util/proto/ya.make b/library/cpp/protobuf/util/proto/ya.make new file mode 100644 index 0000000000..4d68047d8b --- /dev/null +++ b/library/cpp/protobuf/util/proto/ya.make @@ -0,0 +1,11 @@ +PROTO_LIBRARY() + +OWNER(mowgli) + +SRCS( + merge.proto +) + +EXCLUDE_TAGS(GO_PROTO) + +END() diff --git a/library/cpp/protobuf/util/repeated_field_utils.h b/library/cpp/protobuf/util/repeated_field_utils.h new file mode 100644 index 0000000000..c07bd84647 --- /dev/null +++ b/library/cpp/protobuf/util/repeated_field_utils.h @@ -0,0 +1,96 @@ +#pragma once + +#include <google/protobuf/repeated_field.h> +#include <util/generic/vector.h> + +template <typename T> +void RemoveRepeatedPtrFieldElement(google::protobuf::RepeatedPtrField<T>* repeated, unsigned index) { + google::protobuf::RepeatedPtrField<T> r; + Y_ASSERT(index < (unsigned)repeated->size()); + for (unsigned i = 0; i < (unsigned)repeated->size(); ++i) { + if (i == index) { + continue; + } + r.Add()->Swap(repeated->Mutable(i)); + } + r.Swap(repeated); +} + +namespace NProtoBuf { + /// Move item to specified position + template <typename TRepeated> + static void MoveRepeatedFieldItem(TRepeated* field, size_t indexFrom, size_t indexTo) { + if (!field->size() || indexFrom >= static_cast<size_t>(field->size()) || indexFrom == indexTo) + return; + if (indexTo >= static_cast<size_t>(field->size())) + indexTo = field->size() - 1; + if (indexFrom > indexTo) { + for (size_t i = indexFrom; i > indexTo; --i) + field->SwapElements(i, i - 1); + } else { + for (size_t i = indexFrom; i < indexTo; ++i) + field->SwapElements(i, i + 1); + } + } + + template <typename T> + static T* InsertRepeatedFieldItem(NProtoBuf::RepeatedPtrField<T>* field, size_t index) { + T* ret = field->Add(); + MoveRepeatedFieldItem(field, field->size() - 1, index); + return ret; + } + + template <typename TRepeated> // suitable both for RepeatedField and RepeatedPtrField + static void RemoveRepeatedFieldItem(TRepeated* field, size_t index) { + if ((int)index >= field->size()) + return; + + for (int i = index + 1; i < field->size(); ++i) + field->SwapElements(i - 1, i); + + field->RemoveLast(); + } + + template <typename TRepeated, typename TPred> // suitable both for RepeatedField and RepeatedPtrField + static void RemoveRepeatedFieldItemIf(TRepeated* repeated, TPred p) { + auto last = std::remove_if(repeated->begin(), repeated->end(), p); + if (last != repeated->end()) { + size_t countToRemove = repeated->end() - last; + while (countToRemove--) + repeated->RemoveLast(); + } + } + + namespace NImpl { + template <typename TRepeated> + static void ShiftLeft(TRepeated* field, int begIndex, int endIndex, size_t shiftSize) { + Y_ASSERT(begIndex <= field->size()); + Y_ASSERT(endIndex <= field->size()); + size_t shiftIndex = (int)shiftSize < begIndex ? begIndex - shiftSize : 0; + for (int i = begIndex; i < endIndex; ++i, ++shiftIndex) + field->SwapElements(shiftIndex, i); + } + } + + // Remove several items at once, could be more efficient compared to calling RemoveRepeatedFieldItem several times + template <typename TRepeated> + static void RemoveRepeatedFieldItems(TRepeated* field, const TVector<size_t>& sortedIndices) { + if (sortedIndices.empty()) + return; + + size_t shift = 1; + for (size_t i = 1; i < sortedIndices.size(); ++i, ++shift) + NImpl::ShiftLeft(field, sortedIndices[i - 1] + 1, sortedIndices[i], shift); + NImpl::ShiftLeft(field, sortedIndices.back() + 1, field->size(), shift); + + for (; shift > 0; --shift) + field->RemoveLast(); + } + + template <typename TRepeated> + static void ReverseRepeatedFieldItems(TRepeated* field) { + for (int i1 = 0, i2 = field->size() - 1; i1 < i2; ++i1, --i2) + field->SwapElements(i1, i2); + } + +} diff --git a/library/cpp/protobuf/util/repeated_field_utils_ut.cpp b/library/cpp/protobuf/util/repeated_field_utils_ut.cpp new file mode 100644 index 0000000000..58aaaa9e12 --- /dev/null +++ b/library/cpp/protobuf/util/repeated_field_utils_ut.cpp @@ -0,0 +1,46 @@ +#include "repeated_field_utils.h" +#include <library/cpp/protobuf/util/ut/common_ut.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NProtoBuf; + +Y_UNIT_TEST_SUITE(RepeatedFieldUtils) { + Y_UNIT_TEST(RemoveIf) { + { + NProtobufUtilUt::TWalkTest msg; + msg.AddRepInt(0); + msg.AddRepInt(1); + msg.AddRepInt(2); + msg.AddRepInt(3); + msg.AddRepInt(4); + msg.AddRepInt(5); + auto cond = [](ui32 val) { + return val % 2 == 0; + }; + RemoveRepeatedFieldItemIf(msg.MutableRepInt(), cond); + UNIT_ASSERT_VALUES_EQUAL(3, msg.RepIntSize()); + UNIT_ASSERT_VALUES_EQUAL(1, msg.GetRepInt(0)); + UNIT_ASSERT_VALUES_EQUAL(3, msg.GetRepInt(1)); + UNIT_ASSERT_VALUES_EQUAL(5, msg.GetRepInt(2)); + } + + { + NProtobufUtilUt::TWalkTest msg; + msg.AddRepSub()->SetOptInt(0); + msg.AddRepSub()->SetOptInt(1); + msg.AddRepSub()->SetOptInt(2); + msg.AddRepSub()->SetOptInt(3); + msg.AddRepSub()->SetOptInt(4); + msg.AddRepSub()->SetOptInt(5); + auto cond = [](const NProtobufUtilUt::TWalkTest& val) { + return val.GetOptInt() % 2 == 0; + }; + RemoveRepeatedFieldItemIf(msg.MutableRepSub(), cond); + UNIT_ASSERT_VALUES_EQUAL(3, msg.RepSubSize()); + UNIT_ASSERT_VALUES_EQUAL(1, msg.GetRepSub(0).GetOptInt()); + UNIT_ASSERT_VALUES_EQUAL(3, msg.GetRepSub(1).GetOptInt()); + UNIT_ASSERT_VALUES_EQUAL(5, msg.GetRepSub(2).GetOptInt()); + } + } +} diff --git a/library/cpp/protobuf/util/simple_reflection.cpp b/library/cpp/protobuf/util/simple_reflection.cpp new file mode 100644 index 0000000000..d842e9ee44 --- /dev/null +++ b/library/cpp/protobuf/util/simple_reflection.cpp @@ -0,0 +1,70 @@ +#include "simple_reflection.h" + +namespace NProtoBuf { + const Message* GetMessageHelper(const TConstField& curField, bool) { + return curField.HasValue() && curField.IsMessage() ? curField.Get<Message>() : nullptr; + } + + Message* GetMessageHelper(TMutableField& curField, bool createPath) { + if (curField.IsMessage()) { + if (!curField.HasValue()) { + if (createPath) + return curField.Field()->is_repeated() ? curField.AddMessage() : curField.MutableMessage(); + } else { + return curField.MutableMessage(); + } + } + return nullptr; + } + + template <class TField, class TMsg> + TMaybe<TField> ByPathImpl(TMsg& msg, const TVector<const FieldDescriptor*>& fieldsPath, bool createPath) { + if (fieldsPath.empty()) + return TMaybe<TField>(); + TMsg* curParent = &msg; + for (size_t i = 0, size = fieldsPath.size(); i < size; ++i) { + const FieldDescriptor* field = fieldsPath[i]; + if (!curParent) + return TMaybe<TField>(); + TField curField(*curParent, field); + if (size - i == 1) // last element in path + return curField; + curParent = GetMessageHelper(curField, createPath); + } + if (curParent) + return TField(*curParent, fieldsPath.back()); + else + return TMaybe<TField>(); + } + + TMaybe<TConstField> TConstField::ByPath(const Message& msg, const TVector<const FieldDescriptor*>& fieldsPath) { + return ByPathImpl<TConstField, const Message>(msg, fieldsPath, false); + } + + TMaybe<TConstField> TConstField::ByPath(const Message& msg, const TStringBuf& path) { + TFieldPath fieldPath; + if (!fieldPath.InitUnsafe(msg.GetDescriptor(), path)) + return TMaybe<TConstField>(); + return ByPathImpl<TConstField, const Message>(msg, fieldPath.Fields(), false); + } + + TMaybe<TConstField> TConstField::ByPath(const Message& msg, const TFieldPath& path) { + return ByPathImpl<TConstField, const Message>(msg, path.Fields(), false); + } + + TMaybe<TMutableField> TMutableField::ByPath(Message& msg, const TVector<const FieldDescriptor*>& fieldsPath, bool createPath) { + return ByPathImpl<TMutableField, Message>(msg, fieldsPath, createPath); + } + + TMaybe<TMutableField> TMutableField::ByPath(Message& msg, const TStringBuf& path, bool createPath) { + TFieldPath fieldPath; + if (!fieldPath.InitUnsafe(msg.GetDescriptor(), path)) + return TMaybe<TMutableField>(); + return ByPathImpl<TMutableField, Message>(msg, fieldPath.Fields(), createPath); + } + + TMaybe<TMutableField> TMutableField::ByPath(Message& msg, const TFieldPath& path, bool createPath) { + return ByPathImpl<TMutableField, Message>(msg, path.Fields(), createPath); + } + +} diff --git a/library/cpp/protobuf/util/simple_reflection.h b/library/cpp/protobuf/util/simple_reflection.h new file mode 100644 index 0000000000..61e877a787 --- /dev/null +++ b/library/cpp/protobuf/util/simple_reflection.h @@ -0,0 +1,289 @@ +#pragma once + +#include "cast.h" +#include "path.h" +#include "traits.h" + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> + +#include <util/generic/maybe.h> +#include <util/generic/typetraits.h> +#include <util/generic/vector.h> +#include <util/system/defaults.h> + +namespace NProtoBuf { + class TConstField { + public: + TConstField(const Message& msg, const FieldDescriptor* fd) + : Msg(msg) + , Fd(fd) + { + Y_ASSERT(Fd && Fd->containing_type() == Msg.GetDescriptor()); + } + + static TMaybe<TConstField> ByPath(const Message& msg, const TStringBuf& path); + static TMaybe<TConstField> ByPath(const Message& msg, const TVector<const FieldDescriptor*>& fieldsPath); + static TMaybe<TConstField> ByPath(const Message& msg, const TFieldPath& fieldsPath); + + const Message& Parent() const { + return Msg; + } + + const FieldDescriptor* Field() const { + return Fd; + } + + bool HasValue() const { + return IsRepeated() ? Refl().FieldSize(Msg, Fd) > 0 + : Refl().HasField(Msg, Fd); + } + + // deprecated, use HasValue() instead + bool Has() const { + return HasValue(); + } + + size_t Size() const { + return IsRepeated() ? Refl().FieldSize(Msg, Fd) + : (Refl().HasField(Msg, Fd) ? 1 : 0); + } + + template <typename T> + inline typename TSelectCppType<T>::T Get(size_t index = 0) const; + + template <typename TMsg> + inline const TMsg* GetAs(size_t index = 0) const { + // casting version of Get + return IsMessageInstance<TMsg>() ? CheckedCast<const TMsg*>(Get<const Message*>(index)) : nullptr; + } + + template <typename T> + bool IsInstance() const { + return CppType() == TSelectCppType<T>::Result; + } + + template <typename TMsg> + bool IsMessageInstance() const { + return IsMessage() && Fd->message_type() == TMsg::descriptor(); + } + + template <typename TMsg> + bool IsInstance(std::enable_if_t<std::is_base_of<Message, TMsg>::value && !std::is_same<Message, TMsg>::value, void>* = NULL) const { // template will be selected when specifying Message children types + return IsMessage() && Fd->message_type() == TMsg::descriptor(); + } + + bool IsString() const { + return CppType() == FieldDescriptor::CPPTYPE_STRING; + } + + bool IsMessage() const { + return CppType() == FieldDescriptor::CPPTYPE_MESSAGE; + } + + bool HasSameType(const TConstField& other) const { + if (CppType() != other.CppType()) + return false; + if (IsMessage() && Field()->message_type() != other.Field()->message_type()) + return false; + if (CppType() == FieldDescriptor::CPPTYPE_ENUM && Field()->enum_type() != other.Field()->enum_type()) + return false; + return true; + } + + protected: + bool IsRepeated() const { + return Fd->is_repeated(); + } + + FieldDescriptor::CppType CppType() const { + return Fd->cpp_type(); + } + + const Reflection& Refl() const { + return *Msg.GetReflection(); + } + + [[noreturn]] void RaiseUnknown() const { + ythrow yexception() << "Unknown field cpp-type: " << (size_t)CppType(); + } + + bool IsSameField(const TConstField& other) const { + return &Parent() == &other.Parent() && Field() == other.Field(); + } + + protected: + const Message& Msg; + const FieldDescriptor* Fd; + }; + + class TMutableField: public TConstField { + public: + TMutableField(Message& msg, const FieldDescriptor* fd) + : TConstField(msg, fd) + { + } + + static TMaybe<TMutableField> ByPath(Message& msg, const TStringBuf& path, bool createPath = false); + static TMaybe<TMutableField> ByPath(Message& msg, const TVector<const FieldDescriptor*>& fieldsPath, bool createPath = false); + static TMaybe<TMutableField> ByPath(Message& msg, const TFieldPath& fieldsPath, bool createPath = false); + + Message* MutableParent() { + return Mut(); + } + + template <typename T> + inline void Set(T value, size_t index = 0); + + template <typename T> + inline void Add(T value); + + inline void MergeFrom(const TConstField& src); + + inline void Clear() { + Refl().ClearField(Mut(), Fd); + } + /* + void Swap(TMutableField& f) { + Y_ASSERT(Field() == f.Field()); + + // not implemented yet, TODO: implement when Reflection::Mutable(Ptr)RepeatedField + // is ported into arcadia protobuf library from up-stream. + } +*/ + inline void RemoveLast() { + Y_ASSERT(HasValue()); + if (IsRepeated()) + Refl().RemoveLast(Mut(), Fd); + else + Clear(); + } + + inline void SwapElements(size_t index1, size_t index2) { + Y_ASSERT(IsRepeated()); + Y_ASSERT(index1 < Size()); + Y_ASSERT(index2 < Size()); + if (index1 == index2) + return; + Refl().SwapElements(Mut(), Fd, index1, index2); + } + + inline void Remove(size_t index) { + if (index >= Size()) + return; + + // Move to the end + for (size_t i = index, size = Size(); i < size - 1; ++i) + SwapElements(i, i + 1); + RemoveLast(); + } + + Message* MutableMessage(size_t index = 0) { + Y_ASSERT(IsMessage()); + if (IsRepeated()) { + Y_ASSERT(index < Size()); + return Refl().MutableRepeatedMessage(Mut(), Fd, index); + } else { + Y_ASSERT(index == 0); + return Refl().MutableMessage(Mut(), Fd); + } + } + + template <typename TMsg> + inline TMsg* AddMessage() { + return CheckedCast<TMsg*>(AddMessage()); + } + + inline Message* AddMessage() { + Y_ASSERT(IsMessage() && IsRepeated()); + return Refl().AddMessage(Mut(), Fd); + } + + private: + Message* Mut() { + return const_cast<Message*>(&Msg); + } + + template <typename T> + inline void MergeValue(T srcValue); + }; + + // template implementations + + template <typename T> + inline typename TSelectCppType<T>::T TConstField::Get(size_t index) const { + Y_ASSERT(index < Size() || !Fd->is_repeated() && index == 0); // Get for single fields is always allowed because of default values +#define TMP_MACRO_FOR_CPPTYPE(CPPTYPE) \ + case CPPTYPE: \ + return CompatCast<CPPTYPE, TSelectCppType<T>::Result>(TSimpleFieldTraits<CPPTYPE>::Get(Msg, Fd, index)); + switch (CppType()) { + APPLY_TMP_MACRO_FOR_ALL_CPPTYPES() + default: + RaiseUnknown(); + } +#undef TMP_MACRO_FOR_CPPTYPE + } + + template <typename T> + inline void TMutableField::Set(T value, size_t index) { + Y_ASSERT(!IsRepeated() && index == 0 || index < Size()); +#define TMP_MACRO_FOR_CPPTYPE(CPPTYPE) \ + case CPPTYPE: \ + TSimpleFieldTraits<CPPTYPE>::Set(*Mut(), Fd, CompatCast<TSelectCppType<T>::Result, CPPTYPE>(value), index); \ + break; + switch (CppType()) { + APPLY_TMP_MACRO_FOR_ALL_CPPTYPES() + default: + RaiseUnknown(); + } +#undef TMP_MACRO_FOR_CPPTYPE + } + + template <typename T> + inline void TMutableField::Add(T value) { +#define TMP_MACRO_FOR_CPPTYPE(CPPTYPE) \ + case CPPTYPE: \ + TSimpleFieldTraits<CPPTYPE>::Add(*Mut(), Fd, CompatCast<TSelectCppType<T>::Result, CPPTYPE>(value)); \ + break; + switch (CppType()) { + APPLY_TMP_MACRO_FOR_ALL_CPPTYPES() + default: + RaiseUnknown(); + } +#undef TMP_MACRO_FOR_CPPTYPE + } + + template <typename T> + inline void TMutableField::MergeValue(T srcValue) { + Add(srcValue); + } + + template <> + inline void TMutableField::MergeValue<const Message*>(const Message* srcValue) { + if (IsRepeated()) { + Add(srcValue); + } else { + MutableMessage()->MergeFrom(*srcValue); + } + } + + inline void TMutableField::MergeFrom(const TConstField& src) { + Y_ASSERT(HasSameType(src)); + if (IsSameField(src)) + return; +#define TMP_MACRO_FOR_CPPTYPE(CPPTYPE) \ + case CPPTYPE: { \ + for (size_t itemIdx = 0; itemIdx < src.Size(); ++itemIdx) { \ + MergeValue(TSimpleFieldTraits<CPPTYPE>::Get(src.Parent(), src.Field(), itemIdx)); \ + } \ + break; \ + } + switch (CppType()) { + APPLY_TMP_MACRO_FOR_ALL_CPPTYPES() + default: + RaiseUnknown(); + } +#undef TMP_MACRO_FOR_CPPTYPE + } + +} diff --git a/library/cpp/protobuf/util/simple_reflection_ut.cpp b/library/cpp/protobuf/util/simple_reflection_ut.cpp new file mode 100644 index 0000000000..169d4703c9 --- /dev/null +++ b/library/cpp/protobuf/util/simple_reflection_ut.cpp @@ -0,0 +1,359 @@ +#include "simple_reflection.h" +#include <library/cpp/protobuf/util/ut/sample_for_simple_reflection.pb.h> +#include <library/cpp/protobuf/util/ut/extensions.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NProtoBuf; + +Y_UNIT_TEST_SUITE(ProtobufSimpleReflection) { + static TSample GenSampleForMergeFrom() { + TSample smf; + smf.SetOneStr("one str"); + smf.MutableOneMsg()->AddRepInt(1); + smf.AddRepMsg()->AddRepInt(2); + smf.AddRepMsg()->AddRepInt(3); + smf.AddRepStr("one rep str"); + smf.AddRepStr("two rep str"); + smf.SetAnotherOneStr("another one str"); + return smf; + } + + Y_UNIT_TEST(MergeFromGeneric) { + const TSample src(GenSampleForMergeFrom()); + TSample dst; + const Descriptor* descr = dst.GetDescriptor(); + + { + TMutableField dstOneStr(dst, descr->FindFieldByName("OneStr")); + TConstField srcOneStr(src, descr->FindFieldByName("OneStr")); + dstOneStr.MergeFrom(srcOneStr); + UNIT_ASSERT_VALUES_EQUAL(dst.GetOneStr(), src.GetOneStr()); + } + + { // MergeFrom for single message fields acts like a Message::MergeFrom + TMutableField dstOneMsg(dst, descr->FindFieldByName("OneMsg")); + dstOneMsg.MergeFrom(TConstField(src, descr->FindFieldByName("OneMsg"))); + UNIT_ASSERT_VALUES_EQUAL(dst.GetOneMsg().RepIntSize(), src.GetOneMsg().RepIntSize()); + dstOneMsg.MergeFrom(TConstField(src, descr->FindFieldByName("OneMsg"))); + UNIT_ASSERT_VALUES_EQUAL(dst.GetOneMsg().RepIntSize(), src.GetOneMsg().RepIntSize() * 2); + } + + { // MergeFrom for repeated fields acts like append + TMutableField dstRepMsg(dst, descr->FindFieldByName("RepMsg")); + dstRepMsg.MergeFrom(TConstField(src, descr->FindFieldByName("RepMsg"))); + UNIT_ASSERT_VALUES_EQUAL(dst.RepMsgSize(), src.RepMsgSize()); + dstRepMsg.MergeFrom(TConstField(src, descr->FindFieldByName("RepMsg"))); + UNIT_ASSERT_VALUES_EQUAL(dst.RepMsgSize(), src.RepMsgSize() * 2); + for (size_t repMsgIndex = 0; repMsgIndex < dst.RepMsgSize(); ++repMsgIndex) { + UNIT_ASSERT_VALUES_EQUAL(dst.GetRepMsg(repMsgIndex).RepIntSize(), src.GetRepMsg(0).RepIntSize()); + } + } + } + + Y_UNIT_TEST(MergeFromSelf) { + const TSample sample(GenSampleForMergeFrom()); + TSample msg(sample); + const Descriptor* descr = msg.GetDescriptor(); + + TMutableField oneStr(msg, descr->FindFieldByName("OneStr")); + oneStr.MergeFrom(oneStr); + UNIT_ASSERT_VALUES_EQUAL(msg.GetOneStr(), sample.GetOneStr()); + + TMutableField oneMsg(msg, descr->FindFieldByName("OneMsg")); + oneMsg.MergeFrom(oneMsg); // nothing should change + UNIT_ASSERT_VALUES_EQUAL(msg.GetOneMsg().RepIntSize(), sample.GetOneMsg().RepIntSize()); + } + + Y_UNIT_TEST(MergeFromAnotherFD) { + const TSample sample(GenSampleForMergeFrom()); + TSample msg(GenSampleForMergeFrom()); + const Descriptor* descr = msg.GetDescriptor(); + + { // string + TMutableField oneStr(msg, descr->FindFieldByName("OneStr")); + TMutableField repStr(msg, descr->FindFieldByName("RepStr")); + TMutableField anotherOneStr(msg, descr->FindFieldByName("AnotherOneStr")); + oneStr.MergeFrom(anotherOneStr); + UNIT_ASSERT_VALUES_EQUAL(msg.GetOneStr(), sample.GetAnotherOneStr()); + oneStr.MergeFrom(repStr); + const size_t sampleRepStrSize = sample.RepStrSize(); + UNIT_ASSERT_VALUES_EQUAL(msg.GetOneStr(), sample.GetRepStr(sampleRepStrSize - 1)); + repStr.MergeFrom(anotherOneStr); + UNIT_ASSERT_VALUES_EQUAL(msg.RepStrSize(), sampleRepStrSize + 1); + UNIT_ASSERT_VALUES_EQUAL(msg.GetRepStr(sampleRepStrSize), msg.GetAnotherOneStr()); + } + + { // Message + TMutableField oneMsg(msg, descr->FindFieldByName("OneMsg")); + TMutableField repMsg(msg, descr->FindFieldByName("RepMsg")); + oneMsg.MergeFrom(repMsg); + const size_t oneMsgRepIntSize = sample.GetOneMsg().RepIntSize(); + const size_t sizeOfAllRepIntsInRepMsg = sample.RepMsgSize(); + UNIT_ASSERT_VALUES_EQUAL(msg.GetOneMsg().RepIntSize(), oneMsgRepIntSize + sizeOfAllRepIntsInRepMsg); + repMsg.MergeFrom(oneMsg); + UNIT_ASSERT_VALUES_EQUAL(msg.RepMsgSize(), sample.RepMsgSize() + 1); + } + } + + Y_UNIT_TEST(RemoveByIndex) { + TSample msg; + + const Descriptor* descr = msg.GetDescriptor(); + { + TMutableField fld(msg, descr->FindFieldByName("RepMsg")); + msg.AddRepMsg()->AddRepInt(1); + msg.AddRepMsg()->AddRepInt(2); + msg.AddRepMsg()->AddRepInt(3); + + UNIT_ASSERT_VALUES_EQUAL(3, msg.RepMsgSize()); // 1, 2, 3 + fld.Remove(1); // from middle + UNIT_ASSERT_VALUES_EQUAL(2, msg.RepMsgSize()); + UNIT_ASSERT_VALUES_EQUAL(1, msg.GetRepMsg(0).GetRepInt(0)); + UNIT_ASSERT_VALUES_EQUAL(3, msg.GetRepMsg(1).GetRepInt(0)); + + msg.AddRepMsg()->AddRepInt(5); + UNIT_ASSERT_VALUES_EQUAL(3, msg.RepMsgSize()); // 1, 3, 5 + fld.Remove(2); // from end + UNIT_ASSERT_VALUES_EQUAL(2, msg.RepMsgSize()); + UNIT_ASSERT_VALUES_EQUAL(1, msg.GetRepMsg(0).GetRepInt(0)); + UNIT_ASSERT_VALUES_EQUAL(3, msg.GetRepMsg(1).GetRepInt(0)); + msg.ClearRepMsg(); + } + + { + TMutableField fld(msg, descr->FindFieldByName("RepStr")); + msg.AddRepStr("1"); + msg.AddRepStr("2"); + msg.AddRepStr("3"); + UNIT_ASSERT_VALUES_EQUAL(3, msg.RepStrSize()); // "1", "2", "3" + fld.Remove(0); // from begin + UNIT_ASSERT_VALUES_EQUAL(2, msg.RepStrSize()); + UNIT_ASSERT_VALUES_EQUAL("2", msg.GetRepStr(0)); + UNIT_ASSERT_VALUES_EQUAL("3", msg.GetRepStr(1)); + } + + { + TMutableField fld(msg, descr->FindFieldByName("OneStr")); + msg.SetOneStr("1"); + UNIT_ASSERT(msg.HasOneStr()); + fld.Remove(0); // not repeated + UNIT_ASSERT(!msg.HasOneStr()); + } + } + + Y_UNIT_TEST(GetFieldByPath) { + // Simple get by path + { + TSample msg; + msg.SetOneStr("1"); + msg.MutableOneMsg()->AddRepInt(2); + msg.MutableOneMsg()->AddRepInt(3); + msg.AddRepMsg()->AddRepInt(4); + msg.MutableRepMsg(0)->AddRepInt(5); + msg.AddRepMsg()->AddRepInt(6); + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "OneStr"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL("1", (field->Get<TString>())); + } + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "OneMsg"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT(field->IsMessageInstance<TInnerSample>()); + } + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "/OneMsg/RepInt"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(2, field->Size()); + UNIT_ASSERT_VALUES_EQUAL(2, field->Get<int>(0)); + UNIT_ASSERT_VALUES_EQUAL(3, field->Get<int>(1)); + } + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "RepMsg/RepInt"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(2, field->Size()); + UNIT_ASSERT_VALUES_EQUAL(4, field->Get<int>(0)); + UNIT_ASSERT_VALUES_EQUAL(5, field->Get<int>(1)); + } + } + + // get of unset fields + { + TSample msg; + msg.MutableOneMsg(); + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "OneStr"); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + } + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "OneMsg/RepInt"); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + } + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "RepMsg/RepInt"); + UNIT_ASSERT(!field); + } + } + + // mutable + { + TSample msg; + msg.MutableOneMsg(); + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "OneStr"); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + field->Set(TString("zz")); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL("zz", msg.GetOneStr()); + } + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "OneStr"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + field->Set(TString("dd")); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL("dd", msg.GetOneStr()); + } + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "OneMsg/RepInt"); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + field->Add(10); + UNIT_ASSERT_VALUES_EQUAL(10, msg.GetOneMsg().GetRepInt(0)); + } + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "RepMsg/RepInt"); + UNIT_ASSERT(!field); + } + } + + // mutable with path creation + { + TSample msg; + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "OneStr", true); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + } + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "OneMsg/RepInt", true); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + UNIT_ASSERT(msg.HasOneMsg()); + field->Add(10); + UNIT_ASSERT_VALUES_EQUAL(10, msg.GetOneMsg().GetRepInt(0)); + } + + { + TMaybe<TMutableField> field = TMutableField::ByPath(msg, "RepMsg/RepInt", true); + TMaybe<TMutableField> fieldCopy = TMutableField::ByPath(msg, "RepMsg/RepInt", true); + Y_UNUSED(fieldCopy); + UNIT_ASSERT(field); + UNIT_ASSERT(!field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(1, msg.RepMsgSize()); + field->Add(12); + UNIT_ASSERT_VALUES_EQUAL(12, field->Get<int>()); + } + } + + // error + { + {TSample msg; + UNIT_ASSERT(!TConstField::ByPath(msg, "SomeField")); + } + + { + TSample msg; + UNIT_ASSERT(!TMutableField::ByPath(msg, "SomeField/FieldSome")); + } + + { + TSample msg; + UNIT_ASSERT(!TMutableField::ByPath(msg, "SomeField/FieldSome", true)); + } +} + +// extension +{ + TSample msg; + msg.SetExtension(NExt::TTestExt::ExtField, "ext"); + msg.SetExtension(NExt::ExtField, 2); + msg.AddExtension(NExt::Ext2Field, 33); + TInnerSample* subMsg = msg.MutableExtension(NExt::SubMsgExt); + subMsg->AddRepInt(20); + subMsg->SetExtension(NExt::Ext3Field, 54); + + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "NExt.TTestExt.ExtField"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL("ext", field->Get<TString>()); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "NExt.ExtField"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(2, field->Get<int>()); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "ExtField"); // ambiguity + UNIT_ASSERT(!field); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "NExt.Ext2Field"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(33, field->Get<int>()); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "Ext2Field"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(33, field->Get<int>()); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "SubMsgExt"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + const TInnerSample* subMsg2 = field->GetAs<TInnerSample>(); + UNIT_ASSERT(subMsg2); + UNIT_ASSERT_VALUES_EQUAL(1, subMsg2->RepIntSize()); + UNIT_ASSERT_VALUES_EQUAL(20, subMsg2->GetRepInt(0)); + UNIT_ASSERT_VALUES_EQUAL(54, subMsg2->GetExtension(NExt::Ext3Field)); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "SubMsgExt/Ext3Field"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(54, field->Get<int>()); + } + { + TMaybe<TConstField> field = TConstField::ByPath(msg, "SubMsgExt/RepInt"); + UNIT_ASSERT(field); + UNIT_ASSERT(field->HasValue()); + UNIT_ASSERT_VALUES_EQUAL(20, field->Get<int>()); + } +} +} +} diff --git a/library/cpp/protobuf/util/sort.h b/library/cpp/protobuf/util/sort.h new file mode 100644 index 0000000000..985ba6f689 --- /dev/null +++ b/library/cpp/protobuf/util/sort.h @@ -0,0 +1,28 @@ +#pragma once + +#include <google/protobuf/message.h> + +#include <util/generic/vector.h> +#include <util/generic/algorithm.h> + +namespace NProtoBuf { + // TComparePtr is something like: + // typedef bool (*TComparePtr)(const Message* msg1, const Message* msg2); + // typedef bool (*TComparePtr)(const TProto* msg1, const TProto* msg2); + + template <typename TProto, typename TComparePtr> + void SortMessages(RepeatedPtrField<TProto>& msgs, TComparePtr cmp) { + TVector<TProto*> ptrs; + ptrs.reserve(msgs.size()); + while (msgs.size()) { + ptrs.push_back(msgs.ReleaseLast()); + } + + ::StableSort(ptrs.begin(), ptrs.end(), cmp); + + for (size_t i = 0; i < ptrs.size(); ++i) { + msgs.AddAllocated(ptrs[i]); + } + } + +} diff --git a/library/cpp/protobuf/util/traits.h b/library/cpp/protobuf/util/traits.h new file mode 100644 index 0000000000..50f036d0ea --- /dev/null +++ b/library/cpp/protobuf/util/traits.h @@ -0,0 +1,320 @@ +#pragma once + +#include <util/generic/typetraits.h> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> + +namespace NProtoBuf { +// this nasty windows.h macro interfers with protobuf::Reflection::GetMessage() +#if defined(GetMessage) +#undef GetMessage +#endif + + struct TCppTypeTraitsBase { + static inline bool Has(const Message& msg, const FieldDescriptor* field) { // non-repeated + return msg.GetReflection()->HasField(msg, field); + } + static inline size_t Size(const Message& msg, const FieldDescriptor* field) { // repeated + return msg.GetReflection()->FieldSize(msg, field); + } + + static inline void Clear(Message& msg, const FieldDescriptor* field) { + msg.GetReflection()->ClearField(&msg, field); + } + + static inline void RemoveLast(Message& msg, const FieldDescriptor* field) { + msg.GetReflection()->RemoveLast(&msg, field); + } + + static inline void SwapElements(Message& msg, const FieldDescriptor* field, int index1, int index2) { + msg.GetReflection()->SwapElements(&msg, field, index1, index2); + } + }; + + // default value accessor + template <FieldDescriptor::CppType cpptype> + struct TCppTypeTraitsDefault; + +#define DECLARE_CPPTYPE_DEFAULT(cpptype, method) \ + template <> \ + struct TCppTypeTraitsDefault<cpptype> { \ + static auto GetDefault(const FieldDescriptor* fd) \ + -> decltype(fd->default_value_##method()) { \ + Y_ASSERT(fd); \ + return fd->default_value_##method(); \ + } \ + }; + + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_INT32, int32); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_INT64, int64); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_UINT32, uint32); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_UINT64, uint64); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_FLOAT, float); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_DOUBLE, double); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_BOOL, bool); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_ENUM, enum); + DECLARE_CPPTYPE_DEFAULT(FieldDescriptor::CppType::CPPTYPE_STRING, string); + +#undef DECLARE_CPPTYPE_DEFAULT + + // getters/setters of field with specified CppType + template <FieldDescriptor::CppType cpptype> + struct TCppTypeTraits : TCppTypeTraitsBase { + static const FieldDescriptor::CppType CppType = cpptype; + + struct T {}; + static T Get(const Message& msg, const FieldDescriptor* field); + static T GetRepeated(const Message& msg, const FieldDescriptor* field, int index); + static T GetDefault(const FieldDescriptor* field); + + static void Set(Message& msg, const FieldDescriptor* field, T value); + static void AddRepeated(Message& msg, const FieldDescriptor* field, T value); + static void SetRepeated(Message& msg, const FieldDescriptor* field, int index, T value); + }; + + // any type T -> CppType + template <typename T> + struct TSelectCppType { + //static const FieldDescriptor::CppType Result = FieldDescriptor::MAX_CPPTYPE; + }; + +#define DECLARE_CPPTYPE_TRAITS(cpptype, type, method) \ + template <> \ + struct TCppTypeTraits<cpptype>: public TCppTypeTraitsBase { \ + typedef type T; \ + static const FieldDescriptor::CppType CppType = cpptype; \ + \ + static inline T Get(const Message& msg, const FieldDescriptor* field) { \ + return msg.GetReflection()->Get##method(msg, field); \ + } \ + static inline T GetRepeated(const Message& msg, const FieldDescriptor* field, int index) { \ + return msg.GetReflection()->GetRepeated##method(msg, field, index); \ + } \ + static inline T GetDefault(const FieldDescriptor* field) { \ + return TCppTypeTraitsDefault<cpptype>::GetDefault(field); \ + } \ + static inline void Set(Message& msg, const FieldDescriptor* field, T value) { \ + msg.GetReflection()->Set##method(&msg, field, value); \ + } \ + static inline void AddRepeated(Message& msg, const FieldDescriptor* field, T value) { \ + msg.GetReflection()->Add##method(&msg, field, value); \ + } \ + static inline void SetRepeated(Message& msg, const FieldDescriptor* field, int index, T value) { \ + msg.GetReflection()->SetRepeated##method(&msg, field, index, value); \ + } \ + }; \ + template <> \ + struct TSelectCppType<type> { \ + static const FieldDescriptor::CppType Result = cpptype; \ + typedef type T; \ + }; + + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_INT32, i32, Int32); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_INT64, i64, Int64); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_UINT32, ui32, UInt32); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_UINT64, ui64, UInt64); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_DOUBLE, double, Double); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_FLOAT, float, Float); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_BOOL, bool, Bool); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_ENUM, const EnumValueDescriptor*, Enum); + DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_STRING, TString, String); + //DECLARE_CPPTYPE_TRAITS(FieldDescriptor::CPPTYPE_MESSAGE, const Message&, Message); + +#undef DECLARE_CPPTYPE_TRAITS + + // specialization for message pointer + template <> + struct TCppTypeTraits<FieldDescriptor::CPPTYPE_MESSAGE>: public TCppTypeTraitsBase { + typedef const Message* T; + static const FieldDescriptor::CppType CppType = FieldDescriptor::CPPTYPE_MESSAGE; + + static inline T Get(const Message& msg, const FieldDescriptor* field) { + return &(msg.GetReflection()->GetMessage(msg, field)); + } + static inline T GetRepeated(const Message& msg, const FieldDescriptor* field, int index) { + return &(msg.GetReflection()->GetRepeatedMessage(msg, field, index)); + } + static inline Message* Set(Message& msg, const FieldDescriptor* field, const Message* value) { + Message* ret = msg.GetReflection()->MutableMessage(&msg, field); + ret->CopyFrom(*value); + return ret; + } + static inline Message* AddRepeated(Message& msg, const FieldDescriptor* field, const Message* value) { + Message* ret = msg.GetReflection()->AddMessage(&msg, field); + ret->CopyFrom(*value); + return ret; + } + static inline Message* SetRepeated(Message& msg, const FieldDescriptor* field, int index, const Message* value) { + Message* ret = msg.GetReflection()->MutableRepeatedMessage(&msg, field, index); + ret->CopyFrom(*value); + return ret; + } + }; + + template <> + struct TSelectCppType<const Message*> { + static const FieldDescriptor::CppType Result = FieldDescriptor::CPPTYPE_MESSAGE; + typedef const Message* T; + }; + + template <> + struct TSelectCppType<Message> { + static const FieldDescriptor::CppType Result = FieldDescriptor::CPPTYPE_MESSAGE; + typedef const Message* T; + }; + + template <FieldDescriptor::CppType CppType, bool Repeated> + struct TFieldTraits { + typedef TCppTypeTraits<CppType> TBaseTraits; + typedef typename TBaseTraits::T T; + + static inline T Get(const Message& msg, const FieldDescriptor* field, size_t index = 0) { + Y_ASSERT(index == 0); + return TBaseTraits::Get(msg, field); + } + + static inline T GetDefault(const FieldDescriptor* field) { + return TBaseTraits::GetDefault(field); + } + + static inline bool Has(const Message& msg, const FieldDescriptor* field) { + return TBaseTraits::Has(msg, field); + } + + static inline size_t Size(const Message& msg, const FieldDescriptor* field) { + return Has(msg, field); + } + + static inline void Set(Message& msg, const FieldDescriptor* field, T value, size_t index = 0) { + Y_ASSERT(index == 0); + TBaseTraits::Set(msg, field, value); + } + + static inline void Add(Message& msg, const FieldDescriptor* field, T value) { + TBaseTraits::Set(msg, field, value); + } + }; + + template <FieldDescriptor::CppType CppType> + struct TFieldTraits<CppType, true> { + typedef TCppTypeTraits<CppType> TBaseTraits; + typedef typename TBaseTraits::T T; + + static inline T Get(const Message& msg, const FieldDescriptor* field, size_t index = 0) { + return TBaseTraits::GetRepeated(msg, field, index); + } + + static inline T GetDefault(const FieldDescriptor* field) { + return TBaseTraits::GetDefault(field); + } + + static inline size_t Size(const Message& msg, const FieldDescriptor* field) { + return TBaseTraits::Size(msg, field); + } + + static inline bool Has(const Message& msg, const FieldDescriptor* field) { + return Size(msg, field) > 0; + } + + static inline void Set(Message& msg, const FieldDescriptor* field, T value, size_t index = 0) { + TBaseTraits::SetRepeated(msg, field, index, value); + } + + static inline void Add(Message& msg, const FieldDescriptor* field, T value) { + TBaseTraits::AddRepeated(msg, field, value); + } + }; + + // Simpler interface at the cost of checking is_repeated() on each call + template <FieldDescriptor::CppType CppType> + struct TSimpleFieldTraits { + typedef TFieldTraits<CppType, true> TRepeated; + typedef TFieldTraits<CppType, false> TSingle; + typedef typename TRepeated::T T; + + static inline size_t Size(const Message& msg, const FieldDescriptor* field) { + if (field->is_repeated()) + return TRepeated::Size(msg, field); + else + return TSingle::Size(msg, field); + } + + static inline bool Has(const Message& msg, const FieldDescriptor* field) { + if (field->is_repeated()) + return TRepeated::Has(msg, field); + else + return TSingle::Has(msg, field); + } + + static inline T Get(const Message& msg, const FieldDescriptor* field, size_t index = 0) { + Y_ASSERT(index < Size(msg, field) || !field->is_repeated() && index == 0); // Get for single fields is always allowed because of default values + if (field->is_repeated()) + return TRepeated::Get(msg, field, index); + else + return TSingle::Get(msg, field, index); + } + + static inline T GetDefault(const FieldDescriptor* field) { + return TSingle::GetDefault(field); + } + + static inline void Set(Message& msg, const FieldDescriptor* field, T value, size_t index = 0) { + Y_ASSERT(!field->is_repeated() && index == 0 || index < Size(msg, field)); + if (field->is_repeated()) + TRepeated::Set(msg, field, value, index); + else + TSingle::Set(msg, field, value, index); + } + + static inline void Add(Message& msg, const FieldDescriptor* field, T value) { + if (field->is_repeated()) + TRepeated::Add(msg, field, value); + else + TSingle::Add(msg, field, value); + } + }; + + // some cpp-type groups + + template <FieldDescriptor::CppType CppType> + struct TIsIntegerCppType { + enum { + Result = CppType == FieldDescriptor::CPPTYPE_INT32 || + CppType == FieldDescriptor::CPPTYPE_INT64 || + CppType == FieldDescriptor::CPPTYPE_UINT32 || + CppType == FieldDescriptor::CPPTYPE_UINT64 + }; + }; + + template <FieldDescriptor::CppType CppType> + struct TIsFloatCppType { + enum { + Result = CppType == FieldDescriptor::CPPTYPE_FLOAT || + CppType == FieldDescriptor::CPPTYPE_DOUBLE + }; + }; + + template <FieldDescriptor::CppType CppType> + struct TIsNumericCppType { + enum { + Result = CppType == FieldDescriptor::CPPTYPE_BOOL || + TIsIntegerCppType<CppType>::Result || + TIsFloatCppType<CppType>::Result + }; + }; + + // a helper macro for splitting flow by cpp-type (e.g. in a switch) + +#define APPLY_TMP_MACRO_FOR_ALL_CPPTYPES() \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_INT32) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_INT64) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_UINT32) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_UINT64) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_DOUBLE) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_FLOAT) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_BOOL) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_ENUM) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_STRING) \ + TMP_MACRO_FOR_CPPTYPE(NProtoBuf::FieldDescriptor::CPPTYPE_MESSAGE) +} diff --git a/library/cpp/protobuf/util/ut/common_ut.proto b/library/cpp/protobuf/util/ut/common_ut.proto new file mode 100644 index 0000000000..9cf803ffbf --- /dev/null +++ b/library/cpp/protobuf/util/ut/common_ut.proto @@ -0,0 +1,72 @@ +import "google/protobuf/descriptor.proto"; +import "library/cpp/protobuf/util/proto/merge.proto"; + +package NProtobufUtilUt; + +extend google.protobuf.FieldOptions { + optional bool XXX = 53772; +} + +message TWalkTest { + optional uint32 OptInt = 1 [(XXX)=true]; + repeated uint32 RepInt = 2; + + optional string OptStr = 3; + repeated string RepStr = 4 [(XXX)=true]; + + optional TWalkTest OptSub = 5 [(XXX)=true]; + repeated TWalkTest RepSub = 6; +} + +message TWalkTestCyclic { + optional TNested OptNested = 1; + repeated uint64 OptInt64 = 2; + optional TWalkTestCyclic OptSub = 3; + optional TEnum OptEnum = 4; + + message TNested { + optional uint32 OptInt32 = 1; + optional TWalkTestCyclic OptSubNested = 2; + repeated string RepStr = 3; + optional TNested OptNested = 4; + } + enum TEnum { + A = 0; + B = 1; + C = 2; + } +} + +message TMergeTestNoMerge { + option (DontMerge) = true; + + optional uint32 A = 1; + repeated uint32 B = 2; +} + +message TMergeTestMerge { + optional uint32 A = 1; + repeated uint32 B = 2; + repeated uint32 C = 3 [(DontMergeField)=true]; +} + +message TMergeTest { + repeated uint32 MergeInt = 1; + repeated uint32 NoMergeInt = 2 [(DontMergeField)=true]; + + optional TMergeTestMerge MergeSub = 3; + repeated TMergeTestMerge NoMergeRepSub = 4 [(DontMergeField)=true]; + optional TMergeTestNoMerge NoMergeOptSub = 5; +} + +message TTextTest { + optional uint32 Foo = 1; +} + +message TTextEnumTest { + enum EnumTest { + EET_SLOT_1 = 1; + EET_SLOT_2 = 2; + } + optional EnumTest Slot = 1; +} diff --git a/library/cpp/protobuf/util/ut/extensions.proto b/library/cpp/protobuf/util/ut/extensions.proto new file mode 100644 index 0000000000..4944f0f5ca --- /dev/null +++ b/library/cpp/protobuf/util/ut/extensions.proto @@ -0,0 +1,22 @@ +package NExt; + +import "library/cpp/protobuf/util/ut/sample_for_simple_reflection.proto"; + +message TTestExt { + extend TSample { + optional string ExtField = 100; + } +} + +extend TSample { + optional uint64 ExtField = 150; // the same name, but another full name +} + +extend TSample { + repeated uint64 Ext2Field = 105; + optional TInnerSample SubMsgExt = 111; +} + +extend TInnerSample { + optional uint64 Ext3Field = 100; +} diff --git a/library/cpp/protobuf/util/ut/sample_for_is_equal.proto b/library/cpp/protobuf/util/ut/sample_for_is_equal.proto new file mode 100644 index 0000000000..a91c16deaa --- /dev/null +++ b/library/cpp/protobuf/util/ut/sample_for_is_equal.proto @@ -0,0 +1,8 @@ +message TInner { + optional string Brbrbr = 3; +} + +message TSampleForIsEqual { + optional string Name = 1; + optional TInner Inner = 5; +} diff --git a/library/cpp/protobuf/util/ut/sample_for_simple_reflection.proto b/library/cpp/protobuf/util/ut/sample_for_simple_reflection.proto new file mode 100644 index 0000000000..cca1dd869a --- /dev/null +++ b/library/cpp/protobuf/util/ut/sample_for_simple_reflection.proto @@ -0,0 +1,25 @@ +message TInnerSample { + repeated int32 RepInt = 1; + + extensions 100 to 199; +} + +message TSample { + optional string OneStr = 1; + optional TInnerSample OneMsg = 2; + repeated TInnerSample RepMsg = 3; + repeated string RepStr = 4; + optional string AnotherOneStr = 5; + + optional int32 OneInt = 6; + repeated int32 RepInt = 7; + + enum EEnum { + V1 = 1; + V2 = 2; + } + optional EEnum OneEnum = 8; + repeated EEnum RepEnum = 9; + + extensions 100 to 199; +} diff --git a/library/cpp/protobuf/util/ut/ya.make b/library/cpp/protobuf/util/ut/ya.make new file mode 100644 index 0000000000..701ba9a8c8 --- /dev/null +++ b/library/cpp/protobuf/util/ut/ya.make @@ -0,0 +1,19 @@ +OWNER(nga) + +UNITTEST_FOR(library/cpp/protobuf/util) + +SRCS( + extensions.proto + sample_for_is_equal.proto + sample_for_simple_reflection.proto + common_ut.proto + pb_io_ut.cpp + is_equal_ut.cpp + iterators_ut.cpp + simple_reflection_ut.cpp + repeated_field_utils_ut.cpp + walk_ut.cpp + merge_ut.cpp +) + +END() diff --git a/library/cpp/protobuf/util/walk.cpp b/library/cpp/protobuf/util/walk.cpp new file mode 100644 index 0000000000..b65ec03e04 --- /dev/null +++ b/library/cpp/protobuf/util/walk.cpp @@ -0,0 +1,72 @@ +#include "walk.h" + +#include <util/generic/hash_set.h> + +namespace { + using namespace NProtoBuf; + + template <typename TMessage, typename TOnField> + void DoWalkReflection(TMessage& msg, TOnField& onField) { + const Descriptor* descr = msg.GetDescriptor(); + for (int i1 = 0; i1 < descr->field_count(); ++i1) { + const FieldDescriptor* fd = descr->field(i1); + if (!onField(msg, fd)) { + continue; + } + + std::conditional_t<std::is_const_v<TMessage>, TConstField, TMutableField> ff(msg, fd); + if (ff.IsMessage()) { + for (size_t i2 = 0; i2 < ff.Size(); ++i2) { + if constexpr (std::is_const_v<TMessage>) { + WalkReflection(*ff.template Get<Message>(i2), onField); + } else { + WalkReflection(*ff.MutableMessage(i2), onField); + } + } + } + } + } + + void DoWalkSchema(const Descriptor* descriptor, + std::function<bool(const FieldDescriptor*)>& onField, + THashSet<const Descriptor*>& visited) + { + if (!visited.emplace(descriptor).second) { + return; + } + for (int i1 = 0; i1 < descriptor->field_count(); ++i1) { + const FieldDescriptor* fd = descriptor->field(i1); + if (!onField(fd)) { + continue; + } + + if (fd->type() == FieldDescriptor::Type::TYPE_MESSAGE) { + DoWalkSchema(fd->message_type(), onField, visited); + } + } + visited.erase(descriptor); + } + +} + +namespace NProtoBuf { + void WalkReflection(Message& msg, + std::function<bool(Message&, const FieldDescriptor*)> onField) + { + DoWalkReflection(msg, onField); + } + + void WalkReflection(const Message& msg, + std::function<bool(const Message&, const FieldDescriptor*)> onField) + { + DoWalkReflection(msg, onField); + } + + void WalkSchema(const Descriptor* descriptor, + std::function<bool(const FieldDescriptor*)> onField) + { + THashSet<const Descriptor*> visited; + DoWalkSchema(descriptor, onField, visited); + } + +} // namespace NProtoBuf diff --git a/library/cpp/protobuf/util/walk.h b/library/cpp/protobuf/util/walk.h new file mode 100644 index 0000000000..d15d76562d --- /dev/null +++ b/library/cpp/protobuf/util/walk.h @@ -0,0 +1,33 @@ +#pragma once + +#include "simple_reflection.h" + +#include <google/protobuf/message.h> +#include <google/protobuf/descriptor.h> + +#include <functional> + +namespace NProtoBuf { + // Apply @onField processor to each field in @msg (even empty) + // Do not walk deeper the field if the field is an empty message + // Returned bool defines if we should walk down deeper to current node children (true), or not (false) + void WalkReflection(Message& msg, + std::function<bool(Message&, const FieldDescriptor*)> onField); + void WalkReflection(const Message& msg, + std::function<bool(const Message&, const FieldDescriptor*)> onField); + + template <typename TOnField> + inline void WalkReflection(Message& msg, TOnField& onField) { // is used when TOnField is a callable class instance + WalkReflection(msg, std::function<bool(Message&, const FieldDescriptor*)>(std::ref(onField))); + } + template <typename TOnField> + inline void WalkReflection(const Message& msg, TOnField& onField) { + WalkReflection(msg, std::function<bool(const Message&, const FieldDescriptor*)>(std::ref(onField))); + } + + // Apply @onField processor to each descriptor of a field + // Walk every field including nested messages. Avoid cyclic fields pointing to themselves + // Returned bool defines if we should walk down deeper to current node children (true), or not (false) + void WalkSchema(const Descriptor* descriptor, + std::function<bool(const FieldDescriptor*)> onField); +} diff --git a/library/cpp/protobuf/util/walk_ut.cpp b/library/cpp/protobuf/util/walk_ut.cpp new file mode 100644 index 0000000000..2ea6071b17 --- /dev/null +++ b/library/cpp/protobuf/util/walk_ut.cpp @@ -0,0 +1,158 @@ +#include "walk.h" +#include "simple_reflection.h" +#include <library/cpp/protobuf/util/ut/common_ut.pb.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NProtoBuf; + +Y_UNIT_TEST_SUITE(ProtobufWalk) { + static void InitProto(NProtobufUtilUt::TWalkTest & p, int level = 0) { + p.SetOptInt(1); + p.AddRepInt(2); + p.AddRepInt(3); + + p.SetOptStr("123"); + p.AddRepStr("*"); + p.AddRepStr("abcdef"); + p.AddRepStr("1234"); + + if (level == 0) { + InitProto(*p.MutableOptSub(), 1); + InitProto(*p.AddRepSub(), 1); + InitProto(*p.AddRepSub(), 1); + } + } + + static bool IncreaseInts(Message & msg, const FieldDescriptor* fd) { + TMutableField f(msg, fd); + if (f.IsInstance<ui32>()) { + for (size_t i = 0; i < f.Size(); ++i) + f.Set(f.Get<ui64>(i) + 1, i); // ui64 should be ok! + } + return true; + } + + static bool RepeatString1(Message & msg, const FieldDescriptor* fd) { + TMutableField f(msg, fd); + if (f.IsString()) { + for (size_t i = 0; i < f.Size(); ++i) + if (f.Get<TString>(i).StartsWith('1')) + f.Set(f.Get<TString>(i) + f.Get<TString>(i), i); + } + return true; + } + + static bool ClearXXX(Message & msg, const FieldDescriptor* fd) { + const FieldOptions& opt = fd->options(); + if (opt.HasExtension(NProtobufUtilUt::XXX) && opt.GetExtension(NProtobufUtilUt::XXX)) + TMutableField(msg, fd).Clear(); + + return true; + } + + struct TestStruct { + bool Ok = false; + + TestStruct() = default; + bool operator()(Message&, const FieldDescriptor*) { + Ok = true; + return false; + } + }; + + Y_UNIT_TEST(TestWalkRefl) { + NProtobufUtilUt::TWalkTest p; + InitProto(p); + + { + UNIT_ASSERT_EQUAL(p.GetOptInt(), 1); + UNIT_ASSERT_EQUAL(p.RepIntSize(), 2); + UNIT_ASSERT_EQUAL(p.GetRepInt(0), 2); + UNIT_ASSERT_EQUAL(p.GetRepInt(1), 3); + + WalkReflection(p, IncreaseInts); + + UNIT_ASSERT_EQUAL(p.GetOptInt(), 2); + UNIT_ASSERT_EQUAL(p.RepIntSize(), 2); + UNIT_ASSERT_EQUAL(p.GetRepInt(0), 3); + UNIT_ASSERT_EQUAL(p.GetRepInt(1), 4); + + UNIT_ASSERT_EQUAL(p.GetOptSub().GetOptInt(), 2); + UNIT_ASSERT_EQUAL(p.GetOptSub().RepIntSize(), 2); + UNIT_ASSERT_EQUAL(p.GetOptSub().GetRepInt(0), 3); + UNIT_ASSERT_EQUAL(p.GetOptSub().GetRepInt(1), 4); + + UNIT_ASSERT_EQUAL(p.RepSubSize(), 2); + UNIT_ASSERT_EQUAL(p.GetRepSub(1).GetOptInt(), 2); + UNIT_ASSERT_EQUAL(p.GetRepSub(1).RepIntSize(), 2); + UNIT_ASSERT_EQUAL(p.GetRepSub(1).GetRepInt(0), 3); + UNIT_ASSERT_EQUAL(p.GetRepSub(1).GetRepInt(1), 4); + } + { + UNIT_ASSERT_EQUAL(p.GetOptStr(), "123"); + UNIT_ASSERT_EQUAL(p.GetRepStr(2), "1234"); + + WalkReflection(p, RepeatString1); + + UNIT_ASSERT_EQUAL(p.GetOptStr(), "123123"); + UNIT_ASSERT_EQUAL(p.RepStrSize(), 3); + UNIT_ASSERT_EQUAL(p.GetRepStr(0), "*"); + UNIT_ASSERT_EQUAL(p.GetRepStr(1), "abcdef"); + UNIT_ASSERT_EQUAL(p.GetRepStr(2), "12341234"); + + UNIT_ASSERT_EQUAL(p.RepSubSize(), 2); + UNIT_ASSERT_EQUAL(p.GetRepSub(0).GetOptStr(), "123123"); + UNIT_ASSERT_EQUAL(p.GetRepSub(0).RepStrSize(), 3); + UNIT_ASSERT_EQUAL(p.GetRepSub(0).GetRepStr(0), "*"); + UNIT_ASSERT_EQUAL(p.GetRepSub(0).GetRepStr(1), "abcdef"); + UNIT_ASSERT_EQUAL(p.GetRepSub(0).GetRepStr(2), "12341234"); + } + { + UNIT_ASSERT(p.HasOptInt()); + UNIT_ASSERT(p.RepStrSize() == 3); + UNIT_ASSERT(p.HasOptSub()); + + WalkReflection(p, ClearXXX); + + UNIT_ASSERT(!p.HasOptInt()); + UNIT_ASSERT(p.RepIntSize() == 2); + UNIT_ASSERT(p.HasOptStr()); + UNIT_ASSERT(p.RepStrSize() == 0); + UNIT_ASSERT(!p.HasOptSub()); + UNIT_ASSERT(p.RepSubSize() == 2); + } + } + + Y_UNIT_TEST(TestMutableCallable) { + TestStruct testStruct; + NProtobufUtilUt::TWalkTest p; + InitProto(p); + + WalkReflection(p, testStruct); + UNIT_ASSERT(testStruct.Ok); + } + + Y_UNIT_TEST(TestWalkDescr) { + NProtobufUtilUt::TWalkTestCyclic p; + + TStringBuilder printedSchema; + auto func = [&](const FieldDescriptor* desc) mutable { + printedSchema << desc->DebugString(); + return true; + }; + WalkSchema(p.GetDescriptor(), func); + + TString schema = + "optional .NProtobufUtilUt.TWalkTestCyclic.TNested OptNested = 1;\n" + "optional uint32 OptInt32 = 1;\n" + "optional .NProtobufUtilUt.TWalkTestCyclic OptSubNested = 2;\n" + "repeated string RepStr = 3;\n" + "optional .NProtobufUtilUt.TWalkTestCyclic.TNested OptNested = 4;\n" + "repeated uint64 OptInt64 = 2;\n" + "optional .NProtobufUtilUt.TWalkTestCyclic OptSub = 3;\n" + "optional .NProtobufUtilUt.TWalkTestCyclic.TEnum OptEnum = 4;\n"; + + UNIT_ASSERT_STRINGS_EQUAL(printedSchema, schema); + } +} diff --git a/library/cpp/protobuf/util/ya.make b/library/cpp/protobuf/util/ya.make new file mode 100644 index 0000000000..b62028af58 --- /dev/null +++ b/library/cpp/protobuf/util/ya.make @@ -0,0 +1,26 @@ +LIBRARY() + +OWNER(mowgli) + +PEERDIR( + contrib/libs/protobuf + library/cpp/binsaver + library/cpp/protobuf/util/proto + library/cpp/string_utils/base64 +) + +SRCS( + is_equal.cpp + iterators.h + merge.cpp + path.cpp + pb_io.cpp + pb_utils.h + repeated_field_utils.h + simple_reflection.cpp + walk.cpp +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/protobuf/ya.make b/library/cpp/protobuf/ya.make new file mode 100644 index 0000000000..618b542b4f --- /dev/null +++ b/library/cpp/protobuf/ya.make @@ -0,0 +1,19 @@ +RECURSE( + dynamic_prototype + from_xml + from_xml/ut + interop + interop/ut + json + json/ut + parser + parser/ut + protofile + protofile/ut + util + util/proto + yql + yql/ut + yandex_patches_ut + yt +) |