Codeforces Round #430 (Div. 2) E. Nikita and game

xiaoxiao2021-02-28  112

一条直径有两端 考虑把直径的端点分为两部分(被直径中点分开) 那么只要维护两端的直径端点就可以了

当加入一个新点的时候检查是否更新了直径

如果更新了直径那它就会成为直径一端的唯一一个点 然后看他是到左边的那些端点长还是到右边那些端点长 假如是左边那左边的点都是另一端的点 再检查一下到右边的点的距离有没有跟新直径一样的 将它加入另一端的集合

如果没更新则检查最大距离是否与直径相同 相同则加入相应一端的集合中

#include <iostream> #include <algorithm> #include <sstream> #include <string> #include <queue> #include <cstdio> #include <map> #include <set> #include <utility> #include <stack> #include <cstring> #include <cmath> #include <vector> #include <ctime> #include <bitset> using namespace std; #define pb push_back #define sd(n) scanf("%d",&n) #define sdd(n,m) scanf("%d%d",&n,&m) #define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k) #define sld(n) scanf("%lld",&n) #define sldd(n,m) scanf("%lld%lld",&n,&m) #define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k) #define sf(n) scanf("%lf",&n) #define sff(n,m) scanf("%lf%lf",&n,&m) #define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k) #define ss(str) scanf("%s",str) #define ans() printf("%d",ans) #define ansn() printf("%d\n",ans) #define anss() printf("%d ",ans) #define lans() printf("%lld",ans) #define lanss() printf("%lld ",ans) #define lansn() printf("%lld\n",ans) #define fansn() printf("%.10f\n",ans) #define r0(i,n) for(int i=0;i<(n);++i) #define r1(i,e) for(int i=1;i<=e;++i) #define rn(i,e) for(int i=e;i>=1;--i) #define rsz(i,v) for(int i=0;i<(int)v.size();++i) #define szz(x) ((int)x.size()) #define mst(abc,bca) memset(abc,bca,sizeof abc) #define lowbit(a) (a&(-a)) #define all(a) a.begin(),a.end() #define pii pair<int,int> #define pli pair<ll,int> #define mp make_pair #define lrt rt<<1 #define rrt rt<<1|1 #define X first #define Y second #define PI (acos(-1.0)) #define sqr(a) ((a)*(a)) typedef long long ll; typedef unsigned long long ull; const int mod = 1000000000+7; const double eps=1e-9; const int inf=0x3f3f3f3f; const ll infl = 10000000000000000; const int maxn= 300000+10; const int maxm = 100000+10; int in(int &ret) { char c; int sgn ; if(c=getchar(),c==EOF)return -1; while(c!='-'&&(c<'0'||c>'9'))c=getchar(); sgn = (c=='-')?-1:1; ret = (c=='-')?0:(c-'0'); while(c=getchar(),c>='0'&&c<='9')ret = ret*10+(c-'0'); ret *=sgn; return 1; } int fa[maxn][20]; int dep[maxn]; int lca(int a,int b) { int res = 0; if(dep[a]<dep[b]) swap(a,b); for(int i=18; ~i; --i) if(dep[a]-(1<<i) >= dep[b]) { a = fa[a][i]; res += 1<<i; } if(a==b)return res; for(int i=18;~i;--i) { if(fa[a][i]!=fa[b][i]) { res+= 1<< (i+1); a = fa[a][i]; b = fa[b][i]; } } return res+2; } int main() { #ifdef LOCAL freopen("input.txt","r",stdin); // freopen("output.txt","w",stdout); #endif // LOCAL int n; sd(n); set<int>s1,s2; s1.insert(1); dep[1] = 1; int mx = 1; for(int v = 2; v<=n+1; ++v) { int u ; sd(u); fa[v][0] = u; dep[v] = dep[u]+1; for(int i=1; (1<<i)<dep[v]; i++)fa[v][i] = fa[ fa[v][i-1] ][i-1]; int dis1 = s1.empty()?0:lca(v,*s1.begin()); int dis2 = s2.empty()?0:lca(v,*s2.begin()); if(max(dis1,dis2)>mx) { mx = max(dis1,dis2); if(dis1==mx) { for(set<int>::iterator it = s2.begin();it!=s2.end();++it) { if(lca(*it,v)==mx)s1.insert(*it); } s2.clear(); s2.insert(v); } else { for(set<int>::iterator it = s1.begin();it!=s1.end();++it) { if(lca(*it,v)==mx)s2.insert(*it); } s1.clear(); s1.insert(v); } } else if(max(dis1,dis2)==mx) { if(dis1>=dis2)s2.insert(v); else s1.insert(v); } int ans = s1.size()+s2.size(); ansn(); } return 0; }
转载请注明原文地址: https://www.6miu.com/read-61422.html

最新回复(0)