aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/greedy_dict/gd_entry.h
blob: 42482552a34bd04b4cb8e470c40a0b4306a59757 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#pragma once

#include "gd_stats.h"

#include <library/cpp/containers/comptrie/comptrie.h> 

#include <util/generic/ptr.h>
#include <util/generic/strbuf.h>
#include <util/generic/vector.h>

#include <util/memory/pool.h>

namespace NGreedyDict {
    using TStringBufs = TVector<TStringBuf>;

    struct TEntry {
        static const i32 NoPrefix = -1;

        TStringBuf Str;

        i32 NearestPrefix = NoPrefix;
        ui32 Count = 0;
        ui32 Number = 0;
        float ModelP = 0;
        float Score = 0;

        TEntry(TStringBuf b = TStringBuf(), ui32 cnt = 0)
            : Str(b)
            , Count(cnt)
        {
        }

        bool HasPrefix() const {
            return NearestPrefix != NoPrefix;
        }
        ui32 Len() const {
            return Str.size();
        }

        static bool StrLess(const TEntry& a, const TEntry& b) {
            return a.Str < b.Str;
        }
        static bool NumberLess(const TEntry& a, const TEntry& b) {
            return a.Number < b.Number;
        }
        static bool ScoreMore(const TEntry& a, const TEntry& b) {
            return a.Score > b.Score;
        }
    };

    class TEntrySet: public TVector<TEntry>, TNonCopyable {
        TMemoryPool Pool{8112};
        TCompactTrie<char, ui32, TAsIsPacker<ui32>> Trie;

    public:
        ui32 TotalCount = 0;

        void InitWithAlpha();

        void Add(TStringBuf a) {
            push_back(TStringBuf(Pool.Append(a.data(), a.size()), a.size()));
        }

        void Add(TStringBuf a, TStringBuf b) {
            size_t sz = a.size() + b.size();
            char* p = (char*)Pool.Allocate(sz);
            memcpy(p, a.data(), a.size());
            memcpy(p + a.size(), b.data(), b.size());
            push_back(TStringBuf(p, sz));
        }

        TEntry& Get(ui32 idx) {
            return (*this)[idx];
        }

        const TEntry& Get(ui32 idx) const {
            return (*this)[idx];
        }

        void BuildHierarchy();

        // longest prefix
        TEntry* FindPrefix(TStringBuf& str);

        const TEntry* FindPrefix(TStringBuf& str) const {
            return ((TEntrySet*)this)->FindPrefix(str);
        }

        const TEntry* FirstPrefix(const TEntry& e, TStringBuf& suff) {
            if (!e.HasPrefix())
                return nullptr;

            const TEntry& p = Get(e.NearestPrefix);
            suff = e.Str;
            suff.Skip(p.Str.size());
            return &p;
        }

        void SetModelP();
        void SetScores(EEntryScore);
    };

}