[Leetcode] Union Find

Introduction

Video: https://www.youtube.com/watch?v=0jNmHPfA_yE&t=59s
Chinese Video: https://www.youtube.com/watch?v=VJnUwsE4fWA

Find(x) Find the root/cluster of x. To find which cluster a particular element belongs to find the root of that cluster by following the parent node until a self loop is reached.

Union(x, y) To unify two elements find which are the root nodes of each component and if the root nodes are different, make one of the root nodes be the parent of the other

In this data structure, we don’t “un-union” elements. In general, this would be very inefficient to do since we would have to update all the children of a node.

The number of components is equal to the number of roots remaining.

  1. Map the elements to a array.

    [E, F, I, D, C, A, J, L, G, K, B, H]
    index 0 1 2 3 4 5 6 7 8 9 10 11
    union(C, K) # make K’s root C
    0 1 2 3 4 5 6 7 8 4 10 11

Time complexity: O(n) to find the root

  1. Initialize self-pointing root array
  2. Union by rank

Optimize: Path Compression Compress the path to make all vertices directly directly points to the root.

Video: https://www.youtube.com/watch?v=VHRhJWacxis

Time Complexity of Look up: O(1) to find the root.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class UnionFind {
def __init__(self, n):
self._parents = [i for i in range(n + 1)]
self._rank = [0 for i in range(n + 1)] # union by rank

def find(self, u):
while self._parents[u] != u: # self loop reached
self._parents[u] = self._parents[self._parents[u]] # Compress path
u = self._parents[u]
return u

def union(self, u, v):
uroot = self.find(u)
vroot = self.find(v)

if uroot == vroot:
return False # Cant union nodes in the same group
if self.rank[uroot] > self.ranks[vroot]:
self._parents[vroot] = uroot
elif self.rank[uroot] < self.ranks[vroot]:
self._parents[uroot] = vroot
else:
self._parents[vroot] = uroot
self._ranks[uroot] += 1
return True
}

684. Redundant Connection

[Solution] Keep union edges, return the last edge that cannot be added into minimal spanning tree (belongs to the same group -> cycle)

1
2
3
4
5
6
7
8
# With the 
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
uf = UnionFind(len(edges))
no = []
for edge in edges:
if not uf.union(edge[0], edge[1]):
no.append(edge)
return no[-1]

Shorter version of union find:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def findRedundantConnection(self, edges: 'List[List[int]]') -> 'List[int]':
N = len(edges)
parent = [-1] * (N+1)

def find_root(parent, i):
if parent[i] == -1 or parent[i] == i:
parent[i] = i
return i
else:
return find_root(parent, parent[i])

for edge in edges:
x, y = edge
rootx = find_root(parent, x)
rooty = find_root(parent, y)
if rootx != -1 and rootx == rooty:
break
parent[rooty] = rootx
return [x,y]

200. Number of Islands

Leetcode: https://leetcode.com/problems/number-of-islands/

1
2


547. Friend Circles

Leetcode: https://leetcode.com/problems/friend-circles/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def findCircleNum(self, M: List[List[int]]) -> int:
parents = [-1 for i in range(len(M))]
# [0,1]
def find_parent(parents, u):
if parents[u] == u or parents[u] == -1:
parents[u] = u
return u
else:
return find_parent(parents, parents[u])
def union(u, v):
parents[find_parent(parents, u)] = find_parent(parents, v)
# Union: group them up
for i in range(len(M)):
for j in range(len(M[0])):
if M[i][j] == 1:
# Union
union(i,j)
# print(parents)

circle = set(find_parent(parents, i) for i in range(len(M)))
return len(circle)

721. Accounts Merge

Leetcode: https://leetcode.com/problems/accounts-merge/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class DSU(object):
def __init__(self):
self.parents = range(10001) # 1000 * 10 + 1
def find(self, u):
if self.parents[u] == u:
return u
return self.find(self.parents[u])
def union(self, u, v):
if u == v:
return
self.parents[self.find(u)] = self.find(v) # let u's parent become v's root's child

