Wavelet Matrix + BWT + FM-index

以前の記事で何度か触れましたが、FM-indexという方法を使えば完全一致検索を高速に処理できます。他の良い解説の紹介とコード例をあげておきます。

構築時間 検索時間(平均) 検索時間(最悪)
(1) Suffix Array O(N) (SAIS) O(M+logN) O(MlogN)
(2) (1)+LCP Array+RMQ O(N) (Kasaiの方法) O(M+logN) O(M+logN)
(3) (1)+Compression O(N)† O(M+logN) O(MlogN)  (†時間は定数倍大きく、空間は定数倍小さい)
(4) (2)+(3) O(N)† O(M+logN) O(M+logN)
(5) (4)+BWT+Wavelet Matrix+FM-index O(N)† O(MlogC) O(MlogC)

Nはリファレンス文字列の長さ、Mはクエリ文字列の長さ、Cは文字の種類

Wavelet Matrix(ウェーブレット行列)

詳しくは説明しませんが、前回の記事で紹介したSuccinct Bit Vectorを応用して、文字列に対してrank操作を可能にするものです。 ウェーブレット木の効率的で簡単な実装 "The Wavelet Matrix" - EchizenBlog-Zwei によると、元論文は The Wavelet Matrix です。私が知る限り最もわかりやすい解説ブログ記事は 中学生にもわかるウェーブレット行列 - アスペ日記です。岡野原本(高速文字列解析の世界――データ圧縮・全文検索・テキストマイニング (確率と情報の科学))にはp74で説明されていますが、ウェーブレット行列の構築とrank操作の具体的な実装方法は書かれていません。

既存の高速なライブラリとしては GitHub - simongog/sdsl-lite: Succinct Data Structure Library 2.0 GitHub - herumi/cybozulib: a tiny library for C++などがあります。

以下はコード例です。

class BruteForceWaveletMatrix {
private:

	std::vector<uint8_t>v;

public:

	BruteForceWaveletMatrix(const std::vector<uint8_t>& input) {
		v = input;
	}

	int64_t rank(const int64_t index, const uint8_t character) {
		//リファレンス文字列の[0,index)のうち、characterが何回出現したか返す。
		int64_t answer = 0;
		for (int i = 0; i < index; ++i)if (v[i] == character)++answer;
		return answer;
	}
};
class WaveletMatrix {
private:

	std::vector<BitVector>bv;
	std::vector<int64_t>acc;
	int64_t size;

	//x以上の2べき乗数のうち最小の数が2^kとしてkを返す。
	static int64_t log2_ceiling(uint64_t x) {
		if (x <= 0)return 0;
		x--;
		for (int i = 0; i < 6; ++i)x |= x >> (uint64_t(1) << i);
		x = (x & 0x5555555555555555ULL) + ((x & 0xAAAAAAAAAAAAAAAAULL) >> 1);
		x = (x & 0x3333333333333333ULL) + ((x & 0xCCCCCCCCCCCCCCCCULL) >> 2);
		x = (x & 0x0F0F0F0F0F0F0F0FULL) + ((x & 0xF0F0F0F0F0F0F0F0ULL) >> 4);
		x *= 0x0101010101010101ULL;
		return x >> 56;
	}

public:

	WaveletMatrix(const std::vector<uint8_t>& input, const int max_char) {
		//inputは入力文字列で、max_charは入力文字列が含みうる文字の最大値。
		//実際に含んでいる文字の最大値ではない。これを用意する理由は高速化で、
		//具体的にはrank操作においてmax_charを超える引数を受け取らないと仮定する。

		assert(1 <= max_char && max_char <= 255);
		for (const auto x : input)assert(x <= max_char);

		size = input.size();

		//max_charのビット数ぶんBitVectorを用意する。
		const int64_t bits = log2_ceiling(max_char + 1);
		bv = std::vector<BitVector>(bits, BitVector());

		//BitVectorの最上位は、inputの最上位ビットの列を格納する。
		std::vector<uint64_t>bv_tmp((input.size() + 63) / 64, 0);
		const auto SetBit1 = [&](const int i) {bv_tmp[i / 64] |= uint64_t(1) << (i % 64); };
		for (int i = 0; i < input.size(); ++i) {
			if (input[i] & (1 << (bits - 1)))SetBit1(i);
		}
		bv[bits - 1]= BitVector(bv_tmp);

		auto input_tmp = input;
		//BitVectorの最上位以外は、input_tmp配列を(bit_pos+1)番目のビットで安定ソートした後における、
		//bit_pos番目のビットの列を格納する。cf.高速文字列解析の世界p74
		for (int bit_pos = bits - 2; bit_pos >= 0; --bit_pos) {
			int count = 0;
			std::vector<uint8_t>input_tmp2(input.size(), 0);
			for (int i = 0; i < input.size(); ++i) {
				if ((input_tmp[i] & (1 << (bit_pos + 1))) == 0)input_tmp2[count++] = input_tmp[i];
			}
			for (int i = 0; i < input.size(); ++i) {
				if ((input_tmp[i] & (1 << (bit_pos + 1))) != 0)input_tmp2[count++] = input_tmp[i];
			}
			for (int i = 0; i < bv_tmp.size(); ++i)bv_tmp[i] = 0;
			for (int i = 0; i < input.size(); ++i) {
				if (input_tmp2[i] & (1 << bit_pos))SetBit1(i);
			}
			bv[bit_pos] = BitVector(bv_tmp);
			input_tmp = input_tmp2;
		}

		{
			int count = 0;
			std::vector<uint8_t>input_tmp2(input.size(), 0);
			for (int i = 0; i < input.size(); ++i) {
				if ((input_tmp[i] & 1) == 0)input_tmp2[count++] = input_tmp[i];
			}
			for (int i = 0; i < input.size(); ++i) {
				if ((input_tmp[i] & 1) != 0)input_tmp2[count++] = input_tmp[i];
			}
			//acc[i]にはinput_tmp2中でiが最初に出現した位置を格納する。出現しない場合は-1とする。
			acc = std::vector<int64_t>(max_char + 1, -1);
			for (int i = 0; i < input_tmp2.size(); ++i)if (acc[input_tmp2[i]] == -1)acc[input_tmp2[i]] = i;
		}
	}

