#include <algorithm>#include <cstdio>#include <cstring>#include <iostream>using namespace std;constexpr int N = 1000010;char s[N];int n, sa[N], rk[N << 1], oldrk[N << 1], id[N], cnt[N];int main() { int i, m, p, w; scanf("%s", s + 1); n = strlen(s + 1); m = 127; for (i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i; memcpy(oldrk + 1, rk + 1, n * sizeof(int)); for (p = 0, i = 1; i <= n; ++i) { if (oldrk[sa[i]] == oldrk[sa[i - 1]]) { rk[sa[i]] = p; } else { rk[sa[i]] = ++p; } } for (w = 1; w < n; w <<= 1, m = n) { // 对第二关键字:id[i] + w进行计数排序 memset(cnt, 0, sizeof(cnt)); memcpy(id + 1, sa + 1, n * sizeof(int)); // id保存一份儿sa的拷贝,实质上就相当于oldsa for (i = 1; i <= n; ++i) ++cnt[rk[id[i] + w]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[rk[id[i] + w]]--] = id[i]; // 对第一关键字:id[i]进行计数排序 memset(cnt, 0, sizeof(cnt)); memcpy(id + 1, sa + 1, n * sizeof(int)); for (i = 1; i <= n; ++i) ++cnt[rk[id[i]]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[rk[id[i]]]--] = id[i]; memcpy(oldrk + 1, rk + 1, n * sizeof(int)); for (p = 0, i = 1; i <= n; ++i) { if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) { rk[sa[i]] = p; } else { rk[sa[i]] = ++p; } } } for (i = 1; i <= n; ++i) printf("%d ", sa[i]); return 0;}
思考一下第二关键字排序的实质,其实就是把超出字符串范围(即 sa[i]+w>n)的 sa[i] 放到 sa 数组头部,然后把剩下的依原顺序放入:
int cur = 0;for (int i = n - w + 1; i <= n; i++) id[++cur] = i;for (int i = 1; i <= n; i++) if (sa[i] > w) id[++cur] = sa[i] - w;
优化计数排序的值域
每次对 rk 进行更新之后,我们都计算了一个 p,这个 p 即是 rk 的值域,将值域改成它即可。
若排名都不相同可直接生成后缀数组
考虑新的 rk 数组,若其值域为 [1,n] 那么每个排名都不同,此时无需再排序。
实现
#include <algorithm>#include <cstdio>#include <cstring>#include <iostream>using namespace std;constexpr int N = 1000010;char s[N];int n;int m, p, rk[N * 2], oldrk[N], sa[N * 2], id[N], cnt[N];int main() { scanf("%s", s + 1); n = strlen(s + 1); m = 128; for (int i = 1; i <= n; i++) cnt[rk[i] = s[i]]++; for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1]; for (int i = n; i >= 1; i--) sa[cnt[rk[i]]--] = i; for (int w = 1;; w <<= 1, m = p) { // m = p 即为值域优化 int cur = 0; for (int i = n - w + 1; i <= n; i++) id[++cur] = i; for (int i = 1; i <= n; i++) if (sa[i] > w) id[++cur] = sa[i] - w; memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= n; i++) cnt[rk[i]]++; for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1]; for (int i = n; i >= 1; i--) sa[cnt[rk[id[i]]]--] = id[i]; p = 0; memcpy(oldrk, rk, sizeof(oldrk)); for (int i = 1; i <= n; i++) { if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) rk[sa[i]] = p; else rk[sa[i]] = ++p; } if (p == n) break; // p = n 时无需再排序 } for (int i = 1; i <= n; i++) printf("%d ", sa[i]); return 0;}
任务是在线地在主串 T 中寻找模式串 S。在线的意思是,我们已经预先知道知道主串 T,但是当且仅当询问时才知道模式串 S。我们可以先构造出 T 的后缀数组,然后查找子串 S。若子串 S 在 T 中出现,它必定是 T 的一些后缀的前缀。因为我们已经将所有后缀排序了,我们可以通过在 p 数组中二分 S 来实现。比较子串 S 和当前后缀的时间复杂度为 O(∣S∣),因此找子串的时间复杂度为 O(∣S∣log∣T∣)。注意,如果该子串在 T 中出现了多次,每次出现都是在 p 数组中相邻的。因此出现次数可以通过再次二分找到,输出每次出现的位置也很轻松。
#include <cctype>#include <cstring>#include <iostream>using namespace std;constexpr int N = 1000010;char s[N];int n, sa[N], id[N], oldrk[N * 2], rk[N * 2], px[N], cnt[N];bool cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];}int main() { int i, w, m = 200, p, l = 1, r, tot = 0; cin >> n; r = n; for (i = 1; i <= n; ++i) while (cin >> s[i], !isalpha(s[i])); for (i = 1; i <= n; ++i) rk[i] = rk[2 * n + 2 - i] = s[i]; // 拼接正反两个字符串,中间空出一个字符 n = 2 * n + 1; // 求后缀数组 for (i = 1; i <= n; ++i) ++cnt[rk[i]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i; for (w = 1; w < n; w *= 2, m = p) { // m=p 就是优化计数排序值域 for (p = 0, i = n; i > n - w; --i) id[++p] = i; for (i = 1; i <= n; ++i) if (sa[i] > w) id[++p] = sa[i] - w; memset(cnt, 0, sizeof(cnt)); for (i = 1; i <= n; ++i) ++cnt[px[i] = rk[id[i]]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[px[i]]--] = id[i]; memcpy(oldrk, rk, sizeof(rk)); for (p = 0, i = 1; i <= n; ++i) rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p; } // 利用后缀数组O(1)进行判断 while (l <= r) { cout << (rk[l] < rk[n + 1 - r] ? s[l++] : s[r--]); if ((++tot) % 80 == 0) cout << '\n'; // 回车 } return 0;}
height 数组
LCP(最长公共前缀)
两个字符串 S 和 T 的 LCP 就是最大的 x(x≤min(∣S∣,∣T∣)) 使得 Si=Ti(∀1≤i≤x)。
下文中以 lcp(i,j) 表示后缀 i 和后缀 j 的最长公共前缀(的长度)。
height 数组的定义
height[i]=lcp(sa[i],sa[i−1]),即第 i 名的后缀与它前一名的后缀的最长公共前缀。
#include <cstring>#include <iostream>#include <set>using namespace std;constexpr int N = 40010;int n, k, a[N], sa[N], rk[N], oldrk[N], id[N], px[N], cnt[1000010], ht[N], ans;multiset<int> t; // multiset 是最好写的实现方式bool cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];}int main() { cin.tie(nullptr)->sync_with_stdio(false); int i, j, w, p, m = 1000000; cin >> n >> k; --k; for (i = 1; i <= n; ++i) cin >> a[i]; // 求后缀数组 for (i = 1; i <= n; ++i) ++cnt[rk[i] = a[i]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i; for (w = 1; w < n; w <<= 1, m = p) { for (p = 0, i = n; i > n - w; --i) id[++p] = i; for (i = 1; i <= n; ++i) if (sa[i] > w) id[++p] = sa[i] - w; memset(cnt, 0, sizeof(cnt)); for (i = 1; i <= n; ++i) ++cnt[px[i] = rk[id[i]]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[px[i]]--] = id[i]; memcpy(oldrk, rk, sizeof(rk)); for (p = 0, i = 1; i <= n; ++i) rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p; } for (i = 1, j = 0; i <= n; ++i) { // 求 height if (j) --j; while (a[i + j] == a[sa[rk[i] - 1] + j]) ++j; ht[rk[i]] = j; } for (i = 1; i <= n; ++i) { // 求所有最小值的最大值 t.insert(ht[i]); if (i > k) t.erase(t.find(ht[i - k])); ans = max(ans, *t.begin()); } cout << ans; return 0;}
某些题目求解时要求你将后缀数组划分成若干个连续 LCP 长度大于等于某一值的段,亦即将 h 数组划分成若干个连续最小值大于等于某一值的段并统计每一段的答案。如果有多次询问,我们可以将询问离线。观察到当给定值单调递减的时候,满足条件的区间个数总是越来越少,而新区间都是两个或多个原区间相连所得,且新区间中不包含在原区间内的部分的 h 值都为减少到的这个值。我们只需要维护一个并查集,每次合并相邻的两个区间,并维护统计信息即可。
#include <cstring>#include <iostream>#include <string>using namespace std;constexpr int N = 500010;string s;int n, sa[N], rk[N << 1], oldrk[N << 1], id[N], px[N], cnt[N], ht[N], sta[N], top, l[N];long long ans;bool cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w];}int main() { int i, k, w, p, m = 300; cin >> s; n = s.size(); s = " " + s; ans = 1ll * n * (n - 1) * (n + 1) / 2; // 求后缀数组 for (i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i; for (w = 1; w < n; w <<= 1, m = p) { for (p = 0, i = n; i > n - w; --i) id[++p] = i; for (i = 1; i <= n; ++i) if (sa[i] > w) id[++p] = sa[i] - w; memset(cnt, 0, sizeof(cnt)); for (i = 1; i <= n; ++i) ++cnt[px[i] = rk[id[i]]]; for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (i = n; i >= 1; --i) sa[cnt[px[i]]--] = id[i]; memcpy(oldrk, rk, sizeof(rk)); for (p = 0, i = 1; i <= n; ++i) rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p; } // 求 height for (i = 1, k = 0; i <= n; ++i) { if (k) --k; while (s[i + k] == s[sa[rk[i] - 1] + k]) ++k; ht[rk[i]] = k; } // 维护单调栈 for (i = 1; i <= n; ++i) { while (ht[sta[top]] > ht[i]) --top; // top类似于一个指针 l[i] = i - sta[top]; sta[++top] = i; } // 最后利用单调栈算 ans sta[++top] = n + 1; ht[n + 1] = -1; for (i = n; i >= 1; --i) { while (ht[sta[top]] >= ht[i]) --top; ans -= 2ll * ht[i] * l[i] * (sta[top] - i); sta[++top] = i; } cout << ans; return 0;}