Description Input 第一行一整数m,之后输入序列A和B(m<=19,0<=A[i],B[i]<998244353) Output 输出答案 Sample Input 2 1 2 3 4 5 6 7 8 Sample Output 568535691 Solution Code
#include<cstdio> #include<cstring> using namespace std; typedef long long ll; typedef unsigned long long ull; const int mod=998244353,inv2=499122177,maxn=1<<20; int m,n,A[maxn],B[maxn],C[maxn],bit[maxn],a[maxn][21],b[maxn][21],c[maxn][21]; void read(int &x) { x=0; char p=getchar(); while(!(p<='9'&&p>='0'))p=getchar(); while(p<='9'&&p>='0')x*=10,x+=p-48,p=getchar(); } int mod_pow(int a,int b) { int ans=1; while(b) { if(b&1)ans=(ll)ans*a%mod; a=(ll)a*a%mod; b>>=1; } return ans; } void FWT(int a[maxn][21],int n,int sta) { for(int d=1;d<n;d<<=1) for(int i=0;i<n;i+=(d<<1)) for(int j=0;j<d;j++) for(int k=0;k<=m;k++) { int x=a[i+j][k],y=a[i+j+d][k]; a[i+j][k]=(x+y)%mod,a[i+j+d][k]=(x-y+mod)%mod; //xor:a[i+j]=x+y,a[i+j+d]=(x-y+mod)%mod; //and:a[i+j]=x+y; //or:a[i+j+d]=x+y; } if(sta==1) { int inv=mod_pow(inv2,m); for(int i=0;i<n;i++) for(int j=0;j<=m;j++) a[i][j]=(ll)a[i][j]*inv%mod; } } ull temp[21]; int main() { read(m); n=1<<m; for(int i=0;i<n;i++)read(A[i]); for(int i=0;i<n;i++)read(B[i]); for(int i=0;i<n;i++)bit[i]=bit[i>>1]+(i&1); for(int i=0;i<n;i++)A[i]=(ll)A[i]*(1<<bit[i])%mod; for(int i=0;i<n;i++)a[i][bit[i]]=A[i],b[i][bit[i]]=B[i]; FWT(a,n,0),FWT(b,n,0); for(int i=0;i<n;i++) { memset(temp,0,sizeof(temp)); for(int j=0;j<=m;j++) for(int k=0;k<=j;k++) { //c[i][j-k]=(c[i][j-k]+(ll)a[i][k]*b[i][j])%mod; temp[j-k]+=(ll)a[i][k]*b[i][j]; if(temp[j-k]>=(1ll<<63))temp[j-k]%=mod; } for(int j=0;j<=m;j++)c[i][j]=temp[j]%mod; } FWT(c,n,1); for(int i=0;i<n;i++)C[i]=c[i][bit[i]]; int ans=0,p=1; for(int i=0;i<n;i++) { ans=(ans+(ll)C[i]*p)%mod; p=1526ll*p%mod; } printf("%d\n",ans); return 0; }