题目链接
题意
给定一个长度为 n 的数组 A,用
A[1...n]
表示,
A[1...n]
是
1−n
的数一种排列组合。存在一个函数
f(l,r,k)
表示
A[l...r]
中第k大数的值,同时
f(l,r,k)=0
当
r−l+1<k
。给定 k 求解
∑nl=1∑nr=lf(l,r,k)
。
分析
考虑每个数对答案的贡献,显然
A[1...n]
中第
1
到第 k−1 大的数对答案的贡献均为0。然后思考剩下的数。对于数
ai
,想要使
ai
为选定
A[l...r]
中第
k
大数,必须要使 A[l...r] 中恰好存在
k−1
个大于
ai
的数。如果我们已知大于
ai
的每个数的具体位置,那么就从
ai
的左边选取
t
个比它大的数,再从 ai 的右边选取
k−t−1
个比它大的数,就能组合成一种可行解。实际选取这样
k−1
个数不一定仅有一种解,对于左边,第
t
个数到第 t 1 个数之间的所有位置均是合法的选择,右边类似。设左边选
t
个数的合法区间长度为 lef[t] ,右边选
k−1−t
个数的合法区间长度为
rig[k−1−t]
,则总共的可行解数为
lef[t]×rig[k−1−t]
。最后只要枚举一下t的值就能快速求解了。
新的问题是如何快速确定大于
ai
的每个数的位置,或者说快速逐个搜索
ai
左侧和右侧比
ai
大的数的位置。由于不需要利用到比
ai
小的数,不妨从大到小枚举
1−n
中的每一个数,处理完后将对应数的位置放入某个集合中,这样每次查询这个集合时,必然都是比
ai
大的数。然而集合中快速搜索最近位置的点的复杂度为
O(log(n))
,加上枚举的复杂度,总复杂度达到了
O(nklog(n))
。赛时尝试了一发,不出意料的T掉了。然后考虑到优化搜索过程,不难发现搜索时都是从
ai
的位置向左或者向右逐个查询,于是想到用链表的方式保存左边和右边最近点位置。通过链表的指针就能在枚举的过程中快速搜索下一个点的位置了。插入新位置时采用二分的方式找到最近的点的位置,再利用链表关系更新即可。最后的总复杂度为
O(nk+nlog(n))
。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<map>
using namespace std;
#define LL long long
#define MAXN 500500
const int mod=
1e9+
7;
struct Node{
int l,r;
}nxt[MAXN];
int bin[MAXN];
int pos[MAXN];
int lef[MAXN];
int rig[MAXN];
int n;
int lowbit(
int x){
return x&-x;
}
void add(
int x){
while(x<=n){
bin[x]++;
x+=lowbit(x);
}
}
int sum(
int x){
int ret=
0;
while(x){
ret+=bin[x];
x-=lowbit(x);
}
return ret;
}
int query(
int l,
int r){
return sum(r)-sum(l-
1);
}
void updata(
int x){
int l=
1,r=x-
1,mid;
while(l<=r){
mid=(l+r)>>
1;
if(query(mid,x-
1)<
1)
r=mid-
1;
else
l=mid+
1;
}
nxt[x].l=r;
if(r>
0)
nxt[r].r=x;
l=x+
1,r=n;
while(l<=r){
mid=(l+r)>>
1;
if(query(x+
1,mid)<
1)
l=mid+
1;
else
r=mid-
1;
}
nxt[x].r=l;
if(l<=n)
nxt[l].l=x;
}
int main(){
int T,k,a;
cin>>T;
while(T--){
scanf(
"%d %d",&n,&k);
memset(bin,
0,
sizeof(bin));
nxt[
0].l=nxt[n+
1].l=
0;
nxt[
0].r=nxt[n+
1].r=n+
1;
for(
int i=
1;i<=n;++i){
scanf(
"%d",&a);
pos[a]=i;
nxt[i].l=
0;
nxt[i].r=n+
1;
}
for(
int i=n;i>n-k+
1;i--){
updata(pos[i]);
add(pos[i]);
}
LL ans=
0;
for(
int i=n-k+
1;i;i--){
updata(pos[i]);
add(pos[i]);
for(
int j=
0,curl=pos[i],curr=pos[i];j<=k;++j){
lef[j]=curl-nxt[curl].l;
rig[j]=nxt[curr].r-curr;
curl=nxt[curl].l;
curr=nxt[curr].r;
}
for(
int j=
0;j<k;++j)
ans+=i*
1ll*lef[j]*rig[k-j-
1];
}
printf(
"%I64d\n",ans);
}
}