AC自动机+模板

xiaoxiao2025-10-06  12

AC自动机是一种多模匹配算法,

所谓多模就是就是给你一些单词,再给你一段字符串,问有多少个单词在字符串中出现,而KMP就是单模。

学习AC自动机首先要用到字典树的知识和KMP中求next数组的思想。

一般来说有三个步骤

第一、构建一颗字典树

没有学过字典树的可以先去看看,这里没有用二维数组模拟,用的指针,因为有一个fail指针,下面会提到它的用法。

void insert(char *s) { node *p=root; int i=0; int index; while(s[i]){ index=str[i]-'a'; if(p->nexte[index]==NULL){ newnode=(struct node *)malloc(sizeof(struct node)); for(int j=0;j<26;j++) newnode->nexte[j]=0; newnode->count=0; newnode->fail=0; p->nexte[index]=newnode; } p=p->nexte[index]; i++; } p->count++; }

二、fail指针的建立

fail指针与KMP中next数组类似,当字符匹配失败时,需利用当前点的fail指针指向在字典树中根节点到当前点所构成的字符串与其他字符串的最大后缀的节点。

自己都觉得很拗口。。。看个图

图中虚线就是各个点的fail指针,根节点的fail指针当然指向根节点root,再看最下面的叶子节点e,在当前点所表示的字符串为she,那我们看看其他哪个字符串与she的公共后缀最大,很显然是he,那么当前e的fail指针即指向he中的e节点。

懂了fail指针究竟为何物后,看看怎么来求这个指针

这里利用BFS来求,直接与根节点相连的节点的fail指针直接指向根节点即可,其他节点:设当前点为father,它的孩子节点为child,求child的fail,那么就找father节点的fail指针指向的点a,再看点a的孩子节点是否有与child节点表示的字符一样的,如果有那么child节点的fail指针就指向a,如果没有就继续上面的过程,直到指向根节点。

是不是感觉跟KMP中next数组的求法很像呢?

void buildfail() { root->fail=NULL; head=0; tail=1; node *p=NULL; q[head]=root; while(head<tail){ node *temp=q[head++]; for(int i=0;i<26;i++){ //代表这个节点下的nexte数组的26个值(a,b,c...z),依次遍历查找 if(temp->nexte[i]){ if(temp==root) //如果是根节点,那么他的孩子的fail指针指向根节点 temp->nexte[i]->fail=root; else { p=temp->fail; //否则找它父节点的fail指针 while(p){ if(p->nexte[i]){ temp->nexte[i]->fail=p->nexte[i]; break; } p=p->fail; } if(p==NULL) temp->nexte[i]->fail=root; } q[tail++]=temp->nexte[i]; } } } }

第三、就是查询操作了

如果节点匹配,就一直进行下去,每次都加每个节点的count,并标记为-1,但只有代表单词结尾的字符count才是1,其他的为0.

如果不匹配,那么就找当前节点的fail指针指向的节点,直到指向root此次循环结束。

一直重复上面两个过程直到模式串走完

void query(char *str) { int i=0,index,len=strlen(str); node *p=root; while(str[i]){ index=str[i]-'a'; while(p->nexte[index]==NULL&&p!=root) p=p->fail; p=p->nexte[index]; p=(p==NULL)?root:p; //由于上面的while循环,如果为空,那么p一定是root节点 node *temp=p; while(temp!=root&&temp->count!=-1){ ans+=temp->count; temp->count=-1; temp=temp->fail; } i++; } }

如果基本明白了AC自动机,那么就来看一下hdu 2222的一个例题吧

http://acm.hdu.edu.cn/showproblem.php?pid=2222

附AC代码

#include<bits/stdc++.h> #define exp 1e-8 #define mian main #define pii pair<int,int> #define pll pair<ll,ll> #define ll long long #define pb push_back #define PI acos(-1.0) #define inf 0x3f3f3f3f #define w(x) while(x--) #define int_max 2147483647 #define lowbit(x) (x)&(-x) #define gcd(a,b) __gcd(a,b) #define pq(x) priority_queue<x> #define ull unsigned long long #define scn(x) scanf("%d",&x) #define scl(x) scanf("%lld",&x) #define pl(a,n) next_permutation(a,a+n) #define ios ios::sync_with_stdio(false) #define met(a,x) memset((a),(x),sizeof((a))) using namespace std; const int N = 1e6+10; const int maxn=1e7+10; struct node { node *fail; node *nexte[30]; int count; node() { fail=NULL; count=0; met(nexte,0); } }*q[N]; char key[100]; char str[N]; //模式串 int head,tail,ans; node *root; //根节点 node *newnode; void insert(char *s) { node *p=root; int index; int i=0; while(s[i]){ index=s[i]-'a'; if(p->nexte[index]==NULL){ newnode=(struct node *)malloc(sizeof(struct node)); for(int j=0;j<26;j++) newnode->nexte[j]=0; newnode->count=0; newnode->fail=0; p->nexte[index]=newnode; } p=p->nexte[index]; i++; } p->count++; } void buildfail() { root->fail=NULL; head=0; tail=1; node *p=NULL; q[head]=root; while(head<tail){ node *temp=q[head++]; for(int i=0;i<26;i++){ if(temp->nexte[i]){ if(temp==root) //如果是根节点,那么他的孩子的fail指针指向根节点 temp->nexte[i]->fail=root; else { p=temp->fail; //否则找它父节点的fail指针 while(p){ if(p->nexte[i]){ temp->nexte[i]->fail=p->nexte[i]; break; } p=p->fail; } if(p==NULL) temp->nexte[i]->fail=root; } q[tail++]=temp->nexte[i]; } } } } void query(char *str) { int i=0,index,len=strlen(str); node *p=root; while(str[i]){ index=str[i]-'a'; while(p->nexte[index]==NULL&&p!=root) p=p->fail; p=p->nexte[index]; p=(p==NULL)?root:p; node *temp=p; while(temp!=root&&temp->count!=-1){ ans+=temp->count; temp->count=-1; temp=temp->fail; } i++; } } int main() { int t; scanf("%d",&t); while(t--){ ans=0; root=(struct node *)malloc(sizeof(struct node)); for(int j=0;j<26;j++) root->nexte[j]=0; root->fail=0; root->count=0; int x; scanf("%d",&x); getchar(); for(int i=1;i<=x;i++){ gets(key); insert(key); } buildfail(); gets(str); query(str); printf("%d\n",ans); } }

 

转载请注明原文地址: https://www.6miu.com/read-5037445.html

最新回复(0)