因为叶子权值各不相同,我们考虑线段树合并,在merge时维护每个权值的概率。我们记gr为x的右子树中比我大的权值的概率和,我们合并时先做右子树,再做左子树。那么对于一个权值x的出现概率就是另一棵子树中比x大的权值的出现概率g*(1-p)+(1-g)*p。 复杂度 O(nlogn) O ( n l o g n )
#include <bits/stdc++.h> using namespace std; #define ll long long #define inf 0x3f3f3f3f #define N 300010 #define mod 998244353 inline char gc(){ static char buf[1<<16],*S,*T; if(S==T){T=(S=buf)+fread(buf,1,1<<16,stdin);if(T==S) return EOF;} return *S++; } inline int read(){ int x=0,f=1;char ch=gc(); while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=gc();} while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc(); return x*f; } int n,c[N][2],rt[N],w[N],m=0,owo=0,gr,gl,ans=0; struct Icefox{ int x,id; Icefox(){} Icefox(int _x,int _id){x=_x;id=_id;} friend bool operator<(Icefox a,Icefox b){return a.x<b.x;} }aa[N]; struct node{ int mul,sum,lc,rc; }tr[N*20]; inline int ksm(int x,int k){ int res=1;for(;k;k>>=1,x=(ll)x*x%mod) if(k&1) res=(ll)res*x%mod;return res; } inline void pushup(int p){ tr[p].sum=(tr[tr[p].lc].sum+tr[tr[p].rc].sum)%mod; } inline void inc(int &x,int y){x+=y;x%=mod;} inline void domul(int p,int val){ if(!p) return;val%=mod;if(val<0) val+=mod; tr[p].mul=(ll)tr[p].mul*val%mod; tr[p].sum=(ll)tr[p].sum*val%mod; } inline void pushdown(int p){ if(tr[p].mul==1) return; domul(tr[p].lc,tr[p].mul);domul(tr[p].rc,tr[p].mul);tr[p].mul=1; } inline void ins(int &p,int l,int r,int x){ p=++owo;tr[p].sum++;tr[p].mul=1;if(l==r) return; int mid=l+r>>1; if(x<=mid) ins(tr[p].lc,l,mid,x); else ins(tr[p].rc,mid+1,r,x); } inline int merge(int p1,int p2,int p){ if(!p1&&!p2) return 0; if(!p1){ inc(gr,tr[p2].sum); domul(p2,(ll)gl*(1-p)%mod+(ll)(1-gl)*p%mod);return p2; }if(!p2){ inc(gl,tr[p1].sum); domul(p1,(ll)gr*(1-p)%mod+(ll)(1-gr)*p%mod);return p1; }pushdown(p1);pushdown(p2); tr[p1].rc=merge(tr[p1].rc,tr[p2].rc,p); tr[p1].lc=merge(tr[p1].lc,tr[p2].lc,p); pushup(p1);return p1; } void dfs(int x){ if(!c[x][0]) return; if(!c[x][1]){dfs(c[x][0]);rt[x]=rt[c[x][0]];return;} dfs(c[x][0]);dfs(c[x][1]);gr=0;gl=0; rt[x]=merge(rt[c[x][0]],rt[c[x][1]],w[x]); } inline int sqr(int x){return (ll)x*x%mod;} void dfs1(int p,int l,int r){ if(l==r){inc(ans,(ll)l*aa[l].x%mod*sqr(tr[p].sum)%mod);return;} int mid=l+r>>1;pushdown(p); dfs1(tr[p].lc,l,mid);dfs1(tr[p].rc,mid+1,r); } int main(){ // freopen("a.in","r",stdin); n=read(); for(int i=1;i<=n;++i){ int x=read(); if(c[x][0]) c[x][1]=i; else c[x][0]=i; }int inv=ksm(10000,mod-2); for(int i=1;i<=n;++i){ int x=read();if(!c[i][0]){aa[++m]=Icefox(x,i);continue;} w[i]=(ll)x*inv%mod; }sort(aa+1,aa+m+1);for(int i=1;i<=m;++i) ins(rt[aa[i].id],1,m,i); dfs(1);dfs1(rt[1],1,m); printf("%d\n",ans); return 0; }