845 lines
26 KiB
C++
845 lines
26 KiB
C++
#include "marisa/grimoire/vector/pop-count.h"
|
|
#include "marisa/grimoire/vector/bit-vector.h"
|
|
|
|
namespace marisa {
|
|
namespace grimoire {
|
|
namespace vector {
|
|
namespace {
|
|
|
|
#ifdef MARISA_USE_BMI2
|
|
std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
|
|
#ifdef _MSC_VER
|
|
unsigned long pos;
|
|
::_BitScanForward64(&pos, _pdep_u64(1ULL << i, unit));
|
|
return bit_id + pos;
|
|
#else // _MSC_VER
|
|
return bit_id + ::__builtin_ctzll(_pdep_u64(1ULL << i, unit));
|
|
#endif // _MSC_VER
|
|
}
|
|
#else // MARISA_USE_BMI2
|
|
const UInt8 SELECT_TABLE[8][256] = {
|
|
{
|
|
7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
7, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
6, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
5, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0,
|
|
4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
|
|
},
|
|
{
|
|
7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
|
|
5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1,
|
|
6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
|
|
5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 7, 7, 1, 7, 2, 2, 1, 7, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
|
|
5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
7, 6, 6, 1, 6, 2, 2, 1, 6, 3, 3, 1, 3, 2, 2, 1,
|
|
6, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1,
|
|
6, 5, 5, 1, 5, 2, 2, 1, 5, 3, 3, 1, 3, 2, 2, 1,
|
|
5, 4, 4, 1, 4, 2, 2, 1, 4, 3, 3, 1, 3, 2, 2, 1
|
|
},
|
|
{
|
|
7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2,
|
|
7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2,
|
|
7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2,
|
|
7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2,
|
|
6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 7, 7, 7, 7, 7, 7, 2, 7, 7, 7, 3, 7, 3, 3, 2,
|
|
7, 7, 7, 4, 7, 4, 4, 2, 7, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 7, 7, 5, 7, 5, 5, 2, 7, 5, 5, 3, 5, 3, 3, 2,
|
|
7, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 7, 7, 6, 7, 6, 6, 2, 7, 6, 6, 3, 6, 3, 3, 2,
|
|
7, 6, 6, 4, 6, 4, 4, 2, 6, 4, 4, 3, 4, 3, 3, 2,
|
|
7, 6, 6, 5, 6, 5, 5, 2, 6, 5, 5, 3, 5, 3, 3, 2,
|
|
6, 5, 5, 4, 5, 4, 4, 2, 5, 4, 4, 3, 4, 3, 3, 2
|
|
},
|
|
{
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3,
|
|
7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3,
|
|
7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3,
|
|
7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3,
|
|
7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 4, 7, 7, 7, 4, 7, 4, 4, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 3,
|
|
7, 7, 7, 5, 7, 5, 5, 4, 7, 5, 5, 4, 5, 4, 4, 3,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 3,
|
|
7, 7, 7, 6, 7, 6, 6, 4, 7, 6, 6, 4, 6, 4, 4, 3,
|
|
7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 3,
|
|
7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3
|
|
},
|
|
{
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
|
|
7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
|
|
7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
|
|
7, 7, 7, 7, 7, 7, 7, 5, 7, 7, 7, 5, 7, 5, 5, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 4,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
|
|
7, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4
|
|
},
|
|
{
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 6, 7, 7, 7, 6, 7, 6, 6, 5
|
|
},
|
|
{
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6
|
|
},
|
|
{
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
|
|
}
|
|
};
|
|
|
|
#if MARISA_WORD_SIZE == 64
|
|
const UInt64 MASK_01 = 0x0101010101010101ULL;
|
|
#if !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
|
|
const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL;
|
|
const UInt64 MASK_33 = 0x3333333333333333ULL;
|
|
const UInt64 MASK_55 = 0x5555555555555555ULL;
|
|
#endif // !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
|
|
#if !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
|
|
const UInt64 MASK_80 = 0x8080808080808080ULL;
|
|
#endif // !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
|
|
|
|
std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
|
|
UInt64 counts;
|
|
{
|
|
#if defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
|
|
__m128i lower_nibbles = _mm_cvtsi64_si128(
|
|
static_cast<long long>(unit & 0x0F0F0F0F0F0F0F0FULL));
|
|
__m128i upper_nibbles = _mm_cvtsi64_si128(
|
|
static_cast<long long>(unit & 0xF0F0F0F0F0F0F0F0ULL));
|
|
upper_nibbles = _mm_srli_epi32(upper_nibbles, 4);
|
|
|
|
__m128i lower_counts =
|
|
_mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
|
|
lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles);
|
|
__m128i upper_counts =
|
|
_mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
|
|
upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles);
|
|
|
|
counts = static_cast<UInt64>(_mm_cvtsi128_si64(
|
|
_mm_add_epi8(lower_counts, upper_counts)));
|
|
#else // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
|
|
counts = unit - ((unit >> 1) & MASK_55);
|
|
counts = (counts & MASK_33) + ((counts >> 2) & MASK_33);
|
|
counts = (counts + (counts >> 4)) & MASK_0F;
|
|
#endif // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
|
|
counts *= MASK_01;
|
|
}
|
|
|
|
#if defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
|
|
UInt8 skip;
|
|
{
|
|
__m128i x = _mm_cvtsi64_si128(static_cast<long long>((i + 1) * MASK_01));
|
|
__m128i y = _mm_cvtsi64_si128(static_cast<long long>(counts));
|
|
x = _mm_cmpgt_epi8(x, y);
|
|
skip = (UInt8)PopCount::count(static_cast<UInt64>(_mm_cvtsi128_si64(x)));
|
|
}
|
|
#else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
|
|
const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01);
|
|
#ifdef _MSC_VER
|
|
unsigned long skip;
|
|
::_BitScanForward64(&skip, (x & MASK_80) >> 7);
|
|
#else // _MSC_VER
|
|
const int skip = ::__builtin_ctzll((x & MASK_80) >> 7);
|
|
#endif // _MSC_VER
|
|
#endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
|
|
|
|
bit_id += static_cast<std::size_t>(skip);
|
|
unit >>= skip;
|
|
i -= ((counts << 8) >> skip) & 0xFF;
|
|
|
|
return bit_id + SELECT_TABLE[i][unit & 0xFF];
|
|
}
|
|
#else // MARISA_WORD_SIZE == 64
|
|
#ifdef MARISA_USE_SSE2
|
|
const UInt8 POPCNT_TABLE[256] = {
|
|
0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32,
|
|
8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
|
|
8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
|
|
16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
|
|
8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
|
|
16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
|
|
16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
|
|
24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
|
|
8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
|
|
16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
|
|
16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
|
|
24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
|
|
16, 24, 24, 32, 24, 32, 32, 40, 24, 32, 32, 40, 32, 40, 40, 48,
|
|
24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
|
|
24, 32, 32, 40, 32, 40, 40, 48, 32, 40, 40, 48, 40, 48, 48, 56,
|
|
32, 40, 40, 48, 40, 48, 48, 56, 40, 48, 48, 56, 48, 56, 56, 64
|
|
};
|
|
|
|
std::size_t select_bit(std::size_t i, std::size_t bit_id,
|
|
UInt32 unit_lo, UInt32 unit_hi) {
|
|
__m128i unit;
|
|
{
|
|
__m128i lower_dword = _mm_cvtsi32_si128(unit_lo);
|
|
__m128i upper_dword = _mm_cvtsi32_si128(unit_hi);
|
|
upper_dword = _mm_slli_si128(upper_dword, 4);
|
|
unit = _mm_or_si128(lower_dword, upper_dword);
|
|
}
|
|
|
|
__m128i counts;
|
|
{
|
|
#ifdef MARISA_USE_SSSE3
|
|
__m128i lower_nibbles = _mm_set1_epi8(0x0F);
|
|
lower_nibbles = _mm_and_si128(lower_nibbles, unit);
|
|
__m128i upper_nibbles = _mm_set1_epi8((UInt8)0xF0);
|
|
upper_nibbles = _mm_and_si128(upper_nibbles, unit);
|
|
upper_nibbles = _mm_srli_epi32(upper_nibbles, 4);
|
|
|
|
__m128i lower_counts =
|
|
_mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
|
|
lower_counts = _mm_shuffle_epi8(lower_counts, lower_nibbles);
|
|
__m128i upper_counts =
|
|
_mm_set_epi8(4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
|
|
upper_counts = _mm_shuffle_epi8(upper_counts, upper_nibbles);
|
|
|
|
counts = _mm_add_epi8(lower_counts, upper_counts);
|
|
#else // MARISA_USE_SSSE3
|
|
__m128i x = _mm_srli_epi32(unit, 1);
|
|
x = _mm_and_si128(x, _mm_set1_epi8(0x55));
|
|
x = _mm_sub_epi8(unit, x);
|
|
|
|
__m128i y = _mm_srli_epi32(x, 2);
|
|
y = _mm_and_si128(y, _mm_set1_epi8(0x33));
|
|
x = _mm_and_si128(x, _mm_set1_epi8(0x33));
|
|
x = _mm_add_epi8(x, y);
|
|
|
|
y = _mm_srli_epi32(x, 4);
|
|
x = _mm_add_epi8(x, y);
|
|
counts = _mm_and_si128(x, _mm_set1_epi8(0x0F));
|
|
#endif // MARISA_USE_SSSE3
|
|
}
|
|
|
|
__m128i accumulated_counts;
|
|
{
|
|
__m128i x = counts;
|
|
x = _mm_slli_si128(x, 1);
|
|
__m128i y = counts;
|
|
y = _mm_add_epi32(y, x);
|
|
|
|
x = y;
|
|
y = _mm_slli_si128(y, 2);
|
|
x = _mm_add_epi32(x, y);
|
|
|
|
y = x;
|
|
x = _mm_slli_si128(x, 4);
|
|
y = _mm_add_epi32(y, x);
|
|
|
|
accumulated_counts = _mm_set_epi32(0x7F7F7F7FU, 0x7F7F7F7FU, 0, 0);
|
|
accumulated_counts = _mm_or_si128(accumulated_counts, y);
|
|
}
|
|
|
|
UInt8 skip;
|
|
{
|
|
__m128i x = _mm_set1_epi8((UInt8)(i + 1));
|
|
x = _mm_cmpgt_epi8(x, accumulated_counts);
|
|
skip = POPCNT_TABLE[_mm_movemask_epi8(x)];
|
|
}
|
|
|
|
UInt8 byte;
|
|
{
|
|
#ifdef _MSC_VER
|
|
__declspec(align(16)) UInt8 unit_bytes[16];
|
|
__declspec(align(16)) UInt8 accumulated_counts_bytes[16];
|
|
#else // _MSC_VER
|
|
UInt8 unit_bytes[16] __attribute__ ((aligned (16)));
|
|
UInt8 accumulated_counts_bytes[16] __attribute__ ((aligned (16)));
|
|
#endif // _MSC_VER
|
|
accumulated_counts = _mm_slli_si128(accumulated_counts, 1);
|
|
_mm_store_si128(reinterpret_cast<__m128i *>(unit_bytes), unit);
|
|
_mm_store_si128(reinterpret_cast<__m128i *>(accumulated_counts_bytes),
|
|
accumulated_counts);
|
|
|
|
bit_id += skip;
|
|
byte = unit_bytes[skip / 8];
|
|
i -= accumulated_counts_bytes[skip / 8];
|
|
}
|
|
|
|
return bit_id + SELECT_TABLE[i][byte];
|
|
}
|
|
#endif // MARISA_USE_SSE2
|
|
#endif // MARISA_WORD_SIZE == 64
|
|
#endif // MARISA_USE_BMI2
|
|
|
|
} // namespace
|
|
|
|
#if MARISA_WORD_SIZE == 64
|
|
|
|
std::size_t BitVector::rank1(std::size_t i) const {
|
|
MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR);
|
|
MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR);
|
|
|
|
const RankIndex &rank = ranks_[i / 512];
|
|
std::size_t offset = rank.abs();
|
|
switch ((i / 64) % 8) {
|
|
case 1: {
|
|
offset += rank.rel1();
|
|
break;
|
|
}
|
|
case 2: {
|
|
offset += rank.rel2();
|
|
break;
|
|
}
|
|
case 3: {
|
|
offset += rank.rel3();
|
|
break;
|
|
}
|
|
case 4: {
|
|
offset += rank.rel4();
|
|
break;
|
|
}
|
|
case 5: {
|
|
offset += rank.rel5();
|
|
break;
|
|
}
|
|
case 6: {
|
|
offset += rank.rel6();
|
|
break;
|
|
}
|
|
case 7: {
|
|
offset += rank.rel7();
|
|
break;
|
|
}
|
|
}
|
|
offset += PopCount::count(units_[i / 64] & ((1ULL << (i % 64)) - 1));
|
|
return offset;
|
|
}
|
|
|
|
std::size_t BitVector::select0(std::size_t i) const {
|
|
MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR);
|
|
MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR);
|
|
|
|
const std::size_t select_id = i / 512;
|
|
MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR);
|
|
if ((i % 512) == 0) {
|
|
return select0s_[select_id];
|
|
}
|
|
std::size_t begin = select0s_[select_id] / 512;
|
|
std::size_t end = (select0s_[select_id + 1] + 511) / 512;
|
|
if (begin + 10 >= end) {
|
|
while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) {
|
|
++begin;
|
|
}
|
|
} else {
|
|
while (begin + 1 < end) {
|
|
const std::size_t middle = (begin + end) / 2;
|
|
if (i < (middle * 512) - ranks_[middle].abs()) {
|
|
end = middle;
|
|
} else {
|
|
begin = middle;
|
|
}
|
|
}
|
|
}
|
|
const std::size_t rank_id = begin;
|
|
i -= (rank_id * 512) - ranks_[rank_id].abs();
|
|
|
|
const RankIndex &rank = ranks_[rank_id];
|
|
std::size_t unit_id = rank_id * 8;
|
|
if (i < (256U - rank.rel4())) {
|
|
if (i < (128U - rank.rel2())) {
|
|
if (i >= (64U - rank.rel1())) {
|
|
unit_id += 1;
|
|
i -= 64 - rank.rel1();
|
|
}
|
|
} else if (i < (192U - rank.rel3())) {
|
|
unit_id += 2;
|
|
i -= 128 - rank.rel2();
|
|
} else {
|
|
unit_id += 3;
|
|
i -= 192 - rank.rel3();
|
|
}
|
|
} else if (i < (384U - rank.rel6())) {
|
|
if (i < (320U - rank.rel5())) {
|
|
unit_id += 4;
|
|
i -= 256 - rank.rel4();
|
|
} else {
|
|
unit_id += 5;
|
|
i -= 320 - rank.rel5();
|
|
}
|
|
} else if (i < (448U - rank.rel7())) {
|
|
unit_id += 6;
|
|
i -= 384 - rank.rel6();
|
|
} else {
|
|
unit_id += 7;
|
|
i -= 448 - rank.rel7();
|
|
}
|
|
|
|
return select_bit(i, unit_id * 64, ~units_[unit_id]);
|
|
}
|
|
|
|
std::size_t BitVector::select1(std::size_t i) const {
|
|
MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR);
|
|
MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR);
|
|
|
|
const std::size_t select_id = i / 512;
|
|
MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR);
|
|
if ((i % 512) == 0) {
|
|
return select1s_[select_id];
|
|
}
|
|
std::size_t begin = select1s_[select_id] / 512;
|
|
std::size_t end = (select1s_[select_id + 1] + 511) / 512;
|
|
if (begin + 10 >= end) {
|
|
while (i >= ranks_[begin + 1].abs()) {
|
|
++begin;
|
|
}
|
|
} else {
|
|
while (begin + 1 < end) {
|
|
const std::size_t middle = (begin + end) / 2;
|
|
if (i < ranks_[middle].abs()) {
|
|
end = middle;
|
|
} else {
|
|
begin = middle;
|
|
}
|
|
}
|
|
}
|
|
const std::size_t rank_id = begin;
|
|
i -= ranks_[rank_id].abs();
|
|
|
|
const RankIndex &rank = ranks_[rank_id];
|
|
std::size_t unit_id = rank_id * 8;
|
|
if (i < rank.rel4()) {
|
|
if (i < rank.rel2()) {
|
|
if (i >= rank.rel1()) {
|
|
unit_id += 1;
|
|
i -= rank.rel1();
|
|
}
|
|
} else if (i < rank.rel3()) {
|
|
unit_id += 2;
|
|
i -= rank.rel2();
|
|
} else {
|
|
unit_id += 3;
|
|
i -= rank.rel3();
|
|
}
|
|
} else if (i < rank.rel6()) {
|
|
if (i < rank.rel5()) {
|
|
unit_id += 4;
|
|
i -= rank.rel4();
|
|
} else {
|
|
unit_id += 5;
|
|
i -= rank.rel5();
|
|
}
|
|
} else if (i < rank.rel7()) {
|
|
unit_id += 6;
|
|
i -= rank.rel6();
|
|
} else {
|
|
unit_id += 7;
|
|
i -= rank.rel7();
|
|
}
|
|
|
|
return select_bit(i, unit_id * 64, units_[unit_id]);
|
|
}
|
|
|
|
#else // MARISA_WORD_SIZE == 64
|
|
|
|
std::size_t BitVector::rank1(std::size_t i) const {
|
|
MARISA_DEBUG_IF(ranks_.empty(), MARISA_STATE_ERROR);
|
|
MARISA_DEBUG_IF(i > size_, MARISA_BOUND_ERROR);
|
|
|
|
const RankIndex &rank = ranks_[i / 512];
|
|
std::size_t offset = rank.abs();
|
|
switch ((i / 64) % 8) {
|
|
case 1: {
|
|
offset += rank.rel1();
|
|
break;
|
|
}
|
|
case 2: {
|
|
offset += rank.rel2();
|
|
break;
|
|
}
|
|
case 3: {
|
|
offset += rank.rel3();
|
|
break;
|
|
}
|
|
case 4: {
|
|
offset += rank.rel4();
|
|
break;
|
|
}
|
|
case 5: {
|
|
offset += rank.rel5();
|
|
break;
|
|
}
|
|
case 6: {
|
|
offset += rank.rel6();
|
|
break;
|
|
}
|
|
case 7: {
|
|
offset += rank.rel7();
|
|
break;
|
|
}
|
|
}
|
|
if (((i / 32) & 1) == 1) {
|
|
offset += PopCount::count(units_[(i / 32) - 1]);
|
|
}
|
|
offset += PopCount::count(units_[i / 32] & ((1U << (i % 32)) - 1));
|
|
return offset;
|
|
}
|
|
|
|
std::size_t BitVector::select0(std::size_t i) const {
|
|
MARISA_DEBUG_IF(select0s_.empty(), MARISA_STATE_ERROR);
|
|
MARISA_DEBUG_IF(i >= num_0s(), MARISA_BOUND_ERROR);
|
|
|
|
const std::size_t select_id = i / 512;
|
|
MARISA_DEBUG_IF((select_id + 1) >= select0s_.size(), MARISA_BOUND_ERROR);
|
|
if ((i % 512) == 0) {
|
|
return select0s_[select_id];
|
|
}
|
|
std::size_t begin = select0s_[select_id] / 512;
|
|
std::size_t end = (select0s_[select_id + 1] + 511) / 512;
|
|
if (begin + 10 >= end) {
|
|
while (i >= ((begin + 1) * 512) - ranks_[begin + 1].abs()) {
|
|
++begin;
|
|
}
|
|
} else {
|
|
while (begin + 1 < end) {
|
|
const std::size_t middle = (begin + end) / 2;
|
|
if (i < (middle * 512) - ranks_[middle].abs()) {
|
|
end = middle;
|
|
} else {
|
|
begin = middle;
|
|
}
|
|
}
|
|
}
|
|
const std::size_t rank_id = begin;
|
|
i -= (rank_id * 512) - ranks_[rank_id].abs();
|
|
|
|
const RankIndex &rank = ranks_[rank_id];
|
|
std::size_t unit_id = rank_id * 16;
|
|
if (i < (256U - rank.rel4())) {
|
|
if (i < (128U - rank.rel2())) {
|
|
if (i >= (64U - rank.rel1())) {
|
|
unit_id += 2;
|
|
i -= 64 - rank.rel1();
|
|
}
|
|
} else if (i < (192U - rank.rel3())) {
|
|
unit_id += 4;
|
|
i -= 128 - rank.rel2();
|
|
} else {
|
|
unit_id += 6;
|
|
i -= 192 - rank.rel3();
|
|
}
|
|
} else if (i < (384U - rank.rel6())) {
|
|
if (i < (320U - rank.rel5())) {
|
|
unit_id += 8;
|
|
i -= 256 - rank.rel4();
|
|
} else {
|
|
unit_id += 10;
|
|
i -= 320 - rank.rel5();
|
|
}
|
|
} else if (i < (448U - rank.rel7())) {
|
|
unit_id += 12;
|
|
i -= 384 - rank.rel6();
|
|
} else {
|
|
unit_id += 14;
|
|
i -= 448 - rank.rel7();
|
|
}
|
|
|
|
#ifdef MARISA_USE_SSE2
|
|
return select_bit(i, unit_id * 32, ~units_[unit_id], ~units_[unit_id + 1]);
|
|
#else // MARISA_USE_SSE2
|
|
UInt32 unit = ~units_[unit_id];
|
|
PopCount count(unit);
|
|
if (i >= count.lo32()) {
|
|
++unit_id;
|
|
i -= count.lo32();
|
|
unit = ~units_[unit_id];
|
|
count = PopCount(unit);
|
|
}
|
|
|
|
std::size_t bit_id = unit_id * 32;
|
|
if (i < count.lo16()) {
|
|
if (i >= count.lo8()) {
|
|
bit_id += 8;
|
|
unit >>= 8;
|
|
i -= count.lo8();
|
|
}
|
|
} else if (i < count.lo24()) {
|
|
bit_id += 16;
|
|
unit >>= 16;
|
|
i -= count.lo16();
|
|
} else {
|
|
bit_id += 24;
|
|
unit >>= 24;
|
|
i -= count.lo24();
|
|
}
|
|
return bit_id + SELECT_TABLE[i][unit & 0xFF];
|
|
#endif // MARISA_USE_SSE2
|
|
}
|
|
|
|
std::size_t BitVector::select1(std::size_t i) const {
|
|
MARISA_DEBUG_IF(select1s_.empty(), MARISA_STATE_ERROR);
|
|
MARISA_DEBUG_IF(i >= num_1s(), MARISA_BOUND_ERROR);
|
|
|
|
const std::size_t select_id = i / 512;
|
|
MARISA_DEBUG_IF((select_id + 1) >= select1s_.size(), MARISA_BOUND_ERROR);
|
|
if ((i % 512) == 0) {
|
|
return select1s_[select_id];
|
|
}
|
|
std::size_t begin = select1s_[select_id] / 512;
|
|
std::size_t end = (select1s_[select_id + 1] + 511) / 512;
|
|
if (begin + 10 >= end) {
|
|
while (i >= ranks_[begin + 1].abs()) {
|
|
++begin;
|
|
}
|
|
} else {
|
|
while (begin + 1 < end) {
|
|
const std::size_t middle = (begin + end) / 2;
|
|
if (i < ranks_[middle].abs()) {
|
|
end = middle;
|
|
} else {
|
|
begin = middle;
|
|
}
|
|
}
|
|
}
|
|
const std::size_t rank_id = begin;
|
|
i -= ranks_[rank_id].abs();
|
|
|
|
const RankIndex &rank = ranks_[rank_id];
|
|
std::size_t unit_id = rank_id * 16;
|
|
if (i < rank.rel4()) {
|
|
if (i < rank.rel2()) {
|
|
if (i >= rank.rel1()) {
|
|
unit_id += 2;
|
|
i -= rank.rel1();
|
|
}
|
|
} else if (i < rank.rel3()) {
|
|
unit_id += 4;
|
|
i -= rank.rel2();
|
|
} else {
|
|
unit_id += 6;
|
|
i -= rank.rel3();
|
|
}
|
|
} else if (i < rank.rel6()) {
|
|
if (i < rank.rel5()) {
|
|
unit_id += 8;
|
|
i -= rank.rel4();
|
|
} else {
|
|
unit_id += 10;
|
|
i -= rank.rel5();
|
|
}
|
|
} else if (i < rank.rel7()) {
|
|
unit_id += 12;
|
|
i -= rank.rel6();
|
|
} else {
|
|
unit_id += 14;
|
|
i -= rank.rel7();
|
|
}
|
|
|
|
#ifdef MARISA_USE_SSE2
|
|
return select_bit(i, unit_id * 32, units_[unit_id], units_[unit_id + 1]);
|
|
#else // MARISA_USE_SSE2
|
|
UInt32 unit = units_[unit_id];
|
|
PopCount count(unit);
|
|
if (i >= count.lo32()) {
|
|
++unit_id;
|
|
i -= count.lo32();
|
|
unit = units_[unit_id];
|
|
count = PopCount(unit);
|
|
}
|
|
|
|
std::size_t bit_id = unit_id * 32;
|
|
if (i < count.lo16()) {
|
|
if (i >= count.lo8()) {
|
|
bit_id += 8;
|
|
unit >>= 8;
|
|
i -= count.lo8();
|
|
}
|
|
} else if (i < count.lo24()) {
|
|
bit_id += 16;
|
|
unit >>= 16;
|
|
i -= count.lo16();
|
|
} else {
|
|
bit_id += 24;
|
|
unit >>= 24;
|
|
i -= count.lo24();
|
|
}
|
|
return bit_id + SELECT_TABLE[i][unit & 0xFF];
|
|
#endif // MARISA_USE_SSE2
|
|
}
|
|
|
|
#endif // MARISA_WORD_SIZE == 64
|
|
|
|
void BitVector::build_index(const BitVector &bv,
|
|
bool enables_select0, bool enables_select1) {
|
|
ranks_.resize((bv.size() / 512) + (((bv.size() % 512) != 0) ? 1 : 0) + 1);
|
|
|
|
std::size_t num_0s = 0;
|
|
std::size_t num_1s = 0;
|
|
|
|
for (std::size_t i = 0; i < bv.size(); ++i) {
|
|
if ((i % 64) == 0) {
|
|
const std::size_t rank_id = i / 512;
|
|
switch ((i / 64) % 8) {
|
|
case 0: {
|
|
ranks_[rank_id].set_abs(num_1s);
|
|
break;
|
|
}
|
|
case 1: {
|
|
ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
case 2: {
|
|
ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
case 3: {
|
|
ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
case 4: {
|
|
ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
case 5: {
|
|
ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
case 6: {
|
|
ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
case 7: {
|
|
ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (bv[i]) {
|
|
if (enables_select1 && ((num_1s % 512) == 0)) {
|
|
select1s_.push_back(static_cast<UInt32>(i));
|
|
}
|
|
++num_1s;
|
|
} else {
|
|
if (enables_select0 && ((num_0s % 512) == 0)) {
|
|
select0s_.push_back(static_cast<UInt32>(i));
|
|
}
|
|
++num_0s;
|
|
}
|
|
}
|
|
|
|
if ((bv.size() % 512) != 0) {
|
|
const std::size_t rank_id = (bv.size() - 1) / 512;
|
|
switch (((bv.size() - 1) / 64) % 8) {
|
|
case 0: {
|
|
ranks_[rank_id].set_rel1(num_1s - ranks_[rank_id].abs());
|
|
} // fall through
|
|
case 1: {
|
|
ranks_[rank_id].set_rel2(num_1s - ranks_[rank_id].abs());
|
|
} // fall through
|
|
case 2: {
|
|
ranks_[rank_id].set_rel3(num_1s - ranks_[rank_id].abs());
|
|
} // fall through
|
|
case 3: {
|
|
ranks_[rank_id].set_rel4(num_1s - ranks_[rank_id].abs());
|
|
} // fall through
|
|
case 4: {
|
|
ranks_[rank_id].set_rel5(num_1s - ranks_[rank_id].abs());
|
|
} // fall through
|
|
case 5: {
|
|
ranks_[rank_id].set_rel6(num_1s - ranks_[rank_id].abs());
|
|
} // fall through
|
|
case 6: {
|
|
ranks_[rank_id].set_rel7(num_1s - ranks_[rank_id].abs());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
size_ = bv.size();
|
|
num_1s_ = bv.num_1s();
|
|
|
|
ranks_.back().set_abs(num_1s);
|
|
if (enables_select0) {
|
|
select0s_.push_back(static_cast<UInt32>(bv.size()));
|
|
select0s_.shrink();
|
|
}
|
|
if (enables_select1) {
|
|
select1s_.push_back(static_cast<UInt32>(bv.size()));
|
|
select1s_.shrink();
|
|
}
|
|
}
|
|
|
|
} // namespace vector
|
|
} // namespace grimoire
|
|
} // namespace marisa
|