POJ 1741 Tree 树分治

xiaoxiao2021-02-28  57

Tree Time Limit: 1000MS Memory Limit: 30000KTotal Submissions: 27073 Accepted: 9003

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).  Define dist(u,v)=The min distance between node u and v.  Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.  Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.  The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0

Sample Output

8

Source

LouTiancheng@POJ

给定一棵树,每条边有距离。

求节点间距离小于k的点对数量。

树分治。对于每个点,分别求经过这个点的,满足条件的路径数量。

具体而言,对于一个固定的节点,先dfs求出所有未访问节点到当前节点的距离,并排序,再利用two-pointer O(n)求出所有距离和<=k的点对数量。此时多算了路径起点终点都在同一棵子树内的情况,为了消除这一部分,再对每个子树dfs求出多出的部分是多少。

#include <cstdio> #include <iostream> #include <string.h> #include <string> #include <map> #include <queue> #include <vector> #include <set> #include <algorithm> #include <math.h> #include <cmath> #include <bitset> #define mem0(a) memset(a,0,sizeof(a)) #define meminf(a) memset(a,0x3f,sizeof(a)) using namespace std; typedef long long ll; typedef long double ld; const int maxn=200005,inf=0x3f3f3f3f; const ll llinf=0x3f3f3f3f3f3f3f3f; const ld pi=acos(-1.0L); int head[maxn],size[maxn],ms[maxn],d[maxn]; int num=0,root=-1,rs=inf,sum,k,ans,cnt=0; bool visit[maxn]; char s[maxn]; struct Edge { int from,to,pre,dist; }; Edge edge[maxn*2]; void addedge(int from,int to,int dist) { edge[num]=(Edge){from,to,head[from],dist}; head[from]=num++; edge[num]=(Edge){to,from,head[to],dist}; head[to]=num++; } void getroot(int now,int fa) { size[now]=ms[now]=0; for (int i=head[now];i!=-1;i=edge[i].pre) { int to=edge[i].to; if (!visit[to]&&to!=fa) { getroot(to,now); size[now]+=size[to]; ms[now]=max(ms[now],size[to]); } } size[now]++; ms[now]=max(ms[now],sum-size[now]); if (ms[now]<rs) root=now,rs=ms[now]; } void dfs(int now,int fa,int dis) { d[++cnt]=dis; for (int i=head[now];i!=-1;i=edge[i].pre) { int to=edge[i].to; if (!visit[to]&&to!=fa) dfs(to,now,dis+edge[i].dist); } } int cal(int now,int fa,int dis) { cnt=0; dfs(now,0,dis); sort(d+1,d+cnt+1); int ssum=0,i,l=1,r=cnt; while (l<r) { if (d[l]+d[r]<=k) { ssum+=r-l; l++; } else r--; } return ssum; } void solve(int now) { visit[now]=1; ans+=cal(now,0,0); for (int i=head[now];i!=-1;i=edge[i].pre) { int to=edge[i].to; if (!visit[to]) ans-=cal(to,now,edge[i].dist); } for (int i=head[now];i!=-1;i=edge[i].pre) { int to=edge[i].to; if (!visit[to]) { root=-1,rs=inf; sum=size[to]; getroot(to,0); solve(root); } } } int main() { int i,j,x,y,z,n; scanf("%d%d",&n,&k); while (n||k) { num=ans=0; memset(head,-1,sizeof(head)); for (i=1;i<n;i++) { scanf("%d%d%d",&x,&y,&z); addedge(x,y,z); } sum=n; getroot(1,0); mem0(visit); solve(root); printf("%d\n",ans); scanf("%d%d",&n,&k); } return 0; }

转载请注明原文地址: https://www.6miu.com/read-2619662.html

最新回复(0)