Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support building regridders with masks #219

Merged
merged 30 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b1cb1d
support building regridders with masks
stephenworsley Oct 20, 2022
6163ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2022
6dda5c2
handle contiguous bounds for masked data
stephenworsley Oct 20, 2022
73a30aa
improve documentation
stephenworsley Oct 21, 2022
3035301
test for discontiguities
stephenworsley Oct 21, 2022
07daffd
support mask fetching
stephenworsley Oct 21, 2022
0c76b42
clarify boolean arguments
stephenworsley Oct 21, 2022
79ba39c
extend mask support
stephenworsley Mar 1, 2023
de74618
add tests
stephenworsley Mar 1, 2023
8947903
add load/save support and improve docstrings
stephenworsley Mar 9, 2023
d3f543a
additional fixes
stephenworsley Mar 10, 2023
2b248f0
address review comments
stephenworsley Mar 13, 2023
5566ecf
address review comments
stephenworsley Mar 16, 2023
bded83b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2023
d029ae0
fix errors
stephenworsley Mar 16, 2023
748d560
fix errors
stephenworsley Mar 16, 2023
0bf03c8
parameterise test
stephenworsley Mar 17, 2023
e6a4ed8
address review comments
stephenworsley Mar 21, 2023
4287b5f
Apply suggestions from code review
stephenworsley Mar 21, 2023
4987e26
fix test
stephenworsley Mar 21, 2023
891bc80
simplify tests
stephenworsley Mar 21, 2023
62baf48
fix tests
stephenworsley Mar 21, 2023
c8db599
add test
stephenworsley Mar 22, 2023
6f700b6
parameterize test
stephenworsley Mar 22, 2023
be71828
flake8 fix
stephenworsley Mar 22, 2023
9659cd0
flake8 fix
stephenworsley Mar 22, 2023
eb1da3e
lazy mask calculations
stephenworsley Mar 22, 2023
b309b4a
Apply suggestions from code review
stephenworsley Mar 22, 2023
441e9c6
address review comments
stephenworsley Mar 22, 2023
e06c20e
add to changelog
stephenworsley Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion benchmarks/benchmarks/esmf_regridder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MeshToGridESMFRegridder,
)
from esmf_regrid.schemes import ESMFAreaWeightedRegridder
from ..generate_data import _grid_cube, _gridlike_mesh_cube
from ..generate_data import _curvilinear_cube, _grid_cube, _gridlike_mesh_cube


def _make_small_grid_args():
Expand Down Expand Up @@ -452,3 +452,38 @@ def time_save(self, _, tp, rgt):
def time_load(self, _, tp, rgt):
"""Benchmark the loading time."""
_ = self.load_regridder(self.source_file)


class TimeMaskedRegridding:
"""Benchmarks for :class:`~esmf_regrid.esmf_regrid.schemes.ESMFAreaWeightedRegridder`."""

def setup(self):
"""ASV setup method."""
src = _curvilinear_cube(250, 251, [-180, 180], [-90, 90])
tgt = _curvilinear_cube(251, 250, [-180, 180], [-90, 90])

# Make src and tgt discontiguous at (0, 0)
src_mask = np.zeros([250, 251], dtype=bool)
src_mask[0, :] = True
src.data = np.ma.array(src.data, mask=src_mask)
src.coord("latitude").bounds[0, :, :2] = 0
src.coord("longitude").bounds[0, :, :2] = 0

tgt_mask = np.zeros([251, 250], dtype=bool)
tgt_mask[:, 0] = True
tgt.data = np.ma.array(tgt.data, mask=tgt_mask)
tgt.coord("latitude").bounds[:, 0, ::3] = 0
tgt.coord("longitude").bounds[:, 0, ::3] = 0

self.regrid_class = ESMFAreaWeightedRegridder
self.src = src
self.tgt = tgt

def time_prepare_with_masks(self):
"""Benchmark prepare time with discontiguities and masks."""
try:
_ = self.regrid_class(
self.src, self.tgt, use_src_mask=True, use_tgt_mask=True
)
except TypeError:
_ = self.regrid_class(self.src, self.tgt)
54 changes: 54 additions & 0 deletions benchmarks/benchmarks/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,60 @@ def external(*args, **kwargs):
return return_cube


def _curvilinear_cube(
n_lons,
n_lats,
lon_outer_bounds,
lat_outer_bounds,
):
"""Call _curvilinear_cube via :func:`run_function_elsewhere`."""

def external(*args, **kwargs):
"""
Prep and call _curvilinear_cube, saving to a NetCDF file.

Saving to a file allows the original python executable to pick back up.

Remember that all arguments must work as strings.

"""
from iris import save

from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import (
_curvilinear_cube as original,
)

save_path = kwargs.pop("save_path")

cube = original(*args, **kwargs)
save(cube, save_path)

