From ce16f746d2fa3fbe607f869796a7ca43406ee34a Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 6 Sep 2024 14:56:48 +0000 Subject: [PATCH 1/8] initial pass at spherical padding, faster tests, full nan tracking --- src/xarray_regrid/methods/conservative.py | 7 +- src/xarray_regrid/regrid.py | 21 ++- src/xarray_regrid/utils.py | 128 +++++++++++++++- tests/test_format.py | 178 ++++++++++++++++++++++ tests/test_regrid.py | 159 +++++++++---------- 5 files changed, 388 insertions(+), 105 deletions(-) create mode 100644 tests/test_format.py diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index eb96acf..deec777 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -66,7 +66,7 @@ def conservative_regrid( # Attempt to infer the latitude coordinate if latitude_coord is None: for coord in data.coords: - if str(coord).lower() in ["lat", "latitude"]: + if str(coord).lower().startswith("lat"): latitude_coord = coord break @@ -122,7 +122,6 @@ def conservative_regrid_dataset( weights = apply_spherical_correction(weights, latitude_coord) for array in data_vars.keys(): - non_grid_dims = [d for d in data_vars[array].dims if d not in coords] if coord in data_vars[array].dims: data_vars[array], valid_fracs[array] = apply_weights( da=data_vars[array], @@ -130,7 +129,6 @@ def conservative_regrid_dataset( coord=coord, valid_frac=valid_fracs[array], skipna=skipna, - non_grid_dims=non_grid_dims, ) # Mask out any regridded points outside the original domain data_vars[array] = data_vars[array].where(covered_grid) @@ -161,7 +159,6 @@ def apply_weights( coord: Hashable, valid_frac: xr.DataArray, skipna: bool, - non_grid_dims: list[Hashable], ) -> tuple[xr.DataArray, xr.DataArray]: """Apply the weights to convert data to the new coordinates.""" coord_map = {f"target_{coord}": coord} @@ -169,8 +166,6 @@ def apply_weights( if skipna: notnull = da.notnull() - if non_grid_dims: - notnull = notnull.any(non_grid_dims) # Renormalize the weights along this dim by the accumulated valid_frac # along previous dimensions if valid_frac.name != EMPTY_DA_NAME: diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index e91348a..fe0f214 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,6 +1,7 @@ import xarray as xr from xarray_regrid.methods import conservative, interp, most_common +from xarray_regrid.utils import format_for_regrid @xr.register_dataarray_accessor("regrid") @@ -34,7 +35,8 @@ def linear( Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return interp.interp_regrid(self._obj, ds_target_grid, "linear") + ds_formatted = format_for_regrid(self._obj, ds_target_grid) + return interp.interp_regrid(ds_formatted, ds_target_grid, "linear") def nearest( self, @@ -51,14 +53,14 @@ def nearest( Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return interp.interp_regrid(self._obj, ds_target_grid, "nearest") + ds_formatted = format_for_regrid(self._obj, ds_target_grid) + return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest") def cubic( self, ds_target_grid: xr.Dataset, time_dim: str = "time", ) -> xr.DataArray | xr.Dataset: - ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) """Regrid to the coords of the target dataset with cubic interpolation. Args: @@ -68,7 +70,9 @@ def cubic( Returns: Data regridded to the target dataset coordinates. """ - return interp.interp_regrid(self._obj, ds_target_grid, "cubic") + ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) + ds_formatted = format_for_regrid(self._obj, ds_target_grid) + return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic") def conservative( self, @@ -88,6 +92,9 @@ def conservative( time_dim: The name of the time dimension/coordinate. skipna: If True, enable handling for NaN values. This adds some overhead, so can be disabled for optimal performance on data without any NaNs. + With `skipna=True, chunking is recommended in the non-grid dimensions, + otherwise the intermediate arrays that track the fraction of valid data + can become very large and consume excessive memory. Warning: with `skipna=False`, isolated NaNs will propagate throughout the dataset due to the sequential regridding scheme over each dimension. nan_threshold: Threshold value that will retain any output points @@ -104,8 +111,9 @@ def conservative( raise ValueError(msg) ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) + ds_formatted = format_for_regrid(self._obj, ds_target_grid) return conservative.conservative_regrid( - self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold + ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold ) def most_common( @@ -134,8 +142,9 @@ def most_common( Regridded data. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) + ds_formatted = format_for_regrid(self._obj, ds_target_grid) return most_common.most_common_wrapper( - self._obj, ds_target_grid, time_dim, max_mem + ds_formatted, ds_target_grid, time_dim, max_mem ) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index ce07f88..acf89cf 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Callable, Hashable from dataclasses import dataclass from typing import Any, overload @@ -75,7 +75,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]: grid.south, grid.north + grid.resolution_lat, grid.resolution_lat ) - if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0: + if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0: lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon) else: lon_coords = np.arange( @@ -235,3 +235,127 @@ def call_on_dataset( return next(iter(result.data_vars.values())).rename(obj.name) return result + + +def format_for_regrid( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset +) -> xr.DataArray | xr.Dataset: + """Apply any pre-formatting to the input dataset to prepare for regridding. + Currently handles padding of spherical geometry if appropriate coordinate names + can be inferred containing 'lat' and 'lon'. + """ + lat_coord = None + lon_coord = None + + for coord in obj.coords.keys(): + if str(coord).lower().startswith("lat"): + lat_coord = coord + elif str(coord).lower().startswith("lon"): + lon_coord = coord + + if lon_coord is not None or lat_coord is not None: + obj = format_spherical(obj, target, lat_coord, lon_coord) + + return obj + + +def format_spherical( + obj: xr.DataArray | xr.Dataset, + target: xr.Dataset, + lat_coord: Hashable, + lon_coord: Hashable, +) -> xr.DataArray | xr.Dataset: + """Infer whether a lat/lon source grid represents a global domain and + automatically apply spherical padding to improve edge effects. + + For longitude, shift the coordinate to line up with the target values, then + add a single wraparound padding column if the domain is global and the east + or west edges of the target lie outside the source grid centers. + + For latitude, add a single value at each pole computed as the mean of the last + row for global source grids where the first or last point lie equatorward of 90. + """ + + orig_chunksizes = obj.chunksizes + + # If the source coord fully covers the target, don't modify them + if lat_coord and not coord_is_covered(obj, target, lat_coord): + obj = obj.sortby(lat_coord) + target = target.sortby(lat_coord) + + # Only pad if global but don't have edge values directly at poles + polar_lat = 90 + dy = obj[lat_coord].diff(lat_coord).max().values + + # South pole + if dy - polar_lat >= obj[lat_coord][0] > -polar_lat: + south_pole = obj.isel({lat_coord: 0}) + # This should match the Pole="all" option of ESMF + if lon_coord is not None: + south_pole = south_pole.mean(lon_coord) + obj = xr.concat([south_pole, obj], dim=lat_coord) + obj[lat_coord].values[0] = -polar_lat + + # North pole + if polar_lat - dy <= obj[lat_coord][-1] < polar_lat: + north_pole = obj.isel({lat_coord: -1}) + if lon_coord is not None: + north_pole = north_pole.mean(lon_coord) + obj = xr.concat([obj, north_pole], dim=lat_coord) + obj[lat_coord].values[-1] = polar_lat + + # Coerce back to a single chunk if that's what was passed + if len(orig_chunksizes.get(lat_coord, [])) == 1: + obj = obj.chunk({lat_coord: -1}) + + if lon_coord and not coord_is_covered(obj, target, lon_coord): + obj = obj.sortby(lon_coord) + target = target.sortby(lon_coord) + + target_lon = target[lon_coord].values + # Find a wrap point outside of the left and right bounds of the target + # This ensures we have coverage on the target and handles global > regional + wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2 + lon = obj[lon_coord].values + lon = np.where(lon < wrap_point - 360, lon + 360, lon) + lon = np.where(lon > wrap_point, lon - 360, lon) + obj[lon_coord].values[:] = lon + + # Shift operations can produce duplicates + # Simplest solution is to drop them and add back when padding + obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) + + # Only pad if domain is global in lon + dx_s = obj[lon_coord].diff(lon_coord).max().values + dx_t = target[lon_coord].diff(lon_coord).max().values + is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s + + if is_global_lon: + left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s + right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s + left_pad = int(np.ceil(np.max([left_pad, 0]))) + right_pad = int(np.ceil(np.max([right_pad, 0]))) + lon = obj[lon_coord].values + obj = obj.pad( + {lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True + ) + if left_pad: + obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360 + if right_pad: + obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360 + + # Coerce back to a single chunk if that's what was passed + if len(orig_chunksizes.get(lon_coord, [])) == 1: + obj = obj.chunk({lon_coord: -1}) + + return obj + + +def coord_is_covered( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable +) -> bool: + """Check if the source coord fully covers the target coord.""" + pad = target[coord].diff(coord).max().values + left_covered = obj[coord].min() <= target[coord].min() - pad + right_covered = obj[coord].max() >= target[coord].max() + pad + return bool(left_covered.item() and right_covered.item()) diff --git a/tests/test_format.py b/tests/test_format.py new file mode 100644 index 0000000..4fabde9 --- /dev/null +++ b/tests/test_format.py @@ -0,0 +1,178 @@ +import xarray as xr + +import xarray_regrid +from xarray_regrid.utils import format_for_regrid + + +def test_covered(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=80, + east=350, + south=-80, + west=10, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Formatting utils shouldn't modify this one at all + xr.testing.assert_equal(source, formatted) + + +def test_no_edges(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90 - dx_source / 2, + east=360 - dx_source / 2, + south=-90 + dx_source / 2, + west=0 + dx_source / 2, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should add wraparound and polar padding rows/columns + assert formatted.latitude[0] == -90 + assert formatted.latitude[-1] == 90 + assert formatted.longitude[0] == -1 + assert formatted.longitude[-1] == 361 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_360_to_180(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=180, + south=-90, + west=-180, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to target plus wraparound padding + assert formatted.longitude[0] == -182 + assert formatted.longitude[-1] == 182 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_180_to_360(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=180, + south=-90, + west=-180, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to target plus wraparound padding + assert formatted.longitude[0] == -2 + assert formatted.longitude[-1] == 362 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_0_to_360(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=0, + south=-90, + west=-360, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to target plus wraparound padding + assert formatted.longitude[0] == -2 + assert formatted.longitude[-1] == 362 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_global_to_local_shift(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=180, + south=-90, + west=-180, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=300, + south=-90, + west=270, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to cover the target range + assert formatted.longitude.min() <= 270 + assert formatted.longitude.max() >= 300 + assert (formatted.longitude.diff("longitude") == 2).all() diff --git a/tests/test_regrid.py b/tests/test_regrid.py index bf5f59b..0296f6a 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -1,4 +1,3 @@ -from copy import deepcopy from pathlib import Path import numpy as np @@ -6,29 +5,46 @@ import xarray as xr from numpy.testing import assert_array_equal +try: + import xesmf +except ImportError: + xesmf = None + import xarray_regrid DATA_PATH = Path(__file__).parent.parent / "docs" / "notebooks" / "benchmarks" / "data" -CDO_DATA = { +CDO_FILES = { "linear": DATA_PATH / "cdo_bilinear_64b.nc", "nearest": DATA_PATH / "cdo_nearest_64b.nc", "conservative": DATA_PATH / "cdo_conservative_64b.nc", } +# Sample files contain 12 monthly timestamps but subset to one for speed +N_TIMESTAMPS = 1 + @pytest.fixture(scope="session") -def load_input_data() -> xr.Dataset: +def sample_input_data() -> xr.Dataset: ds = xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc") - return ds.compute() + return ds.isel(time=slice(0, N_TIMESTAMPS)).persist() -@pytest.fixture -def sample_input_data(load_input_data) -> xr.Dataset: - return deepcopy(load_input_data) +@pytest.fixture(scope="session") +def conservative_input_data() -> xr.Dataset: + ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") + return ds.isel(time=slice(0, N_TIMESTAMPS)).persist() -@pytest.fixture +@pytest.fixture(scope="session") +def cdo_comparison_data() -> dict[str, xr.Dataset]: + data = {} + for method, path in CDO_FILES.items(): + data[method] = xr.open_dataset(path).isel(time=slice(0, N_TIMESTAMPS)).persist() + return data + + +@pytest.fixture(scope="session") def sample_grid_ds(): grid = xarray_regrid.Grid( north=90, @@ -42,48 +58,7 @@ def sample_grid_ds(): return xarray_regrid.create_regridding_dataset(grid) -@pytest.mark.parametrize( - "method, cdo_file", - [ - ("linear", CDO_DATA["linear"]), - ("nearest", CDO_DATA["nearest"]), - ], -) -def test_basic_regridders_ds(sample_input_data, sample_grid_ds, method, cdo_file): - """Test the dataset regridders (except conservative).""" - regridder = getattr(sample_input_data.regrid, method) - ds_regrid = regridder(sample_grid_ds) - ds_cdo = xr.open_dataset(cdo_file) - xr.testing.assert_allclose(ds_regrid.compute(), ds_cdo.compute()) - - -@pytest.mark.parametrize( - "method, cdo_file", - [ - ("linear", CDO_DATA["linear"]), - ("nearest", CDO_DATA["nearest"]), - ], -) -def test_basic_regridders_da(sample_input_data, sample_grid_ds, method, cdo_file): - """Test the dataarray regridders (except conservative).""" - regridder = getattr(sample_input_data["d2m"].regrid, method) - da_regrid = regridder(sample_grid_ds) - ds_cdo = xr.open_dataset(cdo_file) - xr.testing.assert_allclose(da_regrid.compute(), ds_cdo["d2m"].compute()) - - @pytest.fixture(scope="session") -def load_conservative_input_data() -> xr.Dataset: - ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") - return ds.compute() - - -@pytest.fixture -def conservative_input_data(load_conservative_input_data) -> xr.Dataset: - return deepcopy(load_conservative_input_data) - - -@pytest.fixture def conservative_sample_grid(): grid = xarray_regrid.Grid( north=90, @@ -97,46 +72,60 @@ def conservative_sample_grid(): return xarray_regrid.create_regridding_dataset(grid) -def test_conservative_regridder(conservative_input_data, conservative_sample_grid): +@pytest.mark.parametrize("method", ["linear", "nearest"]) +def test_basic_regridders_ds( + sample_input_data, sample_grid_ds, cdo_comparison_data, method +): + """Test the dataset regridders (except conservative).""" + regridder = getattr(sample_input_data.regrid, method) + ds_regrid = regridder(sample_grid_ds) + ds_cdo = cdo_comparison_data[method] + xr.testing.assert_allclose(ds_regrid, ds_cdo, rtol=0.002, atol=2e-5) + + +@pytest.mark.parametrize("method", ["linear", "nearest"]) +def test_basic_regridders_da( + sample_input_data, sample_grid_ds, cdo_comparison_data, method +): + """Test the dataarray regridders (except conservative).""" + regridder = getattr(sample_input_data["d2m"].regrid, method) + da_regrid = regridder(sample_grid_ds) + da_cdo = cdo_comparison_data[method]["d2m"] + xr.testing.assert_allclose(da_regrid, da_cdo, rtol=0.002, atol=2e-5) + + +def test_conservative_regridder( + conservative_input_data, conservative_sample_grid, cdo_comparison_data +): ds_regrid = conservative_input_data.regrid.conservative( conservative_sample_grid, latitude_coord="latitude" ) - ds_cdo = xr.open_dataset(CDO_DATA["conservative"]) - - # Cut of the edges: edge performance to be improved later (hopefully) - no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)} + ds_cdo = cdo_comparison_data["conservative"] xr.testing.assert_allclose( - ds_regrid["tp"] - .sel(no_edges) - .compute() - .transpose("time", "latitude", "longitude"), - ds_cdo["tp"].sel(no_edges).compute(), + ds_regrid["tp"], + ds_cdo["tp"], rtol=0.002, - atol=2e-6, + atol=2e-5, ) -def test_conservative_nans(conservative_input_data, conservative_sample_grid): +def test_conservative_nans( + conservative_input_data, conservative_sample_grid, cdo_comparison_data +): ds = conservative_input_data ds["tp"] = ds["tp"].where(ds.latitude >= 0).where(ds.longitude < 180) ds_regrid = ds.regrid.conservative( conservative_sample_grid, latitude_coord="latitude" ) - ds_cdo = xr.open_dataset(CDO_DATA["conservative"]) + ds_cdo = cdo_comparison_data["conservative"] - # Cut of the edges: edge performance to be improved later (hopefully) - no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)} - no_nans = {"latitude": slice(1, 90), "longitude": slice(None, 179)} + no_nans = {"latitude": slice(1, 90), "longitude": slice(1, 179)} xr.testing.assert_allclose( - ds_regrid["tp"] - .sel(no_edges) - .sel(no_nans) - .compute() - .transpose("time", "latitude", "longitude"), - ds_cdo["tp"].sel(no_edges).sel(no_nans).compute(), + ds_regrid["tp"].sel(no_nans), + ds_cdo["tp"].sel(no_nans), rtol=0.002, - atol=2e-6, + atol=2e-5, ) @@ -186,7 +175,7 @@ def test_conservative_nan_aggregation_over_dims(): target = xr.Dataset(coords={"x": [0], "y": [0]}) result = data.regrid.conservative(target, skipna=True, nan_threshold=1) - assert np.allclose(result[0].mean().item(), data[0].mean().item()) + np.testing.assert_allclose(result[0].mean().item(), data[0].mean().item()) @pytest.mark.parametrize("nan_threshold", [0, 1]) @@ -212,19 +201,9 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold): xr.testing.assert_allclose(da_coarsen, da_regrid) -def xesmf_available() -> bool: - try: - import xesmf # noqa: F401 - except ImportError: - return False - return True - - -@pytest.mark.skipif(not xesmf_available(), reason="xesmf required") +@pytest.mark.skipif(xesmf is None, reason="xesmf required") def test_conservative_nan_thresholds_against_xesmf(): - import xesmf as xe - - ds = xr.tutorial.open_dataset("ersstv5").sst.compute() + ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).persist() ds = ds.rename(lon="longitude", lat="latitude") new_grid = xarray_regrid.Grid( north=90, @@ -235,16 +214,14 @@ def test_conservative_nan_thresholds_against_xesmf(): resolution_lon=2, ) target_dataset = xarray_regrid.create_regridding_dataset(new_grid) - regridder = xe.Regridder(ds, target_dataset, "conservative") + regridder = xesmf.Regridder(ds, target_dataset, "conservative") for nan_threshold in [0.0, 0.25, 0.5, 0.75, 1.0]: - data_regrid = ds.copy().regrid.conservative( + data_regrid = ds.regrid.conservative( target_dataset, skipna=True, nan_threshold=nan_threshold ) - data_esmf = regridder( - ds.copy(), keep_attrs=True, na_thres=nan_threshold, skipna=True - ) - assert (data_regrid.isnull() == data_esmf.isnull()).mean().values > 0.995 + data_esmf = regridder(ds, keep_attrs=True, na_thres=nan_threshold, skipna=True) + xr.testing.assert_equal(data_regrid.isnull(), data_esmf.isnull()) class TestCoordOrder: From ba27714a583b5a107c9b7fd164910542f97f7454 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 6 Sep 2024 16:33:39 +0000 Subject: [PATCH 2/8] fix netcdf bug --- tests/test_regrid.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 0296f6a..1ba2f48 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -27,20 +27,22 @@ @pytest.fixture(scope="session") def sample_input_data() -> xr.Dataset: ds = xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc") - return ds.isel(time=slice(0, N_TIMESTAMPS)).persist() + # slice after compute due to current xarray bug: https://github.com/pydata/xarray/issues/8909 + # then convert to dask so regridding is lazy on attribute tests + return ds.compute().isel(time=slice(0, N_TIMESTAMPS)).chunk() @pytest.fixture(scope="session") def conservative_input_data() -> xr.Dataset: ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") - return ds.isel(time=slice(0, N_TIMESTAMPS)).persist() + return ds.compute().isel(time=slice(0, N_TIMESTAMPS)).chunk() @pytest.fixture(scope="session") def cdo_comparison_data() -> dict[str, xr.Dataset]: data = {} for method, path in CDO_FILES.items(): - data[method] = xr.open_dataset(path).isel(time=slice(0, N_TIMESTAMPS)).persist() + data[method] = xr.open_dataset(path).compute().isel(time=slice(0, N_TIMESTAMPS)) return data @@ -203,7 +205,7 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold): @pytest.mark.skipif(xesmf is None, reason="xesmf required") def test_conservative_nan_thresholds_against_xesmf(): - ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).persist() + ds = xr.tutorial.open_dataset("ersstv5").sst.compute().isel(time=[0]) ds = ds.rename(lon="longitude", lat="latitude") new_grid = xarray_regrid.Grid( north=90, From 644604822e5bbed343010e1a4d4f79e6dcc15818 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 6 Sep 2024 16:47:14 +0000 Subject: [PATCH 3/8] revert to keeping all test slices --- tests/test_regrid.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 1ba2f48..ccab1f4 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -20,29 +20,25 @@ "conservative": DATA_PATH / "cdo_conservative_64b.nc", } -# Sample files contain 12 monthly timestamps but subset to one for speed -N_TIMESTAMPS = 1 - @pytest.fixture(scope="session") def sample_input_data() -> xr.Dataset: ds = xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc") - # slice after compute due to current xarray bug: https://github.com/pydata/xarray/issues/8909 - # then convert to dask so regridding is lazy on attribute tests - return ds.compute().isel(time=slice(0, N_TIMESTAMPS)).chunk() + # Convert to dask so regridding is lazy for attr-only tests + return ds.compute().chunk() @pytest.fixture(scope="session") def conservative_input_data() -> xr.Dataset: ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") - return ds.compute().isel(time=slice(0, N_TIMESTAMPS)).chunk() + return ds.compute().chunk() @pytest.fixture(scope="session") def cdo_comparison_data() -> dict[str, xr.Dataset]: data = {} for method, path in CDO_FILES.items(): - data[method] = xr.open_dataset(path).compute().isel(time=slice(0, N_TIMESTAMPS)) + data[method] = xr.open_dataset(path).compute() return data @@ -108,7 +104,7 @@ def test_conservative_regridder( ds_regrid["tp"], ds_cdo["tp"], rtol=0.002, - atol=2e-5, + atol=4e-5, ) @@ -127,7 +123,7 @@ def test_conservative_nans( ds_regrid["tp"].sel(no_nans), ds_cdo["tp"].sel(no_nans), rtol=0.002, - atol=2e-5, + atol=4e-5, ) From 283c3102d48f7f2f32c34d539c29b8cf78998010 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 13 Sep 2024 14:17:51 -0400 Subject: [PATCH 4/8] refactor to separate coord handling functions, appease mypy --- src/xarray_regrid/methods/conservative.py | 2 +- src/xarray_regrid/utils.py | 198 ++++++++++++---------- 2 files changed, 106 insertions(+), 94 deletions(-) diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index deec777..54946a8 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -66,7 +66,7 @@ def conservative_regrid( # Attempt to infer the latitude coordinate if latitude_coord is None: for coord in data.coords: - if str(coord).lower().startswith("lat"): + if str(coord).lower() in ["lat", "latitude"]: latitude_coord = coord break diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index acf89cf..9c5a4b7 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Hashable from dataclasses import dataclass -from typing import Any, overload +from typing import Any, TypedDict, overload import numpy as np import pandas as pd @@ -10,6 +10,11 @@ class InvalidBoundsError(Exception): ... +class CoordHandler(TypedDict): + names: list[str] + func: Callable + + @dataclass class Grid: """Object storing grid information.""" @@ -241,112 +246,119 @@ def format_for_regrid( obj: xr.DataArray | xr.Dataset, target: xr.Dataset ) -> xr.DataArray | xr.Dataset: """Apply any pre-formatting to the input dataset to prepare for regridding. - Currently handles padding of spherical geometry if appropriate coordinate names - can be inferred containing 'lat' and 'lon'. + Currently handles padding of spherical geometry if lat/lon coordinates can + be inferred and the domain size requires boundary padding. """ - lat_coord = None - lon_coord = None - - for coord in obj.coords.keys(): - if str(coord).lower().startswith("lat"): - lat_coord = coord - elif str(coord).lower().startswith("lon"): - lon_coord = coord + orig_chunksizes = obj.chunksizes - if lon_coord is not None or lat_coord is not None: - obj = format_spherical(obj, target, lat_coord, lon_coord) + # Special-cased coordinates with accepted names and formatting function + coord_handlers: dict[str, CoordHandler] = { + "lat": {"names": ["lat", "latitude"], "func": format_lat}, + "lon": {"names": ["lon", "longitude"], "func": format_lon}, + } + # Identify coordinates that need to be formatted + formatted_coords = {} + for coord_type, handler in coord_handlers.items(): + for coord in obj.coords.keys(): + if str(coord).lower() in handler["names"]: + formatted_coords[coord_type] = str(coord) + + # Apply formatting + for coord_type, coord in formatted_coords.items(): + # Make sure formatted coords are sorted + obj = obj.sortby(coord) + target = target.sortby(coord) + obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords) + # Coerce back to a single chunk if that's what was passed + if len(orig_chunksizes.get(coord, [])) == 1: + obj = obj.chunk({coord: -1}) return obj -def format_spherical( +def format_lat( obj: xr.DataArray | xr.Dataset, - target: xr.Dataset, - lat_coord: Hashable, - lon_coord: Hashable, + target: xr.Dataset, # noqa ARG001 + formatted_coords: dict[str, str], ) -> xr.DataArray | xr.Dataset: - """Infer whether a lat/lon source grid represents a global domain and - automatically apply spherical padding to improve edge effects. - - For longitude, shift the coordinate to line up with the target values, then - add a single wraparound padding column if the domain is global and the east - or west edges of the target lie outside the source grid centers. - - For latitude, add a single value at each pole computed as the mean of the last + """For latitude, add a single value at each pole computed as the mean of the last row for global source grids where the first or last point lie equatorward of 90. """ + lat_coord = formatted_coords["lat"] + lon_coord = formatted_coords.get("lon") + + # Concat a padded value representing the mean of the first/last lat bands + # This should match the Pole="all" option of ESMF + # TODO: with cos(90) = 0 weighting, these weights might be 0? + + polar_lat = 90 + dy = obj.coords[lat_coord].diff(lat_coord).max().values.item() + + # Only pad if global but don't have edge values directly at poles + # South pole + if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat: + south_pole = obj.isel({lat_coord: 0}) + if lon_coord is not None: + south_pole = south_pole.mean(lon_coord) + obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore + obj.coords[lat_coord].values[0] = -polar_lat + + # North pole + if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat: + north_pole = obj.isel({lat_coord: -1}) + if lon_coord is not None: + north_pole = north_pole.mean(lon_coord) + obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore + obj.coords[lat_coord].values[-1] = polar_lat - orig_chunksizes = obj.chunksizes + return obj - # If the source coord fully covers the target, don't modify them - if lat_coord and not coord_is_covered(obj, target, lat_coord): - obj = obj.sortby(lat_coord) - target = target.sortby(lat_coord) - - # Only pad if global but don't have edge values directly at poles - polar_lat = 90 - dy = obj[lat_coord].diff(lat_coord).max().values - - # South pole - if dy - polar_lat >= obj[lat_coord][0] > -polar_lat: - south_pole = obj.isel({lat_coord: 0}) - # This should match the Pole="all" option of ESMF - if lon_coord is not None: - south_pole = south_pole.mean(lon_coord) - obj = xr.concat([south_pole, obj], dim=lat_coord) - obj[lat_coord].values[0] = -polar_lat - - # North pole - if polar_lat - dy <= obj[lat_coord][-1] < polar_lat: - north_pole = obj.isel({lat_coord: -1}) - if lon_coord is not None: - north_pole = north_pole.mean(lon_coord) - obj = xr.concat([obj, north_pole], dim=lat_coord) - obj[lat_coord].values[-1] = polar_lat - # Coerce back to a single chunk if that's what was passed - if len(orig_chunksizes.get(lat_coord, [])) == 1: - obj = obj.chunk({lat_coord: -1}) - - if lon_coord and not coord_is_covered(obj, target, lon_coord): - obj = obj.sortby(lon_coord) - target = target.sortby(lon_coord) - - target_lon = target[lon_coord].values - # Find a wrap point outside of the left and right bounds of the target - # This ensures we have coverage on the target and handles global > regional - wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2 - lon = obj[lon_coord].values - lon = np.where(lon < wrap_point - 360, lon + 360, lon) - lon = np.where(lon > wrap_point, lon - 360, lon) - obj[lon_coord].values[:] = lon - - # Shift operations can produce duplicates - # Simplest solution is to drop them and add back when padding - obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) - - # Only pad if domain is global in lon - dx_s = obj[lon_coord].diff(lon_coord).max().values - dx_t = target[lon_coord].diff(lon_coord).max().values - is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s - - if is_global_lon: - left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s - right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s - left_pad = int(np.ceil(np.max([left_pad, 0]))) - right_pad = int(np.ceil(np.max([right_pad, 0]))) - lon = obj[lon_coord].values - obj = obj.pad( - {lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True +def format_lon( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str] +) -> xr.DataArray | xr.Dataset: + """For longitude, shift the coordinate to line up with the target values, then + add a single wraparound padding column if the domain is global and the east + or west edges of the target lie outside the source grid centers. + """ + lon_coord = formatted_coords["lon"] + + # Find a wrap point outside of the left and right bounds of the target + # This ensures we have coverage on the target and handles global > regional + source_vals = obj.coords[lon_coord].values + target_vals = target.coords[lon_coord].values + wrap_point = (target_vals[-1] + target_vals[0] + 360) / 2 + source_vals = np.where( + source_vals < wrap_point - 360, source_vals + 360, source_vals + ) + source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals) + obj.coords[lon_coord].values[:] = source_vals + + # Shift operations can produce duplicates + # Simplest solution is to drop them and add back when padding + obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) + + # Only pad if domain is global in lon + source_lon = obj.coords[lon_coord] + target_lon = target.coords[lon_coord] + dx_s = source_lon.diff(lon_coord).max().values.item() + dx_t = target_lon.diff(lon_coord).max().values.item() + is_global_lon = source_lon.max().values - source_lon.min().values >= 360 - dx_s + + if is_global_lon: + left_pad = (source_lon.values[0] - target_lon.values[0] + dx_t / 2) / dx_s + right_pad = (target_lon.values[-1] - source_lon.values[-1] + dx_t / 2) / dx_s + left_pad = int(np.ceil(np.max([left_pad, 0]))) + right_pad = int(np.ceil(np.max([right_pad, 0]))) + obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True) + if left_pad: + obj.coords[lon_coord].values[:left_pad] = ( + source_lon.values[-left_pad:] - 360 + ) + if right_pad: + obj.coords[lon_coord].values[-right_pad:] = ( + source_lon.values[:right_pad] + 360 ) - if left_pad: - obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360 - if right_pad: - obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360 - - # Coerce back to a single chunk if that's what was passed - if len(orig_chunksizes.get(lon_coord, [])) == 1: - obj = obj.chunk({lon_coord: -1}) return obj From 7b44358e527e942753af2c4546d16d0f3c0518e1 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 13 Sep 2024 15:09:06 -0400 Subject: [PATCH 5/8] add examples to docstrings --- src/xarray_regrid/utils.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index 9c5a4b7..ccdb0af 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -281,8 +281,17 @@ def format_lat( target: xr.Dataset, # noqa ARG001 formatted_coords: dict[str, str], ) -> xr.DataArray | xr.Dataset: - """For latitude, add a single value at each pole computed as the mean of the last - row for global source grids where the first or last point lie equatorward of 90. + """If the latitude coordinate is inferred to be global, defined as having + a value within one grid spacing of the poles, and the grid does not natively + have values at -90 and 90, add a single value at each pole computed as the + mean of the first and last latitude bands. This should be roughly equivalent + to the `Pole="all"` option in `ESMF`. + + For example, with a grid spacing of 1 degree, and a source grid ranging from + -89.5 to 89.5, the poles would be padded with values at -90 and 90. A grid ranging + from -88 to 88 would not be padded because coverage does not extend all the way + to the poles. A grid ranging from -90 to 90 would also not be padded because the + poles will already be covered in the regridding weights. """ lat_coord = formatted_coords["lat"] lon_coord = formatted_coords.get("lon") @@ -295,6 +304,8 @@ def format_lat( dy = obj.coords[lat_coord].diff(lat_coord).max().values.item() # Only pad if global but don't have edge values directly at poles + # NOTE: could use xr.pad here instead of xr.concat, but none of the + # modes are an exact fit for this scheme # South pole if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat: south_pole = obj.isel({lat_coord: 0}) @@ -317,9 +328,15 @@ def format_lat( def format_lon( obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str] ) -> xr.DataArray | xr.Dataset: - """For longitude, shift the coordinate to line up with the target values, then - add a single wraparound padding column if the domain is global and the east - or west edges of the target lie outside the source grid centers. + """Format the longitude coordinate by shifting the source grid to line up with + the target anywhere in the range of -360 to 360, and then add a single wraparound + padding column if the domain is inferred to be global and the east or west edges + of the target lie outside the source grid centers. + + For example, with a source grid ranging from 0.5 to 359.5 and a target grid ranging + from -180 to 180, the source grid would be shifted to -179.5 to 179.5 and then + padded on both the left and right with wraparound values at -180.5 and 180.5 to + provide full coverage for the target edge cells at -180 and 180. """ lon_coord = formatted_coords["lon"] From a5788bf4de190a338bed8cd43b9127f2d6d5b7b2 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 19 Sep 2024 13:02:47 +0000 Subject: [PATCH 6/8] fix modifying coordinates --- src/xarray_regrid/methods/conservative.py | 6 +- src/xarray_regrid/utils.py | 94 +++++++++++------------ 2 files changed, 51 insertions(+), 49 deletions(-) diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 54946a8..c620058 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -72,8 +72,10 @@ def conservative_regrid( # Make sure the regridding coordinates are sorted coord_names = [coord for coord in target_ds.coords if coord in data.coords] - target_ds_sorted = target_ds.sortby(coord_names) - data = data.sortby(list(coord_names)) + target_ds_sorted = target_ds.copy() + for coord_name in coord_names: + target_ds_sorted = utils.ensure_monotonic(target_ds_sorted, coord_name) + data = utils.ensure_monotonic(data, coord_name) coords = {name: target_ds_sorted[name] for name in coord_names} regridded_data = utils.call_on_dataset( diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index ccdb0af..24f9173 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,11 +1,13 @@ from collections.abc import Callable, Hashable from dataclasses import dataclass -from typing import Any, TypedDict, overload +from typing import Any, TypedDict, TypeVar import numpy as np import pandas as pd import xarray as xr +XarrayData = TypeVar("XarrayData", xr.DataArray, xr.Dataset) + class InvalidBoundsError(Exception): ... @@ -187,8 +189,8 @@ def create_dot_dataarray( def common_coords( - data1: xr.DataArray | xr.Dataset, - data2: xr.DataArray | xr.Dataset, + data1: XarrayData, + data2: XarrayData, remove_coord: str | None = None, ) -> list[str]: """Return a set of coords which two dataset/arrays have in common.""" @@ -198,30 +200,12 @@ def common_coords( return sorted([str(coord) for coord in coords]) -@overload -def call_on_dataset( - func: Callable[..., xr.Dataset], - obj: xr.DataArray, - *args: Any, - **kwargs: Any, -) -> xr.DataArray: ... - - -@overload -def call_on_dataset( - func: Callable[..., xr.Dataset], - obj: xr.Dataset, - *args: Any, - **kwargs: Any, -) -> xr.Dataset: ... - - def call_on_dataset( func: Callable[..., xr.Dataset], - obj: xr.DataArray | xr.Dataset, + obj: XarrayData, *args: Any, **kwargs: Any, -) -> xr.DataArray | xr.Dataset: +) -> XarrayData: """Use to call a function that expects a Dataset on either a Dataset or DataArray, round-tripping to a temporary dataset.""" placeholder_name = "_UNNAMED_ARRAY" @@ -242,9 +226,7 @@ def call_on_dataset( return result -def format_for_regrid( - obj: xr.DataArray | xr.Dataset, target: xr.Dataset -) -> xr.DataArray | xr.Dataset: +def format_for_regrid(obj: XarrayData, target: xr.Dataset) -> XarrayData: """Apply any pre-formatting to the input dataset to prepare for regridding. Currently handles padding of spherical geometry if lat/lon coordinates can be inferred and the domain size requires boundary padding. @@ -266,8 +248,8 @@ def format_for_regrid( # Apply formatting for coord_type, coord in formatted_coords.items(): # Make sure formatted coords are sorted - obj = obj.sortby(coord) - target = target.sortby(coord) + obj = ensure_monotonic(obj, coord) + target = ensure_monotonic(target, coord) obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords) # Coerce back to a single chunk if that's what was passed if len(orig_chunksizes.get(coord, [])) == 1: @@ -277,10 +259,10 @@ def format_for_regrid( def format_lat( - obj: xr.DataArray | xr.Dataset, + obj: XarrayData, target: xr.Dataset, # noqa ARG001 formatted_coords: dict[str, str], -) -> xr.DataArray | xr.Dataset: +) -> XarrayData: """If the latitude coordinate is inferred to be global, defined as having a value within one grid spacing of the poles, and the grid does not natively have values at -90 and 90, add a single value at each pole computed as the @@ -306,13 +288,14 @@ def format_lat( # Only pad if global but don't have edge values directly at poles # NOTE: could use xr.pad here instead of xr.concat, but none of the # modes are an exact fit for this scheme + lat_vals = obj.coords[lat_coord].values # South pole if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat: south_pole = obj.isel({lat_coord: 0}) if lon_coord is not None: south_pole = south_pole.mean(lon_coord) obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore - obj.coords[lat_coord].values[0] = -polar_lat + lat_vals = np.concatenate([[-polar_lat], lat_vals]) # North pole if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat: @@ -320,14 +303,16 @@ def format_lat( if lon_coord is not None: north_pole = north_pole.mean(lon_coord) obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore - obj.coords[lat_coord].values[-1] = polar_lat + lat_vals = np.concatenate([lat_vals, [polar_lat]]) + + obj = update_coord(obj, lat_coord, lat_vals) return obj def format_lon( - obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str] -) -> xr.DataArray | xr.Dataset: + obj: XarrayData, target: xr.Dataset, formatted_coords: dict[str, str] +) -> XarrayData: """Format the longitude coordinate by shifting the source grid to line up with the target anywhere in the range of -360 to 360, and then add a single wraparound padding column if the domain is inferred to be global and the east or west edges @@ -349,11 +334,9 @@ def format_lon( source_vals < wrap_point - 360, source_vals + 360, source_vals ) source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals) - obj.coords[lon_coord].values[:] = source_vals + obj = update_coord(obj, lon_coord, source_vals) - # Shift operations can produce duplicates - # Simplest solution is to drop them and add back when padding - obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) + obj = ensure_monotonic(obj, lon_coord) # Only pad if domain is global in lon source_lon = obj.coords[lon_coord] @@ -368,23 +351,40 @@ def format_lon( left_pad = int(np.ceil(np.max([left_pad, 0]))) right_pad = int(np.ceil(np.max([right_pad, 0]))) obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True) + lon_vals = obj.coords[lon_coord].values if left_pad: - obj.coords[lon_coord].values[:left_pad] = ( - source_lon.values[-left_pad:] - 360 - ) + lon_vals[:left_pad] = source_lon.values[-left_pad:] - 360 if right_pad: - obj.coords[lon_coord].values[-right_pad:] = ( - source_lon.values[:right_pad] + 360 - ) + lon_vals[-right_pad:] = source_lon.values[:right_pad] + 360 + obj = update_coord(obj, lon_coord, lon_vals) return obj -def coord_is_covered( - obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable -) -> bool: +def coord_is_covered(obj: XarrayData, target: xr.Dataset, coord: Hashable) -> bool: """Check if the source coord fully covers the target coord.""" pad = target[coord].diff(coord).max().values left_covered = obj[coord].min() <= target[coord].min() - pad right_covered = obj[coord].max() >= target[coord].max() + pad return bool(left_covered.item() and right_covered.item()) + + +def ensure_monotonic(obj: XarrayData, coord: Hashable) -> XarrayData: + """Ensure that an object has monotonically increasing indexes for a + given coordinate. Only sort and drop duplicates if needed because this + requires reindexing which can be expensive.""" + is_sorted = (obj[coord].diff(coord) >= 0).all().compute().item() + if not is_sorted: + obj = obj.sortby(coord) + has_duplicates = np.unique(obj[coord].values).size < obj[coord].values.size + if has_duplicates: + obj = obj.drop_duplicates(coord) + return obj + + +def update_coord(obj: XarrayData, coord: Hashable, coord_vals: np.array) -> XarrayData: + """Update the values of a coordinate, ensuring indexes stay in sync.""" + attrs = obj.coords[coord].attrs + obj = obj.assign_coords({coord: coord_vals}) + obj.coords[coord].attrs = attrs + return obj From 52e71c5d3ad427dba01a6601421df4565fed85c6 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 19 Sep 2024 14:06:38 +0000 Subject: [PATCH 7/8] fix typing --- src/xarray_regrid/utils.py | 62 +++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index 24f9173..f41c981 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,13 +1,11 @@ from collections.abc import Callable, Hashable from dataclasses import dataclass -from typing import Any, TypedDict, TypeVar +from typing import Any, TypedDict, overload import numpy as np import pandas as pd import xarray as xr -XarrayData = TypeVar("XarrayData", xr.DataArray, xr.Dataset) - class InvalidBoundsError(Exception): ... @@ -189,8 +187,8 @@ def create_dot_dataarray( def common_coords( - data1: XarrayData, - data2: XarrayData, + data1: xr.DataArray | xr.Dataset, + data2: xr.DataArray | xr.Dataset, remove_coord: str | None = None, ) -> list[str]: """Return a set of coords which two dataset/arrays have in common.""" @@ -202,10 +200,10 @@ def common_coords( def call_on_dataset( func: Callable[..., xr.Dataset], - obj: XarrayData, + obj: xr.DataArray | xr.Dataset, *args: Any, **kwargs: Any, -) -> XarrayData: +) -> xr.DataArray | xr.Dataset: """Use to call a function that expects a Dataset on either a Dataset or DataArray, round-tripping to a temporary dataset.""" placeholder_name = "_UNNAMED_ARRAY" @@ -226,7 +224,9 @@ def call_on_dataset( return result -def format_for_regrid(obj: XarrayData, target: xr.Dataset) -> XarrayData: +def format_for_regrid( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset +) -> xr.DataArray | xr.Dataset: """Apply any pre-formatting to the input dataset to prepare for regridding. Currently handles padding of spherical geometry if lat/lon coordinates can be inferred and the domain size requires boundary padding. @@ -259,10 +259,10 @@ def format_for_regrid(obj: XarrayData, target: xr.Dataset) -> XarrayData: def format_lat( - obj: XarrayData, + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, # noqa ARG001 formatted_coords: dict[str, str], -) -> XarrayData: +) -> xr.DataArray | xr.Dataset: """If the latitude coordinate is inferred to be global, defined as having a value within one grid spacing of the poles, and the grid does not natively have values at -90 and 90, add a single value at each pole computed as the @@ -311,8 +311,8 @@ def format_lat( def format_lon( - obj: XarrayData, target: xr.Dataset, formatted_coords: dict[str, str] -) -> XarrayData: + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str] +) -> xr.DataArray | xr.Dataset: """Format the longitude coordinate by shifting the source grid to line up with the target anywhere in the range of -360 to 360, and then add a single wraparound padding column if the domain is inferred to be global and the east or west edges @@ -361,7 +361,9 @@ def format_lon( return obj -def coord_is_covered(obj: XarrayData, target: xr.Dataset, coord: Hashable) -> bool: +def coord_is_covered( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable +) -> bool: """Check if the source coord fully covers the target coord.""" pad = target[coord].diff(coord).max().values left_covered = obj[coord].min() <= target[coord].min() - pad @@ -369,20 +371,46 @@ def coord_is_covered(obj: XarrayData, target: xr.Dataset, coord: Hashable) -> bo return bool(left_covered.item() and right_covered.item()) -def ensure_monotonic(obj: XarrayData, coord: Hashable) -> XarrayData: +@overload +def ensure_monotonic(obj: xr.DataArray, coord: Hashable) -> xr.DataArray: ... + + +@overload +def ensure_monotonic(obj: xr.Dataset, coord: Hashable) -> xr.Dataset: ... + + +def ensure_monotonic( + obj: xr.DataArray | xr.Dataset, coord: Hashable +) -> xr.DataArray | xr.Dataset: """Ensure that an object has monotonically increasing indexes for a given coordinate. Only sort and drop duplicates if needed because this requires reindexing which can be expensive.""" - is_sorted = (obj[coord].diff(coord) >= 0).all().compute().item() + is_sorted = (obj.coords[coord].diff(coord) >= 0).all().compute().item() if not is_sorted: obj = obj.sortby(coord) - has_duplicates = np.unique(obj[coord].values).size < obj[coord].values.size + has_duplicates = ( + np.unique(obj.coords[coord].values).size < obj.coords[coord].values.size + ) if has_duplicates: obj = obj.drop_duplicates(coord) return obj -def update_coord(obj: XarrayData, coord: Hashable, coord_vals: np.array) -> XarrayData: +@overload +def update_coord( + obj: xr.DataArray, coord: Hashable, coord_vals: np.ndarray +) -> xr.DataArray: ... + + +@overload +def update_coord( + obj: xr.Dataset, coord: Hashable, coord_vals: np.ndarray +) -> xr.Dataset: ... + + +def update_coord( + obj: xr.DataArray | xr.Dataset, coord: Hashable, coord_vals: np.ndarray +) -> xr.DataArray | xr.Dataset: """Update the values of a coordinate, ensuring indexes stay in sync.""" attrs = obj.coords[coord].attrs obj = obj.assign_coords({coord: coord_vals}) From 1f2e99956302f15841d65fb9a556d759cd9c501b Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 20 Sep 2024 12:20:11 +0000 Subject: [PATCH 8/8] review suggestions --- src/xarray_regrid/methods/conservative.py | 2 +- src/xarray_regrid/utils.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index c620058..e3f421a 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -72,7 +72,7 @@ def conservative_regrid( # Make sure the regridding coordinates are sorted coord_names = [coord for coord in target_ds.coords if coord in data.coords] - target_ds_sorted = target_ds.copy() + target_ds_sorted = xr.Dataset(coords=target_ds.coords) for coord_name in coord_names: target_ds_sorted = utils.ensure_monotonic(target_ds_sorted, coord_name) data = utils.ensure_monotonic(data, coord_name) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index f41c981..b507310 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -385,13 +385,9 @@ def ensure_monotonic( """Ensure that an object has monotonically increasing indexes for a given coordinate. Only sort and drop duplicates if needed because this requires reindexing which can be expensive.""" - is_sorted = (obj.coords[coord].diff(coord) >= 0).all().compute().item() - if not is_sorted: + if not obj.indexes[coord].is_monotonic_increasing: obj = obj.sortby(coord) - has_duplicates = ( - np.unique(obj.coords[coord].values).size < obj.coords[coord].values.size - ) - if has_duplicates: + if not obj.indexes[coord].is_unique: obj = obj.drop_duplicates(coord) return obj