diff --git a/ci/environment.yml b/ci/environment.yml index f690fa59..98da6643 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -13,6 +13,9 @@ dependencies: - hypothesis - ruff - typing-extensions + - geoarrow-pyarrow + - lonboard - pip - pip: + - arro3-core - h3ronpy diff --git a/pyproject.toml b/pyproject.toml index a0e898ac..b8ae7a70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,14 @@ dependencies = [ "typing-extensions", ] +[project.optional-dependencies] +explore = [ + "lonboard>=0.9.3", + "pyproj>=3.3", + "matplotlib", + "arro3-core>=0.4.0" +] + [project.urls] # Home = "https://xdggs.readthedocs.io" Repository = "https://github.com/xarray-contrib/xdggs" diff --git a/xdggs/accessor.py b/xdggs/accessor.py index 8dbf9e64..df6e1860 100644 --- a/xdggs/accessor.py +++ b/xdggs/accessor.py @@ -3,6 +3,7 @@ from xdggs.grid import DGGSInfo from xdggs.index import DGGSIndex +from xdggs.plotting import explore @xr.register_dataset_accessor("dggs") @@ -115,3 +116,38 @@ def cell_boundaries(self): return xr.DataArray( boundaries, coords={self._name: self.cell_ids}, dims=self.cell_ids.dims ) + + def explore(self, *, cmap="viridis", center=None, alpha=None): + """interactively explore the data using `lonboard` + + Requires `lonboard`, `matplotlib`, and `arro3.core` to be installed. + + Parameters + ---------- + cmap : str + The name of the color map to use + center : int or float, optional + If set, will use this as the center value of a diverging color map. + alpha : float, optional + If set, controls the transparency of the polygons. + + Returns + ------- + map : lonboard.Map + The rendered map. + + Notes + ----- + Plotting currently is restricted to 1D `DataArray` objects. + """ + if isinstance(self._obj, xr.Dataset): + raise ValueError("does not work with Dataset objects, yet") + + cell_dim = self._obj[self._name].dims[0] + return explore( + self._obj, + cell_dim=cell_dim, + cmap=cmap, + center=center, + alpha=alpha, + ) diff --git a/xdggs/grid.py b/xdggs/grid.py index 853beaae..6582c838 100644 --- a/xdggs/grid.py +++ b/xdggs/grid.py @@ -34,5 +34,5 @@ def cell_ids2geographic(self, cell_ids): def geographic2cell_ids(self, lon, lat): raise NotImplementedError() - def cell_boundaries(self, cell_ids): + def cell_boundaries(self, cell_ids, backend="shapely"): raise NotImplementedError() diff --git a/xdggs/h3.py b/xdggs/h3.py index 30063996..40bf2896 100644 --- a/xdggs/h3.py +++ b/xdggs/h3.py @@ -1,3 +1,4 @@ +import json from collections.abc import Mapping from dataclasses import dataclass from typing import Any, ClassVar @@ -8,7 +9,6 @@ from typing_extensions import Self import numpy as np -import shapely import xarray as xr from h3ronpy.arrow.vector import ( cells_to_coordinates, @@ -22,6 +22,42 @@ from xdggs.utils import _extract_cell_id_variable, register_dggs +def polygons_shapely(wkb): + import shapely + + return shapely.from_wkb(wkb) + + +def polygons_geoarrow(wkb): + import pyproj + import shapely + from arro3.core import list_array + + polygons = shapely.from_wkb(wkb) + crs = pyproj.CRS.from_epsg(4326) + + geometry_type, coords, (ring_offsets, geom_offsets) = shapely.to_ragged_array( + polygons + ) + + if geometry_type != shapely.GeometryType.POLYGON: + raise ValueError(f"unsupported geometry type found: {geometry_type}") + + polygon_array = list_array( + geom_offsets.astype("int32"), list_array(ring_offsets.astype("int32"), coords) + ) + polygon_array_with_geo_meta = polygon_array.cast( + polygon_array.field.with_metadata( + { + "ARROW:extension:name": "geoarrow.polygon", + "ARROW:extension:metadata": json.dumps({"crs": crs.to_json_dict()}), + } + ) + ) + + return polygon_array_with_geo_meta + + @dataclass(frozen=True) class H3Info(DGGSInfo): resolution: int @@ -50,10 +86,18 @@ def cell_ids2geographic( def geographic2cell_ids(self, lon, lat): return coordinates_to_cells(lat, lon, self.resolution, radians=False) - def cell_boundaries(self, cell_ids): + def cell_boundaries(self, cell_ids, backend="shapely"): + # TODO: convert cell ids directly to geoarrow once h3ronpy supports it wkb = cells_to_wkb_polygons(cell_ids, radians=False, link_cells=False) - return shapely.from_wkb(wkb) + backends = { + "shapely": polygons_shapely, + "geoarrow": polygons_geoarrow, + } + backend_func = backends.get(backend) + if backend_func is None: + raise ValueError("invalid backend: {backend!r}") + return backend_func(wkb) @register_dggs("h3") diff --git a/xdggs/healpix.py b/xdggs/healpix.py index 111fae4f..5819b4ce 100644 --- a/xdggs/healpix.py +++ b/xdggs/healpix.py @@ -1,3 +1,4 @@ +import json import operator from collections.abc import Mapping from dataclasses import dataclass, field @@ -26,6 +27,72 @@ from exceptiongroup import ExceptionGroup +def polygons_shapely(vertices): + import shapely + + return shapely.polygons(vertices) + + +def polygons_geoarrow(vertices): + import pyproj + from arro3.core import list_array + + polygon_vertices = np.concatenate([vertices, vertices[:, :1, :]], axis=1) + crs = pyproj.CRS.from_epsg(4326) + + # construct geoarrow arrays + coords = np.reshape(polygon_vertices, (-1, 2)) + coords_per_pixel = polygon_vertices.shape[1] + geom_offsets = np.arange(vertices.shape[0] + 1, dtype="int32") + ring_offsets = geom_offsets * coords_per_pixel + + polygon_array = list_array(geom_offsets, list_array(ring_offsets, coords)) + + # We need to tag the array with extension metadata (`geoarrow.polygon`) so that Lonboard knows that this is a geospatial column. + polygon_array_with_geo_meta = polygon_array.cast( + polygon_array.field.with_metadata( + { + "ARROW:extension:name": "geoarrow.polygon", + "ARROW:extension:metadata": json.dumps( + {"crs": crs.to_json_dict(), "edges": "spherical"} + ), + } + ) + ) + return polygon_array_with_geo_meta + + +def center_around_prime_meridian(lon, lat): + # three tasks: + # - center around the prime meridian (map to a range of [-180, 180]) + # - replace the longitude of points at the poles with the median + # of longitude of the other vertices + # - cells that cross the dateline should have longitudes around 180 + + # center around prime meridian + recentered = (lon + 180) % 360 - 180 + + # replace lon of pole with the median of the remaining vertices + contains_poles = np.isin(lat, np.array([-90, 90])) + pole_cells = np.any(contains_poles, axis=-1) + recentered[contains_poles] = np.median( + np.reshape( + recentered[pole_cells[:, None] & np.logical_not(contains_poles)], (-1, 3) + ), + axis=-1, + ) + + # keep cells that cross the dateline centered around 180 + polygons_to_fix = np.any(recentered < -100, axis=-1) & np.any( + recentered > 100, axis=-1 + ) + result = np.where( + polygons_to_fix[:, None] & (recentered < 0), recentered + 360, recentered + ) + + return result + + @dataclass(frozen=True) class HealpixInfo(DGGSInfo): resolution: int @@ -135,23 +202,29 @@ def cell_ids2geographic(self, cell_ids): def geographic2cell_ids(self, lon, lat): return healpy.ang2pix(self.nside, lon, lat, lonlat=True, nest=self.nest) - def cell_boundaries(self, cell_ids: Any) -> np.ndarray: - import shapely - + def cell_boundaries(self, cell_ids: Any, backend="shapely") -> np.ndarray: boundary_vectors = healpy.boundaries( self.nside, cell_ids, step=1, nest=self.nest ) lon, lat = healpy.vec2ang(np.moveaxis(boundary_vectors, 1, -1), lonlat=True) - boundaries = np.reshape(np.stack((lon, lat), axis=-1), (-1, 4, 2)) + lon_reshaped = np.reshape(lon, (-1, 4)) + lat_reshaped = np.reshape(lat, (-1, 4)) + + lon_ = center_around_prime_meridian(lon_reshaped, lat_reshaped) + + vertices = np.stack((lon_, lat_reshaped), axis=-1) + + backends = { + "shapely": polygons_shapely, + "geoarrow": polygons_geoarrow, + } - # fix the dateline / prime meridian issue - lon_ = boundaries[..., 0] - to_fix = abs(np.max(lon_, axis=-1) - np.min(lon_, axis=-1)) > 300 - fixed_lon = (lon_[to_fix, :] + 180) % 360 - 180 - boundaries[to_fix, :, 0] = fixed_lon + backend_func = backends.get(backend) + if backend_func is None: + raise ValueError("invalid backend: {backend!r}") - return shapely.polygons(boundaries) + return backend_func(vertices) @register_dggs("healpix") diff --git a/xdggs/plotting.py b/xdggs/plotting.py new file mode 100644 index 00000000..9c23fd34 --- /dev/null +++ b/xdggs/plotting.py @@ -0,0 +1,72 @@ +import numpy as np + + +def create_arrow_table(polygons, arr, coords=None): + from arro3.core import Array, ChunkedArray, Schema, Table + + if coords is None: + coords = ["latitude", "longitude"] + + array = Array.from_arrow(polygons) + name = arr.name or "data" + arrow_arrays = { + "geometry": array, + "cell_ids": ChunkedArray([Array.from_numpy(arr.coords["cell_ids"])]), + name: ChunkedArray([Array.from_numpy(arr.data)]), + } | { + coord: ChunkedArray([Array.from_numpy(arr.coords[coord].data)]) + for coord in coords + if coord in arr.coords + } + + fields = [array.field.with_name(name) for name, array in arrow_arrays.items()] + schema = Schema(fields) + + return Table.from_arrays(list(arrow_arrays.values()), schema=schema) + + +def normalize(var, center=None): + from matplotlib.colors import CenteredNorm, Normalize + + if center is None: + vmin = var.min(skipna=True) + vmax = var.max(skipna=True) + normalizer = Normalize(vmin=vmin, vmax=vmax) + else: + halfrange = np.abs(var - center).max(skipna=True) + normalizer = CenteredNorm(vcenter=center, halfrange=halfrange) + + return normalizer(var.data) + + +def explore( + arr, + cell_dim="cells", + cmap="viridis", + center=None, + alpha=None, +): + import lonboard + from lonboard import SolidPolygonLayer + from lonboard.colormap import apply_continuous_cmap + from matplotlib import colormaps + + if len(arr.dims) != 1 or cell_dim not in arr.dims: + raise ValueError( + f"exploration only works with a single dimension ('{cell_dim}')" + ) + + cell_ids = arr.dggs.coord.data + grid_info = arr.dggs.grid_info + + polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow") + + normalized_data = normalize(arr.variable, center=center) + + colormap = colormaps[cmap] + colors = apply_continuous_cmap(normalized_data, colormap, alpha=alpha) + + table = create_arrow_table(polygons, arr) + layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors) + + return lonboard.Map(layer) diff --git a/xdggs/tests/__init__.py b/xdggs/tests/__init__.py index aa71ded3..1a193487 100644 --- a/xdggs/tests/__init__.py +++ b/xdggs/tests/__init__.py @@ -1,5 +1,12 @@ +import geoarrow.pyarrow as ga +import shapely + from xdggs.tests.matchers import ( # noqa: F401 Match, MatchResult, assert_exceptions_equal, ) + + +def geoarrow_to_shapely(arr): + return shapely.from_wkb(ga.as_wkb(arr)) diff --git a/xdggs/tests/test_h3.py b/xdggs/tests/test_h3.py index 2a1ed40b..3ab102c0 100644 --- a/xdggs/tests/test_h3.py +++ b/xdggs/tests/test_h3.py @@ -8,6 +8,7 @@ from xarray.core.indexes import PandasIndex from xdggs import h3 +from xdggs.tests import geoarrow_to_shapely # from the h3 gallery, at resolution 3 cell_ids = [ @@ -202,14 +203,18 @@ def test_geographic2cell_ids(self, cell_centers, cell_ids): ), ), ) - def test_cell_boundaries(self, resolution, cell_ids, expected_coords): + @pytest.mark.parametrize("backend", ["shapely", "geoarrow"]) + def test_cell_boundaries(self, resolution, cell_ids, backend, expected_coords): expected = shapely.polygons(expected_coords) grid = h3.H3Info(resolution=resolution) - actual = grid.cell_boundaries(cell_ids) + backends = {"shapely": lambda arr: arr, "geoarrow": geoarrow_to_shapely} + converter = backends[backend] - shapely.testing.assert_geometries_equal(actual, expected) + actual = grid.cell_boundaries(cell_ids, backend=backend) + + shapely.testing.assert_geometries_equal(converter(actual), expected) @pytest.mark.parametrize("resolution", resolutions) diff --git a/xdggs/tests/test_healpix.py b/xdggs/tests/test_healpix.py index 94a1c31d..942d5a06 100644 --- a/xdggs/tests/test_healpix.py +++ b/xdggs/tests/test_healpix.py @@ -12,7 +12,7 @@ from xarray.core.indexes import PandasIndex from xdggs import healpix -from xdggs.tests import assert_exceptions_equal +from xdggs.tests import assert_exceptions_equal, geoarrow_to_shapely try: ExceptionGroup @@ -246,10 +246,10 @@ def test_roundtrip(self, resolution, indexing_scheme, rotation): np.array([2]), np.array( [ - [0.0, 90.0], - [180.0, 41.8103149], - [225.0, 0.0], - [270.0, 41.8103149], + [-135.0, 90.0], + [-180.0, 41.8103149], + [-135.0, 0.0], + [-90.0, 41.8103149], ] ), ), @@ -265,10 +265,10 @@ def test_roundtrip(self, resolution, indexing_scheme, rotation): [30.0, 54.3409123], ], [ - [315.0, 41.8103149], - [303.75, 30.0], - [315.0, 19.47122063], - [326.25, 30.0], + [-45.0, 41.8103149], + [-56.25, 30.0], + [-45.0, 19.47122063], + [-33.75, 30.0], ], ] ), @@ -307,13 +307,20 @@ def test_roundtrip(self, resolution, indexing_scheme, rotation): ), ), ) - def test_cell_boundaries(self, params, cell_ids, expected_coords): + @pytest.mark.parametrize("backend", ["shapely", "geoarrow"]) + def test_cell_boundaries(self, params, cell_ids, backend, expected_coords): grid = healpix.HealpixInfo.from_dict(params) - actual = grid.cell_boundaries(cell_ids) + actual = grid.cell_boundaries(cell_ids, backend=backend) + + backends = { + "shapely": lambda arr: arr, + "geoarrow": geoarrow_to_shapely, + } + converter = backends[backend] expected = shapely.polygons(expected_coords) - shapely.testing.assert_geometries_equal(actual, expected) + shapely.testing.assert_geometries_equal(converter(actual), expected) @given( *strategies.grid_and_cell_ids( diff --git a/xdggs/tests/test_matchers.py b/xdggs/tests/test_matchers.py index eed8d7ef..4d4e3591 100644 --- a/xdggs/tests/test_matchers.py +++ b/xdggs/tests/test_matchers.py @@ -4,7 +4,7 @@ try: ExceptionGroup -except NameError: +except NameError: # pragma: no cover from exceptiongroup import ExceptionGroup diff --git a/xdggs/tests/test_plotting.py b/xdggs/tests/test_plotting.py new file mode 100644 index 00000000..6070248e --- /dev/null +++ b/xdggs/tests/test_plotting.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest +import xarray as xr +from arro3.core import Array, Table + +from xdggs import plotting + + +@pytest.mark.parametrize( + ["polygons", "arr", "coords", "expected"], + ( + pytest.param( + Array.from_numpy(np.array([1, 2])), + xr.DataArray( + [-1, 1], + coords={ + "cell_ids": ("cells", [0, 1]), + "latitude": ("cells", [-5, 10]), + "longitude": ("cells", [-60, -50]), + }, + dims="cells", + ), + None, + Table.from_pydict( + { + "geometry": Array.from_numpy(np.array([1, 2])), + "cell_ids": Array.from_numpy(np.array([0, 1])), + "data": Array.from_numpy(np.array([-1, 1])), + "latitude": Array.from_numpy(np.array([-5, 10])), + "longitude": Array.from_numpy(np.array([-60, -50])), + } + ), + ), + pytest.param( + Array.from_numpy(np.array([1, 2])), + xr.DataArray( + [-1, 1], + coords={ + "cell_ids": ("cells", [1, 2]), + "latitude": ("cells", [-5, 10]), + "longitude": ("cells", [-60, -50]), + }, + dims="cells", + ), + ["latitude"], + Table.from_pydict( + { + "geometry": Array.from_numpy(np.array([1, 2])), + "cell_ids": Array.from_numpy(np.array([1, 2])), + "data": Array.from_numpy(np.array([-1, 1])), + "latitude": Array.from_numpy(np.array([-5, 10])), + } + ), + ), + pytest.param( + Array.from_numpy(np.array([1, 3])), + xr.DataArray( + [-1, 1], + coords={ + "cell_ids": ("cells", [0, 1]), + "latitude": ("cells", [-5, 10]), + "longitude": ("cells", [-60, -50]), + }, + dims="cells", + name="new_data", + ), + ["longitude"], + Table.from_pydict( + { + "geometry": Array.from_numpy(np.array([1, 3])), + "cell_ids": Array.from_numpy(np.array([0, 1])), + "new_data": Array.from_numpy(np.array([-1, 1])), + "longitude": Array.from_numpy(np.array([-60, -50])), + } + ), + ), + ), +) +def test_create_arrow_table(polygons, arr, coords, expected): + actual = plotting.create_arrow_table(polygons, arr, coords=coords) + + assert actual == expected + + +@pytest.mark.parametrize( + ["var", "center", "expected"], + ( + pytest.param( + xr.Variable("cells", np.array([-5, np.nan, -2, 1])), + None, + np.array([0, np.nan, 0.5, 1]), + id="linear-missing_values", + ), + pytest.param( + xr.Variable("cells", np.arange(-5, 2, dtype="float")), + None, + np.linspace(0, 1, 7), + id="linear-manual", + ), + pytest.param( + xr.Variable("cells", np.linspace(0, 10, 5)), + None, + np.linspace(0, 1, 5), + id="linear-linspace", + ), + pytest.param( + xr.Variable("cells", np.linspace(-5, 5, 10)), + 0, + np.linspace(0, 1, 10), + id="centered-0", + ), + pytest.param( + xr.Variable("cells", np.linspace(0, 10, 10)), + 5, + np.linspace(0, 1, 10), + id="centered-2", + ), + ), +) +def test_normalize(var, center, expected): + actual = plotting.normalize(var, center=center) + + np.testing.assert_allclose(actual, expected)