Skip to content

Interpolation always returns floats #4770

Open
@Illviljan

Description

@Illviljan

What happened:
When interpolating datasets integer arrays are forced to floats.

What you expected to happen:
To retain the same dtype after interpolation.

Minimal Complete Verifiable Example:

import numpy as np
import dask.array as da
a = np.arange(0, 2)
b = np.core.defchararray.add("long_variable_name", a.astype(str))
coords = dict(time=da.array([0, 1]))
data_vars = dict()
for v in b:
    data_vars[v] = xr.DataArray(
        name=v,
        data=da.array([0, 1], dtype=int),
        dims=["time"],
        coords=coords,
    )
ds1 = xr.Dataset(data_vars)

print(ds1)
Out[35]: 
<xarray.Dataset>
Dimensions:              (time: 4)
Coordinates:
  * time                 (time) float64 0.0 0.5 1.0 2.0
Data variables:
    long_variable_name0  (time) int32 dask.array<chunksize=(4,), meta=np.ndarray>
    long_variable_name1  (time) int32 dask.array<chunksize=(4,), meta=np.ndarray>

# Interpolate:
ds1 = ds1.interp(
    time=da.array([0, 0.5, 1, 2]),
    assume_sorted=True,
    method="linear",
    kwargs=dict(fill_value="extrapolate"),
)

# dask array thinks it's an integer array:
print(ds1.long_variable_name0)
Out[55]: 
<xarray.DataArray 'long_variable_name0' (time: 4)>
dask.array<dask_aware_interpnd, shape=(4,), dtype=int32, chunksize=(4,), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) float64 0.0 0.5 1.0 2.0

#  But once computed it turns out is a float:
print(ds1.long_variable_name0.compute())
Out[38]: 
<xarray.DataArray 'long_variable_name0' (time: 4)>
array([0. , 0.5, 1. , 2. ])
Coordinates:
  * time     (time) float64 0.0 0.5 1.0 2.0
 

Anything else we need to know?:
An easy first step is to also force np.float_ in da.blockwise in missing.interp_func.

The more difficult way is to somehow be able to change back the dataarrays into the old dtype without affecting performance. I did a test simply adding .astype() to the returned value in missing.interp and it doubled the calculation time.

I was thinking the conversion to floats in scipy could be avoided altogether by adding a (non-)public option to ignore any dtype checks and just let the user handle the "unsafe" interpolations.

Related:
scipy/scipy#11093

Environment:

Output of xr.show_versions()

xr.show_versions()

INSTALLED VERSIONS

commit: None
python: 3.8.5 (default, Sep 3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)]
python-bits: 64
OS: Windows
libhdf5: 1.10.4
libnetcdf: None

xarray: 0.16.2
pandas: 1.1.5
numpy: 1.17.5
scipy: 1.4.1
netCDF4: None
pydap: None
h5netcdf: None
h5py: 2.10.0
Nio: None
zarr: None
cftime: None
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: 1.3.2
dask: 2020.12.0
distributed: 2020.12.0
matplotlib: 3.3.2
cartopy: None
seaborn: 0.11.1
numbagg: None
pint: None
setuptools: 51.0.0.post20201207
pip: 20.3.3
conda: 4.9.2
pytest: 6.2.1
IPython: 7.19.0
sphinx: 3.4.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions