HDU6058A Kanade's sum

xiaoxiao2021-02-28  105

题目链接

题意

​ 给定一个长度为 n 的数组 A,用 A[1...n] 表示, A[1...n] 1n 的数一种排列组合。存在一个函数 f(l,r,k) 表示 A[l...r] 中第k大数的值,同时 f(l,r,k)=0 rl+1<k 。给定 k 求解 nl=1nr=lf(l,r,k)

分析

​ 考虑每个数对答案的贡献,显然 A[1...n] 中第 1 到第 k1 大的数对答案的贡献均为0。然后思考剩下的数。对于数 ai ,想要使 ai 为选定 A[l...r] 中第 k 大数,必须要使 A[l...r] 中恰好存在 k1 个大于 ai 的数。如果我们已知大于 ai 的每个数的具体位置,那么就从 ai 的左边选取 t 个比它大的数,再从 ai 的右边选取 kt1 个比它大的数,就能组合成一种可行解。实际选取这样 k1 个数不一定仅有一种解,对于左边,第 t 个数到第 t 1 个数之间的所有位置均是合法的选择,右边类似。设左边选 t 个数的合法区间长度为 lef[t] ,右边选 k1t 个数的合法区间长度为 rig[k1t] ,则总共的可行解数为 lef[t]×rig[k1t] 。最后只要枚举一下t的值就能快速求解了。

​ 新的问题是如何快速确定大于 ai 的每个数的位置,或者说快速逐个搜索 ai 左侧和右侧比 ai 大的数的位置。由于不需要利用到比 ai 小的数,不妨从大到小枚举 1n 中的每一个数,处理完后将对应数的位置放入某个集合中,这样每次查询这个集合时,必然都是比 ai 大的数。然而集合中快速搜索最近位置的点的复杂度为 O(log(n)) ,加上枚举的复杂度,总复杂度达到了 O(nklog(n)) 。赛时尝试了一发,不出意料的T掉了。然后考虑到优化搜索过程,不难发现搜索时都是从 ai 的位置向左或者向右逐个查询,于是想到用链表的方式保存左边和右边最近点位置。通过链表的指针就能在枚举的过程中快速搜索下一个点的位置了。插入新位置时采用二分的方式找到最近的点的位置,再利用链表关系更新即可。最后的总复杂度为 O(nk+nlog(n))

代码

#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<vector> #include<map> using namespace std; #define LL long long #define MAXN 500500 const int mod=1e9+7; struct Node{ int l,r; }nxt[MAXN]; int bin[MAXN]; int pos[MAXN]; int lef[MAXN]; int rig[MAXN]; int n; int lowbit(int x){ return x&-x; } void add(int x){ while(x<=n){ bin[x]++; x+=lowbit(x); } } int sum(int x){ int ret=0; while(x){ ret+=bin[x]; x-=lowbit(x); } return ret; } int query(int l,int r){ return sum(r)-sum(l-1); } void updata(int x){ int l=1,r=x-1,mid; while(l<=r){ mid=(l+r)>>1; if(query(mid,x-1)<1) r=mid-1; else l=mid+1; } nxt[x].l=r; if(r>0) nxt[r].r=x; l=x+1,r=n; while(l<=r){ mid=(l+r)>>1; if(query(x+1,mid)<1) l=mid+1; else r=mid-1; } nxt[x].r=l; if(l<=n) nxt[l].l=x; } int main(){ int T,k,a; cin>>T; while(T--){ scanf("%d %d",&n,&k); memset(bin,0,sizeof(bin)); nxt[0].l=nxt[n+1].l=0; nxt[0].r=nxt[n+1].r=n+1; for(int i=1;i<=n;++i){ scanf("%d",&a); pos[a]=i; nxt[i].l=0; nxt[i].r=n+1; } for(int i=n;i>n-k+1;i--){ updata(pos[i]); add(pos[i]); } LL ans=0; for(int i=n-k+1;i;i--){ updata(pos[i]); add(pos[i]); for(int j=0,curl=pos[i],curr=pos[i];j<=k;++j){ lef[j]=curl-nxt[curl].l; rig[j]=nxt[curr].r-curr; curl=nxt[curl].l; curr=nxt[curr].r; } for(int j=0;j<k;++j) ans+=i*1ll*lef[j]*rig[k-j-1]; } printf("%I64d\n",ans); } }
转载请注明原文地址: https://www.6miu.com/read-50413.html

最新回复(0)