Skip to content

Commit

Permalink
Vectorized lazy indexing (#1899)
Browse files Browse the repository at this point in the history
* Start working

* First support of lazy vectorized indexing.

* Some optimization.

* Use unique to decompose vectorized indexing.

* Consolidate vectorizedIndexing

* Support vectorized_indexing in h5py

* Refactoring backend array. Added indexing.decompose_indexers. Drop unwrap_explicit_indexers.

* typo

* bugfix and typo

* Fix based on @WeatherGod comments.

* Use enum-like object to indicate indexing-support types.

* Update test_decompose_indexers.

* Bugfix and benchmarks.

* fix: support outer/basic indexer in LazilyVectorizedIndexedArray

* More comments.

* Fixing style errors.

* Remove unintended dupicate

* combine indexers for on-memory np.ndarray.

* fix whats new

* fix pydap

* Update comments.

* Support VectorizedIndexing for rasterio. Some bugfix.

* flake8

* More tests

* Use LazilyIndexedArray for scalar array instead of loading.

* Support negative step slice in rasterio.

* Make slice-step always positive

* Bugfix in slice-slice

* Add pydap support.

* Rename LazilyIndexedArray -> LazilyOuterIndexedArray. Remove duplicate in zarr.py

* flake8

* Added transpose to LazilyOuterIndexedArray
  • Loading branch information
fujiisoup authored and shoyer committed Mar 6, 2018
1 parent 55128aa commit 54468e1
Show file tree
Hide file tree
Showing 15 changed files with 859 additions and 269 deletions.
47 changes: 46 additions & 1 deletion asv_bench/benchmarks/dataset_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import xarray as xr

from . import randn, requires_dask
from . import randn, randint, requires_dask

try:
import dask
Expand Down Expand Up @@ -71,6 +71,15 @@ def make_ds(self):

self.ds.attrs = {'history': 'created for xarray benchmarking'}

self.oinds = {'time': randint(0, self.nt, 120),
'lon': randint(0, self.nx, 20),
'lat': randint(0, self.ny, 10)}
self.vinds = {'time': xr.DataArray(randint(0, self.nt, 120),
dims='x'),
'lon': xr.DataArray(randint(0, self.nx, 120),
dims='x'),
'lat': slice(3, 20)}


class IOWriteSingleNetCDF3(IOSingleNetCDF):
def setup(self):
Expand Down Expand Up @@ -98,6 +107,14 @@ def setup(self):
def time_load_dataset_netcdf4(self):
xr.open_dataset(self.filepath, engine='netcdf4').load()

def time_orthogonal_indexing(self):
ds = xr.open_dataset(self.filepath, engine='netcdf4')
ds = ds.isel(**self.oinds).load()

def time_vectorized_indexing(self):
ds = xr.open_dataset(self.filepath, engine='netcdf4')
ds = ds.isel(**self.vinds).load()


class IOReadSingleNetCDF3(IOReadSingleNetCDF4):
def setup(self):
Expand All @@ -111,6 +128,14 @@ def setup(self):
def time_load_dataset_scipy(self):
xr.open_dataset(self.filepath, engine='scipy').load()

def time_orthogonal_indexing(self):
ds = xr.open_dataset(self.filepath, engine='scipy')
ds = ds.isel(**self.oinds).load()

def time_vectorized_indexing(self):
ds = xr.open_dataset(self.filepath, engine='scipy')
ds = ds.isel(**self.vinds).load()


class IOReadSingleNetCDF4Dask(IOSingleNetCDF):
def setup(self):
Expand All @@ -127,6 +152,16 @@ def time_load_dataset_netcdf4_with_block_chunks(self):
xr.open_dataset(self.filepath, engine='netcdf4',
chunks=self.block_chunks).load()

def time_load_dataset_netcdf4_with_block_chunks_oindexing(self):
ds = xr.open_dataset(self.filepath, engine='netcdf4',
chunks=self.block_chunks)
ds = ds.isel(**self.oinds).load()

def time_load_dataset_netcdf4_with_block_chunks_vindexing(self):
ds = xr.open_dataset(self.filepath, engine='netcdf4',
chunks=self.block_chunks)
ds = ds.isel(**self.vinds).load()

def time_load_dataset_netcdf4_with_block_chunks_multiprocessing(self):
with dask.set_options(get=dask.multiprocessing.get):
xr.open_dataset(self.filepath, engine='netcdf4',
Expand Down Expand Up @@ -158,6 +193,16 @@ def time_load_dataset_scipy_with_block_chunks(self):
xr.open_dataset(self.filepath, engine='scipy',
chunks=self.block_chunks).load()

def time_load_dataset_scipy_with_block_chunks_oindexing(self):
ds = xr.open_dataset(self.filepath, engine='scipy',
chunks=self.block_chunks)
ds = ds.isel(**self.oinds).load()

def time_load_dataset_scipy_with_block_chunks_vindexing(self):
ds = xr.open_dataset(self.filepath, engine='scipy',
chunks=self.block_chunks)
ds = ds.isel(**self.vinds).load()

def time_load_dataset_scipy_with_time_chunks(self):
with dask.set_options(get=dask.multiprocessing.get):
xr.open_dataset(self.filepath, engine='scipy',
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ Documentation
Enhancements
~~~~~~~~~~~~

- Support lazy vectorized-indexing. After this change, flexible indexing such
as orthogonal/vectorized indexing, becomes possible for all the backend
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Improve :py:func:`~xarray.DataArray.rolling` logic.
:py:func:`~xarray.DataArrayRolling` object now supports
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view
Expand Down
19 changes: 12 additions & 7 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@

class H5NetCDFArrayWrapper(BaseNetCDF4Array):
def __getitem__(self, key):
key = indexing.unwrap_explicit_indexer(
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
key, np_inds = indexing.decompose_indexer(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR)

# h5py requires using lists for fancy indexing:
# https://github.com/h5py/h5py/issues/992
# OuterIndexer only holds 1D integer ndarrays, so it's safe to convert
# them to lists.
key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key)
key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in
key.tuple)
with self.datastore.ensure_open(autoclose=True):
return self.get_array()[key]
array = self.get_array()[key]

if len(np_inds.tuple) > 0:
array = indexing.NumpyIndexingAdapter(array)[np_inds]

return array


def maybe_decode_bytes(txt):
Expand Down Expand Up @@ -85,7 +90,7 @@ def __init__(self, filename, mode='r', format=None, group=None,
def open_store_variable(self, name, var):
with self.ensure_open(autoclose=False):
dimensions = var.dimensions
data = indexing.LazilyIndexedArray(
data = indexing.LazilyOuterIndexedArray(
H5NetCDFArrayWrapper(name, self))
attrs = _read_attributes(var)

Expand Down
15 changes: 9 additions & 6 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,16 @@ def get_array(self):

class NetCDF4ArrayWrapper(BaseNetCDF4Array):
def __getitem__(self, key):
key = indexing.unwrap_explicit_indexer(
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))

key, np_inds = indexing.decompose_indexer(
key, self.shape, indexing.IndexingSupport.OUTER)
if self.datastore.is_remote: # pragma: no cover
getitem = functools.partial(robust_getitem, catch=RuntimeError)
else:
getitem = operator.getitem

with self.datastore.ensure_open(autoclose=True):
try:
data = getitem(self.get_array(), key)
array = getitem(self.get_array(), key.tuple)
except IndexError:
# Catch IndexError in netCDF4 and return a more informative
# error message. This is most often called when an unsorted
Expand All @@ -71,7 +70,10 @@ def __getitem__(self, key):
msg += '\n\nOriginal traceback:\n' + traceback.format_exc()
raise IndexError(msg)

return data
if len(np_inds.tuple) > 0:
array = indexing.NumpyIndexingAdapter(array)[np_inds]

return array


def _encode_nc4_variable(var):
Expand Down Expand Up @@ -277,7 +279,8 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None,
def open_store_variable(self, name, var):
with self.ensure_open(autoclose=False):
dimensions = var.dimensions
data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
data = indexing.LazilyOuterIndexedArray(
NetCDF4ArrayWrapper(name, self))
attributes = OrderedDict((k, var.getncattr(k))
for k in var.ncattrs())
_ensure_fill_value_valid(data, attributes)
Expand Down
14 changes: 9 additions & 5 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@ def dtype(self):
return self.array.dtype

def __getitem__(self, key):
key = indexing.unwrap_explicit_indexer(
key, target=self, allow=indexing.BasicIndexer)
key, np_inds = indexing.decompose_indexer(
key, self.shape, indexing.IndexingSupport.BASIC)

# pull the data from the array attribute if possible, to avoid
# downloading coordinate data twice
array = getattr(self.array, 'array', self.array)
result = robust_getitem(array, key, catch=ValueError)
result = robust_getitem(array, key.tuple, catch=ValueError)
# pydap doesn't squeeze axes automatically like numpy
axis = tuple(n for n, k in enumerate(key)
axis = tuple(n for n, k in enumerate(key.tuple)
if isinstance(k, integer_types))
if len(axis) > 0:
result = np.squeeze(result, axis)

if len(np_inds.tuple) > 0:
result = indexing.NumpyIndexingAdapter(np.asarray(result))[np_inds]

return result


Expand Down Expand Up @@ -74,7 +78,7 @@ def open(cls, url, session=None):
return cls(ds)

def open_store_variable(self, var):
data = indexing.LazilyIndexedArray(PydapArrayWrapper(var))
data = indexing.LazilyOuterIndexedArray(PydapArrayWrapper(var))
return Variable(var.dimensions, data,
_fix_attributes(var.attributes))

Expand Down
15 changes: 10 additions & 5 deletions xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@ def get_array(self):
return self.datastore.ds.variables[self.variable_name]

def __getitem__(self, key):
key = indexing.unwrap_explicit_indexer(
key, target=self, allow=indexing.BasicIndexer)
key, np_inds = indexing.decompose_indexer(
key, self.shape, indexing.IndexingSupport.BASIC)

with self.datastore.ensure_open(autoclose=True):
array = self.get_array()
if key == () and self.ndim == 0:
if key.tuple == () and self.ndim == 0:
return array.get_value()
return array[key]

array = array[key.tuple]
if len(np_inds.tuple) > 0:
array = indexing.NumpyIndexingAdapter(array)[np_inds]

return array


class NioDataStore(AbstractDataStore, DataStorePickleMixin):
Expand All @@ -51,7 +56,7 @@ def __init__(self, filename, mode='r', autoclose=False):
self._mode = mode

def open_store_variable(self, name, var):
data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self))
data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self))
return Variable(var.dimensions, data, var.attributes)

