aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/codecs/greedy_dict/gd_builder.cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/codecs/greedy_dict/gd_builder.cpp
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/codecs/greedy_dict/gd_builder.cpp')
-rw-r--r--library/cpp/codecs/greedy_dict/gd_builder.cpp142
1 files changed, 142 insertions, 0 deletions
diff --git a/library/cpp/codecs/greedy_dict/gd_builder.cpp b/library/cpp/codecs/greedy_dict/gd_builder.cpp
new file mode 100644
index 0000000000..561bfbca01
--- /dev/null
+++ b/library/cpp/codecs/greedy_dict/gd_builder.cpp
@@ -0,0 +1,142 @@
+#include "gd_builder.h"
+
+#include <library/cpp/string_utils/relaxed_escaper/relaxed_escaper.h>
+#include <util/generic/algorithm.h>
+
+#include <util/random/shuffle.h>
+#include <util/stream/output.h>
+#include <util/string/printf.h>
+#include <util/system/rusage.h>
+
+namespace NGreedyDict {
+ void TDictBuilder::RebuildCounts(ui32 maxcand, bool final) {
+ if (!Current) {
+ Current = MakeHolder<TEntrySet>();
+ Current->InitWithAlpha();
+ }
+
+ TEntrySet& set = *Current;
+
+ for (auto& it : set)
+ it.Count = 0;
+
+ CompoundCounts = nullptr;
+ CompoundCountsPool.Clear();
+
+ if (!final) {
+ CompoundCounts = MakeHolder<TCompoundCounts>(&CompoundCountsPool);
+ CompoundCounts->reserve(maxcand);
+ }
+
+ Shuffle(Input.begin(), Input.end(), Rng);
+
+ for (auto str : Input) {
+ if (!final && CompoundCounts->size() > maxcand)
+ break;
+
+ i32 prev = -1;
+
+ while (!!str) {
+ TEntry* e = set.FindPrefix(str);
+ ui32 num = e->Number;
+
+ e->Count += 1;
+ if (!final && prev >= 0) {
+ (*CompoundCounts)[Compose(prev, num)] += 1;
+ }
+
+ prev = num;
+ ++set.TotalCount;
+ }
+ }
+
+ Current->SetModelP();
+ }
+
+ ui32 TDictBuilder::BuildNextGeneration(ui32 maxent) {
+ TAutoPtr<TEntrySet> newset = new TEntrySet;
+ newset->InitWithAlpha();
+ maxent -= newset->size();
+
+ ui32 additions = 0;
+ ui32 deletions = 0;
+
+ {
+ const TEntrySet& set = *Current;
+
+ Candidates.clear();
+ const ui32 total = set.TotalCount;
+ const float minpval = Settings.MinPValue;
+ const EEntryStatTest test = Settings.StatTest;
+ const EEntryScore score = Settings.Score;
+ const ui32 mincnt = Settings.MinAbsCount;
+
+ for (const auto& it : set) {
+ const TEntry& e = it;
+ float modelp = e.ModelP;
+ ui32 cnt = e.Count;
+
+ if (e.HasPrefix() && e.Count > mincnt && StatTest(test, modelp, cnt, total) > minpval)
+ Candidates.push_back(TCandidate(-Score(score, e.Len(), modelp, cnt, total), it.Number));
+ }
+
+ if (!!CompoundCounts) {
+ for (TCompoundCounts::const_iterator it = CompoundCounts->begin(); it != CompoundCounts->end(); ++it) {
+ const TEntry& prev = set.Get(Prev(it->first));
+ const TEntry& next = set.Get(Next(it->first));
+ float modelp = ModelP(prev.Count, next.Count, total);
+ ui32 cnt = it->second;
+ if (cnt > mincnt && StatTest(test, modelp, cnt, total) > minpval)
+ Candidates.push_back(TCandidate(-Score(score, prev.Len() + next.Len(), modelp, cnt, total), it->first));
+ }
+ }
+
+ Sort(Candidates.begin(), Candidates.end());
+
+ if (Candidates.size() > maxent)
+ Candidates.resize(maxent);
+
+ for (const auto& candidate : Candidates) {
+ if (IsCompound(candidate.second)) {
+ additions++;
+ newset->Add(set.Get(Prev(candidate.second)).Str, set.Get(Next(candidate.second)).Str);
+ } else {
+ newset->Add(set.Get(candidate.second).Str);
+ }
+ }
+
+ deletions = set.size() - (newset->size() - additions);
+ }
+
+ Current = newset;
+ Current->BuildHierarchy();
+ return deletions + additions;
+ }
+
+ ui32 TDictBuilder::Build(ui32 maxentries, ui32 maxiters, ui32 mindiff) {
+ size_t totalsz = 0;
+ for (auto it : Input)
+ totalsz += it.size();
+
+ while (maxiters) {
+ maxiters--;
+
+ RebuildCounts(maxentries * Settings.GrowLimit, false);
+
+ if (Settings.Verbose) {
+ TString mess = Sprintf("iter:%" PRIu32 " sz:%" PRIu32 " pend:%" PRIu32, maxiters, (ui32)Current->size(), (ui32)CompoundCounts->size());
+ Clog << Sprintf("%-110s RSS=%" PRIu32 "M", mess.data(), (ui32)(TRusage::Get().MaxRss >> 20)) << Endl;
+ }
+
+ ui32 diff = BuildNextGeneration(maxentries);
+
+ if (Current->size() == maxentries && diff < mindiff)
+ break;
+ }
+
+ RebuildCounts(0, true);
+ Current->SetScores(Settings.Score);
+ return maxiters;
+ }
+
+}