aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/scheme/scimpl_protobuf.cpp
blob: 0c99122c69ac0413c40d2205e7cf6ed56d5de624 (plain) (tree)
1
2
3
4
5
6
7
8



                                    

                                       


                                 
                                                             
                                                  

                                                     
 
                                    
                                                                           

                                               
                                                       
                                                      








                                                                                     



                                                             
             
                 
     

                                                                                
 
                                    






























                                                                                                                             
         
 
                 
     

                                                                                                   
 
                                    






























                                                                                                                             
         
 
                 
     
                                                                 











                                                                            
                                  
                                             
                                              
                                                  
                    
                                          


             
                                                                                                                              
                                            








                                                                                        























                                                                           
                                                    
                                                  
                                                                         



                                                                     

         
                                                                                                    
                                            
                                              
     
                                                                                                        

                                                              
                                                                                                         
                              



                                                                                             









                                                             
                                                                                                            
                                            





                                             



                                                                                                           


                                                           
                                                           
                                        























                                                                               
                                                        






                                                                         


             
                                                                                                       







                                              
                                              


                                   



                                                                                                           




                                                                                         
                                                              



                                                                                    
                                                                 

                                     
 
#include "scheme.h"

#include <util/generic/vector.h>
#include <util/generic/yexception.h>

#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
#include <google/protobuf/reflection.h>

using namespace google::protobuf;

namespace NSc {
    TValue TValue::From(const Message& msg, bool mapAsDict) {
        TValue v;
        const Reflection* r = msg.GetReflection();
        TVector<const FieldDescriptor*> fields;
        TVector<const FieldDescriptor*>::iterator it;
        int i1;

        r->ListFields(msg, &fields);
        for (it = fields.begin(), i1 = 0; it != fields.end(); ++it, ++i1) {
            const FieldDescriptor* field = *it;
            try {
                if (field->is_repeated()) {
                    if (field->is_map() && mapAsDict) {
                        auto& elem = v[field->name()];
                        for (int i2 = 0; i2 < r->FieldSize(msg, field); ++i2) {
                            auto val = FromRepeatedField(msg, field, i2);
                            if (val.IsDict()) {
                                elem[TStringBuf(val["key"])] = val["value"];
                            }
                        }
                    } else {
                        for (int i2 = 0; i2 < r->FieldSize(msg, field); ++i2)
                            v[field->name()][i2] = FromRepeatedField(msg, field, i2);
                    }
                } else {
                    v[field->name()] = FromField(msg, field);
                }
            } catch (...) {
                /* conversion failed, skip this field */
            }
        }

        return v;
    }

