题目大意:给你n个字符串,然后有q个询问,每个询问都给一个字符串的前驱和后缀,要求输出所给出的这个前驱和后缀可能在多少个字符串中出现。 解题思路:因为题目上给出的是前缀和后缀,所以我们在建立字典树时的顺序为str0,strn-1,str2,strn-2··· ···strn-1,str0;设前缀为s 后缀为t,则在查询时查询的字符串也为s0, tn-1,s1, tn-2,前后缀长度不等用*补齐。(如ac f,构造后的字符串为 afc*). 字典树需要维护3个值,这个节点的后继节点child,经过这个节点的字符串有多少个cnt,经过这个节点的字符串的长度len(用动态数组来存)。 在查询时因为要防止 字符串为 aaa 前缀为 aa 后缀为 aa 这种情况出现,所以我们需要对每个节点中vector的数值排序,去掉长度小于前缀+后缀的值,最后cnt的值就可能是当前这个前缀和后缀可能组成字符串的个数了。 这里还有一个要注意的地方就是当出现 * 的时候 * 后边所有的节点都可能是我们要匹配的答案,所以所有节点都要记录下来留作处理。这里我用连个队列来处理我们的节点。一个队列用来存放父亲节点,一个队列用来存放孩子节点,两个队列交换使用,实现对所有可能节点的计数。
注意:这个题的内存给的比较大,而时间比较少,如果用链表的方式来建树可能会超时,这里我给出链表建树和数组建树两种方法。链表式的方法只是用来助理解的,正确的还是要用数组来存。 链表建树:
#include <cstdio> #include <iostream> #include <string.h> #include <stdlib.h> #include <algorithm> #include <queue> #include <vector> #define N 100000+10 using namespace std; struct Trie { vector<int> len; int cnt; Trie *child[26]; Trie()///构造函数,完成每个节点的初始化 { cnt = 0; len.clear(); for(int i=0; i<26; i++) child[i]=NULL; } }; Trie *root,*current,*temp; void insert(char s[],int lens) ///建树时要维护cnt(表示经过这个节点的字符串有多少个)和动态数组len(保存经过这个节点的字符串的长度len) { current = root; for(int i=0; i<strlen(s); i++) { if(current->child[s[i]-'a']==NULL) { temp = new Trie; current->child[s[i]-'a'] = temp; current = current->child[s[i]-'a']; current->len.push_back(lens); current->cnt = 1; } else { current = current->child[s[i]-'a']; current->cnt++; current->len.push_back(lens); } } } void dfs(Trie *root) { for(int i=0; i<26; i++) if(root->child[i]!=NULL) dfs(root->child[i]); sort(root->len.begin(),root->len.end()); } int search(char s[],int lenn,int ret) { queue<Trie*> Q[2]; Q[0].push(root); int ans = 0,tmp = 0; for(int i=0;i < lenn;i++){ tmp = 1-tmp; while(!Q[1-tmp].empty()){ ///每次改变队列,实现当前节点在一个队列,当前节点的子节点在一个队列,不漏掉任何一个可能的节点 current = Q[1-tmp].front(); Q[1-tmp].pop(); if(s[i] == '*'){ for(int j = 0; j < 26; j++) { if(current->child[j]!=NULL) Q[tmp].push(current->child[j]); } } else{ if(current->child[s[i]-'a']!=NULL) Q[tmp].push(current->child[s[i]-'a']); } } } while(!Q[tmp].empty()){ ///因为可能出现*后边有很多个节点的情况,所以可能有多个满足条件的节点,每个可能满足条件的节点都要计算 Trie *t= new Trie; t = Q[tmp].front(); Q[tmp].pop(); int tt = lower_bound(t->len.begin(),t->len.end(),ret) - t->len.begin(); ///lower_bound返回的是大于或等于ret的第一个元素的位置,再减去元素首位置,就是这个数组中有多少个元素比ret小 ans -= tt; ans += t->cnt; } return ans; } void del(Trie *root) { for(int i=0;i<26;i++) if(root->child[i]!=NULL) del(root->child[i]); delete(root); } int main() { int T; cin>>T; while(T--) { root = new Trie; int n,q; char str[N],st[2*N],s[N],t[N]; scanf("%d %d",&n,&q); for(int i=0; i<n; i++) { scanf("%s",str); int L = strlen(str); for(int i=0; i<L; i++)///st用来保存处理后的字符串即str0,strn-1,str1,strn-2······ { st[i*2] = str[i]; st[i*2+1] = str[L-1-i]; } st[2*L] = '\0'; insert(st,L);///构造字典树 } dfs(root); for(int i=0; i<q; i++) { scanf("%s %s",s,t); int len1 = strlen(s); int len2 = strlen(t); int ret = len1+len2; int len = max(len1,len2); for(int j= 0; j < len; j ++)///将前后缀串按前边的约定来构造 { if(j < len1) st[j*2] = s[j]; else st[j*2] = '*'; if(j < len2) st[j*2+1] = t[len2-1-j]; else st[j*2+1] = '*'; } len*=2; st[len] = '\0'; printf("%d\n",search(st,len,ret)); } del(root);//防止内存超 } return 0; }数组建树:
#include<bits/stdc++.h> using namespace std; const int N = 5e5+100; char s[N],t[N],st[N*2]; int n,q; struct Tire{ int L,root,net[N*2][26],ed[2*N]; vector<int> G[2*N]; int newnode(){ for(int i= 0;i < 26;i ++) net[L][i] = -1; ed[L] = 0; return L++; } void init(){ L = 0; for(int i= 0;i < 2*N;i ++) G[i].clear(); root = newnode(); } void build(char *s,int len){ int lens = strlen(s); int now = root; for(int i=0 ;i < lens;i ++){ int id = s[i]-'a'; if(net[now][id] == -1){ net[now][id] = newnode(); } now = net[now][id]; G[now].push_back(len); ed[now]++; } } void dfs(int x){ sort(G[x].begin(),G[x].end()); for(int i = 0;i < 26;i ++){ if(net[x][i] != -1) dfs(net[x][i]); } } }tire; int main(){ int T; cin >> T; while(T--){ tire.init(); scanf("%d %d",&n,&q); for(int i= 1;i <= n;i ++){ scanf("%s",&s); int len = strlen(s); for(int i= 0;i < len;i ++){ st[i*2] = s[i]; st[i*2+1] = s[len-1-i]; } st[len*2] = '\0'; tire.build(st,len); } tire.dfs(0); for(int i= 1;i <= q;i ++){ scanf("%s %s",s,t); int len1 = strlen(s); int len2 = strlen(t); int ret = len1+len2; int len = max(len1,len2); for(int j= 0;j < len;j ++){ if(j < len1) st[j*2] = s[j]; else st[j*2] = '*'; if(j < len2) st[j*2+1] = t[len2-1-j]; else st[j*2+1] = '*'; } len*=2; st[len] = '\0'; queue<int> q[2]; int tmp = 0; q[0].push(0); int ans = 0; for(int j= 0;j < len;j ++){ tmp = 1-tmp; int id = st[j]-'a'; while(!q[1-tmp].empty()){ int now = q[1-tmp].front(); q[1-tmp].pop(); if(st[j] == '*'){ for(int k = 0;k < 26;k ++){ if(tire.net[now][k] != -1){ q[tmp].push(tire.net[now][k]); } } } else{ if(tire.net[now][id] != -1){ q[tmp].push(tire.net[now][id]); } } } } while(!q[tmp].empty()){ int now = q[tmp].front(); q[tmp].pop(); int cnt = lower_bound(tire.G[now].begin(),tire.G[now].end(),ret) - tire.G[now].begin(); ans -= cnt; ans += tire.ed[now]; } printf("%d\n",ans); } } return 0; }