DNA Sequence POJ - 2778 AC自动机 矩阵快速幂

xiaoxiao2025-12-08  3

题解

给m个长度10以内的病毒串 问长度为n的主串且不匹配任意一个病毒串的有多少个

m最大10所以节点数不超过100 利用AC自动机建图 建立邻接矩阵表示从节点i到节点j能转移的字符数量 除去字符结束节点和fail指针路径上是结束节点 通过N个邻接矩阵相乘即可得到i到j走N步的方案数 将0到i求和即为答案 因为N过大需要用矩阵快速幂求解

AC代码

#include <stdio.h> #include <iostream> #include <queue> using namespace std; typedef long long ll; const int INF = 0x3f3f3f3f; const int MOD = 1e5; const int MAXN = 102; const int MAXC = 10; int nxt[MAXN][MAXC], sed[MAXN], fal[MAXN], idx; int vis[MAXN]; char s[MAXN]; struct Matix { ll m[MAXN][MAXN]; Matix() { memset(m, 0, sizeof(m)); } Matix operator*(const Matix &m2){ Matix t; for (int i = 0; i <= idx; ++i) //不能写小于MAXN for (int j = 0; j <= idx; ++j) for (int k = 0; k <= idx; ++k) t.m[i][j] = (t.m[i][j] + m[i][k] * m2.m[k][j]) % MOD; return t; }; }g, u; void Insert(char *s, int n) //插入一个字符串s长度为n的模式串 { int x = 0; for (int i = 0; i < n; i++) { int c = s[i]; if (!nxt[x][c]) nxt[x][c] = ++idx; x = nxt[x][c]; } sed[x]++; } void Build() //建立失配指针信息 { queue<int> q; //需要先给每个节点的父节点建立失配信息 类似广搜 for (int i = 0; i < MAXC; i++) if (nxt[0][i]) //先将根节点连接的有效节点入队 不能从根节点出发 q.push(nxt[0][i]); //初始每个节点的失配节点都是根 while (!q.empty()) { int f = q.front(); q.pop(); for (int i = 0; i < MAXC; i++) if (nxt[f][i]) //存在子节点 fal[nxt[f][i]] = nxt[fal[f]][i], q.push(nxt[f][i]); //子节点失配尝试匹配一次父节点失配指针的子节点 else //如果不存在 nxt[f][i] = nxt[fal[f]][i]; //则直接将这个节点设定为父节点失配节点的子节点 } } int Match(char *s, int n) //查询字符串s能够匹配多少模式串 { int x = 0, res = 0; //当前节点 查询结果 for (int i = 0; i < n; i++) { int c = s[i] - 'A' + 1; x = nxt[x][c]; //转移到当前字符 如果失配会自动到失配指针 for (int p = x; p; p = fal[p])//已经被处理过了的节点不在继续 res += sed[p]; //将以当前节点为结尾的所有子串全部加上 -1标记为已访问 } return res; } void DFS(int x) //建立图 { vis[x] = 1; for (int i = 1; i <= 4; i++) { int flag = 1; for (int p = nxt[x][i]; p; p = fal[p]) //将结束节点和失配连上有结束节点的除去 if (sed[p]) { flag = 0; break; } if (flag) { g.m[x][nxt[x][i]]++; if (!vis[nxt[x][i]]) DFS(nxt[x][i]); } } } int main() { #ifdef LOCAL freopen("C:/input.txt", "r", stdin); #endif int M, N; cin >> M >> N; for (int i = 0; i < M; i++) { scanf("%s", s); int l = strlen(s); for (int i = 0; i < l; i++) if (s[i] == 'A') s[i] = 1; else if (s[i] == 'C') s[i] = 2; else if (s[i] == 'G') s[i] = 3; else if (s[i] == 'T') s[i] = 4; Insert(s, l); } Build(); DFS(0); for (int i = 0; i < MAXN; i++) u.m[i][i] = 1; while (N) { if (N & 1) u = u * g; g = g * g; N >>= 1; } ll ans = 0; for (int i = 0; i < MAXN; i++) ans = (ans + u.m[0][i]) % MOD; cout << ans << endl; return 0; }
转载请注明原文地址: https://www.6miu.com/read-5040564.html

最新回复(0)