测试地址:动态逆序对 做法:本人这几天学习了CDQ分治思想,感觉还是比较难懂,于是找到了比较好理解的经典应用——三维偏序问题来加深理解。 这题首先需要把问题转化为三维偏序问题,然后再使用CDQ分治解决。 首先这个题目是将元素一个一个删除,在每次删除之前询问逆序对数,从这个方面来看好像无法下手,那么我们不如反过来,看成是将元素一个一个插入,在每次插入之后询问逆序对数。那么每个元素我们就可以使用一个三维坐标 (xi,yi,zi) 来表示,其中 xi 指元素的插入时间(以插入先后顺序标号为1~M,一开始就在的标号为0), yi 指元素在排列中的位置, zi 指元素的值。那么对于一个点 (xi,yi,zi) ,如果存在 newi 个点 (xj,yj,zj) 使得 xi≤xj , yi≤yj 且 zi≥zj 或 xi≤xj , yi≥yj 且 zi≤zj ,那么在第 xi 次插入之后逆序对数就会增加 newi 个(想一想,为什么?)。于是我们就得到了一个变形的三维偏序问题,我们需要想办法求出所有的 newi 。 由于N达到100000,所以 O(N2) 的暴力是绝对炸的。网上有人讲解三维偏序问题时说了一句精辟的话:一维排序,二维分治,三维数据结构。按照这个思路,我们首先把所有点按 x 从小到大排序,重新标号为1~N,然后分治。这里使用的分治方法是CDQ分治,CDQ分治是一种思想,包含递归处理左半、处理左半对右半的影响、递归处理右半三个步骤。以下只考虑怎么处理左半对右半的影响。 假设我们在处理一个区间[l,r],这个区间的中点为 mid ,那么首先分别对于区间 [l,mid] 和 [mid+1,r] 按 y 从小到大排序,因为x已经有序了我们就不管 x ,我们对于右半区间的点一个一个处理影响。因为两边的y都是有序的,那么我们就只需要在左边指一个只会往右的指针,设这个指针当前指到 i ,而右边我们正在处理的点为j,如果 yi≤yj ,那么就在计数数组里的 zi 位置增加1,然后 i 自增1,一直到i>mid或者 yi>yj 为止。然后我们再求计数数组中 ≥zj 的所有位置之和,就可以得到左半区间对点 j 做出的贡献,将其加入newj即可。我们注意到对于计数数组的修改涉及单点修改和区间求和,这个我们可以用代码量小的树状数组解决。以上我们就处理完了一种情况,而另一种情况是类似的,这里就不再赘述了。 经过证明,以上方法的时间复杂度为 O(Nlog2N) ,可以通过全部数据。注意每次处理完后清空树状数组时不要鲁莽地使用memset,会TLE,应该按照原来的顺序再把加上的东西都给减掉。除此之外,要注意排序和处理的顺序,因为有时排序会破坏掉原来的顺序。 以下是本人代码:
#include <cstdio> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #define ll long long using namespace std; int n,m,pos[100010]={0}; ll ans[50010]={0},bit[100010]={0}; struct point3D { int x,y,z,id; }p[100010]; bool cmpx(point3D a,point3D b) {return a.x<b.x;} bool cmpy1(point3D a,point3D b) {return a.y<b.y;} bool cmpy2(point3D a,point3D b) {return a.y>b.y;} bool cmpid(point3D a,point3D b) {return a.id<b.id;} int lowbit(int x) { return x&(-x); } void add(int x,ll d) { for(int i=x;i<=n;i+=lowbit(i)) bit[i]+=d; } ll query(int x) { ll s=0; while(x) { s+=bit[x]; x-=lowbit(x); } return s; } ll sum(int l,int r) { return query(r)-query(l-1); } void solve(int l,int r) { int mid=(l+r)>>1; if (l==r) return; solve(l,mid); int h; sort(p+l,p+mid+1,cmpy1); sort(p+mid+1,p+r+1,cmpy1); h=l; for(int i=mid+1;i<=r;i++) { while(h<=mid&&p[h].y<=p[i].y) add(p[h].z,1),h++; ans[m-p[i].x+1]+=sum(p[i].z,n); } for(int i=l;i<h;i++) add(p[i].z,-1); sort(p+l,p+mid+1,cmpy2); sort(p+mid+1,p+r+1,cmpy2); h=l; for(int i=mid+1;i<=r;i++) { while(h<=mid&&p[h].y>=p[i].y) add(p[h].z,1),h++; ans[m-p[i].x+1]+=sum(1,p[i].z); } for(int i=l;i<h;i++) add(p[i].z,-1); sort(p+l+1,p+r+1,cmpid); solve(mid+1,r); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { p[i].y=i; scanf("%d",&p[i].z); } for(int i=1;i<=m;i++) { int a; scanf("%d",&a); pos[a]=m-i+1; } for(int i=1;i<=n;i++) p[i].x=pos[p[i].z]; sort(p+1,p+n+1,cmpx); for(int i=1;i<=n;i++) p[i].id=i; solve(1,n); for(int i=m;i>=1;i--) ans[i]+=ans[i+1]; for(int i=1;i<=m;i++) printf("%lld\n",ans[i]); return 0; }