diff --git a/retroactive/basic/__init__.py b/retroactive/basic/__init__.py index 0b5b977..46ee547 100644 --- a/retroactive/basic/__init__.py +++ b/retroactive/basic/__init__.py @@ -1,3 +1,7 @@ -from bst import BSTNode from names import * +from bst import BSTNode from queue import Queue +from link_cut_tree import LinkCutTree +from splay_tree import SplayNode + + diff --git a/retroactive/basic/link_cut_tree.py b/retroactive/basic/link_cut_tree.py index 802ac1f..65969d1 100644 --- a/retroactive/basic/link_cut_tree.py +++ b/retroactive/basic/link_cut_tree.py @@ -1 +1,143 @@ -## TODO +from splay_tree import SplayNode + + +class LinkCutTree(object): + + def __init__(self): + self.nodes = {} + + # access(n) makes n the root of the virtual tree, and its aux tree contains the path from root to v only. + # First we cut off the right child of the node we need to access (this is done by setting the value of prev to + # None initially). We then splay n to the root of its aux tree. After that we attach n to its path parent (the node + # it points to outside of its own aux tree), we then splay the parent, and repeat the process until we reach the + # root of the virtual tree. After this process we have a chain of right children all the way down to n. + # Splaying n now makes it the root of the virtual tree. It has no right child since we cut it off initially. + # The sub-tree to its left contains all its path parents in-order because the splay and right link process preserves + # order, and the link function (defined elsewhere) always adds new children on the right. + # Note: we return the last aux root so we can have an efficient lca algorithm + def access(self, node): + next = node + prev = None + while next is not None: + next.splay() + next.right = prev + prev = next + next = next.parent + + node.splay() + return prev + + # + # Mutate + # + + # cut(n) accesses n, meaning it is the deepest node on the current preferred path and thus only has a left child + # (no parent since it is the root of the virtual tree). When n is accessed it becomes the head of the virtual tree. + # All nodes in its left subtree are its ancestors in the preferred path. + # All n's children in the represented tree currently have path parent pointers to n. + # So removing n's left child severs the subtree rooted at n from the rest of the tree. + def cut(self, node): + self.access(node) + + # if the node is already a root return + if node.left is None: + return + + # virtual tree updates + node.left.parent = None + node.left = None + + # represented tree updates + node.represented_parent = None + node.parent_edge_weight = None + + # link(p,c,e) Connects two represented trees by connecting the auxiliary trees containing the nodes p and c. + # p represents the node that will get a new child path in the represented tree, and c is the new sub-tree. + # c needs to be the root of its represented tree to ensure that it does not have multiple parents. + # Since c can only have one unique parent in the new represented tree, p (the leaf node in the path that c + # was attached to), we store information about the represented edge between c and p in c. + # Following a nodes represented parents up a tree is equivalent to doing a reverse in order traversal. + def link(self, parent_node, child_root, edge_weight = None): + # make sure the two nodes are from different trees + if self.getRoot(parent_node) == self.getRoot(child_root): + return + + self.access(parent_node) + self.access(child_root) + # assert: child_root.isRoot() == True and child.left == None, because if not we will have two paths leading + # to the child node after linking (child_root.left and child_root.parent) violating the tree property + if child_root.left is not None: + raise Exception('Trying to link a child tree from an internal node') + + #virtual tree + parent_node.right = child_root + child_root.parent = parent_node + + #represented tree + child_root.parent_edge_weight = edge_weight + child_root.represented_parent = parent_node + + def makeTree(self, data): + if data in self.nodes: + raise Exception("makeTree: Creating duplicate nodes") + else: + new_node = SplayNode(data) + self.nodes[data] = new_node + return new_node + + # makeRoot(node) flips the path from node to its root, inverting the path and making it a child path of node. + # All other connections are preserved. + def makeRoot(self, node): + self.access(node) + flipped = None + to_flip = node + while to_flip is not None: + next_edge_weight = to_flip.parent_edge_weight + next_to_flip = to_flip.represented_parent or None + next_edge_weight = to_flip.parent_edge_weight + self.cut(to_flip) + if flipped is not None: + self.link(flipped, to_flip, edge_weight) + flipped = to_flip + to_flip = next_to_flip + edge_weight = next_edge_weight + + + + # + # Query + # + + def getNode(self, data): + if data in self.nodes: + return self.nodes[data] + else: + return None + + # get root(n) accesses n, putting it in the preferred path from the root. Since the in order traversal of the aux + # tree represents the path, the leftmost node will be the root of the represented tree. + def getRoot(self, node): + self.access(node) + while node.left is not None: + node = node.left + # splay the node (now root) so the cost of sequential requests for the same root is amortized O(lg n) + node.splay() + return node + + # path aggregate follows the path in the represented tree from root to the chosen node + # (under the hood this means traversing the aux tree containing the chosen node in order) + def pathAggregate(self, node, fn): + self.access(node) + node.inOrder(fn) + + # lca(a,b) returns the least common ancestor of a and b. This works because after accessing a the last aux tree + # before b's tree becomes the root tree is the point at which the path from root to a and b diverges. This is + # because each aux jump basically follows a path parent pointer. So if we access a and make it the root aux tree, + # the access to b will eventually have to jump into that aux tree. The path pointer it uses to do that will point + # to the node at which a and b diverge. + def lca(self, a, b): + if self.getRoot(a) != self.getRoot(b): + return None + self.access(a) + return self.access(b) + diff --git a/retroactive/basic/splay_tree.py b/retroactive/basic/splay_tree.py new file mode 100644 index 0000000..e14435e --- /dev/null +++ b/retroactive/basic/splay_tree.py @@ -0,0 +1,133 @@ +class SplayNode(object): + """Splay tree used by the link-cut tree.""" + + def __init__(self, data): + self.data = data + self.left = None + self.right = None + self.parent = None + # there can only be a represented edge between two nodes if one is the in order successor of the other + self.represented_parent = None + # stores the edge weight to the parent, since this is a tree this is enough to represent all edges uniquely + self.parent_edge_weight = float("inf") + + + def inOrder(self, fn): + if self.left is not None: + self.left.inOrder(fn) + + fn(self.data) + + if self.right is not None: + self.right.inOrder(fn) + + + def addLeft (self, other): + self.left = other + other.parent = self + + def addRight (self, other): + self.right = other + other.parent = self + + def isRoot(self): + return (self.parent == None or (self.parent.left != self and self.parent.right != self)) + + # copied from http://stevekrenzel.com/articles/printing-trees + def __str__(self, depth=0): + ret = "" + + # Print right branch + if self.right != None: + ret += self.right.__str__(depth + 1) + + # Print own value + ret += "\n" + (" "*depth) + str(self.data) + + # Print left branch + if self.left != None: + ret += self.left.__str__(depth + 1) + + return ret + + def rotateRight(self): + if self.isRoot(): + raise Exception("Trying to right rotate a root") + old_parent = self.parent + + old_parent.left = self.right + if old_parent.left is not None: + old_parent.left.parent = old_parent + + self.right = old_parent + self.parent = old_parent.parent + old_parent.parent = self + + if self.parent is not None: + if self.parent.left == old_parent: + self.parent.left = self + elif self.parent.right == old_parent: + self.parent.right = self + + + def rotateLeft(self): + if self.isRoot(): + raise Exception("Trying to right rotate a root") + + old_parent = self.parent + + old_parent.right = self.left + if old_parent.right is not None: + old_parent.right.parent = old_parent + + self.left = old_parent + self.parent = old_parent.parent + old_parent.parent = self + + if self.parent is not None: + # if neither of these cases is triggered it means self is the root of its splay tree and the parent pointer points to a path parent + if self.parent.left == old_parent: + self.parent.left = self + elif self.parent.right == old_parent: + self.parent.right = self + + + def splay(self): + while not self.isRoot(): + if self.parent.isRoot(): + if self.parent.left == self: + self.rotateRight() + elif self.parent.right == self: + self.rotateLeft() + else: + # this should never happen because an unacknowledged child is treated as a root, violating the loop condition + raise Exception("Splay: Attempting to rotate an unacknowledged (is not a left or right child) node ") + else: + # assert: grandparent != null because !parent.isRoot() + grandparent = self.parent.parent + if grandparent.left == self.parent: + + if self.parent.left == self: + #zig-zig + self.parent.rotateRight() + self.rotateRight() + else: + #zig-zag + self.rotateLeft() + self.rotateRight() + + elif grandparent.right == self.parent: + # assert: grandparent.right == self + if self.parent.right == self: + #zig-zig + self.parent.rotateLeft() + self.rotateLeft() + else: + #zig-zag + self.rotateRight() + self.rotateLeft() + + else: + # this should never be thrown since a node without a grandparent should be caught + # in the first if statement in the loop (self.parent.isRoot()) + raise Exception("Splay: grandparent is not attached to parent") diff --git a/retroactive/dispatcher.py b/retroactive/dispatcher.py index 4acee5e..1c10ef3 100644 --- a/retroactive/dispatcher.py +++ b/retroactive/dispatcher.py @@ -1,6 +1,8 @@ from basic import PriorityQueue, Deque, Queue, UnionFind, Stack, SDPS from partial import GeneralPartiallyRetroactive, PartiallyRetroactiveQueue, PartiallyRetroactiveSDPS from full_retroactivity import GeneralFullyRetroactive +from retroactive.full import RetroactiveUnionFind + def PartiallyRetroactive(initstate): """ @@ -44,7 +46,6 @@ def FullyRetroactive(initstate): elif isinstance(initstate, PriorityQueue): return GeneralFullyRetroactive(PartiallyRetroactive(initstate)) elif isinstance(initstate, UnionFind): - ##TODO: update once FR UnionFind is implemented - return GeneralFullyRetroactive(PartiallyRetroactive(initstate)) + return RetroactiveUnionFind() else: return GeneralFullyRetroactive(PartiallyRetroactive(initstate)) diff --git a/retroactive/examples.py b/retroactive/examples.py index d730308..44663ef 100644 --- a/retroactive/examples.py +++ b/retroactive/examples.py @@ -71,7 +71,28 @@ def testPartiallyRetroactiveSDPS(): print x.state assert x.state == [1,2,3,4,5] +def testFullyRetroActiveUnionFind(): + x = RetroactiveUnionFind() + # union a and b at the current time + x.unionAgo('a','b') + assert x.sameSetAgo('a', 'b', -2) == False + + # union a and b two steps earlier + x.unionAgo('a', 'b' ,-2) + assert x.sameSetAgo('a', 'b', -3) == True + + x.unionAgo('c','d') + assert x.sameSetAgo('b', 'd') == False + + # union a and c before all other unions + x.unionAgo('a','c',-10) + assert x.sameSetAgo('b', 'd',-9) == False + assert x.sameSetAgo('b','d', 0) == True + def all_tests(): testPartiallyRetroactiveSDPS() testPartiallyRetroactiveQueue() testGeneralPartiallyRetroactive() + testFullyRetroActiveUnionFind() + +all_tests() \ No newline at end of file diff --git a/retroactive/full/union_find.py b/retroactive/full/union_find.py index 3d0c6d8..cc6a82d 100644 --- a/retroactive/full/union_find.py +++ b/retroactive/full/union_find.py @@ -1,8 +1,101 @@ +from retroactive.basic import LinkCutTree class RetroactiveUnionFind(object): - ## Requires an implementation of link-cut trees. - ## TODO. - def __init__(self, initstate): - """ - initstate :: UnionFind. - """ - raise NotImplementedError() + """Fully retroactive union find implemented using link-cut trees to represented disjoint forests""" + def __init__(self): + self.forest = LinkCutTree() + self.time = 0 + + # unionAgo(a,b) links nodes a and b in the LinkCutTree if they weren't already linked. + # If they were it cuts the latest edge on the path between a and b (if it is later than the new link time). + # The new subtree will contain exactly one of a and b. We make that node the root of the subtree and link it to + # the other node still in the main tree. + # This preserves old unions because any edge below the cut edge will now follow a path through the a-b edge up to + # the lca, which is guaranteed not to have an edge value greater than the cut edge. Any edge above the cut \ + # edge will not be affected. + def unionAgo(self, a_data, b_data, tdelta = 0): + # if the sets are already connected at the specified time return + if self.sameSetAgo(a_data,b_data,tdelta): + return + + #get node objects to work with + a = self.forest.getNode(a_data) + b = self.forest.getNode(b_data) + union_time = self.time + tdelta + if a is None: + a = self.forest.makeTree(a_data) + + if b is None: + b = self.forest.makeTree(b_data) + + + # if the nodes are not connected at all, connect them. If they are connected at a later time, + # cut the oldest edge on the path between the two nodes, make the union'ed node the root of that tree + # and attach it to the other union'ed node. + if self.forest.getRoot(a) != self.forest.getRoot(b): + self.forest.makeRoot(b) + self.forest.link(a,b,union_time) + else: + lca = self.forest.lca(a,b) + max_time = float("-inf") + max_time_node = None + for next in [a,b]: + while next is not lca: + if next.parent_edge_weight > max_time: + max_time = next.parent_edge_weight + max_time_node = next + next = next.represented_parent + + self.forest.cut(max_time_node) + if self.forest.getRoot(a) == next: + self.forest.makeRoot(a) + self.forest.link(b,a,union_time) + else: + self.forest.makeRoot(b) + self.forest.link(a,b,union_time) + self.time += 1 + + # sameSetAgo(a,b,t) will find the lca of a and b and traverse the path from both to the lca, + # finding the largest edge on the path between a and b. If any edge is larger than time + tdelta then a, and b + # were not connected at (time + tdelta) + def sameSetAgo(self, a_data, b_data, tdelta = 0): + # sameset is reflexive + if a_data == b_data: + return True + + a = self.forest.getNode(a_data) + b = self.forest.getNode(b_data) + query_time = self.time + tdelta + + if a is None or b is None: + return False + + lca = self.forest.lca(a,b) + if lca is None: + return False + + for next in [a,b]: + while next is not lca: + if next.parent_edge_weight > query_time: + return False + next = next.represented_parent + + return True + + # sameSetWhen(a,b) traverses the path between a and b and return the largest edge, + # which is the time at which a and b were connected. + def sameSetWhen(self, a, b): + lca = self.forest.lca(a,b) + if lca is None: + return float("-inf") + + max_time = float("-inf") + for next in [a,b]: + while next is not lca: + if next.parent_edge_weight > max_time: + max_time = next.parent_edge_weight + next = next.represented_parent + + return max_time + + + diff --git a/tests/test_link_cut.py b/tests/test_link_cut.py new file mode 100644 index 0000000..9e4fc11 --- /dev/null +++ b/tests/test_link_cut.py @@ -0,0 +1,288 @@ +import unittest +from retroactive.basic import LinkCutTree + +class test_link_cut(unittest.TestCase): + + def test_getNode_nodeDoesntExist_returnsNone(self): + tree = LinkCutTree() + self.assertIsNone(tree.getNode('a1')) + + def test_getNode_nodeExists_returnsNode(self): + tree = LinkCutTree() + a1 = tree.makeTree('a1') + self.assertEqual(a1,tree.getNode('a1')) + + def test_link_connectsTwoTrees(self): + #arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + + #act + tree.link(a1,a2) + + #assert + self.assertEqual(a1, tree.getRoot(a2)) + + def test_getRoot__returnsProperRootWhenRootHasMultipleChildren(self): + #arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + b1 = tree.makeTree('b1') + tree.link(a1,a2) + tree.link(a1,b1) + #act / assert + self.assertEqual(a1, tree.getRoot(a2)) + self.assertEqual(a1, tree.getRoot(b1)) + + def test_cut__leaves_bothTreesIntact(self): + # arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + a4 = tree.makeTree('a4') + a5 = tree.makeTree('a5') + a6 = tree.makeTree('a6') + tree.link(a1,a2) + tree.link(a2,a3) + tree.link(a3,a4) + tree.link(a4,a5) + tree.link(a5,a6) + + c1 = tree.makeTree('c1') + c2 = tree.makeTree('c2') + c3 = tree.makeTree('c3') + tree.link(c1,c2) + tree.link(c2,c3) + + tree.link(a6,c1); + + #act + tree.cut(c2) + + #assert + self.assertEqual(a1, tree.getRoot(c1)) + self.assertEqual(c2, tree.getRoot(c3)) + + def test_link__preservesRepresentedParent(self): + ''' after linking all these tree's the represented_parent nodes should represent the paths in the represented tree + ''' + #Arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + a4 = tree.makeTree('a4') + a5 = tree.makeTree('a5') + a6 = tree.makeTree('a6') + tree.link(a1,a2) + tree.link(a2,a3) + tree.link(a3,a4) + tree.link(a4,a5) + tree.link(a5,a6) + tree.access(a3) + + c1 = tree.makeTree('c1') + c2 = tree.makeTree('c2') + c3 = tree.makeTree('c3') + tree.link(c1,c2) + tree.link(c2,c3) + + tree.link(a4,c1); + + # act/assert + self.assertEqual(a1,a2.represented_parent) + self.assertEqual(a2,a3.represented_parent) + self.assertEqual(a3,a4.represented_parent) + self.assertEqual(a4,a5.represented_parent) + self.assertEqual(a5,a6.represented_parent) + + self.assertEqual(c1,c2.represented_parent) + self.assertEqual(c2,c3.represented_parent) + + self.assertEqual(a4,c1.represented_parent) + + def test_link_threeTrees_getRootShouldBeSameForAll(self): + tree = LinkCutTree() + #Arrange + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + a4 = tree.makeTree('a4') + a5 = tree.makeTree('a5') + a6 = tree.makeTree('a6') + tree.link(a1,a2) + tree.link(a2,a3) + tree.link(a3,a4) + tree.link(a4,a5) + tree.link(a5,a6) + + b1 = tree.makeTree('b1') + b2 = tree.makeTree('b2') + b3 = tree.makeTree('b3') + + tree.link(b1,b2) + tree.link(b2,b3) + + + tree.link(a3,b1) + + c1 = tree.makeTree('c1') + c2 = tree.makeTree('c2') + c3 = tree.makeTree('c3') + tree.link(c1,c2) + tree.link(c2,c3) + + tree.link(a5,c1) + + # act/assert + self.assertEqual(a1, tree.getRoot(c1)) + self.assertEqual(a1, tree.getRoot(b3)) + + def test_lca_balancedTree_ShouldReturnRoot(self): + #Arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + tree.link(a1,a2) + tree.link(a1,a3) + + self.assertEqual(tree.lca(a2,a3),a1) + + + def test_lca_path_shouldReturnOlderNode(self): + #Arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + tree.link(a1,a2) + tree.link(a2,a3) + + self.assertEqual(tree.lca(a2,a3),a2) + + def test_lca_query_order_doesnt_matter(self): + #Arrange + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + tree.link(a1,a2) + tree.link(a2,a3) + + self.assertEqual(tree.lca(a2,a3),tree.lca(a3,a2)) + + def test_lca_multipleLinks_shouldFindLCA(self): + tree = LinkCutTree() + #Arrange + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + a4 = tree.makeTree('a4') + a5 = tree.makeTree('a5') + a6 = tree.makeTree('a6') + tree.link(a1,a2) + tree.link(a2,a3) + tree.link(a3,a4) + tree.link(a4,a5) + tree.link(a5,a6) + + b1 = tree.makeTree('b1') + b2 = tree.makeTree('b2') + b3 = tree.makeTree('b3') + + tree.link(b1,b2) + tree.link(b2,b3) + + + tree.link(a3,b1) + + c1 = tree.makeTree('c1') + c2 = tree.makeTree('c2') + c3 = tree.makeTree('c3') + tree.link(c1,c2) + tree.link(c1,c3) + + tree.link(a5,c1) + + + self.assertEqual(tree.lca(c3,b3),a3) + + def test_makeRoot_path_shouldFlipPath(self): + + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + a4 = tree.makeTree('a4') + a5 = tree.makeTree('a5') + a6 = tree.makeTree('a6') + tree.link(a1,a2,1) + tree.link(a2,a3,2) + tree.link(a3,a4,3) + tree.link(a4,a5,4) + tree.link(a5,a6,5) + + #act + tree.makeRoot(a6) + + #assert + self.assertEqual(a6, tree.getRoot(a1)) + self.assertEqual(a6, a5.represented_parent) + self.assertEqual(a5, a4.represented_parent) + self.assertEqual(a4, a3.represented_parent) + self.assertEqual(a3, a2.represented_parent) + self.assertEqual(a2, a1.represented_parent) + + + + def test_makeRoot_flip_tree(self): + + tree = LinkCutTree() + a1 = tree.makeTree('a1') + a2 = tree.makeTree('a2') + a3 = tree.makeTree('a3') + a4 = tree.makeTree('a4') + a5 = tree.makeTree('a5') + a6 = tree.makeTree('a6') + tree.link(a1,a2,1) + tree.link(a2,a3,2) + tree.link(a3,a4,3) + tree.link(a4,a5,4) + tree.link(a5,a6,5) + + + b1 = tree.makeTree('b1') + b2 = tree.makeTree('b2') + b3 = tree.makeTree('b3') + + tree.link(b1,b2) + tree.link(b2,b3) + + + tree.link(a3,b1) + #act + tree.makeRoot(b3) + + #assert + + self.assertEqual(b3, tree.getRoot(a6)) + + # nodes on the flipped path still share root + self.assertEqual(tree.getRoot(b1), tree.getRoot(a2)) + + # node in flipped path and node out of path share root + self.assertEqual(tree.getRoot(a6), tree.getRoot(a2)) + + self.assertEqual(b2,b1.represented_parent) + self.assertEqual(b3,b2.represented_parent) + self.assertEqual(b1,a3.represented_parent) + self.assertEqual(a3,a2.represented_parent) + self.assertEqual(a2,a1.represented_parent) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_node.py b/tests/test_node.py new file mode 100644 index 0000000..225a70d --- /dev/null +++ b/tests/test_node.py @@ -0,0 +1,106 @@ +import unittest +from retroactive.basic import SplayNode + +class NodeTest(unittest.TestCase): + def test_left_zigzig_splay(self): + #arrange + a1 = SplayNode('a1') + a2 = SplayNode('a2') + a3 = SplayNode('a3') + a1.addRight(a2) + a2.addRight(a3) + #act + a3.splay() + #assert + self.assertEqual(a3.left,a2) + self.assertEqual(a3.right,None) + self.assertEqual(a3.parent, None) + + self.assertEqual(a2.left,a1) + self.assertEqual(a2.right,None) + self.assertEqual(a2.parent,a3) + + self.assertEqual(a1.left,None) + self.assertEqual(a1.right,None) + self.assertEqual(a1.parent,a2) + + + + + def test_left_zigzag_splay(self): + #arrange + a1 = SplayNode('a1') + a2 = SplayNode('a2') + a3 = SplayNode('a3') + a3L = SplayNode('a3L') + a3R = SplayNode('a3R') + a1.addRight(a2) + a2.addLeft(a3) + a3.addLeft(a3L) + a3.addRight(a3R) + #act + a3.splay() + #assert + self.assertEqual(a3.left,a1) + self.assertEqual(a3.right,a2) + + self.assertEqual(a1.right,a3L) + self.assertEqual(a1.left,None) + self.assertEqual(a1.parent,a3) + + self.assertEqual(a2.left,a3R) + self.assertEqual(a2.right,None) + self.assertEqual(a2.parent,a3) + + + def test_right_zigzig_splay(self): + #arrange + a1 = SplayNode('a1') + a2 = SplayNode('a2') + a3 = SplayNode('a3') + a1.addLeft(a2) + a2.addLeft(a3) + #act + a3.splay() + #assert + self.assertEqual(a3.right,a2) + self.assertEqual(a3.left,None) + self.assertEqual(a3.parent, None) + + self.assertEqual(a2.right,a1) + self.assertEqual(a2.left,None) + self.assertEqual(a2.parent,a3) + + self.assertEqual(a1.right,None) + self.assertEqual(a1.left,None) + self.assertEqual(a1.parent,a2) + + + + def test_right_zigzag_splay(self): + #arrange + a1 = SplayNode('a1') + a2 = SplayNode('a2') + a3 = SplayNode('a3') + a3L = SplayNode('a3L') + a3R = SplayNode('a3R') + a1.addLeft(a2) + a2.addRight(a3) + a3.addLeft(a3L) + a3.addRight(a3R) + #act + a3.splay() + #assert + self.assertEqual(a3.right,a1) + self.assertEqual(a3.left,a2) + + self.assertEqual(a1.left,a3R) + self.assertEqual(a1.right,None) + self.assertEqual(a1.parent,a3) + + self.assertEqual(a2.right,a3L) + self.assertEqual(a2.left,None) + self.assertEqual(a2.parent,a3) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_union_find.py b/tests/test_union_find.py new file mode 100644 index 0000000..3863583 --- /dev/null +++ b/tests/test_union_find.py @@ -0,0 +1,162 @@ +import unittest +from retroactive.full import RetroactiveUnionFind + +class Test_union_find(unittest.TestCase): + ''' + for these tests let x denote the edge that was cut + x = edge that was cut + b = the branch that contains x + preX = ancestors of x in b + postX = children of x in b + p = "parent" the part of the tree above the lca of the nodes to be cut + c = the path from the lca to the the other unioned node that does not contain a max edge + ''' + def test_unionAgo_unionAtSameTime_shouldHaveSameRoot(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2') + a1 = set.forest.getNode('a1') + a2 = set.forest.getNode('a2') + + self.assertEqual(set.forest.getRoot(a1),set.forest.getRoot(a2)) + + + def test_unionAgo_unionAtTime_samesetBeforeShouldFail(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + a1 = set.forest.getNode('a1') + a2 = set.forest.getNode('a2') + + self.assertEqual(set.sameSetAgo('a1','a2',2), False) + + def test_unionAgo_unionAtTime_samesetAfterShouldSucceed(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + a1 = set.forest.getNode('a1') + a2 = set.forest.getNode('a2') + + self.assertEqual(set.sameSetAgo('a1','a2',7), True) + + def test_unionAgo_unionTwice_unionTimeShouldChange(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a1','a2', tdelta = 2) + + self.assertEqual(set.sameSetAgo('a1','a2',2), True) + + + def test_unionAgo_unionTwiceSubTree_unionTimeShouldChange(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a2','a3', tdelta = 14) + set.unionAgo('a3','a4', tdelta = 7) + + set.unionAgo('a2','b1', tdelta = 7) + set.unionAgo('b1','b2', tdelta = 7) + set.unionAgo('b2','b3', tdelta = 7) + + self.assertEqual(set.sameSetAgo('b3','a4',6), False) + set.unionAgo('b3','a4',6) + self.assertEqual(set.sameSetAgo('b3','a4',6), True) + + def test_unionAgo_unionTwice_pToPreXIntact(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a2','a3', tdelta = 3) + set.unionAgo('a3','a4', tdelta = 7) + set.unionAgo('a4','a5', tdelta = 14) + set.unionAgo('a5','a6', tdelta = 7) + + set.unionAgo('a3','b1', tdelta = 7) + set.unionAgo('b1','b2', tdelta = 7) + set.unionAgo('b2','b3', tdelta = 7) + + set.unionAgo('b3','a6',6) + + # p to preX + self.assertEqual(set.sameSetAgo('a1','a4',7), True) + # check the other branch for safety + self.assertEqual(set.sameSetAgo('a1','b2',7), True) + + + def test_unionAgo_unionTwice_pToPostXIntact(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a2','a3', tdelta = 3) + set.unionAgo('a3','a4', tdelta = 7) + set.unionAgo('a4','a5', tdelta = 14) + set.unionAgo('a5','a6', tdelta = 7) + set.unionAgo('a6','a7', tdelta = 7) + set.unionAgo('a7','a8', tdelta = 7) + + + set.unionAgo('a3','b1', tdelta = 7) + set.unionAgo('b1','b2', tdelta = 7) + set.unionAgo('b2','b3', tdelta = 7) + + set.unionAgo('b3','a8',6) + + # c to postX + self.assertEqual(set.sameSetAgo('a1','a7',14), True) + + def test_unionAgo_unionTwice_insidePostXIntact(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a2','a3', tdelta = 3) + set.unionAgo('a3','a4', tdelta = 7) + set.unionAgo('a4','a5', tdelta = 14) + set.unionAgo('a5','a6', tdelta = 7) + set.unionAgo('a6','a7', tdelta = 7) + set.unionAgo('a7','a8', tdelta = 7) + + + set.unionAgo('a3','b1', tdelta = 7) + set.unionAgo('b1','b2', tdelta = 7) + set.unionAgo('b2','b3', tdelta = 7) + + set.unionAgo('b3','a8',6) + + # c to postX + self.assertEqual(set.sameSetAgo('a5','a7',7), True) + + def test_unionAgo_unionTwice_cToPostXIntact(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a2','a3', tdelta = 3) + set.unionAgo('a3','a4', tdelta = 7) + set.unionAgo('a4','a5', tdelta = 14) + set.unionAgo('a5','a6', tdelta = 7) + + set.unionAgo('a3','b1', tdelta = 7) + set.unionAgo('b1','b2', tdelta = 7) + set.unionAgo('b2','b3', tdelta = 7) + + set.unionAgo('b3','a6',6) + + # c to postX + self.assertEqual(set.sameSetAgo('b2','a5',14), True) + + + def test_unionAgo_unionTwice_cToPreXIntact(self): + set = RetroactiveUnionFind() + set.unionAgo('a1','a2', tdelta = 5) + set.unionAgo('a2','a3', tdelta = 3) + set.unionAgo('a3','a4', tdelta = 7) + set.unionAgo('a4','a5', tdelta = 14) + set.unionAgo('a5','a6', tdelta = 7) + + set.unionAgo('a3','b1', tdelta = 7) + set.unionAgo('b1','b2', tdelta = 7) + set.unionAgo('b2','b3', tdelta = 7) + + + + set.unionAgo('b3','a6',6) + + + # c to postX + self.assertEqual(set.sameSetAgo('b3','a4',7), True) + + + +if __name__ == '__main__': + unittest.main()