Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xarray/dask bilinear resampling #519

Merged
merged 7 commits into from
Dec 11, 2018
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 84 additions & 49 deletions satpy/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@
import dask
import dask.array as da

from pyresample.bilinear import get_bil_info, get_sample_from_bil_info
from pyresample.ewa import fornav, ll2cr
from pyresample.geometry import SwathDefinition, AreaDefinition
from pyresample.kd_tree import XArrayResamplerNN
from pyresample.bilinear.xarr import XArrayResamplerBilinear
from satpy import CHUNK_SIZE
from satpy.config import config_search_paths, get_config_path

Expand Down Expand Up @@ -587,76 +587,86 @@ def compute(self, data, cache_id=None, fill_value=0, weight_count=10000,


class BilinearResampler(BaseResampler):

"""Resample using bilinear."""

def precompute(self, mask=None, radius_of_influence=50000,
cache_dir=None, **kwargs):
def __init__(self, source_geo_def, target_geo_def):
super(BilinearResampler, self).__init__(source_geo_def, target_geo_def)
self.resampler = None

def precompute(self, mask=None, radius_of_influence=50000, epsilon=0,
reduce_data=True, nprocs=1,
cache_dir=False, **kwargs):
"""Create bilinear coefficients and store them for later use.

Note: The `mask` keyword should be provided if geolocation may be valid
where data points are invalid. This defaults to the `mask` attribute of
the `data` numpy masked array passed to the `resample` method.
"""

raise NotImplementedError("Bilinear interpolation has not been "
"converted to XArray/Dask yet.")

del kwargs
source_geo_def = self.source_geo_def

bil_hash = self.get_hash(source_geo_def=source_geo_def,
radius_of_influence=radius_of_influence,
mode="bilinear")
source_geo_def = mask_source_lonlats(self.source_geo_def, mask)

filename = self._create_cache_filename(cache_dir, bil_hash)
self._read_params_from_cache(cache_dir, bil_hash, filename)
if self.resampler is None:
kwargs = dict(source_geo_def=source_geo_def,
target_geo_def=self.target_geo_def,
radius_of_influence=radius_of_influence,
neighbours=32,
epsilon=epsilon,
reduce_data=reduce_data)

if self.cache is not None:
LOG.debug("Loaded bilinear parameters")
return self.cache
else:
LOG.debug("Computing bilinear parameters")
self.resampler = XArrayResamplerBilinear(**kwargs)

try:
self.load_bil_info(cache_dir, **kwargs)
LOG.debug("Loaded bilinear parameters")
except IOError:
LOG.debug("Computing bilinear parameters")
self.resampler.get_bil_info()
self.save_bil_info(cache_dir, **kwargs)

def load_bil_info(self, cache_dir, **kwargs):

bilinear_t, bilinear_s, input_idxs, idx_arr = get_bil_info(source_geo_def, self.target_geo_def,
radius_of_influence, neighbours=32,
masked=False)
self.cache = {'bilinear_s': bilinear_s,
'bilinear_t': bilinear_t,
'input_idxs': input_idxs,
'idx_arr': idx_arr}
if cache_dir:
filename = self._create_cache_filename(cache_dir,
prefix='resample_lut_bil_',
**kwargs)
cache = np.load(filename)
for elt in ['bilinear_s', 'bilinear_t', 'valid_input_index',
'index_array']:
if isinstance(cache[elt], tuple):
setattr(self.resampler, elt, cache[elt][0])
else:
setattr(self.resampler, elt, cache[elt])
cache.close()
else:
raise IOError

self._update_caches(bil_hash, cache_dir, filename)
def save_bil_info(self, cache_dir, **kwargs):
if cache_dir:
filename = self._create_cache_filename(cache_dir,
prefix='resample_lut_bil_',
**kwargs)
LOG.info('Saving kd_tree neighbour info to %s', filename)
cache = {'bilinear_s': self.resampler.bilinear_s,
'bilinear_t': self.resampler.bilinear_t,
'valid_input_index': self.resampler.valid_input_index,
'index_array': self.resampler.index_array}

return self.cache
np.savez(filename, **cache)

def compute(self, data, fill_value=None, **kwargs):
"""Resample the given data using bilinear interpolation"""
del kwargs

if fill_value is None:
fill_value = data.attrs.get('_FillValue')
target_shape = self.target_geo_def.shape
if data.ndim == 3:
output_shape = list(target_shape)
output_shape.append(data.shape[-1])
res = np.zeros(output_shape, dtype=data.dtype)
for i in range(data.shape[-1]):
res[:, :, i] = get_sample_from_bil_info(data[:, :, i].ravel(),
self.cache[
'bilinear_t'],
self.cache[
'bilinear_s'],
self.cache[
'input_idxs'],
self.cache['idx_arr'],
output_shape=target_shape)

else:
res = get_sample_from_bil_info(data.ravel(),
self.cache['bilinear_t'],
self.cache['bilinear_s'],
self.cache['input_idxs'],
self.cache['idx_arr'],
output_shape=target_shape)
res = np.ma.masked_invalid(res)
res = self.resampler.get_sample_from_bil_info(data,
fill_value=fill_value,
output_shape=target_shape)

return res

Expand Down Expand Up @@ -796,7 +806,7 @@ def compute(self, data, expand=True, **kwargs):
RESAMPLERS = {"kd_tree": KDTreeResampler,
"nearest": KDTreeResampler,
"ewa": EWAResampler,
# "bilinear": BilinearResampler,
"bilinear": BilinearResampler,
"native": NativeResampler,
}

Expand Down Expand Up @@ -890,3 +900,28 @@ def resample_dataset(dataset, destination_area, **kwargs):
new_data.attrs.update(area=destination_area)

return new_data


def mask_source_lonlats(source_def, mask):
pnuu marked this conversation as resolved.
Show resolved Hide resolved
"""Mask source longitudes and latitudes to match data mask."""
source_geo_def = source_def

# the data may have additional masked pixels
# let's compare them to see if we can use the same area
# assume lons and lats mask are the same
if mask is not None and mask is not False and isinstance(source_geo_def, SwathDefinition):
import xarray.ufuncs as xu
if np.issubsctype(mask.dtype, np.bool):
# copy the source area and use it for the rest of the calculations
LOG.debug("Copying source area to mask invalid dataset points")
if mask.ndim != source_geo_def.lons.ndim:
raise ValueError("Can't mask area, mask has different number "
"of dimensions.")

return SwathDefinition(source_geo_def.lons.where(~mask),
source_geo_def.lats.where(~mask))
else:
return SwathDefinition(source_geo_def.lons.where(~xu.isnan(mask)),
source_geo_def.lats.where(~xu.isnan(mask)))

return source_geo_def