aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/greedy_dict/gd_entry.h
blob: 0362fd9f99bda01cf083b9076ef2e071b1e5a591 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#pragma once 
 
#include "gd_stats.h" 
 
#include <library/cpp/containers/comptrie/comptrie.h>
 
#include <util/generic/ptr.h> 
#include <util/generic/strbuf.h> 
#include <util/generic/vector.h> 
 
#include <util/memory/pool.h> 
 
namespace NGreedyDict { 
    using TStringBufs = TVector<TStringBuf>;
 
    struct TEntry {
        static const i32 NoPrefix = -1;
 
        TStringBuf Str;
 
        i32 NearestPrefix = NoPrefix;
        ui32 Count = 0;
        ui32 Number = 0;
        float ModelP = 0;
        float Score = 0;
 
        TEntry(TStringBuf b = TStringBuf(), ui32 cnt = 0)
            : Str(b)
            , Count(cnt)
        {
        }
 
        bool HasPrefix() const {
            return NearestPrefix != NoPrefix;
        }
        ui32 Len() const {
            return Str.size();
        }
 
        static bool StrLess(const TEntry& a, const TEntry& b) {
            return a.Str < b.Str;
        }
        static bool NumberLess(const TEntry& a, const TEntry& b) {
            return a.Number < b.Number;
        }
        static bool ScoreMore(const TEntry& a, const TEntry& b) {
            return a.Score > b.Score;
        }
    };
 
    class TEntrySet: public TVector<TEntry>, TNonCopyable {
        TMemoryPool Pool{8112};
        TCompactTrie<char, ui32, TAsIsPacker<ui32>> Trie;
 
    public:
        ui32 TotalCount = 0;
 
        void InitWithAlpha();
 
        void Add(TStringBuf a) {
            push_back(TStringBuf(Pool.Append(a.data(), a.size()), a.size()));
        }
 
        void Add(TStringBuf a, TStringBuf b) {
            size_t sz = a.size() + b.size();
            char* p = (char*)Pool.Allocate(sz);
            memcpy(p, a.data(), a.size());
            memcpy(p + a.size(), b.data(), b.size());
            push_back(TStringBuf(p, sz));
        }
 
        TEntry& Get(ui32 idx) {
            return (*this)[idx];
        }
 
        const TEntry& Get(ui32 idx) const {
            return (*this)[idx];
        }
 
        void BuildHierarchy();
 
        // longest prefix
        TEntry* FindPrefix(TStringBuf& str);
 
        const TEntry* FindPrefix(TStringBuf& str) const {
            return ((TEntrySet*)this)->FindPrefix(str);
        }
 
        const TEntry* FirstPrefix(const TEntry& e, TStringBuf& suff) {
            if (!e.HasPrefix())
                return nullptr;
 
            const TEntry& p = Get(e.NearestPrefix);
            suff = e.Str;
            suff.Skip(p.Str.size());
            return &p;
        }
 
        void SetModelP();
        void SetScores(EEntryScore);
    };
 
}