diff options
author | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
---|---|---|
committer | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
commit | 22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch) | |
tree | bffa27765faf54126ad44bcafa89fadecb7a73d7 /tools | |
parent | 332b99e2173f0425444abb759eebcb2fafaa9209 (diff) | |
download | ydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz |
validate canons without yatest_common
Diffstat (limited to 'tools')
-rw-r--r-- | tools/event2cpp/proto_events.cpp | 893 | ||||
-rw-r--r-- | tools/event2cpp/proto_events.h | 20 | ||||
-rw-r--r-- | tools/go_test_miner/main.go | 168 | ||||
-rw-r--r-- | tools/triecompiler/lib/main.cpp | 473 | ||||
-rw-r--r-- | tools/triecompiler/lib/main.h | 5 | ||||
-rw-r--r-- | tools/triecompiler/main.cpp | 5 |
6 files changed, 1564 insertions, 0 deletions
diff --git a/tools/event2cpp/proto_events.cpp b/tools/event2cpp/proto_events.cpp new file mode 100644 index 0000000000..66e9296d2c --- /dev/null +++ b/tools/event2cpp/proto_events.cpp @@ -0,0 +1,893 @@ +#include <google/protobuf/compiler/cpp/cpp_helpers.h> +#include <google/protobuf/io/zero_copy_stream.h> +#include <google/protobuf/io/printer.h> +#include <google/protobuf/stubs/strutil.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/descriptor.pb.h> + +#include <util/string/cast.h> +#include <util/generic/singleton.h> +#include <util/generic/yexception.h> + +#include <library/cpp/eventlog/proto/events_extension.pb.h> + +#include "proto_events.h" + +namespace NProtoBuf::NCompiler::NPlugins { + +namespace NInternal { + using namespace google::protobuf; + using namespace google::protobuf::compiler; + using namespace google::protobuf::compiler::cpp; + + typedef std::map<TProtoStringType, TProtoStringType> TVariables; + + void CheckMessageId(size_t id, const TProtoStringType& name) { + typedef std::map<size_t, TProtoStringType> TMessageIds; + TMessageIds* ids = Singleton<TMessageIds>(); + TMessageIds::const_iterator it = ids->find(id); + + if (it != ids->end()) { + throw yexception() << "Duplicate message_id = " << id + << " in messages " << name + << " and " << it->second << Endl; + } + + (*ids)[id] = name; + } + + void SetCommonFieldVariables(const FieldDescriptor* descriptor, TVariables* variables) { + (*variables)["rname"] = descriptor->name(); + (*variables)["name"] = FieldName(descriptor); + } + + TProtoStringType HeaderFileName(const FileDescriptor* file) { + TProtoStringType basename = cpp::StripProto(file->name()); + + return basename.append(".pb.h"); + } + + TProtoStringType SourceFileName(const FileDescriptor* file) { + TProtoStringType basename = cpp::StripProto(file->name()); + + return basename.append(".pb.cc"); + } + + void GeneratePrintingCycle(TVariables vars, TProtoStringType printTemplate, io::Printer* printer) { + printer->Print("\n{\n"); + printer->Indent(); + printer->Print(vars, + "NProtoBuf::$repeated_field_type$< $type$ >::const_iterator b = $name$().begin();\n" + "NProtoBuf::$repeated_field_type$< $type$ >::const_iterator e = $name$().end();\n\n"); + printer->Print("output << \"[\";\n"); + printer->Print("if (b != e) {\n"); + vars["obj"] = "(*b++)"; + printer->Print(vars, printTemplate.c_str()); + printer->Print(";\n"); + printer->Print(vars, + "for (NProtoBuf::$repeated_field_type$< $type$ >::const_iterator it = b; it != e; ++it) {\n"); + printer->Indent(); + printer->Print("output << \",\";\n"); + vars["obj"] = "(*it)"; + printer->Print(vars, printTemplate.c_str()); + printer->Print(";\n"); + printer->Outdent(); + printer->Print("}\n}\n"); + printer->Print("output << \"]\";\n"); + printer->Outdent(); + printer->Print("}\n"); + } + + class TFieldExtGenerator { + public: + TFieldExtGenerator(const FieldDescriptor* field) + : Descriptor_(field) + { + SetCommonFieldVariables(Descriptor_, &Variables_); + } + + virtual ~TFieldExtGenerator() { + } + + virtual bool NeedProtobufMessageFieldPrinter() const { + return false; + } + + virtual void GenerateCtorArgument(io::Printer* printer) = 0; + virtual void GenerateInitializer(io::Printer* printer, const TString& prefix) = 0; + virtual void GeneratePrintingCode(io::Printer* printer) = 0; + protected: + const FieldDescriptor* Descriptor_; + TVariables Variables_; + }; + + class TMessageFieldExtGenerator: public TFieldExtGenerator { + public: + TMessageFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["type"] = ClassName(Descriptor_->message_type(), true); + Variables_["has_print_function"] = Descriptor_->message_type()->options().HasExtension(message_id) ? "true" : "false"; + } + + bool NeedProtobufMessageFieldPrinter() const override { + return true; + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + "const $type$& arg_$name$"); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print(Variables_, + "$prefix$mutable_$name$()->CopyFrom(arg_$name$);\n"); + } + + void GeneratePrintingCode(io::Printer* printer) override { + printer->Print("output << \"{\";\n"); + printer->Print(Variables_, + "protobufMessageFieldPrinter.PrintProtobufMessageFieldToOutput<$type$, $has_print_function$>($name$(), escapedOutput);\n"); + printer->Print("output << \"}\";\n"); + } + }; + + class TMapFieldExtGenerator: public TFieldExtGenerator { + public: + TMapFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + } + + void GenerateCtorArgument(io::Printer* /* printer */) override { + } + + void GenerateInitializer(io::Printer* /* printer */, const TString& /* prefix */) override { + } + + void GeneratePrintingCode(io::Printer* /* printer */) override { + } + }; + + class TRepeatedMessageFieldExtGenerator: public TFieldExtGenerator { + public: + TRepeatedMessageFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["type"] = ClassName(Descriptor_->message_type(), true); + Variables_["repeated_field_type"] = "RepeatedPtrField"; + Variables_["has_print_function"] = Descriptor_->message_type()->options().HasExtension(message_id) ? "true" : "false"; + } + + bool NeedProtobufMessageFieldPrinter() const override { + return true; + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + "const $type$& arg_$name$"); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print(Variables_, + "$prefix$add_$name$()->CopyFrom(arg_$name$);\n"); + } + void GeneratePrintingCode(io::Printer* printer) override { + GeneratePrintingCycle(Variables_, "protobufMessageFieldPrinter.PrintProtobufMessageFieldToOutput<$type$, $has_print_function$>($obj$, escapedOutput)", printer); + } + }; + + class TStringFieldExtGenerator: public TFieldExtGenerator { + public: + TStringFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["pointer_type"] = Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? "void" : "char"; + Variables_["type"] = "TProtoStringType"; + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + (Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? + "const $pointer_type$* arg_$name$, size_t arg_$name$_size" : "const $type$& arg_$name$") + ); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print( + Variables_, + Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? + "$prefix$set_$name$(arg_$name$, arg_$name$_size);\n" : + "$prefix$set_$name$(arg_$name$);\n" + ); + } + + void GeneratePrintingCode(io::Printer* printer) override { + Repr::ReprType fmt = Repr::none; + + if (Descriptor_->options().HasExtension(repr)) { + fmt = Descriptor_->options().GetExtension(repr); + } + + switch (fmt) { + case Repr::as_base64: + printer->Print(Variables_, "NProtoBuf::PrintAsBase64($name$(), output);\n"); + break; + + case Repr::none: + /* TODO: proper error handling?*/ + default: + printer->Print(Variables_, "escapedOutput << $name$();\n"); + break; + } + } + }; + + class TRepeatedStringFieldExtGenerator: public TFieldExtGenerator { + public: + TRepeatedStringFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["pointer_type"] = Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? "void" : "char"; + Variables_["type"] = "TProtoStringType"; + Variables_["repeated_field_type"] = "RepeatedPtrField"; + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + (Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? + "const $pointer_type$* arg_$name$, size_t arg_$name$_size": "const $type$& arg_$name$") + ); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print( + Variables_, + Descriptor_->type() == FieldDescriptor::TYPE_BYTES ? + "$prefix$add_$name$(arg_$name$, arg_$name$_size);\n" : + "$prefix$add_$name$(arg_$name$);\n" + ); + } + void GeneratePrintingCode(io::Printer* printer) override { + GeneratePrintingCycle(Variables_, "output << \"\\\"\" << $obj$ << \"\\\"\"", printer); + } + }; + + class TEnumFieldExtGenerator: public TFieldExtGenerator { + public: + TEnumFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["type"] = ClassName(Descriptor_->enum_type(), true); + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + "$type$ arg_$name$"); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print(Variables_, + "$prefix$set_$name$(arg_$name$);\n"); + } + + void GeneratePrintingCode(io::Printer* printer) override { + printer->Print(Variables_, + "output << $type$_Name($name$());\n"); + } + }; + + class TRepeatedEnumFieldExtGenerator: public TFieldExtGenerator { + public: + TRepeatedEnumFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["type"] = ClassName(Descriptor_->enum_type(), true); + Variables_["repeated_field_type"] = "RepeatedField"; + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + "$type$ arg_$name$"); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print(Variables_, + "$prefix$add_$name$(arg_$name$);\n"); + } + + void GeneratePrintingCode(io::Printer* printer) override { + TStringStream pattern; + + TProtoStringType type = Variables_["type"]; + pattern << "output << " << type << "_Name(" << type << "($obj$))"; + Variables_["type"] = "int"; + GeneratePrintingCycle(Variables_, pattern.Str(), printer); + Variables_["type"] = type; + } + }; + + class TPrimitiveFieldExtGenerator: public TFieldExtGenerator { + public: + TPrimitiveFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["type"] = PrimitiveTypeName(Descriptor_->cpp_type()); + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + "$type$ arg_$name$"); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print(Variables_, + "$prefix$set_$name$(arg_$name$);\n"); + } + + void GeneratePrintingCode(io::Printer* printer) override { + Repr::ReprType fmt = Repr::none; + + if (Descriptor_->options().HasExtension(repr)) { + fmt = Descriptor_->options().GetExtension(repr); + } + + switch (fmt) { + case Repr::as_bytes: + printer->Print(Variables_, "NProtoBuf::PrintAsBytes($name$(), output);\n"); + break; + + case Repr::as_hex: + printer->Print(Variables_, "NProtoBuf::PrintAsHex($name$(), output);\n"); + break; + + case Repr::none: + /* TODO: proper error handling? */ + default: + printer->Print(Variables_, "output << $name$();\n"); + break; + } + } + }; + + class TRepeatedPrimitiveFieldExtGenerator: public TFieldExtGenerator { + public: + TRepeatedPrimitiveFieldExtGenerator(const FieldDescriptor* field) + : TFieldExtGenerator(field) + { + Variables_["type"] = PrimitiveTypeName(Descriptor_->cpp_type()); + Variables_["repeated_field_type"] = "RepeatedField"; + } + + void GenerateCtorArgument(io::Printer* printer) override { + printer->Print(Variables_, + "$type$ arg_$name$"); + } + + void GenerateInitializer(io::Printer* printer, const TString& prefix) override { + Variables_["prefix"] = prefix; + printer->Print(Variables_, + "$prefix$add_$name$(arg_$name$);\n"); + } + + void GeneratePrintingCode(io::Printer* printer) override { + GeneratePrintingCycle(Variables_, "output << $obj$", printer); + } + }; + + std::unique_ptr<TFieldExtGenerator> MakeGenerator(const FieldDescriptor* field) { + if (field->is_map()) { + return std::make_unique<TMapFieldExtGenerator>(field); + } else if (field->is_repeated()) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_MESSAGE: + return std::make_unique<TRepeatedMessageFieldExtGenerator>(field); + case FieldDescriptor::CPPTYPE_STRING: + switch (field->options().ctype()) { + default: // RepeatedStringFieldExtGenerator handles unknown ctypes. + case FieldOptions::STRING: + return std::make_unique<TRepeatedStringFieldExtGenerator>(field); + } + case FieldDescriptor::CPPTYPE_ENUM: + return std::make_unique<TRepeatedEnumFieldExtGenerator>(field); + default: + return std::make_unique<TRepeatedPrimitiveFieldExtGenerator>(field); + } + } else { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_MESSAGE: + return std::make_unique<TMessageFieldExtGenerator>(field); + case FieldDescriptor::CPPTYPE_STRING: + switch (field->options().ctype()) { + default: // StringFieldGenerator handles unknown ctypes. + case FieldOptions::STRING: + return std::make_unique<TStringFieldExtGenerator>(field); + } + case FieldDescriptor::CPPTYPE_ENUM: + return std::make_unique<TEnumFieldExtGenerator>(field); + default: + return std::make_unique<TPrimitiveFieldExtGenerator>(field); + } + } + } + + class TMessageExtGenerator { + public: + TMessageExtGenerator(const Descriptor* descriptor, OutputDirectory* outputDirectory) + : Descriptor_(descriptor) + , HasMessageId_(Descriptor_->options().HasExtension(message_id)) + , ClassName_(ClassName(Descriptor_, false)) + , OutputDirectory_(outputDirectory) + , HasGeneratorWithProtobufMessageFieldPrinter_(false) + , CanGenerateSpecialConstructor_(false) + { + NestedGenerators_.reserve(descriptor->nested_type_count()); + for (int i = 0; i < descriptor->nested_type_count(); i++) { + NestedGenerators_.emplace_back(descriptor->nested_type(i), OutputDirectory_); + } + + if (HasMessageId_) { + FieldGenerators_.reserve(descriptor->field_count()); + for (int i = 0; i < descriptor->field_count(); i++) { + FieldGenerators_.emplace_back(MakeGenerator(descriptor->field(i))); + HasGeneratorWithProtobufMessageFieldPrinter_ |= FieldGenerators_.back()->NeedProtobufMessageFieldPrinter(); + } + } + + { + size_t intFieldCount = 0; + size_t mapFieldCount = 0; + size_t nonMapFieldCount = 0; + for (int i = 0; i < Descriptor_->field_count(); ++i) { + const FieldDescriptor* field = Descriptor_->field(i); + if (field->is_map()) { + ++mapFieldCount; + } else { + ++nonMapFieldCount; + } + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + case FieldDescriptor::CPPTYPE_INT64: + case FieldDescriptor::CPPTYPE_UINT32: + case FieldDescriptor::CPPTYPE_UINT64: + ++intFieldCount; + break; + default: + break; + } + } + + CanGenerateSpecialConstructor_ = ( + // Certain field combinations would result in ambiguous constructor generation. + // Do not generate special contructor for such combinations. + (intFieldCount != nonMapFieldCount || nonMapFieldCount != 2) && + + // Generate special contructor only if there is at least one non-map Field. + nonMapFieldCount > 0 + ); + } + + } + + void GenerateClassDefinitionExtension() { + if (Descriptor_->options().HasExtension(realm_name) || Descriptor_->options().HasExtension(message_id)) { + GeneratePseudonim(); + } + + if (!HasMessageId_) { + return; + } + + CheckMessageId(Descriptor_->options().GetExtension(message_id), ClassName_); + + TProtoStringType fileName = HeaderFileName(Descriptor_->file()); + TProtoStringType scope = "class_scope:" + Descriptor_->full_name(); + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, scope)); + io::Printer printer(output.get(), '$'); + + printer.Print("//Yandex events extension.\n"); + GenerateHeaderImpl(&printer); + + for (auto& nestedGenerator: NestedGenerators_) { + nestedGenerator.GenerateClassDefinitionExtension(); + } + } + + bool GenerateClassExtension() { + TProtoStringType fileName = SourceFileName(Descriptor_->file()); + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, "namespace_scope")); + io::Printer printer(output.get(), '$'); + + bool hasEventExtension = GenerateSourceImpl(&printer); + + for (auto& nestedGenerator: NestedGenerators_) { + hasEventExtension |= nestedGenerator.GenerateSourceImpl(&printer); + } + + return hasEventExtension; + } + + void GenerateRegistration(io::Printer* printer) { + if (!HasMessageId_) { + return; + } + + TVariables vars; + vars["classname"] = ClassName_; + + printer->Print(vars, "NProtoBuf::TEventFactory::Instance()->RegisterEvent($classname$::descriptor()->options().GetExtension(message_id), factory->GetPrototype($classname$::descriptor()), $classname$::Print);\n"); + } + + private: + void GenerateHeaderImpl(io::Printer* printer) { + TVariables vars; + TProtoStringType mId(ToString(Descriptor_->options().GetExtension(message_id))); + vars["classname"] = ClassName_; + vars["messageid"] = mId.data(); + vars["superclass"] = SuperClassName(Descriptor_, Options{}); + + printer->Print(vars, + "enum {ID = $messageid$};\n\n"); + + { + /* + * Unconditionally generate FromFields() factory method, + * so it could be used in template code + */ + printer->Print(vars, "static $classname$ FromFields(\n"); + GenerateCtorArgs(printer); + printer->Print(");\n"); + } + + if (CanGenerateSpecialConstructor_) { + printer->Print(vars, "$classname$(\n"); + GenerateCtorArgs(printer); + printer->Print(");\n"); + } + + { + printer->Print("void Print(IOutputStream& output, EFieldOutputFlags outputFlags = {}) const;\n"); + printer->Print("static void Print(const google::protobuf::Message* ev, IOutputStream& output, EFieldOutputFlags outputFlags = {});\n"); + } + } + + void GeneratePseudonim() { + TProtoStringType fileName = HeaderFileName(Descriptor_->file()); + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, "namespace_scope")); + io::Printer printer(output.get(), '$'); + + std::vector<TProtoStringType> realm_parts; + + if (Descriptor_->options().HasExtension(realm_name)) { + SplitStringUsing(Descriptor_->options().GetExtension(realm_name), ".", &realm_parts); + } + + if (realm_parts.size() > 0) printer.Print("\n"); + + for (size_t i = 0; i < realm_parts.size(); ++i) { + printer.Print("namespace $part$ {\n", + "part", realm_parts[i]); + } + + printer.Print("typedef $fullclassname$ T$classname$;\n", + "fullclassname", FullClassName(Descriptor_), + "classname", ClassName_); + + for (size_t i = realm_parts.size(); i > 0; --i) { + printer.Print("} // namespace $part$\n", + "part", realm_parts[i - 1]); + } + } + + TProtoStringType FullClassName(const Descriptor* descriptor) { + TProtoStringType result; + std::vector<TProtoStringType> parts; + + SplitStringUsing(descriptor->file()->package(), ".", &parts); + for (size_t i = 0; i < parts.size(); ++i) { + result += "::" + parts[i]; + } + + result += "::" + ClassName(descriptor, false); + + return result; + } + + bool GenerateSourceImpl(io::Printer* printer) { + if (!HasMessageId_) { + return false; + } + + TVariables vars; + vars["classname"] = ClassName_; + + { + // Generate static $classname$::FromFields impl. + printer->Print(vars, "$classname$ $classname$::FromFields(\n"); + GenerateCtorArgs(printer); + printer->Print(")\n"); + + printer->Print("{\n"); + + printer->Indent(); + printer->Print(vars, "$classname$ result;\n"); + GenerateFieldInitializers(printer, /* prefix = */ "result."); + printer->Print("return result;\n"); + printer->Outdent(); + + printer->Print("}\n\n"); + } + + if (CanGenerateSpecialConstructor_) { + // Generate special constructor impl. + printer->Print(vars, "$classname$::$classname$(\n"); + GenerateCtorArgs(printer); + printer->Print(")\n"); + + printer->Print("{\n"); + + printer->Indent(); + printer->Print("SharedCtor();\n"); + GenerateFieldInitializers(printer, /* prefix = */ ""); + printer->Outdent(); + + printer->Print("}\n\n"); + } + + { + // Generate $classname$::Print impl. + const size_t fieldCount = Descriptor_->field_count(); + if (fieldCount > 0) { + printer->Print(vars, + "void $classname$::Print(IOutputStream& output, EFieldOutputFlags outputFlags) const {\n"); + printer->Indent(); + printer->Print( + "TEventFieldOutput escapedOutput{output, outputFlags};\n" + "Y_UNUSED(escapedOutput);\n"); + + if (HasGeneratorWithProtobufMessageFieldPrinter_) { + printer->Print( + "TEventProtobufMessageFieldPrinter protobufMessageFieldPrinter(EProtobufMessageFieldPrintMode::DEFAULT);\n"); + } + } else { + printer->Print(vars, + "void $classname$::Print(IOutputStream& output, EFieldOutputFlags) const {\n"); + printer->Indent(); + } + + printer->Print(vars, + "output << \"$classname$\";\n"); + + for (size_t i = 0; i < fieldCount; ++i) { + printer->Print("output << \"\\t\";\n"); + FieldGenerators_[i]->GeneratePrintingCode(printer); + } + + printer->Outdent(); + printer->Print("}\n\n"); + } + + { + // Generate static $classname$::Print impl. + printer->Print(vars, + "void $classname$::Print(const google::protobuf::Message* ev, IOutputStream& output, EFieldOutputFlags outputFlags) {\n"); + printer->Indent(); + printer->Print(vars, + "const $classname$* This(CheckedCast<const $classname$*>(ev));\n"); + printer->Print( + "This->Print(output, outputFlags);\n"); + printer->Outdent(); + printer->Print("}\n\n"); + } + + return true; + } + + void GenerateCtorArgs(io::Printer* printer) { + printer->Indent(); + const size_t fieldCount = Descriptor_->field_count(); + bool isFirst = true; + for (size_t i = 0; i < fieldCount; ++i) { + if (Descriptor_->field(i)->is_map()) { + continue; + } + const char* delimiter = isFirst ? "" : ", "; + isFirst = false; + printer->Print(delimiter); + FieldGenerators_[i]->GenerateCtorArgument(printer); + } + printer->Outdent(); + } + + void GenerateFieldInitializers(io::Printer* printer, const TString& prefix) { + for (auto& fieldGeneratorHolder: FieldGenerators_) { + fieldGeneratorHolder->GenerateInitializer(printer, prefix); + } + } + + private: + const Descriptor* Descriptor_; + const bool HasMessageId_; + TProtoStringType ClassName_; + OutputDirectory* OutputDirectory_; + bool HasGeneratorWithProtobufMessageFieldPrinter_; + bool CanGenerateSpecialConstructor_; + std::vector<std::unique_ptr<TFieldExtGenerator>> FieldGenerators_; + std::vector<TMessageExtGenerator> NestedGenerators_; + }; + + class TFileExtGenerator { + public: + TFileExtGenerator(const FileDescriptor* file, OutputDirectory* output_directory) + : OutputDirectory_(output_directory) + , File_(file) + { + MessageGenerators_.reserve(file->message_type_count()); + for (int i = 0; i < file->message_type_count(); i++) { + MessageGenerators_.emplace_back(file->message_type(i), OutputDirectory_); + } + } + + void GenerateHeaderExtensions() { + TProtoStringType fileName = HeaderFileName(File_); + + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, "includes")); + io::Printer printer(output.get(), '$'); + + printer.Print("#include <library/cpp/eventlog/event_field_output.h>\n"); + printer.Print("#include <library/cpp/eventlog/event_field_printer.h>\n"); + + for (auto& messageGenerator: MessageGenerators_) { + messageGenerator.GenerateClassDefinitionExtension(); + } + } + + void GenerateSourceExtensions() { + TProtoStringType fileName = SourceFileName(File_); + + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, "includes")); + io::Printer printer(output.get(), '$'); + printer.Print("#include <google/protobuf/io/printer.h>\n"); + printer.Print("#include <google/protobuf/io/zero_copy_stream_impl_lite.h>\n"); + printer.Print("#include <google/protobuf/stubs/strutil.h>\n"); + printer.Print("#include <library/cpp/eventlog/events_extension.h>\n"); + printer.Print("#include <util/generic/cast.h>\n"); + printer.Print("#include <util/stream/output.h>\n"); + + bool hasEventExtension = false; + + for (auto& messageGenerator: MessageGenerators_) { + hasEventExtension |= messageGenerator.GenerateClassExtension(); + } + + if (hasEventExtension) { + GenerateEventRegistrations(); + } + } + + void GenerateEventRegistrations() { + TVariables vars; + TProtoStringType fileId = FilenameIdentifier(File_->name()); + vars["regfunction"] = "regevent_" + fileId; + vars["regclassname"] = "TRegister_" + fileId; + vars["regvarname"] = "registrator_" + fileId ; + vars["filename"] = File_->name(); + + { + TProtoStringType fileName = SourceFileName(File_); + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, "namespace_scope")); + io::Printer printer(output.get(), '$'); + + GenerateRegistrationFunction(vars, printer); + GenerateRegistratorDefinition(vars, printer); + } + + { + + TProtoStringType fileName = HeaderFileName(File_); + std::unique_ptr<io::ZeroCopyOutputStream> output( + OutputDirectory_->OpenForInsert(fileName, "namespace_scope")); + io::Printer printer(output.get(), '$'); + GenerateRegistratorDeclaration(vars, printer); + } + } + + void GenerateRegistrationFunction(const TVariables& vars, io::Printer& printer) { + printer.Print(vars, + "void $regfunction$() {\n"); + printer.Indent(); + + printer.Print("google::protobuf::MessageFactory* factory = google::protobuf::MessageFactory::generated_factory();\n\n"); + for (auto& messageGenerator: MessageGenerators_) { + messageGenerator.GenerateRegistration(&printer); + } + printer.Outdent(); + printer.Print("}\n\n"); + } + + void GenerateRegistratorDeclaration(const TVariables& vars, io::Printer& printer) { + printer.Print(vars, "\nclass $regclassname$ {\n"); + printer.Print("public:\n"); + printer.Indent(); + printer.Print(vars, "$regclassname$();\n"); + printer.Outdent(); + printer.Print("private:\n"); + printer.Indent(); + printer.Print("static bool Registered;\n"); + printer.Outdent(); + printer.Print(vars, "};\n"); + printer.Print(vars, "static $regclassname$ $regvarname$;\n\n"); + } + + void GenerateRegistratorDefinition(const TVariables& vars, io::Printer& printer) { + printer.Print(vars, "$regclassname$::$regclassname$() {\n"); + printer.Indent(); + printer.Print("if (!Registered) {\n"); + printer.Indent(); + printer.Print(vars, "NProtoBuf::TEventFactory::Instance()->ScheduleRegistration(&$regfunction$);\n"); + printer.Print("Registered = true;\n"); + printer.Outdent(); + printer.Print("}\n"); + printer.Outdent(); + printer.Print("}\n\n"); + printer.Print(vars, "bool $regclassname$::Registered;\n\n"); + } + private: + OutputDirectory* OutputDirectory_; + const FileDescriptor* File_; + std::vector<TMessageExtGenerator> MessageGenerators_; + }; +} + + bool TProtoEventExtensionGenerator::Generate(const google::protobuf::FileDescriptor* file, + const TProtoStringType& parameter, + google::protobuf::compiler::OutputDirectory* outputDirectory, + TProtoStringType* error) const { + Y_UNUSED(parameter); + Y_UNUSED(error); + + NInternal::TFileExtGenerator fileGenerator(file, outputDirectory); + + // Generate header. + fileGenerator.GenerateHeaderExtensions(); + + // Generate cc file. + fileGenerator.GenerateSourceExtensions(); + + return true; + } + +} // namespace NProtoBuf::NCompiler::NPlugins + +int main(int argc, char* argv[]) { +#ifdef _MSC_VER + // Don't print a silly message or stick a modal dialog box in my face, + // please. + _set_abort_behavior(0u, ~0u); +#endif // !_MSC_VER + + try { + NProtoBuf::NCompiler::NPlugins::TProtoEventExtensionGenerator generator; + return google::protobuf::compiler::PluginMain(argc, argv, &generator); + } catch (yexception& e) { + Cerr << e.what() << Endl; + } catch (...) { + Cerr << "Unknown error in TProtoEventExtensionGenerator" << Endl; + } + + return 1; +} diff --git a/tools/event2cpp/proto_events.h b/tools/event2cpp/proto_events.h new file mode 100644 index 0000000000..628b4856af --- /dev/null +++ b/tools/event2cpp/proto_events.h @@ -0,0 +1,20 @@ +#pragma once + +#include <google/protobuf/compiler/plugin.h> +#include <google/protobuf/compiler/code_generator.h> +#include <google/protobuf/stubs/common.h> + +namespace NProtoBuf::NCompiler::NPlugins { + +class TProtoEventExtensionGenerator : public google::protobuf::compiler::CodeGenerator { + public: + TProtoEventExtensionGenerator() {} + ~TProtoEventExtensionGenerator() override {} + + bool Generate(const google::protobuf::FileDescriptor* file, + const TProtoStringType& parameter, + google::protobuf::compiler::OutputDirectory* output_directory, + TProtoStringType* error) const override; +}; + +} // namespace NProtoBuf::NCompiler::NPlugins diff --git a/tools/go_test_miner/main.go b/tools/go_test_miner/main.go new file mode 100644 index 0000000000..43b729572e --- /dev/null +++ b/tools/go_test_miner/main.go @@ -0,0 +1,168 @@ +package main + +import ( + "flag" + "fmt" + "go/importer" + "go/token" + "go/types" + "os" + "path/filepath" + "regexp" + "runtime" + "sort" + "strings" + "unicode" + "unicode/utf8" +) + +const ( + usageTemplate = "Usage: %s [-benchmarks] [-examples] [-tests] import-path\n" +) + +func findObjectByName(pkg *types.Package, re *regexp.Regexp, name string) types.Object { + if pkg != nil && re != nil && len(name) > 0 { + if obj := pkg.Scope().Lookup(name); obj != nil { + if re.MatchString(obj.Type().String()) { + return obj + } + } + } + return nil +} + +func isTestName(name, prefix string) bool { + ok := false + if strings.HasPrefix(name, prefix) { + if len(name) == len(prefix) { + ok = true + } else { + rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) + ok = !unicode.IsLower(rune) + } + } + return ok +} + +func main() { + testsPtr := flag.Bool("tests", false, "report tests") + benchmarksPtr := flag.Bool("benchmarks", false, "report benchmarks") + examplesPtr := flag.Bool("examples", false, "report examples") + + flag.Usage = func() { + _, _ = fmt.Fprintf(flag.CommandLine.Output(), usageTemplate, filepath.Base(os.Args[0])) + flag.PrintDefaults() + } + + flag.Parse() + + // Check if the number of positional parameters matches + args := flag.Args() + argsCount := len(args) + if argsCount != 1 { + exitCode := 0 + if argsCount > 1 { + fmt.Println("Error: invalid number of parameters...") + exitCode = 1 + } + flag.Usage() + os.Exit(exitCode) + } + + importPath := args[0] + + var fset token.FileSet + imp := importer.ForCompiler(&fset, runtime.Compiler, nil) + pkg, err := imp.Import(importPath) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + if !*testsPtr && !*benchmarksPtr && !*examplesPtr { + // Nothing to do, just exit normally + os.Exit(0) + } + + // // First approach: just dump the package scope as a string + // // package "junk/snermolaev/libmath" scope 0xc0000df540 { + // // . func junk/snermolaev/libmath.Abs(a int) int + // // . func junk/snermolaev/libmath.AbsReport(s string) + // // . func junk/snermolaev/libmath.Sum(a int, b int) int + // // . func junk/snermolaev/libmath.TestAbs(t *testing.T) + // // . func junk/snermolaev/libmath.TestSum(t *testing.T) + // // . func junk/snermolaev/libmath.init() + // // } + // // and then collect all functions that match test function signature + // pkgPath := pkg.Path() + // scopeContent := strings.Split(pkg.Scope().String(), "\n") + // re := regexp.MustCompile("^\\.\\s*func\\s*" + pkgPath + "\\.(Test\\w*)\\(\\s*\\w*\\s*\\*\\s*testing\\.T\\s*\\)$") + // for _, name := range scopeContent { + // match := re.FindAllStringSubmatch(name, -1) + // if len(match) > 0 { + // fmt.Println(match[0][1]) + // } + // } + + // Second approach: look through all names defined in the pkg scope + // and collect those functions that match test function signature + // Unfortunately I failed to employ reflection mechinary for signature + // comparison for unknown reasons (this needs additional investigation + // I am going to use regexp as workaround for a while) + // testFunc := func (*testing.T) {} + // for ... + // ... + // if reflect.DeepEqual(obj.Type(), reflect.TypeOf(testFunc)) { + // // this condition doesn't work + // } + reBenchmark := regexp.MustCompile(`^func\(\w*\s*\*testing\.B\)$`) + reExample := regexp.MustCompile(`^func\(\s*\)$`) + reTest := regexp.MustCompile(`^func\(\w*\s*\*testing\.T\)$`) + reTestMain := regexp.MustCompile(`^func\(\w*\s*\*testing\.M\)$`) + + var re *regexp.Regexp + names := pkg.Scope().Names() + + var testFns []types.Object + + for _, name := range names { + if name == "TestMain" && findObjectByName(pkg, reTestMain, name) != nil { + fmt.Println("#TestMain") + continue + } + + switch { + case *benchmarksPtr && isTestName(name, "Benchmark"): + re = reBenchmark + case *examplesPtr && isTestName(name, "Example"): + re = reExample + case *testsPtr && isTestName(name, "Test"): + re = reTest + default: + continue + } + + if obj := findObjectByName(pkg, re, name); obj != nil { + testFns = append(testFns, obj) + } + } + + sort.Slice(testFns, func(i, j int) bool { + iPos := testFns[i].Pos() + jPos := testFns[j].Pos() + + if !iPos.IsValid() || !jPos.IsValid() { + return iPos < jPos + } + + iPosition := fset.PositionFor(iPos, true) + jPosition := fset.PositionFor(jPos, true) + + return iPosition.Filename < jPosition.Filename || + (iPosition.Filename == jPosition.Filename && iPosition.Line < jPosition.Line) + }) + + for _, testFn := range testFns { + fmt.Println(testFn.Name()) + } +} diff --git a/tools/triecompiler/lib/main.cpp b/tools/triecompiler/lib/main.cpp new file mode 100644 index 0000000000..320f46d48a --- /dev/null +++ b/tools/triecompiler/lib/main.cpp @@ -0,0 +1,473 @@ +#include "main.h" + +#ifndef CATBOOST_OPENSOURCE +#include <library/cpp/charset/recyr.hh> +#endif + +#include <library/cpp/containers/comptrie/comptrie.h> +#include <library/cpp/deprecated/mapped_file/mapped_file.h> +#include <library/cpp/getopt/small/last_getopt.h> + +#include <util/charset/wide.h> +#include <util/generic/buffer.h> +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/stream/buffered.h> +#include <util/stream/file.h> +#include <util/stream/input.h> +#include <util/stream/output.h> +#include <util/stream/output.h> +#include <util/string/cast.h> +#include <util/string/util.h> +#include <util/system/filemap.h> + +#include <string> + +#ifdef WIN32 +#include <crtdbg.h> +#include <windows.h> +#endif // WIN32 + +namespace { + struct TOptions { + TString Prog = {}; + TString Triefile = {}; + TString Infile = {}; + bool Minimized = false; + bool FastLayout = false; + TCompactTrieBuilderFlags Flags = CTBF_NONE; + bool Unicode = false; + bool Verify = false; + bool Wide = false; + bool NoValues = false; + bool Vector = false; + int ArraySize = 0; + TString ValueType = {}; + bool AllowEmptyKey = false; + bool UseAsIsParker = false; + }; +} // namespace + +static TOptions ParseOptions(const int argc, const char* argv[]) { + TOptions options; + auto parser = NLastGetopt::TOpts::Default(); + parser + .AddLongOption('t', "type") + .DefaultValue("ui64") + .RequiredArgument("TYPE") + .StoreResult(&options.ValueType) + .Help("type of value or array element, possible: ui16, i16, ui32, i32, ui64, i64, float, double, bool," + " TString, TUtf16String (utf-8 in input)"); + parser + .AddCharOption('0') + .NoArgument() + .SetFlag(&options.NoValues) + .Help("Do not store values to produce TCompactTrieSet compatible binary"); + parser + .AddLongOption('a', "array") + .NoArgument() + .SetFlag(&options.Vector) + .Help("Values are arrays (of some type, depending on -t flag). Input looks like" + " 'key <TAB> value1 <TAB> .. <TAB> valueN' (N may differ from line to line)"); + parser + .AddCharOption('S') + .RequiredArgument("INT") + .StoreResult(&options.ArraySize) + .Help("Values are fixed size not packed (!) arrays, may used for mapping classes/structs" + " with monotyped fields"); + parser + .AddLongOption('e', "allow-empty") + .NoArgument() + .SetFlag(&options.AllowEmptyKey) + .Help("Allow empty key"); + parser + .AddLongOption('i', "input") + .DefaultValue("-") + .RequiredArgument("FILE") + .StoreResult(&options.Infile) + .Help("Input file"); + parser + .AddLongOption('m', "minimize") + .NoArgument() + .SetFlag(&options.Minimized) + .Help("Minimize tree into a DAG"); + parser + .AddLongOption('f', "fast-layout") + .NoArgument() + .SetFlag(&options.FastLayout) + .Help("Make fast layout"); + parser + .AddLongOption('v', "verbose") + .NoArgument() + .Help("Be verbose - show progress & stats"); + parser + .AddLongOption('s', "prefix-grouped") + .NoArgument() + .Help("Assume input is prefix-grouped by key (for every prefix all keys with this prefix" + " come in one group; greatly reduces memory usage)"); + parser + .AddLongOption('q', "unique-keys") + .NoArgument() + .Help("Assume the keys are unique (will report an error otherwise)"); + parser + .AddLongOption('c', "check") + .NoArgument() + .Help("Check the compiled trie (works only with an explicit input file name)"); + parser + .AddCharOption('u') + .NoArgument() + .SetFlag(&options.Unicode) + .Help("Recode keys from UTF-8 to Yandex (deprecated)"); + parser + .AddLongOption('w', "wide") + .NoArgument() + .SetFlag(&options.Wide) + .Help("Treat input keys as UTF-8, recode to TChar (wchar16)"); + parser + .AddLongOption('P', "as-is-packer") + .NoArgument() + .SetFlag(&options.UseAsIsParker) + .Help("Use AsIsParker to pack value in trie"); + parser.AddHelpOption('h'); + parser.SetFreeArgsNum(1); + parser.SetFreeArgTitle(0, "TRIE_FILE", "Compiled trie"); + NLastGetopt::TOptsParseResult parsed{&parser, argc, argv}; + options.Triefile = parsed.GetFreeArgs().front(); + if (parsed.Has('q')) { + options.Flags |= CTBF_UNIQUE; + } + if (parsed.Has('s')) { + options.Flags |= CTBF_PREFIX_GROUPED; + } + if (parsed.Has('v')) { + options.Flags |= CTBF_VERBOSE; + } + return options; +} + +namespace { + template <class T> + struct TFromString { + T operator() (const char* start, size_t len) const { + return FromStringImpl<T>(start, len); + } + }; + + template <> + struct TFromString<TUtf16String> { + TUtf16String operator ()(const char* start, size_t len) const { + return UTF8ToWide(start, len); + } + }; + + template <class TTKey, class TTKeyChar, class TTValue, + class TKeyReader = TFromString<TTKey>, + class TValueReader = TFromString<TTValue> > + struct TRecord { + typedef TTKey TKey; + typedef TTKeyChar TKeyChar; + typedef TTValue TValue; + TKey Key; + TValue Value; + TString Tmp; + bool Load(IInputStream& in, const bool allowEmptyKey, const bool noValues) { + while (in.ReadLine(Tmp)) { + if (!Tmp) { + // there is a special case for TrieSet with empty keys allowed + if (!(noValues && allowEmptyKey)) { + continue; + } + } + + const size_t sep = Tmp.find('\t'); + if (sep != TString::npos) { + if (0 == sep && !allowEmptyKey) { + continue; + } + Key = TKeyReader()(Tmp.data(), sep); + Value = TValueReader()(Tmp.data() + sep + 1, Tmp.size() - sep - 1); + } else if (noValues) { + RemoveIfLast<TString>(Tmp, '\n'); + Key = TKeyReader()(Tmp.data(), Tmp.size()); + Value = TValue(); + } + return true; + } + return false; + } + }; + + template <class TTKey, class TTKeyChar, class T, + class TKeyReader = TFromString<TTKey>, + class TValueReader = TFromString<T> > + struct TVectorRecord { + typedef TTKey TKey; + typedef TTKeyChar TKeyChar; + typedef TVector<T> TValue; + TKey Key; + TValue Value; + TString Tmp; + + bool Load(IInputStream& in, const bool allowEmptyKey, const bool noValues) { + Y_UNUSED(noValues); + while (in.ReadLine(Tmp)) { + if (!Tmp && !allowEmptyKey) { + continue; + } + + size_t sep = Tmp.find('\t'); + if (sep == TString::npos) { + RemoveIfLast<TString>(Tmp, '\n'); + Key = TKeyReader()(Tmp.data(), Tmp.size()); + Value = TValue(); + } else { + Key = TKeyReader()(Tmp.data(), sep); + Value = TValue(); + while (sep != Tmp.size()) { + size_t sep2 = Tmp.find('\t', sep + 1); + if (sep2 == TString::npos) { + sep2 = Tmp.size(); + } + + if (sep + 1 != sep2) { + Value.push_back(TValueReader()(Tmp.data() + sep + 1, sep2 - sep - 1)); + } + sep = sep2; + } + } + return true; + } + return false; + } + }; + + template <typename TVectorType> + class TFixedArrayAsIsPacker { + public: + TFixedArrayAsIsPacker() + : ArraySize(0) + , SizeOfValue(0) + { + } + explicit TFixedArrayAsIsPacker(size_t arraySize) + : ArraySize(arraySize) + , SizeOfValue(arraySize * sizeof(typename TVectorType::value_type)) + { + } + void UnpackLeaf(const char* p, TVectorType& t) const { + const typename TVectorType::value_type* beg = reinterpret_cast<const typename TVectorType::value_type*>(p); + t.assign(beg, beg + ArraySize); + } + void PackLeaf(char* buffer, const TVectorType& data, size_t computedSize) const { + Y_ASSERT(computedSize == SizeOfValue && data.size() == ArraySize); + memcpy(buffer, data.data(), computedSize); + } + size_t MeasureLeaf(const TVectorType& data) const { + Y_UNUSED(data); + Y_ASSERT(data.size() == ArraySize); + return SizeOfValue; + } + size_t SkipLeaf(const char* ) const { + return SizeOfValue; + } + private: + size_t ArraySize; + size_t SizeOfValue; + }; + +#ifndef CATBOOST_OPENSOURCE + struct TUTF8ToYandexRecoder { + TString operator()(const char* s, size_t len) { + return Recode(CODES_UTF8, CODES_YANDEX, TString(s, len)); + } + }; +#endif + + struct TUTF8ToWideRecoder { + TUtf16String operator()(const char* s, size_t len) { + return UTF8ToWide(s, len); + } + }; +} // namespace + +template <class TRecord, class TPacker> +static int ProcessFile(IInputStream& in, const TOptions& o, const TPacker& packer) { + TFixedBufferFileOutput out(o.Triefile); + typedef typename TRecord::TKeyChar TKeyChar; + typedef typename TRecord::TValue TValue; + + THolder< TCompactTrieBuilder<TKeyChar, TValue, TPacker> > builder(new TCompactTrieBuilder<TKeyChar, TValue, TPacker>(o.Flags, packer)); + + TRecord r; + while (r.Load(in, o.AllowEmptyKey, o.NoValues)) { + builder->Add(r.Key.data(), r.Key.size(), r.Value); + } + + if (o.Flags & CTBF_VERBOSE) { + Cerr << Endl; + Cerr << "Entries: " << builder->GetEntryCount() << Endl; + Cerr << "Tree nodes: " << builder->GetNodeCount() << Endl; + } + TBufferOutput inputForFastLayout; + IOutputStream* currentOutput = &out; + if (o.FastLayout) { + currentOutput = &inputForFastLayout; + } + if (o.Minimized) { + TBufferOutput raw; + size_t datalength = builder->Save(raw); + if (o.Flags & CTBF_VERBOSE) + Cerr << "Data length (before compression): " << datalength << Endl; + builder.Destroy(); + + datalength = CompactTrieMinimize(*currentOutput, raw.Buffer().Data(), raw.Buffer().Size(), o.Flags & CTBF_VERBOSE, packer); + if (o.Flags & CTBF_VERBOSE) + Cerr << "Data length (minimized): " << datalength << Endl; + } else { + size_t datalength = builder->Save(*currentOutput); + if (o.Flags & CTBF_VERBOSE) + Cerr << "Data length: " << datalength << Endl; + } + if (o.FastLayout) { + builder.Destroy(); + size_t datalength = CompactTrieMakeFastLayout(out, inputForFastLayout.Buffer().Data(), + inputForFastLayout.Buffer().Size(), o.Flags & CTBF_VERBOSE, packer); + if (o.Flags & CTBF_VERBOSE) + Cerr << "Data length (fast layout): " << datalength << Endl; + } + + return 0; +} + +template <class TRecord, class TPacker> +static int VerifyFile(const TOptions& o, const TPacker& packer) { + TMappedFile filemap(o.Triefile); + typedef typename TRecord::TKeyChar TKeyChar; + typedef typename TRecord::TValue TValue; + TCompactTrie<TKeyChar, TValue, TPacker> trie((const char*)filemap.getData(), filemap.getSize(), packer); + + TFileInput in(o.Infile); + size_t entrycount = 0; + int retcode = 0; + TRecord r; + while (r.Load(in, o.AllowEmptyKey, o.NoValues)) { + entrycount++; + TValue trievalue; + + if (!trie.Find(r.Key.data(), r.Key.size(), &trievalue)) { + Cerr << "Trie check failed on key #" << entrycount << "\"" << r.Key << "\": no key present" << Endl; + retcode = 1; + } else if (!o.NoValues && trievalue != r.Value) { + Cerr << "Trie check failed on key #" << entrycount << "\"" << r.Key << "\": value mismatch" << Endl; + retcode = 1; + } + } + + for (typename TCompactTrie<TKeyChar, TValue, TPacker>::TConstIterator iter = trie.Begin(); iter != trie.End(); ++iter) { + entrycount--; + } + + if (entrycount) { + Cerr << "Broken iteration: entry count mismatch" << Endl; + retcode = 1; + } + + if ((o.Flags & CTBF_VERBOSE) && !retcode) { + Cerr << "Trie check successful" << Endl; + } + return retcode; +} + +template <class TRecord, class TPacker> +static int SelectInput(const TOptions& o, const TPacker& packer) { + if ("-"sv == o.Infile) { + TBufferedInput wrapper{&Cin}; + return ProcessFile<TRecord>(wrapper, o, packer); + } + + TFileInput in(o.Infile); + return ProcessFile<TRecord>(in, o, packer); +} + +template <class TRecord, class TPacker> +static int DoMain(const TOptions& o, const TPacker& packer) { + int retcode = SelectInput<TRecord>(o, packer); + if (!retcode && o.Verify && !o.Triefile.empty()) + retcode = VerifyFile<TRecord>(o, packer); + return retcode; +} + +// TRecord - nested template parameter +template<class TValue, + template<class TKey, class TKeyChar, class TValueOther, + class TKeyReader, + class TValueReader> class TRecord, class TPacker> +static int ProcessInput(const TOptions& o, const TPacker& packer) { + if (!o.Wide) { + if (!o.Unicode) { + return DoMain< TRecord< TString, char, TValue, TFromString<TString>, TFromString<TValue> > >(o, packer); + } else { + #ifndef CATBOOST_OPENSOURCE + return DoMain< TRecord< TString, char, TValue, TUTF8ToYandexRecoder, TFromString<TValue> > >(o, packer); + #else + Y_FAIL("Yandex encoding is not supported in CATBOOST_OPENSOURCE mode"); + #endif + } + } else { + return DoMain< TRecord< TUtf16String, TChar, TValue, TUTF8ToWideRecoder, TFromString<TValue> > >(o, packer); + } + } + +template <class TItemType> +static int ProcessInput(const TOptions& o) { + if (o.ArraySize > 0) { + return ProcessInput<TItemType, TVectorRecord>(o, TFixedArrayAsIsPacker<TVector<TItemType> >(o.ArraySize)); + } else if (o.Vector) { + return ProcessInput<TItemType, TVectorRecord>(o, TCompactTriePacker<TVector<TItemType> >()); + } else if (o.UseAsIsParker) { + return ProcessInput<TItemType, TRecord>(o, TAsIsPacker<TItemType>()); + } else { + return ProcessInput<TItemType, TRecord>(o, TCompactTriePacker<TItemType>()); + } +} + +static int Main(const int argc, const char* argv[]) +try { +#ifdef WIN32 + _CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF); + ::SetConsoleCP(1251); + ::SetConsoleOutputCP(1251); +#endif // WIN32 + const TOptions o = ParseOptions(argc, argv); + if (o.NoValues) { + return ProcessInput<ui64, TRecord>(o, TNullPacker<ui64>()); + } else { +#define CHECK_TYPE_AND_PROCESS(valueType) \ + if (o.ValueType == #valueType) { \ + return ProcessInput<valueType>(o); \ + } + CHECK_TYPE_AND_PROCESS(ui16) + CHECK_TYPE_AND_PROCESS(i16) + CHECK_TYPE_AND_PROCESS(ui32) + CHECK_TYPE_AND_PROCESS(i32) + CHECK_TYPE_AND_PROCESS(ui64) + CHECK_TYPE_AND_PROCESS(i64) + CHECK_TYPE_AND_PROCESS(bool) + CHECK_TYPE_AND_PROCESS(float) + CHECK_TYPE_AND_PROCESS(double) + CHECK_TYPE_AND_PROCESS(TString) + CHECK_TYPE_AND_PROCESS(TUtf16String) +#undef CHECK_TYPE_AND_PROCESS + ythrow yexception() << "unknown type for -t option: " << o.ValueType; + } +} catch (const std::exception& e) { + Cerr << "Exception: " << e.what() << Endl; + return 2; +} catch (...) { + Cerr << "Unknown exception!\n"; + return 3; +} + +int NTrieOps::MainCompile(const int argc, const char* argv[]) { + return ::Main(argc, argv); +} diff --git a/tools/triecompiler/lib/main.h b/tools/triecompiler/lib/main.h new file mode 100644 index 0000000000..34bd02be1e --- /dev/null +++ b/tools/triecompiler/lib/main.h @@ -0,0 +1,5 @@ +#pragma once + +namespace NTrieOps { + int MainCompile(const int argc, const char* argv[]); +} diff --git a/tools/triecompiler/main.cpp b/tools/triecompiler/main.cpp new file mode 100644 index 0000000000..4c22a2bd8a --- /dev/null +++ b/tools/triecompiler/main.cpp @@ -0,0 +1,5 @@ +#include <tools/triecompiler/lib/main.h> + +int main(const int argc, const char* argv[]) { + return NTrieOps::MainCompile(argc, argv); +} |