tr[u].son[c]:有两种理解方式。我们可以简单理解为字典树上的一条边,即 trie(u,c);也可以理解为从状态(结点)u 后加一个字符 c 到达的状态(结点),即一个状态转移函数 trans(u,c)。为了方便,下文中我们将用第二种理解方式。
队列 q:用于 BFS 遍历字典树。
tr[u].fail:结点 u 的 fail 指针。
实现
[list2tab]
C++
void build() { queue<int> q; for (int i = 0; i < 26; i++) if (tr[0].son[i]) q.push(tr[0].son[i]); while (!q.empty()) { int u = q.front(); q.pop(); for (int i = 0; i < 26; i++) { if (tr[u].son[i]) { tr[tr[u].son[i]].fail = tr[tr[u].fail].son[i]; q.push(tr[u].son[i]); } else tr[u].son[i] = tr[tr[u].fail].son[i]; } }}
Python
def build(): for i in range(0, 26): if tr[0][i] != 0: q.append(tr[0][i]) while q: u = q.pop(0) for i in range(0, 26): if tr[u][i] != 0: fail[tr[u][i]] = tr[fail[u]][i] q.append(tr[u][i]) else: tr[u][i] = tr[fail[u]][i]
而 trans(S,c) 相当于是在 S 后添加一个字符 c 变成另一个状态 S′。如果 S′ 存在,说明存在一个模式串的前缀是 S′,否则我们让 trans(S,c) 指向 trans(fail(S),c)。由于 fail(S) 对应的字符串是 S 的后缀,因此 trans(fail(S),c) 对应的字符串也是 S′ 的后缀。
换言之在 Trie 上跳转的时侯,我们只会从 S 跳转到 S′,相当于匹配了一个 S′;但在 AC 自动机上跳转的时侯,我们会从 S 跳转到 S′ 的后缀,也就是说我们匹配一个字符 c,然后舍弃 S 的部分前缀。舍弃前缀显然是能匹配的。同时如果文本串能匹配 S,显然它也能匹配 S 的后缀,所以 fail 指针同样在舍弃前缀。所谓的 fail 指针其实就是 S 的一个后缀集合。
Trie 的结点的孩子数组 son 还有另一种比较简单的理解方式:如果在位置 u 失配,我们会跳转到 fail(u) 的位置。注意这会导致我们可能沿着 fail 数组跳转多次才能来到下一个能匹配的位置。所以我们可以用 son 直接记录记录下一个能匹配的位置,这样保证了程序的时间复杂度。
int query(const char t[]) { int u = 0, res = 0; for (int i = 1; t[i]; i++) { u = tr[u].son[t[i] - 'a']; for (int j = u; j && tr[j].cnt != -1; j = tr[j].fail) { res += tr[j].cnt, tr[j].cnt = -1; } } return res;}
Python
def query(t: str) -> int: u, res = 0, 0 for c in t: u = tr[u][c - ord("a")] j = u while j and e[j] != -1: res += e[j] e[j] = -1 j = fail[j] return res
解释
这里 u 作为字典树上当前匹配到的结点,res 即返回的答案。循环遍历匹配串,u 在字典树上跟踪当前字符。利用 fail 指针找出所有匹配的模式串,并累加到答案中。然后将匹配到的串的出现次数清零,这样就不会重复统计同一个串。在上文中我们分析过,字典树的结构其实就是一个 trans 函数,而构建好这个函数后,在匹配字符串的过程中,我们会舍弃部分前缀达到最低限度的匹配。fail 指针则指向了更多的匹配状态。最后上一份图。对于刚才的自动机:
因为我们的 AC 自动机中,每次匹配,会一直向 fail 边跳来找到所有的匹配,但是这样的效率较低,在某些题目中会超时。
那么需要如何优化呢?首先需要了解到 fail 指针的一个性质:一个 AC 自动机中,如果只保留 fail 边,那么剩余的图一定是一棵树。
这是显然的,因为 fail 不会成环,且深度一定比现在低,所以得证。
这样 AC 自动机的匹配就可以转化为在 fail 树上的链求和问题,只需要优化一下该部分就可以了。
这里提供两种思路。
拓扑排序优化
观察到时间主要浪费在在每次都要跳 fail。如果我们可以预先记录,最后一并求和,那么效率就会优化。
于是我们按照 fail 树,做一次内向树上的拓扑排序,就能一次性求出所有模式串的出现次数。
build 函数在原先的基础上,增加了入度统计一部分,为拓扑排序做准备。
构建
void build() { queue<int> q; for (int i = 0; i < 26; i++) if (tr[0].son[i]) q.push(tr[0].son[i]); while (!q.empty()) { int u = q.front(); q.pop(); for (int i = 0; i < 26; i++) { if (tr[u].son[i]) { tr[tr[u].son[i]].fail = tr[tr[u].fail].son[i]; tr[tr[tr[u].fail].son[i]].du++; // 入度计数 q.push(tr[u].son[i]); } else tr[u].son[i] = tr[tr[u].fail].son[i]; } }}
然后我们在查询的时候就可以只为找到结点的 ans 打上标记,在最后再用拓扑排序求出答案。
查询
void query(const char t[]) { int u = 0; for (int i = 1; t[i]; i++) { u = tr[u].son[t[i] - 'a']; tr[u].ans++; }}void topu() { queue<int> q; for (int i = 0; i <= tot; i++) if (tr[i].du == 0) q.push(i); while (!q.empty()) { int u = q.front(); q.pop(); ans[tr[u].idx] = tr[u].ans; int v = tr[u].fail; tr[v].ans += tr[u].ans; if (!--tr[v].du) q.push(v); }}
最后是主函数:
主函数
int main() { // do_something(); AC::build(); scanf("%s", s + 1); AC::query(s); AC::topu(); for (int i = 1; i <= n; i++) printf("%d\n", AC::ans[idx[i]]); // do_another_thing();}
不难想到一个朴素的思路:建立 AC 自动机,在 AC 自动机上对于所有 fail 指针的子串转移,最后取最大值得到答案。
主要代码如下。若不熟悉代码中的类型定义,可以先看末尾的完整代码:
查询部分主要代码
int query(const char t[]) { int u = 0, len = strlen(t + 1); for (int i = 1; i <= len; i++) dp[i] = 0; for (int i = 1; i <= len; i++) { u = tr[u].son[t[i] - 'a']; for (int j = u; j; j = tr[j].fail) { if (tr[j].idx && (dp[i - tr[j].depth] || i - tr[j].depth == 0)) { dp[i] = dp[i - tr[j].depth] + tr[j].depth; } } } int ans = 0; for (int i = 1; i <= len; i++) ans = std::max(ans, dp[i]); return ans;}
void build() { queue<int> q; for (int i = 0; i < 26; i++) if (tr[0].son[i]) { q.push(tr[0].son[i]); tr[tr[0].son[i]].depth = 1; } while (!q.empty()) { int u = q.front(); q.pop(); int v = tr[u].fail; // 对状态的更新在这里 tr[u].stat = tr[v].stat; if (tr[u].idx) tr[u].stat |= 1 << tr[u].depth; for (int i = 0; i < 26; i++) { if (tr[u].son[i]) { tr[tr[u].son[i]].fail = tr[tr[u].fail].son[i]; tr[tr[u].son[i]].depth = tr[u].depth + 1; // 记录深度 q.push(tr[u].son[i]); } else tr[u].son[i] = tr[tr[u].fail].son[i]; } }}
然后查询时就可以去掉跳 fail 的循环,将代码简化如下:
查询
int query(const char t[]) { int u = 0, mx = 0; unsigned st = 1; for (int i = 1; t[i]; i++) { u = tr[u].son[t[i] - 'a']; st <<= 1; // 往下跳了一位每一位的长度都+1 if (tr[u].stat & st) st |= 1, mx = i; } return mx;}
我们的 tr[u].stat 维护的是从结点 u 开始,整条 fail 链上的长度集(因为长度集小于 32 所以不影响),而 st 则维护的是查询字符串走到现在,前 32 位(因为状态压缩自然溢出)的长度集。
#include <cstdio>#include <cstring>#include <queue>using namespace std;constexpr int N = 20 + 6, M = 50 + 6;constexpr int LEN = 2e6 + 6;constexpr int SIZE = 450 + 6;int n, m;namespace AC {struct Node { int son[26]; int fail; int idx; int depth; unsigned stat; void init() { memset(son, 0, sizeof(son)); fail = idx = depth = 0; }} tr[SIZE];int tot;void init() { tot = 0; tr[0].init();}void insert(char s[], int idx) { int u = 0; for (int i = 1; s[i]; i++) { int &son = tr[u].son[s[i] - 'a']; if (!son) son = ++tot, tr[son].init(); u = son; } tr[u].idx = idx;}void build() { queue<int> q; for (int i = 0; i < 26; i++) if (tr[0].son[i]) { q.push(tr[0].son[i]); tr[tr[0].son[i]].depth = 1; } while (!q.empty()) { int u = q.front(); q.pop(); int v = tr[u].fail; // 对状态的更新在这里 tr[u].stat = tr[v].stat; if (tr[u].idx) tr[u].stat |= 1 << tr[u].depth; for (int i = 0; i < 26; i++) { if (tr[u].son[i]) { tr[tr[u].son[i]].fail = tr[tr[u].fail].son[i]; tr[tr[u].son[i]].depth = tr[u].depth + 1; // 记录深度 q.push(tr[u].son[i]); } else tr[u].son[i] = tr[tr[u].fail].son[i]; } }}int query(char t[]) { int u = 0, mx = 0; unsigned st = 1; for (int i = 1; t[i]; i++) { u = tr[u].son[t[i] - 'a']; st <<= 1; if (tr[u].stat & st) st |= 1, mx = i; } return mx;}} // namespace ACchar s[LEN];int main() { AC::init(); scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%s", s + 1); AC::insert(s, i); } AC::build(); for (int i = 1; i <= m; i++) { scanf("%s", s + 1); printf("%d\n", AC::query(s)); } return 0;}
总结
时间复杂度:定义 ∣si∣ 是模板串的长度,∣S∣ 是文本串的长度,∣Σ∣ 是字符集的大小(常数,一般为 26)。如果连了 trie 图,时间复杂度就是 O(∑∣si∣+n∣Σ∣+∣S∣),其中 n 是 AC 自动机中结点的数目,并且最大可以达到 O(∑∣si∣)。如果不连 trie 图,并且在构建 fail 指针的时候避免遍历到空儿子,时间复杂度就是 O(∑∣si∣+∣S∣)。
#include <cstdio>#include <cstring>#include <queue>using namespace std;constexpr int N = 150 + 6;constexpr int LEN = 1e6 + 6;constexpr int SIZE = N * 70 + 6;int n;namespace AC {struct Node { int son[26]; int fail; int idx; void init() { memset(son, 0, sizeof(son)); idx = fail = 0; }} tr[SIZE];int tot;void init() { tot = 0; tr[0].init();}void insert(char s[], int idx) { // 将第 idx 个字符串 s 插入 int u = 0; for (int i = 1; s[i]; i++) { int &son = tr[u].son[s[i] - 'a']; if (!son) son = ++tot, tr[son].init(); u = son; } tr[u].idx = idx;}void build() { queue<int> q; for (int i = 0; i < 26; i++) if (tr[0].son[i]) q.push(tr[0].son[i]); while (!q.empty()) { int u = q.front(); q.pop(); for (int i = 0; i < 26; i++) { if (tr[u].son[i]) { tr[tr[u].son[i]].fail = tr[tr[u].fail].son[i]; q.push(tr[u].son[i]); } else tr[u].son[i] = tr[tr[u].fail].son[i]; } }}int query(char t[], int cnt[]) { int u = 0, res = 0; for (int i = 1; t[i]; i++) { u = tr[u].son[t[i] - 'a']; for (int j = u; j; j = tr[j].fail) ++cnt[tr[j].idx]; // 统计每个字符串出现的次数 } for (int i = 0; i <= tot; ++i) if (tr[i].idx) res = max(res, cnt[tr[i].idx]); return res;}} // namespace ACchar s[N][75], t[LEN];int cnt[N]; // 每一个字符串出现的次数int main() { while (scanf("%d", &n) != EOF && n != 0) { AC::init(); for (int i = 1; i <= n; i++) { scanf("%s", s[i] + 1); AC::insert(s[i], i); cnt[i] = 0; } AC::build(); scanf("%s", t + 1); int x = AC::query(t, cnt); printf("%d\n", x); for (int i = 1; i <= n; i++) if (cnt[i] == x) printf("%s\n", s[i] + 1); } return 0;}