Skip to content

Commit

Permalink
Ensure hashable is a valid input for coordinate identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
BSchilperoort committed Sep 3, 2024
1 parent 1d1c62f commit 2e80872
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def conservative_regrid(
def conservative_regrid_dataset(
data: xr.Dataset,
coords: dict[Hashable, xr.DataArray],
latitude_coord: str,
latitude_coord: Hashable,
skipna: bool,
nan_threshold: float,
) -> xr.Dataset:
Expand All @@ -116,7 +116,7 @@ def conservative_regrid_dataset(
weights = utils.create_dot_dataarray(
nd_weights, str(coord), target_coords, source_coords
)
if str(coord) == latitude_coord:
if coord == latitude_coord:
weights = apply_spherical_correction(weights, latitude_coord)

for array in data_vars.keys():
Expand All @@ -125,7 +125,7 @@ def conservative_regrid_dataset(
data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=weights,
coord_name=str(coord),
coord=coord,
valid_frac=valid_fracs[array],
skipna=skipna,
non_grid_dims=non_grid_dims,
Expand Down Expand Up @@ -157,7 +157,7 @@ def conservative_regrid_dataset(
def apply_weights(
da: xr.DataArray,
weights: xr.DataArray,
coord_name: str,
coord: Hashable,
valid_frac: xr.DataArray,
skipna: bool,
non_grid_dims: list[Hashable],
Expand All @@ -168,7 +168,7 @@ def apply_weights(
def apply_weights(
da: xr.DataArray,
weights: xr.DataArray,
coord_name: str,
coord: Hashable,
valid_frac: None,
skipna: bool,
non_grid_dims: list[Hashable],
Expand All @@ -178,13 +178,13 @@ def apply_weights(
def apply_weights(
da: xr.DataArray,
weights: xr.DataArray,
coord_name: str,
coord: Hashable,
valid_frac: xr.DataArray | None,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray | None]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord_name}": coord_name}
coord_map = {f"target_{coord}": coord}
weights_norm = weights.copy()

if skipna:
Expand All @@ -194,16 +194,16 @@ def apply_weights(
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac is not None:
weights_norm = weights * valid_frac / valid_frac.mean(coord_name)
weights_norm = weights * valid_frac / valid_frac.mean(dim=[coord])

da_reduced: xr.DataArray = xr.dot(
da.fillna(0), weights_norm, dim=coord_name, optimize=True
da.fillna(0), weights_norm, dim=[coord], optimize=True
)
da_reduced = da_reduced.rename(coord_map).transpose(*da.dims)

if skipna:
weights_valid_sum: xr.DataArray = xr.dot(
weights_norm, notnull, dim=coord_name, optimize=True
weights_norm, notnull, dim=[coord], optimize=True
)
weights_valid_sum = weights_valid_sum.rename(coord_map)
da_reduced /= weights_valid_sum.clip(1e-6, None)
Expand All @@ -214,7 +214,7 @@ def apply_weights(

else:
# Update the valid points on this dimension
valid_frac = xr.dot(valid_frac, weights, dim=coord_name, optimize=True)
valid_frac = xr.dot(valid_frac, weights, dim=[coord], optimize=True)
valid_frac = valid_frac.rename(coord_map) # type: ignore
valid_frac = valid_frac.clip(0, 1)

Expand Down Expand Up @@ -247,7 +247,7 @@ def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndar


def apply_spherical_correction(
dot_array: xr.DataArray, latitude_coord: str
dot_array: xr.DataArray, latitude_coord: Hashable
) -> xr.DataArray:
"""Apply a sperical earth correction on the prepared dot product weights."""
da = dot_array.copy()
Expand Down

0 comments on commit 2e80872

Please sign in to comment.