diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2023-12-02 01:45:21 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2023-12-02 02:42:50 +0300 |
commit | 9c43d58f75cf086b744cf4fe2ae180e8f37e4a0c (patch) | |
tree | 9f88a486917d371d099cd712efd91b4c122d209d /contrib/python/marisa-trie/marisa/grimoire/trie | |
parent | 32fb6dda1feb24f9ab69ece5df0cb9ec238ca5e6 (diff) | |
download | ydb-9c43d58f75cf086b744cf4fe2ae180e8f37e4a0c.tar.gz |
Intermediate changes
Diffstat (limited to 'contrib/python/marisa-trie/marisa/grimoire/trie')
12 files changed, 2213 insertions, 0 deletions
diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h b/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h new file mode 100644 index 0000000000..f9da360869 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h @@ -0,0 +1,82 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_CACHE_H_ +#define MARISA_GRIMOIRE_TRIE_CACHE_H_ + +#include <cfloat> + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Cache { + public: + Cache() : parent_(0), child_(0), union_() { + union_.weight = FLT_MIN; + } + Cache(const Cache &cache) + : parent_(cache.parent_), child_(cache.child_), union_(cache.union_) {} + + Cache &operator=(const Cache &cache) { + parent_ = cache.parent_; + child_ = cache.child_; + union_ = cache.union_; + return *this; + } + + void set_parent(std::size_t parent) { + MARISA_DEBUG_IF(parent > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + parent_ = (UInt32)parent; + } + void set_child(std::size_t child) { + MARISA_DEBUG_IF(child > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + child_ = (UInt32)child; + } + void set_base(UInt8 base) { + union_.link = (union_.link & ~0xFFU) | base; + } + void set_extra(std::size_t extra) { + MARISA_DEBUG_IF(extra > (MARISA_UINT32_MAX >> 8), MARISA_SIZE_ERROR); + union_.link = (UInt32)((union_.link & 0xFFU) | (extra << 8)); + } + void set_weight(float weight) { + union_.weight = weight; + } + + std::size_t parent() const { + return parent_; + } + std::size_t child() const { + return child_; + } + UInt8 base() const { + return (UInt8)(union_.link & 0xFFU); + } + std::size_t extra() const { + return union_.link >> 8; + } + char label() const { + return (char)base(); + } + std::size_t link() const { + return union_.link; + } + float weight() const { + return union_.weight; + } + + private: + UInt32 parent_; + UInt32 child_; + union Union { + UInt32 link; + float weight; + } union_; +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_CACHE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/config.h b/contrib/python/marisa-trie/marisa/grimoire/trie/config.h new file mode 100644 index 0000000000..9b307de3e1 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/config.h @@ -0,0 +1,156 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_CONFIG_H_ +#define MARISA_GRIMOIRE_TRIE_CONFIG_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Config { + public: + Config() + : num_tries_(MARISA_DEFAULT_NUM_TRIES), + cache_level_(MARISA_DEFAULT_CACHE), + tail_mode_(MARISA_DEFAULT_TAIL), + node_order_(MARISA_DEFAULT_ORDER) {} + + void parse(int config_flags) { + Config temp; + temp.parse_(config_flags); + swap(temp); + } + + int flags() const { + return (int)num_tries_ | tail_mode_ | node_order_; + } + + std::size_t num_tries() const { + return num_tries_; + } + CacheLevel cache_level() const { + return cache_level_; + } + TailMode tail_mode() const { + return tail_mode_; + } + NodeOrder node_order() const { + return node_order_; + } + + void clear() { + Config().swap(*this); + } + void swap(Config &rhs) { + marisa::swap(num_tries_, rhs.num_tries_); + marisa::swap(cache_level_, rhs.cache_level_); + marisa::swap(tail_mode_, rhs.tail_mode_); + marisa::swap(node_order_, rhs.node_order_); + } + + private: + std::size_t num_tries_; + CacheLevel cache_level_; + TailMode tail_mode_; + NodeOrder node_order_; + + void parse_(int config_flags) { + MARISA_THROW_IF((config_flags & ~MARISA_CONFIG_MASK) != 0, + MARISA_CODE_ERROR); + + parse_num_tries(config_flags); + parse_cache_level(config_flags); + parse_tail_mode(config_flags); + parse_node_order(config_flags); + } + + void parse_num_tries(int config_flags) { + const int num_tries = config_flags & MARISA_NUM_TRIES_MASK; + if (num_tries != 0) { + num_tries_ = num_tries; + } + } + + void parse_cache_level(int config_flags) { + switch (config_flags & MARISA_CACHE_LEVEL_MASK) { + case 0: { + cache_level_ = MARISA_DEFAULT_CACHE; + break; + } + case MARISA_HUGE_CACHE: { + cache_level_ = MARISA_HUGE_CACHE; + break; + } + case MARISA_LARGE_CACHE: { + cache_level_ = MARISA_LARGE_CACHE; + break; + } + case MARISA_NORMAL_CACHE: { + cache_level_ = MARISA_NORMAL_CACHE; + break; + } + case MARISA_SMALL_CACHE: { + cache_level_ = MARISA_SMALL_CACHE; + break; + } + case MARISA_TINY_CACHE: { + cache_level_ = MARISA_TINY_CACHE; + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined cache level"); + } + } + } + + void parse_tail_mode(int config_flags) { + switch (config_flags & MARISA_TAIL_MODE_MASK) { + case 0: { + tail_mode_ = MARISA_DEFAULT_TAIL; + break; + } + case MARISA_TEXT_TAIL: { + tail_mode_ = MARISA_TEXT_TAIL; + break; + } + case MARISA_BINARY_TAIL: { + tail_mode_ = MARISA_BINARY_TAIL; + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined tail mode"); + } + } + } + + void parse_node_order(int config_flags) { + switch (config_flags & MARISA_NODE_ORDER_MASK) { + case 0: { + node_order_ = MARISA_DEFAULT_ORDER; + break; + } + case MARISA_LABEL_ORDER: { + node_order_ = MARISA_LABEL_ORDER; + break; + } + case MARISA_WEIGHT_ORDER: { + node_order_ = MARISA_WEIGHT_ORDER; + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined node order"); + } + } + } + + // Disallows copy and assignment. + Config(const Config &); + Config &operator=(const Config &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_CONFIG_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h b/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h new file mode 100644 index 0000000000..834ab95e1e --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h @@ -0,0 +1,83 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_ENTRY_H_ +#define MARISA_GRIMOIRE_TRIE_ENTRY_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Entry { + public: + Entry() + : ptr_(reinterpret_cast<const char *>(-1)), length_(0), id_(0) {} + Entry(const Entry &entry) + : ptr_(entry.ptr_), length_(entry.length_), id_(entry.id_) {} + + Entry &operator=(const Entry &entry) { + ptr_ = entry.ptr_; + length_ = entry.length_; + id_ = entry.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return *(ptr_ - i); + } + + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr + length - 1; + length_ = (UInt32)length; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + id_ = (UInt32)id; + } + + const char *ptr() const { + return ptr_ - length_ + 1; + } + std::size_t length() const { + return length_; + } + std::size_t id() const { + return id_; + } + + class StringComparer { + public: + bool operator()(const Entry &lhs, const Entry &rhs) const { + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return true; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] > (UInt8)rhs[i]; + } + } + return lhs.length() > rhs.length(); + } + }; + + class IDComparer { + public: + bool operator()(const Entry &lhs, const Entry &rhs) const { + return lhs.id_ < rhs.id_; + } + }; + + private: + const char *ptr_; + UInt32 length_; + UInt32 id_; +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_ENTRY_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/header.h b/contrib/python/marisa-trie/marisa/grimoire/trie/header.h new file mode 100644 index 0000000000..04839f67e1 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/header.h @@ -0,0 +1,62 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_HEADER_H_ +#define MARISA_GRIMOIRE_TRIE_HEADER_H_ + +#include "../io.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Header { + public: + enum { + HEADER_SIZE = 16 + }; + + Header() {} + + void map(Mapper &mapper) { + const char *ptr; + mapper.map(&ptr, HEADER_SIZE); + MARISA_THROW_IF(!test_header(ptr), MARISA_FORMAT_ERROR); + } + void read(Reader &reader) { + char buf[HEADER_SIZE]; + reader.read(buf, HEADER_SIZE); + MARISA_THROW_IF(!test_header(buf), MARISA_FORMAT_ERROR); + } + void write(Writer &writer) const { + writer.write(get_header(), HEADER_SIZE); + } + + std::size_t io_size() const { + return HEADER_SIZE; + } + + private: + + static const char *get_header() { + static const char buf[HEADER_SIZE] = "We love Marisa."; + return buf; + } + + static bool test_header(const char *ptr) { + for (std::size_t i = 0; i < HEADER_SIZE; ++i) { + if (ptr[i] != get_header()[i]) { + return false; + } + } + return true; + } + + // Disallows copy and assignment. + Header(const Header &); + Header &operator=(const Header &); +}; + +} // namespace trie +} // namespace marisa +} // namespace grimoire + +#endif // MARISA_GRIMOIRE_TRIE_HEADER_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/history.h b/contrib/python/marisa-trie/marisa/grimoire/trie/history.h new file mode 100644 index 0000000000..9a3d272260 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/history.h @@ -0,0 +1,66 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_ +#define MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class History { + public: + History() + : node_id_(0), louds_pos_(0), key_pos_(0), + link_id_(MARISA_INVALID_LINK_ID), key_id_(MARISA_INVALID_KEY_ID) {} + + void set_node_id(std::size_t node_id) { + MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + node_id_ = (UInt32)node_id; + } + void set_louds_pos(std::size_t louds_pos) { + MARISA_DEBUG_IF(louds_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + louds_pos_ = (UInt32)louds_pos; + } + void set_key_pos(std::size_t key_pos) { + MARISA_DEBUG_IF(key_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_pos_ = (UInt32)key_pos; + } + void set_link_id(std::size_t link_id) { + MARISA_DEBUG_IF(link_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + link_id_ = (UInt32)link_id; + } + void set_key_id(std::size_t key_id) { + MARISA_DEBUG_IF(key_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_id_ = (UInt32)key_id; + } + + std::size_t node_id() const { + return node_id_; + } + std::size_t louds_pos() const { + return louds_pos_; + } + std::size_t key_pos() const { + return key_pos_; + } + std::size_t link_id() const { + return link_id_; + } + std::size_t key_id() const { + return key_id_; + } + + private: + UInt32 node_id_; + UInt32 louds_pos_; + UInt32 key_pos_; + UInt32 link_id_; + UInt32 key_id_; +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/key.h b/contrib/python/marisa-trie/marisa/grimoire/trie/key.h new file mode 100644 index 0000000000..c09ea86cf8 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/key.h @@ -0,0 +1,227 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_KEY_H_ +#define MARISA_GRIMOIRE_TRIE_KEY_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Key { + public: + Key() : ptr_(NULL), length_(0), union_(), id_(0) { + union_.terminal = 0; + } + Key(const Key &entry) + : ptr_(entry.ptr_), length_(entry.length_), + union_(entry.union_), id_(entry.id_) {} + + Key &operator=(const Key &entry) { + ptr_ = entry.ptr_; + length_ = entry.length_; + union_ = entry.union_; + id_ = entry.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return ptr_[i]; + } + + void substr(std::size_t pos, std::size_t length) { + MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR); + ptr_ += pos; + length_ = (UInt32)length; + } + + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr; + length_ = (UInt32)length; + } + void set_weight(float weight) { + union_.weight = weight; + } + void set_terminal(std::size_t terminal) { + MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + union_.terminal = (UInt32)terminal; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + id_ = (UInt32)id; + } + + const char *ptr() const { + return ptr_; + } + std::size_t length() const { + return length_; + } + float weight() const { + return union_.weight; + } + std::size_t terminal() const { + return union_.terminal; + } + std::size_t id() const { + return id_; + } + + private: + const char *ptr_; + UInt32 length_; + union Union { + float weight; + UInt32 terminal; + } union_; + UInt32 id_; +}; + +inline bool operator==(const Key &lhs, const Key &rhs) { + if (lhs.length() != rhs.length()) { + return false; + } + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +inline bool operator!=(const Key &lhs, const Key &rhs) { + return !(lhs == rhs); +} + +inline bool operator<(const Key &lhs, const Key &rhs) { + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return false; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] < (UInt8)rhs[i]; + } + } + return lhs.length() < rhs.length(); +} + +inline bool operator>(const Key &lhs, const Key &rhs) { + return rhs < lhs; +} + +class ReverseKey { + public: + ReverseKey() : ptr_(NULL), length_(0), union_(), id_(0) { + union_.terminal = 0; + } + ReverseKey(const ReverseKey &entry) + : ptr_(entry.ptr_), length_(entry.length_), + union_(entry.union_), id_(entry.id_) {} + + ReverseKey &operator=(const ReverseKey &entry) { + ptr_ = entry.ptr_; + length_ = entry.length_; + union_ = entry.union_; + id_ = entry.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return *(ptr_ - i - 1); + } + + void substr(std::size_t pos, std::size_t length) { + MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR); + ptr_ -= pos; + length_ = (UInt32)length; + } + + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr + length; + length_ = (UInt32)length; + } + void set_weight(float weight) { + union_.weight = weight; + } + void set_terminal(std::size_t terminal) { + MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + union_.terminal = (UInt32)terminal; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + id_ = (UInt32)id; + } + + const char *ptr() const { + return ptr_ - length_; + } + std::size_t length() const { + return length_; + } + float weight() const { + return union_.weight; + } + std::size_t terminal() const { + return union_.terminal; + } + std::size_t id() const { + return id_; + } + + private: + const char *ptr_; + UInt32 length_; + union Union { + float weight; + UInt32 terminal; + } union_; + UInt32 id_; +}; + +inline bool operator==(const ReverseKey &lhs, const ReverseKey &rhs) { + if (lhs.length() != rhs.length()) { + return false; + } + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +inline bool operator!=(const ReverseKey &lhs, const ReverseKey &rhs) { + return !(lhs == rhs); +} + +inline bool operator<(const ReverseKey &lhs, const ReverseKey &rhs) { + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return false; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] < (UInt8)rhs[i]; + } + } + return lhs.length() < rhs.length(); +} + +inline bool operator>(const ReverseKey &lhs, const ReverseKey &rhs) { + return rhs < lhs; +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_KEY_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc new file mode 100644 index 0000000000..ed168539bc --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc @@ -0,0 +1,877 @@ +#include <algorithm> +#include <functional> +#include <queue> + +#include "../algorithm.h" +#include "header.h" +#include "range.h" +#include "state.h" +#include "louds-trie.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +LoudsTrie::LoudsTrie() + : louds_(), terminal_flags_(), link_flags_(), bases_(), extras_(), + tail_(), next_trie_(), cache_(), cache_mask_(0), num_l1_nodes_(0), + config_(), mapper_() {} + +LoudsTrie::~LoudsTrie() {} + +void LoudsTrie::build(Keyset &keyset, int flags) { + Config config; + config.parse(flags); + + LoudsTrie temp; + temp.build_(keyset, config); + swap(temp); +} + +void LoudsTrie::map(Mapper &mapper) { + Header().map(mapper); + + LoudsTrie temp; + temp.map_(mapper); + temp.mapper_.swap(mapper); + swap(temp); +} + +void LoudsTrie::read(Reader &reader) { + Header().read(reader); + + LoudsTrie temp; + temp.read_(reader); + swap(temp); +} + +void LoudsTrie::write(Writer &writer) const { + Header().write(writer); + + write_(writer); +} + +bool LoudsTrie::lookup(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + + State &state = agent.state(); + state.lookup_init(); + while (state.query_pos() < agent.query().length()) { + if (!find_child(agent)) { + return false; + } + } + if (!terminal_flags_[state.node_id()]) { + return false; + } + agent.set_key(agent.query().ptr(), agent.query().length()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; +} + +void LoudsTrie::reverse_lookup(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + MARISA_THROW_IF(agent.query().id() >= size(), MARISA_BOUND_ERROR); + + State &state = agent.state(); + state.reverse_lookup_init(); + + state.set_node_id(terminal_flags_.select1(agent.query().id())); + if (state.node_id() == 0) { + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(agent.query().id()); + return; + } + for ( ; ; ) { + if (link_flags_[state.node_id()]) { + const std::size_t prev_key_pos = state.key_buf().size(); + restore(agent, get_link(state.node_id())); + std::reverse(state.key_buf().begin() + prev_key_pos, + state.key_buf().end()); + } else { + state.key_buf().push_back((char)bases_[state.node_id()]); + } + + if (state.node_id() <= num_l1_nodes_) { + std::reverse(state.key_buf().begin(), state.key_buf().end()); + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(agent.query().id()); + return; + } + state.set_node_id(louds_.select1(state.node_id()) - state.node_id() - 1); + } +} + +bool LoudsTrie::common_prefix_search(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (state.status_code() == MARISA_END_OF_COMMON_PREFIX_SEARCH) { + return false; + } + + if (state.status_code() != MARISA_READY_TO_COMMON_PREFIX_SEARCH) { + state.common_prefix_search_init(); + if (terminal_flags_[state.node_id()]) { + agent.set_key(agent.query().ptr(), state.query_pos()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; + } + } + + while (state.query_pos() < agent.query().length()) { + if (!find_child(agent)) { + state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH); + return false; + } else if (terminal_flags_[state.node_id()]) { + agent.set_key(agent.query().ptr(), state.query_pos()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; + } + } + state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH); + return false; +} + +bool LoudsTrie::predictive_search(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (state.status_code() == MARISA_END_OF_PREDICTIVE_SEARCH) { + return false; + } + + if (state.status_code() != MARISA_READY_TO_PREDICTIVE_SEARCH) { + state.predictive_search_init(); + while (state.query_pos() < agent.query().length()) { + if (!predictive_find_child(agent)) { + state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH); + return false; + } + } + + History history; + history.set_node_id(state.node_id()); + history.set_key_pos(state.key_buf().size()); + state.history().push_back(history); + state.set_history_pos(1); + + if (terminal_flags_[state.node_id()]) { + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; + } + } + + for ( ; ; ) { + if (state.history_pos() == state.history().size()) { + const History ¤t = state.history().back(); + History next; + next.set_louds_pos(louds_.select0(current.node_id()) + 1); + next.set_node_id(next.louds_pos() - current.node_id() - 1); + state.history().push_back(next); + } + + History &next = state.history()[state.history_pos()]; + const bool link_flag = louds_[next.louds_pos()]; + next.set_louds_pos(next.louds_pos() + 1); + if (link_flag) { + state.set_history_pos(state.history_pos() + 1); + if (link_flags_[next.node_id()]) { + next.set_link_id(update_link_id(next.link_id(), next.node_id())); + restore(agent, get_link(next.node_id(), next.link_id())); + } else { + state.key_buf().push_back((char)bases_[next.node_id()]); + } + next.set_key_pos(state.key_buf().size()); + + if (terminal_flags_[next.node_id()]) { + if (next.key_id() == MARISA_INVALID_KEY_ID) { + next.set_key_id(terminal_flags_.rank1(next.node_id())); + } else { + next.set_key_id(next.key_id() + 1); + } + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(next.key_id()); + return true; + } + } else if (state.history_pos() != 1) { + History ¤t = state.history()[state.history_pos() - 1]; + current.set_node_id(current.node_id() + 1); + const History &prev = + state.history()[state.history_pos() - 2]; + state.key_buf().resize(prev.key_pos()); + state.set_history_pos(state.history_pos() - 1); + } else { + state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH); + return false; + } + } +} + +std::size_t LoudsTrie::total_size() const { + return louds_.total_size() + terminal_flags_.total_size() + + link_flags_.total_size() + bases_.total_size() + + extras_.total_size() + tail_.total_size() + + ((next_trie_.get() != NULL) ? next_trie_->total_size() : 0) + + cache_.total_size(); +} + +std::size_t LoudsTrie::io_size() const { + return Header().io_size() + louds_.io_size() + + terminal_flags_.io_size() + link_flags_.io_size() + + bases_.io_size() + extras_.io_size() + tail_.io_size() + + ((next_trie_.get() != NULL) ? + (next_trie_->io_size() - Header().io_size()) : 0) + + cache_.io_size() + (sizeof(UInt32) * 2); +} + +void LoudsTrie::clear() { + LoudsTrie().swap(*this); +} + +void LoudsTrie::swap(LoudsTrie &rhs) { + louds_.swap(rhs.louds_); + terminal_flags_.swap(rhs.terminal_flags_); + link_flags_.swap(rhs.link_flags_); + bases_.swap(rhs.bases_); + extras_.swap(rhs.extras_); + tail_.swap(rhs.tail_); + next_trie_.swap(rhs.next_trie_); + cache_.swap(rhs.cache_); + marisa::swap(cache_mask_, rhs.cache_mask_); + marisa::swap(num_l1_nodes_, rhs.num_l1_nodes_); + config_.swap(rhs.config_); + mapper_.swap(rhs.mapper_); +} + +void LoudsTrie::build_(Keyset &keyset, const Config &config) { + Vector<Key> keys; + keys.resize(keyset.size()); + for (std::size_t i = 0; i < keyset.size(); ++i) { + keys[i].set_str(keyset[i].ptr(), keyset[i].length()); + keys[i].set_weight(keyset[i].weight()); + } + + Vector<UInt32> terminals; + build_trie(keys, &terminals, config, 1); + + typedef std::pair<UInt32, UInt32> TerminalIdPair; + + Vector<TerminalIdPair> pairs; + pairs.resize(terminals.size()); + for (std::size_t i = 0; i < pairs.size(); ++i) { + pairs[i].first = terminals[i]; + pairs[i].second = (UInt32)i; + } + terminals.clear(); + std::sort(pairs.begin(), pairs.end()); + + std::size_t node_id = 0; + for (std::size_t i = 0; i < pairs.size(); ++i) { + while (node_id < pairs[i].first) { + terminal_flags_.push_back(false); + ++node_id; + } + if (node_id == pairs[i].first) { + terminal_flags_.push_back(true); + ++node_id; + } + } + while (node_id < bases_.size()) { + terminal_flags_.push_back(false); + ++node_id; + } + terminal_flags_.push_back(false); + terminal_flags_.build(false, true); + + for (std::size_t i = 0; i < keyset.size(); ++i) { + keyset[pairs[i].second].set_id(terminal_flags_.rank1(pairs[i].first)); + } +} + +template <typename T> +void LoudsTrie::build_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) { + build_current_trie(keys, terminals, config, trie_id); + + Vector<UInt32> next_terminals; + if (!keys.empty()) { + build_next_trie(keys, &next_terminals, config, trie_id); + } + + if (next_trie_.get() != NULL) { + config_.parse(static_cast<int>((next_trie_->num_tries() + 1)) | + next_trie_->tail_mode() | next_trie_->node_order()); + } else { + config_.parse(1 | tail_.mode() | config.node_order() | + config.cache_level()); + } + + link_flags_.build(false, false); + std::size_t node_id = 0; + for (std::size_t i = 0; i < next_terminals.size(); ++i) { + while (!link_flags_[node_id]) { + ++node_id; + } + bases_[node_id] = (UInt8)(next_terminals[i] % 256); + next_terminals[i] /= 256; + ++node_id; + } + extras_.build(next_terminals); + fill_cache(); +} + +template <typename T> +void LoudsTrie::build_current_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, + std::size_t trie_id) try { + for (std::size_t i = 0; i < keys.size(); ++i) { + keys[i].set_id(i); + } + const std::size_t num_keys = Algorithm().sort(keys.begin(), keys.end()); + reserve_cache(config, trie_id, num_keys); + + louds_.push_back(true); + louds_.push_back(false); + bases_.push_back('\0'); + link_flags_.push_back(false); + + Vector<T> next_keys; + std::queue<Range> queue; + Vector<WeightedRange> w_ranges; + + queue.push(make_range(0, keys.size(), 0)); + while (!queue.empty()) { + const std::size_t node_id = link_flags_.size() - queue.size(); + + Range range = queue.front(); + queue.pop(); + + while ((range.begin() < range.end()) && + (keys[range.begin()].length() == range.key_pos())) { + keys[range.begin()].set_terminal(node_id); + range.set_begin(range.begin() + 1); + } + + if (range.begin() == range.end()) { + louds_.push_back(false); + continue; + } + + w_ranges.clear(); + double weight = keys[range.begin()].weight(); + for (std::size_t i = range.begin() + 1; i < range.end(); ++i) { + if (keys[i - 1][range.key_pos()] != keys[i][range.key_pos()]) { + w_ranges.push_back(make_weighted_range( + range.begin(), i, range.key_pos(), (float)weight)); + range.set_begin(i); + weight = 0.0; + } + weight += keys[i].weight(); + } + w_ranges.push_back(make_weighted_range( + range.begin(), range.end(), range.key_pos(), (float)weight)); + if (config.node_order() == MARISA_WEIGHT_ORDER) { + std::stable_sort(w_ranges.begin(), w_ranges.end(), + std::greater<WeightedRange>()); + } + + if (node_id == 0) { + num_l1_nodes_ = w_ranges.size(); + } + + for (std::size_t i = 0; i < w_ranges.size(); ++i) { + WeightedRange &w_range = w_ranges[i]; + std::size_t key_pos = w_range.key_pos() + 1; + while (key_pos < keys[w_range.begin()].length()) { + std::size_t j; + for (j = w_range.begin() + 1; j < w_range.end(); ++j) { + if (keys[j - 1][key_pos] != keys[j][key_pos]) { + break; + } + } + if (j < w_range.end()) { + break; + } + ++key_pos; + } + cache<T>(node_id, bases_.size(), w_range.weight(), + keys[w_range.begin()][w_range.key_pos()]); + + if (key_pos == w_range.key_pos() + 1) { + bases_.push_back(keys[w_range.begin()][w_range.key_pos()]); + link_flags_.push_back(false); + } else { + bases_.push_back('\0'); + link_flags_.push_back(true); + T next_key; + next_key.set_str(keys[w_range.begin()].ptr(), + keys[w_range.begin()].length()); + next_key.substr(w_range.key_pos(), key_pos - w_range.key_pos()); + next_key.set_weight(w_range.weight()); + next_keys.push_back(next_key); + } + w_range.set_key_pos(key_pos); + queue.push(w_range.range()); + louds_.push_back(true); + } + louds_.push_back(false); + } + + louds_.push_back(false); + louds_.build(trie_id == 1, true); + bases_.shrink(); + + build_terminals(keys, terminals); + keys.swap(next_keys); +} catch (const std::bad_alloc &) { + MARISA_THROW(MARISA_MEMORY_ERROR, "std::bad_alloc"); +} + +template <> +void LoudsTrie::build_next_trie(Vector<Key> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) { + if (trie_id == config.num_tries()) { + Vector<Entry> entries; + entries.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + entries[i].set_str(keys[i].ptr(), keys[i].length()); + } + tail_.build(entries, terminals, config.tail_mode()); + return; + } + Vector<ReverseKey> reverse_keys; + reverse_keys.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + reverse_keys[i].set_str(keys[i].ptr(), keys[i].length()); + reverse_keys[i].set_weight(keys[i].weight()); + } + keys.clear(); + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->build_trie(reverse_keys, terminals, config, trie_id + 1); +} + +template <> +void LoudsTrie::build_next_trie(Vector<ReverseKey> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) { + if (trie_id == config.num_tries()) { + Vector<Entry> entries; + entries.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + entries[i].set_str(keys[i].ptr(), keys[i].length()); + } + tail_.build(entries, terminals, config.tail_mode()); + return; + } + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->build_trie(keys, terminals, config, trie_id + 1); +} + +template <typename T> +void LoudsTrie::build_terminals(const Vector<T> &keys, + Vector<UInt32> *terminals) const { + Vector<UInt32> temp; + temp.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + temp[keys[i].id()] = (UInt32)keys[i].terminal(); + } + terminals->swap(temp); +} + +template <> +void LoudsTrie::cache<Key>(std::size_t parent, std::size_t child, + float weight, char label) { + MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR); + + const std::size_t cache_id = get_cache_id(parent, label); + if (weight > cache_[cache_id].weight()) { + cache_[cache_id].set_parent(parent); + cache_[cache_id].set_child(child); + cache_[cache_id].set_weight(weight); + } +} + +void LoudsTrie::reserve_cache(const Config &config, std::size_t trie_id, + std::size_t num_keys) { + std::size_t cache_size = (trie_id == 1) ? 256 : 1; + while (cache_size < (num_keys / config.cache_level())) { + cache_size *= 2; + } + cache_.resize(cache_size); + cache_mask_ = cache_size - 1; +} + +template <> +void LoudsTrie::cache<ReverseKey>(std::size_t parent, std::size_t child, + float weight, char) { + MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR); + + const std::size_t cache_id = get_cache_id(child); + if (weight > cache_[cache_id].weight()) { + cache_[cache_id].set_parent(parent); + cache_[cache_id].set_child(child); + cache_[cache_id].set_weight(weight); + } +} + +void LoudsTrie::fill_cache() { + for (std::size_t i = 0; i < cache_.size(); ++i) { + const std::size_t node_id = cache_[i].child(); + if (node_id != 0) { + cache_[i].set_base(bases_[node_id]); + cache_[i].set_extra(!link_flags_[node_id] ? + MARISA_INVALID_EXTRA : extras_[link_flags_.rank1(node_id)]); + } else { + cache_[i].set_parent(MARISA_UINT32_MAX); + cache_[i].set_child(MARISA_UINT32_MAX); + } + } +} + +void LoudsTrie::map_(Mapper &mapper) { + louds_.map(mapper); + terminal_flags_.map(mapper); + link_flags_.map(mapper); + bases_.map(mapper); + extras_.map(mapper); + tail_.map(mapper); + if ((link_flags_.num_1s() != 0) && tail_.empty()) { + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->map_(mapper); + } + cache_.map(mapper); + cache_mask_ = cache_.size() - 1; + { + UInt32 temp_num_l1_nodes; + mapper.map(&temp_num_l1_nodes); + num_l1_nodes_ = temp_num_l1_nodes; + } + { + UInt32 temp_config_flags; + mapper.map(&temp_config_flags); + config_.parse((int)temp_config_flags); + } +} + +void LoudsTrie::read_(Reader &reader) { + louds_.read(reader); + terminal_flags_.read(reader); + link_flags_.read(reader); + bases_.read(reader); + extras_.read(reader); + tail_.read(reader); + if ((link_flags_.num_1s() != 0) && tail_.empty()) { + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->read_(reader); + } + cache_.read(reader); + cache_mask_ = cache_.size() - 1; + { + UInt32 temp_num_l1_nodes; + reader.read(&temp_num_l1_nodes); + num_l1_nodes_ = temp_num_l1_nodes; + } + { + UInt32 temp_config_flags; + reader.read(&temp_config_flags); + config_.parse((int)temp_config_flags); + } +} + +void LoudsTrie::write_(Writer &writer) const { + louds_.write(writer); + terminal_flags_.write(writer); + link_flags_.write(writer); + bases_.write(writer); + extras_.write(writer); + tail_.write(writer); + if (next_trie_.get() != NULL) { + next_trie_->write_(writer); + } + cache_.write(writer); + writer.write((UInt32)num_l1_nodes_); + writer.write((UInt32)config_.flags()); +} + +bool LoudsTrie::find_child(Agent &agent) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + + State &state = agent.state(); + const std::size_t cache_id = get_cache_id(state.node_id(), + agent.query()[state.query_pos()]); + if (state.node_id() == cache_[cache_id].parent()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!match(agent, cache_[cache_id].link())) { + return false; + } + } else { + state.set_query_pos(state.query_pos() + 1); + } + state.set_node_id(cache_[cache_id].child()); + return true; + } + + std::size_t louds_pos = louds_.select0(state.node_id()) + 1; + if (!louds_[louds_pos]) { + return false; + } + state.set_node_id(louds_pos - state.node_id() - 1); + std::size_t link_id = MARISA_INVALID_LINK_ID; + do { + if (link_flags_[state.node_id()]) { + link_id = update_link_id(link_id, state.node_id()); + const std::size_t prev_query_pos = state.query_pos(); + if (match(agent, get_link(state.node_id(), link_id))) { + return true; + } else if (state.query_pos() != prev_query_pos) { + return false; + } + } else if (bases_[state.node_id()] == + (UInt8)agent.query()[state.query_pos()]) { + state.set_query_pos(state.query_pos() + 1); + return true; + } + state.set_node_id(state.node_id() + 1); + ++louds_pos; + } while (louds_[louds_pos]); + return false; +} + +bool LoudsTrie::predictive_find_child(Agent &agent) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + + State &state = agent.state(); + const std::size_t cache_id = get_cache_id(state.node_id(), + agent.query()[state.query_pos()]); + if (state.node_id() == cache_[cache_id].parent()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!prefix_match(agent, cache_[cache_id].link())) { + return false; + } + } else { + state.key_buf().push_back(cache_[cache_id].label()); + state.set_query_pos(state.query_pos() + 1); + } + state.set_node_id(cache_[cache_id].child()); + return true; + } + + std::size_t louds_pos = louds_.select0(state.node_id()) + 1; + if (!louds_[louds_pos]) { + return false; + } + state.set_node_id(louds_pos - state.node_id() - 1); + std::size_t link_id = MARISA_INVALID_LINK_ID; + do { + if (link_flags_[state.node_id()]) { + link_id = update_link_id(link_id, state.node_id()); + const std::size_t prev_query_pos = state.query_pos(); + if (prefix_match(agent, get_link(state.node_id(), link_id))) { + return true; + } else if (state.query_pos() != prev_query_pos) { + return false; + } + } else if (bases_[state.node_id()] == + (UInt8)agent.query()[state.query_pos()]) { + state.key_buf().push_back((char)bases_[state.node_id()]); + state.set_query_pos(state.query_pos() + 1); + return true; + } + state.set_node_id(state.node_id() + 1); + ++louds_pos; + } while (louds_[louds_pos]); + return false; +} + +void LoudsTrie::restore(Agent &agent, std::size_t link) const { + if (next_trie_.get() != NULL) { + next_trie_->restore_(agent, link); + } else { + tail_.restore(agent, link); + } +} + +bool LoudsTrie::match(Agent &agent, std::size_t link) const { + if (next_trie_.get() != NULL) { + return next_trie_->match_(agent, link); + } else { + return tail_.match(agent, link); + } +} + +bool LoudsTrie::prefix_match(Agent &agent, std::size_t link) const { + if (next_trie_.get() != NULL) { + return next_trie_->prefix_match_(agent, link); + } else { + return tail_.prefix_match(agent, link); + } +} + +void LoudsTrie::restore_(Agent &agent, std::size_t node_id) const { + MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR); + + State &state = agent.state(); + for ( ; ; ) { + const std::size_t cache_id = get_cache_id(node_id); + if (node_id == cache_[cache_id].child()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + restore(agent, cache_[cache_id].link()); + } else { + state.key_buf().push_back(cache_[cache_id].label()); + } + + node_id = cache_[cache_id].parent(); + if (node_id == 0) { + return; + } + continue; + } + + if (link_flags_[node_id]) { + restore(agent, get_link(node_id)); + } else { + state.key_buf().push_back((char)bases_[node_id]); + } + + if (node_id <= num_l1_nodes_) { + return; + } + node_id = louds_.select1(node_id) - node_id - 1; + } +} + +bool LoudsTrie::match_(Agent &agent, std::size_t node_id) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR); + + State &state = agent.state(); + for ( ; ; ) { + const std::size_t cache_id = get_cache_id(node_id); + if (node_id == cache_[cache_id].child()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!match(agent, cache_[cache_id].link())) { + return false; + } + } else if (cache_[cache_id].label() == + agent.query()[state.query_pos()]) { + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + node_id = cache_[cache_id].parent(); + if (node_id == 0) { + return true; + } else if (state.query_pos() >= agent.query().length()) { + return false; + } + continue; + } + + if (link_flags_[node_id]) { + if (next_trie_.get() != NULL) { + if (!match(agent, get_link(node_id))) { + return false; + } + } else if (!tail_.match(agent, get_link(node_id))) { + return false; + } + } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) { + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + if (node_id <= num_l1_nodes_) { + return true; + } else if (state.query_pos() >= agent.query().length()) { + return false; + } + node_id = louds_.select1(node_id) - node_id - 1; + } +} + +bool LoudsTrie::prefix_match_(Agent &agent, std::size_t node_id) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR); + + State &state = agent.state(); + for ( ; ; ) { + const std::size_t cache_id = get_cache_id(node_id); + if (node_id == cache_[cache_id].child()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!prefix_match(agent, cache_[cache_id].link())) { + return false; + } + } else if (cache_[cache_id].label() == + agent.query()[state.query_pos()]) { + state.key_buf().push_back(cache_[cache_id].label()); + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + node_id = cache_[cache_id].parent(); + if (node_id == 0) { + return true; + } + } else { + if (link_flags_[node_id]) { + if (!prefix_match(agent, get_link(node_id))) { + return false; + } + } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) { + state.key_buf().push_back((char)bases_[node_id]); + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + if (node_id <= num_l1_nodes_) { + return true; + } + node_id = louds_.select1(node_id) - node_id - 1; + } + + if (state.query_pos() >= agent.query().length()) { + restore_(agent, node_id); + return true; + } + } +} + +std::size_t LoudsTrie::get_cache_id(std::size_t node_id, char label) const { + return (node_id ^ (node_id << 5) ^ (UInt8)label) & cache_mask_; +} + +std::size_t LoudsTrie::get_cache_id(std::size_t node_id) const { + return node_id & cache_mask_; +} + +std::size_t LoudsTrie::get_link(std::size_t node_id) const { + return bases_[node_id] | (extras_[link_flags_.rank1(node_id)] * 256); +} + +std::size_t LoudsTrie::get_link(std::size_t node_id, + std::size_t link_id) const { + return bases_[node_id] | (extras_[link_id] * 256); +} + +std::size_t LoudsTrie::update_link_id(std::size_t link_id, + std::size_t node_id) const { + return (link_id == MARISA_INVALID_LINK_ID) ? + link_flags_.rank1(node_id) : (link_id + 1); +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h new file mode 100644 index 0000000000..5a757ac8fc --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h @@ -0,0 +1,135 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_ +#define MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_ + +#include "../../keyset.h" +#include "../../agent.h" +#include "../vector.h" +#include "config.h" +#include "key.h" +#include "tail.h" +#include "cache.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class LoudsTrie { + public: + LoudsTrie(); + ~LoudsTrie(); + + void build(Keyset &keyset, int flags); + + void map(Mapper &mapper); + void read(Reader &reader); + void write(Writer &writer) const; + + bool lookup(Agent &agent) const; + void reverse_lookup(Agent &agent) const; + bool common_prefix_search(Agent &agent) const; + bool predictive_search(Agent &agent) const; + + std::size_t num_tries() const { + return config_.num_tries(); + } + std::size_t num_keys() const { + return size(); + } + std::size_t num_nodes() const { + return (louds_.size() / 2) - 1; + } + + CacheLevel cache_level() const { + return config_.cache_level(); + } + TailMode tail_mode() const { + return config_.tail_mode(); + } + NodeOrder node_order() const { + return config_.node_order(); + } + + bool empty() const { + return size() == 0; + } + std::size_t size() const { + return terminal_flags_.num_1s(); + } + std::size_t total_size() const; + std::size_t io_size() const; + + void clear(); + void swap(LoudsTrie &rhs); + + private: + BitVector louds_; + BitVector terminal_flags_; + BitVector link_flags_; + Vector<UInt8> bases_; + FlatVector extras_; + Tail tail_; + scoped_ptr<LoudsTrie> next_trie_; + Vector<Cache> cache_; + std::size_t cache_mask_; + std::size_t num_l1_nodes_; + Config config_; + Mapper mapper_; + + void build_(Keyset &keyset, const Config &config); + + template <typename T> + void build_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id); + template <typename T> + void build_current_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id); + template <typename T> + void build_next_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id); + template <typename T> + void build_terminals(const Vector<T> &keys, + Vector<UInt32> *terminals) const; + + void reserve_cache(const Config &config, std::size_t trie_id, + std::size_t num_keys); + template <typename T> + void cache(std::size_t parent, std::size_t child, + float weight, char label); + void fill_cache(); + + void map_(Mapper &mapper); + void read_(Reader &reader); + void write_(Writer &writer) const; + + inline bool find_child(Agent &agent) const; + inline bool predictive_find_child(Agent &agent) const; + + inline void restore(Agent &agent, std::size_t node_id) const; + inline bool match(Agent &agent, std::size_t node_id) const; + inline bool prefix_match(Agent &agent, std::size_t node_id) const; + + void restore_(Agent &agent, std::size_t node_id) const; + bool match_(Agent &agent, std::size_t node_id) const; + bool prefix_match_(Agent &agent, std::size_t node_id) const; + + inline std::size_t get_cache_id(std::size_t node_id, char label) const; + inline std::size_t get_cache_id(std::size_t node_id) const; + + inline std::size_t get_link(std::size_t node_id) const; + inline std::size_t get_link(std::size_t node_id, + std::size_t link_id) const; + + inline std::size_t update_link_id(std::size_t link_id, + std::size_t node_id) const; + + // Disallows copy and assignment. + LoudsTrie(const LoudsTrie &); + LoudsTrie &operator=(const LoudsTrie &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/range.h b/contrib/python/marisa-trie/marisa/grimoire/trie/range.h new file mode 100644 index 0000000000..4ca39a9c37 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/range.h @@ -0,0 +1,116 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_RANGE_H_ +#define MARISA_GRIMOIRE_TRIE_RANGE_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Range { + public: + Range() : begin_(0), end_(0), key_pos_(0) {} + + void set_begin(std::size_t begin) { + MARISA_DEBUG_IF(begin > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + begin_ = static_cast<UInt32>(begin); + } + void set_end(std::size_t end) { + MARISA_DEBUG_IF(end > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + end_ = static_cast<UInt32>(end); + } + void set_key_pos(std::size_t key_pos) { + MARISA_DEBUG_IF(key_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_pos_ = static_cast<UInt32>(key_pos); + } + + std::size_t begin() const { + return begin_; + } + std::size_t end() const { + return end_; + } + std::size_t key_pos() const { + return key_pos_; + } + + private: + UInt32 begin_; + UInt32 end_; + UInt32 key_pos_; +}; + +inline Range make_range(std::size_t begin, std::size_t end, + std::size_t key_pos) { + Range range; + range.set_begin(begin); + range.set_end(end); + range.set_key_pos(key_pos); + return range; +} + +class WeightedRange { + public: + WeightedRange() : range_(), weight_(0.0F) {} + + void set_range(const Range &range) { + range_ = range; + } + void set_begin(std::size_t begin) { + range_.set_begin(begin); + } + void set_end(std::size_t end) { + range_.set_end(end); + } + void set_key_pos(std::size_t key_pos) { + range_.set_key_pos(key_pos); + } + void set_weight(float weight) { + weight_ = weight; + } + + const Range &range() const { + return range_; + } + std::size_t begin() const { + return range_.begin(); + } + std::size_t end() const { + return range_.end(); + } + std::size_t key_pos() const { + return range_.key_pos(); + } + float weight() const { + return weight_; + } + + private: + Range range_; + float weight_; +}; + +inline bool operator<(const WeightedRange &lhs, const WeightedRange &rhs) { + return lhs.weight() < rhs.weight(); +} + +inline bool operator>(const WeightedRange &lhs, const WeightedRange &rhs) { + return lhs.weight() > rhs.weight(); +} + +inline WeightedRange make_weighted_range(std::size_t begin, std::size_t end, + std::size_t key_pos, float weight) { + WeightedRange range; + range.set_begin(begin); + range.set_end(end); + range.set_key_pos(key_pos); + range.set_weight(weight); + return range; +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_RANGE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/state.h b/contrib/python/marisa-trie/marisa/grimoire/trie/state.h new file mode 100644 index 0000000000..219bf9e03a --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/state.h @@ -0,0 +1,118 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_STATE_H_ +#define MARISA_GRIMOIRE_TRIE_STATE_H_ + +#include "../vector.h" +#include "history.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +// A search agent has its internal state and the status codes are defined +// below. +typedef enum StatusCode { + MARISA_READY_TO_ALL, + MARISA_READY_TO_COMMON_PREFIX_SEARCH, + MARISA_READY_TO_PREDICTIVE_SEARCH, + MARISA_END_OF_COMMON_PREFIX_SEARCH, + MARISA_END_OF_PREDICTIVE_SEARCH, +} StatusCode; + +class State { + public: + State() + : key_buf_(), history_(), node_id_(0), query_pos_(0), + history_pos_(0), status_code_(MARISA_READY_TO_ALL) {} + + void set_node_id(std::size_t node_id) { + MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + node_id_ = (UInt32)node_id; + } + void set_query_pos(std::size_t query_pos) { + MARISA_DEBUG_IF(query_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + query_pos_ = (UInt32)query_pos; + } + void set_history_pos(std::size_t history_pos) { + MARISA_DEBUG_IF(history_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + history_pos_ = (UInt32)history_pos; + } + void set_status_code(StatusCode status_code) { + status_code_ = status_code; + } + + std::size_t node_id() const { + return node_id_; + } + std::size_t query_pos() const { + return query_pos_; + } + std::size_t history_pos() const { + return history_pos_; + } + StatusCode status_code() const { + return status_code_; + } + + const Vector<char> &key_buf() const { + return key_buf_; + } + const Vector<History> &history() const { + return history_; + } + + Vector<char> &key_buf() { + return key_buf_; + } + Vector<History> &history() { + return history_; + } + + void reset() { + status_code_ = MARISA_READY_TO_ALL; + } + + void lookup_init() { + node_id_ = 0; + query_pos_ = 0; + status_code_ = MARISA_READY_TO_ALL; + } + void reverse_lookup_init() { + key_buf_.resize(0); + key_buf_.reserve(32); + status_code_ = MARISA_READY_TO_ALL; + } + void common_prefix_search_init() { + node_id_ = 0; + query_pos_ = 0; + status_code_ = MARISA_READY_TO_COMMON_PREFIX_SEARCH; + } + void predictive_search_init() { + key_buf_.resize(0); + key_buf_.reserve(64); + history_.resize(0); + history_.reserve(4); + node_id_ = 0; + query_pos_ = 0; + history_pos_ = 0; + status_code_ = MARISA_READY_TO_PREDICTIVE_SEARCH; + } + + private: + Vector<char> key_buf_; + Vector<History> history_; + UInt32 node_id_; + UInt32 query_pos_; + UInt32 history_pos_; + StatusCode status_code_; + + // Disallows copy and assignment. + State(const State &); + State &operator=(const State &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_STATE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc new file mode 100644 index 0000000000..6ec3652e1c --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc @@ -0,0 +1,218 @@ +#include "../algorithm.h" +#include "state.h" +#include "tail.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +Tail::Tail() : buf_(), end_flags_() {} + +void Tail::build(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode) { + MARISA_THROW_IF(offsets == NULL, MARISA_NULL_ERROR); + + switch (mode) { + case MARISA_TEXT_TAIL: { + for (std::size_t i = 0; i < entries.size(); ++i) { + const char * const ptr = entries[i].ptr(); + const std::size_t length = entries[i].length(); + for (std::size_t j = 0; j < length; ++j) { + if (ptr[j] == '\0') { + mode = MARISA_BINARY_TAIL; + break; + } + } + if (mode == MARISA_BINARY_TAIL) { + break; + } + } + break; + } + case MARISA_BINARY_TAIL: { + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined tail mode"); + } + } + + Tail temp; + temp.build_(entries, offsets, mode); + swap(temp); +} + +void Tail::map(Mapper &mapper) { + Tail temp; + temp.map_(mapper); + swap(temp); +} + +void Tail::read(Reader &reader) { + Tail temp; + temp.read_(reader); + swap(temp); +} + +void Tail::write(Writer &writer) const { + write_(writer); +} + +void Tail::restore(Agent &agent, std::size_t offset) const { + MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (end_flags_.empty()) { + for (const char *ptr = &buf_[offset]; *ptr != '\0'; ++ptr) { + state.key_buf().push_back(*ptr); + } + } else { + do { + state.key_buf().push_back(buf_[offset]); + } while (!end_flags_[offset++]); + } +} + +bool Tail::match(Agent &agent, std::size_t offset) const { + MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + + State &state = agent.state(); + if (end_flags_.empty()) { + const char * const ptr = &buf_[offset] - state.query_pos(); + do { + if (ptr[state.query_pos()] != agent.query()[state.query_pos()]) { + return false; + } + state.set_query_pos(state.query_pos() + 1); + if (ptr[state.query_pos()] == '\0') { + return true; + } + } while (state.query_pos() < agent.query().length()); + return false; + } else { + do { + if (buf_[offset] != agent.query()[state.query_pos()]) { + return false; + } + state.set_query_pos(state.query_pos() + 1); + if (end_flags_[offset++]) { + return true; + } + } while (state.query_pos() < agent.query().length()); + return false; + } +} + +bool Tail::prefix_match(Agent &agent, std::size_t offset) const { + MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (end_flags_.empty()) { + const char *ptr = &buf_[offset] - state.query_pos(); + do { + if (ptr[state.query_pos()] != agent.query()[state.query_pos()]) { + return false; + } + state.key_buf().push_back(ptr[state.query_pos()]); + state.set_query_pos(state.query_pos() + 1); + if (ptr[state.query_pos()] == '\0') { + return true; + } + } while (state.query_pos() < agent.query().length()); + ptr += state.query_pos(); + do { + state.key_buf().push_back(*ptr); + } while (*++ptr != '\0'); + return true; + } else { + do { + if (buf_[offset] != agent.query()[state.query_pos()]) { + return false; + } + state.key_buf().push_back(buf_[offset]); + state.set_query_pos(state.query_pos() + 1); + if (end_flags_[offset++]) { + return true; + } + } while (state.query_pos() < agent.query().length()); + do { + state.key_buf().push_back(buf_[offset]); + } while (!end_flags_[offset++]); + return true; + } +} + +void Tail::clear() { + Tail().swap(*this); +} + +void Tail::swap(Tail &rhs) { + buf_.swap(rhs.buf_); + end_flags_.swap(rhs.end_flags_); +} + +void Tail::build_(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode) { + for (std::size_t i = 0; i < entries.size(); ++i) { + entries[i].set_id(i); + } + Algorithm().sort(entries.begin(), entries.end()); + + Vector<UInt32> temp_offsets; + temp_offsets.resize(entries.size(), 0); + + const Entry dummy; + const Entry *last = &dummy; + for (std::size_t i = entries.size(); i > 0; --i) { + const Entry ¤t = entries[i - 1]; + MARISA_THROW_IF(current.length() == 0, MARISA_RANGE_ERROR); + std::size_t match = 0; + while ((match < current.length()) && (match < last->length()) && + ((*last)[match] == current[match])) { + ++match; + } + if ((match == current.length()) && (last->length() != 0)) { + temp_offsets[current.id()] = (UInt32)( + temp_offsets[last->id()] + (last->length() - match)); + } else { + temp_offsets[current.id()] = (UInt32)buf_.size(); + for (std::size_t j = 1; j <= current.length(); ++j) { + buf_.push_back(current[current.length() - j]); + } + if (mode == MARISA_TEXT_TAIL) { + buf_.push_back('\0'); + } else { + for (std::size_t j = 1; j < current.length(); ++j) { + end_flags_.push_back(false); + } + end_flags_.push_back(true); + } + MARISA_THROW_IF(buf_.size() > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + } + last = ¤t; + } + buf_.shrink(); + + offsets->swap(temp_offsets); +} + +void Tail::map_(Mapper &mapper) { + buf_.map(mapper); + end_flags_.map(mapper); +} + +void Tail::read_(Reader &reader) { + buf_.read(reader); + end_flags_.read(reader); +} + +void Tail::write_(Writer &writer) const { + buf_.write(writer); + end_flags_.write(writer); +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h new file mode 100644 index 0000000000..7e5ca1d3e7 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h @@ -0,0 +1,73 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_TAIL_H_ +#define MARISA_GRIMOIRE_TRIE_TAIL_H_ + +#include "../../agent.h" +#include "../vector.h" +#include "entry.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Tail { + public: + Tail(); + + void build(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode); + + void map(Mapper &mapper); + void read(Reader &reader); + void write(Writer &writer) const; + + void restore(Agent &agent, std::size_t offset) const; + bool match(Agent &agent, std::size_t offset) const; + bool prefix_match(Agent &agent, std::size_t offset) const; + + const char &operator[](std::size_t offset) const { + MARISA_DEBUG_IF(offset >= buf_.size(), MARISA_BOUND_ERROR); + return buf_[offset]; + } + + TailMode mode() const { + return end_flags_.empty() ? MARISA_TEXT_TAIL : MARISA_BINARY_TAIL; + } + + bool empty() const { + return buf_.empty(); + } + std::size_t size() const { + return buf_.size(); + } + std::size_t total_size() const { + return buf_.total_size() + end_flags_.total_size(); + } + std::size_t io_size() const { + return buf_.io_size() + end_flags_.io_size(); + } + + void clear(); + void swap(Tail &rhs); + + private: + Vector<char> buf_; + BitVector end_flags_; + + void build_(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode); + + void map_(Mapper &mapper); + void read_(Reader &reader); + void write_(Writer &writer) const; + + // Disallows copy and assignment. + Tail(const Tail &); + Tail &operator=(const Tail &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_TAIL_H_ |