Skip to content

Commit cd53326

Browse files
Ramy-Badr-Ahmedpre-commit-ci[bot]cclauss
committed
kd tree data structure implementation (TheAlgorithms#11532)
* Implemented KD-Tree Data Structure * Implemented KD-Tree Data Structure. updated DIRECTORY.md. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Create __init__.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replaced legacy `np.random.rand` call with `np.random.Generator` in kd_tree/example_usage.py * Replaced legacy `np.random.rand` call with `np.random.Generator` in kd_tree/hypercube_points.py * added typehints and docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docstring for search() * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added tests. Updated docstrings/typehints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated tests and used | for type annotations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * E501 for build_kdtree.py, hypercube_points.py, nearest_neighbour_search.py * I001 for example_usage.py and test_kdtree.py * I001 for example_usage.py and test_kdtree.py * Update data_structures/kd_tree/build_kdtree.py Co-authored-by: Christian Clauss <cclauss@me.com> * Update data_structures/kd_tree/example/hypercube_points.py Co-authored-by: Christian Clauss <cclauss@me.com> * Update data_structures/kd_tree/example/hypercube_points.py Co-authored-by: Christian Clauss <cclauss@me.com> * Added new test cases requested in Review. Refactored the test_build_kdtree() to include various checks. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Considered ruff errors * Considered ruff errors * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update kd_node.py * imported annotations from __future__ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Christian Clauss <cclauss@me.com>
1 parent 7987d53 commit cd53326

10 files changed

+301
-0
lines changed

DIRECTORY.md

+6
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@
285285
* Trie
286286
* [Radix Tree](data_structures/trie/radix_tree.py)
287287
* [Trie](data_structures/trie/trie.py)
288+
* KD Tree
289+
* [KD Tree Node](data_structures/kd_tree/kd_node.py)
290+
* [Build KD Tree](data_structures/kd_tree/build_kdtree.py)
291+
* [Nearest Neighbour Search](data_structures/kd_tree/nearest_neighbour_search.py)
292+
* [Hypercibe Points](data_structures/kd_tree/example/hypercube_points.py)
293+
* [Example Usage](data_structures/kd_tree/example/example_usage.py)
288294

289295
## Digital Image Processing
290296
* [Change Brightness](digital_image_processing/change_brightness.py)

data_structures/kd_tree/__init__.py

Whitespace-only changes.
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from data_structures.kd_tree.kd_node import KDNode
2+
3+
4+
def build_kdtree(points: list[list[float]], depth: int = 0) -> KDNode | None:
5+
"""
6+
Builds a KD-Tree from a list of points.
7+
8+
Args:
9+
points: The list of points to build the KD-Tree from.
10+
depth: The current depth in the tree
11+
(used to determine axis for splitting).
12+
13+
Returns:
14+
The root node of the KD-Tree,
15+
or None if no points are provided.
16+
"""
17+
if not points:
18+
return None
19+
20+
k = len(points[0]) # Dimensionality of the points
21+
axis = depth % k
22+
23+
# Sort point list and choose median as pivot element
24+
points.sort(key=lambda point: point[axis])
25+
median_idx = len(points) // 2
26+
27+
# Create node and construct subtrees
28+
left_points = points[:median_idx]
29+
right_points = points[median_idx + 1 :]
30+
31+
return KDNode(
32+
point=points[median_idx],
33+
left=build_kdtree(left_points, depth + 1),
34+
right=build_kdtree(right_points, depth + 1),
35+
)

data_structures/kd_tree/example/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
3+
from data_structures.kd_tree.build_kdtree import build_kdtree
4+
from data_structures.kd_tree.example.hypercube_points import hypercube_points
5+
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
6+
7+
8+
def main() -> None:
9+
"""
10+
Demonstrates the use of KD-Tree by building it from random points
11+
in a 10-dimensional hypercube and performing a nearest neighbor search.
12+
"""
13+
num_points: int = 5000
14+
cube_size: float = 10.0 # Size of the hypercube (edge length)
15+
num_dimensions: int = 10
16+
17+
# Generate random points within the hypercube
18+
points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions)
19+
hypercube_kdtree = build_kdtree(points.tolist())
20+
21+
# Generate a random query point within the same space
22+
rng = np.random.default_rng()
23+
query_point: list[float] = rng.random(num_dimensions).tolist()
24+
25+
# Perform nearest neighbor search
26+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
27+
hypercube_kdtree, query_point
28+
)
29+
30+
# Print the results
31+
print(f"Query point: {query_point}")
32+
print(f"Nearest point: {nearest_point}")
33+
print(f"Distance: {nearest_dist:.4f}")
34+
print(f"Nodes visited: {nodes_visited}")
35+
36+
37+
if __name__ == "__main__":
38+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
3+
4+
def hypercube_points(
5+
num_points: int, hypercube_size: float, num_dimensions: int
6+
) -> np.ndarray:
7+
"""
8+
Generates random points uniformly distributed within an n-dimensional hypercube.
9+
10+
Args:
11+
num_points: Number of points to generate.
12+
hypercube_size: Size of the hypercube.
13+
num_dimensions: Number of dimensions of the hypercube.
14+
15+
Returns:
16+
An array of shape (num_points, num_dimensions)
17+
with generated points.
18+
"""
19+
rng = np.random.default_rng()
20+
shape = (num_points, num_dimensions)
21+
return hypercube_size * rng.random(shape)

