diff --git a/src/xarray_regrid/methods/conservative.py b/src/xarray_regrid/methods/conservative.py index 0b31f54..5aae059 100644 --- a/src/xarray_regrid/methods/conservative.py +++ b/src/xarray_regrid/methods/conservative.py @@ -91,7 +91,7 @@ def conservative_regrid( def conservative_regrid_dataset( data: xr.Dataset, coords: dict[Hashable, xr.DataArray], - latitude_coord: str, + latitude_coord: Hashable, skipna: bool, nan_threshold: float, ) -> xr.Dataset: @@ -116,7 +116,7 @@ def conservative_regrid_dataset( weights = utils.create_dot_dataarray( nd_weights, str(coord), target_coords, source_coords ) - if str(coord) == latitude_coord: + if coord == latitude_coord: weights = apply_spherical_correction(weights, latitude_coord) for array in data_vars.keys(): @@ -125,7 +125,7 @@ def conservative_regrid_dataset( data_vars[array], valid_fracs[array] = apply_weights( da=data_vars[array], weights=weights, - coord_name=str(coord), + coord=coord, valid_frac=valid_fracs[array], skipna=skipna, non_grid_dims=non_grid_dims, @@ -157,7 +157,7 @@ def conservative_regrid_dataset( def apply_weights( da: xr.DataArray, weights: xr.DataArray, - coord_name: str, + coord: Hashable, valid_frac: xr.DataArray, skipna: bool, non_grid_dims: list[Hashable], @@ -168,7 +168,7 @@ def apply_weights( def apply_weights( da: xr.DataArray, weights: xr.DataArray, - coord_name: str, + coord: Hashable, valid_frac: None, skipna: bool, non_grid_dims: list[Hashable], @@ -178,13 +178,13 @@ def apply_weights( def apply_weights( da: xr.DataArray, weights: xr.DataArray, - coord_name: str, + coord: Hashable, valid_frac: xr.DataArray | None, skipna: bool, non_grid_dims: list[Hashable], ) -> tuple[xr.DataArray, xr.DataArray | None]: """Apply the weights to convert data to the new coordinates.""" - coord_map = {f"target_{coord_name}": coord_name} + coord_map = {f"target_{coord}": coord} weights_norm = weights.copy() if skipna: @@ -194,16 +194,16 @@ def apply_weights( # Renormalize the weights along this dim by the accumulated valid_frac # along previous dimensions if valid_frac is not None: - weights_norm = weights * valid_frac / valid_frac.mean(coord_name) + weights_norm = weights * valid_frac / valid_frac.mean(dim=[coord]) da_reduced: xr.DataArray = xr.dot( - da.fillna(0), weights_norm, dim=coord_name, optimize=True + da.fillna(0), weights_norm, dim=[coord], optimize=True ) da_reduced = da_reduced.rename(coord_map).transpose(*da.dims) if skipna: weights_valid_sum: xr.DataArray = xr.dot( - weights_norm, notnull, dim=coord_name, optimize=True + weights_norm, notnull, dim=[coord], optimize=True ) weights_valid_sum = weights_valid_sum.rename(coord_map) da_reduced /= weights_valid_sum.clip(1e-6, None) @@ -214,7 +214,7 @@ def apply_weights( else: # Update the valid points on this dimension - valid_frac = xr.dot(valid_frac, weights, dim=coord_name, optimize=True) + valid_frac = xr.dot(valid_frac, weights, dim=[coord], optimize=True) valid_frac = valid_frac.rename(coord_map) # type: ignore valid_frac = valid_frac.clip(0, 1) @@ -247,7 +247,7 @@ def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndar def apply_spherical_correction( - dot_array: xr.DataArray, latitude_coord: str + dot_array: xr.DataArray, latitude_coord: Hashable ) -> xr.DataArray: """Apply a sperical earth correction on the prepared dot product weights.""" da = dot_array.copy()