统计最短路
描述
给出n个点,m条带权无向边,问你从1号点到n号点的最短路中有多少种走法?
输入
第一行两个数n,m分别表示点的个数和边的个数。(2≤n≤5000,1≤m≤100000) 接下来m行,每行3个数u,v,w表示u号点到v号点有一条距离为w的边。(1≤u,v≤n,0≤w≤5000) 数据保证1号点能够到达n号点,点和边都可以被走多次。
输出
如果有无穷种走法,输出-1。否则输出走法的方案数mod 1000000009。
样例输入
4 4 1 2 1 1 3 1 2 4 1 3 4 1
样例输出
2
分析
这道题和上一道路径统计其实很相似,就只是多了一个无穷种走法(毒瘤啊……) 而仔细分析一下数据和题意,就会发现,其实无穷种走法,就是指在1~n的最短路中出现过0边权的边,这样他就可以反复地走来来回回地走这同一条边了 那我们需要做的就是判断在1~n的最短路径中是否出现过0边权的路 而判断一条边是否在最短路上,是不能够只简单地满足
d
i
s
[
v
]
=
=
d
i
s
[
u
]
+
w
[
e
]
dis[v]==dis[u]+w[e]
dis[v]==dis[u]+w[e]这个条件,因为这条边可能不存在在1~n这个最短路径上,只是在松弛的时候松弛到了这个地方 所以我们需要跑两遍最短路,正着跑一遍(从S出发),反着跑一遍(从T出发)最后判断这个条件:
d
i
s
[
S
]
[
u
]
+
w
[
e
]
+
d
i
s
[
v
]
[
T
]
=
=
d
i
s
[
S
]
[
T
]
dis[S][u]+w[e]+dis[v][T]==dis[S][T]
dis[S][u]+w[e]+dis[v][T]==dis[S][T] 即可
对了,想起一点,最短路计数的话最好就用dijkstra,用spfa的话,还有些细节要注意
代码
#include<bits/stdc++.h>
#define in read()
#define P 1000000009
#define re register
#define N 5005
#define M 200009
#define ll long long
using namespace std
;
inline int read(){
char ch
;int f
=1,res
=0;
while((ch
=getchar())<'0'||ch
>'9') if(ch
=='-') f
=-1;
while(ch
>='0'&&ch
<='9'){
res
=(res
<<3)+(res
<<1)+ch
-'0';
ch
=getchar();
}
return f
==1?res
:-res
;
}
int n
,m
;
bool vis
[N
];
int nxt
[M
],to
[M
],w
[M
],head
[N
],cnt
=0;
int g
[N
][N
],d
[2][N
],num
[N
];
ll ans
[N
];
void add(int x
,int y
,int z
){nxt
[++cnt
]=head
[x
];head
[x
]=cnt
;to
[cnt
]=y
;w
[cnt
]=z
;}
void dij(int S
){
priority_queue
<pair
<int,int> > q
;
int tp
=(S
==n
)?1:0;
q
.push(make_pair(0,S
));d
[tp
][S
]=0;ans
[S
]=1;
while(!q
.empty()){
int u
=q
.top().second
;
q
.pop();if(vis
[u
]) continue;
vis
[u
]=1;
for(int e
=head
[u
];e
;e
=nxt
[e
]){
int v
=to
[e
];
if(d
[tp
][v
]==d
[tp
][u
]+w
[e
]) ans
[v
]=(ans
[v
]+ans
[u
])%P
;
else if(d
[tp
][v
]>d
[tp
][u
]+w
[e
]){
d
[tp
][v
]=d
[tp
][u
]+w
[e
];
ans
[v
]=ans
[u
];
q
.push(make_pair(-d
[tp
][v
],v
));
}
}
}
}
int main(){
n
=in
;m
=in
;
int i
,j
,x
,y
,u
,v
,z
;
for(i
=1;i
<=m
;++i
){
u
=in
;v
=in
;z
=in
;
add(u
,v
,z
);add(v
,u
,z
);
}
memset(d
,127,sizeof(d
));
dij(1);
memset(vis
,0,sizeof(vis
));memset(ans
,0,sizeof(ans
));
dij(n
);
for(i
=1;i
<=n
;++i
)
{
for(int e
=head
[i
];e
;e
=nxt
[e
])if(!w
[e
]){
j
=to
[e
];
if(d
[0][i
]+d
[1][j
]==d
[0][n
]) return printf("-1"),0;
}
}
printf("%d",ans
[1]);
return 0;
}