From 097e2643ecfbd1ea69464e772136fe6a171a501f Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Tue, 15 May 2018 15:18:25 -0700 Subject: [PATCH] stop using clever getattr hack --- xarray/backends/common.py | 13 +++++---- xarray/backends/rasterio_.py | 51 ++++++++++++++++++------------------ 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e9d9806df27..c9d022ddc61 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -509,15 +509,17 @@ def assert_open(self): class PickleByReconstructionWrapper(object): def __init__(self, opener, file, mode='r', **kwargs): - self.opener = partial(opener, file, mode=mode, **kwargs) self.mode = mode + self._ds = None @property def value(self): - return self.opener() + self._ds = self.opener() + return self._ds def __getstate__(self): + del self._ds state = self.__dict__.copy() if self.mode == 'w': # file has already been created, don't override when restoring @@ -527,8 +529,5 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) - def __getitem__(self, key): - return self.value[key] - - def __getattr__(self, name): - return getattr(self.value, name) + def close(self): + self._ds.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 482ae420679..0f19a1b51be 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -25,15 +25,15 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, rasterio_ds): - self.rasterio_ds = rasterio_ds - self._shape = (rasterio_ds.count, rasterio_ds.height, - rasterio_ds.width) + def __init__(self, riods): + self.riods = riods + self._shape = (riods.value.count, riods.value.height, + riods.value.width) self._ndims = len(self.shape) @property def dtype(self): - dtypes = self.rasterio_ds.dtypes + dtypes = self.riods.value.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') return np.dtype(dtypes[0]) @@ -105,7 +105,7 @@ def _get_indexer(self, key): def __getitem__(self, key): band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - out = self.rasterio_ds.read(band_key, window=tuple(window)) + out = self.riods.value.read(band_key, window=tuple(window)) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return indexing.NumpyIndexingAdapter(out)[np_inds] @@ -203,20 +203,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.count < 1: + if riods.value.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.indexes) + coords['band'] = np.asarray(riods.value.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.affine + transform = riods.value.affine else: - transform = riods.transform + transform = riods.value.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.width, riods.height + nx, ny = riods.value.width, riods.value.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -239,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # For serialization store as tuple of 6 floats, the last row being # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods, 'crs') and riods.crs: + if hasattr(riods.value, 'crs') and riods.value.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.crs.to_string() - if hasattr(riods, 'res'): + attrs['crs'] = riods.value.crs.to_string() + if hasattr(riods.value, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.res - if hasattr(riods, 'is_tiled'): + attrs['res'] = riods.value.res + if hasattr(riods.value, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.is_tiled) + attrs['is_tiled'] = np.uint8(riods.value.is_tiled) with warnings.catch_warnings(): - # casting riods.transform to a tuple makes this future proof + # casting riods.value.transform to a tuple makes this future proof warnings.simplefilter('ignore', FutureWarning) - if hasattr(riods, 'transform'): + if hasattr(riods.value, 'transform'): # Affine transformation matrix (tuple of floats) # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.transform) - if hasattr(riods, 'nodatavals'): + attrs['transform'] = tuple(riods.value.transform) + if hasattr(riods.value, 'nodatavals'): # The nodata values for the raster bands attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.nodatavals]) + for nodataval in riods.value.nodatavals]) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.driver + driver = riods.value.driver if driver in parsers: - meta = parsers[driver](riods.tags(ns=driver)) + meta = parsers[driver](riods.value.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise - if isinstance(v, (list, np.ndarray)) and len(v) == riods.count: + if (isinstance(v, (list, np.ndarray)) and + len(v) == riods.value.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v