Suffix Array + LCP Array + RMQ

前回の記事で、リファレンス文字列に加えてSuffix Arrayが与えられれば完全一致検索クエリを高速に処理できることを紹介しました。そのコードの最悪時間計算量はO(MlogN)(ただしNはリファレンス文字列の長さ、Mはクエリ文字列の長さ)でした。実はSuffix Arrayに加えてLongest Common Prefix(LCP) Arrayというデータを追加で持っておくことで、完全一致検索クエリを最悪時間計算量O(M+logN)で処理できます。この記事ではLCP Arrayを使って完全一致検索する具体的な方法を紹介します。ただしちなみに、前回の記事の方法でもリファレンス・クエリ文字列がともにランダムならば平均時間計算量はO(M+logN)です。実際のデータの性質によっては計算時間も前回の記事の方法のほうが速い場合があります。また、実際にはFM-indexなどのより新しい方法がよく使われます。

LCP arrayとは

短く言うと、Suffix Arrayで隣接するSuffix間の共通するPrefixの長さです。詳しい説明はネット上にいっぱいあるので省きます。"banana$"を例に取ると、下記の通り{0,1,3,0,0,2}となります。

開始位置 suffix 次のsuffixと共通するprefixの最大長
6 $ 0
5 a$ 1
3 ana$ 3
1 anana$ 0
0 banana$ 0
4 na$ 2
2 nana$

LCP Arrayは愚直に求めようとすると最悪時間計算量O(N^{2})かかります。具体的には例えば同じ文字がひたすら並んでいるケースでそうなります。Kasaiの方法を使うと最悪時間計算量O(N)で構築できます。詳しい説明は省きますがコード例は下記の通りです。

std::vector<int>BruteForceLCPArrayConstruction(
	const std::vector<uint8_t>& input,
	const std::vector<int>& suffix_array) {

	for (int i = 0; i < input.size() - 1; ++i)assert(input[i] != 0);
	assert(input.back() == 0);
	assert(input.size() == suffix_array.size());

	std::vector<int>lcp_array(input.size() - 1, 0);
	for (int i = 0; i < input.size() - 1; ++i) {
		while (
			input[suffix_array[i] + lcp_array[i]] ==
			input[suffix_array[i + 1] + lcp_array[i]]) {
			lcp_array[i]++;
		}
	}

	return lcp_array;
}

std::vector<int>KasaiLCPArrayConstruction(
	const std::vector<uint8_t>& input,
	const std::vector<int>& suffix_array) {

	for (int i = 0; i < input.size() - 1; ++i)assert(input[i] != 0);
	assert(input.back() == 0);
	assert(input.size() == suffix_array.size());

	std::vector<int>inverse_suffix_array(suffix_array.size(), 0);
	for (int i = 0; i < inverse_suffix_array.size(); ++i) {
		inverse_suffix_array[suffix_array[i]] = i;
	}

	std::vector<int>lcp_array(input.size() - 1, 0);
	int lcp = 0;
	for (int i = 0; i < input.size() - 1; ++i) {
		if (inverse_suffix_array[i] == input.size() - 1)continue;
		const int pos1 = suffix_array[inverse_suffix_array[i]];
		const int pos2 = suffix_array[inverse_suffix_array[i] + 1];
		while (input[pos1 + lcp] == input[pos2 + lcp])lcp++;
		lcp_array[inverse_suffix_array[i]] = lcp;
		if (lcp > 0)--lcp;
	}
	return lcp_array;
}
RMQについて

LCPを利用して最悪時間計算量O(M+logN)で完全一致検索を行うためには、任意のSuffix間のLCPをO(1)で求める必要があります。ここで任意の(i,j)について、辞書順でi番目とj番目のSuffix間のLCPはmin_{x \in [i,j)}lcp\_array(x)であることを利用します。するともし数列の任意の区間内の最小値をO(1)で求められるのならば、LCP Arrayの区間[i,j)内の最小値をその方法で求めれば、その値が辞書順でi番目とj番目のSuffix間のLCPとなります。実は、数列の任意の区間内の最小値を求めるクエリはRange Minimum Query(RMQ)という名前で知られており、O(1)で処理できるデータ構造が存在します。

