Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added some logic to deal with rasterio objects in addition to filepaths #2589

Merged
merged 8 commits into from
Dec 23, 2018
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ matrix:
- env: CONDA_ENV=py36-bottleneck-dev
- env: CONDA_ENV=py36-condaforge-rc
- env: CONDA_ENV=py36-pynio-dev
- env: CONDA_ENV=py36-rasterio-0.36
- env: CONDA_ENV=py36-rasterio
- env: CONDA_ENV=py36-zarr-dev
- env: CONDA_ENV=docs
- env: CONDA_ENV=py36-hypothesis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- scipy
- seaborn
- toolz
- rasterio=0.36.0
- rasterio>=1.0
- bottleneck
- pip:
- coveralls
Expand Down
2 changes: 1 addition & 1 deletion doc/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ For netCDF and IO
for accessing CAMx, GEOS-Chem (bpch), NOAA ARL files, ICARTT files
(ffi1001) and many other.
- `rasterio <https://github.com/mapbox/rasterio>`__: for reading GeoTiffs and
other gridded raster datasets.
other gridded raster datasets. (version 1.0 or later)
- `iris <https://github.com/scitools/iris>`__: for conversion to and from iris'
Cube objects
- `cfgrib <https://github.com/ecmwf/cfgrib>`__: for reading GRIB files via the
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ v0.11.1 (unreleased)
Breaking changes
~~~~~~~~~~~~~~~~

- Minimum rasterio version increased from 0.36 to 1.0 (for ``open_rasterio``)
- Time bounds variables are now also decoded according to CF conventions
(:issue:`2565`). The previous behavior was to decode them only if they
had specific time attributes, now these attributes are copied
Expand All @@ -49,6 +50,10 @@ Enhancements
- :py:class:`CFTimeIndex` uses slicing for string indexing when possible (like
:py:class:`pandas.DatetimeIndex`), which avoids unnecessary copies.
By `Stephan Hoyer <https://github.com/shoyer>`_
- Enable passing ``rasterio.io.DatasetReader`` or ``rasterio.vrt.WarpedVRT`` to
``open_rasterio`` instead of file path string. Allows for in-memory
reprojection, see (:issue:`2588`).
By `Scott Henderson <https://github.com/scottyhq>`_.
- Like :py:class:`pandas.DatetimeIndex`, :py:class:`CFTimeIndex` now supports
"dayofyear" and "dayofweek" accessors (:issue:`2597`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
Expand Down
35 changes: 25 additions & 10 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
from collections import OrderedDict
from distutils.version import LooseVersion

import numpy as np

from .. import DataArray
Expand All @@ -23,13 +22,14 @@

class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, manager):
def __init__(self, manager, vrt_params=None):
from rasterio.vrt import WarpedVRT
self.manager = manager

# cannot save riods as an attribute: this would break pickleability
riods = manager.acquire()

riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params)
self.vrt_params = vrt_params
self._shape = (riods.count, riods.height, riods.width)

dtypes = riods.dtypes
Expand Down Expand Up @@ -103,6 +103,7 @@ def _get_indexer(self, key):
return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds)

def _getitem(self, key):
from rasterio.vrt import WarpedVRT
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

if not band_key or any(start == stop for (start, stop) in window):
Expand All @@ -112,6 +113,7 @@ def _getitem(self, key):
out = np.zeros(shape, dtype=self.dtype)
else:
riods = self.manager.acquire()
riods = riods if self.vrt_params is None else WarpedVRT(riods,**self.vrt_params)
out = riods.read(band_key, window=window)

if squeeze_axis:
Expand Down Expand Up @@ -176,8 +178,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,

