题意:
给定一个树(结点数<=5e4),问在树上任意选两点,它们之间的点距为素数的概率。
思路:
很容易想到点分治统计答案,然后只需要统计
①经过根结点的
②根节点到子树某一结点的
第②种一遍dfs就可以在o(n)的时间内完成
第①种很容易想到两两儿子之间进行fft,很明显会超时,那么考虑逆向算,先统计所有子树结点之间,然后减去每个子树和自己子树之间的,最后除以2(因为起点终点交换算同一种)
复杂度为nlog(n) + sqrt(n)*(sqrt(n)log(sqrt(n))) =O(nlog(n))
那么我们就能在O(nlog(n))内统计经过根节点的各个路径长度的计数了
注意每个子树和自己的fft长度需要根据子树大小重新算
最后统计长度为素数的计数就行了
那么总的复杂度为O(nlog(n)log(n))
整个思路都比较容易想到,不太好写
代码:
#include<bits/stdc++.h> const double PI = acos(-1.0); using namespace std; struct comple{ double r,i; comple(double rr=0,double ii=0){ r = rr ; i = ii; }comple operator + ( const comple& a){ return comple(r+a.r,i+a.i); }comple operator - ( const comple& a){ return comple(r-a.r,i-a.i); }comple operator * ( const comple& a){ return comple(r*a.r-i*a.i,r*a.i+i*a.r); } }; int getLen(int x){ int res = 1; while(res<x)res<<=1; return res; } void brc(comple*a,int l){ for(int i=1,j=l/2;i<l-1;i++){ if(i<j)swap(a[i],a[j]); int k = l/2; while(j>=k){ j -= k; k >>= 1; }if(j<k)j+=k; } } void fft(comple *y,int l,int on){ brc(y,l); comple u,t; for(int h=2;h<=l;h<<=1){ comple wn(cos(on*2*PI/h),sin(on*2*PI/h)); for(int j=0;j<l;j+=h){ comple w(1,0); for(int k=j;k<j+h/2;k++){ u = y[k]; t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w * wn; } } }if(on<0){ for(int i=0;i<l;i++){ y[i].r/=l; } } } const int MAXN = 50005; bool Np[MAXN]; vector<int>pr; void init(){ for(int i=2;i<MAXN;i++){ if(!Np[i])pr.push_back(i); for(int j=0;j<pr.size();j++){ int t = pr[j] * i; if(t>=MAXN)break; Np[t] = true; if(i%pr[j]==0){ break; } } } } int n; vector<int>G[50005]; int root , ban ; int son[50005]; bool used[50005]; int getSZ(int x,int fa=0){ int res = 1; for(int i=0;i<G[x].size();i++){ int t = G[x][i]; if(!used[t]&&t!=fa){ res += getSZ(t,x); } }return res; } void findR(int x,int fa,int sz){ int cnt = 0; son[x] = 1; for(int i=0;i<G[x].size();i++){ int t = G[x][i]; if(t!=fa&&!used[t]){ findR(t,x,sz); cnt = max(cnt,son[t]); son[x] += son[t]; } }cnt = max(cnt,sz-son[x]); if(cnt<ban){ ban = cnt; root = x; } } const int MAXL = (1<<16)+10; comple F[MAXL]; int num[2][MAXL]; long long ans[MAXL]; long long tmp[MAXL]; void dfs(int x,int fa,int step){ num[0][step]++; num[1][step]++; son[x] = 1; ans[step]++; for(int i=0;i<G[x].size();i++){ int t = G[x][i]; if(!used[t]&&t!=fa){ dfs(t,x,step+1); son[x] += son[t]; } } } void oper(int x,int l){ for(int i=0;i<G[x].size();i++){ int t = G[x][i]; if(!used[t]){ dfs(t,x,1); int ll = getLen((son[t]<<1)+1); for(int i=0;i<ll;i++){ F[i].i = 0; F[i].r = num[1][i]; num[1][i] = 0; }fft(F,ll,1); for(int i=0;i<ll;i++){ F[i] = F[i] * F[i]; }fft(F,ll,-1); for(int i=0;i<ll;i++){ tmp[i] -= (long long)(F[i].r+0.5); } } }for(int i=0;i<l;i++){ F[i].i = 0; F[i].r = num[0][i]; num[0][i] = 0; }fft(F,l,1); for(int i=0;i<l;i++){ F[i] = F[i] * F[i]; }fft(F,l,-1); for(int i=0;i<l;i++){ tmp[i] += (long long)(F[i].r+0.5); ans[i] += tmp[i]/2; tmp[i] = 0; } } void work(int x,int sz){ ban = n ; root = x; findR(x,0,sz); x = root; if(sz<=2)return ; int l = getLen(sz); oper(x,l); used[x] = true; for(int i=0;i<G[x].size();i++){ int t = G[x][i]; if(!used[t]){ work(t,son[t]); } } } int main() { init(); while(scanf("%d",&n)==1){ for(int i=1;i<=n;i++){ G[i].clear(); ans[i] = used[i] = 0; } for(int i=1,s,e;i<n;i++){ scanf("%d%d",&s,&e); G[s].push_back(e); G[e].push_back(s); }work(1,n); long long sum = 0; long long y = 1LL * n * (n-1) / 2; for(int i=0;i<pr.size()&&pr[i]<n;i++){ sum += ans[pr[i]]; }double ANS = 1.0 * sum / y; printf("%.10f\n",ANS); }return 0; } /** 15 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10 11 11 12 12 13 13 14 14 15 0.46666667 **/