aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/marisa-trie/marisa/grimoire/trie
diff options
context:
space:
mode:
authorrobot-piglet <robot-piglet@yandex-team.com>2023-12-02 01:45:21 +0300
committerrobot-piglet <robot-piglet@yandex-team.com>2023-12-02 02:42:50 +0300
commit9c43d58f75cf086b744cf4fe2ae180e8f37e4a0c (patch)
tree9f88a486917d371d099cd712efd91b4c122d209d /contrib/python/marisa-trie/marisa/grimoire/trie
parent32fb6dda1feb24f9ab69ece5df0cb9ec238ca5e6 (diff)
downloadydb-9c43d58f75cf086b744cf4fe2ae180e8f37e4a0c.tar.gz
Intermediate changes
Diffstat (limited to 'contrib/python/marisa-trie/marisa/grimoire/trie')
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/cache.h82
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/config.h156
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/entry.h83
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/header.h62
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/history.h66
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/key.h227
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc877
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h135
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/range.h116
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/state.h118
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc218
-rw-r--r--contrib/python/marisa-trie/marisa/grimoire/trie/tail.h73
12 files changed, 2213 insertions, 0 deletions
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h b/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h
new file mode 100644
index 0000000000..f9da360869
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h
@@ -0,0 +1,82 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_CACHE_H_
+#define MARISA_GRIMOIRE_TRIE_CACHE_H_
+
+#include <cfloat>
+
+#include "../../base.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Cache {
+ public:
+ Cache() : parent_(0), child_(0), union_() {
+ union_.weight = FLT_MIN;
+ }
+ Cache(const Cache &cache)
+ : parent_(cache.parent_), child_(cache.child_), union_(cache.union_) {}
+
+ Cache &operator=(const Cache &cache) {
+ parent_ = cache.parent_;
+ child_ = cache.child_;
+ union_ = cache.union_;
+ return *this;
+ }
+
+ void set_parent(std::size_t parent) {
+ MARISA_DEBUG_IF(parent > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ parent_ = (UInt32)parent;
+ }
+ void set_child(std::size_t child) {
+ MARISA_DEBUG_IF(child > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ child_ = (UInt32)child;
+ }
+ void set_base(UInt8 base) {
+ union_.link = (union_.link & ~0xFFU) | base;
+ }
+ void set_extra(std::size_t extra) {
+ MARISA_DEBUG_IF(extra > (MARISA_UINT32_MAX >> 8), MARISA_SIZE_ERROR);
+ union_.link = (UInt32)((union_.link & 0xFFU) | (extra << 8));
+ }
+ void set_weight(float weight) {
+ union_.weight = weight;
+ }
+
+ std::size_t parent() const {
+ return parent_;
+ }
+ std::size_t child() const {
+ return child_;
+ }
+ UInt8 base() const {
+ return (UInt8)(union_.link & 0xFFU);
+ }
+ std::size_t extra() const {
+ return union_.link >> 8;
+ }
+ char label() const {
+ return (char)base();
+ }
+ std::size_t link() const {
+ return union_.link;
+ }
+ float weight() const {
+ return union_.weight;
+ }
+
+ private:
+ UInt32 parent_;
+ UInt32 child_;
+ union Union {
+ UInt32 link;
+ float weight;
+ } union_;
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_CACHE_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/config.h b/contrib/python/marisa-trie/marisa/grimoire/trie/config.h
new file mode 100644
index 0000000000..9b307de3e1
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/config.h
@@ -0,0 +1,156 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_CONFIG_H_
+#define MARISA_GRIMOIRE_TRIE_CONFIG_H_
+
+#include "../../base.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Config {
+ public:
+ Config()
+ : num_tries_(MARISA_DEFAULT_NUM_TRIES),
+ cache_level_(MARISA_DEFAULT_CACHE),
+ tail_mode_(MARISA_DEFAULT_TAIL),
+ node_order_(MARISA_DEFAULT_ORDER) {}
+
+ void parse(int config_flags) {
+ Config temp;
+ temp.parse_(config_flags);
+ swap(temp);
+ }
+
+ int flags() const {
+ return (int)num_tries_ | tail_mode_ | node_order_;
+ }
+
+ std::size_t num_tries() const {
+ return num_tries_;
+ }
+ CacheLevel cache_level() const {
+ return cache_level_;
+ }
+ TailMode tail_mode() const {
+ return tail_mode_;
+ }
+ NodeOrder node_order() const {
+ return node_order_;
+ }
+
+ void clear() {
+ Config().swap(*this);
+ }
+ void swap(Config &rhs) {
+ marisa::swap(num_tries_, rhs.num_tries_);
+ marisa::swap(cache_level_, rhs.cache_level_);
+ marisa::swap(tail_mode_, rhs.tail_mode_);
+ marisa::swap(node_order_, rhs.node_order_);
+ }
+
+ private:
+ std::size_t num_tries_;
+ CacheLevel cache_level_;
+ TailMode tail_mode_;
+ NodeOrder node_order_;
+
+ void parse_(int config_flags) {
+ MARISA_THROW_IF((config_flags & ~MARISA_CONFIG_MASK) != 0,
+ MARISA_CODE_ERROR);
+
+ parse_num_tries(config_flags);
+ parse_cache_level(config_flags);
+ parse_tail_mode(config_flags);
+ parse_node_order(config_flags);
+ }
+
+ void parse_num_tries(int config_flags) {
+ const int num_tries = config_flags & MARISA_NUM_TRIES_MASK;
+ if (num_tries != 0) {
+ num_tries_ = num_tries;
+ }
+ }
+
+ void parse_cache_level(int config_flags) {
+ switch (config_flags & MARISA_CACHE_LEVEL_MASK) {
+ case 0: {
+ cache_level_ = MARISA_DEFAULT_CACHE;
+ break;
+ }
+ case MARISA_HUGE_CACHE: {
+ cache_level_ = MARISA_HUGE_CACHE;
+ break;
+ }
+ case MARISA_LARGE_CACHE: {
+ cache_level_ = MARISA_LARGE_CACHE;
+ break;
+ }
+ case MARISA_NORMAL_CACHE: {
+ cache_level_ = MARISA_NORMAL_CACHE;
+ break;
+ }
+ case MARISA_SMALL_CACHE: {
+ cache_level_ = MARISA_SMALL_CACHE;
+ break;
+ }
+ case MARISA_TINY_CACHE: {
+ cache_level_ = MARISA_TINY_CACHE;
+ break;
+ }
+ default: {
+ MARISA_THROW(MARISA_CODE_ERROR, "undefined cache level");
+ }
+ }
+ }
+
+ void parse_tail_mode(int config_flags) {
+ switch (config_flags & MARISA_TAIL_MODE_MASK) {
+ case 0: {
+ tail_mode_ = MARISA_DEFAULT_TAIL;
+ break;
+ }
+ case MARISA_TEXT_TAIL: {
+ tail_mode_ = MARISA_TEXT_TAIL;
+ break;
+ }
+ case MARISA_BINARY_TAIL: {
+ tail_mode_ = MARISA_BINARY_TAIL;
+ break;
+ }
+ default: {
+ MARISA_THROW(MARISA_CODE_ERROR, "undefined tail mode");
+ }
+ }
+ }
+
+ void parse_node_order(int config_flags) {
+ switch (config_flags & MARISA_NODE_ORDER_MASK) {
+ case 0: {
+ node_order_ = MARISA_DEFAULT_ORDER;
+ break;
+ }
+ case MARISA_LABEL_ORDER: {
+ node_order_ = MARISA_LABEL_ORDER;
+ break;
+ }
+ case MARISA_WEIGHT_ORDER: {
+ node_order_ = MARISA_WEIGHT_ORDER;
+ break;
+ }
+ default: {
+ MARISA_THROW(MARISA_CODE_ERROR, "undefined node order");
+ }
+ }
+ }
+
+ // Disallows copy and assignment.
+ Config(const Config &);
+ Config &operator=(const Config &);
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_CONFIG_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h b/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h
new file mode 100644
index 0000000000..834ab95e1e
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h
@@ -0,0 +1,83 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_ENTRY_H_
+#define MARISA_GRIMOIRE_TRIE_ENTRY_H_
+
+#include "../../base.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Entry {
+ public:
+ Entry()
+ : ptr_(reinterpret_cast<const char *>(-1)), length_(0), id_(0) {}
+ Entry(const Entry &entry)
+ : ptr_(entry.ptr_), length_(entry.length_), id_(entry.id_) {}
+
+ Entry &operator=(const Entry &entry) {
+ ptr_ = entry.ptr_;
+ length_ = entry.length_;
+ id_ = entry.id_;
+ return *this;
+ }
+
+ char operator[](std::size_t i) const {
+ MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR);
+ return *(ptr_ - i);
+ }
+
+ void set_str(const char *ptr, std::size_t length) {
+ MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR);
+ MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ ptr_ = ptr + length - 1;
+ length_ = (UInt32)length;
+ }
+ void set_id(std::size_t id) {
+ MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ id_ = (UInt32)id;
+ }
+
+ const char *ptr() const {
+ return ptr_ - length_ + 1;
+ }
+ std::size_t length() const {
+ return length_;
+ }
+ std::size_t id() const {
+ return id_;
+ }
+
+ class StringComparer {
+ public:
+ bool operator()(const Entry &lhs, const Entry &rhs) const {
+ for (std::size_t i = 0; i < lhs.length(); ++i) {
+ if (i == rhs.length()) {
+ return true;
+ }
+ if (lhs[i] != rhs[i]) {
+ return (UInt8)lhs[i] > (UInt8)rhs[i];
+ }
+ }
+ return lhs.length() > rhs.length();
+ }
+ };
+
+ class IDComparer {
+ public:
+ bool operator()(const Entry &lhs, const Entry &rhs) const {
+ return lhs.id_ < rhs.id_;
+ }
+ };
+
+ private:
+ const char *ptr_;
+ UInt32 length_;
+ UInt32 id_;
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_ENTRY_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/header.h b/contrib/python/marisa-trie/marisa/grimoire/trie/header.h
new file mode 100644
index 0000000000..04839f67e1
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/header.h
@@ -0,0 +1,62 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_HEADER_H_
+#define MARISA_GRIMOIRE_TRIE_HEADER_H_
+
+#include "../io.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Header {
+ public:
+ enum {
+ HEADER_SIZE = 16
+ };
+
+ Header() {}
+
+ void map(Mapper &mapper) {
+ const char *ptr;
+ mapper.map(&ptr, HEADER_SIZE);
+ MARISA_THROW_IF(!test_header(ptr), MARISA_FORMAT_ERROR);
+ }
+ void read(Reader &reader) {
+ char buf[HEADER_SIZE];
+ reader.read(buf, HEADER_SIZE);
+ MARISA_THROW_IF(!test_header(buf), MARISA_FORMAT_ERROR);
+ }
+ void write(Writer &writer) const {
+ writer.write(get_header(), HEADER_SIZE);
+ }
+
+ std::size_t io_size() const {
+ return HEADER_SIZE;
+ }
+
+ private:
+
+ static const char *get_header() {
+ static const char buf[HEADER_SIZE] = "We love Marisa.";
+ return buf;
+ }
+
+ static bool test_header(const char *ptr) {
+ for (std::size_t i = 0; i < HEADER_SIZE; ++i) {
+ if (ptr[i] != get_header()[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Disallows copy and assignment.
+ Header(const Header &);
+ Header &operator=(const Header &);
+};
+
+} // namespace trie
+} // namespace marisa
+} // namespace grimoire
+
+#endif // MARISA_GRIMOIRE_TRIE_HEADER_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/history.h b/contrib/python/marisa-trie/marisa/grimoire/trie/history.h
new file mode 100644
index 0000000000..9a3d272260
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/history.h
@@ -0,0 +1,66 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_
+#define MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_
+
+#include "../../base.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class History {
+ public:
+ History()
+ : node_id_(0), louds_pos_(0), key_pos_(0),
+ link_id_(MARISA_INVALID_LINK_ID), key_id_(MARISA_INVALID_KEY_ID) {}
+
+ void set_node_id(std::size_t node_id) {
+ MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ node_id_ = (UInt32)node_id;
+ }
+ void set_louds_pos(std::size_t louds_pos) {
+ MARISA_DEBUG_IF(louds_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ louds_pos_ = (UInt32)louds_pos;
+ }
+ void set_key_pos(std::size_t key_pos) {
+ MARISA_DEBUG_IF(key_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ key_pos_ = (UInt32)key_pos;
+ }
+ void set_link_id(std::size_t link_id) {
+ MARISA_DEBUG_IF(link_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ link_id_ = (UInt32)link_id;
+ }
+ void set_key_id(std::size_t key_id) {
+ MARISA_DEBUG_IF(key_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ key_id_ = (UInt32)key_id;
+ }
+
+ std::size_t node_id() const {
+ return node_id_;
+ }
+ std::size_t louds_pos() const {
+ return louds_pos_;
+ }
+ std::size_t key_pos() const {
+ return key_pos_;
+ }
+ std::size_t link_id() const {
+ return link_id_;
+ }
+ std::size_t key_id() const {
+ return key_id_;
+ }
+
+ private:
+ UInt32 node_id_;
+ UInt32 louds_pos_;
+ UInt32 key_pos_;
+ UInt32 link_id_;
+ UInt32 key_id_;
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/key.h b/contrib/python/marisa-trie/marisa/grimoire/trie/key.h
new file mode 100644
index 0000000000..c09ea86cf8
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/key.h
@@ -0,0 +1,227 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_KEY_H_
+#define MARISA_GRIMOIRE_TRIE_KEY_H_
+
+#include "../../base.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Key {
+ public:
+ Key() : ptr_(NULL), length_(0), union_(), id_(0) {
+ union_.terminal = 0;
+ }
+ Key(const Key &entry)
+ : ptr_(entry.ptr_), length_(entry.length_),
+ union_(entry.union_), id_(entry.id_) {}
+
+ Key &operator=(const Key &entry) {
+ ptr_ = entry.ptr_;
+ length_ = entry.length_;
+ union_ = entry.union_;
+ id_ = entry.id_;
+ return *this;
+ }
+
+ char operator[](std::size_t i) const {
+ MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR);
+ return ptr_[i];
+ }
+
+ void substr(std::size_t pos, std::size_t length) {
+ MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR);
+ MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR);
+ MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR);
+ ptr_ += pos;
+ length_ = (UInt32)length;
+ }
+
+ void set_str(const char *ptr, std::size_t length) {
+ MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR);
+ MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ ptr_ = ptr;
+ length_ = (UInt32)length;
+ }
+ void set_weight(float weight) {
+ union_.weight = weight;
+ }
+ void set_terminal(std::size_t terminal) {
+ MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ union_.terminal = (UInt32)terminal;
+ }
+ void set_id(std::size_t id) {
+ MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ id_ = (UInt32)id;
+ }
+
+ const char *ptr() const {
+ return ptr_;
+ }
+ std::size_t length() const {
+ return length_;
+ }
+ float weight() const {
+ return union_.weight;
+ }
+ std::size_t terminal() const {
+ return union_.terminal;
+ }
+ std::size_t id() const {
+ return id_;
+ }
+
+ private:
+ const char *ptr_;
+ UInt32 length_;
+ union Union {
+ float weight;
+ UInt32 terminal;
+ } union_;
+ UInt32 id_;
+};
+
+inline bool operator==(const Key &lhs, const Key &rhs) {
+ if (lhs.length() != rhs.length()) {
+ return false;
+ }
+ for (std::size_t i = 0; i < lhs.length(); ++i) {
+ if (lhs[i] != rhs[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool operator!=(const Key &lhs, const Key &rhs) {
+ return !(lhs == rhs);
+}
+
+inline bool operator<(const Key &lhs, const Key &rhs) {
+ for (std::size_t i = 0; i < lhs.length(); ++i) {
+ if (i == rhs.length()) {
+ return false;
+ }
+ if (lhs[i] != rhs[i]) {
+ return (UInt8)lhs[i] < (UInt8)rhs[i];
+ }
+ }
+ return lhs.length() < rhs.length();
+}
+
+inline bool operator>(const Key &lhs, const Key &rhs) {
+ return rhs < lhs;
+}
+
+class ReverseKey {
+ public:
+ ReverseKey() : ptr_(NULL), length_(0), union_(), id_(0) {
+ union_.terminal = 0;
+ }
+ ReverseKey(const ReverseKey &entry)
+ : ptr_(entry.ptr_), length_(entry.length_),
+ union_(entry.union_), id_(entry.id_) {}
+
+ ReverseKey &operator=(const ReverseKey &entry) {
+ ptr_ = entry.ptr_;
+ length_ = entry.length_;
+ union_ = entry.union_;
+ id_ = entry.id_;
+ return *this;
+ }
+
+ char operator[](std::size_t i) const {
+ MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR);
+ return *(ptr_ - i - 1);
+ }
+
+ void substr(std::size_t pos, std::size_t length) {
+ MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR);
+ MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR);
+ MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR);
+ ptr_ -= pos;
+ length_ = (UInt32)length;
+ }
+
+ void set_str(const char *ptr, std::size_t length) {
+ MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR);
+ MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ ptr_ = ptr + length;
+ length_ = (UInt32)length;
+ }
+ void set_weight(float weight) {
+ union_.weight = weight;
+ }
+ void set_terminal(std::size_t terminal) {
+ MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ union_.terminal = (UInt32)terminal;
+ }
+ void set_id(std::size_t id) {
+ MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ id_ = (UInt32)id;
+ }
+
+ const char *ptr() const {
+ return ptr_ - length_;
+ }
+ std::size_t length() const {
+ return length_;
+ }
+ float weight() const {
+ return union_.weight;
+ }
+ std::size_t terminal() const {
+ return union_.terminal;
+ }
+ std::size_t id() const {
+ return id_;
+ }
+
+ private:
+ const char *ptr_;
+ UInt32 length_;
+ union Union {
+ float weight;
+ UInt32 terminal;
+ } union_;
+ UInt32 id_;
+};
+
+inline bool operator==(const ReverseKey &lhs, const ReverseKey &rhs) {
+ if (lhs.length() != rhs.length()) {
+ return false;
+ }
+ for (std::size_t i = 0; i < lhs.length(); ++i) {
+ if (lhs[i] != rhs[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool operator!=(const ReverseKey &lhs, const ReverseKey &rhs) {
+ return !(lhs == rhs);
+}
+
+inline bool operator<(const ReverseKey &lhs, const ReverseKey &rhs) {
+ for (std::size_t i = 0; i < lhs.length(); ++i) {
+ if (i == rhs.length()) {
+ return false;
+ }
+ if (lhs[i] != rhs[i]) {
+ return (UInt8)lhs[i] < (UInt8)rhs[i];
+ }
+ }
+ return lhs.length() < rhs.length();
+}
+
+inline bool operator>(const ReverseKey &lhs, const ReverseKey &rhs) {
+ return rhs < lhs;
+}
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_KEY_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc
new file mode 100644
index 0000000000..ed168539bc
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc
@@ -0,0 +1,877 @@
+#include <algorithm>
+#include <functional>
+#include <queue>
+
+#include "../algorithm.h"
+#include "header.h"
+#include "range.h"
+#include "state.h"
+#include "louds-trie.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+LoudsTrie::LoudsTrie()
+ : louds_(), terminal_flags_(), link_flags_(), bases_(), extras_(),
+ tail_(), next_trie_(), cache_(), cache_mask_(0), num_l1_nodes_(0),
+ config_(), mapper_() {}
+
+LoudsTrie::~LoudsTrie() {}
+
+void LoudsTrie::build(Keyset &keyset, int flags) {
+ Config config;
+ config.parse(flags);
+
+ LoudsTrie temp;
+ temp.build_(keyset, config);
+ swap(temp);
+}
+
+void LoudsTrie::map(Mapper &mapper) {
+ Header().map(mapper);
+
+ LoudsTrie temp;
+ temp.map_(mapper);
+ temp.mapper_.swap(mapper);
+ swap(temp);
+}
+
+void LoudsTrie::read(Reader &reader) {
+ Header().read(reader);
+
+ LoudsTrie temp;
+ temp.read_(reader);
+ swap(temp);
+}
+
+void LoudsTrie::write(Writer &writer) const {
+ Header().write(writer);
+
+ write_(writer);
+}
+
+bool LoudsTrie::lookup(Agent &agent) const {
+ MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
+
+ State &state = agent.state();
+ state.lookup_init();
+ while (state.query_pos() < agent.query().length()) {
+ if (!find_child(agent)) {
+ return false;
+ }
+ }
+ if (!terminal_flags_[state.node_id()]) {
+ return false;
+ }
+ agent.set_key(agent.query().ptr(), agent.query().length());
+ agent.set_key(terminal_flags_.rank1(state.node_id()));
+ return true;
+}
+
+void LoudsTrie::reverse_lookup(Agent &agent) const {
+ MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
+ MARISA_THROW_IF(agent.query().id() >= size(), MARISA_BOUND_ERROR);
+
+ State &state = agent.state();
+ state.reverse_lookup_init();
+
+ state.set_node_id(terminal_flags_.select1(agent.query().id()));
+ if (state.node_id() == 0) {
+ agent.set_key(state.key_buf().begin(), state.key_buf().size());
+ agent.set_key(agent.query().id());
+ return;
+ }
+ for ( ; ; ) {
+ if (link_flags_[state.node_id()]) {
+ const std::size_t prev_key_pos = state.key_buf().size();
+ restore(agent, get_link(state.node_id()));
+ std::reverse(state.key_buf().begin() + prev_key_pos,
+ state.key_buf().end());
+ } else {
+ state.key_buf().push_back((char)bases_[state.node_id()]);
+ }
+
+ if (state.node_id() <= num_l1_nodes_) {
+ std::reverse(state.key_buf().begin(), state.key_buf().end());
+ agent.set_key(state.key_buf().begin(), state.key_buf().size());
+ agent.set_key(agent.query().id());
+ return;
+ }
+ state.set_node_id(louds_.select1(state.node_id()) - state.node_id() - 1);
+ }
+}
+
+bool LoudsTrie::common_prefix_search(Agent &agent) const {
+ MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
+
+ State &state = agent.state();
+ if (state.status_code() == MARISA_END_OF_COMMON_PREFIX_SEARCH) {
+ return false;
+ }
+
+ if (state.status_code() != MARISA_READY_TO_COMMON_PREFIX_SEARCH) {
+ state.common_prefix_search_init();
+ if (terminal_flags_[state.node_id()]) {
+ agent.set_key(agent.query().ptr(), state.query_pos());
+ agent.set_key(terminal_flags_.rank1(state.node_id()));
+ return true;
+ }
+ }
+
+ while (state.query_pos() < agent.query().length()) {
+ if (!find_child(agent)) {
+ state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
+ return false;
+ } else if (terminal_flags_[state.node_id()]) {
+ agent.set_key(agent.query().ptr(), state.query_pos());
+ agent.set_key(terminal_flags_.rank1(state.node_id()));
+ return true;
+ }
+ }
+ state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
+ return false;
+}
+
+bool LoudsTrie::predictive_search(Agent &agent) const {
+ MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
+
+ State &state = agent.state();
+ if (state.status_code() == MARISA_END_OF_PREDICTIVE_SEARCH) {
+ return false;
+ }
+
+ if (state.status_code() != MARISA_READY_TO_PREDICTIVE_SEARCH) {
+ state.predictive_search_init();
+ while (state.query_pos() < agent.query().length()) {
+ if (!predictive_find_child(agent)) {
+ state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
+ return false;
+ }
+ }
+
+ History history;
+ history.set_node_id(state.node_id());
+ history.set_key_pos(state.key_buf().size());
+ state.history().push_back(history);
+ state.set_history_pos(1);
+
+ if (terminal_flags_[state.node_id()]) {
+ agent.set_key(state.key_buf().begin(), state.key_buf().size());
+ agent.set_key(terminal_flags_.rank1(state.node_id()));
+ return true;
+ }
+ }
+
+ for ( ; ; ) {
+ if (state.history_pos() == state.history().size()) {
+ const History &current = state.history().back();
+ History next;
+ next.set_louds_pos(louds_.select0(current.node_id()) + 1);
+ next.set_node_id(next.louds_pos() - current.node_id() - 1);
+ state.history().push_back(next);
+ }
+
+ History &next = state.history()[state.history_pos()];
+ const bool link_flag = louds_[next.louds_pos()];
+ next.set_louds_pos(next.louds_pos() + 1);
+ if (link_flag) {
+ state.set_history_pos(state.history_pos() + 1);
+ if (link_flags_[next.node_id()]) {
+ next.set_link_id(update_link_id(next.link_id(), next.node_id()));
+ restore(agent, get_link(next.node_id(), next.link_id()));
+ } else {
+ state.key_buf().push_back((char)bases_[next.node_id()]);
+ }
+ next.set_key_pos(state.key_buf().size());
+
+ if (terminal_flags_[next.node_id()]) {
+ if (next.key_id() == MARISA_INVALID_KEY_ID) {
+ next.set_key_id(terminal_flags_.rank1(next.node_id()));
+ } else {
+ next.set_key_id(next.key_id() + 1);
+ }
+ agent.set_key(state.key_buf().begin(), state.key_buf().size());
+ agent.set_key(next.key_id());
+ return true;
+ }
+ } else if (state.history_pos() != 1) {
+ History &current = state.history()[state.history_pos() - 1];
+ current.set_node_id(current.node_id() + 1);
+ const History &prev =
+ state.history()[state.history_pos() - 2];
+ state.key_buf().resize(prev.key_pos());
+ state.set_history_pos(state.history_pos() - 1);
+ } else {
+ state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
+ return false;
+ }
+ }
+}
+
+std::size_t LoudsTrie::total_size() const {
+ return louds_.total_size() + terminal_flags_.total_size()
+ + link_flags_.total_size() + bases_.total_size()
+ + extras_.total_size() + tail_.total_size()
+ + ((next_trie_.get() != NULL) ? next_trie_->total_size() : 0)
+ + cache_.total_size();
+}
+
+std::size_t LoudsTrie::io_size() const {
+ return Header().io_size() + louds_.io_size()
+ + terminal_flags_.io_size() + link_flags_.io_size()
+ + bases_.io_size() + extras_.io_size() + tail_.io_size()
+ + ((next_trie_.get() != NULL) ?
+ (next_trie_->io_size() - Header().io_size()) : 0)
+ + cache_.io_size() + (sizeof(UInt32) * 2);
+}
+
+void LoudsTrie::clear() {
+ LoudsTrie().swap(*this);
+}
+
+void LoudsTrie::swap(LoudsTrie &rhs) {
+ louds_.swap(rhs.louds_);
+ terminal_flags_.swap(rhs.terminal_flags_);
+ link_flags_.swap(rhs.link_flags_);
+ bases_.swap(rhs.bases_);
+ extras_.swap(rhs.extras_);
+ tail_.swap(rhs.tail_);
+ next_trie_.swap(rhs.next_trie_);
+ cache_.swap(rhs.cache_);
+ marisa::swap(cache_mask_, rhs.cache_mask_);
+ marisa::swap(num_l1_nodes_, rhs.num_l1_nodes_);
+ config_.swap(rhs.config_);
+ mapper_.swap(rhs.mapper_);
+}
+
+void LoudsTrie::build_(Keyset &keyset, const Config &config) {
+ Vector<Key> keys;
+ keys.resize(keyset.size());
+ for (std::size_t i = 0; i < keyset.size(); ++i) {
+ keys[i].set_str(keyset[i].ptr(), keyset[i].length());
+ keys[i].set_weight(keyset[i].weight());
+ }
+
+ Vector<UInt32> terminals;
+ build_trie(keys, &terminals, config, 1);
+
+ typedef std::pair<UInt32, UInt32> TerminalIdPair;
+
+ Vector<TerminalIdPair> pairs;
+ pairs.resize(terminals.size());
+ for (std::size_t i = 0; i < pairs.size(); ++i) {
+ pairs[i].first = terminals[i];
+ pairs[i].second = (UInt32)i;
+ }
+ terminals.clear();
+ std::sort(pairs.begin(), pairs.end());
+
+ std::size_t node_id = 0;
+ for (std::size_t i = 0; i < pairs.size(); ++i) {
+ while (node_id < pairs[i].first) {
+ terminal_flags_.push_back(false);
+ ++node_id;
+ }
+ if (node_id == pairs[i].first) {
+ terminal_flags_.push_back(true);
+ ++node_id;
+ }
+ }
+ while (node_id < bases_.size()) {
+ terminal_flags_.push_back(false);
+ ++node_id;
+ }
+ terminal_flags_.push_back(false);
+ terminal_flags_.build(false, true);
+
+ for (std::size_t i = 0; i < keyset.size(); ++i) {
+ keyset[pairs[i].second].set_id(terminal_flags_.rank1(pairs[i].first));
+ }
+}
+
+template <typename T>
+void LoudsTrie::build_trie(Vector<T> &keys,
+ Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
+ build_current_trie(keys, terminals, config, trie_id);
+
+ Vector<UInt32> next_terminals;
+ if (!keys.empty()) {
+ build_next_trie(keys, &next_terminals, config, trie_id);
+ }
+
+ if (next_trie_.get() != NULL) {
+ config_.parse(static_cast<int>((next_trie_->num_tries() + 1)) |
+ next_trie_->tail_mode() | next_trie_->node_order());
+ } else {
+ config_.parse(1 | tail_.mode() | config.node_order() |
+ config.cache_level());
+ }
+
+ link_flags_.build(false, false);
+ std::size_t node_id = 0;
+ for (std::size_t i = 0; i < next_terminals.size(); ++i) {
+ while (!link_flags_[node_id]) {
+ ++node_id;
+ }
+ bases_[node_id] = (UInt8)(next_terminals[i] % 256);
+ next_terminals[i] /= 256;
+ ++node_id;
+ }
+ extras_.build(next_terminals);
+ fill_cache();
+}
+
+template <typename T>
+void LoudsTrie::build_current_trie(Vector<T> &keys,
+ Vector<UInt32> *terminals, const Config &config,
+ std::size_t trie_id) try {
+ for (std::size_t i = 0; i < keys.size(); ++i) {
+ keys[i].set_id(i);
+ }
+ const std::size_t num_keys = Algorithm().sort(keys.begin(), keys.end());
+ reserve_cache(config, trie_id, num_keys);
+
+ louds_.push_back(true);
+ louds_.push_back(false);
+ bases_.push_back('\0');
+ link_flags_.push_back(false);
+
+ Vector<T> next_keys;
+ std::queue<Range> queue;
+ Vector<WeightedRange> w_ranges;
+
+ queue.push(make_range(0, keys.size(), 0));
+ while (!queue.empty()) {
+ const std::size_t node_id = link_flags_.size() - queue.size();
+
+ Range range = queue.front();
+ queue.pop();
+
+ while ((range.begin() < range.end()) &&
+ (keys[range.begin()].length() == range.key_pos())) {
+ keys[range.begin()].set_terminal(node_id);
+ range.set_begin(range.begin() + 1);
+ }
+
+ if (range.begin() == range.end()) {
+ louds_.push_back(false);
+ continue;
+ }
+
+ w_ranges.clear();
+ double weight = keys[range.begin()].weight();
+ for (std::size_t i = range.begin() + 1; i < range.end(); ++i) {
+ if (keys[i - 1][range.key_pos()] != keys[i][range.key_pos()]) {
+ w_ranges.push_back(make_weighted_range(
+ range.begin(), i, range.key_pos(), (float)weight));
+ range.set_begin(i);
+ weight = 0.0;
+ }
+ weight += keys[i].weight();
+ }
+ w_ranges.push_back(make_weighted_range(
+ range.begin(), range.end(), range.key_pos(), (float)weight));
+ if (config.node_order() == MARISA_WEIGHT_ORDER) {
+ std::stable_sort(w_ranges.begin(), w_ranges.end(),
+ std::greater<WeightedRange>());
+ }
+
+ if (node_id == 0) {
+ num_l1_nodes_ = w_ranges.size();
+ }
+
+ for (std::size_t i = 0; i < w_ranges.size(); ++i) {
+ WeightedRange &w_range = w_ranges[i];
+ std::size_t key_pos = w_range.key_pos() + 1;
+ while (key_pos < keys[w_range.begin()].length()) {
+ std::size_t j;
+ for (j = w_range.begin() + 1; j < w_range.end(); ++j) {
+ if (keys[j - 1][key_pos] != keys[j][key_pos]) {
+ break;
+ }
+ }
+ if (j < w_range.end()) {
+ break;
+ }
+ ++key_pos;
+ }
+ cache<T>(node_id, bases_.size(), w_range.weight(),
+ keys[w_range.begin()][w_range.key_pos()]);
+
+ if (key_pos == w_range.key_pos() + 1) {
+ bases_.push_back(keys[w_range.begin()][w_range.key_pos()]);
+ link_flags_.push_back(false);
+ } else {
+ bases_.push_back('\0');
+ link_flags_.push_back(true);
+ T next_key;
+ next_key.set_str(keys[w_range.begin()].ptr(),
+ keys[w_range.begin()].length());
+ next_key.substr(w_range.key_pos(), key_pos - w_range.key_pos());
+ next_key.set_weight(w_range.weight());
+ next_keys.push_back(next_key);
+ }
+ w_range.set_key_pos(key_pos);
+ queue.push(w_range.range());
+ louds_.push_back(true);
+ }
+ louds_.push_back(false);
+ }
+
+ louds_.push_back(false);
+ louds_.build(trie_id == 1, true);
+ bases_.shrink();
+
+ build_terminals(keys, terminals);
+ keys.swap(next_keys);
+} catch (const std::bad_alloc &) {
+ MARISA_THROW(MARISA_MEMORY_ERROR, "std::bad_alloc");
+}
+
+template <>
+void LoudsTrie::build_next_trie(Vector<Key> &keys,
+ Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
+ if (trie_id == config.num_tries()) {
+ Vector<Entry> entries;
+ entries.resize(keys.size());
+ for (std::size_t i = 0; i < keys.size(); ++i) {
+ entries[i].set_str(keys[i].ptr(), keys[i].length());
+ }
+ tail_.build(entries, terminals, config.tail_mode());
+ return;
+ }
+ Vector<ReverseKey> reverse_keys;
+ reverse_keys.resize(keys.size());
+ for (std::size_t i = 0; i < keys.size(); ++i) {
+ reverse_keys[i].set_str(keys[i].ptr(), keys[i].length());
+ reverse_keys[i].set_weight(keys[i].weight());
+ }
+ keys.clear();
+ next_trie_.reset(new (std::nothrow) LoudsTrie);
+ MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
+ next_trie_->build_trie(reverse_keys, terminals, config, trie_id + 1);
+}
+
+template <>
+void LoudsTrie::build_next_trie(Vector<ReverseKey> &keys,
+ Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
+ if (trie_id == config.num_tries()) {
+ Vector<Entry> entries;
+ entries.resize(keys.size());
+ for (std::size_t i = 0; i < keys.size(); ++i) {
+ entries[i].set_str(keys[i].ptr(), keys[i].length());
+ }
+ tail_.build(entries, terminals, config.tail_mode());
+ return;
+ }
+ next_trie_.reset(new (std::nothrow) LoudsTrie);
+ MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
+ next_trie_->build_trie(keys, terminals, config, trie_id + 1);
+}
+
+template <typename T>
+void LoudsTrie::build_terminals(const Vector<T> &keys,
+ Vector<UInt32> *terminals) const {
+ Vector<UInt32> temp;
+ temp.resize(keys.size());
+ for (std::size_t i = 0; i < keys.size(); ++i) {
+ temp[keys[i].id()] = (UInt32)keys[i].terminal();
+ }
+ terminals->swap(temp);
+}
+
+template <>
+void LoudsTrie::cache<Key>(std::size_t parent, std::size_t child,
+ float weight, char label) {
+ MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
+
+ const std::size_t cache_id = get_cache_id(parent, label);
+ if (weight > cache_[cache_id].weight()) {
+ cache_[cache_id].set_parent(parent);
+ cache_[cache_id].set_child(child);
+ cache_[cache_id].set_weight(weight);
+ }
+}
+
+void LoudsTrie::reserve_cache(const Config &config, std::size_t trie_id,
+ std::size_t num_keys) {
+ std::size_t cache_size = (trie_id == 1) ? 256 : 1;
+ while (cache_size < (num_keys / config.cache_level())) {
+ cache_size *= 2;
+ }
+ cache_.resize(cache_size);
+ cache_mask_ = cache_size - 1;
+}
+
+template <>
+void LoudsTrie::cache<ReverseKey>(std::size_t parent, std::size_t child,
+ float weight, char) {
+ MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
+
+ const std::size_t cache_id = get_cache_id(child);
+ if (weight > cache_[cache_id].weight()) {
+ cache_[cache_id].set_parent(parent);
+ cache_[cache_id].set_child(child);
+ cache_[cache_id].set_weight(weight);
+ }
+}
+
+void LoudsTrie::fill_cache() {
+ for (std::size_t i = 0; i < cache_.size(); ++i) {
+ const std::size_t node_id = cache_[i].child();
+ if (node_id != 0) {
+ cache_[i].set_base(bases_[node_id]);
+ cache_[i].set_extra(!link_flags_[node_id] ?
+ MARISA_INVALID_EXTRA : extras_[link_flags_.rank1(node_id)]);
+ } else {
+ cache_[i].set_parent(MARISA_UINT32_MAX);
+ cache_[i].set_child(MARISA_UINT32_MAX);
+ }
+ }
+}
+
+void LoudsTrie::map_(Mapper &mapper) {
+ louds_.map(mapper);
+ terminal_flags_.map(mapper);
+ link_flags_.map(mapper);
+ bases_.map(mapper);
+ extras_.map(mapper);
+ tail_.map(mapper);
+ if ((link_flags_.num_1s() != 0) && tail_.empty()) {
+ next_trie_.reset(new (std::nothrow) LoudsTrie);
+ MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
+ next_trie_->map_(mapper);
+ }
+ cache_.map(mapper);
+ cache_mask_ = cache_.size() - 1;
+ {
+ UInt32 temp_num_l1_nodes;
+ mapper.map(&temp_num_l1_nodes);
+ num_l1_nodes_ = temp_num_l1_nodes;
+ }
+ {
+ UInt32 temp_config_flags;
+ mapper.map(&temp_config_flags);
+ config_.parse((int)temp_config_flags);
+ }
+}
+
+void LoudsTrie::read_(Reader &reader) {
+ louds_.read(reader);
+ terminal_flags_.read(reader);
+ link_flags_.read(reader);
+ bases_.read(reader);
+ extras_.read(reader);
+ tail_.read(reader);
+ if ((link_flags_.num_1s() != 0) && tail_.empty()) {
+ next_trie_.reset(new (std::nothrow) LoudsTrie);
+ MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
+ next_trie_->read_(reader);
+ }
+ cache_.read(reader);
+ cache_mask_ = cache_.size() - 1;
+ {
+ UInt32 temp_num_l1_nodes;
+ reader.read(&temp_num_l1_nodes);
+ num_l1_nodes_ = temp_num_l1_nodes;
+ }
+ {
+ UInt32 temp_config_flags;
+ reader.read(&temp_config_flags);
+ config_.parse((int)temp_config_flags);
+ }
+}
+
+void LoudsTrie::write_(Writer &writer) const {
+ louds_.write(writer);
+ terminal_flags_.write(writer);
+ link_flags_.write(writer);
+ bases_.write(writer);
+ extras_.write(writer);
+ tail_.write(writer);
+ if (next_trie_.get() != NULL) {
+ next_trie_->write_(writer);
+ }
+ cache_.write(writer);
+ writer.write((UInt32)num_l1_nodes_);
+ writer.write((UInt32)config_.flags());
+}
+
+bool LoudsTrie::find_child(Agent &agent) const {
+ MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
+ MARISA_BOUND_ERROR);
+
+ State &state = agent.state();
+ const std::size_t cache_id = get_cache_id(state.node_id(),
+ agent.query()[state.query_pos()]);
+ if (state.node_id() == cache_[cache_id].parent()) {
+ if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
+ if (!match(agent, cache_[cache_id].link())) {
+ return false;
+ }
+ } else {
+ state.set_query_pos(state.query_pos() + 1);
+ }
+ state.set_node_id(cache_[cache_id].child());
+ return true;
+ }
+
+ std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
+ if (!louds_[louds_pos]) {
+ return false;
+ }
+ state.set_node_id(louds_pos - state.node_id() - 1);
+ std::size_t link_id = MARISA_INVALID_LINK_ID;
+ do {
+ if (link_flags_[state.node_id()]) {
+ link_id = update_link_id(link_id, state.node_id());
+ const std::size_t prev_query_pos = state.query_pos();
+ if (match(agent, get_link(state.node_id(), link_id))) {
+ return true;
+ } else if (state.query_pos() != prev_query_pos) {
+ return false;
+ }
+ } else if (bases_[state.node_id()] ==
+ (UInt8)agent.query()[state.query_pos()]) {
+ state.set_query_pos(state.query_pos() + 1);
+ return true;
+ }
+ state.set_node_id(state.node_id() + 1);
+ ++louds_pos;
+ } while (louds_[louds_pos]);
+ return false;
+}
+
+bool LoudsTrie::predictive_find_child(Agent &agent) const {
+ MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
+ MARISA_BOUND_ERROR);
+
+ State &state = agent.state();
+ const std::size_t cache_id = get_cache_id(state.node_id(),
+ agent.query()[state.query_pos()]);
+ if (state.node_id() == cache_[cache_id].parent()) {
+ if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
+ if (!prefix_match(agent, cache_[cache_id].link())) {
+ return false;
+ }
+ } else {
+ state.key_buf().push_back(cache_[cache_id].label());
+ state.set_query_pos(state.query_pos() + 1);
+ }
+ state.set_node_id(cache_[cache_id].child());
+ return true;
+ }
+
+ std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
+ if (!louds_[louds_pos]) {
+ return false;
+ }
+ state.set_node_id(louds_pos - state.node_id() - 1);
+ std::size_t link_id = MARISA_INVALID_LINK_ID;
+ do {
+ if (link_flags_[state.node_id()]) {
+ link_id = update_link_id(link_id, state.node_id());
+ const std::size_t prev_query_pos = state.query_pos();
+ if (prefix_match(agent, get_link(state.node_id(), link_id))) {
+ return true;
+ } else if (state.query_pos() != prev_query_pos) {
+ return false;
+ }
+ } else if (bases_[state.node_id()] ==
+ (UInt8)agent.query()[state.query_pos()]) {
+ state.key_buf().push_back((char)bases_[state.node_id()]);
+ state.set_query_pos(state.query_pos() + 1);
+ return true;
+ }
+ state.set_node_id(state.node_id() + 1);
+ ++louds_pos;
+ } while (louds_[louds_pos]);
+ return false;
+}
+
+void LoudsTrie::restore(Agent &agent, std::size_t link) const {
+ if (next_trie_.get() != NULL) {
+ next_trie_->restore_(agent, link);
+ } else {
+ tail_.restore(agent, link);
+ }
+}
+
+bool LoudsTrie::match(Agent &agent, std::size_t link) const {
+ if (next_trie_.get() != NULL) {
+ return next_trie_->match_(agent, link);
+ } else {
+ return tail_.match(agent, link);
+ }
+}
+
+bool LoudsTrie::prefix_match(Agent &agent, std::size_t link) const {
+ if (next_trie_.get() != NULL) {
+ return next_trie_->prefix_match_(agent, link);
+ } else {
+ return tail_.prefix_match(agent, link);
+ }
+}
+
+void LoudsTrie::restore_(Agent &agent, std::size_t node_id) const {
+ MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
+
+ State &state = agent.state();
+ for ( ; ; ) {
+ const std::size_t cache_id = get_cache_id(node_id);
+ if (node_id == cache_[cache_id].child()) {
+ if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
+ restore(agent, cache_[cache_id].link());
+ } else {
+ state.key_buf().push_back(cache_[cache_id].label());
+ }
+
+ node_id = cache_[cache_id].parent();
+ if (node_id == 0) {
+ return;
+ }
+ continue;
+ }
+
+ if (link_flags_[node_id]) {
+ restore(agent, get_link(node_id));
+ } else {
+ state.key_buf().push_back((char)bases_[node_id]);
+ }
+
+ if (node_id <= num_l1_nodes_) {
+ return;
+ }
+ node_id = louds_.select1(node_id) - node_id - 1;
+ }
+}
+
+bool LoudsTrie::match_(Agent &agent, std::size_t node_id) const {
+ MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
+ MARISA_BOUND_ERROR);
+ MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
+
+ State &state = agent.state();
+ for ( ; ; ) {
+ const std::size_t cache_id = get_cache_id(node_id);
+ if (node_id == cache_[cache_id].child()) {
+ if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
+ if (!match(agent, cache_[cache_id].link())) {
+ return false;
+ }
+ } else if (cache_[cache_id].label() ==
+ agent.query()[state.query_pos()]) {
+ state.set_query_pos(state.query_pos() + 1);
+ } else {
+ return false;
+ }
+
+ node_id = cache_[cache_id].parent();
+ if (node_id == 0) {
+ return true;
+ } else if (state.query_pos() >= agent.query().length()) {
+ return false;
+ }
+ continue;
+ }
+
+ if (link_flags_[node_id]) {
+ if (next_trie_.get() != NULL) {
+ if (!match(agent, get_link(node_id))) {
+ return false;
+ }
+ } else if (!tail_.match(agent, get_link(node_id))) {
+ return false;
+ }
+ } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
+ state.set_query_pos(state.query_pos() + 1);
+ } else {
+ return false;
+ }
+
+ if (node_id <= num_l1_nodes_) {
+ return true;
+ } else if (state.query_pos() >= agent.query().length()) {
+ return false;
+ }
+ node_id = louds_.select1(node_id) - node_id - 1;
+ }
+}
+
+bool LoudsTrie::prefix_match_(Agent &agent, std::size_t node_id) const {
+ MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
+ MARISA_BOUND_ERROR);
+ MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
+
+ State &state = agent.state();
+ for ( ; ; ) {
+ const std::size_t cache_id = get_cache_id(node_id);
+ if (node_id == cache_[cache_id].child()) {
+ if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
+ if (!prefix_match(agent, cache_[cache_id].link())) {
+ return false;
+ }
+ } else if (cache_[cache_id].label() ==
+ agent.query()[state.query_pos()]) {
+ state.key_buf().push_back(cache_[cache_id].label());
+ state.set_query_pos(state.query_pos() + 1);
+ } else {
+ return false;
+ }
+
+ node_id = cache_[cache_id].parent();
+ if (node_id == 0) {
+ return true;
+ }
+ } else {
+ if (link_flags_[node_id]) {
+ if (!prefix_match(agent, get_link(node_id))) {
+ return false;
+ }
+ } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
+ state.key_buf().push_back((char)bases_[node_id]);
+ state.set_query_pos(state.query_pos() + 1);
+ } else {
+ return false;
+ }
+
+ if (node_id <= num_l1_nodes_) {
+ return true;
+ }
+ node_id = louds_.select1(node_id) - node_id - 1;
+ }
+
+ if (state.query_pos() >= agent.query().length()) {
+ restore_(agent, node_id);
+ return true;
+ }
+ }
+}
+
+std::size_t LoudsTrie::get_cache_id(std::size_t node_id, char label) const {
+ return (node_id ^ (node_id << 5) ^ (UInt8)label) & cache_mask_;
+}
+
+std::size_t LoudsTrie::get_cache_id(std::size_t node_id) const {
+ return node_id & cache_mask_;
+}
+
+std::size_t LoudsTrie::get_link(std::size_t node_id) const {
+ return bases_[node_id] | (extras_[link_flags_.rank1(node_id)] * 256);
+}
+
+std::size_t LoudsTrie::get_link(std::size_t node_id,
+ std::size_t link_id) const {
+ return bases_[node_id] | (extras_[link_id] * 256);
+}
+
+std::size_t LoudsTrie::update_link_id(std::size_t link_id,
+ std::size_t node_id) const {
+ return (link_id == MARISA_INVALID_LINK_ID) ?
+ link_flags_.rank1(node_id) : (link_id + 1);
+}
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h
new file mode 100644
index 0000000000..5a757ac8fc
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h
@@ -0,0 +1,135 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_
+#define MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_
+
+#include "../../keyset.h"
+#include "../../agent.h"
+#include "../vector.h"
+#include "config.h"
+#include "key.h"
+#include "tail.h"
+#include "cache.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class LoudsTrie {
+ public:
+ LoudsTrie();
+ ~LoudsTrie();
+
+ void build(Keyset &keyset, int flags);
+
+ void map(Mapper &mapper);
+ void read(Reader &reader);
+ void write(Writer &writer) const;
+
+ bool lookup(Agent &agent) const;
+ void reverse_lookup(Agent &agent) const;
+ bool common_prefix_search(Agent &agent) const;
+ bool predictive_search(Agent &agent) const;
+
+ std::size_t num_tries() const {
+ return config_.num_tries();
+ }
+ std::size_t num_keys() const {
+ return size();
+ }
+ std::size_t num_nodes() const {
+ return (louds_.size() / 2) - 1;
+ }
+
+ CacheLevel cache_level() const {
+ return config_.cache_level();
+ }
+ TailMode tail_mode() const {
+ return config_.tail_mode();
+ }
+ NodeOrder node_order() const {
+ return config_.node_order();
+ }
+
+ bool empty() const {
+ return size() == 0;
+ }
+ std::size_t size() const {
+ return terminal_flags_.num_1s();
+ }
+ std::size_t total_size() const;
+ std::size_t io_size() const;
+
+ void clear();
+ void swap(LoudsTrie &rhs);
+
+ private:
+ BitVector louds_;
+ BitVector terminal_flags_;
+ BitVector link_flags_;
+ Vector<UInt8> bases_;
+ FlatVector extras_;
+ Tail tail_;
+ scoped_ptr<LoudsTrie> next_trie_;
+ Vector<Cache> cache_;
+ std::size_t cache_mask_;
+ std::size_t num_l1_nodes_;
+ Config config_;
+ Mapper mapper_;
+
+ void build_(Keyset &keyset, const Config &config);
+
+ template <typename T>
+ void build_trie(Vector<T> &keys,
+ Vector<UInt32> *terminals, const Config &config, std::size_t trie_id);
+ template <typename T>
+ void build_current_trie(Vector<T> &keys,
+ Vector<UInt32> *terminals, const Config &config, std::size_t trie_id);
+ template <typename T>
+ void build_next_trie(Vector<T> &keys,
+ Vector<UInt32> *terminals, const Config &config, std::size_t trie_id);
+ template <typename T>
+ void build_terminals(const Vector<T> &keys,
+ Vector<UInt32> *terminals) const;
+
+ void reserve_cache(const Config &config, std::size_t trie_id,
+ std::size_t num_keys);
+ template <typename T>
+ void cache(std::size_t parent, std::size_t child,
+ float weight, char label);
+ void fill_cache();
+
+ void map_(Mapper &mapper);
+ void read_(Reader &reader);
+ void write_(Writer &writer) const;
+
+ inline bool find_child(Agent &agent) const;
+ inline bool predictive_find_child(Agent &agent) const;
+
+ inline void restore(Agent &agent, std::size_t node_id) const;
+ inline bool match(Agent &agent, std::size_t node_id) const;
+ inline bool prefix_match(Agent &agent, std::size_t node_id) const;
+
+ void restore_(Agent &agent, std::size_t node_id) const;
+ bool match_(Agent &agent, std::size_t node_id) const;
+ bool prefix_match_(Agent &agent, std::size_t node_id) const;
+
+ inline std::size_t get_cache_id(std::size_t node_id, char label) const;
+ inline std::size_t get_cache_id(std::size_t node_id) const;
+
+ inline std::size_t get_link(std::size_t node_id) const;
+ inline std::size_t get_link(std::size_t node_id,
+ std::size_t link_id) const;
+
+ inline std::size_t update_link_id(std::size_t link_id,
+ std::size_t node_id) const;
+
+ // Disallows copy and assignment.
+ LoudsTrie(const LoudsTrie &);
+ LoudsTrie &operator=(const LoudsTrie &);
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/range.h b/contrib/python/marisa-trie/marisa/grimoire/trie/range.h
new file mode 100644
index 0000000000..4ca39a9c37
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/range.h
@@ -0,0 +1,116 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_RANGE_H_
+#define MARISA_GRIMOIRE_TRIE_RANGE_H_
+
+#include "../../base.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Range {
+ public:
+ Range() : begin_(0), end_(0), key_pos_(0) {}
+
+ void set_begin(std::size_t begin) {
+ MARISA_DEBUG_IF(begin > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ begin_ = static_cast<UInt32>(begin);
+ }
+ void set_end(std::size_t end) {
+ MARISA_DEBUG_IF(end > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ end_ = static_cast<UInt32>(end);
+ }
+ void set_key_pos(std::size_t key_pos) {
+ MARISA_DEBUG_IF(key_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ key_pos_ = static_cast<UInt32>(key_pos);
+ }
+
+ std::size_t begin() const {
+ return begin_;
+ }
+ std::size_t end() const {
+ return end_;
+ }
+ std::size_t key_pos() const {
+ return key_pos_;
+ }
+
+ private:
+ UInt32 begin_;
+ UInt32 end_;
+ UInt32 key_pos_;
+};
+
+inline Range make_range(std::size_t begin, std::size_t end,
+ std::size_t key_pos) {
+ Range range;
+ range.set_begin(begin);
+ range.set_end(end);
+ range.set_key_pos(key_pos);
+ return range;
+}
+
+class WeightedRange {
+ public:
+ WeightedRange() : range_(), weight_(0.0F) {}
+
+ void set_range(const Range &range) {
+ range_ = range;
+ }
+ void set_begin(std::size_t begin) {
+ range_.set_begin(begin);
+ }
+ void set_end(std::size_t end) {
+ range_.set_end(end);
+ }
+ void set_key_pos(std::size_t key_pos) {
+ range_.set_key_pos(key_pos);
+ }
+ void set_weight(float weight) {
+ weight_ = weight;
+ }
+
+ const Range &range() const {
+ return range_;
+ }
+ std::size_t begin() const {
+ return range_.begin();
+ }
+ std::size_t end() const {
+ return range_.end();
+ }
+ std::size_t key_pos() const {
+ return range_.key_pos();
+ }
+ float weight() const {
+ return weight_;
+ }
+
+ private:
+ Range range_;
+ float weight_;
+};
+
+inline bool operator<(const WeightedRange &lhs, const WeightedRange &rhs) {
+ return lhs.weight() < rhs.weight();
+}
+
+inline bool operator>(const WeightedRange &lhs, const WeightedRange &rhs) {
+ return lhs.weight() > rhs.weight();
+}
+
+inline WeightedRange make_weighted_range(std::size_t begin, std::size_t end,
+ std::size_t key_pos, float weight) {
+ WeightedRange range;
+ range.set_begin(begin);
+ range.set_end(end);
+ range.set_key_pos(key_pos);
+ range.set_weight(weight);
+ return range;
+}
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_RANGE_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/state.h b/contrib/python/marisa-trie/marisa/grimoire/trie/state.h
new file mode 100644
index 0000000000..219bf9e03a
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/state.h
@@ -0,0 +1,118 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_STATE_H_
+#define MARISA_GRIMOIRE_TRIE_STATE_H_
+
+#include "../vector.h"
+#include "history.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+// A search agent has its internal state and the status codes are defined
+// below.
+typedef enum StatusCode {
+ MARISA_READY_TO_ALL,
+ MARISA_READY_TO_COMMON_PREFIX_SEARCH,
+ MARISA_READY_TO_PREDICTIVE_SEARCH,
+ MARISA_END_OF_COMMON_PREFIX_SEARCH,
+ MARISA_END_OF_PREDICTIVE_SEARCH,
+} StatusCode;
+
+class State {
+ public:
+ State()
+ : key_buf_(), history_(), node_id_(0), query_pos_(0),
+ history_pos_(0), status_code_(MARISA_READY_TO_ALL) {}
+
+ void set_node_id(std::size_t node_id) {
+ MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ node_id_ = (UInt32)node_id;
+ }
+ void set_query_pos(std::size_t query_pos) {
+ MARISA_DEBUG_IF(query_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ query_pos_ = (UInt32)query_pos;
+ }
+ void set_history_pos(std::size_t history_pos) {
+ MARISA_DEBUG_IF(history_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ history_pos_ = (UInt32)history_pos;
+ }
+ void set_status_code(StatusCode status_code) {
+ status_code_ = status_code;
+ }
+
+ std::size_t node_id() const {
+ return node_id_;
+ }
+ std::size_t query_pos() const {
+ return query_pos_;
+ }
+ std::size_t history_pos() const {
+ return history_pos_;
+ }
+ StatusCode status_code() const {
+ return status_code_;
+ }
+
+ const Vector<char> &key_buf() const {
+ return key_buf_;
+ }
+ const Vector<History> &history() const {
+ return history_;
+ }
+
+ Vector<char> &key_buf() {
+ return key_buf_;
+ }
+ Vector<History> &history() {
+ return history_;
+ }
+
+ void reset() {
+ status_code_ = MARISA_READY_TO_ALL;
+ }
+
+ void lookup_init() {
+ node_id_ = 0;
+ query_pos_ = 0;
+ status_code_ = MARISA_READY_TO_ALL;
+ }
+ void reverse_lookup_init() {
+ key_buf_.resize(0);
+ key_buf_.reserve(32);
+ status_code_ = MARISA_READY_TO_ALL;
+ }
+ void common_prefix_search_init() {
+ node_id_ = 0;
+ query_pos_ = 0;
+ status_code_ = MARISA_READY_TO_COMMON_PREFIX_SEARCH;
+ }
+ void predictive_search_init() {
+ key_buf_.resize(0);
+ key_buf_.reserve(64);
+ history_.resize(0);
+ history_.reserve(4);
+ node_id_ = 0;
+ query_pos_ = 0;
+ history_pos_ = 0;
+ status_code_ = MARISA_READY_TO_PREDICTIVE_SEARCH;
+ }
+
+ private:
+ Vector<char> key_buf_;
+ Vector<History> history_;
+ UInt32 node_id_;
+ UInt32 query_pos_;
+ UInt32 history_pos_;
+ StatusCode status_code_;
+
+ // Disallows copy and assignment.
+ State(const State &);
+ State &operator=(const State &);
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_STATE_H_
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc
new file mode 100644
index 0000000000..6ec3652e1c
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc
@@ -0,0 +1,218 @@
+#include "../algorithm.h"
+#include "state.h"
+#include "tail.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+Tail::Tail() : buf_(), end_flags_() {}
+
+void Tail::build(Vector<Entry> &entries, Vector<UInt32> *offsets,
+ TailMode mode) {
+ MARISA_THROW_IF(offsets == NULL, MARISA_NULL_ERROR);
+
+ switch (mode) {
+ case MARISA_TEXT_TAIL: {
+ for (std::size_t i = 0; i < entries.size(); ++i) {
+ const char * const ptr = entries[i].ptr();
+ const std::size_t length = entries[i].length();
+ for (std::size_t j = 0; j < length; ++j) {
+ if (ptr[j] == '\0') {
+ mode = MARISA_BINARY_TAIL;
+ break;
+ }
+ }
+ if (mode == MARISA_BINARY_TAIL) {
+ break;
+ }
+ }
+ break;
+ }
+ case MARISA_BINARY_TAIL: {
+ break;
+ }
+ default: {
+ MARISA_THROW(MARISA_CODE_ERROR, "undefined tail mode");
+ }
+ }
+
+ Tail temp;
+ temp.build_(entries, offsets, mode);
+ swap(temp);
+}
+
+void Tail::map(Mapper &mapper) {
+ Tail temp;
+ temp.map_(mapper);
+ swap(temp);
+}
+
+void Tail::read(Reader &reader) {
+ Tail temp;
+ temp.read_(reader);
+ swap(temp);
+}
+
+void Tail::write(Writer &writer) const {
+ write_(writer);
+}
+
+void Tail::restore(Agent &agent, std::size_t offset) const {
+ MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR);
+
+ State &state = agent.state();
+ if (end_flags_.empty()) {
+ for (const char *ptr = &buf_[offset]; *ptr != '\0'; ++ptr) {
+ state.key_buf().push_back(*ptr);
+ }
+ } else {
+ do {
+ state.key_buf().push_back(buf_[offset]);
+ } while (!end_flags_[offset++]);
+ }
+}
+
+bool Tail::match(Agent &agent, std::size_t offset) const {
+ MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR);
+ MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
+ MARISA_BOUND_ERROR);
+
+ State &state = agent.state();
+ if (end_flags_.empty()) {
+ const char * const ptr = &buf_[offset] - state.query_pos();
+ do {
+ if (ptr[state.query_pos()] != agent.query()[state.query_pos()]) {
+ return false;
+ }
+ state.set_query_pos(state.query_pos() + 1);
+ if (ptr[state.query_pos()] == '\0') {
+ return true;
+ }
+ } while (state.query_pos() < agent.query().length());
+ return false;
+ } else {
+ do {
+ if (buf_[offset] != agent.query()[state.query_pos()]) {
+ return false;
+ }
+ state.set_query_pos(state.query_pos() + 1);
+ if (end_flags_[offset++]) {
+ return true;
+ }
+ } while (state.query_pos() < agent.query().length());
+ return false;
+ }
+}
+
+bool Tail::prefix_match(Agent &agent, std::size_t offset) const {
+ MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR);
+
+ State &state = agent.state();
+ if (end_flags_.empty()) {
+ const char *ptr = &buf_[offset] - state.query_pos();
+ do {
+ if (ptr[state.query_pos()] != agent.query()[state.query_pos()]) {
+ return false;
+ }
+ state.key_buf().push_back(ptr[state.query_pos()]);
+ state.set_query_pos(state.query_pos() + 1);
+ if (ptr[state.query_pos()] == '\0') {
+ return true;
+ }
+ } while (state.query_pos() < agent.query().length());
+ ptr += state.query_pos();
+ do {
+ state.key_buf().push_back(*ptr);
+ } while (*++ptr != '\0');
+ return true;
+ } else {
+ do {
+ if (buf_[offset] != agent.query()[state.query_pos()]) {
+ return false;
+ }
+ state.key_buf().push_back(buf_[offset]);
+ state.set_query_pos(state.query_pos() + 1);
+ if (end_flags_[offset++]) {
+ return true;
+ }
+ } while (state.query_pos() < agent.query().length());
+ do {
+ state.key_buf().push_back(buf_[offset]);
+ } while (!end_flags_[offset++]);
+ return true;
+ }
+}
+
+void Tail::clear() {
+ Tail().swap(*this);
+}
+
+void Tail::swap(Tail &rhs) {
+ buf_.swap(rhs.buf_);
+ end_flags_.swap(rhs.end_flags_);
+}
+
+void Tail::build_(Vector<Entry> &entries, Vector<UInt32> *offsets,
+ TailMode mode) {
+ for (std::size_t i = 0; i < entries.size(); ++i) {
+ entries[i].set_id(i);
+ }
+ Algorithm().sort(entries.begin(), entries.end());
+
+ Vector<UInt32> temp_offsets;
+ temp_offsets.resize(entries.size(), 0);
+
+ const Entry dummy;
+ const Entry *last = &dummy;
+ for (std::size_t i = entries.size(); i > 0; --i) {
+ const Entry &current = entries[i - 1];
+ MARISA_THROW_IF(current.length() == 0, MARISA_RANGE_ERROR);
+ std::size_t match = 0;
+ while ((match < current.length()) && (match < last->length()) &&
+ ((*last)[match] == current[match])) {
+ ++match;
+ }
+ if ((match == current.length()) && (last->length() != 0)) {
+ temp_offsets[current.id()] = (UInt32)(
+ temp_offsets[last->id()] + (last->length() - match));
+ } else {
+ temp_offsets[current.id()] = (UInt32)buf_.size();
+ for (std::size_t j = 1; j <= current.length(); ++j) {
+ buf_.push_back(current[current.length() - j]);
+ }
+ if (mode == MARISA_TEXT_TAIL) {
+ buf_.push_back('\0');
+ } else {
+ for (std::size_t j = 1; j < current.length(); ++j) {
+ end_flags_.push_back(false);
+ }
+ end_flags_.push_back(true);
+ }
+ MARISA_THROW_IF(buf_.size() > MARISA_UINT32_MAX, MARISA_SIZE_ERROR);
+ }
+ last = &current;
+ }
+ buf_.shrink();
+
+ offsets->swap(temp_offsets);
+}
+
+void Tail::map_(Mapper &mapper) {
+ buf_.map(mapper);
+ end_flags_.map(mapper);
+}
+
+void Tail::read_(Reader &reader) {
+ buf_.read(reader);
+ end_flags_.read(reader);
+}
+
+void Tail::write_(Writer &writer) const {
+ buf_.write(writer);
+ end_flags_.write(writer);
+}
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h
new file mode 100644
index 0000000000..7e5ca1d3e7
--- /dev/null
+++ b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h
@@ -0,0 +1,73 @@
+#pragma once
+#ifndef MARISA_GRIMOIRE_TRIE_TAIL_H_
+#define MARISA_GRIMOIRE_TRIE_TAIL_H_
+
+#include "../../agent.h"
+#include "../vector.h"
+#include "entry.h"
+
+namespace marisa {
+namespace grimoire {
+namespace trie {
+
+class Tail {
+ public:
+ Tail();
+
+ void build(Vector<Entry> &entries, Vector<UInt32> *offsets,
+ TailMode mode);
+
+ void map(Mapper &mapper);
+ void read(Reader &reader);
+ void write(Writer &writer) const;
+
+ void restore(Agent &agent, std::size_t offset) const;
+ bool match(Agent &agent, std::size_t offset) const;
+ bool prefix_match(Agent &agent, std::size_t offset) const;
+
+ const char &operator[](std::size_t offset) const {
+ MARISA_DEBUG_IF(offset >= buf_.size(), MARISA_BOUND_ERROR);
+ return buf_[offset];
+ }
+
+ TailMode mode() const {
+ return end_flags_.empty() ? MARISA_TEXT_TAIL : MARISA_BINARY_TAIL;
+ }
+
+ bool empty() const {
+ return buf_.empty();
+ }
+ std::size_t size() const {
+ return buf_.size();
+ }
+ std::size_t total_size() const {
+ return buf_.total_size() + end_flags_.total_size();
+ }
+ std::size_t io_size() const {
+ return buf_.io_size() + end_flags_.io_size();
+ }
+
+ void clear();
+ void swap(Tail &rhs);
+
+ private:
+ Vector<char> buf_;
+ BitVector end_flags_;
+
+ void build_(Vector<Entry> &entries, Vector<UInt32> *offsets,
+ TailMode mode);
+
+ void map_(Mapper &mapper);
+ void read_(Reader &reader);
+ void write_(Writer &writer) const;
+
+ // Disallows copy and assignment.
+ Tail(const Tail &);
+ Tail &operator=(const Tail &);
+};
+
+} // namespace trie
+} // namespace grimoire
+} // namespace marisa
+
+#endif // MARISA_GRIMOIRE_TRIE_TAIL_H_