aboutsummaryrefslogtreecommitdiffstats
path: root/yql/essentials/minikql/protobuf_udf/module.h
blob: 348a13678e6bd32e85c649db797b7165532a1d4a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#pragma once

#include "type_builder.h"
#include "value_builder.h"

#include <yql/essentials/public/udf/udf_value.h>
#include <yql/essentials/public/udf/udf_registrator.h>

#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>

namespace NYql {
namespace NUdf {

class TProtobufBase : public IUdfModule {
public:
    void CleanupOnTerminate() const override;

    void GetAllFunctions(IFunctionsSink& sink) const override;

    void BuildFunctionTypeInfo(
            const TStringRef& name,
            TType* userType,
            const TStringRef& typeConfig,
            ui32 flags,
            IFunctionTypeInfoBuilder& builder) const override;

protected:
    virtual const NProtoBuf::Descriptor* GetDescriptor() const = 0;

    virtual TProtobufValue* CreateValue(const TProtoInfo& info, bool asText) const = 0;

    virtual TProtobufSerialize* CreateSerialize(const TProtoInfo& info, bool asText) const = 0;
};


template <typename T>
class TProtobufModule : public TProtobufBase {

    class TValue : public TProtobufValue {
    public:
        TValue(const TProtoInfo& info, bool asText)
            : TProtobufValue(info)
            , AsText_(asText)
        {
        }

        TAutoPtr<NProtoBuf::Message> Parse(const TStringBuf& data) const override {
            TAutoPtr<T> proto(new T);
            if (AsText_) {
                NProtoBuf::io::ArrayInputStream si(data.data(), data.size());
                if (!NProtoBuf::TextFormat::Parse(&si, proto.Get())) {
                    ythrow yexception() << "can't parse text protobuf";
                }
            } else {
                if (!proto->ParseFromArray(data.data(), data.size())) {
                    ythrow yexception() << "can't parse binary protobuf";
                }
            }
            return proto.Release();
        }

    private:
        const bool AsText_;
    };

    class TSerialize : public TProtobufSerialize {
    public:
        TSerialize(const TProtoInfo& info, bool asText)
            : TProtobufSerialize(info)
            , AsText_(asText)
        {
        }

        TMaybe<TString> Serialize(const NProtoBuf::Message& proto) const override {
            TString result;
            if (AsText_) {
                if (!NProtoBuf::TextFormat::PrintToString(proto, &result)) {
                    ythrow yexception() << "can't serialize prototext message";
                }
            } else {
                result.ReserveAndResize(proto.ByteSize());
                if (!proto.SerializeToArray(result.begin(), result.size())) {
                    ythrow yexception() << "can't serialize protobin message";
                }
            }
            return result;
        }

        TAutoPtr<NProtoBuf::Message> MakeProto() const override {
            return TAutoPtr<NProtoBuf::Message>(new T);
        }
    private:
        const bool AsText_;
    };

private:
    const NProtoBuf::Descriptor* GetDescriptor() const override {
        return T::descriptor();
    }

    TProtobufValue* CreateValue(const TProtoInfo& info, bool asText) const override {
        return new TValue(info, asText);
    }

    TProtobufSerialize* CreateSerialize(const TProtoInfo& info, bool asText) const override {
        return new TSerialize(info, asText);
    }
};

} // namespace NUdf
} // namespace NYql