From ce16f746d2fa3fbe607f869796a7ca43406ee34a Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 6 Sep 2024 14:56:48 +0000 Subject: [PATCH] initial pass at spherical padding, faster tests, full nan tracking --- src/xarray_regrid/methods/conservative.py | 7 +- src/xarray_regrid/regrid.py | 21 ++- src/xarray_regrid/utils.py | 128 +++++++++++++++- tests/test_format.py | 178 ++++++++++++++++++++++ tests/test_regrid.py | 159 +++++++++---------- 5 files changed, 388 insertions(+), 105 deletions(-) create mode 100644 tests/test_format.py diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index eb96acf..deec777 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -66,7 +66,7 @@ def conservative_regrid( # Attempt to infer the latitude coordinate if latitude_coord is None: for coord in data.coords: - if str(coord).lower() in ["lat", "latitude"]: + if str(coord).lower().startswith("lat"): latitude_coord = coord break @@ -122,7 +122,6 @@ def conservative_regrid_dataset( weights = apply_spherical_correction(weights, latitude_coord) for array in data_vars.keys(): - non_grid_dims = [d for d in data_vars[array].dims if d not in coords] if coord in data_vars[array].dims: data_vars[array], valid_fracs[array] = apply_weights( da=data_vars[array], @@ -130,7 +129,6 @@ def conservative_regrid_dataset( coord=coord, valid_frac=valid_fracs[array], skipna=skipna, - non_grid_dims=non_grid_dims, ) # Mask out any regridded points outside the original domain data_vars[array] = data_vars[array].where(covered_grid) @@ -161,7 +159,6 @@ def apply_weights( coord: Hashable, valid_frac: xr.DataArray, skipna: bool, - non_grid_dims: list[Hashable], ) -> tuple[xr.DataArray, xr.DataArray]: """Apply the weights to convert data to the new coordinates.""" coord_map = {f"target_{coord}": coord} @@ -169,8 +166,6 @@ def apply_weights( if skipna: notnull = da.notnull() - if non_grid_dims: - notnull = notnull.any(non_grid_dims) # Renormalize the weights along this dim by the accumulated valid_frac # along previous dimensions if valid_frac.name != EMPTY_DA_NAME: diff --git a/src/xarray_regrid/regrid.py b/src/xarray_regrid/regrid.py index e91348a..fe0f214 100644 --- a/src/xarray_regrid/regrid.py +++ b/src/xarray_regrid/regrid.py @@ -1,6 +1,7 @@ import xarray as xr from xarray_regrid.methods import conservative, interp, most_common +from xarray_regrid.utils import format_for_regrid @xr.register_dataarray_accessor("regrid") @@ -34,7 +35,8 @@ def linear( Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return interp.interp_regrid(self._obj, ds_target_grid, "linear") + ds_formatted = format_for_regrid(self._obj, ds_target_grid) + return interp.interp_regrid(ds_formatted, ds_target_grid, "linear") def nearest( self, @@ -51,14 +53,14 @@ def nearest( Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) - return interp.interp_regrid(self._obj, ds_target_grid, "nearest") + ds_formatted = format_for_regrid(self._obj, ds_target_grid) + return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest") def cubic( self, ds_target_grid: xr.Dataset, time_dim: str = "time", ) -> xr.DataArray | xr.Dataset: - ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) """Regrid to the coords of the target dataset with cubic interpolation. Args: @@ -68,7 +70,9 @@ def cubic( Returns: Data regridded to the target dataset coordinates. """ - return interp.interp_regrid(self._obj, ds_target_grid, "cubic") + ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) + ds_formatted = format_for_regrid(self._obj, ds_target_grid) + return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic") def conservative( self, @@ -88,6 +92,9 @@ def conservative( time_dim: The name of the time dimension/coordinate. skipna: If True, enable handling for NaN values. This adds some overhead, so can be disabled for optimal performance on data without any NaNs. + With `skipna=True, chunking is recommended in the non-grid dimensions, + otherwise the intermediate arrays that track the fraction of valid data + can become very large and consume excessive memory. Warning: with `skipna=False`, isolated NaNs will propagate throughout the dataset due to the sequential regridding scheme over each dimension. nan_threshold: Threshold value that will retain any output points @@ -104,8 +111,9 @@ def conservative( raise ValueError(msg) ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) + ds_formatted = format_for_regrid(self._obj, ds_target_grid) return conservative.conservative_regrid( - self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold + ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold ) def most_common( @@ -134,8 +142,9 @@ def most_common( Regridded data. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) + ds_formatted = format_for_regrid(self._obj, ds_target_grid) return most_common.most_common_wrapper( - self._obj, ds_target_grid, time_dim, max_mem + ds_formatted, ds_target_grid, time_dim, max_mem ) diff --git a/src/xarray_regrid/utils.py b/src/xarray_regrid/utils.py index ce07f88..acf89cf 100644 --- a/src/xarray_regrid/utils.py +++ b/src/xarray_regrid/utils.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Callable, Hashable from dataclasses import dataclass from typing import Any, overload @@ -75,7 +75,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]: grid.south, grid.north + grid.resolution_lat, grid.resolution_lat ) - if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0: + if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0: lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon) else: lon_coords = np.arange( @@ -235,3 +235,127 @@ def call_on_dataset( return next(iter(result.data_vars.values())).rename(obj.name) return result + + +def format_for_regrid( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset +) -> xr.DataArray | xr.Dataset: + """Apply any pre-formatting to the input dataset to prepare for regridding. + Currently handles padding of spherical geometry if appropriate coordinate names + can be inferred containing 'lat' and 'lon'. + """ + lat_coord = None + lon_coord = None + + for coord in obj.coords.keys(): + if str(coord).lower().startswith("lat"): + lat_coord = coord + elif str(coord).lower().startswith("lon"): + lon_coord = coord + + if lon_coord is not None or lat_coord is not None: + obj = format_spherical(obj, target, lat_coord, lon_coord) + + return obj + + +def format_spherical( + obj: xr.DataArray | xr.Dataset, + target: xr.Dataset, + lat_coord: Hashable, + lon_coord: Hashable, +) -> xr.DataArray | xr.Dataset: + """Infer whether a lat/lon source grid represents a global domain and + automatically apply spherical padding to improve edge effects. + + For longitude, shift the coordinate to line up with the target values, then + add a single wraparound padding column if the domain is global and the east + or west edges of the target lie outside the source grid centers. + + For latitude, add a single value at each pole computed as the mean of the last + row for global source grids where the first or last point lie equatorward of 90. + """ + + orig_chunksizes = obj.chunksizes + + # If the source coord fully covers the target, don't modify them + if lat_coord and not coord_is_covered(obj, target, lat_coord): + obj = obj.sortby(lat_coord) + target = target.sortby(lat_coord) + + # Only pad if global but don't have edge values directly at poles + polar_lat = 90 + dy = obj[lat_coord].diff(lat_coord).max().values + + # South pole + if dy - polar_lat >= obj[lat_coord][0] > -polar_lat: + south_pole = obj.isel({lat_coord: 0}) + # This should match the Pole="all" option of ESMF + if lon_coord is not None: + south_pole = south_pole.mean(lon_coord) + obj = xr.concat([south_pole, obj], dim=lat_coord) + obj[lat_coord].values[0] = -polar_lat + + # North pole + if polar_lat - dy <= obj[lat_coord][-1] < polar_lat: + north_pole = obj.isel({lat_coord: -1}) + if lon_coord is not None: + north_pole = north_pole.mean(lon_coord) + obj = xr.concat([obj, north_pole], dim=lat_coord) + obj[lat_coord].values[-1] = polar_lat + + # Coerce back to a single chunk if that's what was passed + if len(orig_chunksizes.get(lat_coord, [])) == 1: + obj = obj.chunk({lat_coord: -1}) + + if lon_coord and not coord_is_covered(obj, target, lon_coord): + obj = obj.sortby(lon_coord) + target = target.sortby(lon_coord) + + target_lon = target[lon_coord].values + # Find a wrap point outside of the left and right bounds of the target + # This ensures we have coverage on the target and handles global > regional + wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2 + lon = obj[lon_coord].values + lon = np.where(lon < wrap_point - 360, lon + 360, lon) + lon = np.where(lon > wrap_point, lon - 360, lon) + obj[lon_coord].values[:] = lon + + # Shift operations can produce duplicates + # Simplest solution is to drop them and add back when padding + obj = obj.sortby(lon_coord).drop_duplicates(lon_coord) + + # Only pad if domain is global in lon + dx_s = obj[lon_coord].diff(lon_coord).max().values + dx_t = target[lon_coord].diff(lon_coord).max().values + is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s + + if is_global_lon: + left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s + right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s + left_pad = int(np.ceil(np.max([left_pad, 0]))) + right_pad = int(np.ceil(np.max([right_pad, 0]))) + lon = obj[lon_coord].values + obj = obj.pad( + {lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True + ) + if left_pad: + obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360 + if right_pad: + obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360 + + # Coerce back to a single chunk if that's what was passed + if len(orig_chunksizes.get(lon_coord, [])) == 1: + obj = obj.chunk({lon_coord: -1}) + + return obj + + +def coord_is_covered( + obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable +) -> bool: + """Check if the source coord fully covers the target coord.""" + pad = target[coord].diff(coord).max().values + left_covered = obj[coord].min() <= target[coord].min() - pad + right_covered = obj[coord].max() >= target[coord].max() + pad + return bool(left_covered.item() and right_covered.item()) diff --git a/tests/test_format.py b/tests/test_format.py new file mode 100644 index 0000000..4fabde9 --- /dev/null +++ b/tests/test_format.py @@ -0,0 +1,178 @@ +import xarray as xr + +import xarray_regrid +from xarray_regrid.utils import format_for_regrid + + +def test_covered(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=80, + east=350, + south=-80, + west=10, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Formatting utils shouldn't modify this one at all + xr.testing.assert_equal(source, formatted) + + +def test_no_edges(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90 - dx_source / 2, + east=360 - dx_source / 2, + south=-90 + dx_source / 2, + west=0 + dx_source / 2, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should add wraparound and polar padding rows/columns + assert formatted.latitude[0] == -90 + assert formatted.latitude[-1] == 90 + assert formatted.longitude[0] == -1 + assert formatted.longitude[-1] == 361 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_360_to_180(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=180, + south=-90, + west=-180, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to target plus wraparound padding + assert formatted.longitude[0] == -182 + assert formatted.longitude[-1] == 182 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_180_to_360(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=180, + south=-90, + west=-180, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to target plus wraparound padding + assert formatted.longitude[0] == -2 + assert formatted.longitude[-1] == 362 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_0_to_360(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=0, + south=-90, + west=-360, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=360, + south=-90, + west=0, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to target plus wraparound padding + assert formatted.longitude[0] == -2 + assert formatted.longitude[-1] == 362 + assert (formatted.longitude.diff("longitude") == 2).all() + + +def test_global_to_local_shift(): + dx_source = 2 + source = xarray_regrid.Grid( + north=90, + east=180, + south=-90, + west=-180, + resolution_lat=dx_source, + resolution_lon=dx_source, + ).create_regridding_dataset() + + dx_target = 1 + target = xarray_regrid.Grid( + north=90, + east=300, + south=-90, + west=270, + resolution_lat=dx_target, + resolution_lon=dx_target, + ).create_regridding_dataset() + + formatted = format_for_regrid(source, target) + + # Should produce a shift to cover the target range + assert formatted.longitude.min() <= 270 + assert formatted.longitude.max() >= 300 + assert (formatted.longitude.diff("longitude") == 2).all() diff --git a/tests/test_regrid.py b/tests/test_regrid.py index bf5f59b..0296f6a 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -1,4 +1,3 @@ -from copy import deepcopy from pathlib import Path import numpy as np @@ -6,29 +5,46 @@ import xarray as xr from numpy.testing import assert_array_equal +try: + import xesmf +except ImportError: + xesmf = None + import xarray_regrid DATA_PATH = Path(__file__).parent.parent / "docs" / "notebooks" / "benchmarks" / "data" -CDO_DATA = { +CDO_FILES = { "linear": DATA_PATH / "cdo_bilinear_64b.nc", "nearest": DATA_PATH / "cdo_nearest_64b.nc", "conservative": DATA_PATH / "cdo_conservative_64b.nc", } +# Sample files contain 12 monthly timestamps but subset to one for speed +N_TIMESTAMPS = 1 + @pytest.fixture(scope="session") -def load_input_data() -> xr.Dataset: +def sample_input_data() -> xr.Dataset: ds = xr.open_dataset(DATA_PATH / "era5_2m_dewpoint_temperature_2000_monthly.nc") - return ds.compute() + return ds.isel(time=slice(0, N_TIMESTAMPS)).persist() -@pytest.fixture -def sample_input_data(load_input_data) -> xr.Dataset: - return deepcopy(load_input_data) +@pytest.fixture(scope="session") +def conservative_input_data() -> xr.Dataset: + ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") + return ds.isel(time=slice(0, N_TIMESTAMPS)).persist() -@pytest.fixture +@pytest.fixture(scope="session") +def cdo_comparison_data() -> dict[str, xr.Dataset]: + data = {} + for method, path in CDO_FILES.items(): + data[method] = xr.open_dataset(path).isel(time=slice(0, N_TIMESTAMPS)).persist() + return data + + +@pytest.fixture(scope="session") def sample_grid_ds(): grid = xarray_regrid.Grid( north=90, @@ -42,48 +58,7 @@ def sample_grid_ds(): return xarray_regrid.create_regridding_dataset(grid) -@pytest.mark.parametrize( - "method, cdo_file", - [ - ("linear", CDO_DATA["linear"]), - ("nearest", CDO_DATA["nearest"]), - ], -) -def test_basic_regridders_ds(sample_input_data, sample_grid_ds, method, cdo_file): - """Test the dataset regridders (except conservative).""" - regridder = getattr(sample_input_data.regrid, method) - ds_regrid = regridder(sample_grid_ds) - ds_cdo = xr.open_dataset(cdo_file) - xr.testing.assert_allclose(ds_regrid.compute(), ds_cdo.compute()) - - -@pytest.mark.parametrize( - "method, cdo_file", - [ - ("linear", CDO_DATA["linear"]), - ("nearest", CDO_DATA["nearest"]), - ], -) -def test_basic_regridders_da(sample_input_data, sample_grid_ds, method, cdo_file): - """Test the dataarray regridders (except conservative).""" - regridder = getattr(sample_input_data["d2m"].regrid, method) - da_regrid = regridder(sample_grid_ds) - ds_cdo = xr.open_dataset(cdo_file) - xr.testing.assert_allclose(da_regrid.compute(), ds_cdo["d2m"].compute()) - - @pytest.fixture(scope="session") -def load_conservative_input_data() -> xr.Dataset: - ds = xr.open_dataset(DATA_PATH / "era5_total_precipitation_2020_monthly.nc") - return ds.compute() - - -@pytest.fixture -def conservative_input_data(load_conservative_input_data) -> xr.Dataset: - return deepcopy(load_conservative_input_data) - - -@pytest.fixture def conservative_sample_grid(): grid = xarray_regrid.Grid( north=90, @@ -97,46 +72,60 @@ def conservative_sample_grid(): return xarray_regrid.create_regridding_dataset(grid) -def test_conservative_regridder(conservative_input_data, conservative_sample_grid): +@pytest.mark.parametrize("method", ["linear", "nearest"]) +def test_basic_regridders_ds( + sample_input_data, sample_grid_ds, cdo_comparison_data, method +): + """Test the dataset regridders (except conservative).""" + regridder = getattr(sample_input_data.regrid, method) + ds_regrid = regridder(sample_grid_ds) + ds_cdo = cdo_comparison_data[method] + xr.testing.assert_allclose(ds_regrid, ds_cdo, rtol=0.002, atol=2e-5) + + +@pytest.mark.parametrize("method", ["linear", "nearest"]) +def test_basic_regridders_da( + sample_input_data, sample_grid_ds, cdo_comparison_data, method +): + """Test the dataarray regridders (except conservative).""" + regridder = getattr(sample_input_data["d2m"].regrid, method) + da_regrid = regridder(sample_grid_ds) + da_cdo = cdo_comparison_data[method]["d2m"] + xr.testing.assert_allclose(da_regrid, da_cdo, rtol=0.002, atol=2e-5) + + +def test_conservative_regridder( + conservative_input_data, conservative_sample_grid, cdo_comparison_data +): ds_regrid = conservative_input_data.regrid.conservative( conservative_sample_grid, latitude_coord="latitude" ) - ds_cdo = xr.open_dataset(CDO_DATA["conservative"]) - - # Cut of the edges: edge performance to be improved later (hopefully) - no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)} + ds_cdo = cdo_comparison_data["conservative"] xr.testing.assert_allclose( - ds_regrid["tp"] - .sel(no_edges) - .compute() - .transpose("time", "latitude", "longitude"), - ds_cdo["tp"].sel(no_edges).compute(), + ds_regrid["tp"], + ds_cdo["tp"], rtol=0.002, - atol=2e-6, + atol=2e-5, ) -def test_conservative_nans(conservative_input_data, conservative_sample_grid): +def test_conservative_nans( + conservative_input_data, conservative_sample_grid, cdo_comparison_data +): ds = conservative_input_data ds["tp"] = ds["tp"].where(ds.latitude >= 0).where(ds.longitude < 180) ds_regrid = ds.regrid.conservative( conservative_sample_grid, latitude_coord="latitude" ) - ds_cdo = xr.open_dataset(CDO_DATA["conservative"]) + ds_cdo = cdo_comparison_data["conservative"] - # Cut of the edges: edge performance to be improved later (hopefully) - no_edges = {"latitude": slice(-85, 85), "longitude": slice(5, 355)} - no_nans = {"latitude": slice(1, 90), "longitude": slice(None, 179)} + no_nans = {"latitude": slice(1, 90), "longitude": slice(1, 179)} xr.testing.assert_allclose( - ds_regrid["tp"] - .sel(no_edges) - .sel(no_nans) - .compute() - .transpose("time", "latitude", "longitude"), - ds_cdo["tp"].sel(no_edges).sel(no_nans).compute(), + ds_regrid["tp"].sel(no_nans), + ds_cdo["tp"].sel(no_nans), rtol=0.002, - atol=2e-6, + atol=2e-5, ) @@ -186,7 +175,7 @@ def test_conservative_nan_aggregation_over_dims(): target = xr.Dataset(coords={"x": [0], "y": [0]}) result = data.regrid.conservative(target, skipna=True, nan_threshold=1) - assert np.allclose(result[0].mean().item(), data[0].mean().item()) + np.testing.assert_allclose(result[0].mean().item(), data[0].mean().item()) @pytest.mark.parametrize("nan_threshold", [0, 1]) @@ -212,19 +201,9 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold): xr.testing.assert_allclose(da_coarsen, da_regrid) -def xesmf_available() -> bool: - try: - import xesmf # noqa: F401 - except ImportError: - return False - return True - - -@pytest.mark.skipif(not xesmf_available(), reason="xesmf required") +@pytest.mark.skipif(xesmf is None, reason="xesmf required") def test_conservative_nan_thresholds_against_xesmf(): - import xesmf as xe - - ds = xr.tutorial.open_dataset("ersstv5").sst.compute() + ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).persist() ds = ds.rename(lon="longitude", lat="latitude") new_grid = xarray_regrid.Grid( north=90, @@ -235,16 +214,14 @@ def test_conservative_nan_thresholds_against_xesmf(): resolution_lon=2, ) target_dataset = xarray_regrid.create_regridding_dataset(new_grid) - regridder = xe.Regridder(ds, target_dataset, "conservative") + regridder = xesmf.Regridder(ds, target_dataset, "conservative") for nan_threshold in [0.0, 0.25, 0.5, 0.75, 1.0]: - data_regrid = ds.copy().regrid.conservative( + data_regrid = ds.regrid.conservative( target_dataset, skipna=True, nan_threshold=nan_threshold ) - data_esmf = regridder( - ds.copy(), keep_attrs=True, na_thres=nan_threshold, skipna=True - ) - assert (data_regrid.isnull() == data_esmf.isnull()).mean().values > 0.995 + data_esmf = regridder(ds, keep_attrs=True, na_thres=nan_threshold, skipna=True) + xr.testing.assert_equal(data_regrid.isnull(), data_esmf.isnull()) class TestCoordOrder: