Kanade’s trio Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 524288/524288 K (Java/Others) Total Submission(s): 890 Accepted Submission(s): 322
Problem Description Give you an array A[1..n] ,you need to calculate how many tuples (i,j,k)satisfy that (i《j《k) and ((A[i] xor A[j])<(A[j] xor A[k]))
There are T test cases.
1≤T≤20
1≤∑n≤5∗10^5
0≤A[i]<2^30
Input There is only one integer T on first line.
For each test case , the first line consists of one integer n ,and the second line consists of nintegers which means the array A[1..n]
Output For each test case , output an integer , which means the answer.
Sample Input 1
5
1 2 3 4 5
Sample Output 6
题意: 求有多少组元组,使(i《j《k) and ((A[i] xor A[j])<(A[j] xor A[k])) 条件成立
解析: 用字典树构造,将一个数用二进制表示,每一个节点有两个孩子节点(0和1) 根据再a数组中的位置从小到大进行插入 每一次插入,同时计算在插入的这一点作为Ak时对答案的贡献,分为两种情况(根据异或运算符的性质,就是找Ak与Ai最高位的不同的值,然后找在该位与Ai相同的Aj):
num[i][0/1]表示第i位时0/1的数的个数 trie[i].igj 表示在这个数被选为ai时,有多少编号小于i且前i位与Ai不同的数 (trie[p].igj+=num[i][d]-trie[p].cnt,表示在第i位为d的数的数量-第i位为d且前i位与Ai相同的个数)
(1)Ai与Ak前t-1位相同,第t位不同,Aj与Ai前t-1位相同,且第t位相同(就是在n个里面选2个,C(2,n)) 式子:ans+=(ll)trie[tmp].cnt*(trie[tmp].cnt-1)/2; (2)Ai与Ak前t-1位相同,第t位不同,Aj与Ai前t-1位不同,且第t位相同 (3)在第二步的计算时候,会把(i>j)的情况算进去,所以就要把这一部分减去(trie[i].igj表示) 式子:ans+=(ll)trie[tmp].cnt(//Ai)*(num[i][d^1]-trie[tmp].cnt)(//Aj)-trie[tmp].igj(i>j的情况);
#include<stdio.h> #include<string.h> typedef long long ll; const int BIT = 30; const int MAXN = 500010; typedef struct node { int cnt; int igj; //每一个ai中i>j的aj的个数和 int nxt[2]; }node; node trie[MAXN*BIT]; int sz; int num[BIT][2]; int digt[BIT]; ll ans; void insert() { int p=0; for(int i=0;i<BIT;i++) { int d=digt[i]; //p表示d的父节点 //num[i][d]++; if(!trie[p].nxt[d]) //若该节点i还没有点到达过 trie[p].nxt[d]=++sz; if(trie[p].nxt[d^1]) //若该节点i的兄弟节点到达过,或者说有值 { int tmp=trie[p].nxt[d^1]; ans+=(ll)trie[tmp].cnt*(trie[tmp].cnt-1)/2; ans+=(ll)trie[tmp].cnt*(num[i][d^1]-trie[tmp].cnt)-trie[tmp].igj; //ai*aj-(i>j) } p=trie[p].nxt[d]; trie[p].cnt++; trie[p].igj+=num[i][d]-trie[p].cnt; //num[i][d]表示所有a的第i位(从左到右)是d的数-ai } } int main() { int t,n,a; scanf("%d",&t); while(t--) { scanf("%d",&n); memset(num,0,sizeof(num)); memset(trie,0,sizeof(trie)); ans=sz=0; for(int i=0;i<n;i++) { scanf("%d",&a); int tmp=BIT-1; while(tmp!=-1) { num[tmp][a&1]++; digt[tmp--]=a&1; a=a>>1; } insert(); } printf("%lld\n",ans); } return 0; }