Skip to content

Commit

Permalink
Merge pull request #45 from xarray-contrib/boundary-padding
Browse files Browse the repository at this point in the history
Spherical padding and faster tests
  • Loading branch information
slevang authored Sep 20, 2024
2 parents 8bdc636 + 1f2e999 commit 9d71962
Show file tree
Hide file tree
Showing 5 changed files with 461 additions and 125 deletions.
11 changes: 4 additions & 7 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def conservative_regrid(

# Make sure the regridding coordinates are sorted
coord_names = [coord for coord in target_ds.coords if coord in data.coords]
target_ds_sorted = target_ds.sortby(coord_names)
data = data.sortby(list(coord_names))
target_ds_sorted = xr.Dataset(coords=target_ds.coords)
for coord_name in coord_names:
target_ds_sorted = utils.ensure_monotonic(target_ds_sorted, coord_name)
data = utils.ensure_monotonic(data, coord_name)
coords = {name: target_ds_sorted[name] for name in coord_names}

regridded_data = utils.call_on_dataset(
Expand Down Expand Up @@ -122,15 +124,13 @@ def conservative_regrid_dataset(
weights = apply_spherical_correction(weights, latitude_coord)

for array in data_vars.keys():
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
if coord in data_vars[array].dims:
data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=weights,
coord=coord,
valid_frac=valid_fracs[array],
skipna=skipna,
non_grid_dims=non_grid_dims,
)
# Mask out any regridded points outside the original domain
data_vars[array] = data_vars[array].where(covered_grid)
Expand Down Expand Up @@ -161,16 +161,13 @@ def apply_weights(
coord: Hashable,
valid_frac: xr.DataArray,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord}": coord}
weights_norm = weights.copy()

if skipna:
notnull = da.notnull()
if non_grid_dims:
notnull = notnull.any(non_grid_dims)
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac.name != EMPTY_DA_NAME:
Expand Down
21 changes: 15 additions & 6 deletions src/xarray_regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray as xr

from xarray_regrid.methods import conservative, interp, most_common
from xarray_regrid.utils import format_for_regrid


@xr.register_dataarray_accessor("regrid")
Expand Down Expand Up @@ -34,7 +35,8 @@ def linear(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "linear")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")

def nearest(
self,
Expand All @@ -51,14 +53,14 @@ def nearest(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")

def cubic(
self,
ds_target_grid: xr.Dataset,
time_dim: str = "time",
) -> xr.DataArray | xr.Dataset:
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
"""Regrid to the coords of the target dataset with cubic interpolation.
Args:
Expand All @@ -68,7 +70,9 @@ def cubic(
Returns:
Data regridded to the target dataset coordinates.
"""
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")

def conservative(
self,
Expand All @@ -88,6 +92,9 @@ def conservative(
time_dim: The name of the time dimension/coordinate.
skipna: If True, enable handling for NaN values. This adds some overhead,
so can be disabled for optimal performance on data without any NaNs.
With `skipna=True, chunking is recommended in the non-grid dimensions,
otherwise the intermediate arrays that track the fraction of valid data
can become very large and consume excessive memory.
Warning: with `skipna=False`, isolated NaNs will propagate throughout
the dataset due to the sequential regridding scheme over each dimension.
nan_threshold: Threshold value that will retain any output points
Expand All @@ -104,8 +111,9 @@ def conservative(
raise ValueError(msg)

ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return conservative.conservative_regrid(
self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold
ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold
)

def most_common(
Expand Down Expand Up @@ -134,8 +142,9 @@ def most_common(
Regridded data.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return most_common.most_common_wrapper(
self._obj, ds_target_grid, time_dim, max_mem
ds_formatted, ds_target_grid, time_dim, max_mem
)


Expand Down
219 changes: 198 additions & 21 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, overload
from typing import Any, TypedDict, overload

import numpy as np
import pandas as pd
Expand All @@ -10,6 +10,11 @@
class InvalidBoundsError(Exception): ...


class CoordHandler(TypedDict):
names: list[str]
func: Callable


@dataclass
class Grid:
"""Object storing grid information."""
Expand Down Expand Up @@ -75,7 +80,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
grid.south, grid.north + grid.resolution_lat, grid.resolution_lat
)

if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0:
if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0:
lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon)
else:
lon_coords = np.arange(
Expand Down Expand Up @@ -193,24 +198,6 @@ def common_coords(
return sorted([str(coord) for coord in coords])


@overload
def call_on_dataset(
func: Callable[..., xr.Dataset],
obj: xr.DataArray,
*args: Any,
**kwargs: Any,
) -> xr.DataArray: ...


@overload
def call_on_dataset(
func: Callable[..., xr.Dataset],
obj: xr.Dataset,
*args: Any,
**kwargs: Any,
) -> xr.Dataset: ...


def call_on_dataset(
func: Callable[..., xr.Dataset],
obj: xr.DataArray | xr.Dataset,
Expand All @@ -235,3 +222,193 @@ def call_on_dataset(
return next(iter(result.data_vars.values())).rename(obj.name)

return result


def format_for_regrid(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset
) -> xr.DataArray | xr.Dataset:
"""Apply any pre-formatting to the input dataset to prepare for regridding.
Currently handles padding of spherical geometry if lat/lon coordinates can
be inferred and the domain size requires boundary padding.
"""
orig_chunksizes = obj.chunksizes

# Special-cased coordinates with accepted names and formatting function
coord_handlers: dict[str, CoordHandler] = {
"lat": {"names": ["lat", "latitude"], "func": format_lat},
"lon": {"names": ["lon", "longitude"], "func": format_lon},
}
# Identify coordinates that need to be formatted
formatted_coords = {}
for coord_type, handler in coord_handlers.items():
for coord in obj.coords.keys():
if str(coord).lower() in handler["names"]:
formatted_coords[coord_type] = str(coord)

# Apply formatting
for coord_type, coord in formatted_coords.items():
# Make sure formatted coords are sorted
obj = ensure_monotonic(obj, coord)
target = ensure_monotonic(target, coord)
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(coord, [])) == 1:
obj = obj.chunk({coord: -1})

return obj


def format_lat(
obj: xr.DataArray | xr.Dataset,
target: xr.Dataset, # noqa ARG001
formatted_coords: dict[str, str],
) -> xr.DataArray | xr.Dataset:
"""If the latitude coordinate is inferred to be global, defined as having
a value within one grid spacing of the poles, and the grid does not natively
have values at -90 and 90, add a single value at each pole computed as the
mean of the first and last latitude bands. This should be roughly equivalent
to the `Pole="all"` option in `ESMF`.
For example, with a grid spacing of 1 degree, and a source grid ranging from
-89.5 to 89.5, the poles would be padded with values at -90 and 90. A grid ranging
from -88 to 88 would not be padded because coverage does not extend all the way
to the poles. A grid ranging from -90 to 90 would also not be padded because the
poles will already be covered in the regridding weights.
"""
lat_coord = formatted_coords["lat"]
lon_coord = formatted_coords.get("lon")

# Concat a padded value representing the mean of the first/last lat bands
# This should match the Pole="all" option of ESMF
# TODO: with cos(90) = 0 weighting, these weights might be 0?

polar_lat = 90
dy = obj.coords[lat_coord].diff(lat_coord).max().values.item()

# Only pad if global but don't have edge values directly at poles
# NOTE: could use xr.pad here instead of xr.concat, but none of the
# modes are an exact fit for this scheme
lat_vals = obj.coords[lat_coord].values
# South pole
if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat:
south_pole = obj.isel({lat_coord: 0})
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord)
obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore
lat_vals = np.concatenate([[-polar_lat], lat_vals])

# North pole
if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord)
obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore
lat_vals = np.concatenate([lat_vals, [polar_lat]])

obj = update_coord(obj, lat_coord, lat_vals)

return obj


def format_lon(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str]
) -> xr.DataArray | xr.Dataset:
"""Format the longitude coordinate by shifting the source grid to line up with
the target anywhere in the range of -360 to 360, and then add a single wraparound
padding column if the domain is inferred to be global and the east or west edges
of the target lie outside the source grid centers.
For example, with a source grid ranging from 0.5 to 359.5 and a target grid ranging
from -180 to 180, the source grid would be shifted to -179.5 to 179.5 and then
padded on both the left and right with wraparound values at -180.5 and 180.5 to
provide full coverage for the target edge cells at -180 and 180.
"""
lon_coord = formatted_coords["lon"]

# Find a wrap point outside of the left and right bounds of the target
# This ensures we have coverage on the target and handles global > regional
source_vals = obj.coords[lon_coord].values
target_vals = target.coords[lon_coord].values
wrap_point = (target_vals[-1] + target_vals[0] + 360) / 2
source_vals = np.where(
source_vals < wrap_point - 360, source_vals + 360, source_vals
)
source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals)
obj = update_coord(obj, lon_coord, source_vals)

obj = ensure_monotonic(obj, lon_coord)

# Only pad if domain is global in lon
source_lon = obj.coords[lon_coord]
target_lon = target.coords[lon_coord]
dx_s = source_lon.diff(lon_coord).max().values.item()
dx_t = target_lon.diff(lon_coord).max().values.item()
is_global_lon = source_lon.max().values - source_lon.min().values >= 360 - dx_s

if is_global_lon:
left_pad = (source_lon.values[0] - target_lon.values[0] + dx_t / 2) / dx_s
right_pad = (target_lon.values[-1] - source_lon.values[-1] + dx_t / 2) / dx_s
left_pad = int(np.ceil(np.max([left_pad, 0])))
right_pad = int(np.ceil(np.max([right_pad, 0])))
obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True)
lon_vals = obj.coords[lon_coord].values
if left_pad:
lon_vals[:left_pad] = source_lon.values[-left_pad:] - 360
if right_pad:
lon_vals[-right_pad:] = source_lon.values[:right_pad] + 360
obj = update_coord(obj, lon_coord, lon_vals)

return obj


def coord_is_covered(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable
) -> bool:
"""Check if the source coord fully covers the target coord."""
pad = target[coord].diff(coord).max().values
left_covered = obj[coord].min() <= target[coord].min() - pad
right_covered = obj[coord].max() >= target[coord].max() + pad
return bool(left_covered.item() and right_covered.item())


@overload
def ensure_monotonic(obj: xr.DataArray, coord: Hashable) -> xr.DataArray: ...


@overload
def ensure_monotonic(obj: xr.Dataset, coord: Hashable) -> xr.Dataset: ...


def ensure_monotonic(
obj: xr.DataArray | xr.Dataset, coord: Hashable
) -> xr.DataArray | xr.Dataset:
"""Ensure that an object has monotonically increasing indexes for a
given coordinate. Only sort and drop duplicates if needed because this
requires reindexing which can be expensive."""
if not obj.indexes[coord].is_monotonic_increasing:
obj = obj.sortby(coord)
if not obj.indexes[coord].is_unique:
obj = obj.drop_duplicates(coord)
return obj


@overload
def update_coord(
obj: xr.DataArray, coord: Hashable, coord_vals: np.ndarray
) -> xr.DataArray: ...


@overload
def update_coord(
obj: xr.Dataset, coord: Hashable, coord_vals: np.ndarray
) -> xr.Dataset: ...


def update_coord(
obj: xr.DataArray | xr.Dataset, coord: Hashable, coord_vals: np.ndarray
) -> xr.DataArray | xr.Dataset:
"""Update the values of a coordinate, ensuring indexes stay in sync."""
attrs = obj.coords[coord].attrs
obj = obj.assign_coords({coord: coord_vals})
obj.coords[coord].attrs = attrs
return obj
Loading

0 comments on commit 9d71962

Please sign in to comment.