HDU 6064 RXD and numbers(生成树计数+行列式)

xiaoxiao2021-02-28  57

Description 一个长度为n的序列A满足 1.1<=A[i]<=m 2.A[1]=A[n]=1 3.对任意1<=x<=m,至少存在一个1<=i<=n使得A[i]=x 4.对任意1<=x,y<=m,满足A[i]=x,A[i+1]=y的i的数量是D[x][y] 给出D,求满足条件的A序列数量 Input 多组用例,每组用例首先输入一整数m,之后输入一m*m矩阵D,以文件尾结束输入(0<=D[i][j] < 500,1<=m<=400,2<=n=sum{D[i][j]}+1) Output 对于每组用例,输出满足条件的A序列数量,结果模998244353 Sample Input 2 1 2 2 1 4 1 0 0 2 0 3 0 1 2 1 0 0 0 0 3 1 4 0 1 0 0 1 0 0 0 0 0 0 1 0 0 1 0 Sample Output Case #1: 6 Case #2: 18 Case #3: 0 Solution 把1~m看作m个点,D矩阵表示这m个点的邻接矩阵,则问题转化为求该图的欧拉回路个数,但是因为求出欧拉回路有重复(在计数时重边看作不同的边)且对于每一条欧拉回路,要选取一个1点把回路切开成为A序列,对BEST’s THEOREM稍作改动有 Trees用基尔霍夫矩阵计算即可,时间复杂度O(m^3) Code

#include<cstdio> #include<algorithm> using namespace std; typedef long long ll; const int maxn=401,maxm=200001,mod=998244353; int T,n,D[maxn][maxn],a[maxn][maxn]; int fact[maxm],inv[maxn]; void init() { fact[0]=1; for(int i=1;i<maxm;i++)fact[i]=(ll)i*fact[i-1]%mod;//求i! inv[0]=inv[1]=1; for(int i=2;i<maxn;i++)inv[i]=mod-(int)(mod/i*(ll)inv[mod%i]%mod);//线性求1~n逆元 for(int i=1;i<maxn;i++)inv[i]=(ll)inv[i-1]*inv[i]%mod;//预处理i!的逆元 } int inc(int x,int y) { return x+y>=mod?x+y-mod:x+y; } int dec(int x,int y) { return x-y<0?x-y+mod:x-y; } int mod_pow(int a,int b) { int ans=1; while(b) { if(b&1)ans=(ll)ans*a%mod; a=(ll)a*a%mod; b>>=1; } return ans; } int determinant(int n) { int ans=1; for(int k=1;k<=n;k++) { int pos=-1; for(int i=k;i<=n;i++) if(a[i][k]) { pos=i; break; } if(pos==-1)return 0; if(pos!=k) for(int j=k;j<=n;j++)swap(a[pos][j],a[k][j]); int Inv=mod_pow(a[k][k],mod-2); for(int i=k+1;i<=n;i++) if(a[i][k]) { ans=(ll)ans*Inv%mod; for(int j=k+1;j<=n;j++) a[i][j]=dec(((ll)a[i][j]*a[k][k]%mod),((ll)a[k][j]*a[i][k]%mod)); a[i][k]=0; } } for(int i=1;i<=n;i++)ans=(ll)ans*a[i][i]%mod; return ans; } int Solve() { int ans=1; for(int i=1;i<=n;i++) { int in=0,out=0; for(int j=1;j<=n;j++)in+=D[j][i],out+=D[i][j]; if(in!=out)return 0; if(i==1)ans=(ll)ans*fact[in]%mod; else ans=(ll)ans*fact[in-1]%mod; a[i][i]=in; for(int j=1;j<=n;j++)a[i][j]=dec(a[i][j],D[i][j]),ans=(ll)ans*inv[D[i][j]]%mod; } ans=(ll)ans*determinant(n-1)%mod; return ans; } int main() { init(); int Case=1; while(~scanf("%d",&n)) { for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) scanf("%d",&D[i][j]); printf("Case #%d: %d\n",Case++,Solve()); } return 0; }
转载请注明原文地址: https://www.6miu.com/read-55520.html

最新回复(0)