Skip to content

Commit

Permalink
stop using clever getattr hack
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Hamman committed May 15, 2018
1 parent 6669035 commit 097e264
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
13 changes: 6 additions & 7 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,17 @@ def assert_open(self):
class PickleByReconstructionWrapper(object):

def __init__(self, opener, file, mode='r', **kwargs):

self.opener = partial(opener, file, mode=mode, **kwargs)
self.mode = mode
self._ds = None

@property
def value(self):
return self.opener()
self._ds = self.opener()
return self._ds

def __getstate__(self):
del self._ds
state = self.__dict__.copy()
if self.mode == 'w':
# file has already been created, don't override when restoring
Expand All @@ -527,8 +529,5 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)

def __getitem__(self, key):
return self.value[key]

def __getattr__(self, name):
return getattr(self.value, name)
def close(self):
self._ds.close()
51 changes: 26 additions & 25 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, rasterio_ds):
self.rasterio_ds = rasterio_ds
self._shape = (rasterio_ds.count, rasterio_ds.height,
rasterio_ds.width)
def __init__(self, riods):
self.riods = riods
self._shape = (riods.value.count, riods.value.height,
riods.value.width)
self._ndims = len(self.shape)

@property
def dtype(self):
dtypes = self.rasterio_ds.dtypes
dtypes = self.riods.value.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError('All bands should have the same dtype')
return np.dtype(dtypes[0])
Expand Down Expand Up @@ -105,7 +105,7 @@ def _get_indexer(self, key):
def __getitem__(self, key):
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

out = self.rasterio_ds.read(band_key, window=tuple(window))
out = self.riods.value.read(band_key, window=tuple(window))
if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
return indexing.NumpyIndexingAdapter(out)[np_inds]
Expand Down Expand Up @@ -203,20 +203,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
coords = OrderedDict()

# Get bands
if riods.count < 1:
if riods.value.count < 1:
raise ValueError('Unknown dims')
coords['band'] = np.asarray(riods.indexes)
coords['band'] = np.asarray(riods.value.indexes)

# Get coordinates
if LooseVersion(rasterio.__version__) < '1.0':
transform = riods.affine
transform = riods.value.affine
else:
transform = riods.transform
transform = riods.value.transform
if transform.is_rectilinear:
# 1d coordinates
parse = True if parse_coordinates is None else parse_coordinates
if parse:
nx, ny = riods.width, riods.height
nx, ny = riods.value.width, riods.value.height
# xarray coordinates are pixel centered
x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform
_, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform
Expand All @@ -239,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
# For serialization store as tuple of 6 floats, the last row being
# always (0, 0, 1) per definition (see https://github.com/sgillies/affine)
attrs['transform'] = tuple(transform)[:6]
if hasattr(riods, 'crs') and riods.crs:
if hasattr(riods.value, 'crs') and riods.value.crs:
# CRS is a dict-like object specific to rasterio
# If CRS is not None, we convert it back to a PROJ4 string using
# rasterio itself
attrs['crs'] = riods.crs.to_string()
if hasattr(riods, 'res'):
attrs['crs'] = riods.value.crs.to_string()
if hasattr(riods.value, 'res'):
# (width, height) tuple of pixels in units of CRS
attrs['res'] = riods.res
if hasattr(riods, 'is_tiled'):
attrs['res'] = riods.value.res
if hasattr(riods.value, 'is_tiled'):
# Is the TIF tiled? (bool)
# We cast it to an int for netCDF compatibility
attrs['is_tiled'] = np.uint8(riods.is_tiled)
attrs['is_tiled'] = np.uint8(riods.value.is_tiled)
with warnings.catch_warnings():
# casting riods.transform to a tuple makes this future proof
# casting riods.value.transform to a tuple makes this future proof
warnings.simplefilter('ignore', FutureWarning)
if hasattr(riods, 'transform'):
if hasattr(riods.value, 'transform'):
# Affine transformation matrix (tuple of floats)
# Describes coefficients mapping pixel coordinates to CRS
attrs['transform'] = tuple(riods.transform)
if hasattr(riods, 'nodatavals'):
attrs['transform'] = tuple(riods.value.transform)
if hasattr(riods.value, 'nodatavals'):
# The nodata values for the raster bands
attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval
for nodataval in riods.nodatavals])
for nodataval in riods.value.nodatavals])

# Parse extra metadata from tags, if supported
parsers = {'ENVI': _parse_envi}

driver = riods.driver
driver = riods.value.driver
if driver in parsers:
meta = parsers[driver](riods.tags(ns=driver))
meta = parsers[driver](riods.value.tags(ns=driver))

for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if isinstance(v, (list, np.ndarray)) and len(v) == riods.count:
if (isinstance(v, (list, np.ndarray)) and
len(v) == riods.value.count):
coords[k] = ('band', np.asarray(v))
else:
attrs[k] = v
Expand Down

0 comments on commit 097e264

Please sign in to comment.