aboutsummaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorqrort <qrort@yandex-team.com>2022-11-30 23:47:12 +0300
committerqrort <qrort@yandex-team.com>2022-11-30 23:47:12 +0300
commit22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch)
treebffa27765faf54126ad44bcafa89fadecb7a73d7 /tools
parent332b99e2173f0425444abb759eebcb2fafaa9209 (diff)
downloadydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz
validate canons without yatest_common
Diffstat (limited to 'tools')
-rw-r--r--tools/event2cpp/proto_events.cpp893
-rw-r--r--tools/event2cpp/proto_events.h20
-rw-r--r--tools/go_test_miner/main.go168
-rw-r--r--tools/triecompiler/lib/main.cpp473
-rw-r--r--tools/triecompiler/lib/main.h5
-rw-r--r--tools/triecompiler/main.cpp5
6 files changed, 1564 insertions, 0 deletions
diff --git a/tools/event2cpp/proto_events.cpp b/tools/event2cpp/proto_events.cpp
new file mode 100644
index 0000000000..66e9296d2c
--- /dev/null
+++ b/tools/event2cpp/proto_events.cpp
@@ -0,0 +1,893 @@
+#include <google/protobuf/compiler/cpp/cpp_helpers.h>
+#include <google/protobuf/io/zero_copy_stream.h>
+#include <google/protobuf/io/printer.h>
+#include <google/protobuf/stubs/strutil.h>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/descriptor.pb.h>
+
+#include <util/string/cast.h>
+#include <util/generic/singleton.h>
+#include <util/generic/yexception.h>
+
+#include <library/cpp/eventlog/proto/events_extension.pb.h>
+
+#include "proto_events.h"
+
+namespace NProtoBuf::NCompiler::NPlugins {
+
+namespace NInternal {
+ using namespace google::protobuf;
+ using namespace google::protobuf::compiler;
+ using namespace google::protobuf::compiler::cpp;
+
+ typedef std::map<TProtoStringType, TProtoStringType> TVariables;
+
+ void CheckMessageId(size_t id, const TProtoStringType& name) {
+ typedef std::map<size_t, TProtoStringType> TMessageIds;
+ TMessageIds* ids = Singleton<TMessageIds>();
+ TMessageIds::const_iterator it = ids->find(id);
+
+ if (it != ids->end()) {
+ throw yexception() << "Duplicate message_id = " << id
+ << " in messages " << name
+ << " and " << it->second << Endl;
+ }
+
+ (*ids)[id] = name;
+ }
+
+ void SetCommonFieldVariables(const FieldDescriptor* descriptor, TVariables* variables) {
+ (*variables)["rname"] = descriptor->name();
+ (*variables)["name"] = FieldName(descriptor);
+ }
+
+ TProtoStringType HeaderFileName(const FileDescriptor* file) {
+ TProtoStringType basename = cpp::StripProto(file->name());
+
+ return basename.append(".pb.h");
+ }
+
+ TProtoStringType SourceFileName(const FileDescriptor* file) {
+ TProtoStringType basename = cpp::StripProto(file->name());
+
+ return basename.append(".pb.cc");
+ }
+
+ void GeneratePrintingCycle(TVariables vars, TProtoStringType printTemplate, io::Printer* printer) {
+ printer->Print("\n{\n");
+ printer->Indent();
+ printer->Print(vars,
+ "NProtoBuf::$repeated_field_type$< $type$ >::const_iterator b = $name$().begin();\n"
+ "NProtoBuf::$repeated_field_type$< $type$ >::const_iterator e = $name$().end();\n\n");
+ printer->Print("output << \"[\";\n");
+ printer->Print("if (b != e) {\n");
+ vars["obj"] = "(*b++)";
+ printer->Print(vars, printTemplate.c_str());
+ printer->Print(";\n");
+ printer->Print(vars,
+ "for (NProtoBuf::$repeated_field_type$< $type$ >::const_iterator it = b; it != e; ++it) {\n");
+ printer->Indent();
+ printer->Print("output << \",\";\n");
+ vars["obj"] = "(*it)";
+ printer->Print(vars, printTemplate.c_str());
+ printer->Print(";\n");
+ printer->Outdent();
+ printer->Print("}\n}\n");
+ printer->Print("output << \"]\";\n");
+ printer->Outdent();
+ printer->Print("}\n");
+ }
+
+ class TFieldExtGenerator {
+ public:
+ TFieldExtGenerator(const FieldDescriptor* field)
+ : Descriptor_(field)
+ {
+ SetCommonFieldVariables(Descriptor_, &Variables_);
+ }
+
+ virtual ~TFieldExtGenerator() {
+ }
+
+ virtual bool NeedProtobufMessageFieldPrinter() const {
+ return false;
+ }
+
+ virtual void GenerateCtorArgument(io::Printer* printer) = 0;
+ virtual void GenerateInitializer(io::Printer* printer, const TString& prefix) = 0;
+ virtual void GeneratePrintingCode(io::Printer* printer) = 0;
+ protected:
+ const FieldDescriptor* Descriptor_;
+ TVariables Variables_;
+ };
+
+ class TMessageFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TMessageFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["type"] = ClassName(Descriptor_->message_type(), true);
+ Variables_["has_print_function"] = Descriptor_->message_type()->options().HasExtension(message_id) ? "true" : "false";
+ }
+
+ bool NeedProtobufMessageFieldPrinter() const override {
+ return true;
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "const $type$& arg_$name$");
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(Variables_,
+ "$prefix$mutable_$name$()->CopyFrom(arg_$name$);\n");
+ }
+
+ void GeneratePrintingCode(io::Printer* printer) override {
+ printer->Print("output << \"{\";\n");
+ printer->Print(Variables_,
+ "protobufMessageFieldPrinter.PrintProtobufMessageFieldToOutput<$type$, $has_print_function$>($name$(), escapedOutput);\n");
+ printer->Print("output << \"}\";\n");
+ }
+ };
+
+ class TMapFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TMapFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ }
+
+ void GenerateCtorArgument(io::Printer* /* printer */) override {
+ }
+
+ void GenerateInitializer(io::Printer* /* printer */, const TString& /* prefix */) override {
+ }
+
+ void GeneratePrintingCode(io::Printer* /* printer */) override {
+ }
+ };
+
+ class TRepeatedMessageFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TRepeatedMessageFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["type"] = ClassName(Descriptor_->message_type(), true);
+ Variables_["repeated_field_type"] = "RepeatedPtrField";
+ Variables_["has_print_function"] = Descriptor_->message_type()->options().HasExtension(message_id) ? "true" : "false";
+ }
+
+ bool NeedProtobufMessageFieldPrinter() const override {
+ return true;
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "const $type$& arg_$name$");
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(Variables_,
+ "$prefix$add_$name$()->CopyFrom(arg_$name$);\n");
+ }
+ void GeneratePrintingCode(io::Printer* printer) override {
+ GeneratePrintingCycle(Variables_, "protobufMessageFieldPrinter.PrintProtobufMessageFieldToOutput<$type$, $has_print_function$>($obj$, escapedOutput)", printer);
+ }
+ };
+
+ class TStringFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TStringFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["pointer_type"] = Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? "void" : "char";
+ Variables_["type"] = "TProtoStringType";
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ (Descriptor_->type() == FieldDescriptor::TYPE_BYTES ?
+ "const $pointer_type$* arg_$name$, size_t arg_$name$_size" : "const $type$& arg_$name$")
+ );
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(
+ Variables_,
+ Descriptor_->type() == FieldDescriptor::TYPE_BYTES ?
+ "$prefix$set_$name$(arg_$name$, arg_$name$_size);\n" :
+ "$prefix$set_$name$(arg_$name$);\n"
+ );
+ }
+
+ void GeneratePrintingCode(io::Printer* printer) override {
+ Repr::ReprType fmt = Repr::none;
+
+ if (Descriptor_->options().HasExtension(repr)) {
+ fmt = Descriptor_->options().GetExtension(repr);
+ }
+
+ switch (fmt) {
+ case Repr::as_base64:
+ printer->Print(Variables_, "NProtoBuf::PrintAsBase64($name$(), output);\n");
+ break;
+
+ case Repr::none:
+ /* TODO: proper error handling?*/
+ default:
+ printer->Print(Variables_, "escapedOutput << $name$();\n");
+ break;
+ }
+ }
+ };
+
+ class TRepeatedStringFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TRepeatedStringFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["pointer_type"] = Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? "void" : "char";
+ Variables_["type"] = "TProtoStringType";
+ Variables_["repeated_field_type"] = "RepeatedPtrField";
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ (Descriptor_->type() == FieldDescriptor::TYPE_BYTES ?
+ "const $pointer_type$* arg_$name$, size_t arg_$name$_size": "const $type$& arg_$name$")
+ );
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(
+ Variables_,
+ Descriptor_->type() == FieldDescriptor::TYPE_BYTES ?
+ "$prefix$add_$name$(arg_$name$, arg_$name$_size);\n" :
+ "$prefix$add_$name$(arg_$name$);\n"
+ );
+ }
+ void GeneratePrintingCode(io::Printer* printer) override {
+ GeneratePrintingCycle(Variables_, "output << \"\\\"\" << $obj$ << \"\\\"\"", printer);
+ }
+ };
+
+ class TEnumFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TEnumFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["type"] = ClassName(Descriptor_->enum_type(), true);
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "$type$ arg_$name$");
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(Variables_,
+ "$prefix$set_$name$(arg_$name$);\n");
+ }
+
+ void GeneratePrintingCode(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "output << $type$_Name($name$());\n");
+ }
+ };
+
+ class TRepeatedEnumFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TRepeatedEnumFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["type"] = ClassName(Descriptor_->enum_type(), true);
+ Variables_["repeated_field_type"] = "RepeatedField";
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "$type$ arg_$name$");
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(Variables_,
+ "$prefix$add_$name$(arg_$name$);\n");
+ }
+
+ void GeneratePrintingCode(io::Printer* printer) override {
+ TStringStream pattern;
+
+ TProtoStringType type = Variables_["type"];
+ pattern << "output << " << type << "_Name(" << type << "($obj$))";
+ Variables_["type"] = "int";
+ GeneratePrintingCycle(Variables_, pattern.Str(), printer);
+ Variables_["type"] = type;
+ }
+ };
+
+ class TPrimitiveFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TPrimitiveFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["type"] = PrimitiveTypeName(Descriptor_->cpp_type());
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "$type$ arg_$name$");
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(Variables_,
+ "$prefix$set_$name$(arg_$name$);\n");
+ }
+
+ void GeneratePrintingCode(io::Printer* printer) override {
+ Repr::ReprType fmt = Repr::none;
+
+ if (Descriptor_->options().HasExtension(repr)) {
+ fmt = Descriptor_->options().GetExtension(repr);
+ }
+
+ switch (fmt) {
+ case Repr::as_bytes:
+ printer->Print(Variables_, "NProtoBuf::PrintAsBytes($name$(), output);\n");
+ break;
+
+ case Repr::as_hex:
+ printer->Print(Variables_, "NProtoBuf::PrintAsHex($name$(), output);\n");
+ break;
+
+ case Repr::none:
+ /* TODO: proper error handling? */
+ default:
+ printer->Print(Variables_, "output << $name$();\n");
+ break;
+ }
+ }
+ };
+
+ class TRepeatedPrimitiveFieldExtGenerator: public TFieldExtGenerator {
+ public:
+ TRepeatedPrimitiveFieldExtGenerator(const FieldDescriptor* field)
+ : TFieldExtGenerator(field)
+ {
+ Variables_["type"] = PrimitiveTypeName(Descriptor_->cpp_type());
+ Variables_["repeated_field_type"] = "RepeatedField";
+ }
+
+ void GenerateCtorArgument(io::Printer* printer) override {
+ printer->Print(Variables_,
+ "$type$ arg_$name$");
+ }
+
+ void GenerateInitializer(io::Printer* printer, const TString& prefix) override {
+ Variables_["prefix"] = prefix;
+ printer->Print(Variables_,
+ "$prefix$add_$name$(arg_$name$);\n");
+ }
+
+ void GeneratePrintingCode(io::Printer* printer) override {
+ GeneratePrintingCycle(Variables_, "output << $obj$", printer);
+ }
+ };
+
+ std::unique_ptr<TFieldExtGenerator> MakeGenerator(const FieldDescriptor* field) {
+ if (field->is_map()) {
+ return std::make_unique<TMapFieldExtGenerator>(field);
+ } else if (field->is_repeated()) {
+ switch (field->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_MESSAGE:
+ return std::make_unique<TRepeatedMessageFieldExtGenerator>(field);
+ case FieldDescriptor::CPPTYPE_STRING:
+ switch (field->options().ctype()) {
+ default: // RepeatedStringFieldExtGenerator handles unknown ctypes.
+ case FieldOptions::STRING:
+ return std::make_unique<TRepeatedStringFieldExtGenerator>(field);
+ }
+ case FieldDescriptor::CPPTYPE_ENUM:
+ return std::make_unique<TRepeatedEnumFieldExtGenerator>(field);
+ default:
+ return std::make_unique<TRepeatedPrimitiveFieldExtGenerator>(field);
+ }
+ } else {
+ switch (field->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_MESSAGE:
+ return std::make_unique<TMessageFieldExtGenerator>(field);
+ case FieldDescriptor::CPPTYPE_STRING:
+ switch (field->options().ctype()) {
+ default: // StringFieldGenerator handles unknown ctypes.
+ case FieldOptions::STRING:
+ return std::make_unique<TStringFieldExtGenerator>(field);
+ }
+ case FieldDescriptor::CPPTYPE_ENUM:
+ return std::make_unique<TEnumFieldExtGenerator>(field);
+ default:
+ return std::make_unique<TPrimitiveFieldExtGenerator>(field);
+ }
+ }
+ }
+
+ class TMessageExtGenerator {
+ public:
+ TMessageExtGenerator(const Descriptor* descriptor, OutputDirectory* outputDirectory)
+ : Descriptor_(descriptor)
+ , HasMessageId_(Descriptor_->options().HasExtension(message_id))
+ , ClassName_(ClassName(Descriptor_, false))
+ , OutputDirectory_(outputDirectory)
+ , HasGeneratorWithProtobufMessageFieldPrinter_(false)
+ , CanGenerateSpecialConstructor_(false)
+ {
+ NestedGenerators_.reserve(descriptor->nested_type_count());
+ for (int i = 0; i < descriptor->nested_type_count(); i++) {
+ NestedGenerators_.emplace_back(descriptor->nested_type(i), OutputDirectory_);
+ }
+
+ if (HasMessageId_) {
+ FieldGenerators_.reserve(descriptor->field_count());
+ for (int i = 0; i < descriptor->field_count(); i++) {
+ FieldGenerators_.emplace_back(MakeGenerator(descriptor->field(i)));
+ HasGeneratorWithProtobufMessageFieldPrinter_ |= FieldGenerators_.back()->NeedProtobufMessageFieldPrinter();
+ }
+ }
+
+ {
+ size_t intFieldCount = 0;
+ size_t mapFieldCount = 0;
+ size_t nonMapFieldCount = 0;
+ for (int i = 0; i < Descriptor_->field_count(); ++i) {
+ const FieldDescriptor* field = Descriptor_->field(i);
+ if (field->is_map()) {
+ ++mapFieldCount;
+ } else {
+ ++nonMapFieldCount;
+ }
+ switch (field->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ case FieldDescriptor::CPPTYPE_INT64:
+ case FieldDescriptor::CPPTYPE_UINT32:
+ case FieldDescriptor::CPPTYPE_UINT64:
+ ++intFieldCount;
+ break;
+ default:
+ break;
+ }
+ }
+
+ CanGenerateSpecialConstructor_ = (
+ // Certain field combinations would result in ambiguous constructor generation.
+ // Do not generate special contructor for such combinations.
+ (intFieldCount != nonMapFieldCount || nonMapFieldCount != 2) &&
+
+ // Generate special contructor only if there is at least one non-map Field.
+ nonMapFieldCount > 0
+ );
+ }
+
+ }
+
+ void GenerateClassDefinitionExtension() {
+ if (Descriptor_->options().HasExtension(realm_name) || Descriptor_->options().HasExtension(message_id)) {
+ GeneratePseudonim();
+ }
+
+ if (!HasMessageId_) {
+ return;
+ }
+
+ CheckMessageId(Descriptor_->options().GetExtension(message_id), ClassName_);
+
+ TProtoStringType fileName = HeaderFileName(Descriptor_->file());
+ TProtoStringType scope = "class_scope:" + Descriptor_->full_name();
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, scope));
+ io::Printer printer(output.get(), '$');
+
+ printer.Print("//Yandex events extension.\n");
+ GenerateHeaderImpl(&printer);
+
+ for (auto& nestedGenerator: NestedGenerators_) {
+ nestedGenerator.GenerateClassDefinitionExtension();
+ }
+ }
+
+ bool GenerateClassExtension() {
+ TProtoStringType fileName = SourceFileName(Descriptor_->file());
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, "namespace_scope"));
+ io::Printer printer(output.get(), '$');
+
+ bool hasEventExtension = GenerateSourceImpl(&printer);
+
+ for (auto& nestedGenerator: NestedGenerators_) {
+ hasEventExtension |= nestedGenerator.GenerateSourceImpl(&printer);
+ }
+
+ return hasEventExtension;
+ }
+
+ void GenerateRegistration(io::Printer* printer) {
+ if (!HasMessageId_) {
+ return;
+ }
+
+ TVariables vars;
+ vars["classname"] = ClassName_;
+
+ printer->Print(vars, "NProtoBuf::TEventFactory::Instance()->RegisterEvent($classname$::descriptor()->options().GetExtension(message_id), factory->GetPrototype($classname$::descriptor()), $classname$::Print);\n");
+ }
+
+ private:
+ void GenerateHeaderImpl(io::Printer* printer) {
+ TVariables vars;
+ TProtoStringType mId(ToString(Descriptor_->options().GetExtension(message_id)));
+ vars["classname"] = ClassName_;
+ vars["messageid"] = mId.data();
+ vars["superclass"] = SuperClassName(Descriptor_, Options{});
+
+ printer->Print(vars,
+ "enum {ID = $messageid$};\n\n");
+
+ {
+ /*
+ * Unconditionally generate FromFields() factory method,
+ * so it could be used in template code
+ */
+ printer->Print(vars, "static $classname$ FromFields(\n");
+ GenerateCtorArgs(printer);
+ printer->Print(");\n");
+ }
+
+ if (CanGenerateSpecialConstructor_) {
+ printer->Print(vars, "$classname$(\n");
+ GenerateCtorArgs(printer);
+ printer->Print(");\n");
+ }
+
+ {
+ printer->Print("void Print(IOutputStream& output, EFieldOutputFlags outputFlags = {}) const;\n");
+ printer->Print("static void Print(const google::protobuf::Message* ev, IOutputStream& output, EFieldOutputFlags outputFlags = {});\n");
+ }
+ }
+
+ void GeneratePseudonim() {
+ TProtoStringType fileName = HeaderFileName(Descriptor_->file());
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, "namespace_scope"));
+ io::Printer printer(output.get(), '$');
+
+ std::vector<TProtoStringType> realm_parts;
+
+ if (Descriptor_->options().HasExtension(realm_name)) {
+ SplitStringUsing(Descriptor_->options().GetExtension(realm_name), ".", &realm_parts);
+ }
+
+ if (realm_parts.size() > 0) printer.Print("\n");
+
+ for (size_t i = 0; i < realm_parts.size(); ++i) {
+ printer.Print("namespace $part$ {\n",
+ "part", realm_parts[i]);
+ }
+
+ printer.Print("typedef $fullclassname$ T$classname$;\n",
+ "fullclassname", FullClassName(Descriptor_),
+ "classname", ClassName_);
+
+ for (size_t i = realm_parts.size(); i > 0; --i) {
+ printer.Print("} // namespace $part$\n",
+ "part", realm_parts[i - 1]);
+ }
+ }
+
+ TProtoStringType FullClassName(const Descriptor* descriptor) {
+ TProtoStringType result;
+ std::vector<TProtoStringType> parts;
+
+ SplitStringUsing(descriptor->file()->package(), ".", &parts);
+ for (size_t i = 0; i < parts.size(); ++i) {
+ result += "::" + parts[i];
+ }
+
+ result += "::" + ClassName(descriptor, false);
+
+ return result;
+ }
+
+ bool GenerateSourceImpl(io::Printer* printer) {
+ if (!HasMessageId_) {
+ return false;
+ }
+
+ TVariables vars;
+ vars["classname"] = ClassName_;
+
+ {
+ // Generate static $classname$::FromFields impl.
+ printer->Print(vars, "$classname$ $classname$::FromFields(\n");
+ GenerateCtorArgs(printer);
+ printer->Print(")\n");
+
+ printer->Print("{\n");
+
+ printer->Indent();
+ printer->Print(vars, "$classname$ result;\n");
+ GenerateFieldInitializers(printer, /* prefix = */ "result.");
+ printer->Print("return result;\n");
+ printer->Outdent();
+
+ printer->Print("}\n\n");
+ }
+
+ if (CanGenerateSpecialConstructor_) {
+ // Generate special constructor impl.
+ printer->Print(vars, "$classname$::$classname$(\n");
+ GenerateCtorArgs(printer);
+ printer->Print(")\n");
+
+ printer->Print("{\n");
+
+ printer->Indent();
+ printer->Print("SharedCtor();\n");
+ GenerateFieldInitializers(printer, /* prefix = */ "");
+ printer->Outdent();
+
+ printer->Print("}\n\n");
+ }
+
+ {
+ // Generate $classname$::Print impl.
+ const size_t fieldCount = Descriptor_->field_count();
+ if (fieldCount > 0) {
+ printer->Print(vars,
+ "void $classname$::Print(IOutputStream& output, EFieldOutputFlags outputFlags) const {\n");
+ printer->Indent();
+ printer->Print(
+ "TEventFieldOutput escapedOutput{output, outputFlags};\n"
+ "Y_UNUSED(escapedOutput);\n");
+
+ if (HasGeneratorWithProtobufMessageFieldPrinter_) {
+ printer->Print(
+ "TEventProtobufMessageFieldPrinter protobufMessageFieldPrinter(EProtobufMessageFieldPrintMode::DEFAULT);\n");
+ }
+ } else {
+ printer->Print(vars,
+ "void $classname$::Print(IOutputStream& output, EFieldOutputFlags) const {\n");
+ printer->Indent();
+ }
+
+ printer->Print(vars,
+ "output << \"$classname$\";\n");
+
+ for (size_t i = 0; i < fieldCount; ++i) {
+ printer->Print("output << \"\\t\";\n");
+ FieldGenerators_[i]->GeneratePrintingCode(printer);
+ }
+
+ printer->Outdent();
+ printer->Print("}\n\n");
+ }
+
+ {
+ // Generate static $classname$::Print impl.
+ printer->Print(vars,
+ "void $classname$::Print(const google::protobuf::Message* ev, IOutputStream& output, EFieldOutputFlags outputFlags) {\n");
+ printer->Indent();
+ printer->Print(vars,
+ "const $classname$* This(CheckedCast<const $classname$*>(ev));\n");
+ printer->Print(
+ "This->Print(output, outputFlags);\n");
+ printer->Outdent();
+ printer->Print("}\n\n");
+ }
+
+ return true;
+ }
+
+ void GenerateCtorArgs(io::Printer* printer) {
+ printer->Indent();
+ const size_t fieldCount = Descriptor_->field_count();
+ bool isFirst = true;
+ for (size_t i = 0; i < fieldCount; ++i) {
+ if (Descriptor_->field(i)->is_map()) {
+ continue;
+ }
+ const char* delimiter = isFirst ? "" : ", ";
+ isFirst = false;
+ printer->Print(delimiter);
+ FieldGenerators_[i]->GenerateCtorArgument(printer);
+ }
+ printer->Outdent();
+ }
+
+ void GenerateFieldInitializers(io::Printer* printer, const TString& prefix) {
+ for (auto& fieldGeneratorHolder: FieldGenerators_) {
+ fieldGeneratorHolder->GenerateInitializer(printer, prefix);
+ }
+ }
+
+ private:
+ const Descriptor* Descriptor_;
+ const bool HasMessageId_;
+ TProtoStringType ClassName_;
+ OutputDirectory* OutputDirectory_;
+ bool HasGeneratorWithProtobufMessageFieldPrinter_;
+ bool CanGenerateSpecialConstructor_;
+ std::vector<std::unique_ptr<TFieldExtGenerator>> FieldGenerators_;
+ std::vector<TMessageExtGenerator> NestedGenerators_;
+ };
+
+ class TFileExtGenerator {
+ public:
+ TFileExtGenerator(const FileDescriptor* file, OutputDirectory* output_directory)
+ : OutputDirectory_(output_directory)
+ , File_(file)
+ {
+ MessageGenerators_.reserve(file->message_type_count());
+ for (int i = 0; i < file->message_type_count(); i++) {
+ MessageGenerators_.emplace_back(file->message_type(i), OutputDirectory_);
+ }
+ }
+
+ void GenerateHeaderExtensions() {
+ TProtoStringType fileName = HeaderFileName(File_);
+
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, "includes"));
+ io::Printer printer(output.get(), '$');
+
+ printer.Print("#include <library/cpp/eventlog/event_field_output.h>\n");
+ printer.Print("#include <library/cpp/eventlog/event_field_printer.h>\n");
+
+ for (auto& messageGenerator: MessageGenerators_) {
+ messageGenerator.GenerateClassDefinitionExtension();
+ }
+ }
+
+ void GenerateSourceExtensions() {
+ TProtoStringType fileName = SourceFileName(File_);
+
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, "includes"));
+ io::Printer printer(output.get(), '$');
+ printer.Print("#include <google/protobuf/io/printer.h>\n");
+ printer.Print("#include <google/protobuf/io/zero_copy_stream_impl_lite.h>\n");
+ printer.Print("#include <google/protobuf/stubs/strutil.h>\n");
+ printer.Print("#include <library/cpp/eventlog/events_extension.h>\n");
+ printer.Print("#include <util/generic/cast.h>\n");
+ printer.Print("#include <util/stream/output.h>\n");
+
+ bool hasEventExtension = false;
+
+ for (auto& messageGenerator: MessageGenerators_) {
+ hasEventExtension |= messageGenerator.GenerateClassExtension();
+ }
+
+ if (hasEventExtension) {
+ GenerateEventRegistrations();
+ }
+ }
+
+ void GenerateEventRegistrations() {
+ TVariables vars;
+ TProtoStringType fileId = FilenameIdentifier(File_->name());
+ vars["regfunction"] = "regevent_" + fileId;
+ vars["regclassname"] = "TRegister_" + fileId;
+ vars["regvarname"] = "registrator_" + fileId ;
+ vars["filename"] = File_->name();
+
+ {
+ TProtoStringType fileName = SourceFileName(File_);
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, "namespace_scope"));
+ io::Printer printer(output.get(), '$');
+
+ GenerateRegistrationFunction(vars, printer);
+ GenerateRegistratorDefinition(vars, printer);
+ }
+
+ {
+
+ TProtoStringType fileName = HeaderFileName(File_);
+ std::unique_ptr<io::ZeroCopyOutputStream> output(
+ OutputDirectory_->OpenForInsert(fileName, "namespace_scope"));
+ io::Printer printer(output.get(), '$');
+ GenerateRegistratorDeclaration(vars, printer);
+ }
+ }
+
+ void GenerateRegistrationFunction(const TVariables& vars, io::Printer& printer) {
+ printer.Print(vars,
+ "void $regfunction$() {\n");
+ printer.Indent();
+
+ printer.Print("google::protobuf::MessageFactory* factory = google::protobuf::MessageFactory::generated_factory();\n\n");
+ for (auto& messageGenerator: MessageGenerators_) {
+ messageGenerator.GenerateRegistration(&printer);
+ }
+ printer.Outdent();
+ printer.Print("}\n\n");
+ }
+
+ void GenerateRegistratorDeclaration(const TVariables& vars, io::Printer& printer) {
+ printer.Print(vars, "\nclass $regclassname$ {\n");
+ printer.Print("public:\n");
+ printer.Indent();
+ printer.Print(vars, "$regclassname$();\n");
+ printer.Outdent();
+ printer.Print("private:\n");
+ printer.Indent();
+ printer.Print("static bool Registered;\n");
+ printer.Outdent();
+ printer.Print(vars, "};\n");
+ printer.Print(vars, "static $regclassname$ $regvarname$;\n\n");
+ }
+
+ void GenerateRegistratorDefinition(const TVariables& vars, io::Printer& printer) {
+ printer.Print(vars, "$regclassname$::$regclassname$() {\n");
+ printer.Indent();
+ printer.Print("if (!Registered) {\n");
+ printer.Indent();
+ printer.Print(vars, "NProtoBuf::TEventFactory::Instance()->ScheduleRegistration(&$regfunction$);\n");
+ printer.Print("Registered = true;\n");
+ printer.Outdent();
+ printer.Print("}\n");
+ printer.Outdent();
+ printer.Print("}\n\n");
+ printer.Print(vars, "bool $regclassname$::Registered;\n\n");
+ }
+ private:
+ OutputDirectory* OutputDirectory_;
+ const FileDescriptor* File_;
+ std::vector<TMessageExtGenerator> MessageGenerators_;
+ };
+}
+
+ bool TProtoEventExtensionGenerator::Generate(const google::protobuf::FileDescriptor* file,
+ const TProtoStringType& parameter,
+ google::protobuf::compiler::OutputDirectory* outputDirectory,
+ TProtoStringType* error) const {
+ Y_UNUSED(parameter);
+ Y_UNUSED(error);
+
+ NInternal::TFileExtGenerator fileGenerator(file, outputDirectory);
+
+ // Generate header.
+ fileGenerator.GenerateHeaderExtensions();
+
+ // Generate cc file.
+ fileGenerator.GenerateSourceExtensions();
+
+ return true;
+ }
+
+} // namespace NProtoBuf::NCompiler::NPlugins
+
+int main(int argc, char* argv[]) {
+#ifdef _MSC_VER
+ // Don't print a silly message or stick a modal dialog box in my face,
+ // please.
+ _set_abort_behavior(0u, ~0u);
+#endif // !_MSC_VER
+
+ try {
+ NProtoBuf::NCompiler::NPlugins::TProtoEventExtensionGenerator generator;
+ return google::protobuf::compiler::PluginMain(argc, argv, &generator);
+ } catch (yexception& e) {
+ Cerr << e.what() << Endl;
+ } catch (...) {
+ Cerr << "Unknown error in TProtoEventExtensionGenerator" << Endl;
+ }
+
+ return 1;
+}
diff --git a/tools/event2cpp/proto_events.h b/tools/event2cpp/proto_events.h
new file mode 100644
index 0000000000..628b4856af
--- /dev/null
+++ b/tools/event2cpp/proto_events.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include <google/protobuf/compiler/plugin.h>
+#include <google/protobuf/compiler/code_generator.h>
+#include <google/protobuf/stubs/common.h>
+
+namespace NProtoBuf::NCompiler::NPlugins {
+
+class TProtoEventExtensionGenerator : public google::protobuf::compiler::CodeGenerator {
+ public:
+ TProtoEventExtensionGenerator() {}
+ ~TProtoEventExtensionGenerator() override {}
+
+ bool Generate(const google::protobuf::FileDescriptor* file,
+ const TProtoStringType& parameter,
+ google::protobuf::compiler::OutputDirectory* output_directory,
+ TProtoStringType* error) const override;
+};
+
+} // namespace NProtoBuf::NCompiler::NPlugins
diff --git a/tools/go_test_miner/main.go b/tools/go_test_miner/main.go
new file mode 100644
index 0000000000..43b729572e
--- /dev/null
+++ b/tools/go_test_miner/main.go
@@ -0,0 +1,168 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "go/importer"
+ "go/token"
+ "go/types"
+ "os"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "sort"
+ "strings"
+ "unicode"
+ "unicode/utf8"
+)
+
+const (
+ usageTemplate = "Usage: %s [-benchmarks] [-examples] [-tests] import-path\n"
+)
+
+func findObjectByName(pkg *types.Package, re *regexp.Regexp, name string) types.Object {
+ if pkg != nil && re != nil && len(name) > 0 {
+ if obj := pkg.Scope().Lookup(name); obj != nil {
+ if re.MatchString(obj.Type().String()) {
+ return obj
+ }
+ }
+ }
+ return nil
+}
+
+func isTestName(name, prefix string) bool {
+ ok := false
+ if strings.HasPrefix(name, prefix) {
+ if len(name) == len(prefix) {
+ ok = true
+ } else {
+ rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
+ ok = !unicode.IsLower(rune)
+ }
+ }
+ return ok
+}
+
+func main() {
+ testsPtr := flag.Bool("tests", false, "report tests")
+ benchmarksPtr := flag.Bool("benchmarks", false, "report benchmarks")
+ examplesPtr := flag.Bool("examples", false, "report examples")
+
+ flag.Usage = func() {
+ _, _ = fmt.Fprintf(flag.CommandLine.Output(), usageTemplate, filepath.Base(os.Args[0]))
+ flag.PrintDefaults()
+ }
+
+ flag.Parse()
+
+ // Check if the number of positional parameters matches
+ args := flag.Args()
+ argsCount := len(args)
+ if argsCount != 1 {
+ exitCode := 0
+ if argsCount > 1 {
+ fmt.Println("Error: invalid number of parameters...")
+ exitCode = 1
+ }
+ flag.Usage()
+ os.Exit(exitCode)
+ }
+
+ importPath := args[0]
+
+ var fset token.FileSet
+ imp := importer.ForCompiler(&fset, runtime.Compiler, nil)
+ pkg, err := imp.Import(importPath)
+ if err != nil {
+ fmt.Printf("Error: %v\n", err)
+ os.Exit(1)
+ }
+
+ if !*testsPtr && !*benchmarksPtr && !*examplesPtr {
+ // Nothing to do, just exit normally
+ os.Exit(0)
+ }
+
+ // // First approach: just dump the package scope as a string
+ // // package "junk/snermolaev/libmath" scope 0xc0000df540 {
+ // // . func junk/snermolaev/libmath.Abs(a int) int
+ // // . func junk/snermolaev/libmath.AbsReport(s string)
+ // // . func junk/snermolaev/libmath.Sum(a int, b int) int
+ // // . func junk/snermolaev/libmath.TestAbs(t *testing.T)
+ // // . func junk/snermolaev/libmath.TestSum(t *testing.T)
+ // // . func junk/snermolaev/libmath.init()
+ // // }
+ // // and then collect all functions that match test function signature
+ // pkgPath := pkg.Path()
+ // scopeContent := strings.Split(pkg.Scope().String(), "\n")
+ // re := regexp.MustCompile("^\\.\\s*func\\s*" + pkgPath + "\\.(Test\\w*)\\(\\s*\\w*\\s*\\*\\s*testing\\.T\\s*\\)$")
+ // for _, name := range scopeContent {
+ // match := re.FindAllStringSubmatch(name, -1)
+ // if len(match) > 0 {
+ // fmt.Println(match[0][1])
+ // }
+ // }
+
+ // Second approach: look through all names defined in the pkg scope
+ // and collect those functions that match test function signature
+ // Unfortunately I failed to employ reflection mechinary for signature
+ // comparison for unknown reasons (this needs additional investigation
+ // I am going to use regexp as workaround for a while)
+ // testFunc := func (*testing.T) {}
+ // for ...
+ // ...
+ // if reflect.DeepEqual(obj.Type(), reflect.TypeOf(testFunc)) {
+ // // this condition doesn't work
+ // }
+ reBenchmark := regexp.MustCompile(`^func\(\w*\s*\*testing\.B\)$`)
+ reExample := regexp.MustCompile(`^func\(\s*\)$`)
+ reTest := regexp.MustCompile(`^func\(\w*\s*\*testing\.T\)$`)
+ reTestMain := regexp.MustCompile(`^func\(\w*\s*\*testing\.M\)$`)
+
+ var re *regexp.Regexp
+ names := pkg.Scope().Names()
+
+ var testFns []types.Object
+
+ for _, name := range names {
+ if name == "TestMain" && findObjectByName(pkg, reTestMain, name) != nil {
+ fmt.Println("#TestMain")
+ continue
+ }
+
+ switch {
+ case *benchmarksPtr && isTestName(name, "Benchmark"):
+ re = reBenchmark
+ case *examplesPtr && isTestName(name, "Example"):
+ re = reExample
+ case *testsPtr && isTestName(name, "Test"):
+ re = reTest
+ default:
+ continue
+ }
+
+ if obj := findObjectByName(pkg, re, name); obj != nil {
+ testFns = append(testFns, obj)
+ }
+ }
+
+ sort.Slice(testFns, func(i, j int) bool {
+ iPos := testFns[i].Pos()
+ jPos := testFns[j].Pos()
+
+ if !iPos.IsValid() || !jPos.IsValid() {
+ return iPos < jPos
+ }
+
+ iPosition := fset.PositionFor(iPos, true)
+ jPosition := fset.PositionFor(jPos, true)
+
+ return iPosition.Filename < jPosition.Filename ||
+ (iPosition.Filename == jPosition.Filename && iPosition.Line < jPosition.Line)
+ })
+
+ for _, testFn := range testFns {
+ fmt.Println(testFn.Name())
+ }
+}
diff --git a/tools/triecompiler/lib/main.cpp b/tools/triecompiler/lib/main.cpp
new file mode 100644
index 0000000000..320f46d48a
--- /dev/null
+++ b/tools/triecompiler/lib/main.cpp
@@ -0,0 +1,473 @@
+#include "main.h"
+
+#ifndef CATBOOST_OPENSOURCE
+#include <library/cpp/charset/recyr.hh>
+#endif
+
+#include <library/cpp/containers/comptrie/comptrie.h>
+#include <library/cpp/deprecated/mapped_file/mapped_file.h>
+#include <library/cpp/getopt/small/last_getopt.h>
+
+#include <util/charset/wide.h>
+#include <util/generic/buffer.h>
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/stream/buffered.h>
+#include <util/stream/file.h>
+#include <util/stream/input.h>
+#include <util/stream/output.h>
+#include <util/stream/output.h>
+#include <util/string/cast.h>
+#include <util/string/util.h>
+#include <util/system/filemap.h>
+
+#include <string>
+
+#ifdef WIN32
+#include <crtdbg.h>
+#include <windows.h>
+#endif // WIN32
+
+namespace {
+ struct TOptions {
+ TString Prog = {};
+ TString Triefile = {};
+ TString Infile = {};
+ bool Minimized = false;
+ bool FastLayout = false;
+ TCompactTrieBuilderFlags Flags = CTBF_NONE;
+ bool Unicode = false;
+ bool Verify = false;
+ bool Wide = false;
+ bool NoValues = false;
+ bool Vector = false;
+ int ArraySize = 0;
+ TString ValueType = {};
+ bool AllowEmptyKey = false;
+ bool UseAsIsParker = false;
+ };
+} // namespace
+
+static TOptions ParseOptions(const int argc, const char* argv[]) {
+ TOptions options;
+ auto parser = NLastGetopt::TOpts::Default();
+ parser
+ .AddLongOption('t', "type")
+ .DefaultValue("ui64")
+ .RequiredArgument("TYPE")
+ .StoreResult(&options.ValueType)
+ .Help("type of value or array element, possible: ui16, i16, ui32, i32, ui64, i64, float, double, bool,"
+ " TString, TUtf16String (utf-8 in input)");
+ parser
+ .AddCharOption('0')
+ .NoArgument()
+ .SetFlag(&options.NoValues)
+ .Help("Do not store values to produce TCompactTrieSet compatible binary");
+ parser
+ .AddLongOption('a', "array")
+ .NoArgument()
+ .SetFlag(&options.Vector)
+ .Help("Values are arrays (of some type, depending on -t flag). Input looks like"
+ " 'key <TAB> value1 <TAB> .. <TAB> valueN' (N may differ from line to line)");
+ parser
+ .AddCharOption('S')
+ .RequiredArgument("INT")
+ .StoreResult(&options.ArraySize)
+ .Help("Values are fixed size not packed (!) arrays, may used for mapping classes/structs"
+ " with monotyped fields");
+ parser
+ .AddLongOption('e', "allow-empty")
+ .NoArgument()
+ .SetFlag(&options.AllowEmptyKey)
+ .Help("Allow empty key");
+ parser
+ .AddLongOption('i', "input")
+ .DefaultValue("-")
+ .RequiredArgument("FILE")
+ .StoreResult(&options.Infile)
+ .Help("Input file");
+ parser
+ .AddLongOption('m', "minimize")
+ .NoArgument()
+ .SetFlag(&options.Minimized)
+ .Help("Minimize tree into a DAG");
+ parser
+ .AddLongOption('f', "fast-layout")
+ .NoArgument()
+ .SetFlag(&options.FastLayout)
+ .Help("Make fast layout");
+ parser
+ .AddLongOption('v', "verbose")
+ .NoArgument()
+ .Help("Be verbose - show progress & stats");
+ parser
+ .AddLongOption('s', "prefix-grouped")
+ .NoArgument()
+ .Help("Assume input is prefix-grouped by key (for every prefix all keys with this prefix"
+ " come in one group; greatly reduces memory usage)");
+ parser
+ .AddLongOption('q', "unique-keys")
+ .NoArgument()
+ .Help("Assume the keys are unique (will report an error otherwise)");
+ parser
+ .AddLongOption('c', "check")
+ .NoArgument()
+ .Help("Check the compiled trie (works only with an explicit input file name)");
+ parser
+ .AddCharOption('u')
+ .NoArgument()
+ .SetFlag(&options.Unicode)
+ .Help("Recode keys from UTF-8 to Yandex (deprecated)");
+ parser
+ .AddLongOption('w', "wide")
+ .NoArgument()
+ .SetFlag(&options.Wide)
+ .Help("Treat input keys as UTF-8, recode to TChar (wchar16)");
+ parser
+ .AddLongOption('P', "as-is-packer")
+ .NoArgument()
+ .SetFlag(&options.UseAsIsParker)
+ .Help("Use AsIsParker to pack value in trie");
+ parser.AddHelpOption('h');
+ parser.SetFreeArgsNum(1);
+ parser.SetFreeArgTitle(0, "TRIE_FILE", "Compiled trie");
+ NLastGetopt::TOptsParseResult parsed{&parser, argc, argv};
+ options.Triefile = parsed.GetFreeArgs().front();
+ if (parsed.Has('q')) {
+ options.Flags |= CTBF_UNIQUE;
+ }
+ if (parsed.Has('s')) {
+ options.Flags |= CTBF_PREFIX_GROUPED;
+ }
+ if (parsed.Has('v')) {
+ options.Flags |= CTBF_VERBOSE;
+ }
+ return options;
+}
+
+namespace {
+ template <class T>
+ struct TFromString {
+ T operator() (const char* start, size_t len) const {
+ return FromStringImpl<T>(start, len);
+ }
+ };
+
+ template <>
+ struct TFromString<TUtf16String> {
+ TUtf16String operator ()(const char* start, size_t len) const {
+ return UTF8ToWide(start, len);
+ }
+ };
+
+ template <class TTKey, class TTKeyChar, class TTValue,
+ class TKeyReader = TFromString<TTKey>,
+ class TValueReader = TFromString<TTValue> >
+ struct TRecord {
+ typedef TTKey TKey;
+ typedef TTKeyChar TKeyChar;
+ typedef TTValue TValue;
+ TKey Key;
+ TValue Value;
+ TString Tmp;
+ bool Load(IInputStream& in, const bool allowEmptyKey, const bool noValues) {
+ while (in.ReadLine(Tmp)) {
+ if (!Tmp) {
+ // there is a special case for TrieSet with empty keys allowed
+ if (!(noValues && allowEmptyKey)) {
+ continue;
+ }
+ }
+
+ const size_t sep = Tmp.find('\t');
+ if (sep != TString::npos) {
+ if (0 == sep && !allowEmptyKey) {
+ continue;
+ }
+ Key = TKeyReader()(Tmp.data(), sep);
+ Value = TValueReader()(Tmp.data() + sep + 1, Tmp.size() - sep - 1);
+ } else if (noValues) {
+ RemoveIfLast<TString>(Tmp, '\n');
+ Key = TKeyReader()(Tmp.data(), Tmp.size());
+ Value = TValue();
+ }
+ return true;
+ }
+ return false;
+ }
+ };
+
+ template <class TTKey, class TTKeyChar, class T,
+ class TKeyReader = TFromString<TTKey>,
+ class TValueReader = TFromString<T> >
+ struct TVectorRecord {
+ typedef TTKey TKey;
+ typedef TTKeyChar TKeyChar;
+ typedef TVector<T> TValue;
+ TKey Key;
+ TValue Value;
+ TString Tmp;
+
+ bool Load(IInputStream& in, const bool allowEmptyKey, const bool noValues) {
+ Y_UNUSED(noValues);
+ while (in.ReadLine(Tmp)) {
+ if (!Tmp && !allowEmptyKey) {
+ continue;
+ }
+
+ size_t sep = Tmp.find('\t');
+ if (sep == TString::npos) {
+ RemoveIfLast<TString>(Tmp, '\n');
+ Key = TKeyReader()(Tmp.data(), Tmp.size());
+ Value = TValue();
+ } else {
+ Key = TKeyReader()(Tmp.data(), sep);
+ Value = TValue();
+ while (sep != Tmp.size()) {
+ size_t sep2 = Tmp.find('\t', sep + 1);
+ if (sep2 == TString::npos) {
+ sep2 = Tmp.size();
+ }
+
+ if (sep + 1 != sep2) {
+ Value.push_back(TValueReader()(Tmp.data() + sep + 1, sep2 - sep - 1));
+ }
+ sep = sep2;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+ };
+
+ template <typename TVectorType>
+ class TFixedArrayAsIsPacker {
+ public:
+ TFixedArrayAsIsPacker()
+ : ArraySize(0)
+ , SizeOfValue(0)
+ {
+ }
+ explicit TFixedArrayAsIsPacker(size_t arraySize)
+ : ArraySize(arraySize)
+ , SizeOfValue(arraySize * sizeof(typename TVectorType::value_type))
+ {
+ }
+ void UnpackLeaf(const char* p, TVectorType& t) const {
+ const typename TVectorType::value_type* beg = reinterpret_cast<const typename TVectorType::value_type*>(p);
+ t.assign(beg, beg + ArraySize);
+ }
+ void PackLeaf(char* buffer, const TVectorType& data, size_t computedSize) const {
+ Y_ASSERT(computedSize == SizeOfValue && data.size() == ArraySize);
+ memcpy(buffer, data.data(), computedSize);
+ }
+ size_t MeasureLeaf(const TVectorType& data) const {
+ Y_UNUSED(data);
+ Y_ASSERT(data.size() == ArraySize);
+ return SizeOfValue;
+ }
+ size_t SkipLeaf(const char* ) const {
+ return SizeOfValue;
+ }
+ private:
+ size_t ArraySize;
+ size_t SizeOfValue;
+ };
+
+#ifndef CATBOOST_OPENSOURCE
+ struct TUTF8ToYandexRecoder {
+ TString operator()(const char* s, size_t len) {
+ return Recode(CODES_UTF8, CODES_YANDEX, TString(s, len));
+ }
+ };
+#endif
+
+ struct TUTF8ToWideRecoder {
+ TUtf16String operator()(const char* s, size_t len) {
+ return UTF8ToWide(s, len);
+ }
+ };
+} // namespace
+
+template <class TRecord, class TPacker>
+static int ProcessFile(IInputStream& in, const TOptions& o, const TPacker& packer) {
+ TFixedBufferFileOutput out(o.Triefile);
+ typedef typename TRecord::TKeyChar TKeyChar;
+ typedef typename TRecord::TValue TValue;
+
+ THolder< TCompactTrieBuilder<TKeyChar, TValue, TPacker> > builder(new TCompactTrieBuilder<TKeyChar, TValue, TPacker>(o.Flags, packer));
+
+ TRecord r;
+ while (r.Load(in, o.AllowEmptyKey, o.NoValues)) {
+ builder->Add(r.Key.data(), r.Key.size(), r.Value);
+ }
+
+ if (o.Flags & CTBF_VERBOSE) {
+ Cerr << Endl;
+ Cerr << "Entries: " << builder->GetEntryCount() << Endl;
+ Cerr << "Tree nodes: " << builder->GetNodeCount() << Endl;
+ }
+ TBufferOutput inputForFastLayout;
+ IOutputStream* currentOutput = &out;
+ if (o.FastLayout) {
+ currentOutput = &inputForFastLayout;
+ }
+ if (o.Minimized) {
+ TBufferOutput raw;
+ size_t datalength = builder->Save(raw);
+ if (o.Flags & CTBF_VERBOSE)
+ Cerr << "Data length (before compression): " << datalength << Endl;
+ builder.Destroy();
+
+ datalength = CompactTrieMinimize(*currentOutput, raw.Buffer().Data(), raw.Buffer().Size(), o.Flags & CTBF_VERBOSE, packer);
+ if (o.Flags & CTBF_VERBOSE)
+ Cerr << "Data length (minimized): " << datalength << Endl;
+ } else {
+ size_t datalength = builder->Save(*currentOutput);
+ if (o.Flags & CTBF_VERBOSE)
+ Cerr << "Data length: " << datalength << Endl;
+ }
+ if (o.FastLayout) {
+ builder.Destroy();
+ size_t datalength = CompactTrieMakeFastLayout(out, inputForFastLayout.Buffer().Data(),
+ inputForFastLayout.Buffer().Size(), o.Flags & CTBF_VERBOSE, packer);
+ if (o.Flags & CTBF_VERBOSE)
+ Cerr << "Data length (fast layout): " << datalength << Endl;
+ }
+
+ return 0;
+}
+
+template <class TRecord, class TPacker>
+static int VerifyFile(const TOptions& o, const TPacker& packer) {
+ TMappedFile filemap(o.Triefile);
+ typedef typename TRecord::TKeyChar TKeyChar;
+ typedef typename TRecord::TValue TValue;
+ TCompactTrie<TKeyChar, TValue, TPacker> trie((const char*)filemap.getData(), filemap.getSize(), packer);
+
+ TFileInput in(o.Infile);
+ size_t entrycount = 0;
+ int retcode = 0;
+ TRecord r;
+ while (r.Load(in, o.AllowEmptyKey, o.NoValues)) {
+ entrycount++;
+ TValue trievalue;
+
+ if (!trie.Find(r.Key.data(), r.Key.size(), &trievalue)) {
+ Cerr << "Trie check failed on key #" << entrycount << "\"" << r.Key << "\": no key present" << Endl;
+ retcode = 1;
+ } else if (!o.NoValues && trievalue != r.Value) {
+ Cerr << "Trie check failed on key #" << entrycount << "\"" << r.Key << "\": value mismatch" << Endl;
+ retcode = 1;
+ }
+ }
+
+ for (typename TCompactTrie<TKeyChar, TValue, TPacker>::TConstIterator iter = trie.Begin(); iter != trie.End(); ++iter) {
+ entrycount--;
+ }
+
+ if (entrycount) {
+ Cerr << "Broken iteration: entry count mismatch" << Endl;
+ retcode = 1;
+ }
+
+ if ((o.Flags & CTBF_VERBOSE) && !retcode) {
+ Cerr << "Trie check successful" << Endl;
+ }
+ return retcode;
+}
+
+template <class TRecord, class TPacker>
+static int SelectInput(const TOptions& o, const TPacker& packer) {
+ if ("-"sv == o.Infile) {
+ TBufferedInput wrapper{&Cin};
+ return ProcessFile<TRecord>(wrapper, o, packer);
+ }
+
+ TFileInput in(o.Infile);
+ return ProcessFile<TRecord>(in, o, packer);
+}
+
+template <class TRecord, class TPacker>
+static int DoMain(const TOptions& o, const TPacker& packer) {
+ int retcode = SelectInput<TRecord>(o, packer);
+ if (!retcode && o.Verify && !o.Triefile.empty())
+ retcode = VerifyFile<TRecord>(o, packer);
+ return retcode;
+}
+
+// TRecord - nested template parameter
+template<class TValue,
+ template<class TKey, class TKeyChar, class TValueOther,
+ class TKeyReader,
+ class TValueReader> class TRecord, class TPacker>
+static int ProcessInput(const TOptions& o, const TPacker& packer) {
+ if (!o.Wide) {
+ if (!o.Unicode) {
+ return DoMain< TRecord< TString, char, TValue, TFromString<TString>, TFromString<TValue> > >(o, packer);
+ } else {
+ #ifndef CATBOOST_OPENSOURCE
+ return DoMain< TRecord< TString, char, TValue, TUTF8ToYandexRecoder, TFromString<TValue> > >(o, packer);
+ #else
+ Y_FAIL("Yandex encoding is not supported in CATBOOST_OPENSOURCE mode");
+ #endif
+ }
+ } else {
+ return DoMain< TRecord< TUtf16String, TChar, TValue, TUTF8ToWideRecoder, TFromString<TValue> > >(o, packer);
+ }
+ }
+
+template <class TItemType>
+static int ProcessInput(const TOptions& o) {
+ if (o.ArraySize > 0) {
+ return ProcessInput<TItemType, TVectorRecord>(o, TFixedArrayAsIsPacker<TVector<TItemType> >(o.ArraySize));
+ } else if (o.Vector) {
+ return ProcessInput<TItemType, TVectorRecord>(o, TCompactTriePacker<TVector<TItemType> >());
+ } else if (o.UseAsIsParker) {
+ return ProcessInput<TItemType, TRecord>(o, TAsIsPacker<TItemType>());
+ } else {
+ return ProcessInput<TItemType, TRecord>(o, TCompactTriePacker<TItemType>());
+ }
+}
+
+static int Main(const int argc, const char* argv[])
+try {
+#ifdef WIN32
+ _CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
+ ::SetConsoleCP(1251);
+ ::SetConsoleOutputCP(1251);
+#endif // WIN32
+ const TOptions o = ParseOptions(argc, argv);
+ if (o.NoValues) {
+ return ProcessInput<ui64, TRecord>(o, TNullPacker<ui64>());
+ } else {
+#define CHECK_TYPE_AND_PROCESS(valueType) \
+ if (o.ValueType == #valueType) { \
+ return ProcessInput<valueType>(o); \
+ }
+ CHECK_TYPE_AND_PROCESS(ui16)
+ CHECK_TYPE_AND_PROCESS(i16)
+ CHECK_TYPE_AND_PROCESS(ui32)
+ CHECK_TYPE_AND_PROCESS(i32)
+ CHECK_TYPE_AND_PROCESS(ui64)
+ CHECK_TYPE_AND_PROCESS(i64)
+ CHECK_TYPE_AND_PROCESS(bool)
+ CHECK_TYPE_AND_PROCESS(float)
+ CHECK_TYPE_AND_PROCESS(double)
+ CHECK_TYPE_AND_PROCESS(TString)
+ CHECK_TYPE_AND_PROCESS(TUtf16String)
+#undef CHECK_TYPE_AND_PROCESS
+ ythrow yexception() << "unknown type for -t option: " << o.ValueType;
+ }
+} catch (const std::exception& e) {
+ Cerr << "Exception: " << e.what() << Endl;
+ return 2;
+} catch (...) {
+ Cerr << "Unknown exception!\n";
+ return 3;
+}
+
+int NTrieOps::MainCompile(const int argc, const char* argv[]) {
+ return ::Main(argc, argv);
+}
diff --git a/tools/triecompiler/lib/main.h b/tools/triecompiler/lib/main.h
new file mode 100644
index 0000000000..34bd02be1e
--- /dev/null
+++ b/tools/triecompiler/lib/main.h
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace NTrieOps {
+ int MainCompile(const int argc, const char* argv[]);
+}
diff --git a/tools/triecompiler/main.cpp b/tools/triecompiler/main.cpp
new file mode 100644
index 0000000000..4c22a2bd8a
--- /dev/null
+++ b/tools/triecompiler/main.cpp
@@ -0,0 +1,5 @@
+#include <tools/triecompiler/lib/main.h>
+
+int main(const int argc, const char* argv[]) {
+ return NTrieOps::MainCompile(argc, argv);
+}