summaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/protobuf_udf/proto_builder.cpp
diff options
context:
space:
mode:
authorvvvv <[email protected]>2024-11-07 04:19:26 +0300
committervvvv <[email protected]>2024-11-07 04:29:50 +0300
commit2661be00f3bc47590fda9218bf0386d6355c8c88 (patch)
tree3d316c07519191283d31c5f537efc6aabb42a2f0 /yql/essentials/minikql/protobuf_udf/proto_builder.cpp
parentcf2a23963ac10add28c50cc114fbf48953eca5aa (diff)
Moved yql/minikql YQL-19206
init [nodiff:caesar] commit_hash:d1182ef7d430ccf7e4d37ed933c7126d7bd5d6e4
Diffstat (limited to 'yql/essentials/minikql/protobuf_udf/proto_builder.cpp')
-rw-r--r--yql/essentials/minikql/protobuf_udf/proto_builder.cpp275
1 files changed, 275 insertions, 0 deletions
diff --git a/yql/essentials/minikql/protobuf_udf/proto_builder.cpp b/yql/essentials/minikql/protobuf_udf/proto_builder.cpp
new file mode 100644
index 00000000000..ed1b154f841
--- /dev/null
+++ b/yql/essentials/minikql/protobuf_udf/proto_builder.cpp
@@ -0,0 +1,275 @@
+#include "proto_builder.h"
+
+#include <yql/essentials/public/udf/udf_value_builder.h>
+
+#include <util/generic/singleton.h>
+
+using namespace google::protobuf;
+
+namespace {
+ using namespace NYql::NUdf;
+
+ const EnumValueDescriptor& GetEnumValue(const TUnboxedValuePod& source, const FieldDescriptor& field,
+ const TProtoInfo& info, TFlags<EFieldFlag> flags) {
+ const auto* enumDescriptor = field.enum_type();
+ Y_ENSURE(enumDescriptor);
+ if (flags.HasFlags(EFieldFlag::EnumInt)) {
+ const auto number = source.Get<i64>();
+ const auto* result = enumDescriptor->FindValueByNumber(number);
+ if (!result) {
+ ythrow yexception() << "unknown value " << number
+ << " for enum type " << enumDescriptor->full_name()
+ << ", field " << field.full_name();
+ }
+ return *result;
+ } else if (flags.HasFlags(EFieldFlag::EnumString)) {
+ const TStringBuf name = source.AsStringRef();
+ for (int i = 0; i < enumDescriptor->value_count(); ++i) {
+ const auto& value = *enumDescriptor->value(i);
+ if (value.name() == name) {
+ return value;
+ }
+ }
+ ythrow yexception() << "unknown value " << name
+ << " for enum type " << enumDescriptor->full_name()
+ << ", field " << field.full_name();
+ }
+ if (info.EnumFormat == EEnumFormat::Number) {
+ const auto number = source.Get<i32>();
+ const auto* result = enumDescriptor->FindValueByNumber(number);
+ if (!result) {
+ ythrow yexception() << "unknown value " << number
+ << " for enum type " << enumDescriptor->full_name()
+ << ", field " << field.full_name();
+ }
+ return *result;
+ }
+ const TStringBuf name = source.AsStringRef();
+ for (int i = 0; i < enumDescriptor->value_count(); ++i) {
+ const auto& value = *enumDescriptor->value(i);
+ const auto& valueName = info.EnumFormat == EEnumFormat::Name ? value.name() : value.full_name();
+ if (valueName == name) {
+ return value;
+ }
+ }
+ ythrow yexception() << "unknown value " << name
+ << " for enum type " << enumDescriptor->full_name()
+ << ", field " << field.full_name();
+ }
+
+ void FillRepeatedField(const TUnboxedValuePod& source, Message& target,
+ const FieldDescriptor& field, const TProtoInfo& info, TFlags<EFieldFlag> flags) {
+ const auto& reflection = *target.GetReflection();
+ const auto iter = source.GetListIterator();
+ reflection.ClearField(&target, &field);
+ for (TUnboxedValue item; iter.Next(item);) {
+ switch (field.type()) {
+ case FieldDescriptor::TYPE_DOUBLE:
+ reflection.AddDouble(&target, &field, item.Get<double>());
+ break;
+
+ case FieldDescriptor::TYPE_FLOAT:
+ reflection.AddFloat(&target, &field, info.YtMode ? float(item.Get<double>()) : item.Get<float>());
+ break;
+
+ case FieldDescriptor::TYPE_INT64:
+ case FieldDescriptor::TYPE_SFIXED64:
+ case FieldDescriptor::TYPE_SINT64:
+ reflection.AddInt64(&target, &field, item.Get<i64>());
+ break;
+
+ case FieldDescriptor::TYPE_ENUM:
+ {
+ const auto& enumValue = GetEnumValue(item, field, info, flags);
+ reflection.AddEnum(&target, &field, &enumValue);
+ }
+ break;
+ case FieldDescriptor::TYPE_UINT64:
+ case FieldDescriptor::TYPE_FIXED64:
+ reflection.AddUInt64(&target, &field, item.Get<ui64>());
+ break;
+
+ case FieldDescriptor::TYPE_INT32:
+ case FieldDescriptor::TYPE_SFIXED32:
+ case FieldDescriptor::TYPE_SINT32:
+ reflection.AddInt32(&target, &field, item.Get<i32>());
+ break;
+
+ case FieldDescriptor::TYPE_UINT32:
+ case FieldDescriptor::TYPE_FIXED32:
+ reflection.AddUInt32(&target, &field, item.Get<ui32>());
+ break;
+
+ case FieldDescriptor::TYPE_BOOL:
+ reflection.AddBool(&target, &field, item.Get<bool>());
+ break;
+
+ case FieldDescriptor::TYPE_STRING:
+ reflection.AddString(&target, &field, TString(item.AsStringRef()));
+ break;
+
+ case FieldDescriptor::TYPE_BYTES:
+ reflection.AddString(&target, &field, TString(item.AsStringRef()));
+ break;
+
+ case FieldDescriptor::TYPE_MESSAGE:
+ {
+ auto* nestedMessage = reflection.AddMessage(&target, &field);
+ if (flags.HasFlags(EFieldFlag::Binary)) {
+ const auto& bytes = TStringBuf(item.AsStringRef());
+ Y_ENSURE(nestedMessage->ParseFromArray(bytes.data(), bytes.size()));
+ } else {
+ FillProtoFromValue(item, *nestedMessage, info);
+ }
+ }
+ break;
+
+ default:
+ ythrow yexception() << "Unsupported protobuf type: "
+ << field.type_name() << ", field: " << field.name();
+ }
+ }
+ }
+
+ void FillSingleField(const TUnboxedValuePod& source, Message& target,
+ const FieldDescriptor& field, const TProtoInfo& info, TFlags<EFieldFlag> flags) {
+ const auto& reflection = *target.GetReflection();
+ switch (field.type()) {
+ case FieldDescriptor::TYPE_DOUBLE:
+ reflection.SetDouble(&target, &field, source.Get<double>());
+ break;
+
+ case FieldDescriptor::TYPE_FLOAT:
+ reflection.SetFloat(&target, &field, info.YtMode ? float(source.Get<double>()) : source.Get<float>());
+ break;
+
+ case FieldDescriptor::TYPE_INT64:
+ case FieldDescriptor::TYPE_SFIXED64:
+ case FieldDescriptor::TYPE_SINT64:
+ reflection.SetInt64(&target, &field, source.Get<i64>());
+ break;
+
+ case FieldDescriptor::TYPE_ENUM:
+ {
+ const auto& enumValue = GetEnumValue(source, field, info, flags);
+ reflection.SetEnum(&target, &field, &enumValue);
+ }
+ break;
+
+ case FieldDescriptor::TYPE_UINT64:
+ case FieldDescriptor::TYPE_FIXED64:
+ reflection.SetUInt64(&target, &field, source.Get<ui64>());
+ break;
+
+ case FieldDescriptor::TYPE_INT32:
+ case FieldDescriptor::TYPE_SFIXED32:
+ case FieldDescriptor::TYPE_SINT32:
+ reflection.SetInt32(&target, &field, source.Get<i32>());
+ break;
+
+ case FieldDescriptor::TYPE_UINT32:
+ case FieldDescriptor::TYPE_FIXED32:
+ reflection.SetUInt32(&target, &field, source.Get<ui32>());
+ break;
+
+ case FieldDescriptor::TYPE_BOOL:
+ reflection.SetBool(&target, &field, source.Get<bool>());
+ break;
+
+ case FieldDescriptor::TYPE_STRING:
+ reflection.SetString(&target, &field, TString(source.AsStringRef()));
+ break;
+
+ case FieldDescriptor::TYPE_BYTES:
+ reflection.SetString(&target, &field, TString(source.AsStringRef()));
+ break;
+
+ case FieldDescriptor::TYPE_MESSAGE:
+ {
+ auto* nestedMessage = reflection.MutableMessage(&target, &field);
+ if (flags.HasFlags(EFieldFlag::Binary)) {
+ const auto& bytes = TStringBuf(source.AsStringRef());
+ Y_ENSURE(nestedMessage->ParseFromArray(bytes.data(), bytes.size()));
+ } else {
+ FillProtoFromValue(source, *nestedMessage, info);
+ }
+ }
+ break;
+
+ default:
+ ythrow yexception() << "Unsupported protobuf type: "
+ << field.type_name() << ", field: " << field.name();
+ }
+ }
+
+ void FillMapField(const TUnboxedValuePod& source, Message& target, const FieldDescriptor& field, const TProtoInfo& info, TFlags<EFieldFlag> flags) {
+ const auto& reflection = *target.GetReflection();
+ reflection.ClearField(&target, &field);
+ if (source) {
+ const auto noBinaryFlags = TFlags<EFieldFlag>(flags).RemoveFlags(EFieldFlag::Binary);
+ const auto iter = source.GetDictIterator();
+ for (TUnboxedValue key, value; iter.NextPair(key, value);) {
+ auto* nestedMessage = reflection.AddMessage(&target, &field);
+ const auto& descriptor = *nestedMessage->GetDescriptor();
+ FillSingleField(key, *nestedMessage, *descriptor.map_key(), info, noBinaryFlags);
+ FillSingleField(value, *nestedMessage, *descriptor.map_value(), info, flags);
+ }
+ }
+ }
+}
+
+namespace NYql::NUdf {
+
+void FillProtoFromValue(const TUnboxedValuePod& source, Message& target, const TProtoInfo& info) {
+ const auto& descriptor = *target.GetDescriptor();
+ TMessageInfo* messageInfo;
+ {
+ const auto it = info.Messages.find(descriptor.full_name());
+ if (it == info.Messages.end()) {
+ ythrow yexception() << "unknown message " << descriptor.full_name();
+ }
+ messageInfo = it->second.get();
+ }
+
+ const auto& reflection = *target.GetReflection();
+ for (int i = 0; i < descriptor.field_count(); ++i) {
+ const auto& field = *descriptor.field(i);
+ const auto it = messageInfo->Fields.find(field.number());
+ Y_ENSURE(it != messageInfo->Fields.end());
+ auto pos = it->second.Pos;
+ TFlags<EFieldFlag> flags = it->second.Flags;
+ auto fieldValue = source.GetElement(pos);
+ if (field.containing_oneof() && flags.HasFlags(EFieldFlag::Variant)) {
+ const ui32* varIndex = messageInfo->VariantIndicies.FindPtr(field.number());
+ Y_ENSURE(varIndex);
+ if (fieldValue && fieldValue.GetVariantIndex() == *varIndex) {
+ fieldValue = fieldValue.GetVariantItem();
+ } else {
+ reflection.ClearField(&target, &field);
+ continue;
+ }
+ }
+ if (flags.HasFlags(EFieldFlag::Void)) {
+ reflection.ClearField(&target, &field);
+ continue;
+ }
+
+ if (!fieldValue) {
+ if (field.is_required()) {
+ ythrow yexception() << "required field " << field.name() << " has no value";
+ }
+ reflection.ClearField(&target, &field);
+ continue;
+ }
+ if (field.is_map() && flags.HasFlags(EFieldFlag::Dict)) {
+ FillMapField(fieldValue, target, field, info, flags);
+ } else if (field.is_repeated()) {
+ FillRepeatedField(fieldValue, target, field, info, flags);
+ } else {
+ FillSingleField(fieldValue, target, field, info, flags);
+ }
+ }
+}
+
+} // namespace NYql::NUdf
+