Skip to content

Proof of concept: MeshIndexSet #6014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
233 changes: 227 additions & 6 deletions lib/iris/experimental/ugrid/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from collections.abc import Container
from contextlib import contextmanager
from copy import copy
from typing import Iterable

from cf_units import Unit
from dask import array as da
import numpy as np
from numpy.typing import ArrayLike

from ... import _lazy_data as _lazy
from ...common import CFVariableMixin, metadata_filter, metadata_manager_factory
Expand Down Expand Up @@ -683,11 +686,11 @@ def normalise(element, axis):
raise ValueError(emsg)

if self.topology_dimension == 1:
self._coord_manager = _Mesh1DCoordinateManager(**kwargs)
self._connectivity_manager = _Mesh1DConnectivityManager(*connectivities)
self._coord_man = _Mesh1DCoordinateManager(**kwargs)
self._connectivity_man = _Mesh1DConnectivityManager(*connectivities)
elif self.topology_dimension == 2:
self._coord_manager = _Mesh2DCoordinateManager(**kwargs)
self._connectivity_manager = _Mesh2DConnectivityManager(*connectivities)
self._coord_man = _Mesh2DCoordinateManager(**kwargs)
self._connectivity_man = _Mesh2DConnectivityManager(*connectivities)
else:
emsg = f"Unsupported 'topology_dimension', got {topology_dimension!r}."
raise NotImplementedError(emsg)
Expand Down Expand Up @@ -1060,8 +1063,8 @@ def line(text, i_indent=0):
def __setstate__(self, state):
metadata_manager, coord_manager, connectivity_manager = state
self._metadata_manager = metadata_manager
self._coord_manager = coord_manager
self._connectivity_manager = connectivity_manager
self._coord_man = coord_manager
self._connectivity_man = connectivity_manager

def _set_dimension_names(self, node, edge, face, reset=False):
args = (node, edge, face)
Expand Down Expand Up @@ -1092,6 +1095,16 @@ def _set_dimension_names(self, node, edge, face, reset=False):

return result

@property
def _connectivity_manager(self):
# Delivered as a property to help with overriding.
return self._connectivity_man

@property
def _coord_manager(self):
# Delivered as a property to help with overriding.
return self._coord_man

@property
def all_connectivities(self):
"""All the :class:`~iris.experimental.ugrid.mesh.Connectivity` instances of the :class:`Mesh`."""
Expand Down Expand Up @@ -1962,6 +1975,137 @@ def topology_dimension(self):
return self._metadata_manager.topology_dimension


class MeshIndexSet(Mesh):
super_mesh: Mesh = None
location: str = None
indices: ArrayLike = None

class _IndexedMembers(dict):
def _readonly(self, *args, **kwargs):
message = (
"Modification of MeshIndexSet is forbidden - this is only "
f"a view onto an original Mesh: id={self.mesh_id}."
)
raise RuntimeError(message)

__setitem__ = _readonly
__delitem__ = _readonly
pop = _readonly
popitem = _readonly
clear = _readonly
update = _readonly
setdefault = _readonly

mesh_id: int

def __init__(self, seq, **kwargs):
self.mesh_id = kwargs.pop("mesh_id")
super().__init__(seq, **kwargs)

@classmethod
def from_mesh(cls, mesh, location, indices):
instance = copy(mesh)
instance.__class__ = cls
instance.super_mesh = mesh
instance.location = location
instance.indices = indices
return instance

def _validate(self) -> None:
# Instances are expected to be created via the from_mesh() method.
# If __init__() is used instead, the class will not be fully set up.
# Actually disabling __init__() is considered an anti-pattern.
if any(v is None for v in [self.super_mesh, self.location, self.indices]):
message = (
"MeshIndexSet instance not fully set up. Note that "
"MeshIndexSet is designed to be created via the from_mesh() "
"method, not via __init__()."
)
raise RuntimeError(message)

def __eq__(self, other):
# TBD: this is a minimalist implementation and requires to be revisited
return id(self) == id(other)

def __ne__(self, other):
# TBD: this is a minimalist implementation and requires to be revisited
return id(self) != id(other)

