From 915d0747b6fc8d2f78a0891252ffc3d2a9971649 Mon Sep 17 00:00:00 2001 From: Hao Li Date: Thu, 3 Dec 2020 15:38:58 +0800 Subject: [PATCH 1/3] fix bug: edge case of avl delete --- data_structures/binary_tree/avl_tree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 3362610b9303..3239d7e092a7 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -204,14 +204,14 @@ def del_node(root, data): if root is None: return root if get_height(root.get_right()) - get_height(root.get_left()) == 2: - if get_height(root.get_right().get_right()) > get_height( + if get_height(root.get_right().get_right()) >= get_height( root.get_right().get_left() ): root = left_rotation(root) else: root = rl_rotation(root) elif get_height(root.get_right()) - get_height(root.get_left()) == -2: - if get_height(root.get_left().get_left()) > get_height( + if get_height(root.get_left().get_left()) >= get_height( root.get_left().get_right() ): root = right_rotation(root) From 6750484b2eddd7e90b2fe0e5e6448d1a8ba47678 Mon Sep 17 00:00:00 2001 From: Hao Li Date: Fri, 11 Dec 2020 13:09:51 +0800 Subject: [PATCH 2/3] add type hints --- data_structures/binary_tree/avl_tree.py | 61 +++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 3239d7e092a7..3b173b9194e5 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -8,6 +8,7 @@ import math import random +from typing import Any class my_queue: @@ -16,76 +17,76 @@ def __init__(self): self.head = 0 self.tail = 0 - def is_empty(self): + def is_empty(self) -> bool: return self.head == self.tail - def push(self, data): + def push(self, data: Any) -> None: self.data.append(data) self.tail = self.tail + 1 - def pop(self): + def pop(self) -> Any: ret = self.data[self.head] self.head = self.head + 1 return ret - def count(self): + def count(self) -> int: return self.tail - self.head - def print(self): + def print(self) -> None: print(self.data) print("**************") print(self.data[self.head : self.tail]) class my_node: - def __init__(self, data): + def __init__(self, data: Any): self.data = data self.left = None self.right = None self.height = 1 - def get_data(self): + def get_data(self) -> Any: return self.data - def get_left(self): + def get_left(self) -> "my_node": return self.left - def get_right(self): + def get_right(self) -> "my_node": return self.right - def get_height(self): + def get_height(self) -> int: return self.height - def set_data(self, data): + def set_data(self, data: Any) -> None: self.data = data return - def set_left(self, node): + def set_left(self, node: "my_node") -> None: self.left = node return - def set_right(self, node): + def set_right(self, node: "my_node") -> None: self.right = node return - def set_height(self, height): + def set_height(self, height) -> None: self.height = height return -def get_height(node): +def get_height(node: "my_node") -> int: if node is None: return 0 return node.get_height() -def my_max(a, b): +def my_max(a: Any, b: Any) -> Any: if a > b: return a return b -def right_rotation(node): +def right_rotation(node: "my_node") -> "my_node": r""" A B / \ / \ @@ -107,7 +108,7 @@ def right_rotation(node): return ret -def left_rotation(node): +def left_rotation(node: "my_node") -> "my_node": """ a mirror symmetry rotation of the left_rotation """ @@ -122,7 +123,7 @@ def left_rotation(node): return ret -def lr_rotation(node): +def lr_rotation(node: "my_node") -> "my_node": r""" A A Br / \ / \ / \ @@ -137,12 +138,12 @@ def lr_rotation(node): return right_rotation(node) -def rl_rotation(node): +def rl_rotation(node: "my_node") -> "my_node": node.set_right(right_rotation(node.get_right())) return left_rotation(node) -def insert_node(node, data): +def insert_node(node: "my_node", data: Any) -> "my_node": if node is None: return my_node(data) if data < node.get_data(): @@ -168,19 +169,19 @@ def insert_node(node, data): return node -def get_rightMost(root): +def get_rightMost(root: "my_node") -> "my_node": while root.get_right() is not None: root = root.get_right() return root.get_data() -def get_leftMost(root): +def get_leftMost(root: "my_node") -> "my_node": while root.get_left() is not None: root = root.get_left() return root.get_data() -def del_node(root, data): +def del_node(root: "my_node", data: Any) -> "my_node": if root.get_data() == data: if root.get_left() is not None and root.get_right() is not None: temp_data = get_leftMost(root.get_right()) @@ -259,22 +260,24 @@ class AVLtree: def __init__(self): self.root = None - def get_height(self): - # print("yyy") + def get_height(self) -> int: return get_height(self.root) - def insert(self, data): + def insert(self, data: Any) -> None: print("insert:" + str(data)) self.root = insert_node(self.root, data) - def del_node(self, data): + def del_node(self, data: Any) -> None: print("delete:" + str(data)) if self.root is None: print("Tree is empty!") return self.root = del_node(self.root, data) - def __str__(self): # a level traversale, gives a more intuitive look on the tree + def __str__(self): + """ + A level traversale, gives a more intuitive look on the tree + """ output = "" q = my_queue() q.push(self.root) From 76f9128542eecbb814399d01442563f02daf49f5 Mon Sep 17 00:00:00 2001 From: Hao Li Date: Fri, 11 Dec 2020 17:16:26 +0800 Subject: [PATCH 3/3] add test cases for AVL --- data_structures/binary_tree/avl_tree.py | 71 ++++++++++++++++++------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 3b173b9194e5..6227ba876957 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -8,11 +8,12 @@ import math import random +import unittest from typing import Any class my_queue: - def __init__(self): + def __init__(self) -> None: self.data = [] self.head = 0 self.tail = 0 @@ -39,7 +40,7 @@ def print(self) -> None: class my_node: - def __init__(self, data: Any): + def __init__(self, data: Any) -> None: self.data = data self.left = None self.right = None @@ -69,7 +70,7 @@ def set_right(self, node: "my_node") -> None: self.right = node return - def set_height(self, height) -> None: + def set_height(self, height: int) -> None: self.height = height return @@ -257,7 +258,7 @@ class AVLtree: ************************************* """ - def __init__(self): + def __init__(self) -> None: self.root = None def get_height(self) -> int: @@ -274,7 +275,7 @@ def del_node(self, data: Any) -> None: return self.root = del_node(self.root, data) - def __str__(self): + def __str__(self) -> str: """ A level traversale, gives a more intuitive look on the tree """ @@ -311,21 +312,51 @@ def __str__(self): return output -def _test(): - import doctest - - doctest.testmod() +class Test(unittest.TestCase): + def _is_balance(self, avl: AVLtree): + if avl.root is None: + return True + dfs = [avl.root] + while dfs: + now = dfs.pop() + if now.left: + left_height = now.left.height + dfs.append(now.left) + else: + left_height = 0 + if now.right: + right_height = now.right.height + dfs.append(now.right) + else: + right_height = 0 + if abs(left_height - right_height) > 1: + return False + return True + + def test_delete(self): + avl = AVLtree() + for i in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]: + avl.insert(i) + self.assertTrue(self._is_balance(avl)) + + for v in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]: + avl.del_node(v) + print(avl) + self.assertTrue(self._is_balance(avl)) + + def test_delete_random(self): + avl = AVLtree() + random.seed(0) + values = list(range(1000)) + random.shuffle(values) + for i in values: + avl.insert(i) + self.assertTrue(self._is_balance(avl)) + random.shuffle(values) + for i in values: + avl.del_node(i) + self.assertTrue(self._is_balance(avl)) if __name__ == "__main__": - _test() - t = AVLtree() - lst = list(range(10)) - random.shuffle(lst) - for i in lst: - t.insert(i) - print(str(t)) - random.shuffle(lst) - for i in lst: - t.del_node(i) - print(str(t)) + unittest.main()