    TValue TValue::FromField(const Message& msg, const FieldDescriptor* field) {
        TValue v;
        const Reflection* r = msg.GetReflection();

        switch (field->cpp_type()) {
            case FieldDescriptor::CPPTYPE_INT32:
                v = r->GetInt32(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_INT64:
                v = r->GetInt64(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_UINT32:
                v = r->GetUInt32(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_UINT64:
                v = r->GetUInt64(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_DOUBLE:
                v = r->GetDouble(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_FLOAT:
                v = r->GetFloat(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_BOOL:
                v.SetBool(r->GetBool(msg, field));
                break;
            case FieldDescriptor::CPPTYPE_ENUM:
                v = r->GetEnum(msg, field)->name();
                break;
            case FieldDescriptor::CPPTYPE_STRING:
                v = r->GetString(msg, field);
                break;
            case FieldDescriptor::CPPTYPE_MESSAGE:
                v = From(r->GetMessage(msg, field));
                break;
            default:
                ythrow TSchemeException() << "field " << field->full_name() << " unexpected type " << (int)field->cpp_type();
        }

        return v;
    }

    TValue TValue::FromRepeatedField(const Message& msg, const FieldDescriptor* field, int index) {
        TValue v;
        const Reflection* r = msg.GetReflection();

        switch (field->cpp_type()) {
            case FieldDescriptor::CPPTYPE_INT32:
                v = r->GetRepeatedInt32(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_INT64:
                v = r->GetRepeatedInt64(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_UINT32:
                v = r->GetRepeatedUInt32(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_UINT64:
                v = r->GetRepeatedUInt64(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_DOUBLE:
                v = r->GetRepeatedDouble(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_FLOAT:
                v = r->GetRepeatedFloat(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_BOOL:
                v.SetBool(r->GetRepeatedBool(msg, field, index));
                break;
            case FieldDescriptor::CPPTYPE_ENUM:
                v = r->GetRepeatedEnum(msg, field, index)->name();
                break;
            case FieldDescriptor::CPPTYPE_STRING:
                v = r->GetRepeatedString(msg, field, index);
                break;
            case FieldDescriptor::CPPTYPE_MESSAGE:
                v = From(r->GetRepeatedMessage(msg, field, index));
                break;
            default:
                ythrow TSchemeException() << "field " << field->full_name() << " unexpected type " << (int)field->cpp_type();
        }

        return v;
    }

    void TValue::To(Message& msg, const TProtoOpts& opts) const {
        msg.Clear();

        if (IsNull()) {
            return;
        }

        if (!IsDict()) {
            ythrow TSchemeException() << "expected dictionary";
        }

        const Descriptor* descriptor = msg.GetDescriptor();
        for (int i = 0, count = descriptor->field_count(); i < count; ++i) {
            const FieldDescriptor* field = descriptor->field(i);
            if (field->is_map()) {
                ToMapField(msg, field, opts);
            } else if (field->is_repeated()) {
                ToRepeatedField(msg, field, opts);
            } else {
                ToField(msg, field, opts);
            }
        }
    }

    void TValue::ValueToField(const TValue& value, Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
        const TString& name = field->name();
        if (value.IsNull()) {
            if (field->is_required() && !field->has_default_value()) {
                ythrow TSchemeException() << "has no value for required field " << name;
            }
            return;
        }

        const Reflection* reflection = msg.GetReflection();

        switch (field->cpp_type()) {
            case FieldDescriptor::CPPTYPE_INT32:
                reflection->SetInt32(&msg, field, value.ForceIntNumber());
                break;
            case FieldDescriptor::CPPTYPE_INT64:
                reflection->SetInt64(&msg, field, value.ForceIntNumber());
                break;
            case FieldDescriptor::CPPTYPE_UINT32:
                reflection->SetUInt32(&msg, field, value.ForceIntNumber());
                break;
            case FieldDescriptor::CPPTYPE_UINT64:
                reflection->SetUInt64(&msg, field, value.ForceIntNumber());
                break;
            case FieldDescriptor::CPPTYPE_DOUBLE:
                reflection->SetDouble(&msg, field, value.ForceNumber());
                break;
            case FieldDescriptor::CPPTYPE_FLOAT:
                reflection->SetFloat(&msg, field, value.ForceNumber());
                break;
            case FieldDescriptor::CPPTYPE_BOOL:
                reflection->SetBool(&msg, field, value.IsTrue());
                break;
            case FieldDescriptor::CPPTYPE_STRING:
                reflection->SetString(&msg, field, value.ForceString());
                break;
            case FieldDescriptor::CPPTYPE_ENUM:
                value.ToEnumField(msg, field, opts);
                break;
            case FieldDescriptor::CPPTYPE_MESSAGE:
                value.To(*reflection->MutableMessage(&msg, field), opts);
                break;
            default:
                ythrow TSchemeException()
                    << "field " << field->full_name()
                    << " unexpected type " << (int)field->cpp_type();
        }
    }

    void TValue::ToField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
        const TString& name = field->name();
        const TValue& value = Get(name);
        ValueToField(value, msg, field, opts);
    }

    void TValue::ToEnumField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
        const EnumDescriptor* enumField = field->enum_type();

        const EnumValueDescriptor* enumFieldValue = IsString()
                                                        ? enumField->FindValueByName(ForceString())
                                                        : enumField->FindValueByNumber(ForceIntNumber());

        if (!enumFieldValue) {
            if (opts.UnknownEnumValueIsDefault) {
                enumFieldValue = field->default_value_enum();
            } else {
                ythrow TSchemeException() << "invalid value of enum field " << field->name();
            }
        }

        const Reflection* reflection = msg.GetReflection();

        if (field->is_repeated()) {
            reflection->AddEnum(&msg, field, enumFieldValue);
        } else {
            reflection->SetEnum(&msg, field, enumFieldValue);
        }
    }

    void TValue::ToRepeatedField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
        const TString& name = field->name();

        const TValue& fieldValue = Get(name);
        if (fieldValue.IsNull()) {
            return;
        }

        if (!fieldValue.IsArray()) {
            if (opts.SkipTypeMismatch) {
                return; // leave repeated field empty
            } else {
                ythrow TSchemeException() << "invalid type of repeated field " << name << ": not an array";
            }
        }

        const Reflection* reflection = msg.GetReflection();

        for (const TValue& value : fieldValue.GetArray()) {
            switch (field->cpp_type()) {
                case FieldDescriptor::CPPTYPE_INT32:
                    reflection->AddInt32(&msg, field, value.ForceIntNumber());
                    break;
                case FieldDescriptor::CPPTYPE_INT64:
                    reflection->AddInt64(&msg, field, value.ForceIntNumber());
                    break;
                case FieldDescriptor::CPPTYPE_UINT32:
                    reflection->AddUInt32(&msg, field, value.ForceIntNumber());
                    break;
                case FieldDescriptor::CPPTYPE_UINT64:
                    reflection->AddUInt64(&msg, field, value.ForceIntNumber());
                    break;
                case FieldDescriptor::CPPTYPE_DOUBLE:
                    reflection->AddDouble(&msg, field, value.ForceNumber());
                    break;
                case FieldDescriptor::CPPTYPE_FLOAT:
                    reflection->AddFloat(&msg, field, value.ForceNumber());
                    break;
                case FieldDescriptor::CPPTYPE_BOOL:
                    reflection->AddBool(&msg, field, value.IsTrue());
                    break;
                case FieldDescriptor::CPPTYPE_STRING:
                    reflection->AddString(&msg, field, value.ForceString());
                    break;
                case FieldDescriptor::CPPTYPE_ENUM:
                    value.ToEnumField(msg, field, opts);
                    break;
                case FieldDescriptor::CPPTYPE_MESSAGE:
                    value.To(*reflection->AddMessage(&msg, field));
                    break;
                default:
                    ythrow TSchemeException()
                        << "field " << field->full_name()
                        << " unexpected type " << (int)field->cpp_type();
            }
        }
    }

    void TValue::ToMapField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
        const TString& name = field->name();

        const TValue& fieldValue = Get(name);
        if (fieldValue.IsNull()) {
            return;
        }

        if (fieldValue.IsArray()) {
            // read dict from key, value array
            ToRepeatedField(msg, field, opts);
            return;
        }

        if (!fieldValue.IsDict()) {
            if (opts.SkipTypeMismatch) {
                return; // leave map field empty
            } else {
                ythrow TSchemeException() << "invalid type of map field " << name << ": not dict or array";
            }
        }

        const Reflection* reflection = msg.GetReflection();

        auto mutableField = reflection->GetMutableRepeatedFieldRef<Message>(&msg, field);
        for (const auto& value : fieldValue.GetDict()) {
            THolder<Message> entry(mutableField.NewMessage());
            auto entryDesc = entry->GetDescriptor();
            auto keyField = entryDesc->FindFieldByNumber(1);
            auto valueField = entryDesc->FindFieldByNumber(2);
            auto entryReflection = entry->GetReflection();
            entryReflection->SetString(entry.Get(), keyField, TString(value.first));
            ValueToField(value.second, *entry, valueField, opts);
            mutableField.Add(*entry);
        }
    }
}