-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added some logic to deal with rasterio objects in addition to filepaths #2589
Changes from 1 commit
35c5afb
35f9947
ce56d35
cd6fcb9
a113b50
de21485
dde9428
ea068fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…h strings
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,8 @@ | |
import warnings | ||
from collections import OrderedDict | ||
from distutils.version import LooseVersion | ||
|
||
import rasterio | ||
from rasterio.vrt import WarpedVRT | ||
import numpy as np | ||
|
||
from .. import DataArray | ||
|
@@ -24,11 +25,13 @@ | |
class RasterioArrayWrapper(BackendArray): | ||
"""A wrapper around rasterio dataset objects""" | ||
|
||
def __init__(self, manager): | ||
def __init__(self, manager, vrt=None): | ||
self.manager = manager | ||
|
||
# cannot save riods as an attribute: this would break pickleability | ||
riods = manager.acquire() | ||
if vrt: | ||
riods = vrt | ||
|
||
self._shape = (riods.count, riods.height, riods.width) | ||
|
||
|
@@ -123,6 +126,39 @@ def __getitem__(self, key): | |
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem) | ||
|
||
|
||
class RasterioVRTWrapper(RasterioArrayWrapper): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than adding a subclass with a lot of duplicated logic, could you add this into the base Something like: def __init__(self, manager, vrt_params=None):
...
riods = riods if vrt_params is None else WarpedVRT(riods, **vrt_params) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, just moved the changes to RasterioArrayWrapper |
||
"""A wrapper around rasterio WarpedVRT objects""" | ||
def __init__(self, manager, vrt_params): | ||
#print('Using VRT Wrapper') | ||
self.manager = manager | ||
self.vrt_params = vrt_params | ||
# cannot save riods as an attribute: this would break pickleability | ||
riods = manager.acquire() | ||
vrt = WarpedVRT(riods, **vrt_params) | ||
self._shape = (vrt.count, vrt.height, vrt.width) | ||
|
||
dtypes = vrt.dtypes | ||
if not np.all(np.asarray(dtypes) == dtypes[0]): | ||
raise ValueError('All bands should have the same dtype') | ||
self._dtype = np.dtype(dtypes[0]) | ||
|
||
def _getitem(self, key): | ||
band_key, window, squeeze_axis, np_inds = self._get_indexer(key) | ||
|
||
if not band_key or any(start == stop for (start, stop) in window): | ||
# no need to do IO | ||
shape = (len(band_key),) + tuple( | ||
stop - start for (start, stop) in window) | ||
out = np.zeros(shape, dtype=self.dtype) | ||
else: | ||
riods = self.manager.acquire() | ||
vrt = WarpedVRT(riods, **self.vrt_params) | ||
out = vrt.read(band_key, window=window) | ||
|
||
if squeeze_axis: | ||
out = np.squeeze(out, axis=squeeze_axis) | ||
return out[np_inds] | ||
|
||
def _parse_envi(meta): | ||
"""Parse ENVI metadata into Python data structures. | ||
|
||
|
@@ -176,8 +212,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, | |
|
||
Parameters | ||
---------- | ||
filename : str | ||
Path to the file to open. | ||
filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT | ||
Path to the file to open. Or already open rasterio dataset. | ||
parse_coordinates : bool, optional | ||
Whether to parse the x and y coordinates out of the file's | ||
``transform`` attribute or not. The default is to automatically | ||
|
@@ -204,12 +240,26 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, | |
data : DataArray | ||
The newly created DataArray. | ||
""" | ||
|
||
import rasterio | ||
vrt_params = None | ||
if isinstance(filename, rasterio.io.DatasetReader): | ||
filename = filename.name | ||
elif isinstance(filename, rasterio.vrt.WarpedVRT): | ||
vrt = filename | ||
filename = vrt.src_dataset.name | ||
#crs = vrt.crs.to_string() | ||
vrt_params = dict(crs=vrt.crs.to_string(), | ||
resampling=vrt.resampling, | ||
src_nodata=vrt.src_nodata, | ||
dst_nodata=vrt.dst_nodata, | ||
tolerance=vrt.tolerance, | ||
warp_extras=vrt.warp_extras) | ||
|
||
manager = CachingFileManager(rasterio.open, filename, mode='r') | ||
riods = manager.acquire() | ||
|
||
if vrt_params: | ||
riods = WarpedVRT(riods, **vrt_params) | ||
|
||
if cache is None: | ||
cache = chunks is None | ||
|
||
|
@@ -288,7 +338,10 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, | |
else: | ||
attrs[k] = v | ||
|
||
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) | ||
if vrt_params: | ||
data = indexing.LazilyOuterIndexedArray(RasterioVRTWrapper(manager, vrt_params)) | ||
else: | ||
data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) | ||
|
||
# this lets you write arrays loaded with rasterio | ||
data = indexing.CopyOnWriteArray(data) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3057,6 +3057,42 @@ def test_http_url(self): | |
import dask.array as da | ||
assert isinstance(actual.data, da.Array) | ||
|
||
def test_rasterio_environment(self): | ||
with create_tmp_geotiff() as (tmp_file, expected): | ||
# Should fail with error since suffix not allowed | ||
with pytest.raises(Exception): | ||
with rasterio.Env(CPL_VSIL_CURL_ALLOWED_EXTENSIONS='H5'): | ||
with xr.open_rasterio(tmp_file) as actual: | ||
assert_allclose(actual, expected) | ||
|
||
def test_rasterio_vrt(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For network tests we have a special decorator ( |
||
url = 'https://storage.googleapis.com/\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't even know this form of line-wrapping for strings was possible in Python :) |
||
gcp-public-data-landsat/LC08/01/047/027/\ | ||
LC08_L1TP_047027_20130421_20170310_01_T1/\ | ||
LC08_L1TP_047027_20130421_20170310_01_T1_B4.TIF' | ||
env = rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', | ||
CPL_VSIL_CURL_USE_HEAD=False, | ||
CPL_VSIL_CURL_ALLOWED_EXTENSIONS='TIF') | ||
with env: | ||
with rasterio.open(url) as src: | ||
with WarpedVRT(src, crs='epsg:4326') as vrt: | ||
expected_shape = (vrt.width, vrt.height) | ||
expected_crs = vrt.crs | ||
expected_res = vrt.res | ||
# Value of single pixel in center of image | ||
lon,lat = vrt.xy(vrt.width // 2, vrt.height // 2) | ||
expected_val = next(vrt.sample([(lon, lat)])) | ||
with xr.open_rasterio(vrt) as da: | ||
actual_shape = (da.sizes['x'], da.sizes['y']) | ||
actual_crs = da.crs | ||
actual_res = da.res | ||
actual_val = da.sel(dict(x=lon, y=lat), | ||
method='nearest').data | ||
|
||
assert_equal(actual_shape, expected_shape) | ||
assert_equal(actual_crs, expected_crs) | ||
assert_equal(actual_res, expected_res) | ||
assert_equal(expected_val, actual_val) | ||
|
||
class TestEncodingInvalid(object): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imports of optional dependencies can unfortunately not happen at the module top level - import them only when needed in the function scope