思路: 在这棵树上,对于每个节点u存一下从根节点(默认为1) 到 u 路径上所有线段出现的次数,那么用主席树维护所有历史信息,然后计算任意两个特定节点u, v的所有线段次数:num[线段树(u)] + num[线段树(v)] - 2 * num[线段树(LCA(u, v))],然后对于每次询问询问主席树就好了,就是求区间中位数了嘛
代码来源:http://www.cnblogs.com/fightfordream/p/7159024.html
#include<cstdio> #include<vector> #include<string.h> #include<algorithm> using namespace std; #define mem(a,b) memset(a, b, sizeof(a)) #define Lson num<<1,Left,Mid #define Rson num<<1|1, Mid+1, Right const int Maxn = 1e5 + 7; struct Edge{int to,nex,w;}edge[Maxn<<1]; struct Node{int Left,Right,sum;}T[Maxn*40]; int n,m,mx,cas,dis[Maxn],root[Maxn],cnt,dp[Maxn][25],dep[Maxn],fa[Maxn],head[Maxn],tol; void Add(int u, int v, int w){ edge[tol] = (Edge){v,head[u],w}, head[u] = tol++; edge[tol] = (Edge){u,head[v],w}, head[v] = tol++; } void update(int pre, int &rt, int Left, int Right, int pos){ T[++cnt] = T[pre], T[cnt].sum++, rt = cnt; if(Left == Right) return; int Mid = (Left + Right) >> 1; if(Mid >= pos) update(T[pre].Left, T[rt].Left, Left, Mid, pos); else update(T[pre].Right, T[rt].Right, Mid + 1, Right, pos); } int query(int Left, int Right, int f, int x, int y, int k){ if(x == y) return x; int Mid = (x + y) >> 1; int temp = T[T[Left].Left].sum + T[T[Right].Left].sum - T[T[f].Left].sum*2; if(k <= temp) return query(T[Left].Left, T[Right].Left, T[f].Left, x, Mid, k); return query(T[Left].Right, T[Right].Right, T[f].Right, Mid + 1, y, k - temp); } void dfs(int u, int f){ dp[u][0] = fa[u]; for(int i = 1; i <= 20; i++) dp[u][i] = dp[dp[u][i-1]][i-1]; for(int i = head[u]; ~i; i = edge[i].nex){ int v = edge[i].to, w = edge[i].w; if(v == f) continue; fa[v] = u; dep[v] = dep[u] + 1; dis[u] = dis[v] + w; update(root[u], root[v], 1, mx, w); dfs(v, u); } } int LCA(int x, int y) { if(dep[x] < dep[y]) swap(x, y); for(int i = 20; i >= 0; i--) if(dep[dp[x][i]] >= dep[y]) x = dp[x][i]; if(x == y) return x; for(int i = 20; i >= 0; i--) if(dp[x][i] != dp[y][i]) x = dp[x][i], y = dp[y][i]; return dp[x][0]; } int main(){ int u,v,w,q,a,b,f,p; double ans; scanf("%d", &cas); while(cas--){ scanf("%d", &n); tol = 0; mx = 0; cnt = 0; mem(head,-1); mem(dp,0); mem(dis,0); mem(dep,0); for(int i=1;i<n;i++){ scanf("%d%d%d",&u,&v,&w); Add(u,v,w); mx = mx > w ? mx : w; } fa[1] = 1; dfs(1, 0); scanf("%d",&q); while(q--){ scanf("%d%d",&a,&b); f = LCA(a, b); p = dep[a] + dep[b] - (dep[f]<<1); ans = 0; if(p % 2) ans = (double)query(root[a], root[b], root[f], 1, mx, p / 2 + 1); else ans = ((double)query(root[a], root[b], root[f], 1, mx, p / 2) + (double)query(root[a], root[b], root[f], 1, mx, p / 2 + 1)) / 2.0; printf("%.1f\n", ans); } } return 0; }