区间第k大的数(主席树)

xiaoxiao2021-02-28  171

套主席树求区间第k小的数的模板,然后求区间[l,r]第k大的数就等于求区间[l,r]第r-l+1-k小的数(下标从1开始)

区间第K小值问题 有n个数,多次询问一个区间[L,R]中第k小的值是多少。

查询[1,n]中的第K小值 我们先对数据进行离散化,然后按值域建立线段树,线段树中维护某个值域中的元素个数。 在线段树的每个结点上用cnt记录这一个值域中的元素个数。 那么要寻找第K小值,从根结点开始处理,若左儿子中表示的元素个数大于等于K,那么我们递归的处理左儿子,寻找左儿子中第K小的数; 若左儿子中的元素个数小于K,那么第K小的数在右儿子中,我们寻找右儿子中第K-(左儿子中的元素数)小的数。

查询区间[L,R]中的第K小值 我们按照从1到n的顺序依次将数据插入可持久化的线段树中,将会得到n+1个版本的线段树(包括初始化的版本),将其编号为0~n。 可以发现所有版本的线段树都拥有相同的结构,它们同一个位置上的结点的含义都相同。 考虑第i个版本的线段树的结点P,P中储存的值表示[1,i]这个区间中,P结点的值域中所含的元素个数; 假设我们知道了[1,R]区间中P结点的值域中所含的元素个数,也知道[1,L-1]区间中P结点的值域中所包含的元素个数,显然用第一个个数减去第二个个数,就可以得到[L,R]区间中的元素个数。 因此我们对于一个查询[L,R],同步考虑两个根root[L-1]与root[R],用它们同一个位置的结点的差值就表示了区间[L,R]中的元素个数,利用这个性质,从两个根节点,向左右儿子中递归的查找第K小数即可。

常数优化的技巧 一种在常数上减小内存消耗的方法: 插入值时候先不要一次新建到底,能留住就留住,等到需要访问子节点时候再建下去。 这样理论内存复杂度依然是O(Nlg^2N),但因为实际上很多结点在查询时候根本没用到,所以内存能少用一些

#include <map> #include <set> #include <cmath> #include <ctime> #include <stack> #include <queue> #include <cstdio> #include <memory> #include <cctype> #include <bitset> #include <string> #include <vector> #include <climits> #include <cstring> #include <iostream> #include <iomanip> #include <algorithm> #include <functional> #define FIN freopen("input.txt","r",stdin); #define FOUT freopen("output.txt","w+",stdout); using namespace std; typedef long long ll; const int INF = 0x3f3f3f3f; const int mod = 1e9 + 7; const double eps=1e-8; const double Pi=acos(-1.0); const int maxn=50002; struct node { int ls,rs; int cnt;//某个值域元素的个数 }tr[maxn*20]; int cur,root[maxn]; inline void push_up(int o) { tr[o].cnt=tr[tr[o].ls].cnt+tr[tr[o].rs].cnt; } int build(int l,int r) { int k=cur++; if(l==r) { tr[k].cnt=0; return k; } int mid=(l+r)>>1; tr[k].ls=build(l,mid); tr[k].rs=build(mid+1,r); push_up(k); return k; } int update(int o,int l,int r,int pos,int val) { int k=cur++; tr[k]=tr[o]; if(l==pos&&r==pos) { tr[k].cnt+=val; return k; } int mid=(l+r)>>1; if(pos<=mid) tr[k].ls=update(tr[o].ls,l,mid,pos,val); else tr[k].rs=update(tr[o].rs,mid+1,r,pos,val); push_up(k); return k; } int query(int l,int r,int o,int v,int kth) { if(l==r) return l; int mid=(l+r)>>1; int res=tr[tr[v].ls].cnt-tr[tr[o].ls].cnt; if(kth<=res) return query(l,mid,tr[o].ls,tr[v].ls,kth); else return query(mid+1,r,tr[o].rs,tr[v].rs,kth-res); } int num[maxn]; int sortnum[maxn]; int main() { int n; while(~scanf("%d",&n)) { memset(num,0,sizeof(num)); memset(sortnum,0,sizeof(sortnum)); cur=0; for(int i=1;i<=n;i++) { scanf("%d",&num[i]); sortnum[i]=num[i]; } sort(sortnum+1,sortnum+1+n); int cnt=1; for(int i=2;i<=n;i++) { if(sortnum[i]!=sortnum[cnt]) sortnum[++cnt]=sortnum[i]; } root[0]=build(1,cnt); for(int i=1;i<=n;i++) { int p=lower_bound(sortnum+1,sortnum+cnt+1,num[i])-sortnum; root[i]=update(root[i-1],1,cnt,p,1); } int q; scanf("%d",&q); for(int i=0;i<q;i++) { int l,r,k; scanf("%d %d %d",&l,&r,&k); int idx=query(1,cnt,root[l],root[r+1],r-l+1-k+1);//模板的下标从1开始,则求的是root[l-1]与root[r]的值 printf("%d\n",sortnum[idx]); } } }
转载请注明原文地址: https://www.6miu.com/read-42893.html

最新回复(0)