file_name_sections = [
"_curvilinear_cube",
n_lons,
n_lats,
lon_outer_bounds,
lat_outer_bounds,
]
file_name = "_".join(str(section) for section in file_name_sections)
# Remove 'unsafe' characters.
file_name = re.sub(r"\W+", "", file_name)
save_path = (BENCHMARK_DATA / file_name).with_suffix(".nc")

if not REUSE_DATA or not save_path.is_file():
_ = run_function_elsewhere(
external,
n_lons,
n_lats,
lon_outer_bounds,
lat_outer_bounds,
save_path=str(save_path),
)

return_cube = load_cube(str(save_path))
return return_cube


def _gridlike_mesh_cube(n_lons, n_lats):
"""Call _gridlike_mesh via :func:`run_function_elsewhere`."""

Expand Down
50 changes: 45 additions & 5 deletions esmf_regrid/_esmf_sdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ class SDO(ABC):
objects supported by ESMPy, Grids, Meshes, and LocStreams.
"""

def __init__(self, shape, index_offset, field_kwargs):
def __init__(self, shape, index_offset, field_kwargs, mask=None):
self._shape = shape
self._index_offset = index_offset
self._field_kwargs = field_kwargs
self._mask = mask

@abstractmethod
def _make_esmf_sdo(self):
Expand All @@ -42,6 +43,11 @@ def _refined_shape(self):
"""Return shape passed to ESMF."""
return self._shape

@property
def _refined_mask(self):
"""Return mask passed to ESMF."""
return self._mask

@property
def dims(self):
"""Return number of dimensions."""
Expand All @@ -62,6 +68,11 @@ def index_offset(self):
"""Return the index offset."""
return self._index_offset

@property
def mask(self):
"""Return the mask."""
return self._mask

def _array_to_matrix(self, array):
"""
Reshape data to a form that is compatible with weight matrices.
Expand Down Expand Up @@ -111,6 +122,7 @@ def __init__(
crs=None,
circular=False,
areas=None,
mask=None,
center=False,
):
"""
Expand Down Expand Up @@ -140,6 +152,8 @@ def __init__(
Array describing the areas associated with
each face. If ``None``, then :mod:`esmpy` will use its own
calculated areas.
mask: :obj:`~numpy.typing.ArrayLike`, optional
Array describing which elements :mod:`esmpy` will ignore.
center : bool, default=False
Describes if the center points of the grid cells are used in regridding
calculations.
Expand Down Expand Up @@ -196,6 +210,7 @@ def __init__(
shape=shape,
index_offset=1,
field_kwargs={"staggerloc": esmpy.StaggerLoc.CENTER},
mask=mask,
)

def _as_esmf_info(self):
Expand Down Expand Up @@ -283,10 +298,19 @@ def _make_esmf_sdo(self):
grid_center_y = grid.get_coords(1, staggerloc=esmpy.StaggerLoc.CENTER)
grid_center_y[:] = truecenterlats

def add_get_item(grid, **kwargs):
grid.add_item(**kwargs)
return grid.get_item(**kwargs)

if self.mask is not None:
grid_mask = add_get_item(
grid, item=esmpy.GridItem.MASK, staggerloc=esmpy.StaggerLoc.CENTER
)
grid_mask[:] = self._refined_mask

if areas is not None:
grid.add_item(esmpy.GridItem.AREA, staggerloc=esmpy.StaggerLoc.CENTER)
grid_areas = grid.get_item(
esmpy.GridItem.AREA, staggerloc=esmpy.StaggerLoc.CENTER
grid_areas = add_get_item(
grid, item=esmpy.GridItem.AREA, staggerloc=esmpy.StaggerLoc.CENTER
)
grid_areas[:] = areas.T

Expand Down Expand Up @@ -314,6 +338,7 @@ def __init__(
latbounds,
resolution=3,
crs=None,
mask=None,
):
"""
Create a :class:`RefinedGridInfo` object describing the grid.
Expand Down Expand Up @@ -354,7 +379,7 @@ def __init__(
# Create dummy lat/lon values
lons = np.zeros(self.n_lons_orig)
lats = np.zeros(self.n_lats_orig)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs, mask=mask)

if self.n_lats_orig == 1 and np.allclose(latbounds, [-90, 90]):
self._refined_latbounds = np.array([-90, 0, 90])
Expand Down Expand Up @@ -386,6 +411,21 @@ def _refined_shape(self):
self.n_lons_orig * self.lon_expansion,
)

@property
def _refined_mask(self):
"""Return mask passed to ESMF."""
new_mask = np.broadcast_to(
self.mask[:, np.newaxis, :, np.newaxis],
[
self.n_lats_orig,
self.lat_expansion,
self.n_lons_orig,
self.lon_expansion,
],
)
new_mask = new_mask.reshape(self._refined_shape)
return new_mask

