Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/rasterio #1070

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements-py35.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- pandas
- seaborn
- scipy
- rasterio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless rasterio only supports Python 3, should add it to other test suites, too (especially Python 2.7).

Copy link
Author

@NicWayand NicWayand Oct 31, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rasterio supports 2.7 and 3.3-3.5 (https://mapbox.github.io/rasterio/)

I added this requirement to all other test suites, was that correct? Version 1.0 is needed as many function names changed in rasterio at V1.0.

- pip:
- coveralls
- pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[wheel]
universal = 1

[pytest]
[tool:pytest]
python_files=test_*.py
1 change: 1 addition & 0 deletions xarray/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .pynio_ import NioDataStore
from .scipy_ import ScipyDataStore
from .h5netcdf_ import H5NetCDFStore
from .rasterio_ import RasterioDataStore
6 changes: 4 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
import gzip
import os.path
import threading
Expand All @@ -18,6 +17,7 @@
DATAARRAY_NAME = '__xarray_dataarray_name__'
DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'


def _get_default_engine(path, allow_remote=False):
if allow_remote and is_remote_uri(path): # pragma: no cover
try:
Expand Down Expand Up @@ -154,7 +154,7 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True,
decode_coords : bool, optional
If True, decode the 'coordinates' attribute to identify coordinates in
the resulting dataset.
engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional
engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'rasterio'}, optional
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
'netcdf4'.
Expand Down Expand Up @@ -252,6 +252,8 @@ def maybe_decode_store(store, lock=False):
store = backends.H5NetCDFStore(filename_or_obj, group=group)
elif engine == 'pynio':
store = backends.NioDataStore(filename_or_obj)
elif engine == 'rasterio':
store = backends.RasterioDataStore(filename_or_obj)
else:
raise ValueError('unrecognized engine for open_dataset: %r'
% engine)
Expand Down
139 changes: 139 additions & 0 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import numpy as np

try:
import rasterio
except ImportError:
rasterio = False

from .. import Variable, DataArray
from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin
from ..core import indexing
from ..core.pycompat import OrderedDict

from .common import AbstractDataStore

__rio_varname__ = 'raster'


class RasterioArrayWrapper(NDArrayMixin):
def __init__(self, ds):
self._ds = ds
self.array = ds.read()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless the behavior of rasterio changed with v1.0, this loads the data in memory. I might be wrong, but I think the call to ds.read() should happen in __getitem__, ideally with a window kwarg.

I also wonder if the call to read shouldn't happen in a with rasterio.Env(): environment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shoyer, do we need to maintain the array attribute here? Would it make more sense to just populate the _ds attribute and set the array in __getitem__?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to set array -- that's merely for the benefit of NDArrayMixin in core/utils.py. You might even skip NDArrayMixin altogether and just use NdimSizeLenMixin, though you'll want to define properties/methods for each of those defined by NDArrayMixin.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicWayand -

I think you can basically drop this in replacing the current RasterioArrayWrapper:

class RasterioArrayWrapper(NdimSizeLenMixin):
    """Mixin class RasterIO datasets for making wrappers of N-dimensional 
    arrays that conform to the ndarray interface required for the data
    argument to Variable objects.
    A subclass should set the `array` property and override one or more of
    `dtype`, `shape` and `__getitem__`.
    """
    def __init__(self, ds):
        self.ds = ds

    @property
    def dtype(self):
        if len(set(self.ds.dtypes[0])) != 1:
            raise ValueError(
                'Can only handle Rastio dataset with all bands having the same type')
        return np.dtype(self.ds.dtypes[0])

    @property
    def shape(self):
        return self.ds.shape

    def __array__(self, dtype=None):
        '''Not sure if this will work as is'''
        return np.asarray(self[...], dtype=dtype)

    def __getitem__(self, key):
        band = range(self.shape[0])[key[0]]
        window = []
        for win in key[1:]:
            if instance(win, slice):
                window.append((win.start, win.stop))
            elif isinstance(win, int):
                window.append((win, win + 1))
            else:  # integer ndarray
                window.append((win.min(), win.max()))
        raw_data = self.ds.read(band, window=window)
        # now, fix up raw_data to conform to numpy indexing conventions
        # - drop axes for integer band/windows
        # - stride if windows are slices with win.step != 1
        # - subset if window is an integer ndarray

    def __repr__(self):
        return '%s(array=%r)' % (type(self).__name__, self.ds)

