Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/xoak/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from importlib.metadata import version

from .accessor import XoakAccessor
from .index import IndexAdapter, IndexRegistry
from xoak.accessor import XoakAccessor
from xoak.index import IndexAdapter, IndexRegistry
from xoak.tree_adapters import (
S2PointTreeAdapter,
SklearnBallTreeAdapter,
SklearnGeoBallTreeAdapter,
SklearnKDTreeAdapter,
)

__all__ = [
"XoakAccessor",
"IndexAdapter",
"IndexRegistry",
"SklearnBallTreeAdapter",
"SklearnGeoBallTreeAdapter",
"SklearnKDTreeAdapter",
"S2PointTreeAdapter",
"XoakAccessor",
]

__version__ = version("xoak")
104 changes: 104 additions & 0 deletions src/xoak/tree_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

import numpy as np

try:
from xarray.indexes.nd_point_index import TreeAdapter # type: ignore
except ImportError:

class TreeAdapter: ...


if TYPE_CHECKING:
import pys2index
import sklearn.neighbors


class S2PointTreeAdapter(TreeAdapter):
""":py:class:`pys2index.S2PointIndex` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""

_s2point_index: pys2index.S2PointIndex

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from pys2index import S2PointIndex

self._s2point_index = S2PointIndex(points)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._s2point_index.query(points)

def equals(self, other: S2PointTreeAdapter) -> bool:
return np.array_equal(
self._s2point_index.get_cell_ids(), other._s2point_index.get_cell_ids()
)


class SklearnKDTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""

_kdtree: sklearn.neighbors.KDTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import KDTree

self._kdtree = KDTree(points, **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._kdtree.query(points)

def equals(self, other: SklearnKDTreeAdapter) -> bool:
return np.array_equal(self._kdtree.data, other._kdtree.data)


class SklearnBallTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.BallTree` adapter for
:py:class:`~xarray.indexes.NDPointIndex`.

"""

_balltree: sklearn.neighbors.BallTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import BallTree

self._balltree = BallTree(points, **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(points)

def equals(self, other: SklearnBallTreeAdapter) -> bool:
return np.array_equal(self._balltree.data, other._balltree.data)


class SklearnGeoBallTreeAdapter(TreeAdapter):
""":py:class:`sklearn.neighbors.BallTree` adapter for
:py:class:`~xarray.indexes.NDPointIndex`, using the 'haversine' metric.

It can be used for indexing a set of latitude / longitude points.

When building the index, the coordinates must be given in the latitude,
longitude order.

Latitude and longitude values must be given in degrees for both index and
query points (those values are converted in radians by this adapter).

"""

_balltree: sklearn.neighbors.BallTree

def __init__(self, points: np.ndarray, options: Mapping[str, Any]):
from sklearn.neighbors import BallTree

opts = dict(options)
opts.update({"metric": "haversine"})

self._balltree = BallTree(np.deg2rad(points), **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(np.deg2rad(points))

def equals(self, other: SklearnGeoBallTreeAdapter) -> bool:
return np.array_equal(self._balltree.data, other._balltree.data)