前言

最近集训时将难一些的字符串题时发现自己似乎已经忘了,或者就根本没学懂过 AC 自动机,于是重学一遍。
——2025.6.11

功能

AC 自动机可以实现比 KMP 和 Trie 更多的字符串匹配方面的功能。比如求模式串是否在文本串中出现过,出现了多少次等。不同于 KMP,AC 自动机支持多模式串的匹配。可以认为,AC 自动机就是在 Trie 上进行 KMP,同样要求最长公共前后缀,只不过前缀可以从任意模式串中截取而非 KMP 的单一模式串。类似于 KMP 的 $nxt$ 数组,AC 自动机中称这个指针为 $fail$。
我们举个例子:若干模式串组成 Trie
image
为了防止过多 $fail$ 弄得图片过乱,我们只取一个例子:$9$ 节点处,我们发现有 $0$ 节点到 $2$ 节点和 $7$ 节点到 $9$ 节点的最长公共前后缀 he。所以,类似 KMP,$9$ 处的 $fail$ 应指向 $2$。

建树

那么,$fail$ 究竟应该如何构建呢?我们使用 BFS 来遍历 Trie,在失配时不断跳 $fail$。

build AC automaton
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
inline void build(){
queue<int> q;
for(int i=0;i<26;i++)
if(nxt[0][i]) q.push(nxt[0][i]);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;i++){
if(nxt[u][i]){
int to=fail[u];
while(to&&!nxt[to][i])
to=fail[to];
fail[nxt[u][i]]=nxt[to][i];
q.push(nxt[u][i]);
}
}
}
}

然而,这样一直跳 $fail$ 效率太低了。我们可以在一开始就预处理出不存在的边的 $fail$,将查找 $fail$ 优化至 $O(1)$。此时的 Trie 由树变为了图。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
inline void build(){
queue<int> q;
for(int i=0;i<26;i++)
if(nxt[0][i]) q.push(nxt[0][i]);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;i++){
if(nxt[u][i]){
fail[nxt[u][i]]=nxt[fail[u]][i];
q.push(nxt[u][i]);
}
else nxt[u][i]=nxt[fail[u]][i];
}
}
}

多模式串匹配

我们以本题的匹配为例,只需每次跳 $fail$ 并计数,打标记防止记重就好了。

1
2
3
4
5
6
7
8
9
10
inline int query(string s){
int u=0,res=0;
for(int i=0;i<(int)s.length();i++){
int c=s[i]-'a';
u=nxt[u][c];
for(int j=u;j&&ed[j]!=-1;j=fail[j])
res+=ed[j],ed[j]=-1;
}
return res;
}

效率优化

我们发现,匹配时一直在跳 $fail$,这个操作事实上是可以优化的。
显然,一个 AC 自动机上的 $fail$ 边应当会构成一棵内向树。因此可以进行拓扑排序优化。
按照拓扑序处理节点,累加出现次数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include<iostream>
#include<queue>
#include<cstring>
using namespace std;
constexpr int N=2e5+10;
int nxt[N][26],tot,fail[N],n,ed[N],ans[N],indegree[N],mp[N];
inline void insert(string s,int id){
int u=0;
for(int i=0;i<(int)s.length();i++){
int c=s[i]-'a';
if(!nxt[u][c]) nxt[u][c]=++tot;
u=nxt[u][c];
}
++ed[u];
mp[id]=u;
}
inline void build(){
queue<int> q;
for(int i=0;i<26;i++)
if(nxt[0][i]) q.push(nxt[0][i]);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;i++){
if(nxt[u][i]){
fail[nxt[u][i]]=nxt[fail[u]][i];
++indegree[nxt[fail[u]][i]];
q.push(nxt[u][i]);
}
else nxt[u][i]=nxt[fail[u]][i];
}
}
}
inline void topo(){
queue<int> q;
for(int i=0;i<=tot;i++) if(!indegree[i]) q.push(i);
while(!q.empty()){
int u=q.front();
q.pop();
int v=fail[u];
ans[v]+=ans[u];
if(!--indegree[v]) q.push(v);
}
}
inline void query(string s){
int u=0;
for(int i=0;i<(int)s.length();i++){
int c=s[i]-'a';
u=nxt[u][c];
ans[u]++;
}
}
string s,t;
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++)
cin>>s,insert(s,i);
build();
cin>>t;
query(t);
topo();
for(int i=1;i<=n;i++)
cout<<ans[mp[i]]<<'\n';
return 0;
}

习题

[POI 2000] 病毒

