Skip to content

Commit

Permalink
extend mask support
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenworsley committed Mar 1, 2023
1 parent f4dd0b9 commit 2f8132f
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 9 deletions.
27 changes: 24 additions & 3 deletions esmf_regrid/_esmf_sdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def _refined_shape(self):
"""Return shape passed to ESMF."""
return self._shape

@property
def _refined_mask(self):
"""Return mask passed to ESMF."""
return self._mask

@property
def dims(self):
"""Return number of dimensions."""
Expand All @@ -64,7 +69,7 @@ def index_offset(self):

@property
def mask(self):
"""Return the index offset."""
"""Return the mask."""
return self._mask

def _array_to_matrix(self, array):
Expand Down Expand Up @@ -297,7 +302,7 @@ def _make_esmf_sdo(self):
grid_mask = grid.get_item(
ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER
)
grid_mask[:] = self.mask
grid_mask[:] = self._refined_mask

if areas is not None:
grid.add_item(ESMF.GridItem.AREA, staggerloc=ESMF.StaggerLoc.CENTER)
Expand Down Expand Up @@ -330,6 +335,7 @@ def __init__(
latbounds,
resolution=3,
crs=None,
mask=None,
):
"""
Create a :class:`RefinedGridInfo` object describing the grid.
Expand Down Expand Up @@ -370,7 +376,7 @@ def __init__(
# Create dummy lat/lon values
lons = np.zeros(self.n_lons_orig)
lats = np.zeros(self.n_lats_orig)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs, mask=mask)

if self.n_lats_orig == 1 and np.allclose(latbounds, [-90, 90]):
self._refined_latbounds = np.array([-90, 0, 90])
Expand Down Expand Up @@ -402,6 +408,21 @@ def _refined_shape(self):
self.n_lons_orig * self.lon_expansion,
)

@property
def _refined_mask(self):
"""Return mask passed to ESMF."""
new_mask = np.broadcast_to(
self.mask[:, np.newaxis, :, np.newaxis],
[
self.n_lats_orig,
self.lat_expansion,
self.n_lons_orig,
self.lon_expansion,
],
)
new_mask = new_mask.reshape(self._refined_shape)
return new_mask

def _collapse_weights(self, is_tgt):
"""
Return a matrix to collapse the weight matrix.
Expand Down
55 changes: 49 additions & 6 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from esmf_regrid.esmf_regridder import Regridder
from esmf_regrid.experimental.unstructured_regrid import MeshInfo
from esmf_regrid.schemes import _create_cube, _cube_to_GridInfo, _get_coord
from esmf_regrid.schemes import _create_cube, _cube_to_GridInfo, _get_coord, _get_mask


def _map_complete_blocks(src, func, dims, out_sizes):
Expand Down Expand Up @@ -106,7 +106,7 @@ def _map_complete_blocks(src, func, dims, out_sizes):
)


def _mesh_to_MeshInfo(mesh, location):
def _mesh_to_MeshInfo(mesh, location, mask=None):
# Returns a MeshInfo object describing the mesh of the cube.
assert mesh.topology_dimension == 2
if None in mesh.face_coords:
Expand All @@ -119,6 +119,7 @@ def _mesh_to_MeshInfo(mesh, location):
mesh.face_node_connectivity.start_index,
elem_coords=elem_coords,
location=location,
mask=mask,
)
return meshinfo

