879 lines
26 KiB
C++
879 lines
26 KiB
C++
#include <algorithm>
|
|
#include <functional>
|
|
#include <queue>
|
|
|
|
#include "marisa/grimoire/algorithm.h"
|
|
#include "marisa/grimoire/trie/header.h"
|
|
#include "marisa/grimoire/trie/range.h"
|
|
#include "marisa/grimoire/trie/state.h"
|
|
#include "marisa/grimoire/trie/louds-trie.h"
|
|
|
|
namespace marisa {
|
|
namespace grimoire {
|
|
namespace trie {
|
|
|
|
LoudsTrie::LoudsTrie()
|
|
: louds_(), terminal_flags_(), link_flags_(), bases_(), extras_(),
|
|
tail_(), next_trie_(), cache_(), cache_mask_(0), num_l1_nodes_(0),
|
|
config_(), mapper_() {}
|
|
|
|
LoudsTrie::~LoudsTrie() {}
|
|
|
|
void LoudsTrie::build(Keyset &keyset, int flags) {
|
|
Config config;
|
|
config.parse(flags);
|
|
|
|
LoudsTrie temp;
|
|
temp.build_(keyset, config);
|
|
swap(temp);
|
|
}
|
|
|
|
void LoudsTrie::map(Mapper &mapper) {
|
|
Header().map(mapper);
|
|
|
|
LoudsTrie temp;
|
|
temp.map_(mapper);
|
|
temp.mapper_.swap(mapper);
|
|
swap(temp);
|
|
}
|
|
|
|
void LoudsTrie::read(Reader &reader) {
|
|
Header().read(reader);
|
|
|
|
LoudsTrie temp;
|
|
temp.read_(reader);
|
|
swap(temp);
|
|
}
|
|
|
|
void LoudsTrie::write(Writer &writer) const {
|
|
Header().write(writer);
|
|
|
|
write_(writer);
|
|
}
|
|
|
|
bool LoudsTrie::lookup(Agent &agent) const {
|
|
MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
|
|
|
|
State &state = agent.state();
|
|
state.lookup_init();
|
|
while (state.query_pos() < agent.query().length()) {
|
|
if (!find_child(agent)) {
|
|
return false;
|
|
}
|
|
}
|
|
if (!terminal_flags_[state.node_id()]) {
|
|
return false;
|
|
}
|
|
agent.set_key(agent.query().ptr(), agent.query().length());
|
|
agent.set_key(terminal_flags_.rank1(state.node_id()));
|
|
return true;
|
|
}
|
|
|
|
void LoudsTrie::reverse_lookup(Agent &agent) const {
|
|
MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
|
|
MARISA_THROW_IF(agent.query().id() >= size(), MARISA_BOUND_ERROR);
|
|
|
|
State &state = agent.state();
|
|
state.reverse_lookup_init();
|
|
|
|
state.set_node_id(terminal_flags_.select1(agent.query().id()));
|
|
if (state.node_id() == 0) {
|
|
agent.set_key(state.key_buf().begin(), state.key_buf().size());
|
|
agent.set_key(agent.query().id());
|
|
return;
|
|
}
|
|
for ( ; ; ) {
|
|
if (link_flags_[state.node_id()]) {
|
|
const std::size_t prev_key_pos = state.key_buf().size();
|
|
restore(agent, get_link(state.node_id()));
|
|
std::reverse(state.key_buf().begin() + prev_key_pos,
|
|
state.key_buf().end());
|
|
} else {
|
|
state.key_buf().push_back((char)bases_[state.node_id()]);
|
|
}
|
|
|
|
if (state.node_id() <= num_l1_nodes_) {
|
|
std::reverse(state.key_buf().begin(), state.key_buf().end());
|
|
agent.set_key(state.key_buf().begin(), state.key_buf().size());
|
|
agent.set_key(agent.query().id());
|
|
return;
|
|
}
|
|
state.set_node_id(louds_.select1(state.node_id()) - state.node_id() - 1);
|
|
}
|
|
}
|
|
|
|
bool LoudsTrie::common_prefix_search(Agent &agent) const {
|
|
MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
|
|
|
|
State &state = agent.state();
|
|
if (state.status_code() == MARISA_END_OF_COMMON_PREFIX_SEARCH) {
|
|
return false;
|
|
}
|
|
|
|
if (state.status_code() != MARISA_READY_TO_COMMON_PREFIX_SEARCH) {
|
|
state.common_prefix_search_init();
|
|
if (terminal_flags_[state.node_id()]) {
|
|
agent.set_key(agent.query().ptr(), state.query_pos());
|
|
agent.set_key(terminal_flags_.rank1(state.node_id()));
|
|
return true;
|
|
}
|
|
}
|
|
|
|
while (state.query_pos() < agent.query().length()) {
|
|
if (!find_child(agent)) {
|
|
state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
|
|
return false;
|
|
} else if (terminal_flags_[state.node_id()]) {
|
|
agent.set_key(agent.query().ptr(), state.query_pos());
|
|
agent.set_key(terminal_flags_.rank1(state.node_id()));
|
|
return true;
|
|
}
|
|
}
|
|
state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
|
|
return false;
|
|
}
|
|
|
|
bool LoudsTrie::predictive_search(Agent &agent) const {
|
|
MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
|
|
|
|
State &state = agent.state();
|
|
if (state.status_code() == MARISA_END_OF_PREDICTIVE_SEARCH) {
|
|
return false;
|
|
}
|
|
|
|
if (state.status_code() != MARISA_READY_TO_PREDICTIVE_SEARCH) {
|
|
state.predictive_search_init();
|
|
while (state.query_pos() < agent.query().length()) {
|
|
if (!predictive_find_child(agent)) {
|
|
state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
History history;
|
|
history.set_node_id(state.node_id());
|
|
history.set_key_pos(state.key_buf().size());
|
|
state.history().push_back(history);
|
|
state.set_history_pos(1);
|
|
|
|
if (terminal_flags_[state.node_id()]) {
|
|
agent.set_key(state.key_buf().begin(), state.key_buf().size());
|
|
agent.set_key(terminal_flags_.rank1(state.node_id()));
|
|
return true;
|
|
}
|
|
}
|
|
|
|
for ( ; ; ) {
|
|
if (state.history_pos() == state.history().size()) {
|
|
const History ¤t = state.history().back();
|
|
History next;
|
|
next.set_louds_pos(louds_.select0(current.node_id()) + 1);
|
|
next.set_node_id(next.louds_pos() - current.node_id() - 1);
|
|
state.history().push_back(next);
|
|
}
|
|
|
|
History &next = state.history()[state.history_pos()];
|
|
const bool link_flag = louds_[next.louds_pos()];
|
|
next.set_louds_pos(next.louds_pos() + 1);
|
|
if (link_flag) {
|
|
state.set_history_pos(state.history_pos() + 1);
|
|
if (link_flags_[next.node_id()]) {
|
|
next.set_link_id(update_link_id(next.link_id(), next.node_id()));
|
|
restore(agent, get_link(next.node_id(), next.link_id()));
|
|
} else {
|
|
state.key_buf().push_back((char)bases_[next.node_id()]);
|
|
}
|
|
next.set_key_pos(state.key_buf().size());
|
|
|
|
if (terminal_flags_[next.node_id()]) {
|
|
if (next.key_id() == MARISA_INVALID_KEY_ID) {
|
|
next.set_key_id(terminal_flags_.rank1(next.node_id()));
|
|
} else {
|
|
next.set_key_id(next.key_id() + 1);
|
|
}
|
|
agent.set_key(state.key_buf().begin(), state.key_buf().size());
|
|
agent.set_key(next.key_id());
|
|
return true;
|
|
}
|
|
} else if (state.history_pos() != 1) {
|
|
History ¤t = state.history()[state.history_pos() - 1];
|
|
current.set_node_id(current.node_id() + 1);
|
|
const History &prev =
|
|
state.history()[state.history_pos() - 2];
|
|
state.key_buf().resize(prev.key_pos());
|
|
state.set_history_pos(state.history_pos() - 1);
|
|
} else {
|
|
state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
std::size_t LoudsTrie::total_size() const {
|
|
return louds_.total_size() + terminal_flags_.total_size()
|
|
+ link_flags_.total_size() + bases_.total_size()
|
|
+ extras_.total_size() + tail_.total_size()
|
|
+ ((next_trie_.get() != NULL) ? next_trie_->total_size() : 0)
|
|
+ cache_.total_size();
|
|
}
|
|
|
|
std::size_t LoudsTrie::io_size() const {
|
|
return Header().io_size() + louds_.io_size()
|
|
+ terminal_flags_.io_size() + link_flags_.io_size()
|
|
+ bases_.io_size() + extras_.io_size() + tail_.io_size()
|
|
+ ((next_trie_.get() != NULL) ?
|
|
(next_trie_->io_size() - Header().io_size()) : 0)
|
|
+ cache_.io_size() + (sizeof(UInt32) * 2);
|
|
}
|
|
|
|
void LoudsTrie::clear() {
|
|
LoudsTrie().swap(*this);
|
|
}
|
|
|
|
void LoudsTrie::swap(LoudsTrie &rhs) {
|
|
louds_.swap(rhs.louds_);
|
|
terminal_flags_.swap(rhs.terminal_flags_);
|
|
link_flags_.swap(rhs.link_flags_);
|
|
bases_.swap(rhs.bases_);
|
|
extras_.swap(rhs.extras_);
|
|
tail_.swap(rhs.tail_);
|
|
next_trie_.swap(rhs.next_trie_);
|
|
cache_.swap(rhs.cache_);
|
|
marisa::swap(cache_mask_, rhs.cache_mask_);
|
|
marisa::swap(num_l1_nodes_, rhs.num_l1_nodes_);
|
|
config_.swap(rhs.config_);
|
|
mapper_.swap(rhs.mapper_);
|
|
}
|
|
|
|
void LoudsTrie::build_(Keyset &keyset, const Config &config) {
|
|
Vector<Key> keys;
|
|
keys.resize(keyset.size());
|
|
for (std::size_t i = 0; i < keyset.size(); ++i) {
|
|
keys[i].set_str(keyset[i].ptr(), keyset[i].length());
|
|
keys[i].set_weight(keyset[i].weight());
|
|
}
|
|
|
|
Vector<UInt32> terminals;
|
|
build_trie(keys, &terminals, config, 1);
|
|
|
|
typedef std::pair<UInt32, UInt32> TerminalIdPair;
|
|
|
|
Vector<TerminalIdPair> pairs;
|
|
pairs.resize(terminals.size());
|
|
for (std::size_t i = 0; i < pairs.size(); ++i) {
|
|
pairs[i].first = terminals[i];
|
|
pairs[i].second = (UInt32)i;
|
|
}
|
|
terminals.clear();
|
|
std::sort(pairs.begin(), pairs.end());
|
|
|
|
std::size_t node_id = 0;
|
|
for (std::size_t i = 0; i < pairs.size(); ++i) {
|
|
while (node_id < pairs[i].first) {
|
|
terminal_flags_.push_back(false);
|
|
++node_id;
|
|
}
|
|
if (node_id == pairs[i].first) {
|
|
terminal_flags_.push_back(true);
|
|
++node_id;
|
|
}
|
|
}
|
|
while (node_id < bases_.size()) {
|
|
terminal_flags_.push_back(false);
|
|
++node_id;
|
|
}
|
|
terminal_flags_.push_back(false);
|
|
terminal_flags_.build(false, true);
|
|
|
|
for (std::size_t i = 0; i < keyset.size(); ++i) {
|
|
keyset[pairs[i].second].set_id(terminal_flags_.rank1(pairs[i].first));
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void LoudsTrie::build_trie(Vector<T> &keys,
|
|
Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
|
|
build_current_trie(keys, terminals, config, trie_id);
|
|
|
|
Vector<UInt32> next_terminals;
|
|
if (!keys.empty()) {
|
|
build_next_trie(keys, &next_terminals, config, trie_id);
|
|
}
|
|
|
|
if (next_trie_.get() != NULL) {
|
|
config_.parse(static_cast<int>((next_trie_->num_tries() + 1)) |
|
|
next_trie_->tail_mode() | next_trie_->node_order());
|
|
} else {
|
|
config_.parse(1 | tail_.mode() | config.node_order() |
|
|
config.cache_level());
|
|
}
|
|
|
|
link_flags_.build(false, false);
|
|
std::size_t node_id = 0;
|
|
for (std::size_t i = 0; i < next_terminals.size(); ++i) {
|
|
while (!link_flags_[node_id]) {
|
|
++node_id;
|
|
}
|
|
bases_[node_id] = (UInt8)(next_terminals[i] % 256);
|
|
next_terminals[i] /= 256;
|
|
++node_id;
|
|
}
|
|
extras_.build(next_terminals);
|
|
fill_cache();
|
|
}
|
|
|
|
template <typename T>
|
|
void LoudsTrie::build_current_trie(Vector<T> &keys,
|
|
Vector<UInt32> *terminals, const Config &config,
|
|
std::size_t trie_id) try {
|
|
for (std::size_t i = 0; i < keys.size(); ++i) {
|
|
keys[i].set_id(i);
|
|
}
|
|
const std::size_t num_keys = Algorithm().sort(keys.begin(), keys.end());
|
|
reserve_cache(config, trie_id, num_keys);
|
|
|
|
louds_.push_back(true);
|
|
louds_.push_back(false);
|
|
bases_.push_back('\0');
|
|
link_flags_.push_back(false);
|
|
|
|
Vector<T> next_keys;
|
|
std::queue<Range> queue;
|
|
Vector<WeightedRange> w_ranges;
|
|
|
|
queue.push(make_range(0, keys.size(), 0));
|
|
while (!queue.empty()) {
|
|
const std::size_t node_id = link_flags_.size() - queue.size();
|
|
|
|
Range range = queue.front();
|
|
queue.pop();
|
|
|
|
while ((range.begin() < range.end()) &&
|
|
(keys[range.begin()].length() == range.key_pos())) {
|
|
keys[range.begin()].set_terminal(node_id);
|
|
range.set_begin(range.begin() + 1);
|
|
}
|
|
|
|
if (range.begin() == range.end()) {
|
|
louds_.push_back(false);
|
|
continue;
|
|
}
|
|
|
|
w_ranges.clear();
|
|
double weight = keys[range.begin()].weight();
|
|
for (std::size_t i = range.begin() + 1; i < range.end(); ++i) {
|
|
if (keys[i - 1][range.key_pos()] != keys[i][range.key_pos()]) {
|
|
w_ranges.push_back(make_weighted_range(
|
|
range.begin(), i, range.key_pos(), (float)weight));
|
|
range.set_begin(i);
|
|
weight = 0.0;
|
|
}
|
|
weight += keys[i].weight();
|
|
}
|
|
w_ranges.push_back(make_weighted_range(
|
|
range.begin(), range.end(), range.key_pos(), (float)weight));
|
|
if (config.node_order() == MARISA_WEIGHT_ORDER) {
|
|
std::stable_sort(w_ranges.begin(), w_ranges.end(),
|
|
std::greater<WeightedRange>());
|
|
}
|
|
|
|
if (node_id == 0) {
|
|
num_l1_nodes_ = w_ranges.size();
|
|
}
|
|
|
|
for (std::size_t i = 0; i < w_ranges.size(); ++i) {
|
|
WeightedRange &w_range = w_ranges[i];
|
|
std::size_t key_pos = w_range.key_pos() + 1;
|
|
while (key_pos < keys[w_range.begin()].length()) {
|
|
std::size_t j;
|
|
for (j = w_range.begin() + 1; j < w_range.end(); ++j) {
|
|
if (keys[j - 1][key_pos] != keys[j][key_pos]) {
|
|
break;
|
|
}
|
|
}
|
|
if (j < w_range.end()) {
|
|
break;
|
|
}
|
|
++key_pos;
|
|
}
|
|
cache<T>(node_id, bases_.size(), w_range.weight(),
|
|
keys[w_range.begin()][w_range.key_pos()]);
|
|
|
|
if (key_pos == w_range.key_pos() + 1) {
|
|
bases_.push_back(static_cast<unsigned char>(
|
|
keys[w_range.begin()][w_range.key_pos()]));
|
|
link_flags_.push_back(false);
|
|
} else {
|
|
bases_.push_back('\0');
|
|
link_flags_.push_back(true);
|
|
T next_key;
|
|
next_key.set_str(keys[w_range.begin()].ptr(),
|
|
keys[w_range.begin()].length());
|
|
next_key.substr(w_range.key_pos(), key_pos - w_range.key_pos());
|
|
next_key.set_weight(w_range.weight());
|
|
next_keys.push_back(next_key);
|
|
}
|
|
w_range.set_key_pos(key_pos);
|
|
queue.push(w_range.range());
|
|
louds_.push_back(true);
|
|
}
|
|
louds_.push_back(false);
|
|
}
|
|
|
|
louds_.push_back(false);
|
|
louds_.build(trie_id == 1, true);
|
|
bases_.shrink();
|
|
|
|
build_terminals(keys, terminals);
|
|
keys.swap(next_keys);
|
|
} catch (const std::bad_alloc &) {
|
|
MARISA_THROW(MARISA_MEMORY_ERROR, "std::bad_alloc");
|
|
}
|
|
|
|
template <>
|
|
void LoudsTrie::build_next_trie(Vector<Key> &keys,
|
|
Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
|
|
if (trie_id == config.num_tries()) {
|
|
Vector<Entry> entries;
|
|
entries.resize(keys.size());
|
|
for (std::size_t i = 0; i < keys.size(); ++i) {
|
|
entries[i].set_str(keys[i].ptr(), keys[i].length());
|
|
}
|
|
tail_.build(entries, terminals, config.tail_mode());
|
|
return;
|
|
}
|
|
Vector<ReverseKey> reverse_keys;
|
|
reverse_keys.resize(keys.size());
|
|
for (std::size_t i = 0; i < keys.size(); ++i) {
|
|
reverse_keys[i].set_str(keys[i].ptr(), keys[i].length());
|
|
reverse_keys[i].set_weight(keys[i].weight());
|
|
}
|
|
keys.clear();
|
|
next_trie_.reset(new (std::nothrow) LoudsTrie);
|
|
MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
|
|
next_trie_->build_trie(reverse_keys, terminals, config, trie_id + 1);
|
|
}
|
|
|
|
template <>
|
|
void LoudsTrie::build_next_trie(Vector<ReverseKey> &keys,
|
|
Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
|
|
if (trie_id == config.num_tries()) {
|
|
Vector<Entry> entries;
|
|
entries.resize(keys.size());
|
|
for (std::size_t i = 0; i < keys.size(); ++i) {
|
|
entries[i].set_str(keys[i].ptr(), keys[i].length());
|
|
}
|
|
tail_.build(entries, terminals, config.tail_mode());
|
|
return;
|
|
}
|
|
next_trie_.reset(new (std::nothrow) LoudsTrie);
|
|
MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
|
|
next_trie_->build_trie(keys, terminals, config, trie_id + 1);
|
|
}
|
|
|
|
template <typename T>
|
|
void LoudsTrie::build_terminals(const Vector<T> &keys,
|
|
Vector<UInt32> *terminals) const {
|
|
Vector<UInt32> temp;
|
|
temp.resize(keys.size());
|
|
for (std::size_t i = 0; i < keys.size(); ++i) {
|
|
temp[keys[i].id()] = (UInt32)keys[i].terminal();
|
|
}
|
|
terminals->swap(temp);
|
|
}
|
|
|
|
template <>
|
|
void LoudsTrie::cache<Key>(std::size_t parent, std::size_t child,
|
|
float weight, char label) {
|
|
MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
|
|
|
|
const std::size_t cache_id = get_cache_id(parent, label);
|
|
if (weight > cache_[cache_id].weight()) {
|
|
cache_[cache_id].set_parent(parent);
|
|
cache_[cache_id].set_child(child);
|
|
cache_[cache_id].set_weight(weight);
|
|
}
|
|
}
|
|
|
|
void LoudsTrie::reserve_cache(const Config &config, std::size_t trie_id,
|
|
std::size_t num_keys) {
|
|
std::size_t cache_size = (trie_id == 1) ? 256 : 1;
|
|
while (cache_size < (num_keys / config.cache_level())) {
|
|
cache_size *= 2;
|
|
}
|
|
cache_.resize(cache_size);
|
|
cache_mask_ = cache_size - 1;
|
|
}
|
|
|
|
template <>
|
|
void LoudsTrie::cache<ReverseKey>(std::size_t parent, std::size_t child,
|
|
float weight, char) {
|
|
MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
|
|
|
|
const std::size_t cache_id = get_cache_id(child);
|
|
if (weight > cache_[cache_id].weight()) {
|
|
cache_[cache_id].set_parent(parent);
|
|
cache_[cache_id].set_child(child);
|
|
cache_[cache_id].set_weight(weight);
|
|
}
|
|
}
|
|
|
|
void LoudsTrie::fill_cache() {
|
|
for (std::size_t i = 0; i < cache_.size(); ++i) {
|
|
const std::size_t node_id = cache_[i].child();
|
|
if (node_id != 0) {
|
|
cache_[i].set_base(bases_[node_id]);
|
|
cache_[i].set_extra(!link_flags_[node_id] ?
|
|
MARISA_INVALID_EXTRA : extras_[link_flags_.rank1(node_id)]);
|
|
} else {
|
|
cache_[i].set_parent(MARISA_UINT32_MAX);
|
|
cache_[i].set_child(MARISA_UINT32_MAX);
|
|
}
|
|
}
|
|
}
|
|
|
|
void LoudsTrie::map_(Mapper &mapper) {
|
|
louds_.map(mapper);
|
|
terminal_flags_.map(mapper);
|
|
link_flags_.map(mapper);
|
|
bases_.map(mapper);
|
|
extras_.map(mapper);
|
|
tail_.map(mapper);
|
|
if ((link_flags_.num_1s() != 0) && tail_.empty()) {
|
|
next_trie_.reset(new (std::nothrow) LoudsTrie);
|
|
MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
|
|
next_trie_->map_(mapper);
|
|
}
|
|
cache_.map(mapper);
|
|
cache_mask_ = cache_.size() - 1;
|
|
{
|
|
UInt32 temp_num_l1_nodes;
|
|
mapper.map(&temp_num_l1_nodes);
|
|
num_l1_nodes_ = temp_num_l1_nodes;
|
|
}
|
|
{
|
|
UInt32 temp_config_flags;
|
|
mapper.map(&temp_config_flags);
|
|
config_.parse((int)temp_config_flags);
|
|
}
|
|
}
|
|
|
|
void LoudsTrie::read_(Reader &reader) {
|
|
louds_.read(reader);
|
|
terminal_flags_.read(reader);
|
|
link_flags_.read(reader);
|
|
bases_.read(reader);
|
|
extras_.read(reader);
|
|
tail_.read(reader);
|
|
if ((link_flags_.num_1s() != 0) && tail_.empty()) {
|
|
next_trie_.reset(new (std::nothrow) LoudsTrie);
|
|
MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
|
|
next_trie_->read_(reader);
|
|
}
|
|
cache_.read(reader);
|
|
cache_mask_ = cache_.size() - 1;
|
|
{
|
|
UInt32 temp_num_l1_nodes;
|
|
reader.read(&temp_num_l1_nodes);
|
|
num_l1_nodes_ = temp_num_l1_nodes;
|
|
}
|
|
{
|
|
UInt32 temp_config_flags;
|
|
reader.read(&temp_config_flags);
|
|
config_.parse((int)temp_config_flags);
|
|
}
|
|
}
|
|
|
|
void LoudsTrie::write_(Writer &writer) const {
|
|
louds_.write(writer);
|
|
terminal_flags_.write(writer);
|
|
link_flags_.write(writer);
|
|
bases_.write(writer);
|
|
extras_.write(writer);
|
|
tail_.write(writer);
|
|
if (next_trie_.get() != NULL) {
|
|
next_trie_->write_(writer);
|
|
}
|
|
cache_.write(writer);
|
|
writer.write((UInt32)num_l1_nodes_);
|
|
writer.write((UInt32)config_.flags());
|
|
}
|
|
|
|
bool LoudsTrie::find_child(Agent &agent) const {
|
|
MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
|
|
MARISA_BOUND_ERROR);
|
|
|
|
State &state = agent.state();
|
|
const std::size_t cache_id = get_cache_id(state.node_id(),
|
|
agent.query()[state.query_pos()]);
|
|
if (state.node_id() == cache_[cache_id].parent()) {
|
|
if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
|
|
if (!match(agent, cache_[cache_id].link())) {
|
|
return false;
|
|
}
|
|
} else {
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
}
|
|
state.set_node_id(cache_[cache_id].child());
|
|
return true;
|
|
}
|
|
|
|
std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
|
|
if (!louds_[louds_pos]) {
|
|
return false;
|
|
}
|
|
state.set_node_id(louds_pos - state.node_id() - 1);
|
|
std::size_t link_id = MARISA_INVALID_LINK_ID;
|
|
do {
|
|
if (link_flags_[state.node_id()]) {
|
|
link_id = update_link_id(link_id, state.node_id());
|
|
const std::size_t prev_query_pos = state.query_pos();
|
|
if (match(agent, get_link(state.node_id(), link_id))) {
|
|
return true;
|
|
} else if (state.query_pos() != prev_query_pos) {
|
|
return false;
|
|
}
|
|
} else if (bases_[state.node_id()] ==
|
|
(UInt8)agent.query()[state.query_pos()]) {
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
return true;
|
|
}
|
|
state.set_node_id(state.node_id() + 1);
|
|
++louds_pos;
|
|
} while (louds_[louds_pos]);
|
|
return false;
|
|
}
|
|
|
|
bool LoudsTrie::predictive_find_child(Agent &agent) const {
|
|
MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
|
|
MARISA_BOUND_ERROR);
|
|
|
|
State &state = agent.state();
|
|
const std::size_t cache_id = get_cache_id(state.node_id(),
|
|
agent.query()[state.query_pos()]);
|
|
if (state.node_id() == cache_[cache_id].parent()) {
|
|
if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
|
|
if (!prefix_match(agent, cache_[cache_id].link())) {
|
|
return false;
|
|
}
|
|
} else {
|
|
state.key_buf().push_back(cache_[cache_id].label());
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
}
|
|
state.set_node_id(cache_[cache_id].child());
|
|
return true;
|
|
}
|
|
|
|
std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
|
|
if (!louds_[louds_pos]) {
|
|
return false;
|
|
}
|
|
state.set_node_id(louds_pos - state.node_id() - 1);
|
|
std::size_t link_id = MARISA_INVALID_LINK_ID;
|
|
do {
|
|
if (link_flags_[state.node_id()]) {
|
|
link_id = update_link_id(link_id, state.node_id());
|
|
const std::size_t prev_query_pos = state.query_pos();
|
|
if (prefix_match(agent, get_link(state.node_id(), link_id))) {
|
|
return true;
|
|
} else if (state.query_pos() != prev_query_pos) {
|
|
return false;
|
|
}
|
|
} else if (bases_[state.node_id()] ==
|
|
(UInt8)agent.query()[state.query_pos()]) {
|
|
state.key_buf().push_back((char)bases_[state.node_id()]);
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
return true;
|
|
}
|
|
state.set_node_id(state.node_id() + 1);
|
|
++louds_pos;
|
|
} while (louds_[louds_pos]);
|
|
return false;
|
|
}
|
|
|
|
void LoudsTrie::restore(Agent &agent, std::size_t link) const {
|
|
if (next_trie_.get() != NULL) {
|
|
next_trie_->restore_(agent, link);
|
|
} else {
|
|
tail_.restore(agent, link);
|
|
}
|
|
}
|
|
|
|
bool LoudsTrie::match(Agent &agent, std::size_t link) const {
|
|
if (next_trie_.get() != NULL) {
|
|
return next_trie_->match_(agent, link);
|
|
} else {
|
|
return tail_.match(agent, link);
|
|
}
|
|
}
|
|
|
|
bool LoudsTrie::prefix_match(Agent &agent, std::size_t link) const {
|
|
if (next_trie_.get() != NULL) {
|
|
return next_trie_->prefix_match_(agent, link);
|
|
} else {
|
|
return tail_.prefix_match(agent, link);
|
|
}
|
|
}
|
|
|
|
void LoudsTrie::restore_(Agent &agent, std::size_t node_id) const {
|
|
MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
|
|
|
|
State &state = agent.state();
|
|
for ( ; ; ) {
|
|
const std::size_t cache_id = get_cache_id(node_id);
|
|
if (node_id == cache_[cache_id].child()) {
|
|
if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
|
|
restore(agent, cache_[cache_id].link());
|
|
} else {
|
|
state.key_buf().push_back(cache_[cache_id].label());
|
|
}
|
|
|
|
node_id = cache_[cache_id].parent();
|
|
if (node_id == 0) {
|
|
return;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (link_flags_[node_id]) {
|
|
restore(agent, get_link(node_id));
|
|
} else {
|
|
state.key_buf().push_back((char)bases_[node_id]);
|
|
}
|
|
|
|
if (node_id <= num_l1_nodes_) {
|
|
return;
|
|
}
|
|
node_id = louds_.select1(node_id) - node_id - 1;
|
|
}
|
|
}
|
|
|
|
bool LoudsTrie::match_(Agent &agent, std::size_t node_id) const {
|
|
MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
|
|
MARISA_BOUND_ERROR);
|
|
MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
|
|
|
|
State &state = agent.state();
|
|
for ( ; ; ) {
|
|
const std::size_t cache_id = get_cache_id(node_id);
|
|
if (node_id == cache_[cache_id].child()) {
|
|
if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
|
|
if (!match(agent, cache_[cache_id].link())) {
|
|
return false;
|
|
}
|
|
} else if (cache_[cache_id].label() ==
|
|
agent.query()[state.query_pos()]) {
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
node_id = cache_[cache_id].parent();
|
|
if (node_id == 0) {
|
|
return true;
|
|
} else if (state.query_pos() >= agent.query().length()) {
|
|
return false;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (link_flags_[node_id]) {
|
|
if (next_trie_.get() != NULL) {
|
|
if (!match(agent, get_link(node_id))) {
|
|
return false;
|
|
}
|
|
} else if (!tail_.match(agent, get_link(node_id))) {
|
|
return false;
|
|
}
|
|
} else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
if (node_id <= num_l1_nodes_) {
|
|
return true;
|
|
} else if (state.query_pos() >= agent.query().length()) {
|
|
return false;
|
|
}
|
|
node_id = louds_.select1(node_id) - node_id - 1;
|
|
}
|
|
}
|
|
|
|
bool LoudsTrie::prefix_match_(Agent &agent, std::size_t node_id) const {
|
|
MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
|
|
MARISA_BOUND_ERROR);
|
|
MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
|
|
|
|
State &state = agent.state();
|
|
for ( ; ; ) {
|
|
const std::size_t cache_id = get_cache_id(node_id);
|
|
if (node_id == cache_[cache_id].child()) {
|
|
if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
|
|
if (!prefix_match(agent, cache_[cache_id].link())) {
|
|
return false;
|
|
}
|
|
} else if (cache_[cache_id].label() ==
|
|
agent.query()[state.query_pos()]) {
|
|
state.key_buf().push_back(cache_[cache_id].label());
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
node_id = cache_[cache_id].parent();
|
|
if (node_id == 0) {
|
|
return true;
|
|
}
|
|
} else {
|
|
if (link_flags_[node_id]) {
|
|
if (!prefix_match(agent, get_link(node_id))) {
|
|
return false;
|
|
}
|
|
} else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
|
|
state.key_buf().push_back((char)bases_[node_id]);
|
|
state.set_query_pos(state.query_pos() + 1);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
if (node_id <= num_l1_nodes_) {
|
|
return true;
|
|
}
|
|
node_id = louds_.select1(node_id) - node_id - 1;
|
|
}
|
|
|
|
if (state.query_pos() >= agent.query().length()) {
|
|
restore_(agent, node_id);
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
std::size_t LoudsTrie::get_cache_id(std::size_t node_id, char label) const {
|
|
return (node_id ^ (node_id << 5) ^ (UInt8)label) & cache_mask_;
|
|
}
|
|
|
|
std::size_t LoudsTrie::get_cache_id(std::size_t node_id) const {
|
|
return node_id & cache_mask_;
|
|
}
|
|
|
|
std::size_t LoudsTrie::get_link(std::size_t node_id) const {
|
|
return bases_[node_id] | (extras_[link_flags_.rank1(node_id)] * 256);
|
|
}
|
|
|
|
std::size_t LoudsTrie::get_link(std::size_t node_id,
|
|
std::size_t link_id) const {
|
|
return bases_[node_id] | (extras_[link_id] * 256);
|
|
}
|
|
|
|
std::size_t LoudsTrie::update_link_id(std::size_t link_id,
|
|
std::size_t node_id) const {
|
|
return (link_id == MARISA_INVALID_LINK_ID) ?
|
|
link_flags_.rank1(node_id) : (link_id + 1);
|
|
}
|
|
|
|
} // namespace trie
|
|
} // namespace grimoire
|
|
} // namespace marisa
|