aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/greedy_dict/gd_builder.cpp
blob: 1c25299b11a57103bc1d7e4c6c36d92e85ff0a05 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include "gd_builder.h"

#include <library/cpp/string_utils/relaxed_escaper/relaxed_escaper.h>
#include <util/generic/algorithm.h>

#include <util/random/shuffle.h>
#include <util/stream/output.h> 
#include <util/string/printf.h>
#include <util/system/rusage.h>

namespace NGreedyDict {
    void TDictBuilder::RebuildCounts(ui32 maxcand, bool final) {
        if (!Current) {
            Current = MakeHolder<TEntrySet>();
            Current->InitWithAlpha();
        }

        TEntrySet& set = *Current;

        for (auto& it : set)
            it.Count = 0;

        CompoundCounts = nullptr;
        CompoundCountsPool.Clear();

        if (!final) {
            CompoundCounts = MakeHolder<TCompoundCounts>(&CompoundCountsPool);
            CompoundCounts->reserve(maxcand);
        }

        Shuffle(Input.begin(), Input.end(), Rng);

        for (auto str : Input) {
            if (!final && CompoundCounts->size() > maxcand)
                break;

            i32 prev = -1;

            while (!!str) {
                TEntry* e = set.FindPrefix(str);
                ui32 num = e->Number;

                e->Count += 1;
                if (!final && prev >= 0) {
                    (*CompoundCounts)[Compose(prev, num)] += 1;
                }

                prev = num;
                ++set.TotalCount;
            }
        }

        Current->SetModelP();
    }

    ui32 TDictBuilder::BuildNextGeneration(ui32 maxent) {
        TAutoPtr<TEntrySet> newset = new TEntrySet;
        newset->InitWithAlpha();
        maxent -= newset->size();

        ui32 additions = 0;
        ui32 deletions = 0;

        {
            const TEntrySet& set = *Current;

            Candidates.clear();
            const ui32 total = set.TotalCount;
            const float minpval = Settings.MinPValue;
            const EEntryStatTest test = Settings.StatTest;
            const EEntryScore score = Settings.Score;
            const ui32 mincnt = Settings.MinAbsCount;

            for (const auto& it : set) {
                const TEntry& e = it;
                float modelp = e.ModelP;
                ui32 cnt = e.Count;

                if (e.HasPrefix() && e.Count > mincnt && StatTest(test, modelp, cnt, total) > minpval)
                    Candidates.push_back(TCandidate(-Score(score, e.Len(), modelp, cnt, total), it.Number));
            }

            if (!!CompoundCounts) {
                for (TCompoundCounts::const_iterator it = CompoundCounts->begin(); it != CompoundCounts->end(); ++it) {
                    const TEntry& prev = set.Get(Prev(it->first));
                    const TEntry& next = set.Get(Next(it->first));
                    float modelp = ModelP(prev.Count, next.Count, total);
                    ui32 cnt = it->second;
                    if (cnt > mincnt && StatTest(test, modelp, cnt, total) > minpval)
                        Candidates.push_back(TCandidate(-Score(score, prev.Len() + next.Len(), modelp, cnt, total), it->first));
                }
            }

            Sort(Candidates.begin(), Candidates.end());

            if (Candidates.size() > maxent)
                Candidates.resize(maxent);

            for (const auto& candidate : Candidates) {
                if (IsCompound(candidate.second)) {
                    additions++;
                    newset->Add(set.Get(Prev(candidate.second)).Str, set.Get(Next(candidate.second)).Str);
                } else {
                    newset->Add(set.Get(candidate.second).Str);
                }
            }

            deletions = set.size() - (newset->size() - additions);
        }

        Current = newset;
        Current->BuildHierarchy();
        return deletions + additions;
    }

    ui32 TDictBuilder::Build(ui32 maxentries, ui32 maxiters, ui32 mindiff) {
        size_t totalsz = 0;
        for (auto it : Input)
            totalsz += it.size();

        while (maxiters) {
            maxiters--;

            RebuildCounts(maxentries * Settings.GrowLimit, false);

            if (Settings.Verbose) {
                TString mess = Sprintf("iter:%" PRIu32 " sz:%" PRIu32 " pend:%" PRIu32, maxiters, (ui32)Current->size(), (ui32)CompoundCounts->size());
                Clog << Sprintf("%-110s RSS=%" PRIu32 "M", mess.data(), (ui32)(TRusage::Get().MaxRss >> 20)) << Endl;
            }

            ui32 diff = BuildNextGeneration(maxentries);

            if (Current->size() == maxentries && diff < mindiff)
                break;
        }

        RebuildCounts(0, true);
        Current->SetScores(Settings.Score);
        return maxiters;
    }

}