class Solution(object):
def accountsMerge(self, accounts):
"""
:type accounts: List[List[str]]
:rtype: List[List[str]]
"""
dsu = DSU()
email_index = {}
email_name = {}
i = 0
for acc in accounts:
name = acc[0]
for email in acc[1:]:
email_name[email] = name
if email not in email_index:
email_index[email] = i # transform email to int
i += 1
dsu.union(email_index[email], email_index[acc[1]]) # union all email from this person to a group
res = collections.defaultdict(list)
for email in email_name: # all the unique email
res[dsu.find(email_index[email])].append(email) # append email to the parent id

return [[email_name[v[0]]] + sorted(v) for v in res.values()]

737. Sentence Similarity II

Leetcode: https://leetcode.com/problems/sentence-similarity-ii/

Note: Beware of the cases when all the words in the first sentence and second sentence are the same, but all of these words are not in pairs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def areSentencesSimilarTwo(self, words1, words2, pairs):
"""
:type words1: List[str]
:type words2: List[str]
:type pairs: List[List[str]]
:rtype: bool
"""
if len(words1) != len(words2):
return False
if not words1 and not words2:
return True

# give each word unique id
i = 0
word_index = {}
for pair in pairs:
if pair[0] not in word_index:
word_index[pair[0]] = i
i += 1
if pair[1] not in word_index:
word_index[pair[1]] = i
i += 1

parents = range(i + 1)
def find(u):
if parents[u] == u:
return u
return find(parents[u])
def union(u, v):
if u == v:
return
parents[find(v)] = find(u) # merge u to v
for pair in pairs:
u, v = word_index[pair[0]], word_index[pair[1]]
union(u, v)
for i in range(len(words1)):
if words1[i] == words2[i]:
continue
if words1[i] not in word_index or words2[i] not in word_index:
return False
if find(word_index[words1[i]]) != find(word_index[words2[i]]):
return False
return True

130. Surrounded Regions

Leetcode: https://leetcode.com/problems/surrounded-regions/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def solve(self, board):
"""
:type board: List[List[str]]
:rtype: None Do not return anything, modify board in-place instead.
"""
if not board or not board[0]:
return []
parents = range(len(board) * len(board[0]))
onBorder = [0 for i in range(len(board) * len(board[0]))]

def find(u):
if parents[u] == u:
return u
return find(parents[u])
def union(u, v):
if u == v:
return
# merge v into u
parents[find(v)] = find(u)
# update isBorder
if onBorder[find(v)] or onBorder[find(u)]:
onBorder[find(u)] = 1

# mark if it is on the border
for i in range(len(onBorder)):
x, y = i // len(board[0]), i % len(board[0])
if board[x][y] == "O" and (x == 0 or x == len(board) - 1 or y == 0 or y == len(board[0]) - 1):
onBorder[i] = 1

for i in range(len(onBorder)):
x, y = i // len(board[0]), i % len(board[0])
up, right = x - 1, y + 1
if up >= 0 and board[up][y] == board[x][y] == "O":
union(i, i - len(board[0]))
if right < len(board[0]) and board[x][y] == board[x][right] == "O":
union(i, i + 1)
# print(parents)
# print(onBorder)

for i in range(len(onBorder)):
x, y = i // len(board[0]), i % len(board[0])
if board[x][y] == "O" and not onBorder[find(i)]:
board[x][y] = "X"

261. Graph Valid Tree

Leetcode: https://leetcode.com/problems/graph-valid-tree/

Note:

  • Keep unioning the nodes on each edge
  • If two nodes are already unioned, which means the current edge is an additional edge, and there’s a cycle
  • Use len(edges) == n - 1 to detect any isolated islands
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def validTree(self, n, edges):
# Detect cycle in graph
# Detect isolated islands in graph
roots = range(n)
def find(u):
if roots[u] == u:
return u
return find(roots[u])

for edge in edges:
x = find(edge[0])
y = find(edge[1])

if x == y:
# It's an additional edge, and x and y are already in a same set
return False
roots[y] = x
return len(edges) == n - 1 # If there's isolated islands

Minimal Spanning Tree

