diff --git a/.travis.yml b/.travis.yml index e887ec60467..2d7629a7e51 100644 --- a/.travis.yml +++ b/.travis.yml @@ -69,7 +69,7 @@ install: - python setup.py install script: - - py.test xarray --cov=xarray --cov-report term-missing + - py.test xarray --cov=xarray --cov-report term-missing -v after_success: - coveralls diff --git a/ci/requirements-py34.yml b/ci/requirements-py34.yml index a49611751ca..3ad45660b80 100644 --- a/ci/requirements-py34.yml +++ b/ci/requirements-py34.yml @@ -1,9 +1,13 @@ name: test_env +channels: + - conda-forge dependencies: - python=3.4 - bottleneck - pytest - pandas + - rasterio + - scipy - pip: - coveralls - pytest-cov diff --git a/doc/_static/dataset-diagram-build.sh b/doc/_static/dataset-diagram-build.sh old mode 100755 new mode 100644 diff --git a/setup.cfg b/setup.cfg index 6770e9c807f..44b0d881cc2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [wheel] universal = 1 -[pytest] +[tool:pytest] python_files=test_*.py diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index a082bd53e5e..192a3c57db2 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -10,3 +10,4 @@ from .pynio_ import NioDataStore from .scipy_ import ScipyDataStore from .h5netcdf_ import H5NetCDFStore +from .rasterio_ import RasterioDataStore diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2ba87db7e90..8797be8b25a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,4 +1,3 @@ -import sys import gzip import os.path import threading @@ -18,6 +17,7 @@ DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' + def _get_default_engine(path, allow_remote=False): if allow_remote and is_remote_uri(path): # pragma: no cover try: @@ -154,7 +154,7 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'rasterio'}, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. @@ -252,6 +252,8 @@ def maybe_decode_store(store, lock=False): store = backends.H5NetCDFStore(filename_or_obj, group=group) elif engine == 'pynio': store = backends.NioDataStore(filename_or_obj) + elif engine == 'rasterio': + store = backends.RasterioDataStore(filename_or_obj) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py new file mode 100644 index 00000000000..e19a7c27984 --- /dev/null +++ b/xarray/backends/rasterio_.py @@ -0,0 +1,120 @@ +import numpy as np + +try: + import rasterio +except ImportError: + rasterio = False + +from .. import Variable, DataArray +from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin +from ..core import indexing +from ..core.pycompat import OrderedDict, suppress + +from .common import AbstractDataStore + +__rio_varname__ = 'raster' + + +class RasterioArrayWrapper(NDArrayMixin): + def __init__(self, ds): + self._ds = ds + self.array = ds.read() + + @property + def dtype(self): + return np.dtype(self._ds.dtypes[0]) + + def __getitem__(self, key): + if key == () and self.ndim == 0: + return self.array.get_value() + return self.array[key] + + +class RasterioDataStore(AbstractDataStore): + """Store for accessing datasets via Rasterio + """ + def __init__(self, filename, mode='r'): + + with rasterio.Env(): + self.ds = rasterio.open(filename, mode=mode, ) + + # Get coords + nx, ny = self.ds.width, self.ds.height + x0, y0 = self.ds.bounds.left, self.ds.bounds.top + dx, dy = self.ds.res[0], -self.ds.res[1] + + self.coords = {'y': np.linspace(start=y0, num=ny, stop=(y0 + (ny-1) * dy)), + 'x': np.linspace(start=x0, num=nx, stop=(x0 + (nx-1) * dx))} + + # Get dims + if self.ds.count >= 1: + self.dims = ('band', 'y', 'x') + self.coords['band'] = self.ds.indexes + else: + raise ValueError('unknown dims') + + self._attrs = OrderedDict() + with suppress(AttributeError): + for attr_name in ['crs', 'transform', 'proj']: + self._attrs[attr_name] = getattr(self.ds, attr_name) + + def open_store_variable(self, var): + if var != __rio_varname__: + raise ValueError( + 'Rasterio variables are all named %s' % __rio_varname__) + data = indexing.LazilyIndexedArray( + RasterioArrayWrapper(self.ds)) + return Variable(self.dims, data, self._attrs) + + def get_variables(self): + # Get lat lon coordinates + coords = _try_to_get_latlon_coords(self.coords, self._attrs) + rio_vars = {__rio_varname__: self.open_store_variable(__rio_varname__)} + rio_vars.update(coords) + return FrozenOrderedDict(rio_vars) + + def get_attrs(self): + return Frozen(self._attrs) + + def get_dimensions(self): + return Frozen(self.ds.dims) + + def close(self): + self.ds.close() + + +def _transform_proj(p1, p2, x, y, nocopy=False): + """Wrapper around the pyproj transform. + When two projections are equal, this function avoids quite a bunch of + useless calculations. See https://github.com/jswhit/pyproj/issues/15 + """ + import pyproj + import copy + + if p1.srs == p2.srs: + if nocopy: + return x, y + else: + return copy.deepcopy(x), copy.deepcopy(y) + + return pyproj.transform(p1, p2, x, y) + + +def _try_to_get_latlon_coords(coords, attrs): + coords_out = {} + try: + import pyproj + except ImportError: + pyproj = False + if 'crs' in attrs and pyproj: + proj = pyproj.Proj(attrs['crs']) + x, y = np.meshgrid(coords['x'], coords['y']) + proj_out = pyproj.Proj("+init=EPSG:4326", preserve_units=True) + xc, yc = _transform_proj(proj, proj_out, x, y) + dims = ('y', 'x') + + coords_out['lat'] = Variable(dims,yc,attrs={'units': 'degrees_north', 'long_name': 'latitude', + 'standard_name': 'latitude'}) + coords_out['lon'] = Variable(dims,xc,attrs={'units': 'degrees_east', 'long_name': 'longitude', + 'standard_name': 'longitude'}) + return coords_out diff --git a/xarray/test/__init__.py b/xarray/test/__init__.py index 047a247bbb7..2cff7ab495f 100644 --- a/xarray/test/__init__.py +++ b/xarray/test/__init__.py @@ -47,6 +47,13 @@ has_pynio = False +try: + import rasterio + has_rasterio = True +except ImportError: + has_rasterio = False + + try: import dask.array import dask @@ -90,6 +97,10 @@ def requires_pynio(test): return test if has_pynio else unittest.skip('requires pynio')(test) +def requires_rasterio(test): + return test if has_rasterio else unittest.skip('requires rasterio')(test) + + def requires_scipy_or_netCDF4(test): return (test if has_scipy or has_netCDF4 else unittest.skip('requires scipy or netCDF4')(test)) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index eeb5561579b..480bf645f38 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -21,7 +21,7 @@ from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, - requires_pynio, has_netCDF4, has_scipy) + requires_pynio, requires_rasterio, has_netCDF4, has_scipy) from .test_dataset import create_test_data try: @@ -1063,6 +1063,35 @@ def test_weakrefs(self): self.assertDatasetIdentical(actual, expected) +@requires_rasterio +class TestRasterIO(CFEncodedDataTest, Only32BitTypes, TestCase): + def test_write_store(self): + # rasterio is read-only for now + pass + + def test_orthogonal_indexing(self): + # rasterio also does not support list-like indexing + pass + + @contextlib.contextmanager + def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + with create_tmp_file() as tmp_file: + data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) + with open_dataset(tmp_file, engine='rasterio', **open_kwargs) as ds: + yield ds + + def test_weakrefs(self): + example = Dataset({'foo': ('x', np.arange(5.0))}) + expected = example.rename({'foo': 'bar', 'x': 'y'}) + + with create_tmp_file() as tmp_file: + example.to_netcdf(tmp_file, engine='scipy') + on_disk = open_dataset(tmp_file, engine='rasterio') + actual = on_disk.rename({'foo': 'bar', 'x': 'y'}) + del on_disk # trigger garbage collection + self.assertDatasetIdentical(actual, expected) + + class TestEncodingInvalid(TestCase): def test_extract_nc4_encoding(self):