aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/greedy_dict/gd_entry.cpp
blob: 0603a9fca85adfc8beedeb04cc22079df153d39d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "gd_entry.h"
#include "gd_stats.h"

#include <util/generic/algorithm.h>
#include <util/generic/singleton.h>

namespace NGreedyDict {
    class TAlphas { 
        char Memory[512]; 

    public: 
        TStringBufs Alphas; 

        TAlphas() { 
            for (ui32 i = 0; i < 256; ++i) { 
                Memory[2 * i] = (char)i; 
                Memory[2 * i + 1] = 0; 

                Alphas.push_back(TStringBuf(&Memory[2 * i], 1)); 
            } 
        } 
    }; 

    void TEntrySet::InitWithAlpha() { 
        Pool.ClearKeepFirstChunk(); 
        const TStringBufs& a = Singleton<TAlphas>()->Alphas; 
        for (auto it : a) { 
            Add(it); 
        }
        BuildHierarchy(); 
    }

    void TEntrySet::BuildHierarchy() { 
        Sort(begin(), end(), TEntry::StrLess); 

        TCompactTrieBuilder<char, ui32, TAsIsPacker<ui32>> builder(CTBF_PREFIX_GROUPED); 

        for (iterator it = begin(); it != end(); ++it) { 
            it->Number = (it - begin()); 
            TStringBuf suff = it->Str; 
            size_t len = 0; 
            ui32 val = 0; 

            if (builder.FindLongestPrefix(suff.data(), suff.size(), &len, &val) && len) {
                it->NearestPrefix = val; 
            } 

            builder.Add(suff.data(), suff.size(), it->Number);
        }

        TBufferOutput bout; 
        builder.Save(bout); 
        Trie.Init(TBlob::FromBuffer(bout.Buffer())); 
    }

    TEntry* TEntrySet::FindPrefix(TStringBuf& str) { 
        size_t len = 0; 
        ui32 off = 0; 

        if (!Trie.FindLongestPrefix(str, &len, &off)) { 
            return nullptr; 
        } 

        str.Skip(len); 
        return &Get(off); 
    }

    void TEntrySet::SetModelP() { 
        for (iterator it = begin(); it != end(); ++it) { 
            TEntry& e = *it; 

            if (!e.HasPrefix()) { 
                e.ModelP = 0; 
                continue; 
            } 

            TStringBuf suff = e.Str; 
            const TEntry& p = Get(e.NearestPrefix); 
            suff.Skip(p.Len()); 

            float modelp = float(p.Count + e.Count) / TotalCount; 

            while (!!suff) { 
                TEntry* pp = FindPrefix(suff); 
                modelp *= float(pp->Count + e.Count) / TotalCount; 
            } 

            e.ModelP = modelp; 
        }
    }

    void TEntrySet::SetScores(EEntryScore s) { 
        for (auto& it : *this) { 
            it.Score = Score(s, it.Len(), it.ModelP, it.Count, TotalCount); 
        } 
    }

}