diff --git a/nitransforms/base.py b/nitransforms/base.py index 25fd88e0..68b97f75 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -259,6 +259,15 @@ def apply( slightly blurred if *order > 1*, unless the input is prefiltered, i.e. it is the result of calling the spline filter on the original input. + output_dtype: dtype specifier, optional + The dtype of the returned array or image, if specified. + If ``None``, the default behavior is to use the effective dtype of + the input image. If slope and/or intercept are defined, the effective + dtype is float64, otherwise it is equivalent to the input image's + ``get_data_dtype()`` (on-disk type). + If ``reference`` is defined, then the return value is an image, with + a data array of the effective dtype but with the on-disk dtype set to + the input image's on-disk dtype. Returns ------- @@ -279,11 +288,7 @@ def apply( if isinstance(spatialimage, (str, Path)): spatialimage = _nbload(str(spatialimage)) - data = np.asanyarray( - spatialimage.dataobj, - dtype=spatialimage.get_data_dtype() - ) - output_dtype = output_dtype or data.dtype + data = np.asanyarray(spatialimage.dataobj) targets = ImageGrid(spatialimage).index( # data should be an image _as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim) ) @@ -302,9 +307,9 @@ def apply( hdr = None if _ref.header is not None: hdr = _ref.header.copy() - hdr.set_data_dtype(output_dtype) + hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype()) moved = spatialimage.__class__( - resampled.reshape(_ref.shape).astype(output_dtype), + resampled.reshape(_ref.shape), _ref.affine, hdr, )