def _collapse_weights(self, is_tgt):
"""
Return a matrix to collapse the weight matrix.
Expand Down
4 changes: 4 additions & 0 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@


def _get_regrid_weights_dict(src_field, tgt_field, regrid_method):
# The value, in array form, that ESMF should treat as an affirmative mask.
expected_mask = np.array([True])
regridder = esmpy.Regrid(
src_field,
tgt_field,
Expand All @@ -24,6 +26,8 @@ def _get_regrid_weights_dict(src_field, tgt_field, regrid_method):
# Choosing the norm_type DSTAREA allows for mdtol type operations
# to be performed using the weights information later on.
norm_type=esmpy.NormType.DSTAREA,
src_mask_values=expected_mask,
dst_mask_values=expected_mask,
factors=True,
)
# Without specifying deep_copy=true, the information in weights_dict
Expand Down
25 changes: 25 additions & 0 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
]
REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}
SOURCE_NAME = "regridder_source_field"
SOURCE_MASK_NAME = "regridder_source_mask"
TARGET_NAME = "regridder_target_field"
TARGET_MASK_NAME = "regridder_target_mask"
WEIGHTS_NAME = "regridder_weights"
WEIGHTS_SHAPE_NAME = "weights_shape"
WEIGHTS_ROW_NAME = "weight_matrix_rows"
Expand All @@ -33,6 +35,13 @@
RESOLUTION = "resolution"


def _add_mask_to_cube(mask, cube, name):
if isinstance(mask, np.ndarray):
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
mask = mask.astype(int)
mask_coord = AuxCoord(mask, var_name=name, long_name=name)
cube.add_aux_coord(mask_coord, list(range(cube.ndim)))


def save_regridder(rg, filename):
"""
Save a regridder scheme instance.
Expand Down Expand Up @@ -69,6 +78,7 @@ def _standard_grid_cube(grid, name):
if regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_cube = _standard_grid_cube(src_grid, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_mesh = rg.mesh
tgt_location = rg.location
Expand All @@ -77,6 +87,8 @@ def _standard_grid_cube(grid, name):
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
for coord in tgt_mesh_coords:
tgt_cube.add_aux_coord(coord, 0)
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)

elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_location = rg.location
Expand All @@ -85,9 +97,11 @@ def _standard_grid_cube(grid, name):
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
for coord in src_mesh_coords:
src_cube.add_aux_coord(coord, 0)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = (rg.grid_y, rg.grid_x)
tgt_cube = _standard_grid_cube(tgt_grid, TARGET_NAME)
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)
else:
msg = (
f"Expected a regridder of type `GridToMeshESMFRegridder` or "
Expand Down Expand Up @@ -205,13 +219,24 @@ def load_regridder(filename):

mdtol = weights_cube.attributes[MDTOL]

if src_cube.coords(SOURCE_MASK_NAME):
use_src_mask = src_cube.coord(SOURCE_MASK_NAME).points
else:
use_src_mask = False
if tgt_cube.coords(TARGET_MASK_NAME):
use_tgt_mask = tgt_cube.coord(TARGET_MASK_NAME).points
else:
use_tgt_mask = False

regridder = scheme(
src_cube,
tgt_cube,
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
resolution=resolution,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
)

esmf_version = weights_cube.attributes[VERSION_ESMF]
Expand Down
7 changes: 7 additions & 0 deletions esmf_regrid/experimental/unstructured_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
node_start_index,
elem_start_index=0,
areas=None,
mask=None,
elem_coords=None,
location="face",
):
Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(
areas: :obj:`~numpy.typing.ArrayLike`, optional
Array describing the areas associated with
each face. If ``None``, then :mod:`esmpy` will use its own calculated areas.
mask: :obj:`~numpy.typing.ArrayLike`, optional
Array describing which elements :mod:`esmpy` will ignore.
elem_coords : :obj:`~numpy.typing.ArrayLike`, optional
An ``Nx2`` array describing the location of the face centers of the mesh.
``elem_coords[:,0]`` describes the longitudes in degrees and
Expand Down Expand Up @@ -84,6 +87,7 @@ def __init__(
shape=shape,
index_offset=self.esi,
field_kwargs=field_kwargs,
mask=mask,
)

def _as_esmf_info(self):
Expand All @@ -108,6 +112,7 @@ def _as_esmf_info(self):
elemId,
elemType,
elemConn,
self.mask,
self.areas,
elemCoord,
)
Expand All @@ -124,6 +129,7 @@ def _make_esmf_sdo(self):
elemId,
elemType,
elemConn,
mask,
areas,
elemCoord,
) = info
Expand All @@ -139,6 +145,7 @@ def _make_esmf_sdo(self):
elemId,
elemType,
elemConn,
element_mask=mask,
element_area=areas,
element_coords=elemCoord,
)
Expand Down
Loading