BZOJ4598
求树上满足某些条件的点对,首先就可以想到点分治。 然后又与什么字符串匹配有关。
KMP,AC自动机……
之类的好像不太好用。。那就哈希吧! 添加答案的时候有两种情况:
那么就分别维护从上到下的链和从下到上的链。不是所有链都存的,仅当“从该点到当前根的一段是若干个模式串的前缀或者后缀”时才存。 发现当长度为
a
时,不仅m−a可以更新答案,长度为
km−a
的也可以。那这样岂不是每次更新都是
O(n/m)
?存入答案的时候,其实长度为
x
和x km是等价的。取个模再存就好了。
看起来好像不是很难。。但是蒟蒻表示:好多优化啊!!
QAQ
。。还是要把细节都想清白了再打。。
【代码】
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const ull base=
31;
ll
read()
{
ll
x=
0,f=
1;char ch=getchar();
while(!isdigit(ch)){
if(ch==
'-') f=-
1;ch=getchar();}
while(isdigit(ch)){
x=(
x<<
1)+(
x<<
3)+ch-
'0';ch=getchar();}
return x*f;
}
int sum,T,n,
m,cnt,rt,s1,s2;
ll ans;
int b[N<<
1],p[N],nextedge[N<<
1];
int sz[N],f[N],d[N];
int Cnt[N],ccnt[N],st[N],sst[N];
char
s[N];bool Flag[N];
ull a[N],ha[N],Ha[N],P[N],hash[N],Hash[N],deep[N],Deep[N];
void Pre()
{
P[
0]=
1;
for(
int i=
1;i<N;i++) P[i]=P[i-
1]
*base;
}
void Add(
int x,
int y){
cnt++;
b[cnt]=
y;
nextedge[cnt]=p[
x];
p[
x]=cnt;
}
void Anode(
int x,
int y){
Add(
x,
y);Add(
y,
x);
}
void Input_Init()
{
n=
read(),
m=
read();rt=ans=
0,f[
0]=INF;sum=n;
cnt=
0;
for(
int i=
1;i<=n;i++) p[i]=Flag[i]=deep[i]=
0;
scanf(
"%s",
s+
1);
for(
int i=
1;i<=n;i++) a[i]=
s[i]-
'A'+
1;
for(
int i=
1;i<n;i++)
{
static
int x,
y;
x=
read(),
y=
read();
Anode(
x,
y);
}
scanf(
"%s",
s+
1);
for(
int i=
1;i<=
m;i++) ha[i]=
s[i]-
'A'+
1,Ha[i]=
s[
m-i+
1]-
'A'+
1;
for(
int i=
1;i<=n;i++) hash[i]=hash[i-
1]+ha[(i-
1)
%m+
1]
*P[i-
1],Hash[i]=Hash[i-
1]+P[i-
1]
*Ha[(i-
1)
%m+
1];
}
void Get_Root(
int x,
int fa)
{
f[
x]=
0;sz[
x]=
1;
for(
int i=p[
x];i;i=nextedge[i])
{
int v=b[i];
if(v==fa||Flag[v])
continue;
Get_Root(v,
x);
sz[
x]+=sz[v];
f[
x]=max(f[
x],sz[v]);
}
f[
x]=max(f[
x],sum-sz[
x]);
rt=f[rt]>f[
x]?
x:rt;
}
void Get_deep(
int x,
int fa)
{
if(Hash[d[
x]]==deep[
x]&&a[
x]==Ha[
1]) st[++s1]=
x;
if(hash[d[
x]]==deep[
x]&&a[
x]==ha[
1]) sst[++s2]=
x;
for(
int i=p[
x];i;i=nextedge[i])
{
int v=b[i];
if(v==fa||Flag[v])
continue;
deep[v]=deep[
x]
*base+a[v];
d[v]=d[
x]+
1;
Get_deep(v,
x);
}
}
void Calc(
int x)
{
for(
int i=
0;i<=
m;i++) Cnt[i]=ccnt[i]=
0;
if(Ha[
1]==a[
x]) ccnt[
1]=
1;
if(Ha[
m]==a[
x]) Cnt[
1]=
1;
for(
int i=p[
x];i;i=nextedge[i])
{
int v=b[i];
if(Flag[v])
continue;
s1=s2=
0;d[v]=
1;deep[v]=a[v];
Get_deep(v,
x);
for(
int j=
1;j<=s1;j++){
int t=st[j],
pos=
m-d[t]
%m;
if(
pos==
0)
pos=
m;
ans+=Cnt[
pos];
}
for(
int j=
1;j<=s2;j++){
int t=sst[j],
pos=
m-d[t]
%m;
if(
pos==
0)
pos=
m;
ans+=ccnt[
pos];
}
for(
int j=
1;j<=s1;j++) {
int t=st[j];
int pos=d[t]
%m+
1;
if(a[
x]==Ha[
pos]) ccnt[
pos]++;
}
for(
int j=
1;j<=s2;j++) {
int t=sst[j];
int pos=d[t]
%m+
1;
if(a[
x]==ha[
pos]) Cnt[
pos]++;
}
}
}
void Work(
int x)
{
Calc(
x);Flag[
x]=
1;
for(
int i=p[
x];i;i=nextedge[i])
{
int v=b[i];
if(Flag[v])
continue;
sum=sz[v];rt=
0;
if(sum<
m)
continue;
Get_Root(v,
0);
Work(rt);
}
}
int main()
{
T=
read();
Pre();
while(T--)
{
Input_Init();
Get_Root(
1,
0);
Work(rt);
printf(
"%lld\n",ans);
}
return 0;
}