Video (Huahua): https://www.youtube.com/watch?v=wmW8G8SrXDs&ab_channel=HuaHua

Kruskal’s Method

Video: https://www.youtube.com/watch?v=JZBQLXgSGfs

Given a graph G = (V, E), we want to find a Minimal Spanning Tree in the graph (may not be unique). A minimum spanning tree is a subset of the edges which connects all vertices in the graph with the minimal edge cost. O(ElogV)

  1. Sort edges by ascending edge weight
  2. Walk through the sorted edges and look at the two nodes the edge belongs to, if the nodes are already unified we don’t include this edge, otherwise we include it and unify the nodes
  3. The algorithm terminates when every edge has been processed or all the vertices have been unified

Prim’s Method

Video: https://www.youtube.com/watch?v=4ZlRH0eK-qQ&ab_channel=AbdulBari

Always select minimum connected edge and expanding the tree starting from the smallest edge. O(ElogV)

1584. Min Cost to Connect All Points

Leetcode: https://leetcode.com/problems/min-cost-to-connect-all-points/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import heapq

class UnionFind:
def __init__(self, points):
self.roots = {}
self.size = {}
for i, j in points:
self.roots[(i, j)] = (i, j)
self.size[(i, j)] = 1

def find(self, u):
while self.roots[u] != u:
self.roots[u] = self.roots[self.roots[u]]
u = self.roots[u]
return self.roots[u]

def union(self, u, v):
u, v = self.find(u), self.find(v)
if u == v:
return
if self.size[u] > self.size[v]:
self.roots[v] = u
self.size[u] += self.size[v]
else:
self.roots[u] = v
self.size[v] += self.size[u]

class Solution:
def minCostConnectPoints(self, points: List[List[int]]) -> int:
heap = []
uf = UnionFind(points)
for i in range(len(points) - 1):
for j in range(i + 1, len(points)):
dist = abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
heapq.heappush(heap, (dist, (tuple(points[i]), tuple(points[j]))))
res = 0
unionCount = 0
while heap:
dist, pair = heapq.heappop(heap)
a, b = pair[0], pair[1]
if uf.find(a) != uf.find(b):
uf.union(a, b)
unionCount += 1
res += dist
if unionCount == len(points) - 1:
return res
return res

778. Swim in Rising Water

Leetcode: https://leetcode.com/problems/swim-in-rising-water/

Prim’s minimum spanning tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def swimInWater(self, grid: List[List[int]]) -> int:
heap = []
heapq.heappush(heap, (grid[0][0], (0, 0)))
total = 0
def isValid(i, j):
return i >= 0 and i < len(grid) and j >= 0 and j < len(grid[0])
dirs = [(0, 1), (1, 0), (0, -1), (-1, 0)]
visited = set((0, 0))
while heap:
cost, loc = heappop(heap)
visited.add(loc)

total = max(total, cost)
i, j = loc[0], loc[1]
if i == len(grid) - 1 and j == len(grid[0]) - 1:
return total
for delta_i, delta_j in dirs:
new_i, new_j = i + delta_i, j + delta_j
if isValid(new_i, new_j) and (new_i, new_j) not in visited:
heapq.heappush(heap, (grid[new_i][new_j], (new_i, new_j)))
return -1

684. Redundant Connection

Leetcode: https://leetcode.com/problems/redundant-connection/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class UnionFind:
def __init__(self, n):
self.roots = [i for i in range(n)]
self.size = [1 for _ in range(n)]

def find(self, u):
while self.roots[u] != u:
self.roots[u] = self.roots[self.roots[u]]
u = self.roots[u]
return self.roots[u]

def union(self, u, v):
u, v = self.find(u), self.find(v)
if u == v:
return False
if self.size[u] > self.size[v]:
self.roots[v] = u
self.size[u] += self.size[v]
else:
self.roots[u] = v
self.size[v] += self.size[u]
return True

class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
n = 0
for i, j in edges:
n = max(n, i, j)
uf = UnionFind(n)
for i, j in edges:
if not uf.union(i - 1, j - 1):
return [i, j]
return []