def get_variables(self):
Expand Down
67 changes: 46 additions & 21 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,48 +42,73 @@ def dtype(self):
def shape(self):
return self._shape

def __getitem__(self, key):
key = indexing.unwrap_explicit_indexer(
key, self, allow=(indexing.BasicIndexer, indexing.OuterIndexer))
def _get_indexer(self, key):
""" Get indexer for rasterio array.
Parameter
---------
key: ExplicitIndexer
Returns
-------
band_key: an indexer for the 1st dimension
window: two tuples. Each consists of (start, stop).
squeeze_axis: axes to be squeezed
np_ind: indexer for loaded numpy array
See also
--------
indexing.decompose_indexer
"""
key, np_inds = indexing.decompose_indexer(
key, self.shape, indexing.IndexingSupport.OUTER)

# bands cannot be windowed but they can be listed
band_key = key[0]
n_bands = self.shape[0]
band_key = key.tuple[0]
new_shape = []
np_inds2 = []
# bands (axis=0) cannot be windowed but they can be listed
if isinstance(band_key, slice):
start, stop, step = band_key.indices(n_bands)
if step is not None and step != 1:
raise IndexError(_ERROR_MSG)
band_key = np.arange(start, stop)
start, stop, step = band_key.indices(self.shape[0])
band_key = np.arange(start, stop, step)
# be sure we give out a list
band_key = (np.asarray(band_key) + 1).tolist()
if isinstance(band_key, list): # if band_key is not a scalar
new_shape.append(len(band_key))
np_inds2.append(slice(None))

