很久以前就学过树分治,但是掌握不熟练(其实是弃坑了) 所以现在重新拾起这个算法,终于填坑完成…… 发现还是挺简单的
树分治,是用于统计树上路径的算法
POJ1741就是一个很好的例子 下面会以此题为例,详细讲解树分治
树分治分为两种:点分治与边分治 点分治:每次找一个点,分治所有以它的儿子为根的子树 边分治:每次找一条边,分治它连接的两个点为根的子树
由于边分治容易被特殊数据卡,所以一般使用点分治 下面也是用点分治实现
前面提到点分治的效率比较稳定,是因为它每次用来分治的点是有讲究的:选择树的“重心”来分治 树的重心:若将某点设为根后,最大子树最小,那么此点就是树的重心
用反证法易得:重心的最大子树不大于整棵树的一半 也就是说,每次选择重心分治,树的大小至少减去一半 这样点分治的复杂度最坏就是 O(NlogN) 就保证了效率的稳定性
根据重心的定义,可以很容易用一次DFS找到重心 代码实现:
void get_hvy(int x,int fa){ siz[x]=1;MAX[x]=0; for (int j=lnk[x];j;j=nxt[j]) if (son[j]!=fa&&!vis[son[j]]){ get_hvy(son[j],x); siz[x]+=siz[son[j]]; MAX[x]=max(MAX[x],siz[son[j]]); } MAX[x]=max(MAX[x],S-siz[x]); if (!hvy||MAX[x]<MAX[hvy]) hvy=x; }(hvy表示当前的重心,S表示当前树的大小,都是全局变量)
再来看树分治的核心部分。
对于当前处理的树,统计经过树根的路径 但是有不符合题意的路径(路径的两个端点在同一子树,不是简单路径) 所以再以每个儿子为根,计算一次,减去即可(注意要带上权值w) 然后对子树找重心,递归处理 代码实现:
void get_ans(int x){ vis[x]=1;ans+=get_sum(x,0); for (int j=lnk[x];j;j=nxt[j]) if (!vis[son[j]]){ ans-=get_sum(son[j],w[j]); hvy=0;S=siz[son[j]];get_hvy(son[j],0); get_ans(hvy); } }具体如何统计经过树根的路径呢 先获得整棵树的深度,记为dep[] 那么对于任意点对(i,j)只要满足 dep[i]+dep[j]≤k 即可被统计 那么排序后用两个指针 O(N) 推一下就好了 代码实现:
void get_dep(int x,int fa){ now[++now[0]]=dep[x]; for (int j=lnk[x];j;j=nxt[j]) if (fa!=son[j]&&!vis[son[j]]){ dep[son[j]]=dep[x]+w[j]; get_dep(son[j],x); } } int get_sum(int x,int dst){ now[0]=0;int res=0; dep[x]=dst;get_dep(x,0); sort(now+1,now+1+now[0]); for (int i=1,j=now[0];i<=now[0];i++){ while (j>i&&now[i]+now[j]>k) j--; res+=max(0,j-i); } return res; }完整代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn=10005,maxe=2*maxn; int n,k,hvy,S,ans; int tot,lnk[maxn],nxt[maxe],son[maxe],w[maxe]; int siz[maxn],now[maxn],dep[maxn],MAX[maxn]; bool vis[maxn]; inline char nc(){ static char buf[100000],*p1=buf,*p2=buf; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } inline int red(){ int res=0,f=1;char ch=nc(); while (ch<'0'||'9'<ch) {if (ch=='-') f=-f;ch=nc();} while ('0'<=ch&&ch<='9') res=res*10+ch-48,ch=nc(); return res*f; } void add(int x,int y,int z){ son[++tot]=y;nxt[tot]=lnk[x];lnk[x]=tot;w[tot]=z; } void get_hvy(int x,int fa){ siz[x]=1;MAX[x]=0; for (int j=lnk[x];j;j=nxt[j]) if (son[j]!=fa&&!vis[son[j]]){ get_hvy(son[j],x); siz[x]+=siz[son[j]]; MAX[x]=max(MAX[x],siz[son[j]]); } MAX[x]=max(MAX[x],S-siz[x]); if (!hvy||MAX[x]<MAX[hvy]) hvy=x; } void get_dep(int x,int fa){ now[++now[0]]=dep[x]; for (int j=lnk[x];j;j=nxt[j]) if (fa!=son[j]&&!vis[son[j]]){ dep[son[j]]=dep[x]+w[j]; get_dep(son[j],x); } } int get_sum(int x,int dst){ now[0]=0;int res=0; dep[x]=dst;get_dep(x,0); sort(now+1,now+1+now[0]); for (int i=1,j=now[0];i<=now[0];i++){ while (j>i&&now[i]+now[j]>k) j--; res+=max(0,j-i); } return res; } void get_ans(int x){ vis[x]=1;ans+=get_sum(x,0); for (int j=lnk[x];j;j=nxt[j]) if (!vis[son[j]]){ ans-=get_sum(son[j],w[j]); hvy=0;S=siz[son[j]];get_hvy(son[j],0); get_ans(hvy); } } int main(){ for (n=red(),k=red();n||k;n=red(),k=red()){ memset(vis,0,sizeof(vis)); memset(lnk,0,sizeof(lnk));tot=0; for (int i=1,x,y,z;i<n;i++) x=red(),y=red(),z=red(),add(x,y,z),add(y,x,z); hvy=0;S=n;get_hvy(1,0); ans=0;get_ans(hvy); printf("%d\n",ans); } return 0; }附:点分治相关题目BZOJ2152
