Skip to content

Commit

Permalink
initial pass at spherical padding, faster tests, full nan tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Sep 6, 2024
1 parent 8bdc636 commit ce16f74
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 105 deletions.
7 changes: 1 addition & 6 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def conservative_regrid(
# Attempt to infer the latitude coordinate
if latitude_coord is None:
for coord in data.coords:
if str(coord).lower() in ["lat", "latitude"]:
if str(coord).lower().startswith("lat"):
latitude_coord = coord
break

Expand Down Expand Up @@ -122,15 +122,13 @@ def conservative_regrid_dataset(
weights = apply_spherical_correction(weights, latitude_coord)

for array in data_vars.keys():
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
if coord in data_vars[array].dims:
data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=weights,
coord=coord,
valid_frac=valid_fracs[array],
skipna=skipna,
non_grid_dims=non_grid_dims,
)
# Mask out any regridded points outside the original domain
data_vars[array] = data_vars[array].where(covered_grid)
Expand Down Expand Up @@ -161,16 +159,13 @@ def apply_weights(
coord: Hashable,
valid_frac: xr.DataArray,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord}": coord}
weights_norm = weights.copy()

if skipna:
notnull = da.notnull()
if non_grid_dims:
notnull = notnull.any(non_grid_dims)
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac.name != EMPTY_DA_NAME:
Expand Down
21 changes: 15 additions & 6 deletions src/xarray_regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray as xr

from xarray_regrid.methods import conservative, interp, most_common
from xarray_regrid.utils import format_for_regrid


@xr.register_dataarray_accessor("regrid")
Expand Down Expand Up @@ -34,7 +35,8 @@ def linear(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "linear")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")

def nearest(
self,
Expand All @@ -51,14 +53,14 @@ def nearest(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")

def cubic(
self,
ds_target_grid: xr.Dataset,
time_dim: str = "time",
) -> xr.DataArray | xr.Dataset:
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
"""Regrid to the coords of the target dataset with cubic interpolation.
Args:
Expand All @@ -68,7 +70,9 @@ def cubic(
Returns:
Data regridded to the target dataset coordinates.
"""
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")

def conservative(
self,
Expand All @@ -88,6 +92,9 @@ def conservative(
time_dim: The name of the time dimension/coordinate.
skipna: If True, enable handling for NaN values. This adds some overhead,
so can be disabled for optimal performance on data without any NaNs.
With `skipna=True, chunking is recommended in the non-grid dimensions,
otherwise the intermediate arrays that track the fraction of valid data
can become very large and consume excessive memory.
Warning: with `skipna=False`, isolated NaNs will propagate throughout
the dataset due to the sequential regridding scheme over each dimension.
nan_threshold: Threshold value that will retain any output points
Expand All @@ -104,8 +111,9 @@ def conservative(
raise ValueError(msg)

ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return conservative.conservative_regrid(
self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold
ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold
)

def most_common(
Expand Down Expand Up @@ -134,8 +142,9 @@ def most_common(
Regridded data.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return most_common.most_common_wrapper(
self._obj, ds_target_grid, time_dim, max_mem
ds_formatted, ds_target_grid, time_dim, max_mem
)


Expand Down
128 changes: 126 additions & 2 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, overload

Expand Down Expand Up @@ -75,7 +75,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
grid.south, grid.north + grid.resolution_lat, grid.resolution_lat
)

if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0:
if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0:
lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon)
else:
lon_coords = np.arange(
Expand Down Expand Up @@ -235,3 +235,127 @@ def call_on_dataset(
return next(iter(result.data_vars.values())).rename(obj.name)

return result


def format_for_regrid(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset
) -> xr.DataArray | xr.Dataset:
"""Apply any pre-formatting to the input dataset to prepare for regridding.
Currently handles padding of spherical geometry if appropriate coordinate names
can be inferred containing 'lat' and 'lon'.
"""
lat_coord = None
lon_coord = None

for coord in obj.coords.keys():
if str(coord).lower().startswith("lat"):
lat_coord = coord
elif str(coord).lower().startswith("lon"):
lon_coord = coord

if lon_coord is not None or lat_coord is not None:
obj = format_spherical(obj, target, lat_coord, lon_coord)

return obj


def format_spherical(
obj: xr.DataArray | xr.Dataset,
target: xr.Dataset,
lat_coord: Hashable,
lon_coord: Hashable,
) -> xr.DataArray | xr.Dataset:
"""Infer whether a lat/lon source grid represents a global domain and
automatically apply spherical padding to improve edge effects.
For longitude, shift the coordinate to line up with the target values, then
add a single wraparound padding column if the domain is global and the east
or west edges of the target lie outside the source grid centers.
For latitude, add a single value at each pole computed as the mean of the last
row for global source grids where the first or last point lie equatorward of 90.
"""

orig_chunksizes = obj.chunksizes

# If the source coord fully covers the target, don't modify them
if lat_coord and not coord_is_covered(obj, target, lat_coord):
obj = obj.sortby(lat_coord)
target = target.sortby(lat_coord)

# Only pad if global but don't have edge values directly at poles
polar_lat = 90
dy = obj[lat_coord].diff(lat_coord).max().values

# South pole
if dy - polar_lat >= obj[lat_coord][0] > -polar_lat:
south_pole = obj.isel({lat_coord: 0})
# This should match the Pole="all" option of ESMF
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord)
obj = xr.concat([south_pole, obj], dim=lat_coord)
obj[lat_coord].values[0] = -polar_lat

# North pole
if polar_lat - dy <= obj[lat_coord][-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord)
obj = xr.concat([obj, north_pole], dim=lat_coord)
obj[lat_coord].values[-1] = polar_lat

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(lat_coord, [])) == 1:
obj = obj.chunk({lat_coord: -1})

if lon_coord and not coord_is_covered(obj, target, lon_coord):
obj = obj.sortby(lon_coord)
target = target.sortby(lon_coord)

target_lon = target[lon_coord].values
# Find a wrap point outside of the left and right bounds of the target
# This ensures we have coverage on the target and handles global > regional
wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2
lon = obj[lon_coord].values
lon = np.where(lon < wrap_point - 360, lon + 360, lon)
lon = np.where(lon > wrap_point, lon - 360, lon)
obj[lon_coord].values[:] = lon

# Shift operations can produce duplicates
# Simplest solution is to drop them and add back when padding
obj = obj.sortby(lon_coord).drop_duplicates(lon_coord)

# Only pad if domain is global in lon
dx_s = obj[lon_coord].diff(lon_coord).max().values
dx_t = target[lon_coord].diff(lon_coord).max().values
is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s

if is_global_lon:
left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s
right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s
left_pad = int(np.ceil(np.max([left_pad, 0])))
right_pad = int(np.ceil(np.max([right_pad, 0])))
lon = obj[lon_coord].values
obj = obj.pad(
{lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True
)
if left_pad:
obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360
if right_pad:
obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(lon_coord, [])) == 1:
obj = obj.chunk({lon_coord: -1})

return obj


def coord_is_covered(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable
) -> bool:
"""Check if the source coord fully covers the target coord."""
pad = target[coord].diff(coord).max().values
left_covered = obj[coord].min() <= target[coord].min() - pad
right_covered = obj[coord].max() >= target[coord].max() + pad
return bool(left_covered.item() and right_covered.item())
Loading

0 comments on commit ce16f74

Please sign in to comment.