aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/on_disk/aho_corasick/reader.h
blob: e5db58685bc14369528bec1c2ffb496a9c35190f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
#pragma once

#include <util/generic/deque.h>
#include <util/generic/strbuf.h>
#include <util/generic/yexception.h>
#include <util/memory/blob.h>
#include <util/stream/buffer.h>
#include <util/stream/mem.h>
#include <util/system/unaligned_mem.h>
#include <utility>

#include <library/cpp/on_disk/chunks/chunked_helpers.h>

#include "common.h"

template <class O>
class TAhoSearchResult: public TDeque<std::pair<ui32, O>> {
};

/*
 * Mapped-declaraion
 */

template <class O>
class TMappedDefaultOutputContainer {
private:
    TGeneralVector<O> List_;

public:
    TMappedDefaultOutputContainer(const char* data)
        : List_(TBlob::NoCopy(data, (size_t)-1))
    {
    }

    bool IsEmpty() const {
        return List_.GetSize() == 0;
    }

    void FillAnswer(TAhoSearchResult<O>& answer, ui32 pos) const {
        for (ui32 i = 0; i < List_.GetSize(); ++i) {
            answer.push_back(std::make_pair(pos, O()));
            List_.Get(i, answer.back().second);
        }
    }

    size_t CheckData() const {
        return List_.RealSize();
    }
};

template <class O>
class TMappedSingleOutputContainer {
    const ui32* Data;

    ui32 Size() const {
        return ReadUnaligned<ui32>(Data);
    }

public:
    TMappedSingleOutputContainer(const char* data)
        : Data((const ui32*)data)
    {
    }

    bool IsEmpty() const {
        return Size() == 0;
    }

    void FillAnswer(TAhoSearchResult<O>& answer, ui32 pos) const {
        if (!IsEmpty()) {
            answer.push_back(std::make_pair(pos, O()));
            TMemoryInput input(Data + 1, Size());
            TSaveLoadVectorNonPodElement<O>::Load(&input, answer.back().second, Size());
        }
    }

    size_t CheckData() const {
        return sizeof(ui32) + ReadUnaligned<ui32>(Data);
    }
};

template <class TStringType, class O, class C>
class TMappedAhoCorasick;

template <typename TKey, typename TValue>
class TEmptyMapData : TNonCopyable {
private:
    TBufferStream Stream;

public:
    const char* P;

    TEmptyMapData() {
        TPlainHashWriter<TKey, TValue> hash;
        hash.Save(Stream);
        P = Stream.Buffer().Begin();
    }
};

/*
 * каждая вершина имеет свой ui32-номер
 * блок данных для вершины:
 * ui32, ui32, ui32, ui32, степень*char, данные контейнера
 * fail, suff, степень, самый левый сын, лексикографический список меток исходящих рёбер.
 * если степень нулевая, то в блоке только 3 инта
 */
template <class TStringType, class O, class C>
class TMappedAhoVertex {
public:
    typedef typename TStringType::value_type TCharType;
    friend class TMappedAhoCorasick<TStringType, O, C>;

private:
    const char* Data;
    typedef TPlainHash<TCharType, ui32> TGotoMap;
    TGotoMap GotoMap;
    static const TEmptyMapData<TCharType, ui32> EmptyData;

    static const size_t GENERAL_SHIFT = 3 * sizeof(ui32);

private:
    const ui32* DataAsInt() const {
        return (const ui32*)Data;
    }

    ui32 Power() const {
        return ReadUnaligned<ui32>(DataAsInt() + 2);
    }

protected:
    const C Output() const {
        return C(Power() ? GotoMap.ByteEnd() : Data + GENERAL_SHIFT);
    }

    ui32 Fail() const {
        return ReadUnaligned<ui32>(DataAsInt());
    }

    ui32 Suffix() const {
        return ReadUnaligned<ui32>(DataAsInt() + 1);
    }

    bool GotoFunction(const TCharType c, ui32* result) const {
        if (0 == Power())
            return false;
        return GotoMap.Find(c, result);
    }

    bool operator==(const TMappedAhoVertex& rhs) const {
        return Data == rhs.Data;
    }

    size_t CheckData(ui32 totalVertices) const; /// throws yexception in case of bad data

public:
    TMappedAhoVertex(const char* data)
        : Data(data)
        , GotoMap(Power() ? (Data + GENERAL_SHIFT) : EmptyData.P)
    {
    }
};

/*
 * блок данных для бора:
 *   количество вершин N, ui32
 *   суммарный размер блоков для вершин, ui32
 *   блоки данных для каждой вершины
 *   отображение id->offset для блока вершины id, N*ui32
 */
template <class TStringType, class O, class C = TMappedDefaultOutputContainer<O>>
class TMappedAhoCorasick : TNonCopyable {
public:
    typedef TAhoSearchResult<O> TSearchResult;
    typedef TMappedAhoVertex<TStringType, O, C> TAhoVertexType;
    typedef typename TStringType::value_type TCharType;
    typedef TBasicStringBuf<TCharType> TSample;

private:
    const TBlob Blob;
    const char* const AhoVertexes;
    const ui32 VertexAmount;
    const ui32* const Id2Offset;
    const TAhoVertexType Root;

private:
    bool ValidVertex(ui32 id) const {
        return id < VertexAmount;
    }

