diff options
author | vitalyisaev <vitalyisaev@ydb.tech> | 2023-11-30 13:26:22 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@ydb.tech> | 2023-11-30 15:44:45 +0300 |
commit | 0a98fece5a9b54f16afeb3a94b3eb3105e9c3962 (patch) | |
tree | 291d72dbd7e9865399f668c84d11ed86fb190bbf /contrib/python/marisa-trie | |
parent | cb2c8d75065e5b3c47094067cb4aa407d4813298 (diff) | |
download | ydb-0a98fece5a9b54f16afeb3a94b3eb3105e9c3962.tar.gz |
YQ Connector:Use docker-compose in integrational tests
Diffstat (limited to 'contrib/python/marisa-trie')
54 files changed, 7225 insertions, 0 deletions
diff --git a/contrib/python/marisa-trie/agent.pxd b/contrib/python/marisa-trie/agent.pxd new file mode 100644 index 0000000000..bf019673c2 --- /dev/null +++ b/contrib/python/marisa-trie/agent.pxd @@ -0,0 +1,22 @@ +cimport query, key + +cdef extern from "<marisa/agent.h>" namespace "marisa" nogil: + cdef cppclass Agent: + Agent() except + + + query.Query &query() + key.Key &key() + + void set_query(char *str) + void set_query(char *ptr, int length) + void set_query(int key_id) + + void set_key(char *str) + void set_key(char *ptr, int length) + void set_key(int id) + + void clear() + + void init_state() + + void swap(Agent &rhs) diff --git a/contrib/python/marisa-trie/base.pxd b/contrib/python/marisa-trie/base.pxd new file mode 100644 index 0000000000..c434e82122 --- /dev/null +++ b/contrib/python/marisa-trie/base.pxd @@ -0,0 +1,63 @@ +cdef extern from "<marisa/base.h>": + + # A dictionary consists of 3 tries in default. Usually more tries make a + # dictionary space-efficient but time-inefficient. + ctypedef enum marisa_num_tries: + MARISA_MIN_NUM_TRIES + MARISA_MAX_NUM_TRIES + MARISA_DEFAULT_NUM_TRIES + + + # This library uses a cache technique to accelerate search functions. The + # following enumerated type marisa_cache_level gives a list of available cache + # size options. A larger cache enables faster search but takes a more space. + ctypedef enum marisa_cache_level: + MARISA_HUGE_CACHE + MARISA_LARGE_CACHE + MARISA_NORMAL_CACHE + MARISA_SMALL_CACHE + MARISA_TINY_CACHE + MARISA_DEFAULT_CACHE + + # This library provides 2 kinds of TAIL implementations. + ctypedef enum marisa_tail_mode: + # MARISA_TEXT_TAIL merges last labels as zero-terminated strings. So, it is + # available if and only if the last labels do not contain a NULL character. + # If MARISA_TEXT_TAIL is specified and a NULL character exists in the last + # labels, the setting is automatically switched to MARISA_BINARY_TAIL. + MARISA_TEXT_TAIL + + # MARISA_BINARY_TAIL also merges last labels but as byte sequences. It uses + # a bit vector to detect the end of a sequence, instead of NULL characters. + # So, MARISA_BINARY_TAIL requires a larger space if the average length of + # labels is greater than 8. + MARISA_BINARY_TAIL + + MARISA_DEFAULT_TAIL + + # The arrangement of nodes affects the time cost of matching and the order of + # predictive search. + ctypedef enum marisa_node_order: + # MARISA_LABEL_ORDER arranges nodes in ascending label order. + # MARISA_LABEL_ORDER is useful if an application needs to predict keys in + # label order. + MARISA_LABEL_ORDER + + # MARISA_WEIGHT_ORDER arranges nodes in descending weight order. + # MARISA_WEIGHT_ORDER is generally a better choice because it enables faster + # matching. + MARISA_WEIGHT_ORDER + MARISA_DEFAULT_ORDER + + ctypedef enum marisa_config_mask: + MARISA_NUM_TRIES_MASK + MARISA_CACHE_LEVEL_MASK + MARISA_TAIL_MODE_MASK + MARISA_NODE_ORDER_MASK + MARISA_CONFIG_MASK + + +cdef extern from "<marisa/base.h>" namespace "marisa": + ctypedef marisa_cache_level CacheLevel + ctypedef marisa_tail_mode TailMode + ctypedef marisa_node_order NodeOrder diff --git a/contrib/python/marisa-trie/iostream.pxd b/contrib/python/marisa-trie/iostream.pxd new file mode 100644 index 0000000000..435ee85bb0 --- /dev/null +++ b/contrib/python/marisa-trie/iostream.pxd @@ -0,0 +1,7 @@ +from std_iostream cimport istream, ostream +from trie cimport Trie + +cdef extern from "<marisa/iostream.h>" namespace "marisa" nogil: + + istream &read(istream &stream, Trie *trie) + ostream &write(ostream &stream, Trie &trie) diff --git a/contrib/python/marisa-trie/key.pxd b/contrib/python/marisa-trie/key.pxd new file mode 100644 index 0000000000..d99dee5e04 --- /dev/null +++ b/contrib/python/marisa-trie/key.pxd @@ -0,0 +1,22 @@ +cdef extern from "<marisa/key.h>" namespace "marisa" nogil: + + cdef cppclass Key: + Key() + Key(Key &query) + + #Key &operator=(Key &query) + + char operator[](int i) + + void set_str(char *str) + void set_str(char *ptr, int length) + void set_id(int id) + void set_weight(float weight) + + char *ptr() + int length() + int id() + float weight() + + void clear() + void swap(Key &rhs) diff --git a/contrib/python/marisa-trie/keyset.pxd b/contrib/python/marisa-trie/keyset.pxd new file mode 100644 index 0000000000..1fb99a40c5 --- /dev/null +++ b/contrib/python/marisa-trie/keyset.pxd @@ -0,0 +1,30 @@ +cimport key + +cdef extern from "<marisa/keyset.h>" namespace "marisa" nogil: + cdef cppclass Keyset: + +# cdef enum constants: +# BASE_BLOCK_SIZE = 4096 +# EXTRA_BLOCK_SIZE = 1024 +# KEY_BLOCK_SIZE = 256 + + Keyset() + + void push_back(key.Key &key) + void push_back(key.Key &key, char end_marker) + + void push_back(char *str) + void push_back(char *ptr, int length) + void push_back(char *ptr, int length, float weight) + + key.Key &operator[](int i) + + int num_keys() + bint empty() + + int size() + int total_length() + + void reset() + void clear() + void swap(Keyset &rhs) diff --git a/contrib/python/marisa-trie/marisa/agent.cc b/contrib/python/marisa-trie/marisa/agent.cc new file mode 100644 index 0000000000..7f7f49f1bc --- /dev/null +++ b/contrib/python/marisa-trie/marisa/agent.cc @@ -0,0 +1,51 @@ +#include <new> + +#include "agent.h" +#include "grimoire/trie.h" + +namespace marisa { + +Agent::Agent() : query_(), key_(), state_() {} + +Agent::~Agent() {} + +void Agent::set_query(const char *str) { + MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR); + if (state_.get() != NULL) { + state_->reset(); + } + query_.set_str(str); +} + +void Agent::set_query(const char *ptr, std::size_t length) { + MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + if (state_.get() != NULL) { + state_->reset(); + } + query_.set_str(ptr, length); +} + +void Agent::set_query(std::size_t key_id) { + if (state_.get() != NULL) { + state_->reset(); + } + query_.set_id(key_id); +} + +void Agent::init_state() { + MARISA_THROW_IF(state_.get() != NULL, MARISA_STATE_ERROR); + state_.reset(new (std::nothrow) grimoire::State); + MARISA_THROW_IF(state_.get() == NULL, MARISA_MEMORY_ERROR); +} + +void Agent::clear() { + Agent().swap(*this); +} + +void Agent::swap(Agent &rhs) { + query_.swap(rhs.query_); + key_.swap(rhs.key_); + state_.swap(rhs.state_); +} + +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/agent.h b/contrib/python/marisa-trie/marisa/agent.h new file mode 100644 index 0000000000..0f89f7df0f --- /dev/null +++ b/contrib/python/marisa-trie/marisa/agent.h @@ -0,0 +1,75 @@ +#pragma once + +#ifndef MARISA_AGENT_H_ +#define MARISA_AGENT_H_ + +#include "key.h" +#include "query.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class State; + +} // namespace trie +} // namespace grimoire + +class Agent { + public: + Agent(); + ~Agent(); + + const Query &query() const { + return query_; + } + const Key &key() const { + return key_; + } + + void set_query(const char *str); + void set_query(const char *ptr, std::size_t length); + void set_query(std::size_t key_id); + + const grimoire::trie::State &state() const { + return *state_; + } + grimoire::trie::State &state() { + return *state_; + } + + void set_key(const char *str) { + MARISA_DEBUG_IF(str == NULL, MARISA_NULL_ERROR); + key_.set_str(str); + } + void set_key(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_.set_str(ptr, length); + } + void set_key(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_.set_id(id); + } + + bool has_state() const { + return state_.get() != NULL; + } + void init_state(); + + void clear(); + void swap(Agent &rhs); + + private: + Query query_; + Key key_; + scoped_ptr<grimoire::trie::State> state_; + + // Disallows copy and assignment. + Agent(const Agent &); + Agent &operator=(const Agent &); +}; + +} // namespace marisa + +#endif // MARISA_AGENT_H_ diff --git a/contrib/python/marisa-trie/marisa/base.h b/contrib/python/marisa-trie/marisa/base.h new file mode 100644 index 0000000000..5c595dcd2b --- /dev/null +++ b/contrib/python/marisa-trie/marisa/base.h @@ -0,0 +1,196 @@ +#pragma once + +#ifndef MARISA_BASE_H_ +#define MARISA_BASE_H_ + +// Old Visual C++ does not provide stdint.h. +#ifndef _MSC_VER + #include <stdint.h> +#endif // _MSC_VER + +#ifdef __cplusplus + #include <cstddef> +#else // __cplusplus + #include <stddef.h> +#endif // __cplusplus + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#ifdef _MSC_VER +typedef unsigned __int8 marisa_uint8; +typedef unsigned __int16 marisa_uint16; +typedef unsigned __int32 marisa_uint32; +typedef unsigned __int64 marisa_uint64; +#else // _MSC_VER +typedef uint8_t marisa_uint8; +typedef uint16_t marisa_uint16; +typedef uint32_t marisa_uint32; +typedef uint64_t marisa_uint64; +#endif // _MSC_VER + +#if defined(_WIN64) || defined(__amd64__) || defined(__x86_64__) || \ + defined(__ia64__) || defined(__ppc64__) || defined(__powerpc64__) || \ + defined(__sparc64__) || defined(__mips64__) || defined(__aarch64__) || \ + defined(__s390x__) + #define MARISA_WORD_SIZE 64 +#else // defined(_WIN64), etc. + #define MARISA_WORD_SIZE 32 +#endif // defined(_WIN64), etc. + +//#define MARISA_WORD_SIZE (sizeof(void *) * 8) + +#define MARISA_UINT8_MAX ((marisa_uint8)~(marisa_uint8)0) +#define MARISA_UINT16_MAX ((marisa_uint16)~(marisa_uint16)0) +#define MARISA_UINT32_MAX ((marisa_uint32)~(marisa_uint32)0) +#define MARISA_UINT64_MAX ((marisa_uint64)~(marisa_uint64)0) +#define MARISA_SIZE_MAX ((size_t)~(size_t)0) + +#define MARISA_INVALID_LINK_ID MARISA_UINT32_MAX +#define MARISA_INVALID_KEY_ID MARISA_UINT32_MAX +#define MARISA_INVALID_EXTRA (MARISA_UINT32_MAX >> 8) + +// Error codes are defined as members of marisa_error_code. This library throws +// an exception with one of the error codes when an error occurs. +typedef enum marisa_error_code_ { + // MARISA_OK means that a requested operation has succeeded. In practice, an + // exception never has MARISA_OK because it is not an error. + MARISA_OK = 0, + + // MARISA_STATE_ERROR means that an object was not ready for a requested + // operation. For example, an operation to modify a fixed vector throws an + // exception with MARISA_STATE_ERROR. + MARISA_STATE_ERROR = 1, + + // MARISA_NULL_ERROR means that an invalid NULL pointer has been given. + MARISA_NULL_ERROR = 2, + + // MARISA_BOUND_ERROR means that an operation has tried to access an out of + // range address. + MARISA_BOUND_ERROR = 3, + + // MARISA_RANGE_ERROR means that an out of range value has appeared in + // operation. + MARISA_RANGE_ERROR = 4, + + // MARISA_CODE_ERROR means that an undefined code has appeared in operation. + MARISA_CODE_ERROR = 5, + + // MARISA_RESET_ERROR means that a smart pointer has tried to reset itself. + MARISA_RESET_ERROR = 6, + + // MARISA_SIZE_ERROR means that a size has exceeded a library limitation. + MARISA_SIZE_ERROR = 7, + + // MARISA_MEMORY_ERROR means that a memory allocation has failed. + MARISA_MEMORY_ERROR = 8, + + // MARISA_IO_ERROR means that an I/O operation has failed. + MARISA_IO_ERROR = 9, + + // MARISA_FORMAT_ERROR means that input was in invalid format. + MARISA_FORMAT_ERROR = 10, +} marisa_error_code; + +// Min/max values, flags and masks for dictionary settings are defined below. +// Please note that unspecified settings will be replaced with the default +// settings. For example, 0 is equivalent to (MARISA_DEFAULT_NUM_TRIES | +// MARISA_DEFAULT_TRIE | MARISA_DEFAULT_TAIL | MARISA_DEFAULT_ORDER). + +// A dictionary consists of 3 tries in default. Usually more tries make a +// dictionary space-efficient but time-inefficient. +typedef enum marisa_num_tries_ { + MARISA_MIN_NUM_TRIES = 0x00001, + MARISA_MAX_NUM_TRIES = 0x0007F, + MARISA_DEFAULT_NUM_TRIES = 0x00003, +} marisa_num_tries; + +// This library uses a cache technique to accelerate search functions. The +// following enumerated type marisa_cache_level gives a list of available cache +// size options. A larger cache enables faster search but takes a more space. +typedef enum marisa_cache_level_ { + MARISA_HUGE_CACHE = 0x00080, + MARISA_LARGE_CACHE = 0x00100, + MARISA_NORMAL_CACHE = 0x00200, + MARISA_SMALL_CACHE = 0x00400, + MARISA_TINY_CACHE = 0x00800, + MARISA_DEFAULT_CACHE = MARISA_NORMAL_CACHE +} marisa_cache_level; + +// This library provides 2 kinds of TAIL implementations. +typedef enum marisa_tail_mode_ { + // MARISA_TEXT_TAIL merges last labels as zero-terminated strings. So, it is + // available if and only if the last labels do not contain a NULL character. + // If MARISA_TEXT_TAIL is specified and a NULL character exists in the last + // labels, the setting is automatically switched to MARISA_BINARY_TAIL. + MARISA_TEXT_TAIL = 0x01000, + + // MARISA_BINARY_TAIL also merges last labels but as byte sequences. It uses + // a bit vector to detect the end of a sequence, instead of NULL characters. + // So, MARISA_BINARY_TAIL requires a larger space if the average length of + // labels is greater than 8. + MARISA_BINARY_TAIL = 0x02000, + + MARISA_DEFAULT_TAIL = MARISA_TEXT_TAIL, +} marisa_tail_mode; + +// The arrangement of nodes affects the time cost of matching and the order of +// predictive search. +typedef enum marisa_node_order_ { + // MARISA_LABEL_ORDER arranges nodes in ascending label order. + // MARISA_LABEL_ORDER is useful if an application needs to predict keys in + // label order. + MARISA_LABEL_ORDER = 0x10000, + + // MARISA_WEIGHT_ORDER arranges nodes in descending weight order. + // MARISA_WEIGHT_ORDER is generally a better choice because it enables faster + // matching. + MARISA_WEIGHT_ORDER = 0x20000, + + MARISA_DEFAULT_ORDER = MARISA_WEIGHT_ORDER, +} marisa_node_order; + +typedef enum marisa_config_mask_ { + MARISA_NUM_TRIES_MASK = 0x0007F, + MARISA_CACHE_LEVEL_MASK = 0x00F80, + MARISA_TAIL_MODE_MASK = 0x0F000, + MARISA_NODE_ORDER_MASK = 0xF0000, + MARISA_CONFIG_MASK = 0xFFFFF +} marisa_config_mask; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#ifdef __cplusplus +namespace marisa { + +typedef ::marisa_uint8 UInt8; +typedef ::marisa_uint16 UInt16; +typedef ::marisa_uint32 UInt32; +typedef ::marisa_uint64 UInt64; + +typedef ::marisa_error_code ErrorCode; + +typedef ::marisa_cache_level CacheLevel; +typedef ::marisa_tail_mode TailMode; +typedef ::marisa_node_order NodeOrder; + +template <typename T> +inline void swap(T &lhs, T &rhs) { + T temp = lhs; + lhs = rhs; + rhs = temp; +} + +} // namespace marisa +#endif // __cplusplus + +#ifdef __cplusplus + #include "exception.h" + #include "scoped-ptr.h" + #include "scoped-array.h" +#endif // __cplusplus + +#endif // MARISA_BASE_H_ diff --git a/contrib/python/marisa-trie/marisa/exception.h b/contrib/python/marisa-trie/marisa/exception.h new file mode 100644 index 0000000000..630936b23b --- /dev/null +++ b/contrib/python/marisa-trie/marisa/exception.h @@ -0,0 +1,84 @@ +#pragma once + +#ifndef MARISA_EXCEPTION_H_ +#define MARISA_EXCEPTION_H_ + +#include <exception> + +#include "base.h" + +namespace marisa { + +// An exception object keeps a filename, a line number, an error code and an +// error message. The message format is as follows: +// "__FILE__:__LINE__: error_code: error_message" +class Exception : public std::exception { + public: + Exception(const char *filename, int line, + ErrorCode error_code, const char *error_message) + : std::exception(), filename_(filename), line_(line), + error_code_(error_code), error_message_(error_message) {} + Exception(const Exception &ex) + : std::exception(), filename_(ex.filename_), line_(ex.line_), + error_code_(ex.error_code_), error_message_(ex.error_message_) {} + virtual ~Exception() {} + + Exception &operator=(const Exception &rhs) { + filename_ = rhs.filename_; + line_ = rhs.line_; + error_code_ = rhs.error_code_; + error_message_ = rhs.error_message_; + return *this; + } + + const char *filename() const { + return filename_; + } + int line() const { + return line_; + } + ErrorCode error_code() const { + return error_code_; + } + const char *error_message() const { + return error_message_; + } + + virtual const char *what() const noexcept { + return error_message_; + } + + private: + const char *filename_; + int line_; + ErrorCode error_code_; + const char *error_message_; +}; + +// These macros are used to convert a line number to a string constant. +#define MARISA_INT_TO_STR(value) #value +#define MARISA_LINE_TO_STR(line) MARISA_INT_TO_STR(line) +#define MARISA_LINE_STR MARISA_LINE_TO_STR(__LINE__) + +// MARISA_THROW throws an exception with a filename, a line number, an error +// code and an error message. The message format is as follows: +// "__FILE__:__LINE__: error_code: error_message" +#define MARISA_THROW(error_code, error_message) \ + (throw marisa::Exception(__FILE__, __LINE__, error_code, \ + __FILE__ ":" MARISA_LINE_STR ": " #error_code ": " error_message)) + +// MARISA_THROW_IF throws an exception if `condition' is true. +#define MARISA_THROW_IF(condition, error_code) \ + (void)((!(condition)) || (MARISA_THROW(error_code, #condition), 0)) + +// MARISA_DEBUG_IF is ignored if _DEBUG is undefined. So, it is useful for +// debugging time-critical codes. +#ifdef _DEBUG + #define MARISA_DEBUG_IF(cond, error_code) MARISA_THROW_IF(cond, error_code) +#else + #define MARISA_DEBUG_IF(cond, error_code) +#endif + +} // namespace marisa + +#endif // MARISA_EXCEPTION_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/algorithm.h b/contrib/python/marisa-trie/marisa/grimoire/algorithm.h new file mode 100644 index 0000000000..71baec34ac --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/algorithm.h @@ -0,0 +1,27 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_ALGORITHM_H_ +#define MARISA_GRIMOIRE_ALGORITHM_H_ + +#include "algorithm/sort.h" + +namespace marisa { +namespace grimoire { + +class Algorithm { + public: + Algorithm() {} + + template <typename Iterator> + std::size_t sort(Iterator begin, Iterator end) const { + return algorithm::sort(begin, end); + } + + private: + Algorithm(const Algorithm &); + Algorithm &operator=(const Algorithm &); +}; + +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_ALGORITHM_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/algorithm/sort.h b/contrib/python/marisa-trie/marisa/grimoire/algorithm/sort.h new file mode 100644 index 0000000000..9090336ce6 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/algorithm/sort.h @@ -0,0 +1,197 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_ALGORITHM_SORT_H_ +#define MARISA_GRIMOIRE_ALGORITHM_SORT_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace algorithm { +namespace details { + +enum { + MARISA_INSERTION_SORT_THRESHOLD = 10 +}; + +template <typename T> +int get_label(const T &unit, std::size_t depth) { + MARISA_DEBUG_IF(depth > unit.length(), MARISA_BOUND_ERROR); + + return (depth < unit.length()) ? (int)(UInt8)unit[depth] : -1; +} + +template <typename T> +int median(const T &a, const T &b, const T &c, std::size_t depth) { + const int x = get_label(a, depth); + const int y = get_label(b, depth); + const int z = get_label(c, depth); + if (x < y) { + if (y < z) { + return y; + } else if (x < z) { + return z; + } + return x; + } else if (x < z) { + return x; + } else if (y < z) { + return z; + } + return y; +} + +template <typename T> +int compare(const T &lhs, const T &rhs, std::size_t depth) { + for (std::size_t i = depth; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return 1; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] - (UInt8)rhs[i]; + } + } + if (lhs.length() == rhs.length()) { + return 0; + } + return (lhs.length() < rhs.length()) ? -1 : 1; +} + +template <typename Iterator> +std::size_t insertion_sort(Iterator l, Iterator r, std::size_t depth) { + MARISA_DEBUG_IF(l > r, MARISA_BOUND_ERROR); + + std::size_t count = 1; + for (Iterator i = l + 1; i < r; ++i) { + int result = 0; + for (Iterator j = i; j > l; --j) { + result = compare(*(j - 1), *j, depth); + if (result <= 0) { + break; + } + marisa::swap(*(j - 1), *j); + } + if (result != 0) { + ++count; + } + } + return count; +} + +template <typename Iterator> +std::size_t sort(Iterator l, Iterator r, std::size_t depth) { + MARISA_DEBUG_IF(l > r, MARISA_BOUND_ERROR); + + std::size_t count = 0; + while ((r - l) > MARISA_INSERTION_SORT_THRESHOLD) { + Iterator pl = l; + Iterator pr = r; + Iterator pivot_l = l; + Iterator pivot_r = r; + + const int pivot = median(*l, *(l + (r - l) / 2), *(r - 1), depth); + for ( ; ; ) { + while (pl < pr) { + const int label = get_label(*pl, depth); + if (label > pivot) { + break; + } else if (label == pivot) { + marisa::swap(*pl, *pivot_l); + ++pivot_l; + } + ++pl; + } + while (pl < pr) { + const int label = get_label(*--pr, depth); + if (label < pivot) { + break; + } else if (label == pivot) { + marisa::swap(*pr, *--pivot_r); + } + } + if (pl >= pr) { + break; + } + marisa::swap(*pl, *pr); + ++pl; + } + while (pivot_l > l) { + marisa::swap(*--pivot_l, *--pl); + } + while (pivot_r < r) { + marisa::swap(*pivot_r, *pr); + ++pivot_r; + ++pr; + } + + if (((pl - l) > (pr - pl)) || ((r - pr) > (pr - pl))) { + if ((pr - pl) == 1) { + ++count; + } else if ((pr - pl) > 1) { + if (pivot == -1) { + ++count; + } else { + count += sort(pl, pr, depth + 1); + } + } + + if ((pl - l) < (r - pr)) { + if ((pl - l) == 1) { + ++count; + } else if ((pl - l) > 1) { + count += sort(l, pl, depth); + } + l = pr; + } else { + if ((r - pr) == 1) { + ++count; + } else if ((r - pr) > 1) { + count += sort(pr, r, depth); + } + r = pl; + } + } else { + if ((pl - l) == 1) { + ++count; + } else if ((pl - l) > 1) { + count += sort(l, pl, depth); + } + + if ((r - pr) == 1) { + ++count; + } else if ((r - pr) > 1) { + count += sort(pr, r, depth); + } + + l = pl, r = pr; + if ((pr - pl) == 1) { + ++count; + } else if ((pr - pl) > 1) { + if (pivot == -1) { + l = r; + ++count; + } else { + ++depth; + } + } + } + } + + if ((r - l) > 1) { + count += insertion_sort(l, r, depth); + } + return count; +} + +} // namespace details + +template <typename Iterator> +std::size_t sort(Iterator begin, Iterator end) { + MARISA_DEBUG_IF(begin > end, MARISA_BOUND_ERROR); + return details::sort(begin, end, 0); +}; + +} // namespace algorithm +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_ALGORITHM_SORT_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/intrin.h b/contrib/python/marisa-trie/marisa/grimoire/intrin.h new file mode 100644 index 0000000000..16843b353c --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/intrin.h @@ -0,0 +1,116 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_INTRIN_H_ +#define MARISA_GRIMOIRE_INTRIN_H_ + +#include "../base.h" + +#if defined(__x86_64__) || defined(_M_X64) + #define MARISA_X64 +#elif defined(__i386__) || defined(_M_IX86) + #define MARISA_X86 +#else // defined(__i386__) || defined(_M_IX86) + #ifdef MARISA_USE_POPCNT + #undef MARISA_USE_POPCNT + #endif // MARISA_USE_POPCNT + #ifdef MARISA_USE_SSE4A + #undef MARISA_USE_SSE4A + #endif // MARISA_USE_SSE4A + #ifdef MARISA_USE_SSE4 + #undef MARISA_USE_SSE4 + #endif // MARISA_USE_SSE4 + #ifdef MARISA_USE_SSE4_2 + #undef MARISA_USE_SSE4_2 + #endif // MARISA_USE_SSE4_2 + #ifdef MARISA_USE_SSE4_1 + #undef MARISA_USE_SSE4_1 + #endif // MARISA_USE_SSE4_1 + #ifdef MARISA_USE_SSSE3 + #undef MARISA_USE_SSSE3 + #endif // MARISA_USE_SSSE3 + #ifdef MARISA_USE_SSE3 + #undef MARISA_USE_SSE3 + #endif // MARISA_USE_SSE3 + #ifdef MARISA_USE_SSE2 + #undef MARISA_USE_SSE2 + #endif // MARISA_USE_SSE2 +#endif // defined(__i386__) || defined(_M_IX86) + +#ifdef MARISA_USE_POPCNT + #ifndef MARISA_USE_SSE3 + #define MARISA_USE_SSE3 + #endif // MARISA_USE_SSE3 + #ifdef _MSC_VER + #include <intrin.h> + #else // _MSC_VER + #include <popcntintrin.h> + #endif // _MSC_VER +#endif // MARISA_USE_POPCNT + +#ifdef MARISA_USE_SSE4A + #ifndef MARISA_USE_SSE3 + #define MARISA_USE_SSE3 + #endif // MARISA_USE_SSE3 + #ifndef MARISA_USE_POPCNT + #define MARISA_USE_POPCNT + #endif // MARISA_USE_POPCNT +#endif // MARISA_USE_SSE4A + +#ifdef MARISA_USE_SSE4 + #ifndef MARISA_USE_SSE4_2 + #define MARISA_USE_SSE4_2 + #endif // MARISA_USE_SSE4_2 +#endif // MARISA_USE_SSE4 + +#ifdef MARISA_USE_SSE4_2 + #ifndef MARISA_USE_SSE4_1 + #define MARISA_USE_SSE4_1 + #endif // MARISA_USE_SSE4_1 + #ifndef MARISA_USE_POPCNT + #define MARISA_USE_POPCNT + #endif // MARISA_USE_POPCNT +#endif // MARISA_USE_SSE4_2 + +#ifdef MARISA_USE_SSE4_1 + #ifndef MARISA_USE_SSSE3 + #define MARISA_USE_SSSE3 + #endif // MARISA_USE_SSSE3 +#endif // MARISA_USE_SSE4_1 + +#ifdef MARISA_USE_SSSE3 + #ifndef MARISA_USE_SSE3 + #define MARISA_USE_SSE3 + #endif // MARISA_USE_SSE3 + #ifdef MARISA_X64 + #define MARISA_X64_SSSE3 + #else // MARISA_X64 + #define MARISA_X86_SSSE3 + #endif // MAIRSA_X64 + #include <tmmintrin.h> +#endif // MARISA_USE_SSSE3 + +#ifdef MARISA_USE_SSE3 + #ifndef MARISA_USE_SSE2 + #define MARISA_USE_SSE2 + #endif // MARISA_USE_SSE2 +#endif // MARISA_USE_SSE3 + +#ifdef MARISA_USE_SSE2 + #ifdef MARISA_X64 + #define MARISA_X64_SSE2 + #else // MARISA_X64 + #define MARISA_X86_SSE2 + #endif // MAIRSA_X64 + #include <emmintrin.h> +#endif // MARISA_USE_SSE2 + +#ifdef _MSC_VER + #if MARISA_WORD_SIZE == 64 + #include <intrin.h> + #pragma intrinsic(_BitScanForward64) + #else // MARISA_WORD_SIZE == 64 + #include <intrin.h> + #pragma intrinsic(_BitScanForward) + #endif // MARISA_WORD_SIZE == 64 +#endif // _MSC_VER + +#endif // MARISA_GRIMOIRE_INTRIN_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/io.h b/contrib/python/marisa-trie/marisa/grimoire/io.h new file mode 100644 index 0000000000..4de0110dbb --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io.h @@ -0,0 +1,19 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_IO_H_ +#define MARISA_GRIMOIRE_IO_H_ + +#include "io/mapper.h" +#include "io/reader.h" +#include "io/writer.h" + +namespace marisa { +namespace grimoire { + +using io::Mapper; +using io::Reader; +using io::Writer; + +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_IO_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/io/mapper.cc b/contrib/python/marisa-trie/marisa/grimoire/io/mapper.cc new file mode 100644 index 0000000000..9ed6ffc755 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io/mapper.cc @@ -0,0 +1,163 @@ +#if (defined _WIN32) || (defined _WIN64) + #include <sys/types.h> + #include <sys/stat.h> + #include <windows.h> +#else // (defined _WIN32) || (defined _WIN64) + #include <sys/mman.h> + #include <sys/stat.h> + #include <sys/types.h> + #include <fcntl.h> + #include <unistd.h> +#endif // (defined _WIN32) || (defined _WIN64) + +#include "mapper.h" + +namespace marisa { +namespace grimoire { +namespace io { + +#if (defined _WIN32) || (defined _WIN64) +Mapper::Mapper() + : ptr_(NULL), origin_(NULL), avail_(0), size_(0), + file_(NULL), map_(NULL) {} +#else // (defined _WIN32) || (defined _WIN64) +Mapper::Mapper() + : ptr_(NULL), origin_(MAP_FAILED), avail_(0), size_(0), fd_(-1) {} +#endif // (defined _WIN32) || (defined _WIN64) + +#if (defined _WIN32) || (defined _WIN64) +Mapper::~Mapper() { + if (origin_ != NULL) { + ::UnmapViewOfFile(origin_); + } + + if (map_ != NULL) { + ::CloseHandle(map_); + } + + if (file_ != NULL) { + ::CloseHandle(file_); + } +} +#else // (defined _WIN32) || (defined _WIN64) +Mapper::~Mapper() { + if (origin_ != MAP_FAILED) { + ::munmap(origin_, size_); + } + + if (fd_ != -1) { + ::close(fd_); + } +} +#endif // (defined _WIN32) || (defined _WIN64) + +void Mapper::open(const char *filename) { + MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); + + Mapper temp; + temp.open_(filename); + swap(temp); +} + +void Mapper::open(const void *ptr, std::size_t size) { + MARISA_THROW_IF((ptr == NULL) && (size != 0), MARISA_NULL_ERROR); + + Mapper temp; + temp.open_(ptr, size); + swap(temp); +} + +void Mapper::seek(std::size_t size) { + MARISA_THROW_IF(!is_open(), MARISA_STATE_ERROR); + MARISA_THROW_IF(size > avail_, MARISA_IO_ERROR); + + map_data(size); +} + +bool Mapper::is_open() const { + return ptr_ != NULL; +} + +void Mapper::clear() { + Mapper().swap(*this); +} + +void Mapper::swap(Mapper &rhs) { + marisa::swap(ptr_, rhs.ptr_); + marisa::swap(avail_, rhs.avail_); + marisa::swap(origin_, rhs.origin_); + marisa::swap(size_, rhs.size_); +#if (defined _WIN32) || (defined _WIN64) + marisa::swap(file_, rhs.file_); + marisa::swap(map_, rhs.map_); +#else // (defined _WIN32) || (defined _WIN64) + marisa::swap(fd_, rhs.fd_); +#endif // (defined _WIN32) || (defined _WIN64) +} + +const void *Mapper::map_data(std::size_t size) { + MARISA_THROW_IF(!is_open(), MARISA_STATE_ERROR); + MARISA_THROW_IF(size > avail_, MARISA_IO_ERROR); + + const char * const data = static_cast<const char *>(ptr_); + ptr_ = data + size; + avail_ -= size; + return data; +} + +#if (defined _WIN32) || (defined _WIN64) + #ifdef __MSVCRT_VERSION__ + #if __MSVCRT_VERSION__ >= 0x0601 + #define MARISA_HAS_STAT64 + #endif // __MSVCRT_VERSION__ >= 0x0601 + #endif // __MSVCRT_VERSION__ +void Mapper::open_(const char *filename) { + #ifdef MARISA_HAS_STAT64 + struct __stat64 st; + MARISA_THROW_IF(::_stat64(filename, &st) != 0, MARISA_IO_ERROR); + #else // MARISA_HAS_STAT64 + struct _stat st; + MARISA_THROW_IF(::_stat(filename, &st) != 0, MARISA_IO_ERROR); + #endif // MARISA_HAS_STAT64 + MARISA_THROW_IF((UInt64)st.st_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + size_ = (std::size_t)st.st_size; + + file_ = ::CreateFileA(filename, GENERIC_READ, FILE_SHARE_READ, + NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); + MARISA_THROW_IF(file_ == INVALID_HANDLE_VALUE, MARISA_IO_ERROR); + + map_ = ::CreateFileMapping(file_, NULL, PAGE_READONLY, 0, 0, NULL); + MARISA_THROW_IF(map_ == NULL, MARISA_IO_ERROR); + + origin_ = ::MapViewOfFile(map_, FILE_MAP_READ, 0, 0, 0); + MARISA_THROW_IF(origin_ == NULL, MARISA_IO_ERROR); + + ptr_ = static_cast<const char *>(origin_); + avail_ = size_; +} +#else // (defined _WIN32) || (defined _WIN64) +void Mapper::open_(const char *filename) { + struct stat st; + MARISA_THROW_IF(::stat(filename, &st) != 0, MARISA_IO_ERROR); + MARISA_THROW_IF((UInt64)st.st_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + size_ = (std::size_t)st.st_size; + + fd_ = ::open(filename, O_RDONLY); + MARISA_THROW_IF(fd_ == -1, MARISA_IO_ERROR); + + origin_ = ::mmap(NULL, size_, PROT_READ, MAP_SHARED, fd_, 0); + MARISA_THROW_IF(origin_ == MAP_FAILED, MARISA_IO_ERROR); + + ptr_ = static_cast<const char *>(origin_); + avail_ = size_; +} +#endif // (defined _WIN32) || (defined _WIN64) + +void Mapper::open_(const void *ptr, std::size_t size) { + ptr_ = ptr; + avail_ = size; +} + +} // namespace io +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/io/mapper.h b/contrib/python/marisa-trie/marisa/grimoire/io/mapper.h new file mode 100644 index 0000000000..e06072501d --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io/mapper.h @@ -0,0 +1,68 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_IO_MAPPER_H_ +#define MARISA_GRIMOIRE_IO_MAPPER_H_ + +#include <cstdio> + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace io { + +class Mapper { + public: + Mapper(); + ~Mapper(); + + void open(const char *filename); + void open(const void *ptr, std::size_t size); + + template <typename T> + void map(T *obj) { + MARISA_THROW_IF(obj == NULL, MARISA_NULL_ERROR); + *obj = *static_cast<const T *>(map_data(sizeof(T))); + } + + template <typename T> + void map(const T **objs, std::size_t num_objs) { + MARISA_THROW_IF((objs == NULL) && (num_objs != 0), MARISA_NULL_ERROR); + MARISA_THROW_IF(num_objs > (MARISA_SIZE_MAX / sizeof(T)), + MARISA_SIZE_ERROR); + *objs = static_cast<const T *>(map_data(sizeof(T) * num_objs)); + } + + void seek(std::size_t size); + + bool is_open() const; + + void clear(); + void swap(Mapper &rhs); + + private: + const void *ptr_; + void *origin_; + std::size_t avail_; + std::size_t size_; +#if (defined _WIN32) || (defined _WIN64) + void *file_; + void *map_; +#else // (defined _WIN32) || (defined _WIN64) + int fd_; +#endif // (defined _WIN32) || (defined _WIN64) + + void open_(const char *filename); + void open_(const void *ptr, std::size_t size); + + const void *map_data(std::size_t size); + + // Disallows copy and assignment. + Mapper(const Mapper &); + Mapper &operator=(const Mapper &); +}; + +} // namespace io +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_IO_MAPPER_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/io/reader.cc b/contrib/python/marisa-trie/marisa/grimoire/io/reader.cc new file mode 100644 index 0000000000..cb22fcbd4a --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io/reader.cc @@ -0,0 +1,147 @@ +#include <stdio.h> + +#ifdef _WIN32 + #include <io.h> +#else // _WIN32 + #include <unistd.h> +#endif // _WIN32 + +#include <limits> + +#include "reader.h" + +namespace marisa { +namespace grimoire { +namespace io { + +Reader::Reader() + : file_(NULL), fd_(-1), stream_(NULL), needs_fclose_(false) {} + +Reader::~Reader() { + if (needs_fclose_) { + ::fclose(file_); + } +} + +void Reader::open(const char *filename) { + MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); + + Reader temp; + temp.open_(filename); + swap(temp); +} + +void Reader::open(std::FILE *file) { + MARISA_THROW_IF(file == NULL, MARISA_NULL_ERROR); + + Reader temp; + temp.open_(file); + swap(temp); +} + +void Reader::open(int fd) { + MARISA_THROW_IF(fd == -1, MARISA_CODE_ERROR); + + Reader temp; + temp.open_(fd); + swap(temp); +} + +void Reader::open(std::istream &stream) { + Reader temp; + temp.open_(stream); + swap(temp); +} + +void Reader::clear() { + Reader().swap(*this); +} + +void Reader::swap(Reader &rhs) { + marisa::swap(file_, rhs.file_); + marisa::swap(fd_, rhs.fd_); + marisa::swap(stream_, rhs.stream_); + marisa::swap(needs_fclose_, rhs.needs_fclose_); +} + +void Reader::seek(std::size_t size) { + MARISA_THROW_IF(!is_open(), MARISA_STATE_ERROR); + if (size == 0) { + return; + } else if (size <= 16) { + char buf[16]; + read_data(buf, size); + } else { + char buf[1024]; + while (size != 0) { + const std::size_t count = (size < sizeof(buf)) ? size : sizeof(buf); + read_data(buf, count); + size -= count; + } + } +} + +bool Reader::is_open() const { + return (file_ != NULL) || (fd_ != -1) || (stream_ != NULL); +} + +void Reader::open_(const char *filename) { + std::FILE *file = NULL; +#ifdef _MSC_VER + MARISA_THROW_IF(::fopen_s(&file, filename, "rb") != 0, MARISA_IO_ERROR); +#else // _MSC_VER + file = ::fopen(filename, "rb"); + MARISA_THROW_IF(file == NULL, MARISA_IO_ERROR); +#endif // _MSC_VER + file_ = file; + needs_fclose_ = true; +} + +void Reader::open_(std::FILE *file) { + file_ = file; +} + +void Reader::open_(int fd) { + fd_ = fd; +} + +void Reader::open_(std::istream &stream) { + stream_ = &stream; +} + +void Reader::read_data(void *buf, std::size_t size) { + MARISA_THROW_IF(!is_open(), MARISA_STATE_ERROR); + if (size == 0) { + return; + } else if (fd_ != -1) { + while (size != 0) { +#ifdef _WIN32 + static const std::size_t CHUNK_SIZE = + std::numeric_limits<int>::max(); + const unsigned int count = (size < CHUNK_SIZE) ? size : CHUNK_SIZE; + const int size_read = ::_read(fd_, buf, count); +#else // _WIN32 + static const std::size_t CHUNK_SIZE = + std::numeric_limits< ::ssize_t>::max(); + const ::size_t count = (size < CHUNK_SIZE) ? size : CHUNK_SIZE; + const ::ssize_t size_read = ::read(fd_, buf, count); +#endif // _WIN32 + MARISA_THROW_IF(size_read <= 0, MARISA_IO_ERROR); + buf = static_cast<char *>(buf) + size_read; + size -= size_read; + } + } else if (file_ != NULL) { + MARISA_THROW_IF(::fread(buf, 1, size, file_) != size, MARISA_IO_ERROR); + } else if (stream_ != NULL) { + try { + MARISA_THROW_IF(!stream_->read(static_cast<char *>(buf), size), + MARISA_IO_ERROR); + } catch (const std::ios_base::failure &) { + MARISA_THROW(MARISA_IO_ERROR, "std::ios_base::failure"); + } + } +} + +} // namespace io +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/io/reader.h b/contrib/python/marisa-trie/marisa/grimoire/io/reader.h new file mode 100644 index 0000000000..fc1ba5eea7 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io/reader.h @@ -0,0 +1,67 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_IO_READER_H_ +#define MARISA_GRIMOIRE_IO_READER_H_ + +#include <cstdio> +#include <iostream> + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace io { + +class Reader { + public: + Reader(); + ~Reader(); + + void open(const char *filename); + void open(std::FILE *file); + void open(int fd); + void open(std::istream &stream); + + template <typename T> + void read(T *obj) { + MARISA_THROW_IF(obj == NULL, MARISA_NULL_ERROR); + read_data(obj, sizeof(T)); + } + + template <typename T> + void read(T *objs, std::size_t num_objs) { + MARISA_THROW_IF((objs == NULL) && (num_objs != 0), MARISA_NULL_ERROR); + MARISA_THROW_IF(num_objs > (MARISA_SIZE_MAX / sizeof(T)), + MARISA_SIZE_ERROR); + read_data(objs, sizeof(T) * num_objs); + } + + void seek(std::size_t size); + + bool is_open() const; + + void clear(); + void swap(Reader &rhs); + + private: + std::FILE *file_; + int fd_; + std::istream *stream_; + bool needs_fclose_; + + void open_(const char *filename); + void open_(std::FILE *file); + void open_(int fd); + void open_(std::istream &stream); + + void read_data(void *buf, std::size_t size); + + // Disallows copy and assignment. + Reader(const Reader &); + Reader &operator=(const Reader &); +}; + +} // namespace io +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_IO_READER_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/io/writer.cc b/contrib/python/marisa-trie/marisa/grimoire/io/writer.cc new file mode 100644 index 0000000000..1f079d8ce6 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io/writer.cc @@ -0,0 +1,148 @@ +#include <stdio.h> + +#ifdef _WIN32 + #include <io.h> +#else // _WIN32 + #include <unistd.h> +#endif // _WIN32 + +#include <limits> + +#include "writer.h" + +namespace marisa { +namespace grimoire { +namespace io { + +Writer::Writer() + : file_(NULL), fd_(-1), stream_(NULL), needs_fclose_(false) {} + +Writer::~Writer() { + if (needs_fclose_) { + ::fclose(file_); + } +} + +void Writer::open(const char *filename) { + MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); + + Writer temp; + temp.open_(filename); + swap(temp); +} + +void Writer::open(std::FILE *file) { + MARISA_THROW_IF(file == NULL, MARISA_NULL_ERROR); + + Writer temp; + temp.open_(file); + swap(temp); +} + +void Writer::open(int fd) { + MARISA_THROW_IF(fd == -1, MARISA_CODE_ERROR); + + Writer temp; + temp.open_(fd); + swap(temp); +} + +void Writer::open(std::ostream &stream) { + Writer temp; + temp.open_(stream); + swap(temp); +} + +void Writer::clear() { + Writer().swap(*this); +} + +void Writer::swap(Writer &rhs) { + marisa::swap(file_, rhs.file_); + marisa::swap(fd_, rhs.fd_); + marisa::swap(stream_, rhs.stream_); + marisa::swap(needs_fclose_, rhs.needs_fclose_); +} + +void Writer::seek(std::size_t size) { + MARISA_THROW_IF(!is_open(), MARISA_STATE_ERROR); + if (size == 0) { + return; + } else if (size <= 16) { + const char buf[16] = {}; + write_data(buf, size); + } else { + const char buf[1024] = {}; + do { + const std::size_t count = (size < sizeof(buf)) ? size : sizeof(buf); + write_data(buf, count); + size -= count; + } while (size != 0); + } +} + +bool Writer::is_open() const { + return (file_ != NULL) || (fd_ != -1) || (stream_ != NULL); +} + +void Writer::open_(const char *filename) { + std::FILE *file = NULL; +#ifdef _MSC_VER + MARISA_THROW_IF(::fopen_s(&file, filename, "wb") != 0, MARISA_IO_ERROR); +#else // _MSC_VER + file = ::fopen(filename, "wb"); + MARISA_THROW_IF(file == NULL, MARISA_IO_ERROR); +#endif // _MSC_VER + file_ = file; + needs_fclose_ = true; +} + +void Writer::open_(std::FILE *file) { + file_ = file; +} + +void Writer::open_(int fd) { + fd_ = fd; +} + +void Writer::open_(std::ostream &stream) { + stream_ = &stream; +} + +void Writer::write_data(const void *data, std::size_t size) { + MARISA_THROW_IF(!is_open(), MARISA_STATE_ERROR); + if (size == 0) { + return; + } else if (fd_ != -1) { + while (size != 0) { +#ifdef _WIN32 + static const std::size_t CHUNK_SIZE = + std::numeric_limits<int>::max(); + const unsigned int count = (size < CHUNK_SIZE) ? size : CHUNK_SIZE; + const int size_written = ::_write(fd_, data, count); +#else // _WIN32 + static const std::size_t CHUNK_SIZE = + std::numeric_limits< ::ssize_t>::max(); + const ::size_t count = (size < CHUNK_SIZE) ? size : CHUNK_SIZE; + const ::ssize_t size_written = ::write(fd_, data, count); +#endif // _WIN32 + MARISA_THROW_IF(size_written <= 0, MARISA_IO_ERROR); + data = static_cast<const char *>(data) + size_written; + size -= size_written; + } + } else if (file_ != NULL) { + MARISA_THROW_IF(::fwrite(data, 1, size, file_) != size, MARISA_IO_ERROR); + MARISA_THROW_IF(::fflush(file_) != 0, MARISA_IO_ERROR); + } else if (stream_ != NULL) { + try { + MARISA_THROW_IF(!stream_->write(static_cast<const char *>(data), size), + MARISA_IO_ERROR); + } catch (const std::ios_base::failure &) { + MARISA_THROW(MARISA_IO_ERROR, "std::ios_base::failure"); + } + } +} + +} // namespace io +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/io/writer.h b/contrib/python/marisa-trie/marisa/grimoire/io/writer.h new file mode 100644 index 0000000000..1707b23de2 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/io/writer.h @@ -0,0 +1,66 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_IO_WRITER_H_ +#define MARISA_GRIMOIRE_IO_WRITER_H_ + +#include <cstdio> +#include <iostream> + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace io { + +class Writer { + public: + Writer(); + ~Writer(); + + void open(const char *filename); + void open(std::FILE *file); + void open(int fd); + void open(std::ostream &stream); + + template <typename T> + void write(const T &obj) { + write_data(&obj, sizeof(T)); + } + + template <typename T> + void write(const T *objs, std::size_t num_objs) { + MARISA_THROW_IF((objs == NULL) && (num_objs != 0), MARISA_NULL_ERROR); + MARISA_THROW_IF(num_objs > (MARISA_SIZE_MAX / sizeof(T)), + MARISA_SIZE_ERROR); + write_data(objs, sizeof(T) * num_objs); + } + + void seek(std::size_t size); + + bool is_open() const; + + void clear(); + void swap(Writer &rhs); + + private: + std::FILE *file_; + int fd_; + std::ostream *stream_; + bool needs_fclose_; + + void open_(const char *filename); + void open_(std::FILE *file); + void open_(int fd); + void open_(std::ostream &stream); + + void write_data(const void *data, std::size_t size); + + // Disallows copy and assignment. + Writer(const Writer &); + Writer &operator=(const Writer &); +}; + +} // namespace io +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_IO_WRITER_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie.h b/contrib/python/marisa-trie/marisa/grimoire/trie.h new file mode 100644 index 0000000000..d23852a4fd --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie.h @@ -0,0 +1,17 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_H_ +#define MARISA_GRIMOIRE_TRIE_H_ + +#include "trie/state.h" +#include "trie/louds-trie.h" + +namespace marisa { +namespace grimoire { + +using trie::State; +using trie::LoudsTrie; + +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h b/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h new file mode 100644 index 0000000000..f9da360869 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/cache.h @@ -0,0 +1,82 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_CACHE_H_ +#define MARISA_GRIMOIRE_TRIE_CACHE_H_ + +#include <cfloat> + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Cache { + public: + Cache() : parent_(0), child_(0), union_() { + union_.weight = FLT_MIN; + } + Cache(const Cache &cache) + : parent_(cache.parent_), child_(cache.child_), union_(cache.union_) {} + + Cache &operator=(const Cache &cache) { + parent_ = cache.parent_; + child_ = cache.child_; + union_ = cache.union_; + return *this; + } + + void set_parent(std::size_t parent) { + MARISA_DEBUG_IF(parent > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + parent_ = (UInt32)parent; + } + void set_child(std::size_t child) { + MARISA_DEBUG_IF(child > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + child_ = (UInt32)child; + } + void set_base(UInt8 base) { + union_.link = (union_.link & ~0xFFU) | base; + } + void set_extra(std::size_t extra) { + MARISA_DEBUG_IF(extra > (MARISA_UINT32_MAX >> 8), MARISA_SIZE_ERROR); + union_.link = (UInt32)((union_.link & 0xFFU) | (extra << 8)); + } + void set_weight(float weight) { + union_.weight = weight; + } + + std::size_t parent() const { + return parent_; + } + std::size_t child() const { + return child_; + } + UInt8 base() const { + return (UInt8)(union_.link & 0xFFU); + } + std::size_t extra() const { + return union_.link >> 8; + } + char label() const { + return (char)base(); + } + std::size_t link() const { + return union_.link; + } + float weight() const { + return union_.weight; + } + + private: + UInt32 parent_; + UInt32 child_; + union Union { + UInt32 link; + float weight; + } union_; +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_CACHE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/config.h b/contrib/python/marisa-trie/marisa/grimoire/trie/config.h new file mode 100644 index 0000000000..9b307de3e1 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/config.h @@ -0,0 +1,156 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_CONFIG_H_ +#define MARISA_GRIMOIRE_TRIE_CONFIG_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Config { + public: + Config() + : num_tries_(MARISA_DEFAULT_NUM_TRIES), + cache_level_(MARISA_DEFAULT_CACHE), + tail_mode_(MARISA_DEFAULT_TAIL), + node_order_(MARISA_DEFAULT_ORDER) {} + + void parse(int config_flags) { + Config temp; + temp.parse_(config_flags); + swap(temp); + } + + int flags() const { + return (int)num_tries_ | tail_mode_ | node_order_; + } + + std::size_t num_tries() const { + return num_tries_; + } + CacheLevel cache_level() const { + return cache_level_; + } + TailMode tail_mode() const { + return tail_mode_; + } + NodeOrder node_order() const { + return node_order_; + } + + void clear() { + Config().swap(*this); + } + void swap(Config &rhs) { + marisa::swap(num_tries_, rhs.num_tries_); + marisa::swap(cache_level_, rhs.cache_level_); + marisa::swap(tail_mode_, rhs.tail_mode_); + marisa::swap(node_order_, rhs.node_order_); + } + + private: + std::size_t num_tries_; + CacheLevel cache_level_; + TailMode tail_mode_; + NodeOrder node_order_; + + void parse_(int config_flags) { + MARISA_THROW_IF((config_flags & ~MARISA_CONFIG_MASK) != 0, + MARISA_CODE_ERROR); + + parse_num_tries(config_flags); + parse_cache_level(config_flags); + parse_tail_mode(config_flags); + parse_node_order(config_flags); + } + + void parse_num_tries(int config_flags) { + const int num_tries = config_flags & MARISA_NUM_TRIES_MASK; + if (num_tries != 0) { + num_tries_ = num_tries; + } + } + + void parse_cache_level(int config_flags) { + switch (config_flags & MARISA_CACHE_LEVEL_MASK) { + case 0: { + cache_level_ = MARISA_DEFAULT_CACHE; + break; + } + case MARISA_HUGE_CACHE: { + cache_level_ = MARISA_HUGE_CACHE; + break; + } + case MARISA_LARGE_CACHE: { + cache_level_ = MARISA_LARGE_CACHE; + break; + } + case MARISA_NORMAL_CACHE: { + cache_level_ = MARISA_NORMAL_CACHE; + break; + } + case MARISA_SMALL_CACHE: { + cache_level_ = MARISA_SMALL_CACHE; + break; + } + case MARISA_TINY_CACHE: { + cache_level_ = MARISA_TINY_CACHE; + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined cache level"); + } + } + } + + void parse_tail_mode(int config_flags) { + switch (config_flags & MARISA_TAIL_MODE_MASK) { + case 0: { + tail_mode_ = MARISA_DEFAULT_TAIL; + break; + } + case MARISA_TEXT_TAIL: { + tail_mode_ = MARISA_TEXT_TAIL; + break; + } + case MARISA_BINARY_TAIL: { + tail_mode_ = MARISA_BINARY_TAIL; + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined tail mode"); + } + } + } + + void parse_node_order(int config_flags) { + switch (config_flags & MARISA_NODE_ORDER_MASK) { + case 0: { + node_order_ = MARISA_DEFAULT_ORDER; + break; + } + case MARISA_LABEL_ORDER: { + node_order_ = MARISA_LABEL_ORDER; + break; + } + case MARISA_WEIGHT_ORDER: { + node_order_ = MARISA_WEIGHT_ORDER; + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined node order"); + } + } + } + + // Disallows copy and assignment. + Config(const Config &); + Config &operator=(const Config &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_CONFIG_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h b/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h new file mode 100644 index 0000000000..834ab95e1e --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/entry.h @@ -0,0 +1,83 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_ENTRY_H_ +#define MARISA_GRIMOIRE_TRIE_ENTRY_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Entry { + public: + Entry() + : ptr_(reinterpret_cast<const char *>(-1)), length_(0), id_(0) {} + Entry(const Entry &entry) + : ptr_(entry.ptr_), length_(entry.length_), id_(entry.id_) {} + + Entry &operator=(const Entry &entry) { + ptr_ = entry.ptr_; + length_ = entry.length_; + id_ = entry.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return *(ptr_ - i); + } + + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr + length - 1; + length_ = (UInt32)length; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + id_ = (UInt32)id; + } + + const char *ptr() const { + return ptr_ - length_ + 1; + } + std::size_t length() const { + return length_; + } + std::size_t id() const { + return id_; + } + + class StringComparer { + public: + bool operator()(const Entry &lhs, const Entry &rhs) const { + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return true; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] > (UInt8)rhs[i]; + } + } + return lhs.length() > rhs.length(); + } + }; + + class IDComparer { + public: + bool operator()(const Entry &lhs, const Entry &rhs) const { + return lhs.id_ < rhs.id_; + } + }; + + private: + const char *ptr_; + UInt32 length_; + UInt32 id_; +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_ENTRY_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/header.h b/contrib/python/marisa-trie/marisa/grimoire/trie/header.h new file mode 100644 index 0000000000..04839f67e1 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/header.h @@ -0,0 +1,62 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_HEADER_H_ +#define MARISA_GRIMOIRE_TRIE_HEADER_H_ + +#include "../io.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Header { + public: + enum { + HEADER_SIZE = 16 + }; + + Header() {} + + void map(Mapper &mapper) { + const char *ptr; + mapper.map(&ptr, HEADER_SIZE); + MARISA_THROW_IF(!test_header(ptr), MARISA_FORMAT_ERROR); + } + void read(Reader &reader) { + char buf[HEADER_SIZE]; + reader.read(buf, HEADER_SIZE); + MARISA_THROW_IF(!test_header(buf), MARISA_FORMAT_ERROR); + } + void write(Writer &writer) const { + writer.write(get_header(), HEADER_SIZE); + } + + std::size_t io_size() const { + return HEADER_SIZE; + } + + private: + + static const char *get_header() { + static const char buf[HEADER_SIZE] = "We love Marisa."; + return buf; + } + + static bool test_header(const char *ptr) { + for (std::size_t i = 0; i < HEADER_SIZE; ++i) { + if (ptr[i] != get_header()[i]) { + return false; + } + } + return true; + } + + // Disallows copy and assignment. + Header(const Header &); + Header &operator=(const Header &); +}; + +} // namespace trie +} // namespace marisa +} // namespace grimoire + +#endif // MARISA_GRIMOIRE_TRIE_HEADER_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/history.h b/contrib/python/marisa-trie/marisa/grimoire/trie/history.h new file mode 100644 index 0000000000..9a3d272260 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/history.h @@ -0,0 +1,66 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_ +#define MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class History { + public: + History() + : node_id_(0), louds_pos_(0), key_pos_(0), + link_id_(MARISA_INVALID_LINK_ID), key_id_(MARISA_INVALID_KEY_ID) {} + + void set_node_id(std::size_t node_id) { + MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + node_id_ = (UInt32)node_id; + } + void set_louds_pos(std::size_t louds_pos) { + MARISA_DEBUG_IF(louds_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + louds_pos_ = (UInt32)louds_pos; + } + void set_key_pos(std::size_t key_pos) { + MARISA_DEBUG_IF(key_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_pos_ = (UInt32)key_pos; + } + void set_link_id(std::size_t link_id) { + MARISA_DEBUG_IF(link_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + link_id_ = (UInt32)link_id; + } + void set_key_id(std::size_t key_id) { + MARISA_DEBUG_IF(key_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_id_ = (UInt32)key_id; + } + + std::size_t node_id() const { + return node_id_; + } + std::size_t louds_pos() const { + return louds_pos_; + } + std::size_t key_pos() const { + return key_pos_; + } + std::size_t link_id() const { + return link_id_; + } + std::size_t key_id() const { + return key_id_; + } + + private: + UInt32 node_id_; + UInt32 louds_pos_; + UInt32 key_pos_; + UInt32 link_id_; + UInt32 key_id_; +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_STATE_HISTORY_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/key.h b/contrib/python/marisa-trie/marisa/grimoire/trie/key.h new file mode 100644 index 0000000000..c09ea86cf8 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/key.h @@ -0,0 +1,227 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_KEY_H_ +#define MARISA_GRIMOIRE_TRIE_KEY_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Key { + public: + Key() : ptr_(NULL), length_(0), union_(), id_(0) { + union_.terminal = 0; + } + Key(const Key &entry) + : ptr_(entry.ptr_), length_(entry.length_), + union_(entry.union_), id_(entry.id_) {} + + Key &operator=(const Key &entry) { + ptr_ = entry.ptr_; + length_ = entry.length_; + union_ = entry.union_; + id_ = entry.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return ptr_[i]; + } + + void substr(std::size_t pos, std::size_t length) { + MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR); + ptr_ += pos; + length_ = (UInt32)length; + } + + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr; + length_ = (UInt32)length; + } + void set_weight(float weight) { + union_.weight = weight; + } + void set_terminal(std::size_t terminal) { + MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + union_.terminal = (UInt32)terminal; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + id_ = (UInt32)id; + } + + const char *ptr() const { + return ptr_; + } + std::size_t length() const { + return length_; + } + float weight() const { + return union_.weight; + } + std::size_t terminal() const { + return union_.terminal; + } + std::size_t id() const { + return id_; + } + + private: + const char *ptr_; + UInt32 length_; + union Union { + float weight; + UInt32 terminal; + } union_; + UInt32 id_; +}; + +inline bool operator==(const Key &lhs, const Key &rhs) { + if (lhs.length() != rhs.length()) { + return false; + } + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +inline bool operator!=(const Key &lhs, const Key &rhs) { + return !(lhs == rhs); +} + +inline bool operator<(const Key &lhs, const Key &rhs) { + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return false; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] < (UInt8)rhs[i]; + } + } + return lhs.length() < rhs.length(); +} + +inline bool operator>(const Key &lhs, const Key &rhs) { + return rhs < lhs; +} + +class ReverseKey { + public: + ReverseKey() : ptr_(NULL), length_(0), union_(), id_(0) { + union_.terminal = 0; + } + ReverseKey(const ReverseKey &entry) + : ptr_(entry.ptr_), length_(entry.length_), + union_(entry.union_), id_(entry.id_) {} + + ReverseKey &operator=(const ReverseKey &entry) { + ptr_ = entry.ptr_; + length_ = entry.length_; + union_ = entry.union_; + id_ = entry.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return *(ptr_ - i - 1); + } + + void substr(std::size_t pos, std::size_t length) { + MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR); + ptr_ -= pos; + length_ = (UInt32)length; + } + + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr + length; + length_ = (UInt32)length; + } + void set_weight(float weight) { + union_.weight = weight; + } + void set_terminal(std::size_t terminal) { + MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + union_.terminal = (UInt32)terminal; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + id_ = (UInt32)id; + } + + const char *ptr() const { + return ptr_ - length_; + } + std::size_t length() const { + return length_; + } + float weight() const { + return union_.weight; + } + std::size_t terminal() const { + return union_.terminal; + } + std::size_t id() const { + return id_; + } + + private: + const char *ptr_; + UInt32 length_; + union Union { + float weight; + UInt32 terminal; + } union_; + UInt32 id_; +}; + +inline bool operator==(const ReverseKey &lhs, const ReverseKey &rhs) { + if (lhs.length() != rhs.length()) { + return false; + } + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +inline bool operator!=(const ReverseKey &lhs, const ReverseKey &rhs) { + return !(lhs == rhs); +} + +inline bool operator<(const ReverseKey &lhs, const ReverseKey &rhs) { + for (std::size_t i = 0; i < lhs.length(); ++i) { + if (i == rhs.length()) { + return false; + } + if (lhs[i] != rhs[i]) { + return (UInt8)lhs[i] < (UInt8)rhs[i]; + } + } + return lhs.length() < rhs.length(); +} + +inline bool operator>(const ReverseKey &lhs, const ReverseKey &rhs) { + return rhs < lhs; +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_KEY_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc new file mode 100644 index 0000000000..ed168539bc --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.cc @@ -0,0 +1,877 @@ +#include <algorithm> +#include <functional> +#include <queue> + +#include "../algorithm.h" +#include "header.h" +#include "range.h" +#include "state.h" +#include "louds-trie.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +LoudsTrie::LoudsTrie() + : louds_(), terminal_flags_(), link_flags_(), bases_(), extras_(), + tail_(), next_trie_(), cache_(), cache_mask_(0), num_l1_nodes_(0), + config_(), mapper_() {} + +LoudsTrie::~LoudsTrie() {} + +void LoudsTrie::build(Keyset &keyset, int flags) { + Config config; + config.parse(flags); + + LoudsTrie temp; + temp.build_(keyset, config); + swap(temp); +} + +void LoudsTrie::map(Mapper &mapper) { + Header().map(mapper); + + LoudsTrie temp; + temp.map_(mapper); + temp.mapper_.swap(mapper); + swap(temp); +} + +void LoudsTrie::read(Reader &reader) { + Header().read(reader); + + LoudsTrie temp; + temp.read_(reader); + swap(temp); +} + +void LoudsTrie::write(Writer &writer) const { + Header().write(writer); + + write_(writer); +} + +bool LoudsTrie::lookup(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + + State &state = agent.state(); + state.lookup_init(); + while (state.query_pos() < agent.query().length()) { + if (!find_child(agent)) { + return false; + } + } + if (!terminal_flags_[state.node_id()]) { + return false; + } + agent.set_key(agent.query().ptr(), agent.query().length()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; +} + +void LoudsTrie::reverse_lookup(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + MARISA_THROW_IF(agent.query().id() >= size(), MARISA_BOUND_ERROR); + + State &state = agent.state(); + state.reverse_lookup_init(); + + state.set_node_id(terminal_flags_.select1(agent.query().id())); + if (state.node_id() == 0) { + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(agent.query().id()); + return; + } + for ( ; ; ) { + if (link_flags_[state.node_id()]) { + const std::size_t prev_key_pos = state.key_buf().size(); + restore(agent, get_link(state.node_id())); + std::reverse(state.key_buf().begin() + prev_key_pos, + state.key_buf().end()); + } else { + state.key_buf().push_back((char)bases_[state.node_id()]); + } + + if (state.node_id() <= num_l1_nodes_) { + std::reverse(state.key_buf().begin(), state.key_buf().end()); + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(agent.query().id()); + return; + } + state.set_node_id(louds_.select1(state.node_id()) - state.node_id() - 1); + } +} + +bool LoudsTrie::common_prefix_search(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (state.status_code() == MARISA_END_OF_COMMON_PREFIX_SEARCH) { + return false; + } + + if (state.status_code() != MARISA_READY_TO_COMMON_PREFIX_SEARCH) { + state.common_prefix_search_init(); + if (terminal_flags_[state.node_id()]) { + agent.set_key(agent.query().ptr(), state.query_pos()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; + } + } + + while (state.query_pos() < agent.query().length()) { + if (!find_child(agent)) { + state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH); + return false; + } else if (terminal_flags_[state.node_id()]) { + agent.set_key(agent.query().ptr(), state.query_pos()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; + } + } + state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH); + return false; +} + +bool LoudsTrie::predictive_search(Agent &agent) const { + MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (state.status_code() == MARISA_END_OF_PREDICTIVE_SEARCH) { + return false; + } + + if (state.status_code() != MARISA_READY_TO_PREDICTIVE_SEARCH) { + state.predictive_search_init(); + while (state.query_pos() < agent.query().length()) { + if (!predictive_find_child(agent)) { + state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH); + return false; + } + } + + History history; + history.set_node_id(state.node_id()); + history.set_key_pos(state.key_buf().size()); + state.history().push_back(history); + state.set_history_pos(1); + + if (terminal_flags_[state.node_id()]) { + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(terminal_flags_.rank1(state.node_id())); + return true; + } + } + + for ( ; ; ) { + if (state.history_pos() == state.history().size()) { + const History ¤t = state.history().back(); + History next; + next.set_louds_pos(louds_.select0(current.node_id()) + 1); + next.set_node_id(next.louds_pos() - current.node_id() - 1); + state.history().push_back(next); + } + + History &next = state.history()[state.history_pos()]; + const bool link_flag = louds_[next.louds_pos()]; + next.set_louds_pos(next.louds_pos() + 1); + if (link_flag) { + state.set_history_pos(state.history_pos() + 1); + if (link_flags_[next.node_id()]) { + next.set_link_id(update_link_id(next.link_id(), next.node_id())); + restore(agent, get_link(next.node_id(), next.link_id())); + } else { + state.key_buf().push_back((char)bases_[next.node_id()]); + } + next.set_key_pos(state.key_buf().size()); + + if (terminal_flags_[next.node_id()]) { + if (next.key_id() == MARISA_INVALID_KEY_ID) { + next.set_key_id(terminal_flags_.rank1(next.node_id())); + } else { + next.set_key_id(next.key_id() + 1); + } + agent.set_key(state.key_buf().begin(), state.key_buf().size()); + agent.set_key(next.key_id()); + return true; + } + } else if (state.history_pos() != 1) { + History ¤t = state.history()[state.history_pos() - 1]; + current.set_node_id(current.node_id() + 1); + const History &prev = + state.history()[state.history_pos() - 2]; + state.key_buf().resize(prev.key_pos()); + state.set_history_pos(state.history_pos() - 1); + } else { + state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH); + return false; + } + } +} + +std::size_t LoudsTrie::total_size() const { + return louds_.total_size() + terminal_flags_.total_size() + + link_flags_.total_size() + bases_.total_size() + + extras_.total_size() + tail_.total_size() + + ((next_trie_.get() != NULL) ? next_trie_->total_size() : 0) + + cache_.total_size(); +} + +std::size_t LoudsTrie::io_size() const { + return Header().io_size() + louds_.io_size() + + terminal_flags_.io_size() + link_flags_.io_size() + + bases_.io_size() + extras_.io_size() + tail_.io_size() + + ((next_trie_.get() != NULL) ? + (next_trie_->io_size() - Header().io_size()) : 0) + + cache_.io_size() + (sizeof(UInt32) * 2); +} + +void LoudsTrie::clear() { + LoudsTrie().swap(*this); +} + +void LoudsTrie::swap(LoudsTrie &rhs) { + louds_.swap(rhs.louds_); + terminal_flags_.swap(rhs.terminal_flags_); + link_flags_.swap(rhs.link_flags_); + bases_.swap(rhs.bases_); + extras_.swap(rhs.extras_); + tail_.swap(rhs.tail_); + next_trie_.swap(rhs.next_trie_); + cache_.swap(rhs.cache_); + marisa::swap(cache_mask_, rhs.cache_mask_); + marisa::swap(num_l1_nodes_, rhs.num_l1_nodes_); + config_.swap(rhs.config_); + mapper_.swap(rhs.mapper_); +} + +void LoudsTrie::build_(Keyset &keyset, const Config &config) { + Vector<Key> keys; + keys.resize(keyset.size()); + for (std::size_t i = 0; i < keyset.size(); ++i) { + keys[i].set_str(keyset[i].ptr(), keyset[i].length()); + keys[i].set_weight(keyset[i].weight()); + } + + Vector<UInt32> terminals; + build_trie(keys, &terminals, config, 1); + + typedef std::pair<UInt32, UInt32> TerminalIdPair; + + Vector<TerminalIdPair> pairs; + pairs.resize(terminals.size()); + for (std::size_t i = 0; i < pairs.size(); ++i) { + pairs[i].first = terminals[i]; + pairs[i].second = (UInt32)i; + } + terminals.clear(); + std::sort(pairs.begin(), pairs.end()); + + std::size_t node_id = 0; + for (std::size_t i = 0; i < pairs.size(); ++i) { + while (node_id < pairs[i].first) { + terminal_flags_.push_back(false); + ++node_id; + } + if (node_id == pairs[i].first) { + terminal_flags_.push_back(true); + ++node_id; + } + } + while (node_id < bases_.size()) { + terminal_flags_.push_back(false); + ++node_id; + } + terminal_flags_.push_back(false); + terminal_flags_.build(false, true); + + for (std::size_t i = 0; i < keyset.size(); ++i) { + keyset[pairs[i].second].set_id(terminal_flags_.rank1(pairs[i].first)); + } +} + +template <typename T> +void LoudsTrie::build_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) { + build_current_trie(keys, terminals, config, trie_id); + + Vector<UInt32> next_terminals; + if (!keys.empty()) { + build_next_trie(keys, &next_terminals, config, trie_id); + } + + if (next_trie_.get() != NULL) { + config_.parse(static_cast<int>((next_trie_->num_tries() + 1)) | + next_trie_->tail_mode() | next_trie_->node_order()); + } else { + config_.parse(1 | tail_.mode() | config.node_order() | + config.cache_level()); + } + + link_flags_.build(false, false); + std::size_t node_id = 0; + for (std::size_t i = 0; i < next_terminals.size(); ++i) { + while (!link_flags_[node_id]) { + ++node_id; + } + bases_[node_id] = (UInt8)(next_terminals[i] % 256); + next_terminals[i] /= 256; + ++node_id; + } + extras_.build(next_terminals); + fill_cache(); +} + +template <typename T> +void LoudsTrie::build_current_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, + std::size_t trie_id) try { + for (std::size_t i = 0; i < keys.size(); ++i) { + keys[i].set_id(i); + } + const std::size_t num_keys = Algorithm().sort(keys.begin(), keys.end()); + reserve_cache(config, trie_id, num_keys); + + louds_.push_back(true); + louds_.push_back(false); + bases_.push_back('\0'); + link_flags_.push_back(false); + + Vector<T> next_keys; + std::queue<Range> queue; + Vector<WeightedRange> w_ranges; + + queue.push(make_range(0, keys.size(), 0)); + while (!queue.empty()) { + const std::size_t node_id = link_flags_.size() - queue.size(); + + Range range = queue.front(); + queue.pop(); + + while ((range.begin() < range.end()) && + (keys[range.begin()].length() == range.key_pos())) { + keys[range.begin()].set_terminal(node_id); + range.set_begin(range.begin() + 1); + } + + if (range.begin() == range.end()) { + louds_.push_back(false); + continue; + } + + w_ranges.clear(); + double weight = keys[range.begin()].weight(); + for (std::size_t i = range.begin() + 1; i < range.end(); ++i) { + if (keys[i - 1][range.key_pos()] != keys[i][range.key_pos()]) { + w_ranges.push_back(make_weighted_range( + range.begin(), i, range.key_pos(), (float)weight)); + range.set_begin(i); + weight = 0.0; + } + weight += keys[i].weight(); + } + w_ranges.push_back(make_weighted_range( + range.begin(), range.end(), range.key_pos(), (float)weight)); + if (config.node_order() == MARISA_WEIGHT_ORDER) { + std::stable_sort(w_ranges.begin(), w_ranges.end(), + std::greater<WeightedRange>()); + } + + if (node_id == 0) { + num_l1_nodes_ = w_ranges.size(); + } + + for (std::size_t i = 0; i < w_ranges.size(); ++i) { + WeightedRange &w_range = w_ranges[i]; + std::size_t key_pos = w_range.key_pos() + 1; + while (key_pos < keys[w_range.begin()].length()) { + std::size_t j; + for (j = w_range.begin() + 1; j < w_range.end(); ++j) { + if (keys[j - 1][key_pos] != keys[j][key_pos]) { + break; + } + } + if (j < w_range.end()) { + break; + } + ++key_pos; + } + cache<T>(node_id, bases_.size(), w_range.weight(), + keys[w_range.begin()][w_range.key_pos()]); + + if (key_pos == w_range.key_pos() + 1) { + bases_.push_back(keys[w_range.begin()][w_range.key_pos()]); + link_flags_.push_back(false); + } else { + bases_.push_back('\0'); + link_flags_.push_back(true); + T next_key; + next_key.set_str(keys[w_range.begin()].ptr(), + keys[w_range.begin()].length()); + next_key.substr(w_range.key_pos(), key_pos - w_range.key_pos()); + next_key.set_weight(w_range.weight()); + next_keys.push_back(next_key); + } + w_range.set_key_pos(key_pos); + queue.push(w_range.range()); + louds_.push_back(true); + } + louds_.push_back(false); + } + + louds_.push_back(false); + louds_.build(trie_id == 1, true); + bases_.shrink(); + + build_terminals(keys, terminals); + keys.swap(next_keys); +} catch (const std::bad_alloc &) { + MARISA_THROW(MARISA_MEMORY_ERROR, "std::bad_alloc"); +} + +template <> +void LoudsTrie::build_next_trie(Vector<Key> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) { + if (trie_id == config.num_tries()) { + Vector<Entry> entries; + entries.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + entries[i].set_str(keys[i].ptr(), keys[i].length()); + } + tail_.build(entries, terminals, config.tail_mode()); + return; + } + Vector<ReverseKey> reverse_keys; + reverse_keys.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + reverse_keys[i].set_str(keys[i].ptr(), keys[i].length()); + reverse_keys[i].set_weight(keys[i].weight()); + } + keys.clear(); + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->build_trie(reverse_keys, terminals, config, trie_id + 1); +} + +template <> +void LoudsTrie::build_next_trie(Vector<ReverseKey> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) { + if (trie_id == config.num_tries()) { + Vector<Entry> entries; + entries.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + entries[i].set_str(keys[i].ptr(), keys[i].length()); + } + tail_.build(entries, terminals, config.tail_mode()); + return; + } + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->build_trie(keys, terminals, config, trie_id + 1); +} + +template <typename T> +void LoudsTrie::build_terminals(const Vector<T> &keys, + Vector<UInt32> *terminals) const { + Vector<UInt32> temp; + temp.resize(keys.size()); + for (std::size_t i = 0; i < keys.size(); ++i) { + temp[keys[i].id()] = (UInt32)keys[i].terminal(); + } + terminals->swap(temp); +} + +template <> +void LoudsTrie::cache<Key>(std::size_t parent, std::size_t child, + float weight, char label) { + MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR); + + const std::size_t cache_id = get_cache_id(parent, label); + if (weight > cache_[cache_id].weight()) { + cache_[cache_id].set_parent(parent); + cache_[cache_id].set_child(child); + cache_[cache_id].set_weight(weight); + } +} + +void LoudsTrie::reserve_cache(const Config &config, std::size_t trie_id, + std::size_t num_keys) { + std::size_t cache_size = (trie_id == 1) ? 256 : 1; + while (cache_size < (num_keys / config.cache_level())) { + cache_size *= 2; + } + cache_.resize(cache_size); + cache_mask_ = cache_size - 1; +} + +template <> +void LoudsTrie::cache<ReverseKey>(std::size_t parent, std::size_t child, + float weight, char) { + MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR); + + const std::size_t cache_id = get_cache_id(child); + if (weight > cache_[cache_id].weight()) { + cache_[cache_id].set_parent(parent); + cache_[cache_id].set_child(child); + cache_[cache_id].set_weight(weight); + } +} + +void LoudsTrie::fill_cache() { + for (std::size_t i = 0; i < cache_.size(); ++i) { + const std::size_t node_id = cache_[i].child(); + if (node_id != 0) { + cache_[i].set_base(bases_[node_id]); + cache_[i].set_extra(!link_flags_[node_id] ? + MARISA_INVALID_EXTRA : extras_[link_flags_.rank1(node_id)]); + } else { + cache_[i].set_parent(MARISA_UINT32_MAX); + cache_[i].set_child(MARISA_UINT32_MAX); + } + } +} + +void LoudsTrie::map_(Mapper &mapper) { + louds_.map(mapper); + terminal_flags_.map(mapper); + link_flags_.map(mapper); + bases_.map(mapper); + extras_.map(mapper); + tail_.map(mapper); + if ((link_flags_.num_1s() != 0) && tail_.empty()) { + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->map_(mapper); + } + cache_.map(mapper); + cache_mask_ = cache_.size() - 1; + { + UInt32 temp_num_l1_nodes; + mapper.map(&temp_num_l1_nodes); + num_l1_nodes_ = temp_num_l1_nodes; + } + { + UInt32 temp_config_flags; + mapper.map(&temp_config_flags); + config_.parse((int)temp_config_flags); + } +} + +void LoudsTrie::read_(Reader &reader) { + louds_.read(reader); + terminal_flags_.read(reader); + link_flags_.read(reader); + bases_.read(reader); + extras_.read(reader); + tail_.read(reader); + if ((link_flags_.num_1s() != 0) && tail_.empty()) { + next_trie_.reset(new (std::nothrow) LoudsTrie); + MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR); + next_trie_->read_(reader); + } + cache_.read(reader); + cache_mask_ = cache_.size() - 1; + { + UInt32 temp_num_l1_nodes; + reader.read(&temp_num_l1_nodes); + num_l1_nodes_ = temp_num_l1_nodes; + } + { + UInt32 temp_config_flags; + reader.read(&temp_config_flags); + config_.parse((int)temp_config_flags); + } +} + +void LoudsTrie::write_(Writer &writer) const { + louds_.write(writer); + terminal_flags_.write(writer); + link_flags_.write(writer); + bases_.write(writer); + extras_.write(writer); + tail_.write(writer); + if (next_trie_.get() != NULL) { + next_trie_->write_(writer); + } + cache_.write(writer); + writer.write((UInt32)num_l1_nodes_); + writer.write((UInt32)config_.flags()); +} + +bool LoudsTrie::find_child(Agent &agent) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + + State &state = agent.state(); + const std::size_t cache_id = get_cache_id(state.node_id(), + agent.query()[state.query_pos()]); + if (state.node_id() == cache_[cache_id].parent()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!match(agent, cache_[cache_id].link())) { + return false; + } + } else { + state.set_query_pos(state.query_pos() + 1); + } + state.set_node_id(cache_[cache_id].child()); + return true; + } + + std::size_t louds_pos = louds_.select0(state.node_id()) + 1; + if (!louds_[louds_pos]) { + return false; + } + state.set_node_id(louds_pos - state.node_id() - 1); + std::size_t link_id = MARISA_INVALID_LINK_ID; + do { + if (link_flags_[state.node_id()]) { + link_id = update_link_id(link_id, state.node_id()); + const std::size_t prev_query_pos = state.query_pos(); + if (match(agent, get_link(state.node_id(), link_id))) { + return true; + } else if (state.query_pos() != prev_query_pos) { + return false; + } + } else if (bases_[state.node_id()] == + (UInt8)agent.query()[state.query_pos()]) { + state.set_query_pos(state.query_pos() + 1); + return true; + } + state.set_node_id(state.node_id() + 1); + ++louds_pos; + } while (louds_[louds_pos]); + return false; +} + +bool LoudsTrie::predictive_find_child(Agent &agent) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + + State &state = agent.state(); + const std::size_t cache_id = get_cache_id(state.node_id(), + agent.query()[state.query_pos()]); + if (state.node_id() == cache_[cache_id].parent()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!prefix_match(agent, cache_[cache_id].link())) { + return false; + } + } else { + state.key_buf().push_back(cache_[cache_id].label()); + state.set_query_pos(state.query_pos() + 1); + } + state.set_node_id(cache_[cache_id].child()); + return true; + } + + std::size_t louds_pos = louds_.select0(state.node_id()) + 1; + if (!louds_[louds_pos]) { + return false; + } + state.set_node_id(louds_pos - state.node_id() - 1); + std::size_t link_id = MARISA_INVALID_LINK_ID; + do { + if (link_flags_[state.node_id()]) { + link_id = update_link_id(link_id, state.node_id()); + const std::size_t prev_query_pos = state.query_pos(); + if (prefix_match(agent, get_link(state.node_id(), link_id))) { + return true; + } else if (state.query_pos() != prev_query_pos) { + return false; + } + } else if (bases_[state.node_id()] == + (UInt8)agent.query()[state.query_pos()]) { + state.key_buf().push_back((char)bases_[state.node_id()]); + state.set_query_pos(state.query_pos() + 1); + return true; + } + state.set_node_id(state.node_id() + 1); + ++louds_pos; + } while (louds_[louds_pos]); + return false; +} + +void LoudsTrie::restore(Agent &agent, std::size_t link) const { + if (next_trie_.get() != NULL) { + next_trie_->restore_(agent, link); + } else { + tail_.restore(agent, link); + } +} + +bool LoudsTrie::match(Agent &agent, std::size_t link) const { + if (next_trie_.get() != NULL) { + return next_trie_->match_(agent, link); + } else { + return tail_.match(agent, link); + } +} + +bool LoudsTrie::prefix_match(Agent &agent, std::size_t link) const { + if (next_trie_.get() != NULL) { + return next_trie_->prefix_match_(agent, link); + } else { + return tail_.prefix_match(agent, link); + } +} + +void LoudsTrie::restore_(Agent &agent, std::size_t node_id) const { + MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR); + + State &state = agent.state(); + for ( ; ; ) { + const std::size_t cache_id = get_cache_id(node_id); + if (node_id == cache_[cache_id].child()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + restore(agent, cache_[cache_id].link()); + } else { + state.key_buf().push_back(cache_[cache_id].label()); + } + + node_id = cache_[cache_id].parent(); + if (node_id == 0) { + return; + } + continue; + } + + if (link_flags_[node_id]) { + restore(agent, get_link(node_id)); + } else { + state.key_buf().push_back((char)bases_[node_id]); + } + + if (node_id <= num_l1_nodes_) { + return; + } + node_id = louds_.select1(node_id) - node_id - 1; + } +} + +bool LoudsTrie::match_(Agent &agent, std::size_t node_id) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR); + + State &state = agent.state(); + for ( ; ; ) { + const std::size_t cache_id = get_cache_id(node_id); + if (node_id == cache_[cache_id].child()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!match(agent, cache_[cache_id].link())) { + return false; + } + } else if (cache_[cache_id].label() == + agent.query()[state.query_pos()]) { + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + node_id = cache_[cache_id].parent(); + if (node_id == 0) { + return true; + } else if (state.query_pos() >= agent.query().length()) { + return false; + } + continue; + } + + if (link_flags_[node_id]) { + if (next_trie_.get() != NULL) { + if (!match(agent, get_link(node_id))) { + return false; + } + } else if (!tail_.match(agent, get_link(node_id))) { + return false; + } + } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) { + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + if (node_id <= num_l1_nodes_) { + return true; + } else if (state.query_pos() >= agent.query().length()) { + return false; + } + node_id = louds_.select1(node_id) - node_id - 1; + } +} + +bool LoudsTrie::prefix_match_(Agent &agent, std::size_t node_id) const { + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR); + + State &state = agent.state(); + for ( ; ; ) { + const std::size_t cache_id = get_cache_id(node_id); + if (node_id == cache_[cache_id].child()) { + if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) { + if (!prefix_match(agent, cache_[cache_id].link())) { + return false; + } + } else if (cache_[cache_id].label() == + agent.query()[state.query_pos()]) { + state.key_buf().push_back(cache_[cache_id].label()); + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + node_id = cache_[cache_id].parent(); + if (node_id == 0) { + return true; + } + } else { + if (link_flags_[node_id]) { + if (!prefix_match(agent, get_link(node_id))) { + return false; + } + } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) { + state.key_buf().push_back((char)bases_[node_id]); + state.set_query_pos(state.query_pos() + 1); + } else { + return false; + } + + if (node_id <= num_l1_nodes_) { + return true; + } + node_id = louds_.select1(node_id) - node_id - 1; + } + + if (state.query_pos() >= agent.query().length()) { + restore_(agent, node_id); + return true; + } + } +} + +std::size_t LoudsTrie::get_cache_id(std::size_t node_id, char label) const { + return (node_id ^ (node_id << 5) ^ (UInt8)label) & cache_mask_; +} + +std::size_t LoudsTrie::get_cache_id(std::size_t node_id) const { + return node_id & cache_mask_; +} + +std::size_t LoudsTrie::get_link(std::size_t node_id) const { + return bases_[node_id] | (extras_[link_flags_.rank1(node_id)] * 256); +} + +std::size_t LoudsTrie::get_link(std::size_t node_id, + std::size_t link_id) const { + return bases_[node_id] | (extras_[link_id] * 256); +} + +std::size_t LoudsTrie::update_link_id(std::size_t link_id, + std::size_t node_id) const { + return (link_id == MARISA_INVALID_LINK_ID) ? + link_flags_.rank1(node_id) : (link_id + 1); +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h new file mode 100644 index 0000000000..5a757ac8fc --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/louds-trie.h @@ -0,0 +1,135 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_ +#define MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_ + +#include "../../keyset.h" +#include "../../agent.h" +#include "../vector.h" +#include "config.h" +#include "key.h" +#include "tail.h" +#include "cache.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class LoudsTrie { + public: + LoudsTrie(); + ~LoudsTrie(); + + void build(Keyset &keyset, int flags); + + void map(Mapper &mapper); + void read(Reader &reader); + void write(Writer &writer) const; + + bool lookup(Agent &agent) const; + void reverse_lookup(Agent &agent) const; + bool common_prefix_search(Agent &agent) const; + bool predictive_search(Agent &agent) const; + + std::size_t num_tries() const { + return config_.num_tries(); + } + std::size_t num_keys() const { + return size(); + } + std::size_t num_nodes() const { + return (louds_.size() / 2) - 1; + } + + CacheLevel cache_level() const { + return config_.cache_level(); + } + TailMode tail_mode() const { + return config_.tail_mode(); + } + NodeOrder node_order() const { + return config_.node_order(); + } + + bool empty() const { + return size() == 0; + } + std::size_t size() const { + return terminal_flags_.num_1s(); + } + std::size_t total_size() const; + std::size_t io_size() const; + + void clear(); + void swap(LoudsTrie &rhs); + + private: + BitVector louds_; + BitVector terminal_flags_; + BitVector link_flags_; + Vector<UInt8> bases_; + FlatVector extras_; + Tail tail_; + scoped_ptr<LoudsTrie> next_trie_; + Vector<Cache> cache_; + std::size_t cache_mask_; + std::size_t num_l1_nodes_; + Config config_; + Mapper mapper_; + + void build_(Keyset &keyset, const Config &config); + + template <typename T> + void build_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id); + template <typename T> + void build_current_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id); + template <typename T> + void build_next_trie(Vector<T> &keys, + Vector<UInt32> *terminals, const Config &config, std::size_t trie_id); + template <typename T> + void build_terminals(const Vector<T> &keys, + Vector<UInt32> *terminals) const; + + void reserve_cache(const Config &config, std::size_t trie_id, + std::size_t num_keys); + template <typename T> + void cache(std::size_t parent, std::size_t child, + float weight, char label); + void fill_cache(); + + void map_(Mapper &mapper); + void read_(Reader &reader); + void write_(Writer &writer) const; + + inline bool find_child(Agent &agent) const; + inline bool predictive_find_child(Agent &agent) const; + + inline void restore(Agent &agent, std::size_t node_id) const; + inline bool match(Agent &agent, std::size_t node_id) const; + inline bool prefix_match(Agent &agent, std::size_t node_id) const; + + void restore_(Agent &agent, std::size_t node_id) const; + bool match_(Agent &agent, std::size_t node_id) const; + bool prefix_match_(Agent &agent, std::size_t node_id) const; + + inline std::size_t get_cache_id(std::size_t node_id, char label) const; + inline std::size_t get_cache_id(std::size_t node_id) const; + + inline std::size_t get_link(std::size_t node_id) const; + inline std::size_t get_link(std::size_t node_id, + std::size_t link_id) const; + + inline std::size_t update_link_id(std::size_t link_id, + std::size_t node_id) const; + + // Disallows copy and assignment. + LoudsTrie(const LoudsTrie &); + LoudsTrie &operator=(const LoudsTrie &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_LOUDS_TRIE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/range.h b/contrib/python/marisa-trie/marisa/grimoire/trie/range.h new file mode 100644 index 0000000000..4ca39a9c37 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/range.h @@ -0,0 +1,116 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_RANGE_H_ +#define MARISA_GRIMOIRE_TRIE_RANGE_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Range { + public: + Range() : begin_(0), end_(0), key_pos_(0) {} + + void set_begin(std::size_t begin) { + MARISA_DEBUG_IF(begin > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + begin_ = static_cast<UInt32>(begin); + } + void set_end(std::size_t end) { + MARISA_DEBUG_IF(end > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + end_ = static_cast<UInt32>(end); + } + void set_key_pos(std::size_t key_pos) { + MARISA_DEBUG_IF(key_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + key_pos_ = static_cast<UInt32>(key_pos); + } + + std::size_t begin() const { + return begin_; + } + std::size_t end() const { + return end_; + } + std::size_t key_pos() const { + return key_pos_; + } + + private: + UInt32 begin_; + UInt32 end_; + UInt32 key_pos_; +}; + +inline Range make_range(std::size_t begin, std::size_t end, + std::size_t key_pos) { + Range range; + range.set_begin(begin); + range.set_end(end); + range.set_key_pos(key_pos); + return range; +} + +class WeightedRange { + public: + WeightedRange() : range_(), weight_(0.0F) {} + + void set_range(const Range &range) { + range_ = range; + } + void set_begin(std::size_t begin) { + range_.set_begin(begin); + } + void set_end(std::size_t end) { + range_.set_end(end); + } + void set_key_pos(std::size_t key_pos) { + range_.set_key_pos(key_pos); + } + void set_weight(float weight) { + weight_ = weight; + } + + const Range &range() const { + return range_; + } + std::size_t begin() const { + return range_.begin(); + } + std::size_t end() const { + return range_.end(); + } + std::size_t key_pos() const { + return range_.key_pos(); + } + float weight() const { + return weight_; + } + + private: + Range range_; + float weight_; +}; + +inline bool operator<(const WeightedRange &lhs, const WeightedRange &rhs) { + return lhs.weight() < rhs.weight(); +} + +inline bool operator>(const WeightedRange &lhs, const WeightedRange &rhs) { + return lhs.weight() > rhs.weight(); +} + +inline WeightedRange make_weighted_range(std::size_t begin, std::size_t end, + std::size_t key_pos, float weight) { + WeightedRange range; + range.set_begin(begin); + range.set_end(end); + range.set_key_pos(key_pos); + range.set_weight(weight); + return range; +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_RANGE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/state.h b/contrib/python/marisa-trie/marisa/grimoire/trie/state.h new file mode 100644 index 0000000000..219bf9e03a --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/state.h @@ -0,0 +1,118 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_STATE_H_ +#define MARISA_GRIMOIRE_TRIE_STATE_H_ + +#include "../vector.h" +#include "history.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +// A search agent has its internal state and the status codes are defined +// below. +typedef enum StatusCode { + MARISA_READY_TO_ALL, + MARISA_READY_TO_COMMON_PREFIX_SEARCH, + MARISA_READY_TO_PREDICTIVE_SEARCH, + MARISA_END_OF_COMMON_PREFIX_SEARCH, + MARISA_END_OF_PREDICTIVE_SEARCH, +} StatusCode; + +class State { + public: + State() + : key_buf_(), history_(), node_id_(0), query_pos_(0), + history_pos_(0), status_code_(MARISA_READY_TO_ALL) {} + + void set_node_id(std::size_t node_id) { + MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + node_id_ = (UInt32)node_id; + } + void set_query_pos(std::size_t query_pos) { + MARISA_DEBUG_IF(query_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + query_pos_ = (UInt32)query_pos; + } + void set_history_pos(std::size_t history_pos) { + MARISA_DEBUG_IF(history_pos > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + history_pos_ = (UInt32)history_pos; + } + void set_status_code(StatusCode status_code) { + status_code_ = status_code; + } + + std::size_t node_id() const { + return node_id_; + } + std::size_t query_pos() const { + return query_pos_; + } + std::size_t history_pos() const { + return history_pos_; + } + StatusCode status_code() const { + return status_code_; + } + + const Vector<char> &key_buf() const { + return key_buf_; + } + const Vector<History> &history() const { + return history_; + } + + Vector<char> &key_buf() { + return key_buf_; + } + Vector<History> &history() { + return history_; + } + + void reset() { + status_code_ = MARISA_READY_TO_ALL; + } + + void lookup_init() { + node_id_ = 0; + query_pos_ = 0; + status_code_ = MARISA_READY_TO_ALL; + } + void reverse_lookup_init() { + key_buf_.resize(0); + key_buf_.reserve(32); + status_code_ = MARISA_READY_TO_ALL; + } + void common_prefix_search_init() { + node_id_ = 0; + query_pos_ = 0; + status_code_ = MARISA_READY_TO_COMMON_PREFIX_SEARCH; + } + void predictive_search_init() { + key_buf_.resize(0); + key_buf_.reserve(64); + history_.resize(0); + history_.reserve(4); + node_id_ = 0; + query_pos_ = 0; + history_pos_ = 0; + status_code_ = MARISA_READY_TO_PREDICTIVE_SEARCH; + } + + private: + Vector<char> key_buf_; + Vector<History> history_; + UInt32 node_id_; + UInt32 query_pos_; + UInt32 history_pos_; + StatusCode status_code_; + + // Disallows copy and assignment. + State(const State &); + State &operator=(const State &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_STATE_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc new file mode 100644 index 0000000000..6ec3652e1c --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.cc @@ -0,0 +1,218 @@ +#include "../algorithm.h" +#include "state.h" +#include "tail.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +Tail::Tail() : buf_(), end_flags_() {} + +void Tail::build(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode) { + MARISA_THROW_IF(offsets == NULL, MARISA_NULL_ERROR); + + switch (mode) { + case MARISA_TEXT_TAIL: { + for (std::size_t i = 0; i < entries.size(); ++i) { + const char * const ptr = entries[i].ptr(); + const std::size_t length = entries[i].length(); + for (std::size_t j = 0; j < length; ++j) { + if (ptr[j] == '\0') { + mode = MARISA_BINARY_TAIL; + break; + } + } + if (mode == MARISA_BINARY_TAIL) { + break; + } + } + break; + } + case MARISA_BINARY_TAIL: { + break; + } + default: { + MARISA_THROW(MARISA_CODE_ERROR, "undefined tail mode"); + } + } + + Tail temp; + temp.build_(entries, offsets, mode); + swap(temp); +} + +void Tail::map(Mapper &mapper) { + Tail temp; + temp.map_(mapper); + swap(temp); +} + +void Tail::read(Reader &reader) { + Tail temp; + temp.read_(reader); + swap(temp); +} + +void Tail::write(Writer &writer) const { + write_(writer); +} + +void Tail::restore(Agent &agent, std::size_t offset) const { + MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (end_flags_.empty()) { + for (const char *ptr = &buf_[offset]; *ptr != '\0'; ++ptr) { + state.key_buf().push_back(*ptr); + } + } else { + do { + state.key_buf().push_back(buf_[offset]); + } while (!end_flags_[offset++]); + } +} + +bool Tail::match(Agent &agent, std::size_t offset) const { + MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(), + MARISA_BOUND_ERROR); + + State &state = agent.state(); + if (end_flags_.empty()) { + const char * const ptr = &buf_[offset] - state.query_pos(); + do { + if (ptr[state.query_pos()] != agent.query()[state.query_pos()]) { + return false; + } + state.set_query_pos(state.query_pos() + 1); + if (ptr[state.query_pos()] == '\0') { + return true; + } + } while (state.query_pos() < agent.query().length()); + return false; + } else { + do { + if (buf_[offset] != agent.query()[state.query_pos()]) { + return false; + } + state.set_query_pos(state.query_pos() + 1); + if (end_flags_[offset++]) { + return true; + } + } while (state.query_pos() < agent.query().length()); + return false; + } +} + +bool Tail::prefix_match(Agent &agent, std::size_t offset) const { + MARISA_DEBUG_IF(buf_.empty(), MARISA_STATE_ERROR); + + State &state = agent.state(); + if (end_flags_.empty()) { + const char *ptr = &buf_[offset] - state.query_pos(); + do { + if (ptr[state.query_pos()] != agent.query()[state.query_pos()]) { + return false; + } + state.key_buf().push_back(ptr[state.query_pos()]); + state.set_query_pos(state.query_pos() + 1); + if (ptr[state.query_pos()] == '\0') { + return true; + } + } while (state.query_pos() < agent.query().length()); + ptr += state.query_pos(); + do { + state.key_buf().push_back(*ptr); + } while (*++ptr != '\0'); + return true; + } else { + do { + if (buf_[offset] != agent.query()[state.query_pos()]) { + return false; + } + state.key_buf().push_back(buf_[offset]); + state.set_query_pos(state.query_pos() + 1); + if (end_flags_[offset++]) { + return true; + } + } while (state.query_pos() < agent.query().length()); + do { + state.key_buf().push_back(buf_[offset]); + } while (!end_flags_[offset++]); + return true; + } +} + +void Tail::clear() { + Tail().swap(*this); +} + +void Tail::swap(Tail &rhs) { + buf_.swap(rhs.buf_); + end_flags_.swap(rhs.end_flags_); +} + +void Tail::build_(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode) { + for (std::size_t i = 0; i < entries.size(); ++i) { + entries[i].set_id(i); + } + Algorithm().sort(entries.begin(), entries.end()); + + Vector<UInt32> temp_offsets; + temp_offsets.resize(entries.size(), 0); + + const Entry dummy; + const Entry *last = &dummy; + for (std::size_t i = entries.size(); i > 0; --i) { + const Entry ¤t = entries[i - 1]; + MARISA_THROW_IF(current.length() == 0, MARISA_RANGE_ERROR); + std::size_t match = 0; + while ((match < current.length()) && (match < last->length()) && + ((*last)[match] == current[match])) { + ++match; + } + if ((match == current.length()) && (last->length() != 0)) { + temp_offsets[current.id()] = (UInt32)( + temp_offsets[last->id()] + (last->length() - match)); + } else { + temp_offsets[current.id()] = (UInt32)buf_.size(); + for (std::size_t j = 1; j <= current.length(); ++j) { + buf_.push_back(current[current.length() - j]); + } + if (mode == MARISA_TEXT_TAIL) { + buf_.push_back('\0'); + } else { + for (std::size_t j = 1; j < current.length(); ++j) { + end_flags_.push_back(false); + } + end_flags_.push_back(true); + } + MARISA_THROW_IF(buf_.size() > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + } + last = ¤t; + } + buf_.shrink(); + + offsets->swap(temp_offsets); +} + +void Tail::map_(Mapper &mapper) { + buf_.map(mapper); + end_flags_.map(mapper); +} + +void Tail::read_(Reader &reader) { + buf_.read(reader); + end_flags_.read(reader); +} + +void Tail::write_(Writer &writer) const { + buf_.write(writer); + end_flags_.write(writer); +} + +} // namespace trie +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h new file mode 100644 index 0000000000..7e5ca1d3e7 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/trie/tail.h @@ -0,0 +1,73 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_TRIE_TAIL_H_ +#define MARISA_GRIMOIRE_TRIE_TAIL_H_ + +#include "../../agent.h" +#include "../vector.h" +#include "entry.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class Tail { + public: + Tail(); + + void build(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode); + + void map(Mapper &mapper); + void read(Reader &reader); + void write(Writer &writer) const; + + void restore(Agent &agent, std::size_t offset) const; + bool match(Agent &agent, std::size_t offset) const; + bool prefix_match(Agent &agent, std::size_t offset) const; + + const char &operator[](std::size_t offset) const { + MARISA_DEBUG_IF(offset >= buf_.size(), MARISA_BOUND_ERROR); + return buf_[offset]; + } + + TailMode mode() const { + return end_flags_.empty() ? MARISA_TEXT_TAIL : MARISA_BINARY_TAIL; + } + + bool empty() const { + return buf_.empty(); + } + std::size_t size() const { + return buf_.size(); + } + std::size_t total_size() const { + return buf_.total_size() + end_flags_.total_size(); + } + std::size_t io_size() const { + return buf_.io_size() + end_flags_.io_size(); + } + + void clear(); + void swap(Tail &rhs); + + private: + Vector<char> buf_; + BitVector end_flags_; + + void build_(Vector<Entry> &entries, Vector<UInt32> *offsets, + TailMode mode); + + void map_(Mapper &mapper); + void read_(Reader &reader); + void write_(Writer &writer) const; + + // Disallows copy and assignment. + Tail(const Tail &); + Tail &operator=(const Tail &); +}; + +} // namespace trie +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_TRIE_TAIL_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector.h new file mode 100644 index 0000000000..d942a7f279 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector.h @@ -0,0 +1,19 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_H_ + +#include "vector/vector.h" +#include "vector/flat-vector.h" +#include "vector/bit-vector.h" + +namespace marisa { +namespace grimoire { + +using vector::Vector; +typedef vector::FlatVector FlatVector; +typedef vector::BitVector BitVector; + +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.cc b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.cc new file mode 100644 index 0000000000..a5abc69319 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.cc @@ -0,0 +1,825 @@ +#include "pop-count.h" +#include "bit-vector.h" + +namespace marisa { +namespace grimoire { +namespace vector { +namespace { + +const UInt8 SELECT_TABLE[8][256] = { + { + 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, + 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0 + }, + { + 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1, + 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1, + 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1, + 7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1, + 6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1, + 6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1, + 5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1 + }, + { + 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2, + 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2, + 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2, + 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2, + 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2, + 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2, + 7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2, + 7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2, + 7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2, + 7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2, + 7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2, + 6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, + 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3, + 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3, + 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3, + 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, + 7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3, + 7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3, + 7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3, + 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, + 7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6 + }, + { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 + } +}; + +#if MARISA_WORD_SIZE == 64 +const UInt64 MASK_55 = 0x5555555555555555ULL; +const UInt64 MASK_33 = 0x3333333333333333ULL; +const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL; +const UInt64 MASK_01 = 0x0101010101010101ULL; +const UInt64 MASK_80 = 0x8080808080808080ULL; + +std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) { + UInt64 counts; + { + #if defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + __m128i lower_nibbles = _mm_cvtsi64_si128(unit & 0x0F0F0F0F0F0F0F0FULL); + __m128i upper_nibbles = _mm_cvtsi64_si128(unit & 0xF0F0F0F0F0F0F0F0ULL); + upper_nibbles = _mm_srli_epi32(upper_nibbles, 4); + + __m128i lower_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles); + __m128i upper_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles); + + counts = _mm_cvtsi128_si64(_mm_add_epi8(lower_counts, upper_counts)); + #else // defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + counts = unit - ((unit >> 1) & MASK_55); + counts = (counts & MASK_33) + ((counts >> 2) & MASK_33); + counts = (counts + (counts >> 4)) & MASK_0F; + #endif // defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + counts *= MASK_01; + } + + #if defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + UInt8 skip; + { + __m128i x = _mm_cvtsi64_si128((i + 1) * MASK_01); + __m128i y = _mm_cvtsi64_si128(counts); + x = _mm_cmpgt_epi8(x, y); + skip = (UInt8)PopCount::count(_mm_cvtsi128_si64(x)); + } + #else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01); + #ifdef _MSC_VER + unsigned long skip; + ::_BitScanForward64(&skip, (x & MASK_80) >> 7); + #else // _MSC_VER + const int skip = ::__builtin_ctzll((x & MASK_80) >> 7); + #endif // _MSC_VER + #endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + + bit_id += skip; + unit >>= skip; + i -= ((counts << 8) >> skip) & 0xFF; + + return bit_id + SELECT_TABLE[i][unit & 0xFF]; +} +#else // MARISA_WORD_SIZE == 64 + #ifdef MARISA_USE_SSE2 +const UInt8 POPCNT_TABLE[256] = { + 0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56, + 32, 40, 40, 48, 40, 48, 48, 56, 40, 48, 48, 56, 48, 56, 56, 64 +}; + +std::size_t select_bit(std::size_t i, std::size_t bit_id, + UInt32 unit_lo, UInt32 unit_hi) { + __m128i unit; + { + __m128i lower_dword = _mm_cvtsi32_si128(unit_lo); + __m128i upper_dword = _mm_cvtsi32_si128(unit_hi); + upper_dword = _mm_slli_si128(upper_dword, 4); + unit = _mm_or_si128(lower_dword, upper_dword); + } + + __m128i counts; + { + #ifdef MARISA_USE_SSSE3 + __m128i lower_nibbles = _mm_set1_epi8(0x0F); + lower_nibbles = _mm_and_si128(lower_nibbles, unit); + __m128i upper_nibbles = _mm_set1_epi8((UInt8)0xF0); + upper_nibbles = _mm_and_si128(upper_nibbles, unit); + upper_nibbles = _mm_srli_epi32(upper_nibbles, 4); + + __m128i lower_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles); + __m128i upper_counts = + _mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0); + upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles); + + counts = _mm_add_epi8(lower_counts, upper_counts); + #else // MARISA_USE_SSSE3 + __m128i x = _mm_srli_epi32(unit, 1); + x = _mm_and_si128(x, _mm_set1_epi8(0x55)); + x = _mm_sub_epi8(unit, x); + + __m128i y = _mm_srli_epi32(x, 2); + y = _mm_and_si128(y, _mm_set1_epi8(0x33)); + x = _mm_and_si128(x, _mm_set1_epi8(0x33)); + x = _mm_add_epi8(x, y); + + y = _mm_srli_epi32(x, 4); + x = _mm_add_epi8(x, y); + counts = _mm_and_si128(x, _mm_set1_epi8(0x0F)); + #endif // MARISA_USE_SSSE3 + } + + __m128i accumulated_counts; + { + __m128i x = counts; + x = _mm_slli_si128(x, 1); + __m128i y = counts; + y = _mm_add_epi32(y, x); + + x = y; + y = _mm_slli_si128(y, 2); + x = _mm_add_epi32(x, y); + + y = x; + x = _mm_slli_si128(x, 4); + y = _mm_add_epi32(y, x); + + accumulated_counts = _mm_set_epi32(0x7F7F7F7FU, 0x7F7F7F7FU, 0, 0); + accumulated_counts = _mm_or_si128(accumulated_counts, y); + } + + UInt8 skip; + { + __m128i x = _mm_set1_epi8((UInt8)(i + 1)); + x = _mm_cmpgt_epi8(x, accumulated_counts); + skip = POPCNT_TABLE[_mm_movemask_epi8(x)]; + } + + UInt8 byte; + { + #ifdef _MSC_VER + __declspec(align(16)) UInt8 unit_bytes[16]; + __declspec(align(16)) UInt8 accumulated_counts_bytes[16]; + #else // _MSC_VER + UInt8 unit_bytes[16] __attribute__ ((aligned (16))); + UInt8 accumulated_counts_bytes[16] __attribute__ ((aligned (16))); + #endif // _MSC_VER + accumulated_counts = _mm_slli_si128(accumulated_counts, 1); + _mm_store_si128(reinterpret_cast<__m128i *>(unit_bytes), unit); + _mm_store_si128(reinterpret_cast<__m128i *>(accumulated_counts_bytes), + accumulated_counts); + + bit_id += skip; + byte = unit_bytes[skip / 8]; + i -= accumulated_counts_bytes[skip / 8]; + } + + return bit_id + SELECT_TABLE[i][byte]; +} + #endif // MARISA_USE_SSE2 +#endif // MARISA_WORD_SIZE == 64 + +} // namespace + +#if MARISA_WORD_SIZE == 64 + +std::size_t BitVector::rank1(std::size_t i) const { + MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR); + + const RankIndex &rank = ranks_[i / 512]; + std::size_t offset = rank.abs(); + switch ((i / 64) % 8) { + case 1: { + offset += rank.rel1(); + break; + } + case 2: { + offset += rank.rel2(); + break; + } + case 3: { + offset += rank.rel3(); + break; + } + case 4: { + offset += rank.rel4(); + break; + } + case 5: { + offset += rank.rel5(); + break; + } + case 6: { + offset += rank.rel6(); + break; + } + case 7: { + offset += rank.rel7(); + break; + } + } + offset += PopCount::count(units_[i / 64] & ((1ULL << (i % 64)) - 1)); + return offset; +} + +std::size_t BitVector::select0(std::size_t i) const { + MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select0s_[select_id]; + } + std::size_t begin = select0s_[select_id] / 512; + std::size_t end = (select0s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < (middle * 512) - ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= (rank_id * 512) - ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 8; + if (i < (256U - rank.rel4())) { + if (i < (128U - rank.rel2())) { + if (i >= (64U - rank.rel1())) { + unit_id += 1; + i -= 64 - rank.rel1(); + } + } else if (i < (192U - rank.rel3())) { + unit_id += 2; + i -= 128 - rank.rel2(); + } else { + unit_id += 3; + i -= 192 - rank.rel3(); + } + } else if (i < (384U - rank.rel6())) { + if (i < (320U - rank.rel5())) { + unit_id += 4; + i -= 256 - rank.rel4(); + } else { + unit_id += 5; + i -= 320 - rank.rel5(); + } + } else if (i < (448U - rank.rel7())) { + unit_id += 6; + i -= 384 - rank.rel6(); + } else { + unit_id += 7; + i -= 448 - rank.rel7(); + } + + return select_bit(i, unit_id * 64, ~units_[unit_id]); +} + +std::size_t BitVector::select1(std::size_t i) const { + MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select1s_[select_id]; + } + std::size_t begin = select1s_[select_id] / 512; + std::size_t end = (select1s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 8; + if (i < rank.rel4()) { + if (i < rank.rel2()) { + if (i >= rank.rel1()) { + unit_id += 1; + i -= rank.rel1(); + } + } else if (i < rank.rel3()) { + unit_id += 2; + i -= rank.rel2(); + } else { + unit_id += 3; + i -= rank.rel3(); + } + } else if (i < rank.rel6()) { + if (i < rank.rel5()) { + unit_id += 4; + i -= rank.rel4(); + } else { + unit_id += 5; + i -= rank.rel5(); + } + } else if (i < rank.rel7()) { + unit_id += 6; + i -= rank.rel6(); + } else { + unit_id += 7; + i -= rank.rel7(); + } + + return select_bit(i, unit_id * 64, units_[unit_id]); +} + +#else // MARISA_WORD_SIZE == 64 + +std::size_t BitVector::rank1(std::size_t i) const { + MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR); + + const RankIndex &rank = ranks_[i / 512]; + std::size_t offset = rank.abs(); + switch ((i / 64) % 8) { + case 1: { + offset += rank.rel1(); + break; + } + case 2: { + offset += rank.rel2(); + break; + } + case 3: { + offset += rank.rel3(); + break; + } + case 4: { + offset += rank.rel4(); + break; + } + case 5: { + offset += rank.rel5(); + break; + } + case 6: { + offset += rank.rel6(); + break; + } + case 7: { + offset += rank.rel7(); + break; + } + } + if (((i / 32) & 1) == 1) { + offset += PopCount::count(units_[(i / 32) - 1]); + } + offset += PopCount::count(units_[i / 32] & ((1U << (i % 32)) - 1)); + return offset; +} + +std::size_t BitVector::select0(std::size_t i) const { + MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select0s_[select_id]; + } + std::size_t begin = select0s_[select_id] / 512; + std::size_t end = (select0s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < (middle * 512) - ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= (rank_id * 512) - ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 16; + if (i < (256U - rank.rel4())) { + if (i < (128U - rank.rel2())) { + if (i >= (64U - rank.rel1())) { + unit_id += 2; + i -= 64 - rank.rel1(); + } + } else if (i < (192U - rank.rel3())) { + unit_id += 4; + i -= 128 - rank.rel2(); + } else { + unit_id += 6; + i -= 192 - rank.rel3(); + } + } else if (i < (384U - rank.rel6())) { + if (i < (320U - rank.rel5())) { + unit_id += 8; + i -= 256 - rank.rel4(); + } else { + unit_id += 10; + i -= 320 - rank.rel5(); + } + } else if (i < (448U - rank.rel7())) { + unit_id += 12; + i -= 384 - rank.rel6(); + } else { + unit_id += 14; + i -= 448 - rank.rel7(); + } + +#ifdef MARISA_USE_SSE2 + return select_bit(i, unit_id * 32, ~units_[unit_id], ~units_[unit_id + 1]); +#else // MARISA_USE_SSE2 + UInt32 unit = ~units_[unit_id]; + PopCount count(unit); + if (i >= count.lo32()) { + ++unit_id; + i -= count.lo32(); + unit = ~units_[unit_id]; + count = PopCount(unit); + } + + std::size_t bit_id = unit_id * 32; + if (i < count.lo16()) { + if (i >= count.lo8()) { + bit_id += 8; + unit >>= 8; + i -= count.lo8(); + } + } else if (i < count.lo24()) { + bit_id += 16; + unit >>= 16; + i -= count.lo16(); + } else { + bit_id += 24; + unit >>= 24; + i -= count.lo24(); + } + return bit_id + SELECT_TABLE[i][unit & 0xFF]; +#endif // MARISA_USE_SSE2 +} + +std::size_t BitVector::select1(std::size_t i) const { + MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR); + + const std::size_t select_id = i / 512; + MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR); + if ((i % 512) == 0) { + return select1s_[select_id]; + } + std::size_t begin = select1s_[select_id] / 512; + std::size_t end = (select1s_[select_id + 1] + 511) / 512; + if (begin + 10 >= end) { + while (i >= ranks_[begin + 1].abs()) { + ++begin; + } + } else { + while (begin + 1 < end) { + const std::size_t middle = (begin + end) / 2; + if (i < ranks_[middle].abs()) { + end = middle; + } else { + begin = middle; + } + } + } + const std::size_t rank_id = begin; + i -= ranks_[rank_id].abs(); + + const RankIndex &rank = ranks_[rank_id]; + std::size_t unit_id = rank_id * 16; + if (i < rank.rel4()) { + if (i < rank.rel2()) { + if (i >= rank.rel1()) { + unit_id += 2; + i -= rank.rel1(); + } + } else if (i < rank.rel3()) { + unit_id += 4; + i -= rank.rel2(); + } else { + unit_id += 6; + i -= rank.rel3(); + } + } else if (i < rank.rel6()) { + if (i < rank.rel5()) { + unit_id += 8; + i -= rank.rel4(); + } else { + unit_id += 10; + i -= rank.rel5(); + } + } else if (i < rank.rel7()) { + unit_id += 12; + i -= rank.rel6(); + } else { + unit_id += 14; + i -= rank.rel7(); + } + +#ifdef MARISA_USE_SSE2 + return select_bit(i, unit_id * 32, units_[unit_id], units_[unit_id + 1]); +#else // MARISA_USE_SSE2 + UInt32 unit = units_[unit_id]; + PopCount count(unit); + if (i >= count.lo32()) { + ++unit_id; + i -= count.lo32(); + unit = units_[unit_id]; + count = PopCount(unit); + } + + std::size_t bit_id = unit_id * 32; + if (i < count.lo16()) { + if (i >= count.lo8()) { + bit_id += 8; + unit >>= 8; + i -= count.lo8(); + } + } else if (i < count.lo24()) { + bit_id += 16; + unit >>= 16; + i -= count.lo16(); + } else { + bit_id += 24; + unit >>= 24; + i -= count.lo24(); + } + return bit_id + SELECT_TABLE[i][unit & 0xFF]; +#endif // MARISA_USE_SSE2 +} + +#endif // MARISA_WORD_SIZE == 64 + +void BitVector::build_index(const BitVector &bv, + bool enables_select0, bool enables_select1) { + ranks_.resize((bv.size() / 512) + (((bv.size() % 512) != 0) ? 1 : 0) + 1); + + std::size_t num_0s = 0; + std::size_t num_1s = 0; + + for (std::size_t i = 0; i < bv.size(); ++i) { + if ((i % 64) == 0) { + const std::size_t rank_id = i / 512; + switch ((i / 64) % 8) { + case 0: { + ranks_[rank_id].set_abs(num_1s); + break; + } + case 1: { + ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs()); + break; + } + case 2: { + ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs()); + break; + } + case 3: { + ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs()); + break; + } + case 4: { + ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs()); + break; + } + case 5: { + ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs()); + break; + } + case 6: { + ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs()); + break; + } + case 7: { + ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs()); + break; + } + } + } + + if (bv[i]) { + if (enables_select1 && ((num_1s % 512) == 0)) { + select1s_.push_back(static_cast<UInt32>(i)); + } + ++num_1s; + } else { + if (enables_select0 && ((num_0s % 512) == 0)) { + select0s_.push_back(static_cast<UInt32>(i)); + } + ++num_0s; + } + } + + if ((bv.size() % 512) != 0) { + const std::size_t rank_id = (bv.size() - 1) / 512; + switch (((bv.size() - 1) / 64) % 8) { + case 0: { + ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs()); + } + case 1: { + ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs()); + } + case 2: { + ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs()); + } + case 3: { + ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs()); + } + case 4: { + ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs()); + } + case 5: { + ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs()); + } + case 6: { + ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs()); + break; + } + } + } + + size_ = bv.size(); + num_1s_ = bv.num_1s(); + + ranks_.back().set_abs(num_1s); + if (enables_select0) { + select0s_.push_back(static_cast<UInt32>(bv.size())); + select0s_.shrink(); + } + if (enables_select1) { + select1s_.push_back(static_cast<UInt32>(bv.size())); + select1s_.shrink(); + } +} + +} // namespace vector +} // namespace grimoire +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.h new file mode 100644 index 0000000000..56f99ed699 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/bit-vector.h @@ -0,0 +1,180 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_BIT_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_BIT_VECTOR_H_ + +#include "rank-index.h" +#include "vector.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +class BitVector { + public: +#if MARISA_WORD_SIZE == 64 + typedef UInt64 Unit; +#else // MARISA_WORD_SIZE == 64 + typedef UInt32 Unit; +#endif // MARISA_WORD_SIZE == 64 + + BitVector() + : units_(), size_(0), num_1s_(0), ranks_(), select0s_(), select1s_() {} + + void build(bool enables_select0, bool enables_select1) { + BitVector temp; + temp.build_index(*this, enables_select0, enables_select1); + units_.shrink(); + temp.units_.swap(units_); + swap(temp); + } + + void map(Mapper &mapper) { + BitVector temp; + temp.map_(mapper); + swap(temp); + } + void read(Reader &reader) { + BitVector temp; + temp.read_(reader); + swap(temp); + } + void write(Writer &writer) const { + write_(writer); + } + + void disable_select0() { + select0s_.clear(); + } + void disable_select1() { + select1s_.clear(); + } + + void push_back(bool bit) { + MARISA_THROW_IF(size_ == MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + if (size_ == (MARISA_WORD_SIZE * units_.size())) { + units_.resize(units_.size() + (64 / MARISA_WORD_SIZE), 0); + } + if (bit) { + units_[size_ / MARISA_WORD_SIZE] |= + (Unit)1 << (size_ % MARISA_WORD_SIZE); + ++num_1s_; + } + ++size_; + } + + bool operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return (units_[i / MARISA_WORD_SIZE] + & ((Unit)1 << (i % MARISA_WORD_SIZE))) != 0; + } + + std::size_t rank0(std::size_t i) const { + MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR); + return i - rank1(i); + } + std::size_t rank1(std::size_t i) const; + + std::size_t select0(std::size_t i) const; + std::size_t select1(std::size_t i) const; + + std::size_t num_0s() const { + return size_ - num_1s_; + } + std::size_t num_1s() const { + return num_1s_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t size() const { + return size_; + } + std::size_t total_size() const { + return units_.total_size() + ranks_.total_size() + + select0s_.total_size() + select1s_.total_size(); + } + std::size_t io_size() const { + return units_.io_size() + (sizeof(UInt32) * 2) + ranks_.io_size() + + select0s_.io_size() + select1s_.io_size(); + } + + void clear() { + BitVector().swap(*this); + } + void swap(BitVector &rhs) { + units_.swap(rhs.units_); + marisa::swap(size_, rhs.size_); + marisa::swap(num_1s_, rhs.num_1s_); + ranks_.swap(rhs.ranks_); + select0s_.swap(rhs.select0s_); + select1s_.swap(rhs.select1s_); + } + + private: + Vector<Unit> units_; + std::size_t size_; + std::size_t num_1s_; + Vector<RankIndex> ranks_; + Vector<UInt32> select0s_; + Vector<UInt32> select1s_; + + void build_index(const BitVector &bv, + bool enables_select0, bool enables_select1); + + void map_(Mapper &mapper) { + units_.map(mapper); + { + UInt32 temp_size; + mapper.map(&temp_size); + size_ = temp_size; + } + { + UInt32 temp_num_1s; + mapper.map(&temp_num_1s); + MARISA_THROW_IF(temp_num_1s > size_, MARISA_FORMAT_ERROR); + num_1s_ = temp_num_1s; + } + ranks_.map(mapper); + select0s_.map(mapper); + select1s_.map(mapper); + } + + void read_(Reader &reader) { + units_.read(reader); + { + UInt32 temp_size; + reader.read(&temp_size); + size_ = temp_size; + } + { + UInt32 temp_num_1s; + reader.read(&temp_num_1s); + MARISA_THROW_IF(temp_num_1s > size_, MARISA_FORMAT_ERROR); + num_1s_ = temp_num_1s; + } + ranks_.read(reader); + select0s_.read(reader); + select1s_.read(reader); + } + + void write_(Writer &writer) const { + units_.write(writer); + writer.write((UInt32)size_); + writer.write((UInt32)num_1s_); + ranks_.write(writer); + select0s_.write(writer); + select1s_.write(writer); + } + + // Disallows copy and assignment. + BitVector(const BitVector &); + BitVector &operator=(const BitVector &); +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_BIT_VECTOR_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/flat-vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector/flat-vector.h new file mode 100644 index 0000000000..14b25d7d82 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/flat-vector.h @@ -0,0 +1,206 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_FLAT_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_FLAT_VECTOR_H_ + +#include "vector.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +class FlatVector { + public: +#if MARISA_WORD_SIZE == 64 + typedef UInt64 Unit; +#else // MARISA_WORD_SIZE == 64 + typedef UInt32 Unit; +#endif // MARISA_WORD_SIZE == 64 + + FlatVector() : units_(), value_size_(0), mask_(0), size_(0) {} + + void build(const Vector<UInt32> &values) { + FlatVector temp; + temp.build_(values); + swap(temp); + } + + void map(Mapper &mapper) { + FlatVector temp; + temp.map_(mapper); + swap(temp); + } + void read(Reader &reader) { + FlatVector temp; + temp.read_(reader); + swap(temp); + } + void write(Writer &writer) const { + write_(writer); + } + + UInt32 operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + + const std::size_t pos = i * value_size_; + const std::size_t unit_id = pos / MARISA_WORD_SIZE; + const std::size_t unit_offset = pos % MARISA_WORD_SIZE; + + if ((unit_offset + value_size_) <= MARISA_WORD_SIZE) { + return (UInt32)(units_[unit_id] >> unit_offset) & mask_; + } else { + return (UInt32)((units_[unit_id] >> unit_offset) + | (units_[unit_id + 1] << (MARISA_WORD_SIZE - unit_offset))) & mask_; + } + } + + std::size_t value_size() const { + return value_size_; + } + UInt32 mask() const { + return mask_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t size() const { + return size_; + } + std::size_t total_size() const { + return units_.total_size(); + } + std::size_t io_size() const { + return units_.io_size() + (sizeof(UInt32) * 2) + sizeof(UInt64); + } + + void clear() { + FlatVector().swap(*this); + } + void swap(FlatVector &rhs) { + units_.swap(rhs.units_); + marisa::swap(value_size_, rhs.value_size_); + marisa::swap(mask_, rhs.mask_); + marisa::swap(size_, rhs.size_); + } + + private: + Vector<Unit> units_; + std::size_t value_size_; + UInt32 mask_; + std::size_t size_; + + void build_(const Vector<UInt32> &values) { + UInt32 max_value = 0; + for (std::size_t i = 0; i < values.size(); ++i) { + if (values[i] > max_value) { + max_value = values[i]; + } + } + + std::size_t value_size = 0; + while (max_value != 0) { + ++value_size; + max_value >>= 1; + } + + std::size_t num_units = values.empty() ? 0 : (64 / MARISA_WORD_SIZE); + if (value_size != 0) { + num_units = (std::size_t)( + (((UInt64)value_size * values.size()) + (MARISA_WORD_SIZE - 1)) + / MARISA_WORD_SIZE); + num_units += num_units % (64 / MARISA_WORD_SIZE); + } + + units_.resize(num_units); + if (num_units > 0) { + units_.back() = 0; + } + + value_size_ = value_size; + if (value_size != 0) { + mask_ = MARISA_UINT32_MAX >> (32 - value_size); + } + size_ = values.size(); + + for (std::size_t i = 0; i < values.size(); ++i) { + set(i, values[i]); + } + } + + void map_(Mapper &mapper) { + units_.map(mapper); + { + UInt32 temp_value_size; + mapper.map(&temp_value_size); + MARISA_THROW_IF(temp_value_size > 32, MARISA_FORMAT_ERROR); + value_size_ = temp_value_size; + } + { + UInt32 temp_mask; + mapper.map(&temp_mask); + mask_ = temp_mask; + } + { + UInt64 temp_size; + mapper.map(&temp_size); + MARISA_THROW_IF(temp_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + size_ = (std::size_t)temp_size; + } + } + + void read_(Reader &reader) { + units_.read(reader); + { + UInt32 temp_value_size; + reader.read(&temp_value_size); + MARISA_THROW_IF(temp_value_size > 32, MARISA_FORMAT_ERROR); + value_size_ = temp_value_size; + } + { + UInt32 temp_mask; + reader.read(&temp_mask); + mask_ = temp_mask; + } + { + UInt64 temp_size; + reader.read(&temp_size); + MARISA_THROW_IF(temp_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + size_ = (std::size_t)temp_size; + } + } + + void write_(Writer &writer) const { + units_.write(writer); + writer.write((UInt32)value_size_); + writer.write((UInt32)mask_); + writer.write((UInt64)size_); + } + + void set(std::size_t i, UInt32 value) { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + MARISA_DEBUG_IF(value > mask_, MARISA_RANGE_ERROR); + + const std::size_t pos = i * value_size_; + const std::size_t unit_id = pos / MARISA_WORD_SIZE; + const std::size_t unit_offset = pos % MARISA_WORD_SIZE; + + units_[unit_id] &= ~((Unit)mask_ << unit_offset); + units_[unit_id] |= (Unit)(value & mask_) << unit_offset; + if ((unit_offset + value_size_) > MARISA_WORD_SIZE) { + units_[unit_id + 1] &= + ~((Unit)mask_ >> (MARISA_WORD_SIZE - unit_offset)); + units_[unit_id + 1] |= + (Unit)(value & mask_) >> (MARISA_WORD_SIZE - unit_offset); + } + } + + // Disallows copy and assignment. + FlatVector(const FlatVector &); + FlatVector &operator=(const FlatVector &); +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_FLAT_VECTOR_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/pop-count.h b/contrib/python/marisa-trie/marisa/grimoire/vector/pop-count.h new file mode 100644 index 0000000000..6d04cf831d --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/pop-count.h @@ -0,0 +1,111 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_POP_COUNT_H_ +#define MARISA_GRIMOIRE_VECTOR_POP_COUNT_H_ + +#include "../intrin.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +#if MARISA_WORD_SIZE == 64 + +class PopCount { + public: + explicit PopCount(UInt64 x) : value_() { + x = (x & 0x5555555555555555ULL) + ((x & 0xAAAAAAAAAAAAAAAAULL) >> 1); + x = (x & 0x3333333333333333ULL) + ((x & 0xCCCCCCCCCCCCCCCCULL) >> 2); + x = (x & 0x0F0F0F0F0F0F0F0FULL) + ((x & 0xF0F0F0F0F0F0F0F0ULL) >> 4); + x *= 0x0101010101010101ULL; + value_ = x; + } + + std::size_t lo8() const { + return (std::size_t)(value_ & 0xFFU); + } + std::size_t lo16() const { + return (std::size_t)((value_ >> 8) & 0xFFU); + } + std::size_t lo24() const { + return (std::size_t)((value_ >> 16) & 0xFFU); + } + std::size_t lo32() const { + return (std::size_t)((value_ >> 24) & 0xFFU); + } + std::size_t lo40() const { + return (std::size_t)((value_ >> 32) & 0xFFU); + } + std::size_t lo48() const { + return (std::size_t)((value_ >> 40) & 0xFFU); + } + std::size_t lo56() const { + return (std::size_t)((value_ >> 48) & 0xFFU); + } + std::size_t lo64() const { + return (std::size_t)((value_ >> 56) & 0xFFU); + } + + static std::size_t count(UInt64 x) { +#if defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + #ifdef _MSC_VER + return __popcnt64(x); + #else // _MSC_VER + return _mm_popcnt_u64(x); + #endif // _MSC_VER +#else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + return PopCount(x).lo64(); +#endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) + } + + private: + UInt64 value_; +}; + +#else // MARISA_WORD_SIZE == 64 + +class PopCount { + public: + explicit PopCount(UInt32 x) : value_() { + x = (x & 0x55555555U) + ((x & 0xAAAAAAAAU) >> 1); + x = (x & 0x33333333U) + ((x & 0xCCCCCCCCU) >> 2); + x = (x & 0x0F0F0F0FU) + ((x & 0xF0F0F0F0U) >> 4); + x *= 0x01010101U; + value_ = x; + } + + std::size_t lo8() const { + return value_ & 0xFFU; + } + std::size_t lo16() const { + return (value_ >> 8) & 0xFFU; + } + std::size_t lo24() const { + return (value_ >> 16) & 0xFFU; + } + std::size_t lo32() const { + return (value_ >> 24) & 0xFFU; + } + + static std::size_t count(UInt32 x) { +#ifdef MARISA_USE_POPCNT + #ifdef _MSC_VER + return __popcnt(x); + #else // _MSC_VER + return _mm_popcnt_u32(x); + #endif // _MSC_VER +#else // MARISA_USE_POPCNT + return PopCount(x).lo32(); +#endif // MARISA_USE_POPCNT + } + + private: + UInt32 value_; +}; + +#endif // MARISA_WORD_SIZE == 64 + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_POP_COUNT_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/rank-index.h b/contrib/python/marisa-trie/marisa/grimoire/vector/rank-index.h new file mode 100644 index 0000000000..2403709954 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/rank-index.h @@ -0,0 +1,83 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_RANK_INDEX_H_ +#define MARISA_GRIMOIRE_VECTOR_RANK_INDEX_H_ + +#include "../../base.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +class RankIndex { + public: + RankIndex() : abs_(0), rel_lo_(0), rel_hi_(0) {} + + void set_abs(std::size_t value) { + MARISA_DEBUG_IF(value > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + abs_ = (UInt32)value; + } + void set_rel1(std::size_t value) { + MARISA_DEBUG_IF(value > 64, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~0x7FU) | (value & 0x7FU)); + } + void set_rel2(std::size_t value) { + MARISA_DEBUG_IF(value > 128, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~(0xFFU << 7)) | ((value & 0xFFU) << 7)); + } + void set_rel3(std::size_t value) { + MARISA_DEBUG_IF(value > 192, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~(0xFFU << 15)) | ((value & 0xFFU) << 15)); + } + void set_rel4(std::size_t value) { + MARISA_DEBUG_IF(value > 256, MARISA_RANGE_ERROR); + rel_lo_ = (UInt32)((rel_lo_ & ~(0x1FFU << 23)) | ((value & 0x1FFU) << 23)); + } + void set_rel5(std::size_t value) { + MARISA_DEBUG_IF(value > 320, MARISA_RANGE_ERROR); + rel_hi_ = (UInt32)((rel_hi_ & ~0x1FFU) | (value & 0x1FFU)); + } + void set_rel6(std::size_t value) { + MARISA_DEBUG_IF(value > 384, MARISA_RANGE_ERROR); + rel_hi_ = (UInt32)((rel_hi_ & ~(0x1FFU << 9)) | ((value & 0x1FFU) << 9)); + } + void set_rel7(std::size_t value) { + MARISA_DEBUG_IF(value > 448, MARISA_RANGE_ERROR); + rel_hi_ = (UInt32)((rel_hi_ & ~(0x1FFU << 18)) | ((value & 0x1FFU) << 18)); + } + + std::size_t abs() const { + return abs_; + } + std::size_t rel1() const { + return rel_lo_ & 0x7FU; + } + std::size_t rel2() const { + return (rel_lo_ >> 7) & 0xFFU; + } + std::size_t rel3() const { + return (rel_lo_ >> 15) & 0xFFU; + } + std::size_t rel4() const { + return (rel_lo_ >> 23) & 0x1FFU; + } + std::size_t rel5() const { + return rel_hi_ & 0x1FFU; + } + std::size_t rel6() const { + return (rel_hi_ >> 9) & 0x1FFU; + } + std::size_t rel7() const { + return (rel_hi_ >> 18) & 0x1FFU; + } + + private: + UInt32 abs_; + UInt32 rel_lo_; + UInt32 rel_hi_; +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_RANK_INDEX_H_ diff --git a/contrib/python/marisa-trie/marisa/grimoire/vector/vector.h b/contrib/python/marisa-trie/marisa/grimoire/vector/vector.h new file mode 100644 index 0000000000..148cc8b491 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/grimoire/vector/vector.h @@ -0,0 +1,257 @@ +#pragma once +#ifndef MARISA_GRIMOIRE_VECTOR_VECTOR_H_ +#define MARISA_GRIMOIRE_VECTOR_VECTOR_H_ + +#include <new> + +#include "../io.h" + +namespace marisa { +namespace grimoire { +namespace vector { + +template <typename T> +class Vector { + public: + Vector() + : buf_(), objs_(NULL), const_objs_(NULL), + size_(0), capacity_(0), fixed_(false) {} + ~Vector() { + if (objs_ != NULL) { + for (std::size_t i = 0; i < size_; ++i) { + objs_[i].~T(); + } + } + } + + void map(Mapper &mapper) { + Vector temp; + temp.map_(mapper); + swap(temp); + } + + void read(Reader &reader) { + Vector temp; + temp.read_(reader); + swap(temp); + } + + void write(Writer &writer) const { + write_(writer); + } + + void push_back(const T &x) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == max_size(), MARISA_SIZE_ERROR); + reserve(size_ + 1); + new (&objs_[size_]) T(x); + ++size_; + } + + void pop_back() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + objs_[--size_].~T(); + } + + // resize() assumes that T's placement new does not throw an exception. + void resize(std::size_t size) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + reserve(size); + for (std::size_t i = size_; i < size; ++i) { + new (&objs_[i]) T; + } + for (std::size_t i = size; i < size_; ++i) { + objs_[i].~T(); + } + size_ = size; + } + + // resize() assumes that T's placement new does not throw an exception. + void resize(std::size_t size, const T &x) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + reserve(size); + for (std::size_t i = size_; i < size; ++i) { + new (&objs_[i]) T(x); + } + for (std::size_t i = size; i < size_; ++i) { + objs_[i].~T(); + } + size_ = size; + } + + void reserve(std::size_t capacity) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + if (capacity <= capacity_) { + return; + } + MARISA_DEBUG_IF(capacity > max_size(), MARISA_SIZE_ERROR); + std::size_t new_capacity = capacity; + if (capacity_ > (capacity / 2)) { + if (capacity_ > (max_size() / 2)) { + new_capacity = max_size(); + } else { + new_capacity = capacity_ * 2; + } + } + realloc(new_capacity); + } + + void shrink() { + MARISA_THROW_IF(fixed_, MARISA_STATE_ERROR); + if (size_ != capacity_) { + realloc(size_); + } + } + + void fix() { + MARISA_THROW_IF(fixed_, MARISA_STATE_ERROR); + fixed_ = true; + } + + const T *begin() const { + return const_objs_; + } + const T *end() const { + return const_objs_ + size_; + } + const T &operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return const_objs_[i]; + } + const T &front() const { + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return const_objs_[0]; + } + const T &back() const { + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return const_objs_[size_ - 1]; + } + + T *begin() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + return objs_; + } + T *end() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + return objs_ + size_; + } + T &operator[](std::size_t i) { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return objs_[i]; + } + T &front() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return objs_[0]; + } + T &back() { + MARISA_DEBUG_IF(fixed_, MARISA_STATE_ERROR); + MARISA_DEBUG_IF(size_ == 0, MARISA_STATE_ERROR); + return objs_[size_ - 1]; + } + + std::size_t size() const { + return size_; + } + std::size_t capacity() const { + return capacity_; + } + bool fixed() const { + return fixed_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t total_size() const { + return sizeof(T) * size_; + } + std::size_t io_size() const { + return sizeof(UInt64) + ((total_size() + 7) & ~(std::size_t)0x07); + } + + void clear() { + Vector().swap(*this); + } + void swap(Vector &rhs) { + buf_.swap(rhs.buf_); + marisa::swap(objs_, rhs.objs_); + marisa::swap(const_objs_, rhs.const_objs_); + marisa::swap(size_, rhs.size_); + marisa::swap(capacity_, rhs.capacity_); + marisa::swap(fixed_, rhs.fixed_); + } + + static std::size_t max_size() { + return MARISA_SIZE_MAX / sizeof(T); + } + + private: + scoped_array<char> buf_; + T *objs_; + const T *const_objs_; + std::size_t size_; + std::size_t capacity_; + bool fixed_; + + void map_(Mapper &mapper) { + UInt64 total_size; + mapper.map(&total_size); + MARISA_THROW_IF(total_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + MARISA_THROW_IF((total_size % sizeof(T)) != 0, MARISA_FORMAT_ERROR); + const std::size_t size = (std::size_t)(total_size / sizeof(T)); + mapper.map(&const_objs_, size); + mapper.seek((std::size_t)((8 - (total_size % 8)) % 8)); + size_ = size; + fix(); + } + void read_(Reader &reader) { + UInt64 total_size; + reader.read(&total_size); + MARISA_THROW_IF(total_size > MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + MARISA_THROW_IF((total_size % sizeof(T)) != 0, MARISA_FORMAT_ERROR); + const std::size_t size = (std::size_t)(total_size / sizeof(T)); + resize(size); + reader.read(objs_, size); + reader.seek((std::size_t)((8 - (total_size % 8)) % 8)); + } + void write_(Writer &writer) const { + writer.write((UInt64)total_size()); + writer.write(const_objs_, size_); + writer.seek((8 - (total_size() % 8)) % 8); + } + + // realloc() assumes that T's placement new does not throw an exception. + void realloc(std::size_t new_capacity) { + MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR); + + scoped_array<char> new_buf( + new (std::nothrow) char[sizeof(T) * new_capacity]); + MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR); + T *new_objs = reinterpret_cast<T *>(new_buf.get()); + + for (std::size_t i = 0; i < size_; ++i) { + new (&new_objs[i]) T(objs_[i]); + } + for (std::size_t i = 0; i < size_; ++i) { + objs_[i].~T(); + } + + buf_.swap(new_buf); + objs_ = new_objs; + const_objs_ = new_objs; + capacity_ = new_capacity; + } + + // Disallows copy and assignment. + Vector(const Vector &); + Vector &operator=(const Vector &); +}; + +} // namespace vector +} // namespace grimoire +} // namespace marisa + +#endif // MARISA_GRIMOIRE_VECTOR_VECTOR_H_ diff --git a/contrib/python/marisa-trie/marisa/iostream.h b/contrib/python/marisa-trie/marisa/iostream.h new file mode 100644 index 0000000000..da5ec77a6c --- /dev/null +++ b/contrib/python/marisa-trie/marisa/iostream.h @@ -0,0 +1,19 @@ +#pragma once +#ifndef MARISA_IOSTREAM_H_ +#define MARISA_IOSTREAM_H_ + +#include <iosfwd> + +namespace marisa { + +class Trie; + +std::istream &read(std::istream &stream, Trie *trie); +std::ostream &write(std::ostream &stream, const Trie &trie); + +std::istream &operator>>(std::istream &stream, Trie &trie); +std::ostream &operator<<(std::ostream &stream, const Trie &trie); + +} // namespace marisa + +#endif // MARISA_IOSTREAM_H_ diff --git a/contrib/python/marisa-trie/marisa/key.h b/contrib/python/marisa-trie/marisa/key.h new file mode 100644 index 0000000000..48e03226c4 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/key.h @@ -0,0 +1,86 @@ +#pragma once +#ifndef MARISA_KEY_H_ +#define MARISA_KEY_H_ + +#include "base.h" + +namespace marisa { + +class Key { + public: + Key() : ptr_(NULL), length_(0), union_() { + union_.id = 0; + } + Key(const Key &key) + : ptr_(key.ptr_), length_(key.length_), union_(key.union_) {} + + Key &operator=(const Key &key) { + ptr_ = key.ptr_; + length_ = key.length_; + union_ = key.union_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return ptr_[i]; + } + + void set_str(const char *str) { + MARISA_DEBUG_IF(str == NULL, MARISA_NULL_ERROR); + std::size_t length = 0; + while (str[length] != '\0') { + ++length; + } + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = str; + length_ = (UInt32)length; + } + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + ptr_ = ptr; + length_ = (UInt32)length; + } + void set_id(std::size_t id) { + MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + union_.id = (UInt32)id; + } + void set_weight(float weight) { + union_.weight = weight; + } + + const char *ptr() const { + return ptr_; + } + std::size_t length() const { + return length_; + } + std::size_t id() const { + return union_.id; + } + float weight() const { + return union_.weight; + } + + void clear() { + Key().swap(*this); + } + void swap(Key &rhs) { + marisa::swap(ptr_, rhs.ptr_); + marisa::swap(length_, rhs.length_); + marisa::swap(union_.id, rhs.union_.id); + } + + private: + const char *ptr_; + UInt32 length_; + union Union { + UInt32 id; + float weight; + } union_; +}; + +} // namespace marisa + +#endif // MARISA_KEY_H_ diff --git a/contrib/python/marisa-trie/marisa/keyset.cc b/contrib/python/marisa-trie/marisa/keyset.cc new file mode 100644 index 0000000000..adb82b31fe --- /dev/null +++ b/contrib/python/marisa-trie/marisa/keyset.cc @@ -0,0 +1,181 @@ +#include <new> + +#include "keyset.h" + +namespace marisa { + +Keyset::Keyset() + : base_blocks_(), base_blocks_size_(0), base_blocks_capacity_(0), + extra_blocks_(), extra_blocks_size_(0), extra_blocks_capacity_(0), + key_blocks_(), key_blocks_size_(0), key_blocks_capacity_(0), + ptr_(NULL), avail_(0), size_(0), total_length_(0) {} + +void Keyset::push_back(const Key &key) { + MARISA_DEBUG_IF(size_ == MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + + char * const key_ptr = reserve(key.length()); + for (std::size_t i = 0; i < key.length(); ++i) { + key_ptr[i] = key[i]; + } + + Key &new_key = key_blocks_[size_ / KEY_BLOCK_SIZE][size_ % KEY_BLOCK_SIZE]; + new_key.set_str(key_ptr, key.length()); + new_key.set_id(key.id()); + ++size_; + total_length_ += new_key.length(); +} + +void Keyset::push_back(const Key &key, char end_marker) { + MARISA_DEBUG_IF(size_ == MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + + if ((size_ / KEY_BLOCK_SIZE) == key_blocks_size_) { + append_key_block(); + } + + char * const key_ptr = reserve(key.length() + 1); + for (std::size_t i = 0; i < key.length(); ++i) { + key_ptr[i] = key[i]; + } + key_ptr[key.length()] = end_marker; + + Key &new_key = key_blocks_[size_ / KEY_BLOCK_SIZE][size_ % KEY_BLOCK_SIZE]; + new_key.set_str(key_ptr, key.length()); + new_key.set_id(key.id()); + ++size_; + total_length_ += new_key.length(); +} + +void Keyset::push_back(const char *str) { + MARISA_DEBUG_IF(size_ == MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR); + + std::size_t length = 0; + while (str[length] != '\0') { + ++length; + } + push_back(str, length); +} + +void Keyset::push_back(const char *ptr, std::size_t length, float weight) { + MARISA_DEBUG_IF(size_ == MARISA_SIZE_MAX, MARISA_SIZE_ERROR); + MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + MARISA_THROW_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); + + char * const key_ptr = reserve(length); + for (std::size_t i = 0; i < length; ++i) { + key_ptr[i] = ptr[i]; + } + + Key &key = key_blocks_[size_ / KEY_BLOCK_SIZE][size_ % KEY_BLOCK_SIZE]; + key.set_str(key_ptr, length); + key.set_weight(weight); + ++size_; + total_length_ += length; +} + +void Keyset::reset() { + base_blocks_size_ = 0; + extra_blocks_size_ = 0; + ptr_ = NULL; + avail_ = 0; + size_ = 0; + total_length_ = 0; +} + +void Keyset::clear() { + Keyset().swap(*this); +} + +void Keyset::swap(Keyset &rhs) { + base_blocks_.swap(rhs.base_blocks_); + marisa::swap(base_blocks_size_, rhs.base_blocks_size_); + marisa::swap(base_blocks_capacity_, rhs.base_blocks_capacity_); + extra_blocks_.swap(rhs.extra_blocks_); + marisa::swap(extra_blocks_size_, rhs.extra_blocks_size_); + marisa::swap(extra_blocks_capacity_, rhs.extra_blocks_capacity_); + key_blocks_.swap(rhs.key_blocks_); + marisa::swap(key_blocks_size_, rhs.key_blocks_size_); + marisa::swap(key_blocks_capacity_, rhs.key_blocks_capacity_); + marisa::swap(ptr_, rhs.ptr_); + marisa::swap(avail_, rhs.avail_); + marisa::swap(size_, rhs.size_); + marisa::swap(total_length_, rhs.total_length_); +} + +char *Keyset::reserve(std::size_t size) { + if ((size_ / KEY_BLOCK_SIZE) == key_blocks_size_) { + append_key_block(); + } + + if (size > EXTRA_BLOCK_SIZE) { + append_extra_block(size); + return extra_blocks_[extra_blocks_size_ - 1].get(); + } else { + if (size > avail_) { + append_base_block(); + } + ptr_ += size; + avail_ -= size; + return ptr_ - size; + } +} + +void Keyset::append_base_block() { + if (base_blocks_size_ == base_blocks_capacity_) { + const std::size_t new_capacity = + (base_blocks_size_ != 0) ? (base_blocks_size_ * 2) : 1; + scoped_array<scoped_array<char> > new_blocks( + new (std::nothrow) scoped_array<char>[new_capacity]); + MARISA_THROW_IF(new_blocks.get() == NULL, MARISA_MEMORY_ERROR); + for (std::size_t i = 0; i < base_blocks_size_; ++i) { + base_blocks_[i].swap(new_blocks[i]); + } + base_blocks_.swap(new_blocks); + base_blocks_capacity_ = new_capacity; + } + if (base_blocks_[base_blocks_size_].get() == NULL) { + scoped_array<char> new_block(new (std::nothrow) char[BASE_BLOCK_SIZE]); + MARISA_THROW_IF(new_block.get() == NULL, MARISA_MEMORY_ERROR); + base_blocks_[base_blocks_size_].swap(new_block); + } + ptr_ = base_blocks_[base_blocks_size_++].get(); + avail_ = BASE_BLOCK_SIZE; +} + +void Keyset::append_extra_block(std::size_t size) { + if (extra_blocks_size_ == extra_blocks_capacity_) { + const std::size_t new_capacity = + (extra_blocks_size_ != 0) ? (extra_blocks_size_ * 2) : 1; + scoped_array<scoped_array<char> > new_blocks( + new (std::nothrow) scoped_array<char>[new_capacity]); + MARISA_THROW_IF(new_blocks.get() == NULL, MARISA_MEMORY_ERROR); + for (std::size_t i = 0; i < extra_blocks_size_; ++i) { + extra_blocks_[i].swap(new_blocks[i]); + } + extra_blocks_.swap(new_blocks); + extra_blocks_capacity_ = new_capacity; + } + scoped_array<char> new_block(new (std::nothrow) char[size]); + MARISA_THROW_IF(new_block.get() == NULL, MARISA_MEMORY_ERROR); + extra_blocks_[extra_blocks_size_++].swap(new_block); +} + +void Keyset::append_key_block() { + if (key_blocks_size_ == key_blocks_capacity_) { + const std::size_t new_capacity = + (key_blocks_size_ != 0) ? (key_blocks_size_ * 2) : 1; + scoped_array<scoped_array<Key> > new_blocks( + new (std::nothrow) scoped_array<Key>[new_capacity]); + MARISA_THROW_IF(new_blocks.get() == NULL, MARISA_MEMORY_ERROR); + for (std::size_t i = 0; i < key_blocks_size_; ++i) { + key_blocks_[i].swap(new_blocks[i]); + } + key_blocks_.swap(new_blocks); + key_blocks_capacity_ = new_capacity; + } + scoped_array<Key> new_block(new (std::nothrow) Key[KEY_BLOCK_SIZE]); + MARISA_THROW_IF(new_block.get() == NULL, MARISA_MEMORY_ERROR); + key_blocks_[key_blocks_size_++].swap(new_block); +} + +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/keyset.h b/contrib/python/marisa-trie/marisa/keyset.h new file mode 100644 index 0000000000..86762dba47 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/keyset.h @@ -0,0 +1,81 @@ +#pragma once +#ifndef MARISA_KEYSET_H_ +#define MARISA_KEYSET_H_ + +#include "key.h" + +namespace marisa { + +class Keyset { + public: + enum { + BASE_BLOCK_SIZE = 4096, + EXTRA_BLOCK_SIZE = 1024, + KEY_BLOCK_SIZE = 256 + }; + + Keyset(); + + void push_back(const Key &key); + void push_back(const Key &key, char end_marker); + + void push_back(const char *str); + void push_back(const char *ptr, std::size_t length, float weight = 1.0); + + const Key &operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return key_blocks_[i / KEY_BLOCK_SIZE][i % KEY_BLOCK_SIZE]; + } + Key &operator[](std::size_t i) { + MARISA_DEBUG_IF(i >= size_, MARISA_BOUND_ERROR); + return key_blocks_[i / KEY_BLOCK_SIZE][i % KEY_BLOCK_SIZE]; + } + + std::size_t num_keys() const { + return size_; + } + + bool empty() const { + return size_ == 0; + } + std::size_t size() const { + return size_; + } + std::size_t total_length() const { + return total_length_; + } + + void reset(); + + void clear(); + void swap(Keyset &rhs); + + private: + scoped_array<scoped_array<char> > base_blocks_; + std::size_t base_blocks_size_; + std::size_t base_blocks_capacity_; + scoped_array<scoped_array<char> > extra_blocks_; + std::size_t extra_blocks_size_; + std::size_t extra_blocks_capacity_; + scoped_array<scoped_array<Key> > key_blocks_; + std::size_t key_blocks_size_; + std::size_t key_blocks_capacity_; + char *ptr_; + std::size_t avail_; + std::size_t size_; + std::size_t total_length_; + + char *reserve(std::size_t size); + + void append_base_block(); + void append_extra_block(std::size_t size); + void append_key_block(); + + // Disallows copy and assignment. + Keyset(const Keyset &); + Keyset &operator=(const Keyset &); +}; + +} // namespace marisa + +#endif // MARISA_KEYSET_H_ diff --git a/contrib/python/marisa-trie/marisa/query.h b/contrib/python/marisa-trie/marisa/query.h new file mode 100644 index 0000000000..e08f8f72dc --- /dev/null +++ b/contrib/python/marisa-trie/marisa/query.h @@ -0,0 +1,72 @@ +#pragma once +#ifndef MARISA_QUERY_H_ +#define MARISA_QUERY_H_ + +#include "base.h" + +namespace marisa { + +class Query { + public: + Query() : ptr_(NULL), length_(0), id_(0) {} + Query(const Query &query) + : ptr_(query.ptr_), length_(query.length_), id_(query.id_) {} + + Query &operator=(const Query &query) { + ptr_ = query.ptr_; + length_ = query.length_; + id_ = query.id_; + return *this; + } + + char operator[](std::size_t i) const { + MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); + return ptr_[i]; + } + + void set_str(const char *str) { + MARISA_DEBUG_IF(str == NULL, MARISA_NULL_ERROR); + std::size_t length = 0; + while (str[length] != '\0') { + ++length; + } + ptr_ = str; + length_ = length; + } + void set_str(const char *ptr, std::size_t length) { + MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); + ptr_ = ptr; + length_ = length; + } + void set_id(std::size_t id) { + id_ = id; + } + + const char *ptr() const { + return ptr_; + } + std::size_t length() const { + return length_; + } + std::size_t id() const { + return id_; + } + + void clear() { + Query().swap(*this); + } + void swap(Query &rhs) { + marisa::swap(ptr_, rhs.ptr_); + marisa::swap(length_, rhs.length_); + marisa::swap(id_, rhs.id_); + } + + private: + const char *ptr_; + std::size_t length_; + std::size_t id_; +}; + +} // namespace marisa + +#endif // MARISA_QUERY_H_ diff --git a/contrib/python/marisa-trie/marisa/scoped-array.h b/contrib/python/marisa-trie/marisa/scoped-array.h new file mode 100644 index 0000000000..210cb908a7 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/scoped-array.h @@ -0,0 +1,49 @@ +#pragma once +#ifndef MARISA_SCOPED_ARRAY_H_ +#define MARISA_SCOPED_ARRAY_H_ + +#include "base.h" + +namespace marisa { + +template <typename T> +class scoped_array { + public: + scoped_array() : array_(NULL) {} + explicit scoped_array(T *array) : array_(array) {} + + ~scoped_array() { + delete [] array_; + } + + void reset(T *array = NULL) { + MARISA_THROW_IF((array != NULL) && (array == array_), MARISA_RESET_ERROR); + scoped_array(array).swap(*this); + } + + T &operator[](std::size_t i) const { + MARISA_DEBUG_IF(array_ == NULL, MARISA_STATE_ERROR); + return array_[i]; + } + T *get() const { + return array_; + } + + void clear() { + scoped_array().swap(*this); + } + void swap(scoped_array &rhs) { + marisa::swap(array_, rhs.array_); + } + + private: + T *array_; + + // Disallows copy and assignment. + scoped_array(const scoped_array &); + scoped_array &operator=(const scoped_array &); +}; + +} // namespace marisa + +#endif // MARISA_SCOPED_ARRAY_H_ diff --git a/contrib/python/marisa-trie/marisa/scoped-ptr.h b/contrib/python/marisa-trie/marisa/scoped-ptr.h new file mode 100644 index 0000000000..9a9c447353 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/scoped-ptr.h @@ -0,0 +1,53 @@ +#pragma once +#ifndef MARISA_SCOPED_PTR_H_ +#define MARISA_SCOPED_PTR_H_ + +#include "base.h" + +namespace marisa { + +template <typename T> +class scoped_ptr { + public: + scoped_ptr() : ptr_(NULL) {} + explicit scoped_ptr(T *ptr) : ptr_(ptr) {} + + ~scoped_ptr() { + delete ptr_; + } + + void reset(T *ptr = NULL) { + MARISA_THROW_IF((ptr != NULL) && (ptr == ptr_), MARISA_RESET_ERROR); + scoped_ptr(ptr).swap(*this); + } + + T &operator*() const { + MARISA_DEBUG_IF(ptr_ == NULL, MARISA_STATE_ERROR); + return *ptr_; + } + T *operator->() const { + MARISA_DEBUG_IF(ptr_ == NULL, MARISA_STATE_ERROR); + return ptr_; + } + T *get() const { + return ptr_; + } + + void clear() { + scoped_ptr().swap(*this); + } + void swap(scoped_ptr &rhs) { + marisa::swap(ptr_, rhs.ptr_); + } + + private: + T *ptr_; + + // Disallows copy and assignment. + scoped_ptr(const scoped_ptr &); + scoped_ptr &operator=(const scoped_ptr &); +}; + +} // namespace marisa + +#endif // MARISA_SCOPED_PTR_H_ diff --git a/contrib/python/marisa-trie/marisa/stdio.h b/contrib/python/marisa-trie/marisa/stdio.h new file mode 100644 index 0000000000..334ce56816 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/stdio.h @@ -0,0 +1,16 @@ +#pragma once +#ifndef MARISA_MYSTDIO_H_ +#define MARISA_MYSTDIO_H_ + +#include <cstdio> + +namespace marisa { + +class Trie; + +void fread(std::FILE *file, Trie *trie); +void fwrite(std::FILE *file, const Trie &trie); + +} // namespace marisa + +#endif // MARISA_MYSTDIO_H_ diff --git a/contrib/python/marisa-trie/marisa/trie.cc b/contrib/python/marisa-trie/marisa/trie.cc new file mode 100644 index 0000000000..5baaf9b288 --- /dev/null +++ b/contrib/python/marisa-trie/marisa/trie.cc @@ -0,0 +1,249 @@ +#include "stdio.h" +#include "iostream.h" +#include "trie.h" +#include "grimoire/trie.h" + +namespace marisa { + +Trie::Trie() : trie_() {} + +Trie::~Trie() {} + +void Trie::build(Keyset &keyset, int config_flags) { + scoped_ptr<grimoire::LoudsTrie> temp(new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + temp->build(keyset, config_flags); + trie_.swap(temp); +} + +void Trie::mmap(const char *filename) { + MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); + + scoped_ptr<grimoire::LoudsTrie> temp(new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + grimoire::Mapper mapper; + mapper.open(filename); + temp->map(mapper); + trie_.swap(temp); +} + +void Trie::map(const void *ptr, std::size_t size) { + MARISA_THROW_IF((ptr == NULL) && (size != 0), MARISA_NULL_ERROR); + + scoped_ptr<grimoire::LoudsTrie> temp(new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + grimoire::Mapper mapper; + mapper.open(ptr, size); + temp->map(mapper); + trie_.swap(temp); +} + +void Trie::load(const char *filename) { + MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); + + scoped_ptr<grimoire::LoudsTrie> temp(new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + grimoire::Reader reader; + reader.open(filename); + temp->read(reader); + trie_.swap(temp); +} + +void Trie::read(int fd) { + MARISA_THROW_IF(fd == -1, MARISA_CODE_ERROR); + + scoped_ptr<grimoire::LoudsTrie> temp(new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + grimoire::Reader reader; + reader.open(fd); + temp->read(reader); + trie_.swap(temp); +} + +void Trie::save(const char *filename) const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(filename == NULL, MARISA_NULL_ERROR); + + grimoire::Writer writer; + writer.open(filename); + trie_->write(writer); +} + +void Trie::write(int fd) const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + MARISA_THROW_IF(fd == -1, MARISA_CODE_ERROR); + + grimoire::Writer writer; + writer.open(fd); + trie_->write(writer); +} + +bool Trie::lookup(Agent &agent) const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + if (!agent.has_state()) { + agent.init_state(); + } + return trie_->lookup(agent); +} + +void Trie::reverse_lookup(Agent &agent) const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + if (!agent.has_state()) { + agent.init_state(); + } + trie_->reverse_lookup(agent); +} + +bool Trie::common_prefix_search(Agent &agent) const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + if (!agent.has_state()) { + agent.init_state(); + } + return trie_->common_prefix_search(agent); +} + +bool Trie::predictive_search(Agent &agent) const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + if (!agent.has_state()) { + agent.init_state(); + } + return trie_->predictive_search(agent); +} + +std::size_t Trie::num_tries() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->num_tries(); +} + +std::size_t Trie::num_keys() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->num_keys(); +} + +std::size_t Trie::num_nodes() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->num_nodes(); +} + +TailMode Trie::tail_mode() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->tail_mode(); +} + +NodeOrder Trie::node_order() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->node_order(); +} + +bool Trie::empty() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->empty(); +} + +std::size_t Trie::size() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->size(); +} + +std::size_t Trie::total_size() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->total_size(); +} + +std::size_t Trie::io_size() const { + MARISA_THROW_IF(trie_.get() == NULL, MARISA_STATE_ERROR); + return trie_->io_size(); +} + +void Trie::clear() { + Trie().swap(*this); +} + +void Trie::swap(Trie &rhs) { + trie_.swap(rhs.trie_); +} + +} // namespace marisa + +#include <iostream> + +namespace marisa { + +class TrieIO { + public: + static void fread(std::FILE *file, Trie *trie) { + MARISA_THROW_IF(trie == NULL, MARISA_NULL_ERROR); + + scoped_ptr<grimoire::LoudsTrie> temp( + new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + grimoire::Reader reader; + reader.open(file); + temp->read(reader); + trie->trie_.swap(temp); + } + static void fwrite(std::FILE *file, const Trie &trie) { + MARISA_THROW_IF(file == NULL, MARISA_NULL_ERROR); + MARISA_THROW_IF(trie.trie_.get() == NULL, MARISA_STATE_ERROR); + grimoire::Writer writer; + writer.open(file); + trie.trie_->write(writer); + } + + static std::istream &read(std::istream &stream, Trie *trie) { + MARISA_THROW_IF(trie == NULL, MARISA_NULL_ERROR); + + scoped_ptr<grimoire::LoudsTrie> temp( + new (std::nothrow) grimoire::LoudsTrie); + MARISA_THROW_IF(temp.get() == NULL, MARISA_MEMORY_ERROR); + + grimoire::Reader reader; + reader.open(stream); + temp->read(reader); + trie->trie_.swap(temp); + return stream; + } + static std::ostream &write(std::ostream &stream, const Trie &trie) { + MARISA_THROW_IF(trie.trie_.get() == NULL, MARISA_STATE_ERROR); + grimoire::Writer writer; + writer.open(stream); + trie.trie_->write(writer); + return stream; + } +}; + +void fread(std::FILE *file, Trie *trie) { + MARISA_THROW_IF(file == NULL, MARISA_NULL_ERROR); + MARISA_THROW_IF(trie == NULL, MARISA_NULL_ERROR); + TrieIO::fread(file, trie); +} + +void fwrite(std::FILE *file, const Trie &trie) { + MARISA_THROW_IF(file == NULL, MARISA_NULL_ERROR); + TrieIO::fwrite(file, trie); +} + +std::istream &read(std::istream &stream, Trie *trie) { + MARISA_THROW_IF(trie == NULL, MARISA_NULL_ERROR); + return TrieIO::read(stream, trie); +} + +std::ostream &write(std::ostream &stream, const Trie &trie) { + return TrieIO::write(stream, trie); +} + +std::istream &operator>>(std::istream &stream, Trie &trie) { + return read(stream, &trie); +} + +std::ostream &operator<<(std::ostream &stream, const Trie &trie) { + return write(stream, trie); +} + +} // namespace marisa diff --git a/contrib/python/marisa-trie/marisa/trie.h b/contrib/python/marisa-trie/marisa/trie.h new file mode 100644 index 0000000000..df85bd86ba --- /dev/null +++ b/contrib/python/marisa-trie/marisa/trie.h @@ -0,0 +1,65 @@ +#pragma once +#ifndef MARISA_TRIE_H_ +#define MARISA_TRIE_H_ + +#include "keyset.h" +#include "agent.h" + +namespace marisa { +namespace grimoire { +namespace trie { + +class LoudsTrie; + +} // namespace trie +} // namespace grimoire + +class Trie { + friend class TrieIO; + + public: + Trie(); + ~Trie(); + + void build(Keyset &keyset, int config_flags = 0); + + void mmap(const char *filename); + void map(const void *ptr, std::size_t size); + + void load(const char *filename); + void read(int fd); + + void save(const char *filename) const; + void write(int fd) const; + + bool lookup(Agent &agent) const; + void reverse_lookup(Agent &agent) const; + bool common_prefix_search(Agent &agent) const; + bool predictive_search(Agent &agent) const; + + std::size_t num_tries() const; + std::size_t num_keys() const; + std::size_t num_nodes() const; + + TailMode tail_mode() const; + NodeOrder node_order() const; + + bool empty() const; + std::size_t size() const; + std::size_t total_size() const; + std::size_t io_size() const; + + void clear(); + void swap(Trie &rhs); + + private: + scoped_ptr<grimoire::trie::LoudsTrie> trie_; + + // Disallows copy and assignment. + Trie(const Trie &); + Trie &operator=(const Trie &); +}; + +} // namespace marisa + +#endif // MARISA_TRIE_H_ diff --git a/contrib/python/marisa-trie/marisa_trie.pyx b/contrib/python/marisa-trie/marisa_trie.pyx new file mode 100644 index 0000000000..f9fe6f331b --- /dev/null +++ b/contrib/python/marisa-trie/marisa_trie.pyx @@ -0,0 +1,763 @@ +# cython: profile=False, embedsignature=True + +from __future__ import unicode_literals + +from std_iostream cimport stringstream, istream, ostream +from libc.string cimport strncmp +cimport keyset +cimport key +cimport agent +cimport trie +cimport iostream +cimport base + +import itertools +import struct +import warnings + +try: + from itertools import izip +except ImportError: + izip = zip + + +DEFAULT_CACHE = base.MARISA_DEFAULT_CACHE +HUGE_CACHE = base.MARISA_HUGE_CACHE +LARGE_CACHE = base.MARISA_LARGE_CACHE +NORMAL_CACHE = base.MARISA_NORMAL_CACHE +SMALL_CACHE = base.MARISA_SMALL_CACHE +TINY_CACHE = base.MARISA_TINY_CACHE + +MIN_NUM_TRIES = base.MARISA_MIN_NUM_TRIES +MAX_NUM_TRIES = base.MARISA_MAX_NUM_TRIES +DEFAULT_NUM_TRIES = base.MARISA_DEFAULT_NUM_TRIES + +# MARISA_TEXT_TAIL merges last labels as zero-terminated strings. So, it is +# available if and only if the last labels do not contain a NULL character. +# If MARISA_TEXT_TAIL is specified and a NULL character exists in the last +# labels, the setting is automatically switched to MARISA_BINARY_TAIL. +TEXT_TAIL = base.MARISA_TEXT_TAIL + +# MARISA_BINARY_TAIL also merges last labels but as byte sequences. It uses +# a bit vector to detect the end of a sequence, instead of NULL characters. +# So, MARISA_BINARY_TAIL requires a larger space if the average length of +# labels is greater than 8. +BINARY_TAIL = base.MARISA_BINARY_TAIL +DEFAULT_TAIL = base.MARISA_DEFAULT_TAIL + + +# MARISA_LABEL_ORDER arranges nodes in ascending label order. +# MARISA_LABEL_ORDER is useful if an application needs to predict keys in +# label order. +LABEL_ORDER = base.MARISA_LABEL_ORDER + +# MARISA_WEIGHT_ORDER arranges nodes in descending weight order. +# MARISA_WEIGHT_ORDER is generally a better choice because it enables faster +# matching. +WEIGHT_ORDER = base.MARISA_WEIGHT_ORDER +DEFAULT_ORDER = base.MARISA_DEFAULT_ORDER + + +cdef class _Trie: + cdef trie.Trie* _trie + + cdef bytes _encode_key(self, key): + return key + + cdef _get_key(self, agent.Agent& ag): + return ag.key().ptr()[:ag.key().length()] + + def __init__(self, arg=None, num_tries=DEFAULT_NUM_TRIES, binary=False, + cache_size=DEFAULT_CACHE, order=DEFAULT_ORDER, weights=None): + """ + ``arg`` can be one of the following: + + * an iterable with bytes keys; + * None (if you're going to load a trie later). + + Pass a ``weights`` iterable with expected lookup frequencies + to optimize lookup and prefix search speed. + """ + + if self._trie: + return + self._trie = new trie.Trie() + + byte_keys = (self._encode_key(key) for key in (arg or [])) + + self._build( + byte_keys, + weights, + num_tries=num_tries, + binary=binary, + cache_size=cache_size, + order=order + ) + + def __dealloc__(self): + if self._trie: + del self._trie + + def _config_flags(self, num_tries=DEFAULT_NUM_TRIES, binary=False, + cache_size=DEFAULT_CACHE, order=DEFAULT_ORDER): + if not MIN_NUM_TRIES <= num_tries <= MAX_NUM_TRIES: + raise ValueError( + "num_tries (which is %d) must be between between %d and %d" % + (num_tries, MIN_NUM_TRIES, MAX_NUM_TRIES)) + + binary_flag = BINARY_TAIL if binary else TEXT_TAIL + return num_tries | binary_flag | cache_size | order + + def _build(self, byte_keys, weights=None, **options): + if weights is None: + weights = itertools.repeat(1.0) + + cdef char* data + cdef float weight + cdef keyset.Keyset *ks = new keyset.Keyset() + + try: + for key, weight in izip(byte_keys, weights): + ks.push_back(<char *>key, len(key), weight) + self._trie.build(ks[0], self._config_flags(**options)) + finally: + del ks + + def __richcmp__(self, other, int op): + if op == 2: # == + if other is self: + return True + elif not isinstance(other, _Trie): + return False + + return (<_Trie>self)._equals(other) + elif op == 3: # != + return not (self == other) + + raise TypeError("unorderable types: {0} and {1}".format( + self.__class__, other.__class__)) + + cdef bint _equals(self, _Trie other) nogil: + cdef int num_keys = self._trie.num_keys() + cdef base.NodeOrder node_order = self._trie.node_order() + if (other._trie.num_keys() != num_keys or + other._trie.node_order() != node_order): + return False + + cdef agent.Agent ag1, ag2 + ag1.set_query(b"") + ag2.set_query(b"") + cdef int i + cdef key.Key key1, key2 + for i in range(num_keys): + self._trie.predictive_search(ag1) + other._trie.predictive_search(ag2) + key1 = ag1.key() + key2 = ag2.key() + if (key1.length() != key2.length() or + strncmp(key1.ptr(), key2.ptr(), key1.length()) != 0): + return False + return True + + def __iter__(self): + return self.iterkeys() + + def __len__(self): + return self._trie.num_keys() + + def __contains__(self, key): + cdef bytes _key = self._encode_key(key) + return self._contains(_key) + + cdef bint _contains(self, bytes key): + cdef agent.Agent ag + ag.set_query(key, len(key)) + return self._trie.lookup(ag) + + def read(self, f): + """Read a trie from an open file. + + :param file f: a "real" on-disk file object. Passing a *file-like* + object would result in an error. + + .. deprecated:: 0.7.3 + + The method will be removed in version 0.8.0. Please use + :meth:`load` instead. + """ + warnings.warn("Trie.save is deprecated and will " + "be removed in marisa_trie 0.8.0. Please use " + "Trie.load instead.", DeprecationWarning) + self._trie.read(f.fileno()) + return self + + def write(self, f): + """Write a trie to an open file. + + :param file f: a "real" on-disk file object. Passing a *file-like* + object would result in an error. + + .. deprecated:: 0.7.3 + + The method will be removed in version 0.8.0. Please use + :meth:`save` instead. + """ + warnings.warn("Trie.write is deprecated and will " + "be removed in marisa_trie 0.8.0. Please use " + "Trie.save instead.", DeprecationWarning) + self._trie.write(f.fileno()) + + def save(self, path): + """Save a trie to a specified path.""" + with open(path, 'w') as f: + self._trie.write(f.fileno()) + + def load(self, path): + """Load a trie from a specified path.""" + with open(path, 'r') as f: + self._trie.read(f.fileno()) + return self + + cpdef bytes tobytes(self) except +: + """Return raw trie content as bytes.""" + cdef stringstream stream + iostream.write((<ostream *> &stream)[0], self._trie[0]) + cdef bytes res = stream.str() + return res + + cpdef frombytes(self, bytes data) except +: + """Load a trie from raw bytes generated by :meth:`tobytes`.""" + cdef stringstream* stream = new stringstream(data) + try: + iostream.read((<istream *> stream)[0], self._trie) + finally: + del stream + return self + + def __reduce__(self): + return self.__class__, (), self.tobytes() + + __setstate__ = frombytes + + def mmap(self, path): + """Memory map the content of a trie stored in a file. + + This allows to query trie without loading it fully in memory. + """ + import sys + str_path = path.encode(sys.getfilesystemencoding()) + cdef char* c_path = str_path + self._trie.mmap(c_path) + return self + + def iterkeys(self, prefix=None): + """ + Return an iterator over trie keys starting with a given ``prefix``. + """ + cdef agent.Agent ag + cdef bytes b_prefix = b'' + if prefix is not None: + b_prefix = self._encode_key(prefix) + ag.set_query(b_prefix, len(b_prefix)) + + while self._trie.predictive_search(ag): + yield self._get_key(ag) + + cpdef list keys(self, prefix=None): + """Return a list of trie keys starting with a given ``prefix``.""" + # non-generator inlined version of iterkeys() + cdef list res = [] + cdef bytes b_prefix = b'' + if prefix is not None: + b_prefix = self._encode_key(prefix) + cdef agent.Agent ag + ag.set_query(b_prefix, len(b_prefix)) + + while self._trie.predictive_search(ag): + res.append(self._get_key(ag)) + + return res + + def has_keys_with_prefix(self, prefix=""): + """ + Return ``True`` if any key in the trie begins with ``prefix``. + + .. deprecated:: 0.7.3 + + The method will be removed in version 0.8.0. Please use + :meth:`iterkeys` instead. + """ + warnings.warn("Trie.has_keys_with_prefix is deprecated and will " + "be removed in marisa_trie 0.8.0. Please use " + "Trie.iterkeys instead.", DeprecationWarning) + + cdef agent.Agent ag + cdef bytes b_prefix = self._encode_key(prefix) + ag.set_query(b_prefix, len(b_prefix)) + return self._trie.predictive_search(ag) + + +cdef class BinaryTrie(_Trie): + """A trie mapping bytes keys to auto-generated unique IDs.""" + + # key_id method is not in _Trie because it won't work for BytesTrie + cpdef int key_id(self, bytes key) except -1: + """Return an ID generated for a given ``key``. + + :raises KeyError: if key is not present in this trie. + """ + cdef int res = self._key_id(key, len(key)) + if res == -1: + raise KeyError(key) + return res + + cdef int _key_id(self, char* key, int len): + cdef bint res + cdef agent.Agent ag + ag.set_query(key, len) + res = self._trie.lookup(ag) + if not res: + return -1 + return ag.key().id() + + cpdef restore_key(self, int index): + """Return a key corresponding to a given ID.""" + cdef agent.Agent ag + ag.set_query(index) + try: + self._trie.reverse_lookup(ag) + except KeyError: + raise KeyError(index) + return self._get_key(ag) + + def __getitem__(self, bytes key): + return self.key_id(key) + + def get(self, bytes key, default=None): + """ + Return an ID for a given ``key`` or ``default`` if ``key`` is + not present in this trie. + """ + cdef int res + + res = self._key_id(key, len(key)) + if res == -1: + return default + return res + + def iter_prefixes(self, bytes key): + """ + Return an iterator of all prefixes of a given key. + """ + cdef agent.Agent ag + ag.set_query(key, len(key)) + + while self._trie.common_prefix_search(ag): + yield self._get_key(ag) + + def prefixes(self, bytes key): + """ + Return a list with all prefixes of a given key. + """ + # this an inlined version of ``list(self.iter_prefixes(key))`` + + cdef list res = [] + cdef agent.Agent ag + ag.set_query(key, len(key)) + + while self._trie.common_prefix_search(ag): + res.append(self._get_key(ag)) + return res + + def items(self, bytes prefix=b""): + # inlined for speed + cdef list res = [] + cdef agent.Agent ag + ag.set_query(prefix, len(prefix)) + + while self._trie.predictive_search(ag): + res.append((self._get_key(ag), ag.key().id())) + + return res + + def iteritems(self, bytes prefix=b""): + """ + Return an iterator over items that have a prefix ``prefix``. + """ + cdef agent.Agent ag + ag.set_query(prefix, len(prefix)) + + while self._trie.predictive_search(ag): + yield self._get_key(ag), ag.key().id() + + +cdef class _UnicodeKeyedTrie(_Trie): + """ + MARISA-trie wrapper for unicode keys. + """ + cdef bytes _encode_key(self, key): + return key.encode('utf8') + + cdef _get_key(self, agent.Agent& ag): + return <unicode>_Trie._get_key(self, ag).decode('utf8') + + +cdef class Trie(_UnicodeKeyedTrie): + """A trie mapping unicode keys to auto-generated unique IDs.""" + + # key_id method is not in _Trie because it won't work for BytesTrie + cpdef int key_id(self, unicode key) except -1: + """Return an ID generated for a given ``key``. + + :raises KeyError: if key is not present in this trie. + """ + cdef bytes _key = <bytes>key.encode('utf8') + cdef int res = self._key_id(_key) + if res == -1: + raise KeyError(key) + return res + + def __getitem__(self, unicode key): + return self.key_id(key) + + def get(self, key, default=None): + """ + Return an ID for a given ``key`` or ``default`` if ``key`` is + not present in this trie. + """ + cdef bytes b_key + cdef int res + + if isinstance(key, unicode): + b_key = <bytes>(<unicode>key).encode('utf8') + else: + b_key = key + + res = self._key_id(b_key) + if res == -1: + return default + return res + + cpdef restore_key(self, int index): + """Return a key corresponding to a given ID.""" + cdef agent.Agent ag + ag.set_query(index) + try: + self._trie.reverse_lookup(ag) + except KeyError: + raise KeyError(index) + return self._get_key(ag) + + cdef int _key_id(self, char* key): + cdef bint res + cdef agent.Agent ag + ag.set_query(key) + res = self._trie.lookup(ag) + if not res: + return -1 + return ag.key().id() + + def iter_prefixes(self, unicode key): + """ + Return an iterator of all prefixes of a given key. + """ + cdef bytes b_key = <bytes>key.encode('utf8') + cdef agent.Agent ag + ag.set_query(b_key) + + while self._trie.common_prefix_search(ag): + yield self._get_key(ag) + + def prefixes(self, unicode key): + """ + Return a list with all prefixes of a given key. + """ + # this an inlined version of ``list(self.iter_prefixes(key))`` + + cdef list res = [] + cdef bytes b_key = <bytes>key.encode('utf8') + cdef agent.Agent ag + ag.set_query(b_key) + + while self._trie.common_prefix_search(ag): + res.append(self._get_key(ag)) + return res + + def iteritems(self, unicode prefix=""): + """ + Return an iterator over items that have a prefix ``prefix``. + """ + cdef bytes b_prefix = <bytes>prefix.encode('utf8') + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + yield self._get_key(ag), ag.key().id() + + def items(self, unicode prefix=""): + # inlined for speed + cdef list res = [] + cdef bytes b_prefix = <bytes>prefix.encode('utf8') + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + res.append((self._get_key(ag), ag.key().id())) + + return res + + +# This symbol is not allowed in utf8 so it is safe to use +# as a separator between utf8-encoded string and binary payload. +# XXX: b'\xff' value changes sort order for BytesTrie and RecordTrie. +# See https://github.com/kmike/DAWG docs for a description of a similar issue. +cdef bytes _VALUE_SEPARATOR = b'\xff' + + +cdef class BytesTrie(_UnicodeKeyedTrie): + """A trie mapping unicode keys to lists of bytes objects. + + The mapping is implemented by appending binary values to UTF8-encoded + and storing the result in MARISA-trie. + """ + cdef bytes _b_value_separator + cdef unsigned char _c_value_separator + + def __init__(self, arg=None, bytes value_separator=_VALUE_SEPARATOR, + **options): + """ + ``arg`` must be an iterable of tuples (unicode_key, bytes_payload). + """ + super(BytesTrie, self).__init__() + + self._b_value_separator = value_separator + self._c_value_separator = <unsigned char>ord(value_separator) + + byte_keys = (self._raw_key(d[0], d[1]) for d in (arg or [])) + self._build(byte_keys, **options) + + cpdef bytes _raw_key(self, unicode key, bytes payload): + return key.encode('utf8') + self._b_value_separator + payload + + cdef bint _contains(self, bytes key): + cdef agent.Agent ag + cdef bytes _key = key + self._b_value_separator + ag.set_query(_key) + return self._trie.predictive_search(ag) + + cpdef list prefixes(self, unicode key): + """ + Return a list with all prefixes of a given key. + """ + + # XXX: is there a char-walking API in libmarisa? + # This implementation is suboptimal. + + cdef agent.Agent ag + cdef list res = [] + cdef int key_len = len(key) + cdef unicode prefix + cdef bytes b_prefix + cdef int ind = 1 + + while ind <= key_len: + prefix = key[:ind] + b_prefix = <bytes>(prefix.encode('utf8') + self._b_value_separator) + ag.set_query(b_prefix) + if self._trie.predictive_search(ag): + res.append(prefix) + + ind += 1 + + return res + + def __getitem__(self, key): + cdef list res = self.get(key) + if res is None: + raise KeyError(key) + return res + + cpdef get(self, key, default=None): + """ + Return a list of payloads (as byte objects) for a given key + or ``default`` if the key is not found. + """ + cdef list res + + if isinstance(key, unicode): + res = self.get_value(<unicode>key) + else: + res = self.b_get_value(key) + + if not res: + return default + return res + + cpdef list get_value(self, unicode key): + """ + Return a list of payloads (as byte objects) for a given unicode key. + """ + cdef bytes b_key = <bytes>key.encode('utf8') + return self.b_get_value(b_key) + + cpdef list b_get_value(self, bytes key): + """ + Return a list of payloads (as byte objects) for a given utf8-encoded key. + """ + cdef list res = [] + cdef bytes value + cdef bytes b_prefix = key + self._b_value_separator + cdef int prefix_len = len(b_prefix) + + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + value = ag.key().ptr()[prefix_len:ag.key().length()] + res.append(value) + + return res + + cpdef list items(self, unicode prefix=""): + # copied from iteritems for speed + cdef bytes b_prefix = <bytes>prefix.encode('utf8') + cdef bytes value + cdef unicode key + cdef unsigned char* raw_key + cdef list res = [] + cdef int i, value_len + + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + raw_key = <unsigned char*>ag.key().ptr() + + for i in range(0, ag.key().length()): + if raw_key[i] == self._c_value_separator: + break + + key = raw_key[:i].decode('utf8') + value = raw_key[i+1:ag.key().length()] + + res.append( + (key, value) + ) + return res + + def iteritems(self, unicode prefix=""): + cdef bytes b_prefix = <bytes>prefix.encode('utf8') + cdef bytes value + cdef unicode key + cdef unsigned char* raw_key + cdef int i, value_len + + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + raw_key = <unsigned char*>ag.key().ptr() + + for i in range(0, ag.key().length()): + if raw_key[i] == self._c_value_separator: + break + + key = raw_key[:i].decode('utf8') + value = raw_key[i+1:ag.key().length()] + + yield key, value + + cpdef list keys(self, prefix=""): + # copied from iterkeys for speed + cdef bytes b_prefix = <bytes>prefix.encode('utf8') + cdef unicode key + cdef unsigned char* raw_key + cdef list res = [] + cdef int i + + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + raw_key = <unsigned char*>ag.key().ptr() + + for i in range(0, ag.key().length()): + if raw_key[i] == self._c_value_separator: + key = raw_key[:i].decode('utf8') + res.append(key) + break + return res + + def iterkeys(self, unicode prefix=""): + cdef bytes b_prefix = <bytes>prefix.encode('utf8') + cdef unicode key + cdef unsigned char* raw_key + cdef int i + + cdef agent.Agent ag + ag.set_query(b_prefix) + + while self._trie.predictive_search(ag): + raw_key = <unsigned char*>ag.key().ptr() + + for i in range(0, ag.key().length()): + if raw_key[i] == self._c_value_separator: + yield raw_key[:i].decode('utf8') + break + + +cdef class _UnpackTrie(BytesTrie): + + def __init__(self, arg=None, **options): + keys = ((d[0], self._pack(d[1])) for d in (arg or [])) + super(_UnpackTrie, self).__init__(keys, **options) + + cdef _unpack(self, bytes value): + return value + + cdef bytes _pack(self, value): + return value + + cpdef list b_get_value(self, bytes key): + cdef list values = BytesTrie.b_get_value(self, key) + return [self._unpack(val) for val in values] + + cpdef list items(self, unicode prefix=""): + cdef list items = BytesTrie.items(self, prefix) + return [(key, self._unpack(val)) for (key, val) in items] + + def iteritems(self, unicode prefix=""): + return ((key, self._unpack(val)) for key, val in BytesTrie.iteritems(self, prefix)) + + +cdef class RecordTrie(_UnpackTrie): + """A trie mapping unicode keys to lists of data tuples. + + The data is packed using :mod:`struct` module, therefore all + tuples must be of the same format. See :mod:`struct` documentation + for available format strings. + + The mapping is implemented by appending binary values to UTF8-encoded + and storing the result in MARISA-trie. + """ + cdef _struct + cdef _fmt + + def __init__(self, fmt, arg=None, **options): + """ + ``arg`` must be an iterable of tuples (unicode_key, data_tuple). + Data tuples will be converted to bytes with + ``struct.pack(fmt, *data_tuple)``. + """ + self._fmt = fmt + self._struct = struct.Struct(str(fmt)) + super(RecordTrie, self).__init__(arg, **options) + + cdef _unpack(self, bytes value): + return self._struct.unpack(value) + + cdef bytes _pack(self, value): + return self._struct.pack(*value) + + def __reduce__(self): + return self.__class__, (self._fmt, ), self.tobytes() diff --git a/contrib/python/marisa-trie/query.pxd b/contrib/python/marisa-trie/query.pxd new file mode 100644 index 0000000000..a650bb8965 --- /dev/null +++ b/contrib/python/marisa-trie/query.pxd @@ -0,0 +1,20 @@ +cdef extern from "<marisa/query.h>" namespace "marisa" nogil: + + cdef cppclass Query: + Query() + Query(Query &query) + + #Query &operator=(Query &query) + + char operator[](int i) + + void set_str(char *str) + void set_str(char *ptr, int length) + void set_id(int id) + + char *ptr() + int length() + int id() + + void clear() + void swap(Query &rhs) diff --git a/contrib/python/marisa-trie/std_iostream.pxd b/contrib/python/marisa-trie/std_iostream.pxd new file mode 100644 index 0000000000..bf7d0e89aa --- /dev/null +++ b/contrib/python/marisa-trie/std_iostream.pxd @@ -0,0 +1,18 @@ +from libcpp.string cimport string + +cdef extern from "<istream>" namespace "std" nogil: + cdef cppclass istream: + istream() except + + istream& read (char* s, int n) except + + + cdef cppclass ostream: + ostream() except + + ostream& write (char* s, int n) except + + +cdef extern from "<sstream>" namespace "std" nogil: + + cdef cppclass stringstream: + stringstream() + stringstream(string s) + string str () + diff --git a/contrib/python/marisa-trie/trie.pxd b/contrib/python/marisa-trie/trie.pxd new file mode 100644 index 0000000000..f525caf8ad --- /dev/null +++ b/contrib/python/marisa-trie/trie.pxd @@ -0,0 +1,41 @@ +cimport agent +cimport base +cimport keyset + + +cdef extern from "<marisa/trie.h>" namespace "marisa" nogil: + + cdef cppclass Trie: + Trie() + + void build(keyset.Keyset &keyset, int config_flags) except + + void build(keyset.Keyset &keyset) except + + + void mmap(char *filename) except + + void map(void *ptr, int size) except + + + void load(char *filename) except + + void read(int fd) except + + + void save(char *filename) except + + void write(int fd) except + + + bint lookup(agent.Agent &agent) except + + void reverse_lookup(agent.Agent &agent) except +KeyError + bint common_prefix_search(agent.Agent &agent) except + + bint predictive_search(agent.Agent &agent) except + + + int num_tries() except + + int num_keys() except + + int num_nodes() except + + + base.TailMode tail_mode() + base.NodeOrder node_order() + + bint empty() except + + int size() except + + int total_size() except + + int io_size() except + + + void clear() except + + void swap(Trie &rhs) except + diff --git a/contrib/python/marisa-trie/ya.make b/contrib/python/marisa-trie/ya.make new file mode 100644 index 0000000000..490eef9afa --- /dev/null +++ b/contrib/python/marisa-trie/ya.make @@ -0,0 +1,33 @@ +PY23_LIBRARY() + +LICENSE(MIT) + +VERSION(0.7.5) + +NO_COMPILER_WARNINGS() + +ADDINCL( + contrib/python/marisa-trie +) + +SRCS( + marisa/agent.cc + marisa/keyset.cc + marisa/trie.cc + + marisa/grimoire/io/mapper.cc + marisa/grimoire/io/reader.cc + marisa/grimoire/io/writer.cc + marisa/grimoire/trie/louds-trie.cc + marisa/grimoire/trie/tail.cc + marisa/grimoire/vector/bit-vector.cc +) + +PY_SRCS( + TOP_LEVEL + marisa_trie.pyx +) + +NO_LINT() + +END() |