data_structures/kd_tree/kd_node.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
4+
class KDNode:
5+
"""
6+
Represents a node in a KD-Tree.
7+
8+
Attributes:
9+
point: The point stored in this node.
10+
left: The left child node.
11+
right: The right child node.
12+
"""
13+
14+
def __init__(
15+
self,
16+
point: list[float],
17+
left: KDNode | None = None,
18+
right: KDNode | None = None,
19+
) -> None:
20+
"""
21+
Initializes a KDNode with the given point and child nodes.
22+
23+
Args:
24+
point (list[float]): The point stored in this node.
25+
left (Optional[KDNode]): The left child node.
26+
right (Optional[KDNode]): The right child node.
27+
"""
28+
self.point = point
29+
self.left = left
30+
self.right = right
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from data_structures.kd_tree.kd_node import KDNode
2+
3+
4+
def nearest_neighbour_search(
5+
root: KDNode | None, query_point: list[float]
6+
) -> tuple[list[float] | None, float, int]:
7+
"""
8+
Performs a nearest neighbor search in a KD-Tree for a given query point.
9+
10+
Args:
11+
root (KDNode | None): The root node of the KD-Tree.
12+
query_point (list[float]): The point for which the nearest neighbor
13+
is being searched.
14+
15+
Returns:
16+
tuple[list[float] | None, float, int]:
17+
- The nearest point found in the KD-Tree to the query point,
18+
or None if no point is found.
19+
- The squared distance to the nearest point.
20+
- The number of nodes visited during the search.
21+
"""
22+
nearest_point: list[float] | None = None
23+
nearest_dist: float = float("inf")
24+
nodes_visited: int = 0
25+
26+
def search(node: KDNode | None, depth: int = 0) -> None:
27+
"""
28+
Recursively searches for the nearest neighbor in the KD-Tree.
29+
30+
Args:
31+
node: The current node in the KD-Tree.
32+
depth: The current depth in the KD-Tree.
33+
"""
34+
nonlocal nearest_point, nearest_dist, nodes_visited
35+
if node is None:
36+
return
37+
38+
nodes_visited += 1
39+
40+
# Calculate the current distance (squared distance)
41+
current_point = node.point
42+
current_dist = sum(
43+
(query_coord - point_coord) ** 2
44+
for query_coord, point_coord in zip(query_point, current_point)
45+
)
46+
47+
# Update nearest point if the current node is closer
48+
if nearest_point is None or current_dist < nearest_dist:
49+
nearest_point = current_point
50+
nearest_dist = current_dist
51+
52+
# Determine which subtree to search first (based on axis and query point)
53+
k = len(query_point) # Dimensionality of points
54+
axis = depth % k
55+
56+
if query_point[axis] <= current_point[axis]:
57+
nearer_subtree = node.left
58+
further_subtree = node.right
59+
else:
60+
nearer_subtree = node.right
61+
further_subtree = node.left
62+
63+
# Search the nearer subtree first
64+
search(nearer_subtree, depth + 1)
65+
66+
# If the further subtree has a closer point
67+
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist:
68+
search(further_subtree, depth + 1)
69+
70+
search(root, 0)
71+
return nearest_point, nearest_dist, nodes_visited

