diff --git a/kd_tree.py b/kd_tree.py index 5b07fda..34b870b 100644 --- a/kd_tree.py +++ b/kd_tree.py @@ -39,7 +39,11 @@ def __init__(self, points, dim, dist_sq_func=None): def make(points, i=0): if len(points) > 1: - points.sort(key=lambda x: x[i]) + if type(points).__module__ == 'numpy' and type(points).__name__ == 'ndarray': + # Numpy-specific fix + points = points[points[:, i].argsort()] + else: + points.sort(key=lambda x: x[i]) i = (i + 1) % dim m = len(points) >> 1 return [make(points[:m], i), make(points[m + 1:], i),