	int64_t rank(const int64_t index, const uint8_t character) {
		//リファレンス文字列の[0,index)のうち、characterが何回出現したか返す。

		if (acc[character] == -1)return 0;

		//リファレンス文字列の[0,index)のうち、最上位ビットがcharacterに等しいものを数える。
		uint64_t tmp_rank =
			(character & (1 << (bv.size() - 1))) ?
			bv[bv.size() - 1].rank1(index) :
			bv[bv.size() - 1].rank0(index);

		//最上位-1ビットから下位ビットに向けて処理していく。
		for (int bit_pos = bv.size() - 2; bit_pos >= 0; --bit_pos) {
			if (character & (1 << (bit_pos + 1))) {
				const uint64_t start_pos = size - bv[bit_pos + 1].sum1();
				tmp_rank = (character & (1 << bit_pos)) ?
					bv[bit_pos].rank1(start_pos + tmp_rank) :
					bv[bit_pos].rank0(start_pos + tmp_rank);
			}
			else {
				tmp_rank = (character & (1 << bit_pos)) ?
					bv[bit_pos].rank1(tmp_rank) :
					bv[bit_pos].rank0(tmp_rank);
			}
		}
		if (character & 1)tmp_rank += size - bv[0].sum1();
		return tmp_rank - acc[character];
	}
};
BWT

Suffix Arrayのsuffixたちの1個前の文字を集めたものです。

std::vector<uint8_t>BWT(const std::vector<uint8_t>& input) {

	for (int i = 0; i < input.size() - 1; ++i)assert(input[i] != 0);
	assert(input.back() == 0);

	const auto suffix_array = SAISSuffixArrayConstruction(input);

	std::vector<uint8_t>bwt(suffix_array.size(), 0);
	for (int i = 0; i < suffix_array.size(); ++i) {
		if (suffix_array[i] == 0)bwt[i] = input.back();
		else bwt[i] = input[suffix_array[i] - 1];
	}
	return bwt;
}
FM-index

FM-indexの解説は岡野原本の7.5.2節が鉄板なので読みましょう。ライブラリは先ほどと同じくSDSLなどが鉄板として知られています。

以下のコードは簡単のため、queryが来るたびにFM-indexを都度構築するように書かれています。実際には最初に一度だけ構築すれば十分です。

std::vector<int> FMIndexMatch(
	const std::vector<uint8_t>& ref,
	const std::vector<uint8_t>& query) {

	for (const auto x : ref)assert(x != 0);
	for (const auto x : query)assert(x != 0);

	auto ref_ = ref;
	ref_.push_back(0);
	const auto suffix_array = SAISSuffixArrayConstruction(ref_);
	const auto bwt = BWT(ref_);
	//BruteForceWaveletMatrix wavelet_matrix(bwt);
	WaveletMatrix wavelet_matrix(bwt, 255);

	//bwt中でのiより小さい文字の出現回数を求めてacc_num[i]に格納する。
	std::vector<int64_t>num(256, 0), acc_num(256, 0);
	for (const auto x : bwt)++num[x];
	for (int i = 1; i < 256; ++i)acc_num[i] = num[i - 1] + acc_num[i - 1];

	//ここまで構築で、ここから検索

	int64_t start_pos = 0, end_pos = ref_.size();
	for (int i = query.size() - 1; i >= 0; --i) {
		const auto c = query[i];
		start_pos = acc_num[c] + wavelet_matrix.rank(start_pos, c);
		end_pos = acc_num[c] + wavelet_matrix.rank(end_pos, c);
		if (start_pos >= end_pos)return std::vector<int>{};
	}

	std::vector<int>match_start_pos;
	for (int64_t i = start_pos; i < end_pos; ++i)match_start_pos.push_back(suffix_array[i]);
	sort(match_start_pos.begin(), match_start_pos.end());
	return match_start_pos;
}