【一句话题意】给你一个图源点s和汇点t,先求得最短路,再在最短路上,求在s和t同时出发,不相遇不重复的路径方案数。对1e9+7取模。 【分析】考虑容斥原理,不考虑相遇,则方案数为{最短路条数}2。再排除在边上相遇的情况{f[u]*g[v]}2和在点上相遇的情况{f[u]*g[u]}2。 f、g数组为从s(或t)走最短路出发到点i的方案数。 细节上处理的关键在于从整张图中抽离出最短路再进行dp。 【code】
#include<queue> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; typedef long long LL; const int mod=1e9+7; const int maxn=1e5+1000; struct Edge{ int v,nxt; LL c; }edge[maxn<<2]; int head[maxn],tot; int n,m,s,t; inline void read(int &x){ x=0;int fl=1;char tmp=getchar(); while(tmp<'0'||tmp>'9'){if(fl=='-')fl=-fl;tmp=getchar();} while(tmp>='0'&&tmp<='9') x=(x<<1)+(x<<3)+tmp-'0',tmp=getchar(); } inline void add_edge(int x,int y,int c){ edge[tot].v=y; edge[tot].c=c; edge[tot].nxt=head[x]; head[x]=tot++; } int p[maxn]; LL dist[maxn],f[maxn],g[maxn],ans; bool inq[maxn]; inline void spfa(){ memset(dist,0x3f,sizeof(dist)); queue<int> q; q.push(s);dist[s]=0; while(!q.empty()){ int u=q.front();q.pop();inq[u]=0; for(int i=head[u];i!=-1;i=edge[i].nxt){ int v=edge[i].v; if(dist[v]>dist[u]+edge[i].c){ dist[v]=dist[u]+edge[i].c; if(!inq[v])inq[v]=1,q.push(v); } } } } bool cmp(int x,int y){ return dist[x]<dist[y]; } int main(){ memset(head,-1,sizeof(head)); cin>>n>>m>>s>>t; for(int i=1;i<=m;i++){ int u,v,c;read(u),read(v),read(c); add_edge(u,v,c); add_edge(v,u,c); } spfa(); for(int i=1;i<=n;i++) p[i]=i; sort(p+1,p+n+1,cmp); f[s]=1; for(int i=1;i<=n;i++){ int x=p[i]; for(int j=head[x];j!=-1;j=edge[j].nxt) if(dist[edge[j].v]==dist[x]+edge[j].c) f[edge[j].v]=(f[edge[j].v]+f[x])%mod; } g[t]=1; for(int i=1;i<=n;i++){ int x=p[n-i+1]; for(int j=head[x];j!=-1;j=edge[j].nxt) if(dist[edge[j].v]==dist[x]-edge[j].c) g[edge[j].v]=(g[edge[j].v]+g[x])%mod; } ans=f[t]*f[t]%mod; for(int i=1;i<=n;i++){ int x=i; for(int j=head[x];j!=-1;j=edge[j].nxt) if(dist[edge[j].v]==dist[x]+edge[j].c){ if((dist[edge[j].v]<<1)>dist[t]&&(dist[x]<<1)<dist[t]){ ans=((ans-f[x]*g[edge[j].v]%mod*f[x]%mod*g[edge[j].v])%mod+mod)%mod; } } } for(int i=1;i<=n;i++) if((dist[i]<<1)==dist[t]) ans=((ans-f[i]*f[i]%mod*g[i]%mod*g[i])%mod+mod)%mod; printf("%lld\n",ans); return 0; }