構築 更新 RMQ
ナイーブ O(N) O(1) O(N)
Segment Tree O(N) O(logN) O(logN)
Sparse Table O(NlogN) O(N) O(1)

Segment TreeとSparse Tableは蟻本に載っていて、ネット上にも色々情報があります。ちなみに、O(N)で構築できるSparse Tableもあります。
qiita.com

コード例は以下の通りです。

template<typename T>class BruteForceMin {
private:
	std::vector<T>v;

public:

	BruteForceMin(const std::vector<T>& input) {
		v = input;
	}
	BruteForceMin(const int64_t N) {
		v = std::vector<T>(N, std::numeric_limits<T>::max());
	}
	BruteForceMin() {
		BruteForceMin(1);
	}

	//v[index]=numberとする。
	void update(const int64_t index, const T number) {
		v[index] = number;
	}

	//[index]の値を返す。
	T getnum(const int64_t index) {
		return v[index];
	}

	//[L,R)の最小値を求める。
	int getmin(const int64_t L, const int64_t R) {
		T answer = std::numeric_limits<T>::max();
		for (int64_t i = L; i < R; ++i)if (v[i] < answer)answer = v[i];
		return answer;
	}
};
template<typename T>class SegTreeMin {
private:

	int64_t SIZE;
	std::vector<T>v;

	T getmin_inner(
		const int64_t queryL, const int64_t queryR, const int64_t index,
		const int64_t segL, const int64_t segR) {
		if (queryR <= segL || segR <= queryL)return std::numeric_limits<T>::max();
		if (queryL <= segL && segR <= queryR)return v[index];
		return std::min(
			getmin_inner(queryL, queryR, index * 2,
				segL, (segL + segR) / 2),
			getmin_inner(queryL, queryR, index * 2 + 1,
				(segL + segR) / 2, segR));
	}

public:

	SegTreeMin(const std::vector<T>& input) {
		SIZE = roundup_pow2(input.size());
		v = std::vector<T>(SIZE * 2, std::numeric_limits<T>::max());
		for (int64_t i = 0; i < input.size(); ++i)v[i + SIZE] = input[i];
		for (int64_t i = SIZE - 1; i; --i)v[i] = std::min(v[i * 2], v[i * 2 + 1]);
	}

	SegTreeMin(const int64_t N) {
		SegTreeMin(std::vector<T>(N, std::numeric_limits<T>::max()));
	}

	SegTreeMin() {
		SegTreeMin(1);
	}

	//v[index]=numberとする。
	void update(const int64_t index, const T number) {
		v[index + SIZE] = number;
		for (int64_t i = (index + SIZE) / 2; i; i = i / 2) {
			v[i] = std::min(v[i * 2], v[i * 2 + 1]);
		}
	}

	//[index]の値を返す。
	T getnum(const int64_t index) {
		return v[index + SIZE];
	}

	//[L,R)の最小値を返す。
	T getmin(const int64_t L, const int64_t R) {
		return getmin_inner(L, R, 1, 0, SIZE);
	}

	//x以上であるような2のべき乗数のうち最小のものを返す。
	static uint64_t roundup_pow2(uint64_t x) {
		x--;
		for(int i = 0; i < 6; ++i)x |= x >> (1 << i);
		return x + 1;
	}
};
template<typename T>class SparseTableMin{
private:

	int64_t rank;
	std::vector<std::vector<T>>v;

public:

	SparseTableMin(const std::vector<T>& input) {
		rank = std::max(int64_t(1), log2_ceiling(input.size()));
		v = std::vector<std::vector<T>>(rank,
			std::vector<T>(input.size(), std::numeric_limits<T>::max()));
		for (int64_t i = 0; i < input.size(); ++i)v[0][i] = input[i];
		for (int64_t x = 1; x < rank; ++x) {
			for (int64_t i = 0; i < input.size(); ++i) {
				const int64_t pos = i + (1 << (x - 1));
				if (pos >= input.size())v[x][i] = v[x - 1][i];
				else v[x][i] = std::min(v[x - 1][i], v[x - 1][pos]);
			}
		}
	}

	SparseTableMin(const int64_t N) {
		SparseTableMin(std::vector<T>(N, std::numeric_limits<T>::max()));
	}
	SparseTableMin() {
		SparseTableMin(1);
	}

	//v[index]=numberとする。O(N)
	void update(const int64_t index, const T number) {
		v[0][index] = number;
		for (int64_t x = 1; x < rank; ++x) {
			for (int64_t i = std::max(int64_t(0), index - (1 << x) + 1);
				i <= index; ++i) {
				const int64_t pos = i + (1 << (x - 1));
				if (pos >= v[0].size())v[x][i] = v[x - 1][i];
				else v[x][i] = std::min(v[x - 1][i], v[x - 1][pos]);
			}
		}
	}

	T getnum(const int64_t index) {
		return v[0][index];
	}

	//[L,R)の最小値を返す。log2_ceilingがO(1)だとみなすとO(1)
	T getmin(const int64_t L, const int64_t R) {
		if (L + 1 == R)return v[0][L];
		if (R <= L)return std::numeric_limits<T>::max();
		const int64_t query_length = R - L;
		const int64_t index_rank = log2_ceiling(query_length) - 1;
 		return std::min(v[index_rank][L], v[index_rank][R - (1 << index_rank)]);
	}

	//x以上の2べき乗数のうち最小の数が2^kとしてkを返す。
	static int64_t log2_ceiling(uint64_t x) {

		if (x <= 0)return 0;
		x--;

		//msbの位置を求めるために、ビットを下位に伝播させてpopcntする。

		for (int i = 0; i < 6; ++i)x |= x >> (uint64_t(1) << i);
		x = (x & 0x5555555555555555ULL) + ((x & 0xAAAAAAAAAAAAAAAAULL) >> 1);
		x = (x & 0x3333333333333333ULL) + ((x & 0xCCCCCCCCCCCCCCCCULL) >> 2);
		x = (x & 0x0F0F0F0F0F0F0F0FULL) + ((x & 0xF0F0F0F0F0F0F0F0ULL) >> 4);
		x *= 0x0101010101010101ULL;
		return x >> 56;
	}
};
template<typename T>class AlstrupSparceTableMin {
private:

	int64_t block_size;
	SparseTableMin<T> macro_array;
	std::vector<T>v;
	std::vector<std::vector<uint64_t>>blocks;

	T getmin_micro(const int64_t index, const int64_t L, const int64_t R) {
		const uint64_t w = blocks[index][R] & ~((uint64_t(1) << L) - uint64_t(1));
		const int64_t lsb =
			SparseTableMin<T>::log2_ceiling(w & uint64_t(-int64_t(w)));
		return v[index * block_size + ((w == 0) ? R : lsb)];
	}

public:

	AlstrupSparceTableMin(const std::vector<T>& input) {
		v = input;
		block_size = std::max(int64_t(1),
			SparseTableMin<T>::log2_ceiling(input.size()) / 2);
		assert(block_size <= 63);
		const int64_t block_num = (input.size() + block_size - 1) / block_size;
		{
			std::vector<int64_t>M(block_num, std::numeric_limits<T>::max());
			for (int64_t i = 0; i < input.size(); ++i) {
				const int64_t index = i / block_size;
				M[index] = std::min(M[index], int64_t(input[i]));
			}
			macro_array = SparseTableMin<T>(M);
		}
		blocks = std::vector<std::vector<uint64_t>>(block_num,
			std::vector<uint64_t>(block_size, 0));
		for (int64_t x = 0; x < blocks.size(); ++x) {
			std::vector<T>B(block_size, std::numeric_limits<T>::max());
			const int64_t offset = x * block_size;
			for (int64_t i = 0;
				offset + i < input.size() && i < block_size; ++i) {
				B[i] = input[offset + i];
			}
			std::vector<int64_t>g(block_size, -1);
			std::stack<int64_t>g_stack;
			for (int64_t i = 0; i < block_size; ++i) {
				while (!g_stack.empty() && B[i] <= B[g_stack.top()]) {
					g_stack.pop();
				}
				g[i] = g_stack.empty() ? -1 : g_stack.top();
				g_stack.push(i);
			}
			for (int64_t i = 1; i < block_size; ++i) {
				blocks[x][i] = g[i] == -1 ? 0 :
					(blocks[x][g[i]] |
					(uint64_t(1) << (uint64_t(g[i]))));
			}
		}
	}

	AlstrupSparceTableMin(const int64_t N) {
		AlstrupSparceTableMin(std::vector<T>(N, std::numeric_limits<T>::max()));
	}

	AlstrupSparceTableMin() {
		AlstrupSparceTableMin(1);
	}

	//v[index]=numberとする。O(N)
	void update(const int64_t index, const T number) {

		v[index] = number;

		const int64_t block_index = index / block_size;
		int64_t macro_min = std::numeric_limits<T>::max();
		const int64_t offset = block_index * block_size;
		for (int64_t i = 0; i < block_size && offset + i < v.size(); ++i) {
			macro_min = std::min(macro_min, v[offset + i]);
		}
		macro_array.update(block_index, macro_min);

		std::vector<T>B(block_size, std::numeric_limits<T>::max());
		for (int64_t i = 0; offset + i < v.size() && i < block_size; ++i) {
			B[i] = v[offset + i];
		}
		std::vector<int64_t>g(block_size, -1);
		std::stack<int64_t>g_stack;
		for (int64_t i = 0; i < block_size; ++i) {
			while (!g_stack.empty() && B[i] <= B[g_stack.top()])g_stack.pop();
			g[i] = g_stack.empty() ? -1 : g_stack.top();
			g_stack.push(i);
		}

		for (int64_t i = 1; i < block_size; ++i) {
			blocks[block_index][i] = g[i] == -1 ? 0 :
				(blocks[block_index][g[i]] |
				(uint64_t(1) << (uint64_t(g[i]))));
		}
	}

	T getnum(const int64_t index) {
		return v[index];
	}

	//[L,R)の最小値を返す。log2_ceilingがO(1)だとみなすとO(1)
	int getmin(const int64_t L, const int64_t R) {
		if (R <= L)return std::numeric_limits<T>::max();
		const int64_t L_index = L / block_size;
		const int64_t R_index = (R - 1) / block_size;
		if (L_index == R_index) {
			return getmin_micro(L_index, L % block_size, (R - 1) % block_size);
		}
		T answer = macro_array.getmin(L_index + 1, R_index);
		answer = std::min(
			answer, getmin_micro(L_index, L % block_size, block_size - 1));
		answer = std::min(
			answer, getmin_micro(R_index, 0, (R - 1) % block_size));
		return answer;
	}
};

