forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
kd tree data structure implementation (TheAlgorithms#11532)
* Implemented KD-Tree Data Structure * Implemented KD-Tree Data Structure. updated DIRECTORY.md. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Create __init__.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replaced legacy `np.random.rand` call with `np.random.Generator` in kd_tree/example_usage.py * Replaced legacy `np.random.rand` call with `np.random.Generator` in kd_tree/hypercube_points.py * added typehints and docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docstring for search() * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added tests. Updated docstrings/typehints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated tests and used | for type annotations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * E501 for build_kdtree.py, hypercube_points.py, nearest_neighbour_search.py * I001 for example_usage.py and test_kdtree.py * I001 for example_usage.py and test_kdtree.py * Update data_structures/kd_tree/build_kdtree.py Co-authored-by: Christian Clauss <cclauss@me.com> * Update data_structures/kd_tree/example/hypercube_points.py Co-authored-by: Christian Clauss <cclauss@me.com> * Update data_structures/kd_tree/example/hypercube_points.py Co-authored-by: Christian Clauss <cclauss@me.com> * Added new test cases requested in Review. Refactored the test_build_kdtree() to include various checks. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Considered ruff errors * Considered ruff errors * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update kd_node.py * imported annotations from __future__ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Christian Clauss <cclauss@me.com>
- Loading branch information
1 parent
cda92e1
commit 883afdf
Showing
10 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from data_structures.kd_tree.kd_node import KDNode | ||
|
||
|
||
def build_kdtree(points: list[list[float]], depth: int = 0) -> KDNode | None: | ||
""" | ||
Builds a KD-Tree from a list of points. | ||
Args: | ||
points: The list of points to build the KD-Tree from. | ||
depth: The current depth in the tree | ||
(used to determine axis for splitting). | ||
Returns: | ||
The root node of the KD-Tree, | ||
or None if no points are provided. | ||
""" | ||
if not points: | ||
return None | ||
|
||
k = len(points[0]) # Dimensionality of the points | ||
axis = depth % k | ||
|
||
# Sort point list and choose median as pivot element | ||
points.sort(key=lambda point: point[axis]) | ||
median_idx = len(points) // 2 | ||
|
||
# Create node and construct subtrees | ||
left_points = points[:median_idx] | ||
right_points = points[median_idx + 1 :] | ||
|
||
return KDNode( | ||
point=points[median_idx], | ||
left=build_kdtree(left_points, depth + 1), | ||
right=build_kdtree(right_points, depth + 1), | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
|
||
from data_structures.kd_tree.build_kdtree import build_kdtree | ||
from data_structures.kd_tree.example.hypercube_points import hypercube_points | ||
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
|
||
|
||
def main() -> None: | ||
""" | ||
Demonstrates the use of KD-Tree by building it from random points | ||
in a 10-dimensional hypercube and performing a nearest neighbor search. | ||
""" | ||
num_points: int = 5000 | ||
cube_size: float = 10.0 # Size of the hypercube (edge length) | ||
num_dimensions: int = 10 | ||
|
||
# Generate random points within the hypercube | ||
points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions) | ||
hypercube_kdtree = build_kdtree(points.tolist()) | ||
|
||
# Generate a random query point within the same space | ||
rng = np.random.default_rng() | ||
query_point: list[float] = rng.random(num_dimensions).tolist() | ||
|
||
# Perform nearest neighbor search | ||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
hypercube_kdtree, query_point | ||
) | ||
|
||
# Print the results | ||
print(f"Query point: {query_point}") | ||
print(f"Nearest point: {nearest_point}") | ||
print(f"Distance: {nearest_dist:.4f}") | ||
print(f"Nodes visited: {nodes_visited}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import numpy as np | ||
|
||
|
||
def hypercube_points( | ||
num_points: int, hypercube_size: float, num_dimensions: int | ||
) -> np.ndarray: | ||
""" | ||
Generates random points uniformly distributed within an n-dimensional hypercube. | ||
Args: | ||
num_points: Number of points to generate. | ||
hypercube_size: Size of the hypercube. | ||
num_dimensions: Number of dimensions of the hypercube. | ||
Returns: | ||
An array of shape (num_points, num_dimensions) | ||
with generated points. | ||
""" | ||
rng = np.random.default_rng() | ||
shape = (num_points, num_dimensions) | ||
return hypercube_size * rng.random(shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from __future__ import annotations | ||
|
||
|
||
class KDNode: | ||
""" | ||
Represents a node in a KD-Tree. | ||
Attributes: | ||
point: The point stored in this node. | ||
left: The left child node. | ||
right: The right child node. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
point: list[float], | ||
left: KDNode | None = None, | ||
right: KDNode | None = None, | ||
) -> None: | ||
""" | ||
Initializes a KDNode with the given point and child nodes. | ||
Args: | ||
point (list[float]): The point stored in this node. | ||
left (Optional[KDNode]): The left child node. | ||
right (Optional[KDNode]): The right child node. | ||
""" | ||
self.point = point | ||
self.left = left | ||
self.right = right |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from data_structures.kd_tree.kd_node import KDNode | ||
|
||
|
||
def nearest_neighbour_search( | ||
root: KDNode | None, query_point: list[float] | ||
) -> tuple[list[float] | None, float, int]: | ||
""" | ||
Performs a nearest neighbor search in a KD-Tree for a given query point. | ||
Args: | ||
root (KDNode | None): The root node of the KD-Tree. | ||
query_point (list[float]): The point for which the nearest neighbor | ||
is being searched. | ||
Returns: | ||
tuple[list[float] | None, float, int]: | ||
- The nearest point found in the KD-Tree to the query point, | ||
or None if no point is found. | ||
- The squared distance to the nearest point. | ||
- The number of nodes visited during the search. | ||
""" | ||
nearest_point: list[float] | None = None | ||
nearest_dist: float = float("inf") | ||
nodes_visited: int = 0 | ||
|
||
def search(node: KDNode | None, depth: int = 0) -> None: | ||
""" | ||
Recursively searches for the nearest neighbor in the KD-Tree. | ||
Args: | ||
node: The current node in the KD-Tree. | ||
depth: The current depth in the KD-Tree. | ||
""" | ||
nonlocal nearest_point, nearest_dist, nodes_visited | ||
if node is None: | ||
return | ||
|
||
nodes_visited += 1 | ||
|
||
# Calculate the current distance (squared distance) | ||
current_point = node.point | ||
current_dist = sum( | ||
(query_coord - point_coord) ** 2 | ||
for query_coord, point_coord in zip(query_point, current_point) | ||
) | ||
|
||
# Update nearest point if the current node is closer | ||
if nearest_point is None or current_dist < nearest_dist: | ||
nearest_point = current_point | ||
nearest_dist = current_dist | ||
|
||
# Determine which subtree to search first (based on axis and query point) | ||
k = len(query_point) # Dimensionality of points | ||
axis = depth % k | ||
|
||
if query_point[axis] <= current_point[axis]: | ||
nearer_subtree = node.left | ||
further_subtree = node.right | ||
else: | ||
nearer_subtree = node.right | ||
further_subtree = node.left | ||
|
||
# Search the nearer subtree first | ||
search(nearer_subtree, depth + 1) | ||
|
||
# If the further subtree has a closer point | ||
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist: | ||
search(further_subtree, depth + 1) | ||
|
||
search(root, 0) | ||
return nearest_point, nearest_dist, nodes_visited |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from data_structures.kd_tree.build_kdtree import build_kdtree | ||
from data_structures.kd_tree.example.hypercube_points import hypercube_points | ||
from data_structures.kd_tree.kd_node import KDNode | ||
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("num_points", "cube_size", "num_dimensions", "depth", "expected_result"), | ||
[ | ||
(0, 10.0, 2, 0, None), # Empty points list | ||
(10, 10.0, 2, 2, KDNode), # Depth = 2, 2D points | ||
(10, 10.0, 3, -2, KDNode), # Depth = -2, 3D points | ||
], | ||
) | ||
def test_build_kdtree(num_points, cube_size, num_dimensions, depth, expected_result): | ||
""" | ||
Test that KD-Tree is built correctly. | ||
Cases: | ||
- Empty points list. | ||
- Positive depth value. | ||
- Negative depth value. | ||
""" | ||
points = ( | ||
hypercube_points(num_points, cube_size, num_dimensions).tolist() | ||
if num_points > 0 | ||
else [] | ||
) | ||
|
||
kdtree = build_kdtree(points, depth=depth) | ||
|
||
if expected_result is None: | ||
# Empty points list case | ||
assert kdtree is None, f"Expected None for empty points list, got {kdtree}" | ||
else: | ||
# Check if root node is not None | ||
assert kdtree is not None, "Expected a KDNode, got None" | ||
|
||
# Check if root has correct dimensions | ||
assert ( | ||
len(kdtree.point) == num_dimensions | ||
), f"Expected point dimension {num_dimensions}, got {len(kdtree.point)}" | ||
|
||
# Check that the tree is balanced to some extent (simplistic check) | ||
assert isinstance( | ||
kdtree, KDNode | ||
), f"Expected KDNode instance, got {type(kdtree)}" | ||
|
||
|
||
def test_nearest_neighbour_search(): | ||
""" | ||
Test the nearest neighbor search function. | ||
""" | ||
num_points = 10 | ||
cube_size = 10.0 | ||
num_dimensions = 2 | ||
points = hypercube_points(num_points, cube_size, num_dimensions) | ||
kdtree = build_kdtree(points.tolist()) | ||
|
||
rng = np.random.default_rng() | ||
query_point = rng.random(num_dimensions).tolist() | ||
|
||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
kdtree, query_point | ||
) | ||
|
||
# Check that nearest point is not None | ||
assert nearest_point is not None | ||
|
||
# Check that distance is a non-negative number | ||
assert nearest_dist >= 0 | ||
|
||
# Check that nodes visited is a non-negative integer | ||
assert nodes_visited >= 0 | ||
|
||
|
||
def test_edge_cases(): | ||
""" | ||
Test edge cases such as an empty KD-Tree. | ||
""" | ||
empty_kdtree = build_kdtree([]) | ||
query_point = [0.0] * 2 # Using a default 2D query point | ||
|
||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
empty_kdtree, query_point | ||
) | ||
|
||
# With an empty KD-Tree, nearest_point should be None | ||
assert nearest_point is None | ||
assert nearest_dist == float("inf") | ||
assert nodes_visited == 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
|
||
pytest.main() |