aboutsummaryrefslogtreecommitdiffstats
path: root/yt/cpp/mapreduce/io/proto_helpers.cpp
blob: 392ed593e8d23602cba53a66f31112d09deed6e4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include "proto_helpers.h"

#include <yt/yt/core/misc/protobuf_helpers.h>

#include <yt/cpp/mapreduce/interface/io.h>
#include <yt/cpp/mapreduce/interface/fluent.h>

#include <yt/yt_proto/yt/formats/extension.pb.h>

#include <google/protobuf/descriptor.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/io/coded_stream.h>

#include <util/stream/str.h>
#include <util/stream/file.h>
#include <util/folder/path.h>

namespace NYT {

using ::google::protobuf::Message;
using ::google::protobuf::Descriptor;
using ::google::protobuf::DescriptorPool;

using ::google::protobuf::io::CodedInputStream;

////////////////////////////////////////////////////////////////////////////////

namespace {

TVector<const Descriptor*> GetJobDescriptors(const TString& fileName)
{
    TVector<const Descriptor*> descriptors;
    if (!TFsPath(fileName).Exists()) {
        ythrow TIOException() <<
            "Cannot load '" << fileName << "' file";
    }

    TIFStream input(fileName);
    TString line;
    while (input.ReadLine(line)) {
        const auto* pool = DescriptorPool::generated_pool();
        const auto* descriptor = pool->FindMessageTypeByName(line);
        descriptors.push_back(descriptor);
    }

    return descriptors;
}

} // namespace

////////////////////////////////////////////////////////////////////////////////

TVector<const Descriptor*> GetJobInputDescriptors()
{
    return GetJobDescriptors("proto_input");
}

TVector<const Descriptor*> GetJobOutputDescriptors()
{
    return GetJobDescriptors("proto_output");
}

void ValidateProtoDescriptor(
    const Message& row,
    size_t tableIndex,
    const TVector<const Descriptor*>& descriptors,
    bool isRead)
{
    const char* direction = isRead ? "input" : "output";

    if (tableIndex >= descriptors.size()) {
        ythrow TIOException() <<
            "Table index " << tableIndex <<
            " is out of range [0, " << descriptors.size() <<
            ") in " << direction;
    }

    if (row.GetDescriptor() != descriptors[tableIndex]) {
        ythrow TIOException() <<
            "Invalid row of type " << row.GetDescriptor()->full_name() <<
            " at index " << tableIndex <<
            ", row of type " << descriptors[tableIndex]->full_name() <<
            " expected in " << direction;
    }
}

void ParseFromArcadiaStream(IInputStream* stream, Message& row, ui32 length)
{
    TLengthLimitedInput input(stream, length);
    TProtobufInputStreamAdaptor adaptor(&input);
    CodedInputStream codedStream(&adaptor);
    codedStream.SetTotalBytesLimit(length + 1);
    bool parsedOk = row.ParseFromCodedStream(&codedStream);
    Y_ENSURE(parsedOk, "Failed to parse protobuf message");

    Y_ENSURE(input.Left() == 0);
}

////////////////////////////////////////////////////////////////////////////////

} // namespace NYT