I'm not sure the __array__ will work as it is until we get the __get_item__ fully functional.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also worth consider how/if you want to handle automatically masking missing values -- it looks like ds.read() can optionally return a MaskedArray. The standard xarray approach would be to automatically promote the dtype and fill these with NaNs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, the above code is waiting for the decision from rasterio group?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest getting something working here, with an eye toward keeping it self-contained. Later, we can try to port it upstream to rasterio.


@property
def dtype(self):
return np.dtype(self._ds.dtypes[0])

def __getitem__(self, key):
if key == () and self.ndim == 0:
return self.array.get_value()
return self.array[key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what @fmaussion said above, I think this should be something like:

    def __getitem__(self, key):
        if key == () and self.ndim == 0:
            return self._ds.read()
        return self._ds.read(band, window=window)

Where band and window describe the slice of data to be read. Off the top of my head, I'm not exactly sure how to parse the key here though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the time indexers get here, they should have already passed through canonicalize_indexer, which means you only need to handle a key that is a tuple of the appropriate number of dimensions (i.e., 3) composed of integers, slices and integer ndarrays.

Based on the docstring for read, we want something like:

def __getitem__(self, key):
    band = range(self.shape[0])[key[0]]
    window = []
    for win in key[1:]:
        if instance(win, slice):
            window.append((win.start, win.stop))
        elif isinstance(win, int):
            window.append((win, win + 1))
        else:  # integer ndarray
            window.append((win.min(), win.max()))
    raw_data = self._ds.read(band, window=window)
    # now, fix up raw_data to conform to numpy indexing conventions
    # - drop axes for integer band/windows
    # - stride if windows are slices with win.step != 1
    # - subset if window is an integer ndarray

Honestly, this logic should probably live in rasterio if possible. I'm a little surprised that they have never implemented a __getitem__ method.

Copy link
Member

@shoyer shoyer Oct 31, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also: https://gist.github.com/lpinner/bd57b54a5c6903e4a6a2 (can't reuse this directly, though, because it doesn't have a license).

Anyways, I would definitely see if the rasterio folks are up for implementing a __getitem__ method. There's no reason why this is xarray specific -- you would need this for just using dask.array with rasterio, as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also rasterio/rasterio#920 for related discussion.



class RasterioDataStore(AbstractDataStore):
"""Store for accessing datasets via Rasterio
"""
def __init__(self, filename, mode='r'):

with rasterio.Env():
self.ds = rasterio.open(filename, mode=mode, )

# Get coords
nx, ny = self.ds.width, self.ds.height
x0, y0 = self.ds.bounds.left, self.ds.bounds.top
dx, dy = self.ds.res[0], -self.ds.res[1]

self.coords = {'y': np.arange(start=y0, stop=(y0 + ny * dy), step=dy),
'x': np.arange(start=x0, stop=(x0 + nx * dx), step=dx)}

# Get dims
if self.ds.count >= 1:
self.dims = ('band', 'y', 'x')
self.coords['band'] = self.ds.indexes
else:
raise ValueError('unknown dims')

self._attrs = OrderedDict()
for attr_name in ['crs', 'transform', 'proj']:
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use with pycompat.suppress(AttributeError):

self._attrs[attr_name] = getattr(self.ds, attr_name)
except AttributeError:
pass

# def get_vardata(self, var_id=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed

# """Read the geotiff band.
# Parameters
# ----------
# var_id: the variable name (here the band number)
# """
# # wx = (self.sub_x[0], self.sub_x[1] + 1)
# # wy = (self.sub_y[0], self.sub_y[1] + 1)
# with rasterio.Env():
# band = self.ds.read() # var_id, window=(wy, wx))
# return band

def open_store_variable(self, var):
if var != __rio_varname__:
raise ValueError(
'Rasterio variables are all named %s' % __rio_varname__)
data = indexing.LazilyIndexedArray(
RasterioArrayWrapper(self.ds))
return Variable(self.dims, data, self._attrs)

def get_variables(self):
# Get lat lon coordinates
coords = _try_to_get_latlon_coords(self.coords, self._attrs)
vars = {__rio_varname__: self.open_store_variable(__rio_varname__)}
vars.update(coords)
return FrozenOrderedDict(vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use another name other than vars since it is a python builtin.


def get_attrs(self):
return Frozen(self._attrs)

def get_dimensions(self):
return Frozen(self.ds.dims)

def close(self):
self.ds.close()


def _transform_proj(p1, p2, x, y, nocopy=False):
"""Wrapper around the pyproj transform.
When two projections are equal, this function avoids quite a bunch of
useless calculations. See https://github.com/jswhit/pyproj/issues/15
"""
import pyproj
import copy

if p1.srs == p2.srs:
if nocopy:
return x, y
else:
return copy.deepcopy(x), copy.deepcopy(y)

return pyproj.transform(p1, p2, x, y)


def _try_to_get_latlon_coords(coords, attrs):
coords_out = {}
try:
import pyproj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little surprised rasterio doesn't have projections built in.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do: https://mapbox.github.io/rasterio/topics/reproject.html

I will take a stab at doing it the rasterio way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a reprojection, which does seem useful, but it seems separate from giving you lat and lon coordinates

except ImportError:
pyproj = False
if 'crs' in attrs and pyproj:
proj = pyproj.Proj(attrs['crs'])
x, y = np.meshgrid(coords['x'], coords['y'])
proj_out = pyproj.Proj("+init=EPSG:4326", preserve_units=True)
xc, yc = _transform_proj(proj, proj_out, x, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might try to do this calculation lazily, e.g., by making a utils.NDArrayMixin subclass in conventions.py and wrapping with indexing.LazilyIndexedArray.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand how to do this. Is there an example I can go off of?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some examples of making NDArrayMixin subclasses in conventions.py, e.g.,
https://github.com/pydata/xarray/blob/master/xarray/conventions.py#L311

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to use dask.array to build the coordinate arrays, though the NDArrayMixin would avoid the need for dask.

coords = dict(y=coords['y'], x=coords['x'])
dims = ('y', 'x')

coords_out['lat'] = DataArray(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be just xarray.Variable objects, not DataArray objects

data=yc, coords=coords, dims=dims, name='lat',
attrs={'units': 'degrees_north', 'long_name': 'latitude',
'standard_name': 'latitude'})
coords_out['lon'] = DataArray(
data=xc, coords=coords, dims=dims, name='lon',
attrs={'units': 'degrees_east', 'long_name': 'longitude',
'standard_name': 'longitude'})
return coords_out
11 changes: 11 additions & 0 deletions xarray/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
has_pynio = False


try:
import rasterio
has_rasterio = True
except ImportError:
has_rasterio = False


try:
import dask.array
import dask
Expand Down Expand Up @@ -90,6 +97,10 @@ def requires_pynio(test):
return test if has_pynio else unittest.skip('requires pynio')(test)


def requires_rasterio(test):
return test if has_rasterio else unittest.skip('requires rasterio')(test)


def requires_scipy_or_netCDF4(test):
return (test if has_scipy or has_netCDF4
else unittest.skip('requires scipy or netCDF4')(test))
Expand Down
31 changes: 30 additions & 1 deletion xarray/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap,
requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf,
requires_pynio, has_netCDF4, has_scipy)
requires_pynio, requires_rasterio, has_netCDF4, has_scipy)
from .test_dataset import create_test_data

try:
Expand Down Expand Up @@ -1063,6 +1063,35 @@ def test_weakrefs(self):
self.assertDatasetIdentical(actual, expected)


@requires_rasterio
class TestRasterIO(CFEncodedDataTest, Only32BitTypes, TestCase):
def test_write_store(self):
# rasterio is read-only for now
pass

def test_orthogonal_indexing(self):
# rasterio also does not support list-like indexing
pass

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={}):
with create_tmp_file() as tmp_file:
data.to_netcdf(tmp_file, engine='scipy', **save_kwargs)
with open_dataset(tmp_file, engine='rasterio', **open_kwargs) as ds:
yield ds

def test_weakrefs(self):
example = Dataset({'foo': ('x', np.arange(5.0))})
expected = example.rename({'foo': 'bar', 'x': 'y'})

with create_tmp_file() as tmp_file:
example.to_netcdf(tmp_file, engine='scipy')
on_disk = open_dataset(tmp_file, engine='rasterio')
actual = on_disk.rename({'foo': 'bar', 'x': 'y'})
del on_disk # trigger garbage collection
self.assertDatasetIdentical(actual, expected)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicWayand - take a look at the PyNio tests. They are probably the closest analog for what we need to test here.



class TestEncodingInvalid(TestCase):

def test_extract_nc4_encoding(self):
Expand Down