From 915d0747b6fc8d2f78a0891252ffc3d2a9971649 Mon Sep 17 00:00:00 2001
From: Hao Li
Date: Thu, 3 Dec 2020 15:38:58 +0800
Subject: [PATCH 1/3] fix bug: edge case of avl delete
---
data_structures/binary_tree/avl_tree.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py
index 3362610b9303..3239d7e092a7 100644
--- a/data_structures/binary_tree/avl_tree.py
+++ b/data_structures/binary_tree/avl_tree.py
@@ -204,14 +204,14 @@ def del_node(root, data):
if root is None:
return root
if get_height(root.get_right()) - get_height(root.get_left()) == 2:
- if get_height(root.get_right().get_right()) > get_height(
+ if get_height(root.get_right().get_right()) >= get_height(
root.get_right().get_left()
):
root = left_rotation(root)
else:
root = rl_rotation(root)
elif get_height(root.get_right()) - get_height(root.get_left()) == -2:
- if get_height(root.get_left().get_left()) > get_height(
+ if get_height(root.get_left().get_left()) >= get_height(
root.get_left().get_right()
):
root = right_rotation(root)
From 6750484b2eddd7e90b2fe0e5e6448d1a8ba47678 Mon Sep 17 00:00:00 2001
From: Hao Li
Date: Fri, 11 Dec 2020 13:09:51 +0800
Subject: [PATCH 2/3] add type hints
---
data_structures/binary_tree/avl_tree.py | 61 +++++++++++++------------
1 file changed, 32 insertions(+), 29 deletions(-)
diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py
index 3239d7e092a7..3b173b9194e5 100644
--- a/data_structures/binary_tree/avl_tree.py
+++ b/data_structures/binary_tree/avl_tree.py
@@ -8,6 +8,7 @@
import math
import random
+from typing import Any
class my_queue:
@@ -16,76 +17,76 @@ def __init__(self):
self.head = 0
self.tail = 0
- def is_empty(self):
+ def is_empty(self) -> bool:
return self.head == self.tail
- def push(self, data):
+ def push(self, data: Any) -> None:
self.data.append(data)
self.tail = self.tail + 1
- def pop(self):
+ def pop(self) -> Any:
ret = self.data[self.head]
self.head = self.head + 1
return ret
- def count(self):
+ def count(self) -> int:
return self.tail - self.head
- def print(self):
+ def print(self) -> None:
print(self.data)
print("**************")
print(self.data[self.head : self.tail])
class my_node:
- def __init__(self, data):
+ def __init__(self, data: Any):
self.data = data
self.left = None
self.right = None
self.height = 1
- def get_data(self):
+ def get_data(self) -> Any:
return self.data
- def get_left(self):
+ def get_left(self) -> "my_node":
return self.left
- def get_right(self):
+ def get_right(self) -> "my_node":
return self.right
- def get_height(self):
+ def get_height(self) -> int:
return self.height
- def set_data(self, data):
+ def set_data(self, data: Any) -> None:
self.data = data
return
- def set_left(self, node):
+ def set_left(self, node: "my_node") -> None:
self.left = node
return
- def set_right(self, node):
+ def set_right(self, node: "my_node") -> None:
self.right = node
return
- def set_height(self, height):
+ def set_height(self, height) -> None:
self.height = height
return
-def get_height(node):
+def get_height(node: "my_node") -> int:
if node is None:
return 0
return node.get_height()
-def my_max(a, b):
+def my_max(a: Any, b: Any) -> Any:
if a > b:
return a
return b
-def right_rotation(node):
+def right_rotation(node: "my_node") -> "my_node":
r"""
A B
/ \ / \
@@ -107,7 +108,7 @@ def right_rotation(node):
return ret
-def left_rotation(node):
+def left_rotation(node: "my_node") -> "my_node":
"""
a mirror symmetry rotation of the left_rotation
"""
@@ -122,7 +123,7 @@ def left_rotation(node):
return ret
-def lr_rotation(node):
+def lr_rotation(node: "my_node") -> "my_node":
r"""
A A Br
/ \ / \ / \
@@ -137,12 +138,12 @@ def lr_rotation(node):
return right_rotation(node)
-def rl_rotation(node):
+def rl_rotation(node: "my_node") -> "my_node":
node.set_right(right_rotation(node.get_right()))
return left_rotation(node)
-def insert_node(node, data):
+def insert_node(node: "my_node", data: Any) -> "my_node":
if node is None:
return my_node(data)
if data < node.get_data():
@@ -168,19 +169,19 @@ def insert_node(node, data):
return node
-def get_rightMost(root):
+def get_rightMost(root: "my_node") -> "my_node":
while root.get_right() is not None:
root = root.get_right()
return root.get_data()
-def get_leftMost(root):
+def get_leftMost(root: "my_node") -> "my_node":
while root.get_left() is not None:
root = root.get_left()
return root.get_data()
-def del_node(root, data):
+def del_node(root: "my_node", data: Any) -> "my_node":
if root.get_data() == data:
if root.get_left() is not None and root.get_right() is not None:
temp_data = get_leftMost(root.get_right())
@@ -259,22 +260,24 @@ class AVLtree:
def __init__(self):
self.root = None
- def get_height(self):
- # print("yyy")
+ def get_height(self) -> int:
return get_height(self.root)
- def insert(self, data):
+ def insert(self, data: Any) -> None:
print("insert:" + str(data))
self.root = insert_node(self.root, data)
- def del_node(self, data):
+ def del_node(self, data: Any) -> None:
print("delete:" + str(data))
if self.root is None:
print("Tree is empty!")
return
self.root = del_node(self.root, data)
- def __str__(self): # a level traversale, gives a more intuitive look on the tree
+ def __str__(self):
+ """
+ A level traversale, gives a more intuitive look on the tree
+ """
output = ""
q = my_queue()
q.push(self.root)
From 76f9128542eecbb814399d01442563f02daf49f5 Mon Sep 17 00:00:00 2001
From: Hao Li
Date: Fri, 11 Dec 2020 17:16:26 +0800
Subject: [PATCH 3/3] add test cases for AVL
---
data_structures/binary_tree/avl_tree.py | 71 ++++++++++++++++++-------
1 file changed, 51 insertions(+), 20 deletions(-)
diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py
index 3b173b9194e5..6227ba876957 100644
--- a/data_structures/binary_tree/avl_tree.py
+++ b/data_structures/binary_tree/avl_tree.py
@@ -8,11 +8,12 @@
import math
import random
+import unittest
from typing import Any
class my_queue:
- def __init__(self):
+ def __init__(self) -> None:
self.data = []
self.head = 0
self.tail = 0
@@ -39,7 +40,7 @@ def print(self) -> None:
class my_node:
- def __init__(self, data: Any):
+ def __init__(self, data: Any) -> None:
self.data = data
self.left = None
self.right = None
@@ -69,7 +70,7 @@ def set_right(self, node: "my_node") -> None:
self.right = node
return
- def set_height(self, height) -> None:
+ def set_height(self, height: int) -> None:
self.height = height
return
@@ -257,7 +258,7 @@ class AVLtree:
*************************************
"""
- def __init__(self):
+ def __init__(self) -> None:
self.root = None
def get_height(self) -> int:
@@ -274,7 +275,7 @@ def del_node(self, data: Any) -> None:
return
self.root = del_node(self.root, data)
- def __str__(self):
+ def __str__(self) -> str:
"""
A level traversale, gives a more intuitive look on the tree
"""
@@ -311,21 +312,51 @@ def __str__(self):
return output
-def _test():
- import doctest
-
- doctest.testmod()
+class Test(unittest.TestCase):
+ def _is_balance(self, avl: AVLtree):
+ if avl.root is None:
+ return True
+ dfs = [avl.root]
+ while dfs:
+ now = dfs.pop()
+ if now.left:
+ left_height = now.left.height
+ dfs.append(now.left)
+ else:
+ left_height = 0
+ if now.right:
+ right_height = now.right.height
+ dfs.append(now.right)
+ else:
+ right_height = 0
+ if abs(left_height - right_height) > 1:
+ return False
+ return True
+
+ def test_delete(self):
+ avl = AVLtree()
+ for i in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]:
+ avl.insert(i)
+ self.assertTrue(self._is_balance(avl))
+
+ for v in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]:
+ avl.del_node(v)
+ print(avl)
+ self.assertTrue(self._is_balance(avl))
+
+ def test_delete_random(self):
+ avl = AVLtree()
+ random.seed(0)
+ values = list(range(1000))
+ random.shuffle(values)
+ for i in values:
+ avl.insert(i)
+ self.assertTrue(self._is_balance(avl))
+ random.shuffle(values)
+ for i in values:
+ avl.del_node(i)
+ self.assertTrue(self._is_balance(avl))
if __name__ == "__main__":
- _test()
- t = AVLtree()
- lst = list(range(10))
- random.shuffle(lst)
- for i in lst:
- t.insert(i)
- print(str(t))
- random.shuffle(lst)
- for i in lst:
- t.del_node(i)
- print(str(t))
+ unittest.main()