找出二值图像的所有连通域,面试官给了半个小时,其实还是很充分的,大概在规定时间内做出了这道题。发个blog稍微做个记录。分别用python和C++写出这道题。
详细描述一下这道题:
输入是一张二值化的矩阵,只含有0和255两种取值。求255的数值组成的连通域,规则就和围棋连通的规则一样,这道题像是在找围棋里白子的连通域。输出形式简单明了,比如这张图里有诸多连通域,那么对每个连通域进行编号,第一个连通域取值全部为1,第二个连通域取值为2,以此类推。
比如,输入为:
[[0 , 255, 0 , 0, 0 , 255], [255, 255, 0 , 0, 255, 0 ], [0 , 255, 255, 0, 255, 0 ]]则要求的输出为:
[[0 1 0 0 0 2] [1 1 0 0 3 0] [0 1 1 0 3 0]]我把这道题当作动态规划问题来做,首先要判断一个值为255的pixel连通哪些255的pixel,依次找到连通的所有pixel。就类似于一种蔓延。当时,为了求快,首选的语言还是python:
import numpy as np def islink(a, b, arr): """ to get a list of neighbor 255-value pixel's position """ ls = [] m, n = arr.shape left = max(0, b - 1) # left edge to prevent overflow right = min(n - 1, b + 1) # right edge to prevent overflow top = max(0, a - 1) # top edge to prevent overflow btm = min(m - 1, a + 1) # bottom edge to prevent overflow if arr[a][left] == 255: ls.append([a, left]) if arr[a][right] == 255: ls.append([a, right]) if arr[top][b] == 255: ls.append([top, b]) if arr[btm][b] == 255: ls.append([btm, b]) return ls def sol(arr): m, n = arr.shape pointer = 0 # initialize the link-area value for i in range(m): for j in range(n): if arr[i][j] == 255: if pointer == 255.5: pointer = int(pointer + 0.5) else: pointer += 1 # change link-area value if pointer == 255: # prevent link-area value equal to pixel value pointer += 0.5 arr[i][j] = pointer ls = islink(i, j, arr) while True: if len(ls) == 0: break temp_ls = [] for k in range(len(ls)): c = ls[k][0] r = ls[k][1] arr[c][r] = pointer temp_ls = temp_ls + islink(c, r, arr) ls = temp_ls arr[arr == 255.5] = 255 # restore the true value of the 255th link-area return arr def run(): arr = np.array([[0, 255, 0, 0, 0, 255], [255, 255, 0, 0, 255, 0], [0, 255, 255, 0, 255, 0]]) print(sol(arr)) run()今天手写的代码复现到IDE上,居然跑通了。不过之前没考虑到link-area == 255的情况,一个corner case被面试官指出来,所以现在加了个判断条件,遇到255的时候自动加0.5变成255.5,然后在返回矩阵之前修改回255。这样就解决了corner的情况。面试官建议的方法是新建一个矩阵,把像素设为负值,就不会存在link-area value等于pixel value的情况。其实解决办法还是很多啦
输出结果为:
[[0 1 0 0 0 2] [1 1 0 0 3 0] [0 1 1 0 3 0]]C++版本:
#include<iostream> #include<string> #include<vector> using namespace std; vector<vector<int>> islink(int a, int b, vector<vector<int>> arr) { int m = arr.size(); int n = arr[0].size(); vector<vector<int>> ls; int left = (0 > b - 1) ? 0 : b - 1; int right = (n - 1 < b + 1) ? n - 1 : b + 1; int top = (0 > a - 1) ? 0 : a - 1; int btn = (m - 1 < a + 1) ? m - 1 : a + 1; if (arr[a][left] == 255) { vector<int> temp; temp.push_back(a); temp.push_back(left); ls.push_back(temp); } if (arr[a][right] == 255) { vector<int> temp; temp.push_back(a); temp.push_back(right); ls.push_back(temp); } if (arr[top][b] == 255) { vector<int> temp; temp.push_back(top); temp.push_back(b); ls.push_back(temp); } if (arr[btn][b] == 255) { vector<int> temp; temp.push_back(btn); temp.push_back(b); ls.push_back(temp); } return ls; } vector<vector<int>> sol(vector<vector<int>> arr) { int m = arr.size(); int n = arr[0].size(); int pointer = 0; vector<vector<int>> ls; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { if (arr[i][j] == 255) { pointer++; arr[i][j] = pointer; ls = islink(i, j, arr); while (1) { if (!ls.size()) break; vector<vector<int>> temp_ls; for (int k = 0; k < ls.size(); k++) { int r = ls[k][0]; int c = ls[k][1]; arr[r][c] = pointer; vector<vector<int>> tms = islink(r, c, arr); for (int z = 0; z < tms.size(); z++) { temp_ls.push_back(tms[z]); }; } ls = temp_ls; } } } } return arr; } void print_vector(vector<vector<int>> arr) { int m = arr.size(); int n = arr[0].size(); for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { cout << arr[i][j] << "\t"; } cout << "\n"; } } int main() { vector<vector<int>> arr; arr = { { 0, 255, 0, 0, 255}, {255, 255, 0,255, 0}, {255, 0 ,255,255, 0} }; arr = sol(arr); print_vector(arr); system("pause"); return 0; }C++的代码量明显比Python长,即使是同样的逻辑。。