Wavelet Matrix + BWT + FM-index
以前の記事で何度か触れましたが、FM-indexという方法を使えば完全一致検索を高速に処理できます。他の良い解説の紹介とコード例をあげておきます。
構築時間 | 検索時間(平均) | 検索時間(最悪) | |||
---|---|---|---|---|---|
(1) | Suffix Array | O(N) (SAIS) | O(M+logN) | O(MlogN) | |
(2) | (1)+LCP Array+RMQ | O(N) (Kasaiの方法) | O(M+logN) | O(M+logN) | |
(3) | (1)+Compression | O(N)† | O(M+logN) | O(MlogN) | (†時間は定数倍大きく、空間は定数倍小さい) |
(4) | (2)+(3) | O(N)† | O(M+logN) | O(M+logN) | |
(5) | (4)+BWT+Wavelet Matrix+FM-index | O(N)† | O(MlogC) | O(MlogC) |
Nはリファレンス文字列の長さ、Mはクエリ文字列の長さ、Cは文字の種類
Wavelet Matrix(ウェーブレット行列)
詳しくは説明しませんが、前回の記事で紹介したSuccinct Bit Vectorを応用して、文字列に対してrank操作を可能にするものです。 ウェーブレット木の効率的で簡単な実装 "The Wavelet Matrix" - EchizenBlog-Zwei によると、元論文は The Wavelet Matrix です。私が知る限り最もわかりやすい解説ブログ記事は 中学生にもわかるウェーブレット行列 - アスペ日記です。岡野原本(高速文字列解析の世界――データ圧縮・全文検索・テキストマイニング (確率と情報の科学))にはp74で説明されていますが、ウェーブレット行列の構築とrank操作の具体的な実装方法は書かれていません。
既存の高速なライブラリとしては GitHub - simongog/sdsl-lite: Succinct Data Structure Library 2.0 GitHub - herumi/cybozulib: a tiny library for C++などがあります。
以下はコード例です。
class BruteForceWaveletMatrix { private: std::vector<uint8_t>v; public: BruteForceWaveletMatrix(const std::vector<uint8_t>& input) { v = input; } int64_t rank(const int64_t index, const uint8_t character) { //リファレンス文字列の[0,index)のうち、characterが何回出現したか返す。 int64_t answer = 0; for (int i = 0; i < index; ++i)if (v[i] == character)++answer; return answer; } }; class WaveletMatrix { private: std::vector<BitVector>bv; std::vector<int64_t>acc; int64_t size; //x以上の2べき乗数のうち最小の数が2^kとしてkを返す。 static int64_t log2_ceiling(uint64_t x) { if (x <= 0)return 0; x--; for (int i = 0; i < 6; ++i)x |= x >> (uint64_t(1) << i); x = (x & 0x5555555555555555ULL) + ((x & 0xAAAAAAAAAAAAAAAAULL) >> 1); x = (x & 0x3333333333333333ULL) + ((x & 0xCCCCCCCCCCCCCCCCULL) >> 2); x = (x & 0x0F0F0F0F0F0F0F0FULL) + ((x & 0xF0F0F0F0F0F0F0F0ULL) >> 4); x *= 0x0101010101010101ULL; return x >> 56; } public: WaveletMatrix(const std::vector<uint8_t>& input, const int max_char) { //inputは入力文字列で、max_charは入力文字列が含みうる文字の最大値。 //実際に含んでいる文字の最大値ではない。これを用意する理由は高速化で、 //具体的にはrank操作においてmax_charを超える引数を受け取らないと仮定する。 assert(1 <= max_char && max_char <= 255); for (const auto x : input)assert(x <= max_char); size = input.size(); //max_charのビット数ぶんBitVectorを用意する。 const int64_t bits = log2_ceiling(max_char + 1); bv = std::vector<BitVector>(bits, BitVector()); //BitVectorの最上位は、inputの最上位ビットの列を格納する。 std::vector<uint64_t>bv_tmp((input.size() + 63) / 64, 0); const auto SetBit1 = [&](const int i) {bv_tmp[i / 64] |= uint64_t(1) << (i % 64); }; for (int i = 0; i < input.size(); ++i) { if (input[i] & (1 << (bits - 1)))SetBit1(i); } bv[bits - 1]= BitVector(bv_tmp); auto input_tmp = input; //BitVectorの最上位以外は、input_tmp配列を(bit_pos+1)番目のビットで安定ソートした後における、 //bit_pos番目のビットの列を格納する。cf.高速文字列解析の世界p74 for (int bit_pos = bits - 2; bit_pos >= 0; --bit_pos) { int count = 0; std::vector<uint8_t>input_tmp2(input.size(), 0); for (int i = 0; i < input.size(); ++i) { if ((input_tmp[i] & (1 << (bit_pos + 1))) == 0)input_tmp2[count++] = input_tmp[i]; } for (int i = 0; i < input.size(); ++i) { if ((input_tmp[i] & (1 << (bit_pos + 1))) != 0)input_tmp2[count++] = input_tmp[i]; } for (int i = 0; i < bv_tmp.size(); ++i)bv_tmp[i] = 0; for (int i = 0; i < input.size(); ++i) { if (input_tmp2[i] & (1 << bit_pos))SetBit1(i); } bv[bit_pos] = BitVector(bv_tmp); input_tmp = input_tmp2; } { int count = 0; std::vector<uint8_t>input_tmp2(input.size(), 0); for (int i = 0; i < input.size(); ++i) { if ((input_tmp[i] & 1) == 0)input_tmp2[count++] = input_tmp[i]; } for (int i = 0; i < input.size(); ++i) { if ((input_tmp[i] & 1) != 0)input_tmp2[count++] = input_tmp[i]; } //acc[i]にはinput_tmp2中でiが最初に出現した位置を格納する。出現しない場合は-1とする。 acc = std::vector<int64_t>(max_char + 1, -1); for (int i = 0; i < input_tmp2.size(); ++i)if (acc[input_tmp2[i]] == -1)acc[input_tmp2[i]] = i; } } int64_t rank(const int64_t index, const uint8_t character) { //リファレンス文字列の[0,index)のうち、characterが何回出現したか返す。 if (acc[character] == -1)return 0; //リファレンス文字列の[0,index)のうち、最上位ビットがcharacterに等しいものを数える。 uint64_t tmp_rank = (character & (1 << (bv.size() - 1))) ? bv[bv.size() - 1].rank1(index) : bv[bv.size() - 1].rank0(index); //最上位-1ビットから下位ビットに向けて処理していく。 for (int bit_pos = bv.size() - 2; bit_pos >= 0; --bit_pos) { if (character & (1 << (bit_pos + 1))) { const uint64_t start_pos = size - bv[bit_pos + 1].sum1(); tmp_rank = (character & (1 << bit_pos)) ? bv[bit_pos].rank1(start_pos + tmp_rank) : bv[bit_pos].rank0(start_pos + tmp_rank); } else { tmp_rank = (character & (1 << bit_pos)) ? bv[bit_pos].rank1(tmp_rank) : bv[bit_pos].rank0(tmp_rank); } } if (character & 1)tmp_rank += size - bv[0].sum1(); return tmp_rank - acc[character]; } };
BWT
Suffix Arrayのsuffixたちの1個前の文字を集めたものです。
std::vector<uint8_t>BWT(const std::vector<uint8_t>& input) { for (int i = 0; i < input.size() - 1; ++i)assert(input[i] != 0); assert(input.back() == 0); const auto suffix_array = SAISSuffixArrayConstruction(input); std::vector<uint8_t>bwt(suffix_array.size(), 0); for (int i = 0; i < suffix_array.size(); ++i) { if (suffix_array[i] == 0)bwt[i] = input.back(); else bwt[i] = input[suffix_array[i] - 1]; } return bwt; }
FM-index
FM-indexの解説は岡野原本の7.5.2節が鉄板なので読みましょう。ライブラリは先ほどと同じくSDSLなどが鉄板として知られています。
以下のコードは簡単のため、queryが来るたびにFM-indexを都度構築するように書かれています。実際には最初に一度だけ構築すれば十分です。
std::vector<int> FMIndexMatch( const std::vector<uint8_t>& ref, const std::vector<uint8_t>& query) { for (const auto x : ref)assert(x != 0); for (const auto x : query)assert(x != 0); auto ref_ = ref; ref_.push_back(0); const auto suffix_array = SAISSuffixArrayConstruction(ref_); const auto bwt = BWT(ref_); //BruteForceWaveletMatrix wavelet_matrix(bwt); WaveletMatrix wavelet_matrix(bwt, 255); //bwt中でのiより小さい文字の出現回数を求めてacc_num[i]に格納する。 std::vector<int64_t>num(256, 0), acc_num(256, 0); for (const auto x : bwt)++num[x]; for (int i = 1; i < 256; ++i)acc_num[i] = num[i - 1] + acc_num[i - 1]; //ここまで構築で、ここから検索 int64_t start_pos = 0, end_pos = ref_.size(); for (int i = query.size() - 1; i >= 0; --i) { const auto c = query[i]; start_pos = acc_num[c] + wavelet_matrix.rank(start_pos, c); end_pos = acc_num[c] + wavelet_matrix.rank(end_pos, c); if (start_pos >= end_pos)return std::vector<int>{}; } std::vector<int>match_start_pos; for (int64_t i = start_pos; i < end_pos; ++i)match_start_pos.push_back(suffix_array[i]); sort(match_start_pos.begin(), match_start_pos.end()); return match_start_pos; }