这是一道点分治模板题,由于写题要用到,我就来学了学点分治,然鹅这段时间
POJ
在维护,所以这里就不给出题目的链接了。╰( ̄▽ ̄)╭
Description
给定一棵
n
个节点的树,每条边都带有一个权值,问两点之间距离小于等于k的点对有多少个。
(n≤3∗104,k<108,w≤103)
Solution
看了一眼题目,立马先想暴力,枚举每个出发点
dfs
一遍,然后默默地看了一眼数据范围,
O(n2)
的暴力在
3∗104
的数据规模下无能为力啊。(╯°Д°)╯
于是就要用到树的点分治了。点分治,顾名思义就是把树上的点分开来处理咯。处理当前子树时, 路径有两种:1、穿过根的;2、没穿过根的。我们只要把子树内所有节点按
dep
排个序,然后用两个指针
i和j
分别从两头往中间移,直到
dep[i]+dep[j]≤k
,这样比
j
小的点也都能与i组成合法的点对,又要防止重复,我们只要将
ans
加上
j−i
就行了,即使得
dep[i]≤dep[j]
。然而这样会把树内所有的点对都算上,包括路径没有穿过根的点对,也就是说这些点在子树中还会再被算一遍,我们为了只算路径穿过根的点对,要把子树中的点对先删去,于是计算的问题就搞定了。
下面的问题就是如何分治了,为了防止复杂度到达极端的
O(n2)
,我们每次分治到下一层不能随意地找根,我们要时分开后的树尽可能小,这样就要用到树的重心了,因为以树的重心为根时最大的子树的
size
最小。对于求树的重心的方法,只要
dfs
一遍存下每棵子树的
size
,然后求出最大子树(还有子树外的点)最小的点即是树的重心了。
于是整个算法的流程就是:
1.找到子树的重心;2.以重心为根处理整棵子树;3.找到子树的子树的重心……以此类推。
Code
#include<vector>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<iostream>
#define N 30005
using namespace std;
template <
class T>
inline void Rd(T &res){
char c;res=
0;
int k=
1;
while(c=getchar(),c<
48&&c!=
'-');
if(c==
'-'){k=-
1;c=
'0';}
do{
res=(res<<
3)+(res<<
1)+(c^
48);
}
while(c=getchar(),c>=
48);
res*=k;
}
template <
class T>
inline void Pt(T res){
if(res<
0){
putchar(
'-');
res=-res;
}
if(res>=
10)Pt(res/
10);
putchar(res%
10+
48);
}
struct edge{
int v,nxt,w;
}e[N<<
1];
int n,k;
int head[N],edgecnt;
struct P1{
vector<int>G;
int ans;
int sz[N],dis[N],siz,root,SZ;
bool vis[N];
void find(
int x,
int t){
int tmp=
0;
sz[x]=
1;
for(
int i=head[x];~i;i=e[i].nxt){
int v=e[i].v;
if(v!=t&&!vis[v]){
find(v,x);
sz[x]+=sz[v];
if(tmp<sz[v])tmp=sz[v];
}
}
if(tmp<SZ-sz[x])tmp=SZ-sz[x];
if(tmp<siz){
siz=tmp;
root=x;
}
}
void dfs(
int x,
int t){
G.push_back(dis[x]);
for(
int i=head[x];~i;i=e[i].nxt){
int v=e[i].v;
if(v!=t&&!vis[v]){
dis[v]=dis[x]+e[i].w;
dfs(v,x);
}
}
}
int calc(
int x,
int d){
G.clear();dis[x]=d;
dfs(x,-
1);
sort(G.begin(),G.end());
int res=
0;
for(
int l=
0,r=G.size()-
1;l<r;)
if(G[l]+G[r]<=k)res+=r-l++;
else r--;
return res;
}
void work(
int x){
ans+=calc(x,
0);
vis[x]=
1;
for(
int i=head[x];~i;i=e[i].nxt){
int v=e[i].v;
if(!vis[v]){
ans-=calc(v,e[i].w);
SZ=sz[v];
siz=
1e9;
find(v,-
1);
work(root);
}
}
}
void solve(){
memset(vis,
0,
sizeof(vis));
ans=
0;
SZ=n;
siz=
1e9;
find(
1,-
1);
work(root);
Pt(ans);
putchar(
'\n');
}
}P1;
void add_edge(
int u,
int v,
int w){
e[edgecnt]=(edge){v,head[u],w};head[u]=edgecnt++;
}
int main(){
memset(head,-
1,
sizeof(head));
int a,b,c;
Rd(n);Rd(k);
for(
int i=
1;i<n;i++){
Rd(a);Rd(b);Rd(c);
add_edge(a,b,c);
add_edge(b,a,c);
}
P1.solve();
return 0;
}