Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added some logic to deal with rasterio objects in addition to filepaths #2589

Merged
merged 8 commits into from
Dec 23, 2018
Next Next commit
added some logic to deal with rasterio objects in addition to filepat…
…h strings
scottyhq committed Dec 4, 2018

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 35c5afbc46a84c1e12031e2cc9a21445a7bab9a5
67 changes: 60 additions & 7 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,8 @@
import warnings
from collections import OrderedDict
from distutils.version import LooseVersion

import rasterio
from rasterio.vrt import WarpedVRT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports of optional dependencies can unfortunately not happen at the module top level - import them only when needed in the function scope

import numpy as np

from .. import DataArray
@@ -24,11 +25,13 @@
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, manager):
def __init__(self, manager, vrt=None):
self.manager = manager

# cannot save riods as an attribute: this would break pickleability
riods = manager.acquire()
if vrt:
riods = vrt

self._shape = (riods.count, riods.height, riods.width)

@@ -123,6 +126,39 @@ def __getitem__(self, key):
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem)


class RasterioVRTWrapper(RasterioArrayWrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding a subclass with a lot of duplicated logic, could you add this into the base RasterioArrayWrapper class?

Something like:

def __init__(self, manager, vrt_params=None):
    ...
    riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params)

Copy link
Contributor Author

@scottyhq scottyhq Dec 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, just moved the changes to RasterioArrayWrapper

"""A wrapper around rasterio WarpedVRT objects"""
def __init__(self, manager, vrt_params):
#print('Using VRT Wrapper')
self.manager = manager
self.vrt_params = vrt_params
# cannot save riods as an attribute: this would break pickleability
riods = manager.acquire()
vrt = WarpedVRT(riods, **vrt_params)
self._shape = (vrt.count, vrt.height, vrt.width)

dtypes = vrt.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError('All bands should have the same dtype')
self._dtype = np.dtype(dtypes[0])

def _getitem(self, key):
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

if not band_key or any(start == stop for (start, stop) in window):
# no need to do IO
shape = (len(band_key),) + tuple(
stop - start for (start, stop) in window)
out = np.zeros(shape, dtype=self.dtype)
else:
riods = self.manager.acquire()
vrt = WarpedVRT(riods, **self.vrt_params)
out = vrt.read(band_key, window=window)

if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
return out[np_inds]

def _parse_envi(meta):
"""Parse ENVI metadata into Python data structures.

@@ -176,8 +212,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,

Parameters
----------
filename : str
Path to the file to open.
filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT
Path to the file to open. Or already open rasterio dataset.
parse_coordinates : bool, optional
Whether to parse the x and y coordinates out of the file's
``transform`` attribute or not. The default is to automatically
@@ -204,12 +240,26 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
data : DataArray
The newly created DataArray.
"""

import rasterio
vrt_params = None
if isinstance(filename, rasterio.io.DatasetReader):
filename = filename.name
elif isinstance(filename, rasterio.vrt.WarpedVRT):
vrt = filename
filename = vrt.src_dataset.name
#crs = vrt.crs.to_string()
vrt_params = dict(crs=vrt.crs.to_string(),
resampling=vrt.resampling,
src_nodata=vrt.src_nodata,
dst_nodata=vrt.dst_nodata,
tolerance=vrt.tolerance,
warp_extras=vrt.warp_extras)

manager = CachingFileManager(rasterio.open, filename, mode='r')
riods = manager.acquire()

if vrt_params:
riods = WarpedVRT(riods, **vrt_params)

if cache is None:
cache = chunks is None

@@ -288,7 +338,10 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
else:
attrs[k] = v

data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))
if vrt_params:
data = indexing.LazilyOuterIndexedArray(RasterioVRTWrapper(manager, vrt_params))
else:
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))

# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
36 changes: 36 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -3057,6 +3057,42 @@ def test_http_url(self):
import dask.array as da
assert isinstance(actual.data, da.Array)

def test_rasterio_environment(self):
with create_tmp_geotiff() as (tmp_file, expected):
# Should fail with error since suffix not allowed
with pytest.raises(Exception):
with rasterio.Env(CPL_VSIL_CURL_ALLOWED_EXTENSIONS='H5'):
with xr.open_rasterio(tmp_file) as actual:
assert_allclose(actual, expected)

def test_rasterio_vrt(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For network tests we have a special decorator (@network), see https://github.com/pydata/xarray/blob/55f21deff4c2b42bd6ead4dbe26a1b123337913a/xarray/tests/test_tutorial.py (although that's the only use of it as it seems?)

url = 'https://storage.googleapis.com/\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't even know this form of line-wrapping for strings was possible in Python :)

gcp-public-data-landsat/LC08/01/047/027/\
LC08_L1TP_047027_20130421_20170310_01_T1/\
LC08_L1TP_047027_20130421_20170310_01_T1_B4.TIF'
env = rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR',
CPL_VSIL_CURL_USE_HEAD=False,
CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF')
with env:
with rasterio.open(url) as src:
with WarpedVRT(src, crs='epsg:4326') as vrt:
expected_shape = (vrt.width, vrt.height)
expected_crs = vrt.crs
expected_res = vrt.res
# Value of single pixel in center of image
lon,lat = vrt.xy(vrt.width // 2, vrt.height // 2)
expected_val = next(vrt.sample([(lon, lat)]))
with xr.open_rasterio(vrt) as da:
actual_shape = (da.sizes['x'], da.sizes['y'])
actual_crs = da.crs
actual_res = da.res
actual_val = da.sel(dict(x=lon, y=lat),
method='nearest').data

assert_equal(actual_shape, expected_shape)
assert_equal(actual_crs, expected_crs)
assert_equal(actual_res, expected_res)
assert_equal(expected_val, actual_val)

class TestEncodingInvalid(object):