Expand All @@ -144,6 +145,8 @@ def _regrid_unstructured_to_rectilinear__prepare(
method,
precomputed_weights=None,
resolution=None,
src_mask=None,
tgt_mask=None,
):
"""
First (setup) part of 'regrid_unstructured_to_rectilinear'.
Expand Down Expand Up @@ -185,8 +188,10 @@ def _regrid_unstructured_to_rectilinear__prepare(
# mesh belongs to.
mesh_dim = src_mesh_cube.mesh_dim()

meshinfo = _mesh_to_MeshInfo(mesh, location)
gridinfo = _cube_to_GridInfo(target_grid_cube, center=center, resolution=resolution)
meshinfo = _mesh_to_MeshInfo(mesh, location, mask=src_mask)
gridinfo = _cube_to_GridInfo(
target_grid_cube, center=center, resolution=resolution, mask=tgt_mask
)

regridder = Regridder(
meshinfo, gridinfo, method=method, precomputed_weights=precomputed_weights
Expand Down Expand Up @@ -247,6 +252,8 @@ def regrid_unstructured_to_rectilinear(
mdtol=0,
method="conservative",
resolution=None,
src_mask=None,
tgt_mask=None,
):
r"""
Regrid unstructured :class:`~iris.cube.Cube` onto rectilinear grid.
Expand Down Expand Up @@ -303,6 +310,8 @@ def regrid_unstructured_to_rectilinear(
grid_cube,
method=method,
resolution=resolution,
src_mask=src_mask,
tgt_mask=tgt_mask,
)
result = _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol)
return result
Expand All @@ -319,6 +328,8 @@ def __init__(
method="conservative",
precomputed_weights=None,
resolution=None,
src_mask=False,
tgt_mask=False,
):
"""
Create regridder for conversions between source mesh and target grid.
Expand Down Expand Up @@ -382,12 +393,23 @@ def __init__(
)
self.resolution = resolution

if src_mask is True:
src_mask = _get_mask(src_mesh_cube)
elif src_mask is False:
src_mask = None
if tgt_mask is True:
tgt_mask = _get_mask(target_grid_cube)
elif tgt_mask is False:
tgt_mask = None

partial_regrid_info = _regrid_unstructured_to_rectilinear__prepare(
src_mesh_cube,
target_grid_cube,
method=self.method,
precomputed_weights=precomputed_weights,
resolution=resolution,
src_mask=src_mask,
tgt_mask=tgt_mask,
)

# Record source mesh.
Expand Down Expand Up @@ -472,6 +494,8 @@ def _regrid_rectilinear_to_unstructured__prepare(
method,
precomputed_weights=None,
resolution=None,
src_mask=None,
tgt_mask=None,
):
"""
First (setup) part of 'regrid_rectilinear_to_unstructured'.
Expand Down Expand Up @@ -516,8 +540,10 @@ def _regrid_rectilinear_to_unstructured__prepare(
else:
grid_y_dim, grid_x_dim = src_grid_cube.coord_dims(grid_x)

meshinfo = _mesh_to_MeshInfo(mesh, location)
gridinfo = _cube_to_GridInfo(src_grid_cube, center=center, resolution=resolution)
meshinfo = _mesh_to_MeshInfo(mesh, location, mask=tgt_mask)
gridinfo = _cube_to_GridInfo(
src_grid_cube, center=center, resolution=resolution, mask=src_mask
)

regridder = Regridder(
gridinfo, meshinfo, method=method, precomputed_weights=precomputed_weights
Expand Down Expand Up @@ -577,6 +603,8 @@ def regrid_rectilinear_to_unstructured(
mdtol=0,
method="conservative",
resolution=None,
src_mask=None,
tgt_mask=None,
):
r"""
Regrid rectilinear :class:`~iris.cube.Cube` onto unstructured mesh.
Expand Down Expand Up @@ -637,6 +665,8 @@ def regrid_rectilinear_to_unstructured(
mesh_cube,
method=method,
resolution=resolution,
src_mask=src_mask,
tgt_mask=tgt_mask,
)
result = _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol)
return result
Expand All @@ -653,6 +683,8 @@ def __init__(
method="conservative",
precomputed_weights=None,
resolution=None,
src_mask=False,
tgt_mask=False,
):
"""
Create regridder for conversions between source grid and target mesh.
Expand Down Expand Up @@ -711,12 +743,23 @@ def __init__(
self.method = method
self.resolution = resolution

if src_mask is True:
src_mask = _get_mask(src_grid_cube)
elif src_mask is False:
src_mask = None
if tgt_mask is True:
tgt_mask = _get_mask(target_mesh_cube)
elif tgt_mask is False:
tgt_mask = None

partial_regrid_info = _regrid_rectilinear_to_unstructured__prepare(
src_grid_cube,
target_mesh_cube,
method=self.method,
precomputed_weights=precomputed_weights,
resolution=self.resolution,
src_mask=src_mask,
tgt_mask=tgt_mask,
)

# Store regrid info.
Expand Down
1 change: 1 addition & 0 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _cube_to_GridInfo(cube, center=False, resolution=None, mask=None):
lat_bound_array,
crs=crs,
resolution=resolution,
mask=mask,
)
return grid_info

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,42 @@ def test_curvilinear():
# Check metadata and scalar coords.
result.data = expected_data
assert expected_cube == result


def test_masks():
"""
Test initialisation of :func:`esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`.
Checks that the `src_mask` and `tgt_mask` keywords work properly.
"""
src = _curvilinear_cube(7, 6, [-180, 180], [-90, 90])
tgt = _gridlike_mesh_cube(6, 7)

# Make src and tgt discontiguous at (0, 0)
src_mask = np.zeros([6, 7], dtype=bool)
src_mask[0, 0] = True
src.data = np.ma.array(src.data, mask=src_mask)
src_discontiguous = src.copy()
src_discontiguous.coord("latitude").bounds[0, 0] = 0
src_discontiguous.coord("longitude").bounds[0, 0] = 0

tgt_mask = np.zeros([7 * 6], dtype=bool)
tgt_mask[0] = True
tgt.data = np.ma.array(tgt.data, mask=tgt_mask)

rg_src_masked = GridToMeshESMFRegridder(src_discontiguous, tgt, src_mask=True)
rg_tgt_masked = GridToMeshESMFRegridder(src, tgt, tgt_mask=True)
rg_unmasked = GridToMeshESMFRegridder(src, tgt)

weights_src_masked = rg_src_masked.regridder.weight_matrix
weights_tgt_masked = rg_tgt_masked.regridder.weight_matrix
weights_unmasked = rg_unmasked.regridder.weight_matrix

# Check there are no weights associated with the masked point.
assert weights_src_masked[:, 0].nnz == 0
assert weights_tgt_masked[0].nnz == 0

# Check all other weights are correct.
assert np.allclose(
weights_src_masked[:, 1:].todense(), weights_unmasked[:, 1:].todense()
)
assert np.allclose(weights_tgt_masked[1:].todense(), weights_unmasked[1:].todense())
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,42 @@ def test_curvilinear():
# Check metadata and scalar coords.
result.data = expected_data
assert expected_cube == result


def test_masks():
"""
Test initialisation of :func:`esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`.
Checks that the `src_mask` and `tgt_mask` keywords work properly.
"""
src = _gridlike_mesh_cube(7, 6)
tgt = _curvilinear_cube(6, 7, [-180, 180], [-90, 90])

# Make src and tgt discontiguous at (0, 0)
src_mask = np.zeros([6 * 7], dtype=bool)
src_mask[0] = True
src.data = np.ma.array(src.data, mask=src_mask)

tgt_mask = np.zeros([7, 6], dtype=bool)
tgt_mask[0, 0] = True
tgt.data = np.ma.array(tgt.data, mask=tgt_mask)
tgt_discontiguous = tgt.copy()
tgt_discontiguous.coord("latitude").bounds[0, 0] = 0
tgt_discontiguous.coord("longitude").bounds[0, 0] = 0

rg_src_masked = MeshToGridESMFRegridder(src, tgt, src_mask=True)
rg_tgt_masked = MeshToGridESMFRegridder(src, tgt_discontiguous, tgt_mask=True)
rg_unmasked = MeshToGridESMFRegridder(src, tgt)

weights_src_masked = rg_src_masked.regridder.weight_matrix
weights_tgt_masked = rg_tgt_masked.regridder.weight_matrix
weights_unmasked = rg_unmasked.regridder.weight_matrix

# Check there are no weights associated with the masked point.
assert weights_src_masked[:, 0].nnz == 0
assert weights_tgt_masked[0].nnz == 0

# Check all other weights are correct.
assert np.allclose(
weights_src_masked[:, 1:].todense(), weights_unmasked[:, 1:].todense()
)
assert np.allclose(weights_tgt_masked[1:].todense(), weights_unmasked[1:].todense())

0 comments on commit 2f8132f

Please sign in to comment.