def _calculate_node_indices(self):
self._validate()
if self.location == "node":
result = self.indices
elif self.location in ["edge", "face"]:
(connectivity,) = [
c
for c in self.super_mesh.all_connectivities
if (
c is not None
and c.location == self.location
and c.connected == "node"
)
]
# Doesn't matter if connectivity is transposed or not in this case.
conn_indices = connectivity.indices[self.indices]
node_set = list(set(conn_indices.compressed()))
node_set.sort()
result = node_set
else:
result = None
# TODO: should this be validated earlier?
# Maybe even with an Enum?
message = (
f"Expected location to be one of `node`, `edge` or `face`, "
f"got `{self.location}`"
)
raise NotImplementedError(message)

return result

def _calculate_edge_indices(self):
self._validate()
if self.location == "edge":
result = self.indices
else:
result = None
return result

def _calculate_face_indices(self):
self._validate()
if self.location == "face":
result = self.indices
else:
result = None
return result

@property
def _coord_manager(self):
# Intended to be a 'view' on the original, and Meshes are mutable, so
# must re-index every time it is accessed.
with self.super_mesh._coord_man.indexed(
self._calculate_node_indices(),
self._calculate_edge_indices(),
self._calculate_face_indices(),
mesh_id=id(self.super_mesh),
) as indexed_coord_man:
result = copy(indexed_coord_man)

return result

@property
def _connectivity_manager(self):
# Intended to be a 'view' on the original, and Meshes are mutable, so
# must re-index every time it is accessed.
with self.super_mesh._connectivity_man.indexed(
self._calculate_node_indices(),
self._calculate_edge_indices(),
self._calculate_face_indices(),
mesh_id=id(self.super_mesh),
) as indexed_connectivity_man:
result = copy(indexed_connectivity_man)
return result


class _Mesh1DCoordinateManager:
"""TBD: require clarity on coord_systems validation.

Expand Down Expand Up @@ -2019,6 +2163,41 @@ def __str__(self):
args = [f"{member}" for member, coord in self if coord is not None]
return f"{self.__class__.__name__}({', '.join(args)})"

@contextmanager
def indexed(
self,
node_indices: np.typing.ArrayLike,
edge_indices: np.typing.ArrayLike,
face_indices: np.typing.ArrayLike,
mesh_id: int,
):
# TODO: make members_indexed read-only.
members_original = copy(self._members)

indices_dict = {
"node": node_indices,
"edge": edge_indices,
"face": face_indices,
}
members_indexed = {}
for key, coord in self._members.items():
indexing = None
indexed = None
if coord is not None:
indexing = indices_dict[key.split("_")[0]]
if indexing is not None:
indexed = coord[indexing]
members_indexed[key] = indexed

try:
self._members = MeshIndexSet._IndexedMembers(
members_indexed,
mesh_id=mesh_id,
)
yield self
finally:
self._members = members_original

def _remove(self, **kwargs):
result = {}
members = self.filters(**kwargs)
Expand Down Expand Up @@ -2547,6 +2726,48 @@ def element_filter(instances, loc_arg, loc_name):
result_dict = {k: v for k, v in self._members.items() if id(v) in result_ids}
return result_dict

@contextmanager
def indexed(
self,
node_indices: np.typing.ArrayLike,
edge_indices: np.typing.ArrayLike,
face_indices: np.typing.ArrayLike,
mesh_id: int,
):
# TODO: make members_indexed read-only.
members_original = copy(self._members)

indices_dict = {
"node": node_indices,
"edge": edge_indices,
"face": face_indices,
}
node_index_mapping = {
old_index: new_index for new_index, old_index in enumerate(node_indices)
}
members_indexed = {}
for key, connectivity in self._members.items():
indexing = None
indexed = None
if connectivity is not None:
indexing = indices_dict[connectivity.location]
if indexing is not None:
new_values = connectivity.indices_by_location()[indexing]
new_values = np.vectorize(node_index_mapping.get)(new_values)
if connectivity.location_axis == 1:
new_values = new_values.T
indexed = connectivity.copy(new_values)
members_indexed[key] = indexed

try:
self._members = MeshIndexSet._IndexedMembers(
members_indexed,
mesh_id=mesh_id,
)
yield self
finally:
self._members = members_original

def remove(
self,
item=None,
Expand Down