建立 AC 自动机时判断,若某串的最长前缀是病毒,则它本身一定不合法。之后 dfs 判环就行,注意使用标记数组保证 dfs 的复杂度正确。代码有点丑陋,为了卡常写了循环展开。

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<iostream>
#include<queue>
using namespace std;
constexpr int N=3e4+10;
int nxt[N][2],tot,n,fail[N];
bool ed[N];
inline void insert(string s){
int p=0;
for(char c:s){
if(!nxt[p][c^48]) nxt[p][c^48]=++tot;
p=nxt[p][c^48];
}
ed[p]=1;
}
inline void build(){
queue<int> q;
if(nxt[0][0]) q.push(nxt[0][0]);
if(nxt[0][1]) q.push(nxt[0][1]);
while(!q.empty()){
int u=q.front();
q.pop();
if(nxt[u][0]){
fail[nxt[u][0]]=nxt[fail[u]][0];
q.push(nxt[u][0]);
}
else nxt[u][0]=nxt[fail[u]][0];
if(ed[fail[nxt[u][0]]]) ed[nxt[u][0]]=1;
if(nxt[u][1]){
fail[nxt[u][1]]=nxt[fail[u]][1];
q.push(nxt[u][1]);
}
else nxt[u][1]=nxt[fail[u]][1];
if(ed[fail[nxt[u][1]]]) ed[nxt[u][1]]=1;
}
}
bool vis[N],used[N];
void dfs(int u){
vis[u]=1;
if(vis[nxt[u][0]]||vis[nxt[u][1]]){
cout<<"TAK\n";
exit(0);
}
int p=nxt[u][0];
if(!used[p]&&!ed[p]) used[p]=1,dfs(p);
p=nxt[u][1];
if(!used[p]&&!ed[p]) used[p]=1,dfs(p);
vis[u]=0;
}
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++){
string s;
cin>>s;
insert(s);
}
build();
dfs(0);
cout<<"NIE\n";
return 0;
}

[NOI2011] 阿狸的打字机

code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<iostream>
#include<queue>
#include<vector>
#include<algorithm>
#include<cstring>
using namespace std;
constexpr int N=1e5+10;
int n,m,tot,nxt[N][26],fail[N],fa[N],endpos[N],nxt1[N][26];
inline void build(){
queue<int> q;
for(int i=0;i<26;i++)
if(nxt[0][i]) q.push(nxt[0][i]);
while(!q.empty()){
int u=q.front();
q.pop();
for(int i=0;i<26;i++){
if(nxt[u][i]){
fail[nxt[u][i]]=nxt[fail[u]][i];
q.push(nxt[u][i]);
}
else nxt[u][i]=nxt[fail[u]][i];
}
}
}
struct query{
int x,id;
query(int x=0,int id=0):x(x),id(id){}
};
vector<query> q[N];
vector<int> e[N];
int dfnl[N],dfnr[N],dfncnt;
void dfsfail(int u){
dfnl[u]=++dfncnt;
for(int v:e[u]) dfsfail(v);
dfnr[u]=dfncnt;
}
#define lowbit(x) (x&-x)
int tree[N];
inline void modify(int x,int k){
if(!x) return;
for(;x<=dfncnt;x+=lowbit(x)) tree[x]+=k;
}
inline int ask(int x){
int res=0;
for(;x;x^=lowbit(x)) res+=tree[x];
return res;
}
int ans[N];
void dfstrie(int u){
modify(dfnl[u],1);
for(auto temp:q[u]){
int id=temp.id,x=temp.x;
ans[id]=ask(dfnr[x])-ask(dfnl[x]-1);
}
for(int i=0;i<26;i++){
int v=nxt1[u][i];
if(v==0) continue;
dfstrie(v);
}
modify(dfnl[u],-1);
}
int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
string s;
cin>>s;
int p=0;
for(char c:s){
if(c>='a'&&c<='z'){
if(!nxt[p][c-'a']) nxt[p][c-'a']=++tot;
fa[nxt[p][c-'a']]=p;
p=nxt[p][c-'a'];
}
else if(c=='P') endpos[++n]=p;
else p=fa[p];
}
memcpy(nxt1,nxt,sizeof(nxt));
cin>>m;
for(int i=1,x,y;i<=m;i++){
cin>>x>>y;
q[endpos[y]].push_back(query(endpos[x],i));
}
build();
for(int i=1;i<=tot;i++)
e[fail[i]].push_back(i);
dfsfail(0);
dfstrie(0);
for(int i=1;i<=m;i++)
cout<<ans[i]<<'\n';
return 0;
}

图片来源:

https://oi-wiki.org/string/ac-automaton/