data_structures/kd_tree/tests/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import numpy as np
2+
import pytest
3+
4+
from data_structures.kd_tree.build_kdtree import build_kdtree
5+
from data_structures.kd_tree.example.hypercube_points import hypercube_points
6+
from data_structures.kd_tree.kd_node import KDNode
7+
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
8+
9+
10+
@pytest.mark.parametrize(
11+
("num_points", "cube_size", "num_dimensions", "depth", "expected_result"),
12+
[
13+
(0, 10.0, 2, 0, None), # Empty points list
14+
(10, 10.0, 2, 2, KDNode), # Depth = 2, 2D points
15+
(10, 10.0, 3, -2, KDNode), # Depth = -2, 3D points
16+
],
17+
)
18+
def test_build_kdtree(num_points, cube_size, num_dimensions, depth, expected_result):
19+
"""
20+
Test that KD-Tree is built correctly.
21+
22+
Cases:
23+
- Empty points list.
24+
- Positive depth value.
25+
- Negative depth value.
26+
"""
27+
points = (
28+
hypercube_points(num_points, cube_size, num_dimensions).tolist()
29+
if num_points > 0
30+
else []
31+
)
32+
33+
kdtree = build_kdtree(points, depth=depth)
34+
35+
if expected_result is None:
36+
# Empty points list case
37+
assert kdtree is None, f"Expected None for empty points list, got {kdtree}"
38+
else:
39+
# Check if root node is not None
40+
assert kdtree is not None, "Expected a KDNode, got None"
41+
42+
# Check if root has correct dimensions
43+
assert (
44+
len(kdtree.point) == num_dimensions
45+
), f"Expected point dimension {num_dimensions}, got {len(kdtree.point)}"
46+
47+
# Check that the tree is balanced to some extent (simplistic check)
48+
assert isinstance(
49+
kdtree, KDNode
50+
), f"Expected KDNode instance, got {type(kdtree)}"
51+
52+
53+
def test_nearest_neighbour_search():
54+
"""
55+
Test the nearest neighbor search function.
56+
"""
57+
num_points = 10
58+
cube_size = 10.0
59+
num_dimensions = 2
60+
points = hypercube_points(num_points, cube_size, num_dimensions)
61+
kdtree = build_kdtree(points.tolist())
62+
63+
rng = np.random.default_rng()
64+
query_point = rng.random(num_dimensions).tolist()
65+
66+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
67+
kdtree, query_point
68+
)
69+
70+
# Check that nearest point is not None
71+
assert nearest_point is not None
72+
73+
# Check that distance is a non-negative number
74+
assert nearest_dist >= 0
75+
76+
# Check that nodes visited is a non-negative integer
77+
assert nodes_visited >= 0
78+
79+
80+
def test_edge_cases():
81+
"""
82+
Test edge cases such as an empty KD-Tree.
83+
"""
84+
empty_kdtree = build_kdtree([])
85+
query_point = [0.0] * 2 # Using a default 2D query point
86+
87+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
88+
empty_kdtree, query_point
89+
)
90+
91+
# With an empty KD-Tree, nearest_point should be None
92+
assert nearest_point is None
93+
assert nearest_dist == float("inf")
94+
assert nodes_visited == 0
95+
96+
97+
if __name__ == "__main__":
98+
import pytest
99+
100+
pytest.main()

0 commit comments

Comments
 (0)