forked from pydata/xarray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rasterio_.py
120 lines (93 loc) · 3.59 KB
/
rasterio_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
try:
import rasterio
except ImportError:
rasterio = False
from .. import Variable, DataArray
from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin
from ..core import indexing
from ..core.pycompat import OrderedDict, suppress
from .common import AbstractDataStore
__rio_varname__ = 'raster'
class RasterioArrayWrapper(NDArrayMixin):
def __init__(self, ds):
self._ds = ds
self.array = ds.read()
@property
def dtype(self):
return np.dtype(self._ds.dtypes[0])
def __getitem__(self, key):
if key == () and self.ndim == 0:
return self.array.get_value()
return self.array[key]
class RasterioDataStore(AbstractDataStore):
"""Store for accessing datasets via Rasterio
"""
def __init__(self, filename, mode='r'):
with rasterio.Env():
self.ds = rasterio.open(filename, mode=mode, )
# Get coords
nx, ny = self.ds.width, self.ds.height
x0, y0 = self.ds.bounds.left, self.ds.bounds.top
dx, dy = self.ds.res[0], -self.ds.res[1]
self.coords = {'y': np.linspace(start=y0, num=ny, stop=(y0 + (ny-1) * dy)),
'x': np.linspace(start=x0, num=nx, stop=(x0 + (nx-1) * dx))}
# Get dims
if self.ds.count >= 1:
self.dims = ('band', 'y', 'x')
self.coords['band'] = self.ds.indexes
else:
raise ValueError('unknown dims')
self._attrs = OrderedDict()
with suppress(AttributeError):
for attr_name in ['crs', 'transform', 'proj']:
self._attrs[attr_name] = getattr(self.ds, attr_name)
def open_store_variable(self, var):
if var != __rio_varname__:
raise ValueError(
'Rasterio variables are all named %s' % __rio_varname__)
data = indexing.LazilyIndexedArray(
RasterioArrayWrapper(self.ds))
return Variable(self.dims, data, self._attrs)
def get_variables(self):
# Get lat lon coordinates
coords = _try_to_get_latlon_coords(self.coords, self._attrs)
rio_vars = {__rio_varname__: self.open_store_variable(__rio_varname__)}
rio_vars.update(coords)
return FrozenOrderedDict(rio_vars)
def get_attrs(self):
return Frozen(self._attrs)
def get_dimensions(self):
return Frozen(self.ds.dims)
def close(self):
self.ds.close()
def _transform_proj(p1, p2, x, y, nocopy=False):
"""Wrapper around the pyproj transform.
When two projections are equal, this function avoids quite a bunch of
useless calculations. See https://github.com/jswhit/pyproj/issues/15
"""
import pyproj
import copy
if p1.srs == p2.srs:
if nocopy:
return x, y
else:
return copy.deepcopy(x), copy.deepcopy(y)
return pyproj.transform(p1, p2, x, y)
def _try_to_get_latlon_coords(coords, attrs):
coords_out = {}
try:
import pyproj
except ImportError:
pyproj = False
if 'crs' in attrs and pyproj:
proj = pyproj.Proj(attrs['crs'])
x, y = np.meshgrid(coords['x'], coords['y'])
proj_out = pyproj.Proj("+init=EPSG:4326", preserve_units=True)
xc, yc = _transform_proj(proj, proj_out, x, y)
dims = ('y', 'x')
coords_out['lat'] = Variable(dims,yc,attrs={'units': 'degrees_north', 'long_name': 'latitude',
'standard_name': 'latitude'})
coords_out['lon'] = Variable(dims,xc,attrs={'units': 'degrees_east', 'long_name': 'longitude',
'standard_name': 'longitude'})
return coords_out