【题目链接】 https://www.lydsy.com/JudgeOnline/problem.php?id=4543 【题解】 枚举中点的方式行不通了,需要换一种思路。 想办法dp一下: 记 f[i][j] f [ i ] [ j ] 表示以 i i 为根的子树,到ii距离为 j j 的点的数目。 g[i][j]g[i][j]表示以 i i 为根的子树,在其中有多少对点可以与在子树ii外,且到 i i 的距离为jj的点组成满足题意的三元组的数目。 dp时:每次加入 i i 的一个儿kk子后,先更新答案: ans=ans+∑nj=0f[i][j]∗g[k][j+1]+∑nj=0g[i][j]∗f[k][j+1] a n s = a n s + ∑ j = 0 n f [ i ] [ j ] ∗ g [ k ] [ j + 1 ] + ∑ j = 0 n g [ i ] [ j ] ∗ f [ k ] [ j + 1 ] 再更新 f f 和gg: g[i][j]=g[i][j]+f[i][j]∗f[k][j−1] g [ i ] [ j ] = g [ i ] [ j ] + f [ i ] [ j ] ∗ f [ k ] [ j − 1 ] f[i][j]=f[i][j]+f[k][j−1] f [ i ] [ j ] = f [ i ] [ j ] + f [ k ] [ j − 1 ] g[i][j]=g[i][j]+f[k][j+1] g [ i ] [ j ] = g [ i ] [ j ] + f [ k ] [ j + 1 ] *注意顺序。 然而这样复杂度还是 O(N2) O ( N 2 ) 的,需要优化转移。 考虑 i i 与son[i]son[i]若 i i 只有一个儿子,那么f[i][j]=f[i][j−1],g[i][j]=g[i][j 1]f[i][j]=f[i][j−1],g[i][j]=g[i][j 1]。 所以可以通过指针移动 O(1) O ( 1 ) 解决。 所以我们想到了一种优化方式:每个节点通过指针移动继承可以延伸最长的的儿子的答案,其他儿子暴力计算。 这个算法的复杂度是: ∑ni=1dep[i]−∑ni=1dep[i]−1=O(N) ∑ i = 1 n d e p [ i ] − ∑ i = 1 n d e p [ i ] − 1 = O ( N ) 每个儿子往上转移的复杂度是 dep d e p 的,由于一个节点的深度一定是长链所指向的儿子的深度+1,所以可以省下 dep−1 d e p − 1 次转移。
# include <bits/stdc++.h> # define N 1000100 # define ll long long using namespace std; int read(){ int tmp=0, fh=1; char ch=getchar(); while (ch<'0'||ch>'9'){if (ch=='-') fh=-1; ch=getchar();} while (ch>='0'&&ch<='9'){tmp=tmp*10+ch-'0'; ch=getchar();} return tmp*fh; } struct Edge{ int data,next; }e[N*2]; int dep[N],per[N],head[N],place,n; ll space[N*10]; ll *f[N],*g[N],ans,*now=space+N; void build(int u, int v){ e[++place].data=v; e[place].next=head[u]; head[u]=place; } void create(int id){ f[id]=now; now=now+dep[id]*2+1; g[id]=now; now=now+dep[id]*2+1; } void dep_cnt(int x, int fa){ for (int ed=head[x]; ed!=0; ed=e[ed].next) if (e[ed].data!=fa){ dep_cnt(e[ed].data,x); if (dep[e[ed].data]>dep[per[x]]) per[x]=e[ed].data; } dep[x]=dep[per[x]]+1; } void solve(int x, int fa){ f[x][0]=1; if (per[x]!=0){ f[per[x]]=f[x]+1; g[per[x]]=g[x]-1; solve(per[x],x); ans=ans+g[per[x]][1]; } for (int ed=head[x]; ed!=0; ed=e[ed].next) if (e[ed].data!=fa&&e[ed].data!=per[x]){ create(e[ed].data); solve(e[ed].data,x); int y=e[ed].data; for(int j=dep[y];j>=0;j--){ if(j) ans+=f[x][j-1]*g[y][j]; ans+=g[x][j+1]*f[y][j]; g[x][j+1]+=f[x][j+1]*f[y][j]; } for(int j=0;j<=dep[y];j++){ if(j) g[x][j-1]+=g[y][j]; f[x][j+1]+=f[y][j]; } } } int main(){ n=read(); for (int i=1; i<n; i++){ int u=read(), v=read(); build(u,v); build(v,u); } dep_cnt(1,0); create(1); solve(1,0); printf("%lld\n",ans); return 0; }