bzoj4825洛谷P3721 单旋 splay

xiaoxiao2021-02-28  58

题目分析

有人问起我学会的第一个平衡树是什么。

我说是spaly。

在HNOI2017的考场上学会的。

俗话说的好,双旋的splay,单旋的spaly,不旋的saply,O(1)的asply,那么我们就来用splay做一做这道题。

首先我们手模一发单旋最小值操作。会发现,假如最小值节点是x,那么这个操作就是把x放到根,x的右子树给他原来的父亲当左子树,把原来的根节点给它做右子树。

思考思考就会发现,x的右子树的dfs序应该是连续的,准确的说,以x为根的子树应该是spaly的dfs序从左边开始的一段连续的区间,且这个区间里的所有节点的深度都要大于等于x的深度。

现在我们用一棵splay来维护,splay中节点的顺序就是按照权值排序,然后维护一下每个节点的dep值(深度)和其子树里的最小深度,然后每种操作的方法如下:

1.插入:我们寻找x的前驱和后继,发现要么前驱是后继的父亲,要么后继是前驱的父亲(因为前驱和后继的dfs序一定相邻,所以这两个节点一定相邻),所以新加入节点的深度就是max(dep(前驱),dep(后继))+1。除此之外,就简单地将新节点插入splay中即可。

2.单旋最小/大值:找到x的右/左子树代表的区间长度,首先将所有节点的dep +1,然后将x的右/左子树的节点 dep -1,然后再将x的dep单点赋值成1.

3.删除:删除x节点,将所有节点的深度都-1

这道题最难的地方,果然还是细节处理。我调试了两个半小时。由于每个人的写法不同,不予赘述我错了哪些细节,附赠一个丑丑的数据生成器,加油对拍吧。

代码

#include<bits/stdc++.h> using namespace std; int read() { int q=0;char ch=' '; while(ch<'0'||ch>'9') ch=getchar(); while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar(); return q; } const int N=100005,inf=0x3f3f3f3f; int m,rt,n; int s[N][2],f[N],dep[N],v[N],laz[N],mn[N],sz[N]; void up(int x) { sz[x]=sz[s[x][0]]+sz[s[x][1]]+1; mn[x]=min(min(mn[s[x][0]],mn[s[x][1]]),dep[x]); } void pd(int x) { if(!laz[x]) return; int ls=s[x][0],rs=s[x][1],t=laz[x]; if(ls) dep[ls]+=t,mn[ls]+=t,laz[ls]+=t; if(rs) dep[rs]+=t,mn[rs]+=t,laz[rs]+=t; laz[x]=0; } int is(int x) {return s[f[x]][1]==x;} void spin(int x,int &mb) { int fa=f[x],g=f[fa],t=is(x); if(f[x]==mb) mb=x; else s[g][is(fa)]=x; f[x]=g,f[fa]=x,f[s[x][t^1]]=fa; s[fa][t]=s[x][t^1],s[x][t^1]=fa; up(fa),up(x); } void splay(int x,int &mb) { while(x!=mb) { if(f[x]!=mb) { if(is(x)^is(f[x])) spin(x,mb); else spin(f[x],mb); } spin(x,mb); } } int find(int x,int num) {//寻找dfs序第num的节点 pd(x); if(sz[s[x][0]]+1==num) return x; if(sz[s[x][0]]>=num) return find(s[x][0],num); else return find(s[x][1],num-sz[s[x][0]]-1); } void add(int l,int r,int num) {//区间加 int x=find(rt,l-1),y=find(rt,r+1); splay(x,rt),splay(y,s[x][1]); laz[s[y][0]]+=num,dep[s[y][0]]+=num,mn[s[y][0]]+=num; up(y),up(x);//注意pushup } int pre(int x,int num) {//前驱 if(!x) return 0; pd(x); if(v[x]<num) {int kl=pre(s[x][1],num);return kl?kl:x;} else return pre(s[x][0],num); } int nxt(int x,int num) {//后继 if(!x) return 0; pd(x); if(v[x]>num) {int kl=nxt(s[x][0],num);return kl?kl:x;} else return nxt(s[x][1],num); } void ins(int &x,int num,int d,int las) {//插入 if(!x) {x=++n,f[x]=las,v[x]=num,dep[x]=mn[x]=d,sz[x]=1;return;} pd(x); if(num<v[x]) ins(s[x][0],num,d,x); else ins(s[x][1],num,d,x); up(x); } int getl(int x,int num) {//获得从左边开始的连续的dep[x]>=num的区间长度 if(!x) return 0; pd(x); if(dep[x]>=num&&mn[s[x][0]]>=num) return sz[s[x][0]]+1+getl(s[x][1],num); else return getl(s[x][0],num); } int getr(int x,int num) {//获得从右边开始的连续的dep[x]>=num的区间长度 if(!x) return 0; pd(x); if(dep[x]>=num&&mn[s[x][1]]>=num) return sz[s[x][1]]+1+getr(s[x][0],num); else return getr(s[x][1],num); } void chan(int x,int num) {//单点修改 pd(x); if(v[x]==num) {mn[x]=dep[x]=1;return;} if(num<v[x]) chan(s[x][0],num); else chan(s[x][1],num); up(x); } void del(int x) {//删除 splay(x,rt); if(s[x][0]*s[x][1]==0) rt=s[x][0]+s[x][1],f[rt]=0; else { int y=s[x][1]; while(s[y][0]) pd(y),y=s[y][0]; s[y][0]=s[x][0],f[s[x][0]]=y,rt=s[x][1],f[rt]=0; while(y) up(y),y=f[y];//记得pushup } } int main() { int x,y; m=read(); mn[0]=inf,ins(rt,-inf,inf,0),ins(rt,inf,inf,0); while(m--) { int bj=read(); if(bj==1) { x=read();int a=pre(rt,x),b=nxt(rt,x); a=((a==1||a==2)?0:dep[a]),b=((b==1||b==2)?0:dep[b]); printf("%d\n",max(a,b)+1); ins(rt,x,max(a,b)+1,0),splay(n,rt);//这个splay用于维护平衡 } if(bj==2||bj==4) { x=find(rt,2),printf("%d\n",dep[x]); y=min(getl(rt,dep[x]),sz[rt]-1); add(2,sz[rt]-1,1),add(2,y,-1); chan(rt,v[x]); } if(bj==3||bj==5) { x=find(rt,sz[rt]-1),printf("%d\n",dep[x]); y=min(getr(rt,dep[x]),sz[rt]-1); add(2,sz[rt]-1,1),add(sz[rt]-y+1,sz[rt]-1,-1); chan(rt,v[x]); } if(bj==4||bj==5) del(x),add(2,sz[rt]-1,-1); } return 0; }

数据生成器

#include<bits/stdc++.h> using namespace std; int a[100005],js,n; void ins() { ++js; int x=rand()%20+1; while(a[x]) x=rand()%20+1; a[x]=1;printf("1 %d\n",x); } int main() { srand(time(NULL)); n=rand()%10+1,printf("%d\n",n); while(n--) { if(!js) ins(); else { int bj=rand()%5+1; if(bj==1) ins(); else printf("%d\n",bj); if(bj==4||bj==5) --js; } } return 0; }
转载请注明原文地址: https://www.6miu.com/read-2629920.html

最新回复(0)