Uh oh!
There was an error while loading. Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork 50k
Implemented KD Tree Data Structure#11532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Uh oh!
There was an error while loading. Please reload this page.
Changes from all commits
0d6985c6665d236b3d47e4203cda3222bd3a41ae5b1668d7381d69176cddcbd8b238d1cd1dd9fead2838543584cad31f8313229214608a9f7c1aa7eba24e7505975a331782d16a9b3e12fd24d42cf9d92a3803eef1f5862ec6559d5c07a1a3c09ac1bab43e7a10ff15d77a2850426806File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading. Please reload this page.
Jump to
Uh oh!
There was an error while loading. Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from data_structures.kd_tree.kd_node import KDNode | ||
| def build_kdtree(points: list[list[float]], depth: int = 0) -> KDNode | None: | ||
| """ | ||
| Builds a KD-Tree from a list of points. | ||
| Args: | ||
| points: The list of points to build the KD-Tree from. | ||
| depth: The current depth in the tree | ||
| (used to determine axis for splitting). | ||
| Returns: | ||
| The root node of the KD-Tree, | ||
| or None if no points are provided. | ||
| """ | ||
| if not points: | ||
| return None | ||
| k = len(points[0]) # Dimensionality of the points | ||
| axis = depth % k | ||
| # Sort point list and choose median as pivot element | ||
| points.sort(key=lambda point: point[axis]) | ||
| median_idx = len(points) // 2 | ||
| # Create node and construct subtrees | ||
| left_points = points[:median_idx] | ||
| right_points = points[median_idx + 1 :] | ||
| return KDNode( | ||
| point=points[median_idx], | ||
| left=build_kdtree(left_points, depth + 1), | ||
| right=build_kdtree(right_points, depth + 1), | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import numpy as np | ||
| from data_structures.kd_tree.build_kdtree import build_kdtree | ||
| from data_structures.kd_tree.example.hypercube_points import hypercube_points | ||
| from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
| def main() -> None: | ||
Ramy-Badr-Ahmed marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading. Please reload this page. | ||
| """ | ||
| Demonstrates the use of KD-Tree by building it from random points | ||
| in a 10-dimensional hypercube and performing a nearest neighbor search. | ||
| """ | ||
| num_points: int = 5000 | ||
| cube_size: float = 10.0 # Size of the hypercube (edge length) | ||
| num_dimensions: int = 10 | ||
| # Generate random points within the hypercube | ||
| points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions) | ||
| hypercube_kdtree = build_kdtree(points.tolist()) | ||
| # Generate a random query point within the same space | ||
| rng = np.random.default_rng() | ||
| query_point: list[float] = rng.random(num_dimensions).tolist() | ||
| # Perform nearest neighbor search | ||
| nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
| hypercube_kdtree, query_point | ||
| ) | ||
| # Print the results | ||
| print(f"Query point:{query_point}") | ||
| print(f"Nearest point:{nearest_point}") | ||
| print(f"Distance:{nearest_dist:.4f}") | ||
| print(f"Nodes visited:{nodes_visited}") | ||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| import numpy as np | ||
| def hypercube_points( | ||
Ramy-Badr-Ahmed marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading. Please reload this page. | ||
| num_points: int, hypercube_size: float, num_dimensions: int | ||
| ) -> np.ndarray: | ||
| """ | ||
| Generates random points uniformly distributed within an n-dimensional hypercube. | ||
| Args: | ||
| num_points: Number of points to generate. | ||
| hypercube_size: Size of the hypercube. | ||
| num_dimensions: Number of dimensions of the hypercube. | ||
| Returns: | ||
| An array of shape (num_points, num_dimensions) | ||
| with generated points. | ||
| """ | ||
| rng = np.random.default_rng() | ||
| shape = (num_points, num_dimensions) | ||
| return hypercube_size * rng.random(shape) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| from __future__ import annotations | ||
| class KDNode: | ||
| """ | ||
| Represents a node in a KD-Tree. | ||
| Attributes: | ||
| point: The point stored in this node. | ||
| left: The left child node. | ||
| right: The right child node. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| point: list[float], | ||
| left: KDNode | None = None, | ||
| right: KDNode | None = None, | ||
| ) -> None: | ||
| """ | ||
| Initializes a KDNode with the given point and child nodes. | ||
| Args: | ||
| point (list[float]): The point stored in this node. | ||
| left (Optional[KDNode]): The left child node. | ||
| right (Optional[KDNode]): The right child node. | ||
| """ | ||
| self.point = point | ||
| self.left = left | ||
| self.right = right |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from data_structures.kd_tree.kd_node import KDNode | ||
| def nearest_neighbour_search( | ||
Ramy-Badr-Ahmed marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading. Please reload this page. | ||
| root: KDNode | None, query_point: list[float] | ||
| ) -> tuple[list[float] | None, float, int]: | ||
| """ | ||
| Performs a nearest neighbor search in a KD-Tree for a given query point. | ||
| Args: | ||
| root (KDNode | None): The root node of the KD-Tree. | ||
| query_point (list[float]): The point for which the nearest neighbor | ||
| is being searched. | ||
| Returns: | ||
| tuple[list[float] | None, float, int]: | ||
| - The nearest point found in the KD-Tree to the query point, | ||
| or None if no point is found. | ||
| - The squared distance to the nearest point. | ||
| - The number of nodes visited during the search. | ||
| """ | ||
| nearest_point: list[float] | None = None | ||
| nearest_dist: float = float("inf") | ||
| nodes_visited: int = 0 | ||
| def search(node: KDNode | None, depth: int = 0) -> None: | ||
| """ | ||
| Recursively searches for the nearest neighbor in the KD-Tree. | ||
| Args: | ||
| node: The current node in the KD-Tree. | ||
| depth: The current depth in the KD-Tree. | ||
| """ | ||
| nonlocal nearest_point, nearest_dist, nodes_visited | ||
| if node is None: | ||
| return | ||
| nodes_visited += 1 | ||
| # Calculate the current distance (squared distance) | ||
| current_point = node.point | ||
| current_dist = sum( | ||
| (query_coord - point_coord) ** 2 | ||
| for query_coord, point_coord in zip(query_point, current_point) | ||
| ) | ||
| # Update nearest point if the current node is closer | ||
| if nearest_point is None or current_dist < nearest_dist: | ||
| nearest_point = current_point | ||
| nearest_dist = current_dist | ||
| # Determine which subtree to search first (based on axis and query point) | ||
| k = len(query_point) # Dimensionality of points | ||
| axis = depth % k | ||
| if query_point[axis] <= current_point[axis]: | ||
| nearer_subtree = node.left | ||
| further_subtree = node.right | ||
| else: | ||
| nearer_subtree = node.right | ||
| further_subtree = node.left | ||
| # Search the nearer subtree first | ||
| search(nearer_subtree, depth + 1) | ||
| # If the further subtree has a closer point | ||
| if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist: | ||
| search(further_subtree, depth + 1) | ||
| search(root, 0) | ||
| return nearest_point, nearest_dist, nodes_visited | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| import numpy as np | ||
| import pytest | ||
| from data_structures.kd_tree.build_kdtree import build_kdtree | ||
| from data_structures.kd_tree.example.hypercube_points import hypercube_points | ||
| from data_structures.kd_tree.kd_node import KDNode | ||
| from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
| @pytest.mark.parametrize( | ||
| ("num_points", "cube_size", "num_dimensions", "depth", "expected_result"), | ||
| [ | ||
| (0, 10.0, 2, 0, None), # Empty points list | ||
| (10, 10.0, 2, 2, KDNode), # Depth = 2, 2D points | ||
| (10, 10.0, 3, -2, KDNode), # Depth = -2, 3D points | ||
| ], | ||
| ) | ||
| def test_build_kdtree(num_points, cube_size, num_dimensions, depth, expected_result): | ||
| """ | ||
| Test that KD-Tree is built correctly. | ||
| Cases: | ||
| - Empty points list. | ||
| - Positive depth value. | ||
| - Negative depth value. | ||
| """ | ||
| points = ( | ||
| hypercube_points(num_points, cube_size, num_dimensions).tolist() | ||
| if num_points > 0 | ||
| else [] | ||
| ) | ||
| kdtree = build_kdtree(points, depth=depth) | ||
| if expected_result is None: | ||
| # Empty points list case | ||
| assert kdtree is None, f"Expected None for empty points list, got{kdtree}" | ||
| else: | ||
| # Check if root node is not None | ||
| assert kdtree is not None, "Expected a KDNode, got None" | ||
| # Check if root has correct dimensions | ||
| assert ( | ||
Ramy-Badr-Ahmed marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading. Please reload this page. | ||
| len(kdtree.point) == num_dimensions | ||
| ), f"Expected point dimension{num_dimensions}, got{len(kdtree.point)}" | ||
| # Check that the tree is balanced to some extent (simplistic check) | ||
| assert isinstance( | ||
Ramy-Badr-Ahmed marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading. Please reload this page. | ||
| kdtree, KDNode | ||
| ), f"Expected KDNode instance, got{type(kdtree)}" | ||
| def test_nearest_neighbour_search(): | ||
| """ | ||
| Test the nearest neighbor search function. | ||
| """ | ||
| num_points = 10 | ||
| cube_size = 10.0 | ||
| num_dimensions = 2 | ||
| points = hypercube_points(num_points, cube_size, num_dimensions) | ||
| kdtree = build_kdtree(points.tolist()) | ||
| rng = np.random.default_rng() | ||
| query_point = rng.random(num_dimensions).tolist() | ||
| nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
| kdtree, query_point | ||
| ) | ||
| # Check that nearest point is not None | ||
| assert nearest_point is not None | ||
| # Check that distance is a non-negative number | ||
| assert nearest_dist >= 0 | ||
| # Check that nodes visited is a non-negative integer | ||
| assert nodes_visited >= 0 | ||
| def test_edge_cases(): | ||
| """ | ||
| Test edge cases such as an empty KD-Tree. | ||
| """ | ||
| empty_kdtree = build_kdtree([]) | ||
| query_point = [0.0] * 2 # Using a default 2D query point | ||
| nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
| empty_kdtree, query_point | ||
| ) | ||
| # With an empty KD-Tree, nearest_point should be None | ||
| assert nearest_point is None | ||
| assert nearest_dist == float("inf") | ||
| assert nodes_visited == 0 | ||
| if __name__ == "__main__": | ||
| import pytest | ||
| pytest.main() | ||
Uh oh!
There was an error while loading. Please reload this page.