# but other dims can only be windowed
window = []
squeeze_axis = []
for i, (k, n) in enumerate(zip(key[1:], self.shape[1:])):
for i, (k, n) in enumerate(zip(key.tuple[1:], self.shape[1:])):
if isinstance(k, slice):
# step is always positive. see indexing.decompose_indexer
start, stop, step = k.indices(n)
if step is not None and step != 1:
raise IndexError(_ERROR_MSG)
np_inds2.append(slice(None, None, step))
new_shape.append(stop - start)
elif is_scalar(k):
# windowed operations will always return an array
# we will have to squeeze it later
squeeze_axis.append(i + 1)
squeeze_axis.append(- (2 - i))
start = k
stop = k + 1
else:
k = np.asarray(k)
start = k[0]
stop = k[-1] + 1
ids = np.arange(start, stop)
if not ((k.shape == ids.shape) and np.all(k == ids)):
raise IndexError(_ERROR_MSG)
start, stop = np.min(k), np.max(k) + 1
np_inds2.append(k - start)
new_shape.append(stop - start)
window.append((start, stop))

np_inds = indexing._combine_indexers(
indexing.OuterIndexer(tuple(np_inds2)), new_shape, np_inds)
return band_key, window, tuple(squeeze_axis), np_inds

def __getitem__(self, key):
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

out = self.rasterio_ds.read(band_key, window=tuple(window))
if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
return out
return indexing.NumpyIndexingAdapter(out)[np_inds]


def _parse_envi(meta):
Expand Down Expand Up @@ -249,7 +274,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
else:
attrs[k] = v

data = indexing.LazilyIndexedArray(RasterioArrayWrapper(riods))
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods))

# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
Expand Down
Loading

0 comments on commit 54468e1

Please sign in to comment.