    TAhoVertexType GetVertexAt(ui32 id) const {
        if (!ValidVertex(id))
            ythrow yexception() << "TMappedAhoCorasick fatal error: invalid id " << id;
        return TAhoVertexType(AhoVertexes + Id2Offset[id]);
    }

public:
    TMappedAhoCorasick(const TBlob& blob)
        : Blob(blob)
        , AhoVertexes(GetBlock(blob, 1).AsCharPtr())
        , VertexAmount(TSingleValue<ui32>(GetBlock(blob, 2)).Get())
        , Id2Offset((const ui32*)(GetBlock(Blob, 3).AsCharPtr()))
        , Root(GetVertexAt(0))
    {
        {
            const ui32 version = TSingleValue<ui32>(GetBlock(blob, 0)).Get();
            if (version != TAhoCorasickCommon::GetVersion())
                ythrow yexception() << "Unknown version " << version << " instead of " << TAhoCorasickCommon::GetVersion();
        }
        {
            TChunkedDataReader reader(blob);
            if (reader.GetBlocksCount() != TAhoCorasickCommon::GetBlockCount())
                ythrow yexception() << "wrong block count " << reader.GetBlocksCount();
        }
    }

    bool AhoContains(const TSample& str) const;
    TSearchResult AhoSearch(const TSample& str) const;
    void AhoSearch(const TSample& str, TSearchResult* result) const;
    size_t CheckData() const; /// throws yexception in case of bad data
};

using TSimpleMappedAhoCorasick = TMappedAhoCorasick<TString, ui32, TMappedSingleOutputContainer<ui32>>;
using TDefaultMappedAhoCorasick = TMappedAhoCorasick<TString, ui32>;

/*
 * Mapped-implementation
 */
template <class TStringType, class O, class C>
bool TMappedAhoCorasick<TStringType, O, C>::AhoContains(const TSample& str) const {
    TAhoVertexType current = Root;
    const size_t len = str.size();
    for (size_t i = 0; i < len; ++i) {
        bool outer = false;
        ui32 gotoVertex;
        while (!current.GotoFunction(str[i], &gotoVertex)) {
            if (current == Root) { /// nowhere to go
                outer = true;
                break;
            }
            current = GetVertexAt(current.Fail());
        }
        if (outer)
            continue;
        current = GetVertexAt(gotoVertex);

        TAhoVertexType v = current;
        while (true) {
            if (!v.Output().IsEmpty())
                return true;
            if ((ui32)-1 == v.Suffix())
                break;
            v = GetVertexAt(v.Suffix());
        }
    }
    return false;
}

template <class TStringType, class O, class C>
void TMappedAhoCorasick<TStringType, O, C>::AhoSearch(const TSample& str, typename TMappedAhoCorasick<TStringType, O, C>::TSearchResult* answer) const {
    answer->clear();
    TAhoVertexType current = Root;
    const size_t len = str.length();
    for (size_t i = 0; i < len; ++i) {
        bool outer = false;
        ui32 gotoVertex;
        while (!current.GotoFunction(str[i], &gotoVertex)) {
            if (current == Root) { /// nowhere to go
                outer = true;
                break;
            }
            current = GetVertexAt(current.Fail());
        }
        if (outer)
            continue;
        current = GetVertexAt(gotoVertex);

        TAhoVertexType v = current;
        while (true) {
            v.Output().FillAnswer(*answer, (ui32)i);
            if ((ui32)-1 == v.Suffix())
                break;
            v = GetVertexAt(v.Suffix());
        }
    }
}

template <class TStringType, class O, class C>
typename TMappedAhoCorasick<TStringType, O, C>::TSearchResult TMappedAhoCorasick<TStringType, O, C>::AhoSearch(const TSample& str) const {
    TAhoSearchResult<O> answer;
    AhoSearch(str, &answer);
    return answer;
}

/*
 * implementation of CheckData in Mapped-classes
 */

static inline void CheckRange(ui32 id, ui32 strictUpperBound) {
    if (id >= strictUpperBound) {
        throw yexception() << id << " of " << strictUpperBound << " - index is invalid";
    }
}

template <class TStringType, class O, class C>
const TEmptyMapData<typename TStringType::value_type, ui32> TMappedAhoVertex<TStringType, O, C>::EmptyData;

template <class TStringType, class O, class C>
size_t TMappedAhoVertex<TStringType, O, C>::CheckData(ui32 totalVertices) const {
    size_t bytesNeeded = GENERAL_SHIFT;
    CheckRange(Fail(), totalVertices);
    if (Suffix() != (ui32)(-1))
        CheckRange(Suffix(), totalVertices);
    if (Power()) {
        for (typename TGotoMap::TConstIterator toItem = GotoMap.Begin(); toItem != GotoMap.End(); ++toItem)
            CheckRange(toItem->Second(), totalVertices);
        bytesNeeded += GotoMap.ByteSize();
    }
    bytesNeeded += Output().CheckData();
    return bytesNeeded;
}

template <class TStringType, class O, class C>
size_t TMappedAhoCorasick<TStringType, O, C>::CheckData() const {
    try {
        size_t bytesNeeded = 0;
        for (ui32 id = 0; id < VertexAmount; ++id) {
            if (Id2Offset[id] != bytesNeeded) {
                ythrow yexception() << "wrong offset[" << id << "]: " << Id2Offset[id];
            }
            bytesNeeded += GetVertexAt(id).CheckData(VertexAmount);
        }
        bytesNeeded += VertexAmount * sizeof(ui32);
        const size_t realsize = GetBlock(Blob, 1).Size() + GetBlock(Blob, 3).Size();
        if (realsize != bytesNeeded) {
            ythrow yexception() << "extra information " << bytesNeeded << " " << realsize;
        }
        return bytesNeeded;
    } catch (const yexception& e) {
        ythrow yexception() << "Bad data: " << e.what();
    }
}