Parameters
----------
filename : str
Path to the file to open.
filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT
Path to the file to open. Or already open rasterio dataset.
parse_coordinates : bool, optional
Whether to parse the x and y coordinates out of the file's
``transform`` attribute or not. The default is to automatically
Expand All @@ -204,11 +206,24 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
data : DataArray
The newly created DataArray.
"""

import rasterio
from rasterio.vrt import WarpedVRT
vrt_params = None
if isinstance(filename, rasterio.io.DatasetReader):
filename = filename.name
elif isinstance(filename, rasterio.vrt.WarpedVRT):
vrt = filename
filename = vrt.src_dataset.name
vrt_params = dict(crs=vrt.crs.to_string(),
resampling=vrt.resampling,
src_nodata=vrt.src_nodata,
dst_nodata=vrt.dst_nodata,
tolerance=vrt.tolerance,
warp_extras=vrt.warp_extras)

manager = CachingFileManager(rasterio.open, filename, mode='r')
riods = manager.acquire()
riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params)

if cache is None:
cache = chunks is None
Expand Down Expand Up @@ -282,13 +297,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if (isinstance(v, (list, np.ndarray)) and
len(v) == riods.count):
if (isinstance(v, (list, np.ndarray))
and len(v) == riods.count):
coords[k] = ('band', np.asarray(v))
else:
attrs[k] = v

data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager, vrt_params))

# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
Expand Down
113 changes: 90 additions & 23 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def check_dtypes_roundtripped(self, expected, actual):
actual_dtype = actual.variables[k].dtype
# TODO: check expected behavior for string dtypes more carefully
string_kinds = {'O', 'S', 'U'}
assert (expected_dtype == actual_dtype or
(expected_dtype.kind in string_kinds and
actual_dtype.kind in string_kinds))
assert (expected_dtype == actual_dtype
or (expected_dtype.kind in string_kinds and
actual_dtype.kind in string_kinds))

def test_roundtrip_test_data(self):
expected = create_test_data()
Expand Down Expand Up @@ -410,17 +410,17 @@ def test_roundtrip_cftime_datetime_data(self):
with self.roundtrip(expected, save_kwargs=kwds) as actual:
abs_diff = abs(actual.t.values - expected_decoded_t)
assert (abs_diff <= np.timedelta64(1, 's')).all()
assert (actual.t.encoding['units'] ==
'days since 0001-01-01 00:00:00.000000')
assert (actual.t.encoding['calendar'] ==
expected_calendar)
assert (actual.t.encoding['units']
== 'days since 0001-01-01 00:00:00.000000')
assert (actual.t.encoding['calendar']
== expected_calendar)

abs_diff = abs(actual.t0.values - expected_decoded_t0)
assert (abs_diff <= np.timedelta64(1, 's')).all()
assert (actual.t0.encoding['units'] ==
'days since 0001-01-01')
assert (actual.t.encoding['calendar'] ==
expected_calendar)
assert (actual.t0.encoding['units']
== 'days since 0001-01-01')
assert (actual.t.encoding['calendar']
== expected_calendar)

def test_roundtrip_timedelta_data(self):
time_deltas = pd.to_timedelta(['1h', '2h', 'NaT'])
Expand Down Expand Up @@ -668,24 +668,24 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn):

with self.roundtrip(decoded) as actual:
for k in decoded.variables:
assert (decoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (decoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(decoded, actual, decode_bytes=False)

with self.roundtrip(decoded,
open_kwargs=dict(decode_cf=False)) as actual:
# TODO: this assumes that all roundtrips will first
# encode. Is that something we want to test for?
for k in encoded.variables:
assert (encoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (encoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)

with self.roundtrip(encoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
assert (encoded.variables[k].dtype ==
actual.variables[k].dtype)
assert (encoded.variables[k].dtype
== actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)

# make sure roundtrip encoding didn't change the
Expand Down Expand Up @@ -2621,8 +2621,8 @@ def myatts(**attrs):
'ULOD_FLAG': '-7777', 'ULOD_VALUE': 'N/A',
'LLOD_FLAG': '-8888',
'LLOD_VALUE': ('N/A, N/A, N/A, N/A, 0.025'),
'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/' +
'IcarttDataFormat.htm'),
'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/'
+ 'IcarttDataFormat.htm'),
'REVISION': 'R0',
'R0': 'No comments for this revision.',
'TFLAG': 'Start_UTC'
Expand Down Expand Up @@ -2711,8 +2711,8 @@ def test_uamiv_format_read(self):
expected = xr.Variable(('TSTEP',), data,
dict(bounds='time_bounds',
long_name=('synthesized time coordinate ' +
'from SDATE, STIME, STEP ' +
'global attributes')))
'from SDATE, STIME, STEP '
+ 'global attributes')))
actual = camxfile.variables['time']
assert_allclose(expected, actual)
camxfile.close()
Expand Down Expand Up @@ -2741,8 +2741,8 @@ def test_uamiv_format_mfread(self):
data = np.concatenate([data1] * 2, axis=0)
attrs = dict(bounds='time_bounds',
long_name=('synthesized time coordinate ' +
'from SDATE, STIME, STEP ' +
'global attributes'))
'from SDATE, STIME, STEP '
+ 'global attributes'))
expected = xr.Variable(('TSTEP',), data, attrs)
actual = camxfile.variables['time']
assert_allclose(expected, actual)
Expand Down Expand Up @@ -3158,6 +3158,73 @@ def test_http_url(self):
import dask.array as da
assert isinstance(actual.data, da.Array)

def test_rasterio_environment(self):
import rasterio
with create_tmp_geotiff() as (tmp_file, expected):
# Should fail with error since suffix not allowed
with pytest.raises(Exception):
with rasterio.Env(GDAL_SKIP='GTiff'):
with xr.open_rasterio(tmp_file) as actual:
assert_allclose(actual, expected)

def test_rasterio_vrt(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For network tests we have a special decorator (@network), see https://github.com/pydata/xarray/blob/55f21deff4c2b42bd6ead4dbe26a1b123337913a/xarray/tests/test_tutorial.py (although that's the only use of it as it seems?)

import rasterio
# tmp_file default crs is UTM: CRS({'init': 'epsg:32618'}
with create_tmp_geotiff() as (tmp_file, expected):
with rasterio.open(tmp_file) as src:
with rasterio.vrt.WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
print(expected_crs)
expected_res = vrt.res
# Value of single pixel in center of image
lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
print(actual_crs)
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert actual_crs == expected_crs
assert actual_res == expected_res
assert actual_shape == expected_shape
assert expected_val.all() == actual_val.all()

@network
def test_rasterio_vrt_network(self):
import rasterio

url = 'https://storage.googleapis.com/\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't even know this form of line-wrapping for strings was possible in Python :)

gcp-public-data-landsat/LC08/01/047/027/\
LC08_L1TP_047027_20130421_20170310_01_T1/\
LC08_L1TP_047027_20130421_20170310_01_T1_B4.TIF'
env = rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR',
CPL_VSIL_CURL_USE_HEAD=False,
CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF')
with env:
with rasterio.open(url) as src:
with rasterio.vrt.WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
expected_res = vrt.res
# Value of single pixel in center of image
lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert_equal(actual_shape, expected_shape)
assert_equal(actual_crs, expected_crs)
assert_equal(actual_res, expected_res)
assert_equal(expected_val, actual_val)


class TestEncodingInvalid(object):

Expand Down