msbとpopcountはAMD64に専用命令がありますが、今回は簡単のため専用命令を使わずに実装しました。(log2_ceiling)

LCP Arrayを利用した完全一致検索

完全一致検索を実際に行うコード例はネット上にありそうであまり無いので以下に記しておきます。以下のコードでは簡単のためSuffix ArrayとLCP Arrayを関数内で都度構築するように書かれていますが、実際に運用したい人はそれらを引数として受け取るように適宜書き換えて下さい。とはいえ完全一致検索を実際に必要とする人はFM-indexなどを使うほうが良い場合が多いと思います。

std::vector<int> SuffixArrayLCPMatch(
	const std::vector<uint8_t>& ref,
	const std::vector<uint8_t>& query) {

	for (const auto x : ref)assert(x != 0);
	for (const auto x : query)assert(x != 0);

	auto ref_ = ref;
	ref_.push_back(0);
	const auto suffix_array = SAISSuffixArrayConstruction(ref_);
	SparseTableMin<int>lcp_array(KasaiLCPArrayConstruction(ref_, suffix_array));

	const auto compare = [&](const int ref_index, const int lcp_lower_bound) {
		//queryと、ref_[k]以降を比較する。
		//ref_[k]以降がquery全体をprefixとして含む場合、queryのほうが若いと判定する。
		//queryには終端文字が付いていないが、あたかも終端文字よりも若い"超終端文字"が付いているかのように判定すると言ってもよい。

		//比較の処理において、先頭の lcp_lower_bound 文字は一致すると仮定する。
		//返り値は、実際に一致した文字数とどちらが先か。

		int i = lcp_lower_bound;

		//ref_.back()は終端文字であり、queryに終端文字は無いので、2つ目の条件は不要である。
		while (i < query.size() &&
			/*ref_index + i < ref_.size() &&*/
			query[i] == ref_[ref_index + i])++i;
		const bool query_is_front = (i == query.size()) || query[i] < ref_[ref_index + i];

		return std::make_pair(i, query_is_front);
	};

	//suffix arrayの最後よりもクエリ文字列のほうが辞書順で後なら、完全一致は存在しない。
	//終端文字はref_にのみ存在し、他のどの文字よりも若いのでそれが言える。
	if (compare(suffix_array.back(), 0).second == false) {
		return std::vector<int>{};
	}

	//suffix array[0]は終端文字列のみの文字列であるから、クエリ文字列よりも辞書順で先である。
	//ゆえに、この時点で、クエリ文字列は辞書順でsuffix arrayのどこかの間に挿入されるはずである。
	//その挿入位置を二分探索で特定する。
	int L = 0, R = suffix_array.size() - 1;
	//これ以降、Lよりも後ろかつRよりも前に挿入されることが判明している。

	//まずはじめに、クエリ文字列と真ん中のsuffixを普通に辞書順比較して、LCPの値と前後関係を求める。
	const int first_mid = (L + R) / 2;
	const auto first_comparison = compare(suffix_array[first_mid], 0);
	int query_ref_lcp = first_comparison.first;
	bool lcp_is_R = first_comparison.second;

	//クエリ文字列が先なら挿入位置は[L,mid]のどこかに絞られる。さもなくば[mid,R]に絞られる。
	if (first_comparison.second)R = first_mid;
	else L = first_mid;

	//以降、「クエリ文字列と片端のsuffixとのLCP」が求まっている状態を保ちつつ二分探索する。
	while (L + 1 < R) {
		const int mid = (L + R) / 2;
		if (lcp_is_R) {
			//この時点で、クエリ文字列とR番目のsuffixとのLCPが求まっていて、query_ref_lcpに格納されている。
			
			//mid番目のsuffixとR番目のsuffixとのLCPを求める。
			//このRMQの実装では半開区間[L,R)の最小値を返す。lcp_array[i]はiとi+1とのLCPなので、第二引数はR+1ではなくRでよい。
			const int mid_R_lcp = lcp_array.getmin(mid, R);

			if (mid_R_lcp < query_ref_lcp) {
				//クエリ文字列とR番目のsuffixが長く一致していて、mid番目のsuffixとR番目のsuffixとが短くしか一致しないのなら、
				//挿入位置は[mid,R]に絞られる。
				//また、[L,mid]内の全suffixに対してクエリは絶対に完全一致しない。
				//なぜなら、LCPがRMQで求まる性質より、任意の正数dに対してLCP(mid,R)>=LCP(mid-d,R)だからである。
				L = mid;
			}
			else if (mid_R_lcp > query_ref_lcp) {
				//クエリ文字列とR番目のsuffixとが短くしか一致せず、mid番目のsuffixとR番目のsuffixとが長く一致するのなら、
				//挿入位置は[L,mid]に絞られる。
				//また、[mid,R]内の全suffixに対してクエリは絶対に完全一致しない。
				//なぜなら、[mid,R]内の全suffixは、query_ref_lcp番目で不一致するという点で共通するからである。
				R = mid;
			}
			else {
				//クエリ文字列とR番目のsuffixとの一致長と、mid番目のsuffixとR番目のsuffixとの一致長が同じなら、
				//クエリ文字列とmid番目のsuffixとを普通に辞書順比較して、LCPの値と前後関係を求める。
				//このとき求まるLCPの値は、R番目のsuffixとのLCPの値以上であることが確定している。
				//ゆえに文字比較を省略でき、最悪時間計算量が削減される。
				const auto comparison = compare(suffix_array[mid], query_ref_lcp);
				query_ref_lcp = comparison.first;
				lcp_is_R = comparison.second;
				if (comparison.second)R = mid;
				else L = mid;
			}
		}
		else {
			//この時点で、クエリ文字列とL番目のsuffixとのLCPが求まっていて、query_ref_lcpに格納されている。
			
			//L番目のsuffixとmid番目のsuffixとのLCPを求める。
			//このRMQの実装では半開区間[L,R)の最小値を返す。lcp_array[i]はiとi+1とのLCPなので、第二引数はmid+1ではなくmidでよい。
			const int L_mid_lcp = lcp_array.getmin(L, mid);

			if (L_mid_lcp < query_ref_lcp) {
				//クエリ文字列とL番目のsuffixとが長く一致していて、L番目のsuffixとmid番目のsuffixとが短くしか一致しないのなら、
				//挿入位置は[L,mid]に絞られる。
				//また、[mid,R]内の全suffixに対してクエリは絶対に完全一致しない。
				//なぜなら、LCPがRMQで求まる性質より、任意の正数dに対してLCP(L,mid)>=LCP(L,mid+d)だからである。
				R = mid;
			}
			else if (L_mid_lcp > query_ref_lcp) {
				//クエリ文字列とL番目のsuffixとが短くしか一致せず、L番目のsuffixとmid番目のsuffixとが長く一致するのなら、
				//挿入位置は[mid,R]に絞られる。
				//また、[L,mid]内の全suffixに対してクエリは絶対に完全一致しない。
				//なぜなら、[L,mid]内の全suffixは、query_ref_lcp番目で不一致するという点で共通するからである。
				L = mid;
			}
			else {
				//クエリ文字列とL番目のsuffixとの一致長と、L番目のsuffixとmid番目のsuffixとの一致長が同じなら、
				//クエリ文字列とmid番目のsuffixとを普通に辞書順比較して、LCPの値と前後関係を求める。
				//このとき求まるLCPの値は、L番目のsuffixとのLCPの値以上であることが確定している。
				//ゆえに文字比較を省略でき、最悪時間計算量が削減される。
				const auto comparison = compare(suffix_array[mid], query_ref_lcp);
				query_ref_lcp = comparison.first;
				lcp_is_R = comparison.second;
				if (comparison.second)R = mid;
				else L = mid;
			}
		}
	}

	//この時点でL+1==Rであり、クエリ文字列は辞書順でR番目のsuffixの直前に位置する。
	//終端文字はref_にのみ存在し、他のどの文字よりも若いので、完全一致するとすればR番目以降である。
	const auto final_comparison = compare(suffix_array[R], 0);

	//クエリ文字列がR番目のsuffixと完全一致しないなら、他のどのsuffixともしないので終了。
	if(final_comparison.first < query.size())return std::vector<int>{};

	std::vector<int>match_start_pos{suffix_array[R]};

	//pos番目のsuffixと完全一致するとき、
	//pos+1番目のsuffixとも完全一致する⇔pos番目のsuffixとpos+1番目のsuffixとのLCPがクエリ文字列長以上である。
	for (int pos = R; pos < suffix_array.size() - 1; pos++) {
		if (lcp_array.getnum(pos) < query.size())break;
		match_start_pos.push_back(suffix_array[pos + 1]);
	}

	sort(match_start_pos.begin(), match_start_pos.end());
	return match_start_pos;
}