并查集简单来说就是数据分类,怎么分呢,初始把数组 pre[i] = i 设定,表示自己归属于自己,如果A是B的老大,那么pre[A] = A, pre[B] = A, 那么后来查询B的时候,我们就可以用一个函数查到他的老大是谁(函数后面介绍)
int find(int x) { return pre[x] == x ? x : find(pre[x]); } void join(int x, int y) { int tx = find(x); int ty = find(y); if(tx != ty) { pre[y] = x; } }初始化pre[i] = i join(1,2), 因为find(1) = 1,find(2) = 2 所以pre[1] = 1,pre[2] = 1 变成find(1) = 1,find(2) = 1
join(2,3) 因为find(2) = 1, find(3) = 3 所以pre[2] = 1, pre[3] = 2 变成find(2) = 1, find(3) = 2
以此类推
输入数据 n, m ,,, n为节点数,m为边的条数 9 7 1 2 2 3 4 5 5 6 6 7 7 8 8 9 x,y表示输入的点,tx,ty表示find后的数值
pre[]numpre[1]1pre[2]1pre[3]2pre[4]4pre[5]4pre[6]5pre[7]6pre[8]7pre[9]8· 从表中能看出,2指向1, 3指向2,每一个元素指向自己上一个节点,直到指到”老大”,根节点。 用find函数去查询的时候,就会按照数组的值查询下去,直到根节点,就是pre[i] = i 的节点。 另外还有一种优化版的find函数,像上面一样查找会很麻烦,一个一个的向根节点查找,能不能直接指向跟节点呢,当然可以
int find(int x) { return pre[x] == x ? x : pre[x] = find(pre[x]); }当find函数这样写的时候,pre数组每次都会更新,如果继续输入上面的例子就会变成
pre[]numpre[1]1pre[2]1pre[3]2pre[4]4pre[5]4pre[6]4pre[7]4pre[8]4pre[9]8除去最后一个数字外, 其他节点都直接指向了根节点,就是因为每次find的时候都会更新pre数组,当然,最后一个数是因为find函数的时候就是 pre[9] = 9,等还有下一条边的时候, 9 、10时,9才会更新。 find函数还有另外的写法,功能是一样的
int find(int x) { int r=x; while(r!=pre[r]) r=pre[r]; int i=x,j; while(pre[i]!=r) { j=pre[i]; pre[i]=r; i=j; } return r; }这个跟优化版的find函数是一样的,好理解,但是写的时候并不快。
一个小例题 这里就是利用在传入并查集的时候进行判定,详细题解请点击这里 AC 代码
#include<iostream> #include<cstring> #define N 1010 using namespace std; char cmd; bool use[N]; int n, d, st, se; struct node { int x, y, pre; }p[N]; int find(int x) { return x == p[x].pre ? x : find(p[x].pre); } void join(const node p1, const node p2) { int root1 = find(p1.pre); int root2 = find(p2.pre); if(root1 != root2) { if((p1.x-p2.x)*(p1.x-p2.x) + (p1.y-p2.y)*(p1.y-p2.y) <= d*d) { p[root1].pre = root2; } } } int main () { scanf("%d %d", &n, &d); for(int i = 1; i <= n; i++) { p[i].pre = i; } memset(use, false, sizeof(use)); for(int i = 1; i <= n; i++) { scanf("%d %d", &p[i].x, &p[i].y); } int x; while(~scanf("\n%c", &cmd)) { if(cmd == 'O') { scanf("%d", &x); use[x] = true; for(int i = 1; i <= n; i++) { if(use[i] && i != x) { join(p[i], p[x]); } } } else { scanf("%d %d", &st, &se); if(find(st) == find(se)) { printf("SUCCESS\n"); } else printf("FAIL\n"); } } return 0; }最小生成树有与最短路一样的算法,还有跟并查集一样的算法,分别为prime和kruskal算法
先说prime 如果还不懂最短路的话请点这里 跟最短路基本一样的代码,先贴代码
int prime() { int ans = 0; for(int i = 1; i <= n; i++) { dis[i] = map[1][i]; } memset(use, false, sizeof(use)); use[1] = true; for(int i = 1; i <= n; i++) { int minn = inf, u; for(int j = 1; j <= n; j++) { if(!use[j] && dis[j] < minn) { minn = dis[u = j]; } } if(minn == inf) break; ans += minn; use[u] = true; for(int j = 1; j <= n; j++) { if(!use[j] && dis[j] > map[u][j]) dis[j] = map[u][j]; } } return ans; }我把最短路代码贴出来对比一下
void dijstra(int st) { dis[st] = 0; memset(vis, 0, sizeof(vis)); for(int cnt = 1; cnt < n; cnt++) { int minn = inf, u; for(int i = 0; i < n; i++) if(!vis[i] && dis[i] < minn) minn = dis[u = i]; if(minn == inf) break; vis[u] = 1; for(int i = 0; i < n; i++) { if(!vis[i] && dis[i] > dis[u] + map[u][i]) dis[i] = dis[u] + map[u][i]; } } }先说一下根本的区别,最短路是得出两地最短距离的算法,二最小生成树是把所有的地全部连接起来,包括间接连接,花费的长度最短的算法。 相比而言,最短路需要只到起点,但是最小生成树不需要起点,它的起点就是你定义时候是从1开始还是从0开始的,那个就是你的起点,因为他要把所有的点都算进来,所有没有起点终点这一说; 不同的是,每次选出最小的之后就加入ans 更新的数值变为 dis[j] = map[u][j]; 因为连接并不重要,重要的是连接点,而且用的是最小的变
kruskal算法 这个算法真的很简单,特别好用、 这个算法适合知道 两个点和距离,如果是图的话,还需要转换才行。 首先我们需要把这些数据存入一个结构体,一个三个数字的结构体,然后按照距离来排序,这样的话先取出的边一定是最短的,然后我们用并查集的办法,查询这两个点是否已经连通,如果没有,那么让它加入并查集,那么以后加入的都是没有加进来的,直到全部判断完,一定是个用最短路径连接的一个路线图
struct node { int x, y, cost; }cnt[5050]; bool cmp(node a, node b) { return a.cost < b.cost; }这是结构体,下面的函数是一个bool类型,是一个判定,在sort里用到
sort(cnt+1, cnt+1+t, cmp);这一条语句的意义就是按照结构体中的C的大小来排序,用结构体的话好处是,一旦利用C来排序,x y 的值是一起排序的,虽然不是按照大小排序的,是跟着C来移动的
int find(int x) { return x == pre[x] ? x : find(pre[x]); } int kruskal() { int ans = 0; for(int i = 1; i <= n; i++) pre[i] = i; sort(cnt+1, cnt+1+t, cmp); for(int i = 1; i <= t; i++) { int x = find(cnt[i].x); int y = find(cnt[i].y); if(x != y) { ans += cnt[i].cost; pre[y] = x; } } return ans; }首先初始化pre数组,t为数据的组数,在加入并查集的时候我们把它的花费记录下来,直至全部遍历完。是不是很简单?
一个小例题 一个完完全全的最小生成树例题,详细题解请点击这 AC代码
#include<iostream> #include<algorithm> #include<cstring> #define maxn 11000 using namespace std; int n, m; int pre[maxn]; struct node { int a, b, cost; }cnt[maxn]; int find(int x) { return pre[x] == x ? x : find(pre[x]); } bool cmp(node a, node b) { return a.cost < b.cost; } int kruskal() { int ans = 0; for(int i = 1; i <= n; i++) pre[i] = i; sort(cnt+1, cnt+1+m, cmp); for(int i = 1; i <= m; i++) { int x = find(cnt[i].a); int y = find(cnt[i].b); if(x != y) { pre[y] = x; ans += cnt[i].cost; } } return ans; } int main () { while(scanf("%d", &n) && n) { scanf("%d", &m); memset(cnt, 0, sizeof(cnt)); for(int i = 1; i <= m; i++) { scanf("%d %d %d", &cnt[i].a, &cnt[i].b, &cnt[i].cost); } printf("%d\n", kruskal()); } }