diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5e5da295186..980f996cb6d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,9 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``. + The ``rasterio`` backend now returns pickleable objects (:issue:`2021`). + .. _whats-new.0.10.6: v0.10.6 (31 May 2018) @@ -220,6 +223,7 @@ Bug fixes By `Deepak Cherian `_. - Colorbar limits are now determined by excluding ±Infs too. By `Deepak Cherian `_. + By `Joe Hamman `_. - Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`). By `Alex Hilson `_ and `Stephan Hoyer `_. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 2961838e85f..d5eccd9be52 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -8,6 +8,7 @@ import traceback import warnings from collections import Mapping, OrderedDict +from functools import partial import numpy as np @@ -507,3 +508,30 @@ def assert_open(self): if not self._isopen: raise AssertionError('internal failure: file must be open ' 'if `autoclose=True` is used.') + + +class PickleByReconstructionWrapper(object): + + def __init__(self, opener, file, mode='r', **kwargs): + self.opener = partial(opener, file, mode=mode, **kwargs) + self.mode = mode + self._ds = None + + @property + def value(self): + self._ds = self.opener() + return self._ds + + def __getstate__(self): + state = self.__dict__.copy() + del state['_ds'] + if self.mode == 'w': + # file has already been created, don't override when restoring + state['mode'] = 'a' + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def close(self): + self._ds.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 8c0764c3ec9..0f19a1b51be 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -8,7 +8,7 @@ from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray +from .common import BackendArray, PickleByReconstructionWrapper try: from dask.utils import SerializableLock as Lock @@ -25,15 +25,15 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, rasterio_ds): - self.rasterio_ds = rasterio_ds - self._shape = (rasterio_ds.count, rasterio_ds.height, - rasterio_ds.width) + def __init__(self, riods): + self.riods = riods + self._shape = (riods.value.count, riods.value.height, + riods.value.width) self._ndims = len(self.shape) @property def dtype(self): - dtypes = self.rasterio_ds.dtypes + dtypes = self.riods.value.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') return np.dtype(dtypes[0]) @@ -105,7 +105,7 @@ def _get_indexer(self, key): def __getitem__(self, key): band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - out = self.rasterio_ds.read(band_key, window=tuple(window)) + out = self.riods.value.read(band_key, window=tuple(window)) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return indexing.NumpyIndexingAdapter(out)[np_inds] @@ -194,7 +194,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, """ import rasterio - riods = rasterio.open(filename, mode='r') + + riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') if cache is None: cache = chunks is None @@ -202,20 +203,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.count < 1: + if riods.value.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.indexes) + coords['band'] = np.asarray(riods.value.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.affine + transform = riods.value.affine else: - transform = riods.transform + transform = riods.value.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.width, riods.height + nx, ny = riods.value.width, riods.value.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -238,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # For serialization store as tuple of 6 floats, the last row being # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods, 'crs') and riods.crs: + if hasattr(riods.value, 'crs') and riods.value.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.crs.to_string() - if hasattr(riods, 'res'): + attrs['crs'] = riods.value.crs.to_string() + if hasattr(riods.value, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.res - if hasattr(riods, 'is_tiled'): + attrs['res'] = riods.value.res + if hasattr(riods.value, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.is_tiled) + attrs['is_tiled'] = np.uint8(riods.value.is_tiled) with warnings.catch_warnings(): - # casting riods.transform to a tuple makes this future proof + # casting riods.value.transform to a tuple makes this future proof warnings.simplefilter('ignore', FutureWarning) - if hasattr(riods, 'transform'): + if hasattr(riods.value, 'transform'): # Affine transformation matrix (tuple of floats) # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.transform) - if hasattr(riods, 'nodatavals'): + attrs['transform'] = tuple(riods.value.transform) + if hasattr(riods.value, 'nodatavals'): # The nodata values for the raster bands attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.nodatavals]) + for nodataval in riods.value.nodatavals]) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.driver + driver = riods.value.driver if driver in parsers: - meta = parsers[driver](riods.tags(ns=driver)) + meta = parsers[driver](riods.value.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise - if isinstance(v, (list, np.ndarray)) and len(v) == riods.count: + if (isinstance(v, (list, np.ndarray)) and + len(v) == riods.value.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0e6151b2db5..df7ed66f4fd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -19,7 +19,8 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import robust_getitem +from xarray.backends.common import (robust_getitem, + PickleByReconstructionWrapper) from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing @@ -2724,7 +2725,8 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, # yields a temporary geotiff file and a corresponding expected DataArray import rasterio from rasterio.transform import from_origin - with create_tmp_file(suffix='.tif') as tmp_file: + with create_tmp_file(suffix='.tif', + allow_cleanup_failure=ON_WINDOWS) as tmp_file: # allow 2d or 3d shapes if nz == 1: data_shape = ny, nx @@ -2996,6 +2998,14 @@ def test_chunks(self): ex = expected.sel(band=1).mean(dim='x') assert_allclose(ac, ex) + def test_pickle_rasterio(self): + # regression test for https://github.com/pydata/xarray/issues/2121 + with create_tmp_geotiff() as (tmp_file, expected): + with xr.open_rasterio(tmp_file) as rioda: + temp = pickle.dumps(rioda) + with pickle.loads(temp) as actual: + assert_equal(actual, rioda) + def test_ENVI_tags(self): rasterio = pytest.importorskip('rasterio', minversion='1.0a') from rasterio.transform import from_origin @@ -3260,3 +3270,29 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) + + +def test_pickle_reconstructor(): + + lines = ['foo bar spam eggs'] + + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: + with open(tmp, 'w') as f: + f.writelines(lines) + + obj = PickleByReconstructionWrapper(open, tmp) + + assert obj.value.readlines() == lines + + p_obj = pickle.dumps(obj) + obj.value.close() # for windows + obj2 = pickle.loads(p_obj) + + assert obj2.value.readlines() == lines + + # roundtrip again to make sure we can fully restore the state + p_obj2 = pickle.dumps(obj2) + obj2.value.close() # for windows + obj3 = pickle.loads(p_obj2) + + assert obj3.value.readlines() == lines diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 0ac03327494..8679e892be4 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -17,13 +17,14 @@ from distributed.client import futures_of import xarray as xr -from xarray.tests.test_backends import ON_WINDOWS, create_tmp_file +from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, + create_tmp_geotiff) from xarray.tests.test_dataset import create_test_data from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( - assert_allclose, has_h5netcdf, has_netCDF4, has_scipy, requires_zarr, - raises_regex) + assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, + requires_zarr, raises_regex) # this is to stop isort throwing errors. May have been easier to just use # `isort:skip` in retrospect @@ -136,6 +137,17 @@ def test_dask_distributed_zarr_integration_test(loop): assert_allclose(original, computed) +@requires_rasterio +def test_dask_distributed_rasterio_integration_test(loop): + with create_tmp_geotiff() as (tmp_file, expected): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + da_tiff = xr.open_rasterio(tmp_file, chunks={'band': 1}) + assert isinstance(da_tiff.data, da.Array) + actual = da_tiff.compute() + assert_allclose(actual, expected) + + @pytest.mark.skipif(distributed.__version__ <= '1.19.3', reason='Need recent distributed version to clean up get') @gen_cluster(client=True, timeout=None)