题目分析
有人问起我学会的第一个平衡树是什么。
我说是spaly。
在HNOI2017的考场上学会的。
俗话说的好,双旋的splay,单旋的spaly,不旋的saply,O(1)的asply,那么我们就来用splay做一做这道题。
首先我们手模一发单旋最小值操作。会发现,假如最小值节点是x,那么这个操作就是把x放到根,x的右子树给他原来的父亲当左子树,把原来的根节点给它做右子树。
思考思考就会发现,x的右子树的dfs序应该是连续的,准确的说,以x为根的子树应该是spaly的dfs序从左边开始的一段连续的区间,且这个区间里的所有节点的深度都要大于等于x的深度。
现在我们用一棵splay来维护,splay中节点的顺序就是按照权值排序,然后维护一下每个节点的dep值(深度)和其子树里的最小深度,然后每种操作的方法如下:
1.插入:我们寻找x的前驱和后继,发现要么前驱是后继的父亲,要么后继是前驱的父亲(因为前驱和后继的dfs序一定相邻,所以这两个节点一定相邻),所以新加入节点的深度就是max(dep(前驱),dep(后继))+1。除此之外,就简单地将新节点插入splay中即可。
2.单旋最小/大值:找到x的右/左子树代表的区间长度,首先将所有节点的dep +1,然后将x的右/左子树的节点 dep -1,然后再将x的dep单点赋值成1.
3.删除:删除x节点,将所有节点的深度都-1
这道题最难的地方,果然还是细节处理。我调试了两个半小时。由于每个人的写法不同,不予赘述我错了哪些细节,附赠一个丑丑的数据生成器,加油对拍吧。
代码
#include<bits/stdc++.h>
using namespace std
;
int read() {
int q
=0;char ch
=' ';
while(ch
<'0'||ch
>'9') ch
=getchar();
while(ch
>='0'&&ch
<='9') q
=q
*10+ch
-'0',ch
=getchar();
return q
;
}
const int N
=100005,inf
=0x3f3f3f3f;
int m
,rt
,n
;
int s
[N
][2],f
[N
],dep
[N
],v
[N
],laz
[N
],mn
[N
],sz
[N
];
void up(int x
) {
sz
[x
]=sz
[s
[x
][0]]+sz
[s
[x
][1]]+1;
mn
[x
]=min(min(mn
[s
[x
][0]],mn
[s
[x
][1]]),dep
[x
]);
}
void pd(int x
) {
if(!laz
[x
]) return;
int ls
=s
[x
][0],rs
=s
[x
][1],t
=laz
[x
];
if(ls
) dep
[ls
]+=t
,mn
[ls
]+=t
,laz
[ls
]+=t
;
if(rs
) dep
[rs
]+=t
,mn
[rs
]+=t
,laz
[rs
]+=t
;
laz
[x
]=0;
}
int is(int x
) {return s
[f
[x
]][1]==x
;}
void spin(int x
,int &mb
) {
int fa
=f
[x
],g
=f
[fa
],t
=is(x
);
if(f
[x
]==mb
) mb
=x
;
else s
[g
][is(fa
)]=x
;
f
[x
]=g
,f
[fa
]=x
,f
[s
[x
][t
^1]]=fa
;
s
[fa
][t
]=s
[x
][t
^1],s
[x
][t
^1]=fa
;
up(fa
),up(x
);
}
void splay(int x
,int &mb
) {
while(x
!=mb
) {
if(f
[x
]!=mb
) {
if(is(x
)^is(f
[x
])) spin(x
,mb
);
else spin(f
[x
],mb
);
}
spin(x
,mb
);
}
}
int find(int x
,int num
) {
pd(x
);
if(sz
[s
[x
][0]]+1==num
) return x
;
if(sz
[s
[x
][0]]>=num
) return find(s
[x
][0],num
);
else return find(s
[x
][1],num
-sz
[s
[x
][0]]-1);
}
void add(int l
,int r
,int num
) {
int x
=find(rt
,l
-1),y
=find(rt
,r
+1);
splay(x
,rt
),splay(y
,s
[x
][1]);
laz
[s
[y
][0]]+=num
,dep
[s
[y
][0]]+=num
,mn
[s
[y
][0]]+=num
;
up(y
),up(x
);
}
int pre(int x
,int num
) {
if(!x
) return 0;
pd(x
);
if(v
[x
]<num
) {int kl
=pre(s
[x
][1],num
);return kl
?kl
:x
;}
else return pre(s
[x
][0],num
);
}
int nxt(int x
,int num
) {
if(!x
) return 0;
pd(x
);
if(v
[x
]>num
) {int kl
=nxt(s
[x
][0],num
);return kl
?kl
:x
;}
else return nxt(s
[x
][1],num
);
}
void ins(int &x
,int num
,int d
,int las
) {
if(!x
) {x
=++n
,f
[x
]=las
,v
[x
]=num
,dep
[x
]=mn
[x
]=d
,sz
[x
]=1;return;}
pd(x
);
if(num
<v
[x
]) ins(s
[x
][0],num
,d
,x
);
else ins(s
[x
][1],num
,d
,x
);
up(x
);
}
int getl(int x
,int num
) {
if(!x
) return 0;
pd(x
);
if(dep
[x
]>=num
&&mn
[s
[x
][0]]>=num
) return sz
[s
[x
][0]]+1+getl(s
[x
][1],num
);
else return getl(s
[x
][0],num
);
}
int getr(int x
,int num
) {
if(!x
) return 0;
pd(x
);
if(dep
[x
]>=num
&&mn
[s
[x
][1]]>=num
) return sz
[s
[x
][1]]+1+getr(s
[x
][0],num
);
else return getr(s
[x
][1],num
);
}
void chan(int x
,int num
) {
pd(x
);
if(v
[x
]==num
) {mn
[x
]=dep
[x
]=1;return;}
if(num
<v
[x
]) chan(s
[x
][0],num
);
else chan(s
[x
][1],num
);
up(x
);
}
void del(int x
) {
splay(x
,rt
);
if(s
[x
][0]*s
[x
][1]==0) rt
=s
[x
][0]+s
[x
][1],f
[rt
]=0;
else {
int y
=s
[x
][1];
while(s
[y
][0]) pd(y
),y
=s
[y
][0];
s
[y
][0]=s
[x
][0],f
[s
[x
][0]]=y
,rt
=s
[x
][1],f
[rt
]=0;
while(y
) up(y
),y
=f
[y
];
}
}
int main()
{
int x
,y
;
m
=read();
mn
[0]=inf
,ins(rt
,-inf
,inf
,0),ins(rt
,inf
,inf
,0);
while(m
--) {
int bj
=read();
if(bj
==1) {
x
=read();int a
=pre(rt
,x
),b
=nxt(rt
,x
);
a
=((a
==1||a
==2)?0:dep
[a
]),b
=((b
==1||b
==2)?0:dep
[b
]);
printf("%d\n",max(a
,b
)+1);
ins(rt
,x
,max(a
,b
)+1,0),splay(n
,rt
);
}
if(bj
==2||bj
==4) {
x
=find(rt
,2),printf("%d\n",dep
[x
]);
y
=min(getl(rt
,dep
[x
]),sz
[rt
]-1);
add(2,sz
[rt
]-1,1),add(2,y
,-1);
chan(rt
,v
[x
]);
}
if(bj
==3||bj
==5) {
x
=find(rt
,sz
[rt
]-1),printf("%d\n",dep
[x
]);
y
=min(getr(rt
,dep
[x
]),sz
[rt
]-1);
add(2,sz
[rt
]-1,1),add(sz
[rt
]-y
+1,sz
[rt
]-1,-1);
chan(rt
,v
[x
]);
}
if(bj
==4||bj
==5) del(x
),add(2,sz
[rt
]-1,-1);
}
return 0;
}
数据生成器
#include<bits/stdc++.h>
using namespace std
;
int a
[100005],js
,n
;
void ins() {
++js
;
int x
=rand()%20+1;
while(a
[x
]) x
=rand()%20+1;
a
[x
]=1;printf("1 %d\n",x
);
}
int main()
{
srand(time(NULL));
n
=rand()%10+1,printf("%d\n",n
);
while(n
--) {
if(!js
) ins();
else {
int bj
=rand()%5+1;
if(bj
==1) ins();
else printf("%d\n",bj
);
if(bj
==4||bj
==5) --js
;
}
}
return 0;
}