坑死我了,其实我还是不太懂为什么一定要补上不存在的边? 但是这题的妙处就是可以由fail跑出所有的后缀来
#include <iostream> #include <algorithm> #include <queue> #include <stack> #include <cstdio> #include <string> #include <cstring> #include <vector> #include <set> #include <cmath> #include <map> #define LL long long #define INF 0x3f3f3f3f #define mod 1000000007 const int maxn = 10000+5; using namespace std; char p[1000]; int n,k,l; double dp[maxn][105]; bool vis[maxn][105]; double pos[1000]; struct Aho{ struct node{ int next[63]; int fail,cnt; }state[maxn]; queue<int> q; int size; int idx(char ch) { if(islower(ch)) return ch - 'a'; else if(isupper(ch)) return ch - 'A' + 26; return ch - '0' + 52; } void init(){ while(!q.empty()) q.pop(); for(int i=0; i<maxn; i++){ memset(state[i].next, 0, sizeof(state[i].next)); state[i].fail = state[i].cnt = 0; } size = 1; } void insert(char *s){ int n = (int)strlen(s); int now = 0; for(int i=0; i<n; i++){ int c = idx(s[i]); if(!state[now].next[c]){ state[now].next[c] = size++; } now = state[now].next[c]; } state[now].cnt = 1; } void build(){ state[0].fail = -1; q.push(0);//0是根节点 while(!q.empty()){ int u = q.front(); q.pop(); for(int i=0; i<62; i++){ if(state[u].next[i]){ if(u == 0) state[state[u].next[i]].fail = 0; else{ int v = state[u].fail;//父亲的fail while(v != -1){ if(state[v].next[i]){//如果该节点的儿子有这条边 state[state[u].next[i]].fail = state[v].next[i]; state[state[u].next[i]].cnt |= state[state[v].next[i]].cnt; break; } v = state[v].fail; } if(v == -1) state[state[u].next[i]].fail = 0; } q.push(state[u].next[i]); } else{//按照蓝书上的话说 是把不存在的fail也补上 导致match时可以不需要不断往上跳 if(u == 0) state[u].next[i] = 0; else state[u].next[i] = state[state[u].fail].next[i]; } } } } double match(int u, int l){ if(!l) return 1.0; if(vis[u][l]) return dp[u][l]; vis[u][l] = true; double & ans = dp[u][l]; ans = 0.0; for(int i=0; i<62; i++) if(state[state[u].next[i]].cnt == 0) ans += pos[i] * match(state[u].next[i],l-1); return ans; } }aho; int main(){ int T, kases = 1; scanf("%d",&T); while(T--){ aho.init(); memset(vis, false, sizeof(vis)); memset(dp, 0, sizeof(dp)); memset(pos, 0, sizeof(pos)); scanf("%d",&k); for(int i=0; i<k; i++){ scanf("%s",p); aho.insert(p); } aho.build(); scanf("%d",&n); char ch; double ps; for(int i=0; i<n; i++){ getchar(); scanf("%c %lf",&ch,&ps); int c = aho.idx(ch); pos[c] = ps; } scanf("%d",&l); printf("Case #%d: %.6lf\n",kases++,aho.match(0,l)); } } /* 2 1 a 2 a 0.5 b 0.5 2 2 ab ab 2 a 0.2 b 0.8 2 */