前置技能: LCA (Tarjan)
今年冬令营上讲的东西现在才学 惭愧惭愧。。 首先做出图G的dfs树T 其中S为起点 定义T中节点 u 的半必经点sdom[u] 为 u的祖先中能经过若干条非树边到达u的(除了sdom[u]直接引出的边 其余边必须为非树边) 深度最小的节点 则对于任意一条边 (v->u) 这里我们默认sdom idom 记录的是dfn 则有
1.dfn[v]<dfn[u] 即(v->u)为一条树边或前向边 有sdom[u]=Min(dfn[u],sdom[u]); 2.dfn[v]>dfn[u] 即(v->u)为横插边或反组边 则有sdom[u]=Min(sdom[u],sdom[x]),x为v->LCA(v,u)的路径节点中sdom最小的节点以上的式子我们可以利用tarjan的LCA 从dfn为n的节点倒推来做 对于idom我们怎么求呢?
令x为u->sdom[u]路径中sdom最小的节点
1. sdom[u]=sdom[x] idom[u]=sdom[x] 2. sdom[u]!= sdom[x] idom[u]=idom[sdom[u]]同样x的求解过程我们也可以用tarjan的lca求 至此问题完美解决
至于证明。。 sdom的求解可以感性理解一下 idom嘛。。。我也不会
HDU4694 就是很裸的Dominator Tree了。。 就当是模板练习
#include<cstdio> #include<iostream> #include<cstring> #include<cstdlib> using namespace std; char c; inline bool read(int&a) {a=0;do c=getchar();while(c!=EOF&&(c<'0'||c>'9'));if(c==EOF)return false;while(c<='9'&&c>='0')a=(a<<3)+(a<<1)+c-'0',c=getchar();return true;} const int N=500001; struct Chain { Chain*next; int u; }*Head[N],*Head2[N],*Head3[N]; inline void Add(Chain**Head,int u,int v) { Chain*tp=new Chain;tp->next=Head[u];Head[u]=tp;tp->u=v; } int mn[N],dfn[N],idom[N],sdom[N],id[N],fa[N],f[N]; int cnt; int find(int x) { if(f[x]==x)return x; int y=find(f[x]); if(sdom[mn[x]]>sdom[mn[f[x]]])mn[x]=mn[f[x]]; return f[x]=y; } void dfs(int u) { id[dfn[u]=++cnt]=u; for(Chain*tp=Head[u];tp;tp=tp->next) if(!dfn[tp->u])dfs(tp->u),fa[dfn[tp->u]]=dfn[u]; } int n,m; inline void tarjan(int s) { for(int i=1;i<=n;i++)f[i]=sdom[i]=mn[i]=fa[i]=i,dfn[i]=0; cnt=0; dfs(s); int k,x; for(int i=cnt;i>1;i--) { for(Chain*tp=Head2[id[i]];tp;tp=tp->next) if(dfn[tp->u]) find(k=dfn[tp->u]),sdom[i]= sdom[i]<sdom[mn[k]]?sdom[i]:sdom[mn[k]]; Add(Head3,sdom[i],i); for(Chain*tp=Head3[f[i]=x=fa[i]];tp;tp=tp->next) find(k=tp->u),idom[k] = sdom[mn[k]] < x?mn[k]:x; Head3[x]=NULL; } for(int i=2;i<=cnt;Add(Head3,id[idom[i]],id[i]),i++) { if(idom[i]!=sdom[i])idom[i]=idom[idom[i]]; //if(idom[i]==i)puts("WA"); } } int Ans[N]; void Mp(int u,int p) { Ans[u]=p; for(Chain*tp=Head3[u];tp;tp=tp->next) Mp(tp->u,p+tp->u); } void out(int x) {if(!x){putchar('0');return ;}if(x>9)out(x/10);putchar('0'+x%10);} int main() { while(true) { cnt=0; if(!read(n))break;read(m); memset(Head,0,sizeof(Head)); memset(Head2,0,sizeof(Head2)); memset(Head3,0,sizeof(Head3)); for(int i=1;i<=m;i++){int a,b;read(a),read(b);Add(Head,a,b),Add(Head2,b,a);} tarjan(n); Mp(n,n); for(int i=1;i<=n;i++) out(Ans[i]),putchar(i==n?'\n':' '),Ans[i]=0; } return 0; }