Skip to content

Commit

Permalink
Refactor out duplicated gufunc args
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Jun 26, 2024
1 parent 61a7f59 commit e4f5c54
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions src/sknnr_spatial/image/_dask_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,11 @@ def predict(
estimator_type = getattr(estimator, "_estimator_type", "")
output_dtype = ESTIMATOR_OUTPUT_DTYPES.get(estimator_type, np.float64)

y_pred = da.apply_gufunc(
# If the wrapped estimator was fit with a dataframe, it will warn about
# missing feature names because this passes unnamed arrays. Suppress that
# and let the wrapper handle feature name checks.
suppress_feature_name_warnings(estimator._wrapped.predict),
signature,
self.preprocessor.flat,
axis=self.preprocessor.flat_band_dim,
output_dtypes=[output_dtype],
y_pred = self._apply_gufunc(
estimator._wrapped.predict,
signature=signature,
output_sizes=output_sizes,
allow_rechunk=True,
output_dtypes=[output_dtype],
)

# Reshape from (n_samples,) to (n_samples, 1)
Expand All @@ -88,17 +82,11 @@ def kneighbors(
signature = "(x)->(k)" if not return_distance else "(x)->(k),(k)"
output_dtypes: list[type] = [int] if not return_distance else [float, int]

result = da.apply_gufunc(
# If the wrapped estimator was fit with a dataframe, it will warn about
# missing feature names because this passes unnamed arrays. Suppress that
# and let the wrapper handle feature name checks.
suppress_feature_name_warnings(estimator._wrapped.kneighbors),
signature,
self.preprocessor.flat,
result = self._apply_gufunc(
estimator._wrapped.kneighbors,
signature=signature,
output_sizes={"k": k},
output_dtypes=output_dtypes,
axis=self.preprocessor.flat_band_dim,
allow_rechunk=True,
n_neighbors=n_neighbors,
return_distance=return_distance,
**kneighbors_kwargs,
Expand All @@ -113,3 +101,21 @@ def kneighbors(
return dist, nn

return self.preprocessor.unflatten(result, var_names=var_names)

def _apply_gufunc(self, func, *, signature, output_sizes, output_dtypes, **kwargs):
"""Apply a gufunc to the image across bands."""
# sklearn estimator methods like `predict` may warn about missing feature
# names because this passes unnamed arrays. We can suppress those and let
# the wrapper handle feature name checks.
suppressed_func = suppress_feature_name_warnings(func)

return da.apply_gufunc(
suppressed_func,
signature,
self.preprocessor.flat,
output_sizes=output_sizes,
output_dtypes=output_dtypes,
axis=self.preprocessor.flat_band_dim,
allow_rechunk=True,
**kwargs,
)

0 comments on commit e4f5c54

Please sign in to comment.