Skip to content

Commit

Permalink
Fix typing in updated conservative routines
Browse files Browse the repository at this point in the history
  • Loading branch information
BSchilperoort committed Sep 2, 2024
1 parent d9afb96 commit 68b166f
Showing 1 changed file with 45 additions and 19 deletions.
64 changes: 45 additions & 19 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def conservative_regrid(
def conservative_regrid(
data: xr.DataArray | xr.Dataset,
target_ds: xr.Dataset,
latitude_coord: str | None,
latitude_coord: str | Hashable | None,
skipna: bool = True,
nan_threshold: float = 1.0,
) -> xr.DataArray | xr.Dataset:
Expand Down Expand Up @@ -66,7 +66,7 @@ def conservative_regrid(
# Attempt to infer the latitude coordinate
if latitude_coord is None:
for coord in data.coords:
if coord.lower().startswith("lat"):
if str(coord).lower().startswith("lat"):
latitude_coord = coord
break

Expand Down Expand Up @@ -100,7 +100,7 @@ def conservative_regrid_dataset(
"""Dataset implementation of the conservative regridding method."""
data_vars = dict(data.data_vars)
data_coords = dict(data.coords)
valid_fracs = {v: None for v in data_vars}
valid_fracs = {v: xr.DataArray() for v in data_vars}
data_attrs = {v: data_vars[v].attrs for v in data_vars}
coord_attrs = {c: data_coords[c].attrs for c in data_coords}
ds_attrs = data.attrs
Expand All @@ -112,11 +112,11 @@ def conservative_regrid_dataset(

target_coords = coords[coord].to_numpy()
source_coords = data[coord].to_numpy()
weights = get_weights(source_coords, target_coords)
nd_weights = get_weights(source_coords, target_coords)

# Modify weights to correct for latitude distortion
weights = utils.create_dot_dataarray(
weights, str(coord), target_coords, source_coords
nd_weights, str(coord), target_coords, source_coords
)
if str(coord) == latitude_coord:
weights = apply_spherical_correction(weights, latitude_coord)
Expand All @@ -125,12 +125,12 @@ def conservative_regrid_dataset(
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
if coord in data_vars[array].dims:
data_vars[array], valid_fracs[array] = apply_weights(
data_vars[array],
weights,
coord,
valid_fracs[array],
skipna,
non_grid_dims,
da=data_vars[array],
weights=weights,
coord_name=str(coord),
valid_frac=valid_fracs[array],
skipna=skipna,
non_grid_dims=non_grid_dims,
)
# Mask out any regridded points outside the original domain
data_vars[array] = data_vars[array].where(covered_grid)
Expand All @@ -155,14 +155,36 @@ def conservative_regrid_dataset(
return ds_regridded


@overload
def apply_weights(
da: xr.DataArray,
weights: xr.DataArray,
coord_name: str,
valid_frac: xr.DataArray,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]: ...


@overload
def apply_weights(
da: xr.DataArray,
weights: np.ndarray,
coord_name: Hashable,
weights: xr.DataArray,
coord_name: str,
valid_frac: None,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, None]: ...


def apply_weights(
da: xr.DataArray,
weights: xr.DataArray,
coord_name: str,
valid_frac: xr.DataArray | None,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]:
) -> tuple[xr.DataArray, xr.DataArray | None]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord_name}": coord_name}
weights_norm = weights.copy()
Expand All @@ -176,13 +198,16 @@ def apply_weights(
if valid_frac is not None:
weights_norm = weights * valid_frac / valid_frac.mean(coord_name)

da_reduced = xr.dot(da.fillna(0), weights_norm, dim=coord_name, optimize=True)
da_reduced: xr.DataArray = xr.dot(
da.fillna(0), weights_norm, dim=coord_name, optimize=True
)
da_reduced = da_reduced.rename(coord_map).transpose(*da.dims)

if skipna:
weights_valid_sum = xr.dot(
weights_valid_sum: xr.DataArray = xr.dot(
weights_norm, notnull, dim=coord_name, optimize=True
).rename(coord_map)
)
weights_valid_sum = weights_valid_sum.rename(coord_map)
da_reduced /= weights_valid_sum.clip(1e-6, None)

if valid_frac is None:
Expand All @@ -193,7 +218,8 @@ def apply_weights(
# Update the valid points on this dimension
valid_frac = xr.dot(
valid_frac, weights, dim=coord_name, optimize=True
).rename(coord_map)
)
valid_frac = valid_frac.rename(coord_map) # type: ignore
valid_frac = valid_frac.clip(0, 1)

return da_reduced, valid_frac
Expand All @@ -203,7 +229,7 @@ def get_valid_threshold(nan_threshold: float) -> float:
"""Invert the nan_threshold and coerce it to just above zero and below
one to handle numerical precision limitations in the weight sum."""
# This matches xesmf where na_thresh=0 keeps points with any valid data
valid_threshold = 1 - np.clip(nan_threshold, 1e-6, 1.0 - 1e-6)
valid_threshold: float = 1 - np.clip(nan_threshold, 1e-6, 1.0 - 1e-6)
return valid_threshold


Expand Down

0 comments on commit 68b166f

Please sign in to comment.