-
Notifications
You must be signed in to change notification settings - Fork 80
/
weighted_quick_union_union_find.py
66 lines (53 loc) · 1.76 KB
/
weighted_quick_union_union_find.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# coding: utf-8
"""
https://learning.oreilly.com/library/view/Data+Structures+and+Algorithms+in+Python/9781118290279/19_chap14.html#ch014-sec051
"""
class Node:
def __init__(self, element):
self.element = element
self.parent = self # NOTE: We treat the node instance as the node id.
self.size = 1
def __len__(self):
return self.size
# There is another implementation:
# https://github.com/vinta/fuck-coding-interviews/blob/master/problems/kruskal_mst_really_special_subtree.py
class WeightedQuickUnionUnionFind:
def __init__(self, union_pairs=()):
self.num_groups = 0
self.element_groups = {
# element: node_instance
}
for p, q in union_pairs:
self.union(p, q)
def __len__(self):
return self.num_groups
def make_group(self, element):
node = self.element_groups.get(element)
if node is None:
node = Node(element)
self.element_groups[element] = node
self.num_groups += 1
return node
def find(self, p):
try:
node = self.element_groups[p]
except KeyError:
node = self.make_group(p)
else:
while node.parent != node:
node = node.parent
return node.parent
def union(self, p, q):
p_group = self.find(p)
q_group = self.find(q)
if len(p_group) < len(q_group):
# Merge p into q.
p_group.parent = q_group.parent
q_group.size += p_group.size
else:
# Merge q into p.
q_group.parent = p_group.parent
p_group.size += q_group.size
self.num_groups -= 1
def is_connected(self, p, q):
return self.find(p) == self.find(q)