#include "pb_io.h"
#include <library/cpp/binsaver/bin_saver.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <google/protobuf/io/tokenizer.h>
#include <google/protobuf/message.h>
#include <google/protobuf/messagext.h>
#include <google/protobuf/text_format.h>
#include <util/generic/string.h>
#include <util/stream/file.h>
#include <util/stream/str.h>
#include <util/string/cast.h>
namespace NProtoBuf {
class TEnumIdValuePrinter : public google::protobuf::TextFormat::FastFieldValuePrinter {
public:
void PrintEnum(int32 val, const TString& /*name*/, google::protobuf::TextFormat::BaseTextGenerator* generator) const override {
generator->PrintString(ToString(val));
}
};
void ParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven) {
if (!m.ParseFromString(allowUneven ? Base64DecodeUneven(dataBase64) : Base64StrictDecode(dataBase64))) {
ythrow yexception() << "can't parse " << m.GetTypeName() << " from base64-encoded string";
}
}
bool TryParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven) {
try {
ParseFromBase64String(dataBase64, m, allowUneven);
return true;
} catch (const std::exception&) {
return false;
}
}
void SerializeToBase64String(const Message& m, TString& dataBase64) {
TString rawData;
if (!m.SerializeToString(&rawData)) {
ythrow yexception() << "can't serialize " << m.GetTypeName();
}
Base64EncodeUrl(rawData, dataBase64);
}
TString SerializeToBase64String(const Message& m) {
TString s;
SerializeToBase64String(m, s);
return s;
}
bool TrySerializeToBase64String(const Message& m, TString& dataBase64) {
try {
SerializeToBase64String(m, dataBase64);
return true;
} catch (const std::exception&) {
return false;
}
}
const TString ShortUtf8DebugString(const Message& message) {
TextFormat::Printer printer;
printer.SetSingleLineMode(true);
printer.SetUseUtf8StringEscaping(true);
TString result;
printer.PrintToString(message, &result);
return result;
}
bool MergePartialFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage) {
google::protobuf::io::CodedInputStream input(reinterpret_cast<const ui8*>(serializedProtoMessage.data()), serializedProtoMessage.size());
bool ok = m.MergePartialFromCodedStream(&input);
ok = ok && input.ConsumedEntireMessage();
return ok;
}
bool MergeFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage) {
return MergePartialFromString(m, serializedProtoMessage) && m.IsInitialized();
}
} // end of namespace NProtoBuf
namespace {
class TErrorCollector: public NProtoBuf::io::ErrorCollector {
public:
TErrorCollector(const NProtoBuf::Message& m, IOutputStream* errorOut, IOutputStream* warningOut)
: TypeName_(m.GetTypeName())
{
ErrorOut_ = errorOut ? errorOut : &Cerr;
WarningOut_ = warningOut ? warningOut : &Cerr;
}
void AddError(int line, int column, const TProtoStringType& message) override {
PrintErrorMessage(ErrorOut_, "Error", line, column, message);
}
void AddWarning(int line, int column, const TProtoStringType& message) override {
PrintErrorMessage(WarningOut_, "Warning", line, column, message);
}
private:
void PrintErrorMessage(IOutputStream* out, TStringBuf errorLevel, int line, int column, const TProtoStringType& message) {
(*out) << errorLevel << " parsing text-format ";
if (line >= 0) {
(*out) << TypeName_ << ": " << (line + 1) << ":" << (column + 1) << ": " << message;
} else {
(*out) << TypeName_ << ": " << message;
}
out->Flush();
}
private:
const TProtoStringType TypeName_;
IOutputStream* ErrorOut_;
IOutputStream* WarningOut_;
};
} // end of anonymous namespace
int operator&(NProtoBuf::Message& m, IBinSaver& f) {
TStringStream ss;
if (f.IsReading()) {
f.Add(0, &ss.Str());
m.ParseFromArcadiaStream(&ss);
} else {
m.SerializeToArcadiaStream(&ss);
f.Add(0, &ss.Str());
}
return 0;
}
void SerializeToTextFormat(const NProtoBuf::Message& m, IOutputStream& out) {
NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out);
if (!NProtoBuf::TextFormat::Print(m, &adaptor)) {
ythrow yexception() << "SerializeToTextFormat failed on Print";
}
}
void SerializeToTextFormat(const NProtoBuf::Message& m, const TString& fileName) {
/* TUnbufferedFileOutput is unbuffered, but TCopyingOutputStreamAdaptor adds
* a buffer on top of it. */
TUnbufferedFileOutput stream(fileName);
SerializeToTextFormat(m, stream);
}
void SerializeToTextFormatWithEnumId(const NProtoBuf::Message& m, IOutputStream& out) {
google::protobuf::TextFormat::Printer printer;
printer.SetDefaultFieldValuePrinter(new NProtoBuf::TEnumIdValuePrinter());
NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out);
if (!printer.Print(m, &adaptor)) {
ythrow yexception() << "SerializeToTextFormatWithEnumId failed on Print";
}
}
void SerializeToTextFormatPretty(const NProtoBuf::Message& m, IOutputStream& out) {
google::protobuf::TextFormat::Printer printer;
printer.SetUseUtf8StringEscaping(true);
printer.SetUseShortRepeatedPrimitives(true);
NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out);
if (!printer.Print(m, &adaptor)) {
ythrow yexception() << "SerializeToTextFormatPretty failed on Print";
}
}
static void ConfigureParser(const EParseFromTextFormatOptions options,
NProtoBuf::TextFormat::Parser& p) {
if (options & EParseFromTextFormatOption::AllowUnknownField) {
p.AllowUnknownField(true);
}
}
void ParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
NProtoBuf::io::TCopyingInputStreamAdaptor adaptor(&in);
NProtoBuf::TextFormat::Parser p;
ConfigureParser(options, p);
TStringStream errorLog;
THolder<TErrorCollector> errorCollector;
errorCollector = MakeHolder<TErrorCollector>(m, &errorLog, warningStream);
p.RecordErrorsTo(errorCollector.Get());
if (!p.Parse(&adaptor, &m)) {
// remove everything that may have been read
m.Clear();
ythrow yexception() << errorLog.Str();
}
}
void ParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
/* TUnbufferedFileInput is unbuffered, but TCopyingInputStreamAdaptor adds
* a buffer on top of it. */
TUnbufferedFileInput stream(fileName);
ParseFromTextFormat(stream, m, options, warningStream);
}
bool TryParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
try {
ParseFromTextFormat(fileName, m, options, warningStream);
} catch (std::exception&) {
return false;
}
return true;
}
bool TryParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
try {
ParseFromTextFormat(in, m, options, warningStream);
} catch (std::exception&) {
return false;
}
return true;
}
void MergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options) {
NProtoBuf::io::TCopyingInputStreamAdaptor adaptor(&in);
NProtoBuf::TextFormat::Parser p;
ConfigureParser(options, p);
if (!p.Merge(&adaptor, &m)) {
ythrow yexception() << "MergeFromTextFormat failed on Merge for " << m.GetTypeName();
}
}
void MergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options) {
/* TUnbufferedFileInput is unbuffered, but TCopyingInputStreamAdaptor adds
* a buffer on top of it. */
TUnbufferedFileInput stream(fileName);
MergeFromTextFormat(stream, m, options);
}
bool TryMergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options) {
try {
MergeFromTextFormat(fileName, m, options);
} catch (std::exception&) {
return false;
}
return true;
}
bool TryMergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
const EParseFromTextFormatOptions options) {
try {
MergeFromTextFormat(in, m, options);
} catch (std::exception&) {
return false;
}
return true;
}