From 91fe92c6f0f0cb8cef208116bb88de036b1a681c Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Thu, 13 Jun 2024 16:35:13 +0100 Subject: [PATCH 1/9] add MeshIndexSet --- lib/iris/experimental/ugrid/mesh.py | 230 ++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index a798f7af77..db9db75ca5 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -1962,6 +1962,236 @@ def topology_dimension(self): return self._metadata_manager.topology_dimension +class MeshIndexSet(Mesh): + def __init__(self, mesh, location, indices): + self.super_mesh = mesh + self.location = location + self.indices = indices + + self._metadata_manager = metadata_manager_factory(MeshMetadata) + + # topology_dimension is read-only, so assign directly to the metadata manager + self._metadata_manager.topology_dimension = mesh.topology_dimension + + self.node_dimension = mesh.node_dimension + self.edge_dimension = mesh.edge_dimension + self.face_dimension = mesh.face_dimension + + # assign the metadata to the metadata manager + self.standard_name = mesh.standard_name + self.long_name = mesh.long_name + self.var_name = mesh.var_name + self.units = mesh.units + self.attributes = mesh.attributes + + self._coord_manager = _MeshIndexCoordinateManager(mesh, location, indices) + self._connectivity_manager = _MeshIndexConnectivityManager( + mesh, location, indices + ) + + +class _MeshIndexManager: + def __init__(self, mesh, location, indices): + self.mesh = mesh + self.location = location + self.indices = indices + + self.face_indices = self._calculate_face_indices() + self.edge_indices = self._calculate_edge_indices() + self.node_indices = self._calculate_node_indices() + self.node_index_dict = { + old_index: new_index + for new_index, old_index in enumerate(self.node_indices) + } + + def _calculate_node_indices(self): + if self.location == "node": + return self.indices + elif self.location == "edge": + connectivity = self.mesh.edge_node_connectivity + node_set = list(set(connectivity.indices.compressed())) + node_set.sort() + return node_set + elif self.location == "face": + connectivity = self.mesh.face_node_connectivity + node_set = list(set(connectivity.indices.compressed())) + node_set.sort() + return node_set + + def _calculate_edge_indices(self): + if self.location != "edge": + return None + return self.indices + + def _calculate_face_indices(self): + if self.location != "face": + return None + return self.indices + + +class _MeshIndexCoordinateManager(_MeshIndexManager): + + REQUIRED = ( + "node_x", + "node_y", + ) + OPTIONAL = ( + "edge_x", + "edge_y", + "face_x", + "face_y", + ) + + def __init__(self, mesh, location, indices): + super().__init__(mesh, location, indices) + self.ALL = self.REQUIRED + self.OPTIONAL + self._members = {member: getattr(self, member) for member in self.ALL} + + @property + def node_x(self): + return self.mesh._coord_manager.node_x[self.node_indices] + + @property + def node_y(self): + return self.mesh._coord_manager.node_y[self.node_indices] + + @property + def edge_x(self): + return self.mesh._coord_manager.edge_x[self.edge_indices] + + @property + def edge_y(self): + return self.mesh._coord_manager.edge_y[self.edge_indices] + + @property + def face_x(self): + return self.mesh._coord_manager.face_x[self.face_indices] + + @property + def face_y(self): + return self.mesh._coord_manager.face_y[self.face_indices] + + @property + def node_coords(self): + return MeshNodeCoords(node_x=self.node_x, node_y=self.node_y) + + @property + def edge_coords(self): + return MeshEdgeCoords(edge_x=self.edge_x, edge_y=self.edge_y) + + @property + def face_coords(self): + return MeshFaceCoords(face_x=self.face_x, face_y=self.face_y) + + def filters( + self, + item=None, + standard_name=None, + long_name=None, + var_name=None, + attributes=None, + axis=None, + include_nodes=None, + include_edges=None, + include_faces=None, + ): + # TBD: support coord_systems? + + # Preserve original argument before modifying. + face_requested = include_faces + + # Rationalise the tri-state behaviour. + args = [include_nodes, include_edges, include_faces] + state = not any(set(filter(lambda arg: arg is not None, args))) + include_nodes, include_edges, include_faces = map( + lambda arg: arg if arg is not None else state, args + ) + + def populated_coords(coords_tuple): + return list(filter(None, list(coords_tuple))) + + members = [] + if include_nodes: + members += populated_coords(self.node_coords) + if include_edges: + members += populated_coords(self.edge_coords) + if hasattr(self, "face_coords"): + if include_faces: + members += populated_coords(self.face_coords) + elif face_requested: + dmsg = "Ignoring request to filter non-existent 'face_coords'" + logger.debug(dmsg, extra=dict(cls=self.__class__.__name__)) + + result = metadata_filter( + members, + item=item, + standard_name=standard_name, + long_name=long_name, + var_name=var_name, + attributes=attributes, + axis=axis, + ) + + # Use the results to filter the _members dict for returning. + result_ids = [id(r) for r in result] + result_dict = {k: v for k, v in self._members.items() if id(v) in result_ids} + return result_dict + + def filter(self, **kwargs): + # TODO: rationalise commonality with MeshConnectivityManager.filter and Cube.coord. + result = self.filters(**kwargs) + + if len(result) > 1: + names = ", ".join(f"{member}={coord!r}" for member, coord in result.items()) + emsg = ( + f"Expected to find exactly 1 coordinate, but found {len(result)}. " + f"They were: {names}." + ) + raise CoordinateNotFoundError(emsg) + + if len(result) == 0: + item = kwargs["item"] + if item is not None: + if not isinstance(item, str): + item = item.name() + name = ( + item + or kwargs["standard_name"] + or kwargs["long_name"] + or kwargs["var_name"] + or None + ) + name = "" if name is None else f"{name!r} " + emsg = f"Expected to find exactly 1 {name}coordinate, but found none." + raise CoordinateNotFoundError(emsg) + + return result + + +class _MeshIndexConnectivityManager(_MeshIndexManager): + @property + def edge_node(self): + if self.edge_indices is None: + return None + else: + connectivity = self.mesh.edge_node_connectivity[self.edge_indices] + connectivity.indices = np.vectorize(self.node_index_dict.get)( + connectivity.indices + ) + return connectivity + + @property + def face_node(self): + if self.face_indices is None: + return None + else: + connectivity = self.mesh.face_node_connectivity[self.face_indices] + connectivity.indices = np.vectorize(self.node_index_dict.get)( + connectivity.indices + ) + return connectivity + + class _Mesh1DCoordinateManager: """TBD: require clarity on coord_systems validation. From b2ba7ff2ec29c65f36992bb2fa28a0c5233ad8ae Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Wed, 19 Jun 2024 15:16:11 +0100 Subject: [PATCH 2/9] improve MeshIndexSet --- lib/iris/experimental/ugrid/mesh.py | 97 +++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 20 deletions(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index db9db75ca5..2b2f8b34a0 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -1989,6 +1989,14 @@ def __init__(self, mesh, location, indices): mesh, location, indices ) + def __eq__(self, other): + # TBD: this is a minimalist implementation and requires to be revisited + return id(self) == id(other) + + def __ne__(self, other): + # TBD: this is a minimalist implementation and requires to be revisited + return id(self) != id(other) + class _MeshIndexManager: def __init__(self, mesh, location, indices): @@ -2008,12 +2016,12 @@ def _calculate_node_indices(self): if self.location == "node": return self.indices elif self.location == "edge": - connectivity = self.mesh.edge_node_connectivity + connectivity = self.mesh.edge_node_connectivity[self.indices] node_set = list(set(connectivity.indices.compressed())) node_set.sort() return node_set elif self.location == "face": - connectivity = self.mesh.face_node_connectivity + connectivity = self.mesh.face_node_connectivity[self.indices] node_set = list(set(connectivity.indices.compressed())) node_set.sort() return node_set @@ -2045,31 +2053,58 @@ class _MeshIndexCoordinateManager(_MeshIndexManager): def __init__(self, mesh, location, indices): super().__init__(mesh, location, indices) self.ALL = self.REQUIRED + self.OPTIONAL + self._members = {} self._members = {member: getattr(self, member) for member in self.ALL} + def __eq__(self, other): + # TBD: this is a minimalist implementation and requires to be revisited + return id(self) == id(other) + + def __ne__(self, other): + # TBD: this is a minimalist implementation and requires to be revisited + return id(self) != id(other) + @property def node_x(self): - return self.mesh._coord_manager.node_x[self.node_indices] + if "node_x" in self._members: + return self._members["node_x"] + else: + return self.mesh._coord_manager.node_x[self.node_indices] @property def node_y(self): - return self.mesh._coord_manager.node_y[self.node_indices] + if "node_y" in self._members: + return self._members["node_y"] + else: + return self.mesh._coord_manager.node_y[self.node_indices] @property def edge_x(self): - return self.mesh._coord_manager.edge_x[self.edge_indices] + if "edge_x" in self._members: + return self._members["edge_x"] + else: + return self.mesh._coord_manager.edge_x[self.edge_indices] @property def edge_y(self): - return self.mesh._coord_manager.edge_y[self.edge_indices] + if "edge_y" in self._members: + return self._members["edge_y"] + else: + return self.mesh._coord_manager.edge_y[self.edge_indices] @property def face_x(self): - return self.mesh._coord_manager.face_x[self.face_indices] + if "face_x" in self._members: + return self._members["face_x"] + else: + return self.mesh._coord_manager.face_x[self.face_indices] @property def face_y(self): - return self.mesh._coord_manager.face_y[self.face_indices] + if "face_y" in self._members: + return self._members["face_y"] + else: + return self.mesh._coord_manager.face_y[self.face_indices] @property def node_coords(self): @@ -2084,16 +2119,16 @@ def face_coords(self): return MeshFaceCoords(face_x=self.face_x, face_y=self.face_y) def filters( - self, - item=None, - standard_name=None, - long_name=None, - var_name=None, - attributes=None, - axis=None, - include_nodes=None, - include_edges=None, - include_faces=None, + self, + item=None, + standard_name=None, + long_name=None, + var_name=None, + attributes=None, + axis=None, + include_nodes=None, + include_edges=None, + include_faces=None, ): # TBD: support coord_systems? @@ -2175,9 +2210,20 @@ def edge_node(self): return None else: connectivity = self.mesh.edge_node_connectivity[self.edge_indices] - connectivity.indices = np.vectorize(self.node_index_dict.get)( + connectivity_indices = np.vectorize(self.node_index_dict.get)( connectivity.indices ) + connectivity = Connectivity( + connectivity_indices, + connectivity.cf_role, + standard_name=connectivity.standard_name, + long_name=connectivity.long_name, + var_name=connectivity.var_name, + units=connectivity.units, + attributes=connectivity.attributes, + start_index=connectivity.start_index, + location_axis=connectivity.location_axis, + ) return connectivity @property @@ -2186,9 +2232,20 @@ def face_node(self): return None else: connectivity = self.mesh.face_node_connectivity[self.face_indices] - connectivity.indices = np.vectorize(self.node_index_dict.get)( + connectivity_indices = np.vectorize(self.node_index_dict.get)( connectivity.indices ) + connectivity = Connectivity( + connectivity_indices, + connectivity.cf_role, + standard_name=connectivity.standard_name, + long_name=connectivity.long_name, + var_name=connectivity.var_name, + units=connectivity.units, + attributes=connectivity.attributes, + start_index=connectivity.start_index, + location_axis=connectivity.location_axis, + ) return connectivity From 7c05019b549c4a6ac91a27530de42ee14bd1fbe2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jun 2024 14:56:39 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lib/iris/experimental/ugrid/mesh.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index 2b2f8b34a0..115833d3bf 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -2038,7 +2038,6 @@ def _calculate_face_indices(self): class _MeshIndexCoordinateManager(_MeshIndexManager): - REQUIRED = ( "node_x", "node_y", From 800ad2d08f6a6461a2872f3d6b292ce3ef57dff6 Mon Sep 17 00:00:00 2001 From: Martin Yeo Date: Fri, 9 Aug 2024 13:29:51 +0100 Subject: [PATCH 4/9] Simpler MeshMetadata generation. --- lib/iris/experimental/ugrid/mesh.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index 115833d3bf..c594b4824d 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -1968,21 +1968,7 @@ def __init__(self, mesh, location, indices): self.location = location self.indices = indices - self._metadata_manager = metadata_manager_factory(MeshMetadata) - - # topology_dimension is read-only, so assign directly to the metadata manager - self._metadata_manager.topology_dimension = mesh.topology_dimension - - self.node_dimension = mesh.node_dimension - self.edge_dimension = mesh.edge_dimension - self.face_dimension = mesh.face_dimension - - # assign the metadata to the metadata manager - self.standard_name = mesh.standard_name - self.long_name = mesh.long_name - self.var_name = mesh.var_name - self.units = mesh.units - self.attributes = mesh.attributes + self._metadata_manager = MeshMetadata.from_metadata(mesh.metadata) self._coord_manager = _MeshIndexCoordinateManager(mesh, location, indices) self._connectivity_manager = _MeshIndexConnectivityManager( From ac48bd6df52c54c9db982b5b7273f509fff69f8b Mon Sep 17 00:00:00 2001 From: Martin Yeo Date: Fri, 9 Aug 2024 13:38:30 +0100 Subject: [PATCH 5/9] Common handling of face and edge connectivities. --- lib/iris/experimental/ugrid/mesh.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index c594b4824d..556567071b 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -2001,16 +2001,24 @@ def __init__(self, mesh, location, indices): def _calculate_node_indices(self): if self.location == "node": return self.indices - elif self.location == "edge": - connectivity = self.mesh.edge_node_connectivity[self.indices] - node_set = list(set(connectivity.indices.compressed())) - node_set.sort() - return node_set - elif self.location == "face": - connectivity = self.mesh.face_node_connectivity[self.indices] - node_set = list(set(connectivity.indices.compressed())) + elif self.location in ["edge", "face"]: + (connectivity,) = [ + c + for c in self.mesh.all_connectivities + if c.location == self.location and c.connected == "node" + ] + conn_indices = connectivity.indices_by_location()[self.indices] + node_set = list(set(conn_indices.compressed())) node_set.sort() return node_set + else: + # TODO: should this be validated earlier? + # Maybe even with an Enum? + message = ( + f"Expected location to be one of `node`, `edge` or `face`, " + f"got `{self.location}`" + ) + raise NotImplementedError(message) def _calculate_edge_indices(self): if self.location != "edge": From ab93e7a8646f09c81ac93fffeb499b8e61fae21c Mon Sep 17 00:00:00 2001 From: Martin Yeo Date: Fri, 9 Aug 2024 13:43:10 +0100 Subject: [PATCH 6/9] Consistent return pattern with rest of Iris. --- lib/iris/experimental/ugrid/mesh.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index 556567071b..01456ea44d 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -2000,7 +2000,7 @@ def __init__(self, mesh, location, indices): def _calculate_node_indices(self): if self.location == "node": - return self.indices + result = self.indices elif self.location in ["edge", "face"]: (connectivity,) = [ c @@ -2010,8 +2010,9 @@ def _calculate_node_indices(self): conn_indices = connectivity.indices_by_location()[self.indices] node_set = list(set(conn_indices.compressed())) node_set.sort() - return node_set + result = node_set else: + result = None # TODO: should this be validated earlier? # Maybe even with an Enum? message = ( @@ -2020,15 +2021,21 @@ def _calculate_node_indices(self): ) raise NotImplementedError(message) + return result + def _calculate_edge_indices(self): - if self.location != "edge": - return None - return self.indices + if self.location == "edge": + result = self.indices + else: + result = None + return result def _calculate_face_indices(self): - if self.location != "face": - return None - return self.indices + if self.location == "face": + result = self.indices + else: + result = None + return result class _MeshIndexCoordinateManager(_MeshIndexManager): From a1d38c4a2adc4336c29d24666f5ad663d6c1279d Mon Sep 17 00:00:00 2001 From: Martin Yeo Date: Mon, 12 Aug 2024 12:23:33 +0100 Subject: [PATCH 7/9] Alternative for managing indexed meshes. --- lib/iris/experimental/ugrid/mesh.py | 517 ++++++++++++---------------- 1 file changed, 228 insertions(+), 289 deletions(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index 01456ea44d..b465f3660d 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -12,6 +12,8 @@ from abc import ABC, abstractmethod from collections import namedtuple from collections.abc import Container +from contextlib import contextmanager +from copy import copy from typing import Iterable from cf_units import Unit @@ -603,9 +605,9 @@ class Mesh(CFVariableMixin): def __init__( self, - topology_dimension, - node_coords_and_axes, - connectivities, + topology_dimension=None, + node_coords_and_axes=None, + connectivities=None, edge_coords_and_axes=None, face_coords_and_axes=None, standard_name=None, @@ -616,6 +618,7 @@ def __init__( node_dimension=None, edge_dimension=None, face_dimension=None, + _copy_mesh=None, ): """Mesh initialise. @@ -631,66 +634,94 @@ def __init__( # TODO: support volumes. # TODO: support (coord, "z") - self._metadata_manager = metadata_manager_factory(MeshMetadata) + copy_mode = ( + _copy_mesh is not None and getattr(_copy_mesh, "cf_role") == "mesh_topology" + ) - # topology_dimension is read-only, so assign directly to the metadata manager - if topology_dimension not in self.TOPOLOGY_DIMENSIONS: - emsg = f"Expected 'topology_dimension' in range {self.TOPOLOGY_DIMENSIONS!r}, got {topology_dimension!r}." - raise ValueError(emsg) - self._metadata_manager.topology_dimension = topology_dimension - - self.node_dimension = node_dimension - self.edge_dimension = edge_dimension - self.face_dimension = face_dimension - - # assign the metadata to the metadata manager - self.standard_name = standard_name - self.long_name = long_name - self.var_name = var_name - self.units = units - self.attributes = attributes - - # based on the topology_dimension, create the appropriate coordinate manager - def normalise(element, axis): - result = str(axis).lower() - if result not in self.AXES: - emsg = f"Invalid axis specified for {element} coordinate {coord.name()!r}, got {axis!r}." + if copy_mode: + self._metadata_manager = _copy_mesh.metadata + self._coord_man = _copy_mesh._coord_manager + self._connectivity_man = _copy_mesh._connectivity_manager + else: + mandatory = [topology_dimension, node_coords_and_axes, connectivities] + if any([m is None for m in mandatory]): + message = ( + "1 or more of mandatory arguments missing: " + "topology_dimension, node_coords_and_axes, connectivities." + ) + raise ValueError(message) + + self._metadata_manager = metadata_manager_factory(MeshMetadata) + + # topology_dimension is read-only, so assign directly to the metadata manager + if topology_dimension not in self.TOPOLOGY_DIMENSIONS: + emsg = ( + f"Expected 'topology_dimension' in range " + f"{self.TOPOLOGY_DIMENSIONS!r}, got " + f"{topology_dimension!r}." + ) + raise ValueError(emsg) + self._metadata_manager.topology_dimension = topology_dimension + + self.node_dimension = node_dimension + self.edge_dimension = edge_dimension + self.face_dimension = face_dimension + + # assign the metadata to the metadata manager + self.standard_name = standard_name + self.long_name = long_name + self.var_name = var_name + self.units = units + self.attributes = attributes + + # based on the topology_dimension, create the appropriate coordinate manager + def normalise(element, axis): + result = str(axis).lower() + if result not in self.AXES: + emsg = ( + f"Invalid axis specified for {element} coordinate " + f"{coord.name()!r}, got {axis!r}." + ) + raise ValueError(emsg) + return f"{element}_{result}" + + if not isinstance(node_coords_and_axes, Iterable): + node_coords_and_axes = [node_coords_and_axes] + + if not isinstance(connectivities, Iterable): + connectivities = [connectivities] + + kwargs = {} + for coord, axis in node_coords_and_axes: + kwargs[normalise("node", axis)] = coord + if edge_coords_and_axes is not None: + for coord, axis in edge_coords_and_axes: + kwargs[normalise("edge", axis)] = coord + if face_coords_and_axes is not None: + for coord, axis in face_coords_and_axes: + kwargs[normalise("face", axis)] = coord + + # check the UGRID minimum requirement for coordinates + if "node_x" not in kwargs: + emsg = ( + "Require a node coordinate that is x-axis like to be " "provided." + ) + raise ValueError(emsg) + if "node_y" not in kwargs: + emsg = ( + "Require a node coordinate that is y-axis like to be " "provided." + ) raise ValueError(emsg) - return f"{element}_{result}" - - if not isinstance(node_coords_and_axes, Iterable): - node_coords_and_axes = [node_coords_and_axes] - - if not isinstance(connectivities, Iterable): - connectivities = [connectivities] - - kwargs = {} - for coord, axis in node_coords_and_axes: - kwargs[normalise("node", axis)] = coord - if edge_coords_and_axes is not None: - for coord, axis in edge_coords_and_axes: - kwargs[normalise("edge", axis)] = coord - if face_coords_and_axes is not None: - for coord, axis in face_coords_and_axes: - kwargs[normalise("face", axis)] = coord - - # check the UGRID minimum requirement for coordinates - if "node_x" not in kwargs: - emsg = "Require a node coordinate that is x-axis like to be provided." - raise ValueError(emsg) - if "node_y" not in kwargs: - emsg = "Require a node coordinate that is y-axis like to be provided." - raise ValueError(emsg) - if self.topology_dimension == 1: - self._coord_manager = _Mesh1DCoordinateManager(**kwargs) - self._connectivity_manager = _Mesh1DConnectivityManager(*connectivities) - elif self.topology_dimension == 2: - self._coord_manager = _Mesh2DCoordinateManager(**kwargs) - self._connectivity_manager = _Mesh2DConnectivityManager(*connectivities) - else: - emsg = f"Unsupported 'topology_dimension', got {topology_dimension!r}." - raise NotImplementedError(emsg) + if self.topology_dimension == 1: + self._coord_man = _Mesh1DCoordinateManager(**kwargs) + self._connectivity_man = _Mesh1DConnectivityManager(*connectivities) + elif self.topology_dimension == 2: + self._coord_man = _Mesh2DCoordinateManager(**kwargs) + self._connectivity_man = _Mesh2DConnectivityManager(*connectivities) + else: + emsg = f"Unsupported 'topology_dimension', got {topology_dimension!r}." + raise NotImplementedError(emsg) @classmethod def from_coords(cls, *coords): @@ -1092,6 +1123,16 @@ def _set_dimension_names(self, node, edge, face, reset=False): return result + @property + def _connectivity_manager(self): + # Delivered as a property to help with overriding. + return self._connectivity_man + + @property + def _coord_manager(self): + # Delivered as a property to help with overriding. + return self._coord_man + @property def all_connectivities(self): """All the :class:`~iris.experimental.ugrid.mesh.Connectivity` instances of the :class:`Mesh`.""" @@ -1963,18 +2004,34 @@ def topology_dimension(self): class MeshIndexSet(Mesh): + class _IndexedMembers(dict): + def _readonly(self, *args, **kwargs): + message = ( + "Modification of MeshIndexSet is forbidden - this is only " + f"a view onto an original Mesh: id={self.mesh_id}." + ) + raise RuntimeError(message) + + __setitem__ = _readonly + __delitem__ = _readonly + pop = _readonly + popitem = _readonly + clear = _readonly + update = _readonly + setdefault = _readonly + + mesh_id: int + + def __init__(self, seq, **kwargs): + self.mesh_id = kwargs.pop("mesh_id") + super().__init__(seq, **kwargs) + def __init__(self, mesh, location, indices): + super().__init__(_copy_mesh=mesh) self.super_mesh = mesh self.location = location self.indices = indices - self._metadata_manager = MeshMetadata.from_metadata(mesh.metadata) - - self._coord_manager = _MeshIndexCoordinateManager(mesh, location, indices) - self._connectivity_manager = _MeshIndexConnectivityManager( - mesh, location, indices - ) - def __eq__(self, other): # TBD: this is a minimalist implementation and requires to be revisited return id(self) == id(other) @@ -1983,29 +2040,18 @@ def __ne__(self, other): # TBD: this is a minimalist implementation and requires to be revisited return id(self) != id(other) - -class _MeshIndexManager: - def __init__(self, mesh, location, indices): - self.mesh = mesh - self.location = location - self.indices = indices - - self.face_indices = self._calculate_face_indices() - self.edge_indices = self._calculate_edge_indices() - self.node_indices = self._calculate_node_indices() - self.node_index_dict = { - old_index: new_index - for new_index, old_index in enumerate(self.node_indices) - } - def _calculate_node_indices(self): if self.location == "node": result = self.indices elif self.location in ["edge", "face"]: (connectivity,) = [ c - for c in self.mesh.all_connectivities - if c.location == self.location and c.connected == "node" + for c in self.super_mesh.all_connectivities + if ( + c is not None + and c.location == self.location + and c.connected == "node" + ) ] conn_indices = connectivity.indices_by_location()[self.indices] node_set = list(set(conn_indices.compressed())) @@ -2037,216 +2083,32 @@ def _calculate_face_indices(self): result = None return result - -class _MeshIndexCoordinateManager(_MeshIndexManager): - REQUIRED = ( - "node_x", - "node_y", - ) - OPTIONAL = ( - "edge_x", - "edge_y", - "face_x", - "face_y", - ) - - def __init__(self, mesh, location, indices): - super().__init__(mesh, location, indices) - self.ALL = self.REQUIRED + self.OPTIONAL - self._members = {} - self._members = {member: getattr(self, member) for member in self.ALL} - - def __eq__(self, other): - # TBD: this is a minimalist implementation and requires to be revisited - return id(self) == id(other) - - def __ne__(self, other): - # TBD: this is a minimalist implementation and requires to be revisited - return id(self) != id(other) - - @property - def node_x(self): - if "node_x" in self._members: - return self._members["node_x"] - else: - return self.mesh._coord_manager.node_x[self.node_indices] - - @property - def node_y(self): - if "node_y" in self._members: - return self._members["node_y"] - else: - return self.mesh._coord_manager.node_y[self.node_indices] - - @property - def edge_x(self): - if "edge_x" in self._members: - return self._members["edge_x"] - else: - return self.mesh._coord_manager.edge_x[self.edge_indices] - @property - def edge_y(self): - if "edge_y" in self._members: - return self._members["edge_y"] - else: - return self.mesh._coord_manager.edge_y[self.edge_indices] - - @property - def face_x(self): - if "face_x" in self._members: - return self._members["face_x"] - else: - return self.mesh._coord_manager.face_x[self.face_indices] - - @property - def face_y(self): - if "face_y" in self._members: - return self._members["face_y"] - else: - return self.mesh._coord_manager.face_y[self.face_indices] - - @property - def node_coords(self): - return MeshNodeCoords(node_x=self.node_x, node_y=self.node_y) - - @property - def edge_coords(self): - return MeshEdgeCoords(edge_x=self.edge_x, edge_y=self.edge_y) - - @property - def face_coords(self): - return MeshFaceCoords(face_x=self.face_x, face_y=self.face_y) - - def filters( - self, - item=None, - standard_name=None, - long_name=None, - var_name=None, - attributes=None, - axis=None, - include_nodes=None, - include_edges=None, - include_faces=None, - ): - # TBD: support coord_systems? - - # Preserve original argument before modifying. - face_requested = include_faces - - # Rationalise the tri-state behaviour. - args = [include_nodes, include_edges, include_faces] - state = not any(set(filter(lambda arg: arg is not None, args))) - include_nodes, include_edges, include_faces = map( - lambda arg: arg if arg is not None else state, args - ) - - def populated_coords(coords_tuple): - return list(filter(None, list(coords_tuple))) - - members = [] - if include_nodes: - members += populated_coords(self.node_coords) - if include_edges: - members += populated_coords(self.edge_coords) - if hasattr(self, "face_coords"): - if include_faces: - members += populated_coords(self.face_coords) - elif face_requested: - dmsg = "Ignoring request to filter non-existent 'face_coords'" - logger.debug(dmsg, extra=dict(cls=self.__class__.__name__)) - - result = metadata_filter( - members, - item=item, - standard_name=standard_name, - long_name=long_name, - var_name=var_name, - attributes=attributes, - axis=axis, - ) - - # Use the results to filter the _members dict for returning. - result_ids = [id(r) for r in result] - result_dict = {k: v for k, v in self._members.items() if id(v) in result_ids} - return result_dict - - def filter(self, **kwargs): - # TODO: rationalise commonality with MeshConnectivityManager.filter and Cube.coord. - result = self.filters(**kwargs) - - if len(result) > 1: - names = ", ".join(f"{member}={coord!r}" for member, coord in result.items()) - emsg = ( - f"Expected to find exactly 1 coordinate, but found {len(result)}. " - f"They were: {names}." - ) - raise CoordinateNotFoundError(emsg) - - if len(result) == 0: - item = kwargs["item"] - if item is not None: - if not isinstance(item, str): - item = item.name() - name = ( - item - or kwargs["standard_name"] - or kwargs["long_name"] - or kwargs["var_name"] - or None - ) - name = "" if name is None else f"{name!r} " - emsg = f"Expected to find exactly 1 {name}coordinate, but found none." - raise CoordinateNotFoundError(emsg) + def _coord_manager(self): + # Intended to be a 'view' on the original, and Meshes are mutable, so + # must re-index every time it is accessed. + with self.super_mesh._coord_man.indexed( + self._calculate_node_indices(), + self._calculate_edge_indices(), + self._calculate_face_indices(), + mesh_id=id(self.super_mesh), + ) as indexed_coord_man: + result = copy(indexed_coord_man) return result - -class _MeshIndexConnectivityManager(_MeshIndexManager): - @property - def edge_node(self): - if self.edge_indices is None: - return None - else: - connectivity = self.mesh.edge_node_connectivity[self.edge_indices] - connectivity_indices = np.vectorize(self.node_index_dict.get)( - connectivity.indices - ) - connectivity = Connectivity( - connectivity_indices, - connectivity.cf_role, - standard_name=connectivity.standard_name, - long_name=connectivity.long_name, - var_name=connectivity.var_name, - units=connectivity.units, - attributes=connectivity.attributes, - start_index=connectivity.start_index, - location_axis=connectivity.location_axis, - ) - return connectivity - @property - def face_node(self): - if self.face_indices is None: - return None - else: - connectivity = self.mesh.face_node_connectivity[self.face_indices] - connectivity_indices = np.vectorize(self.node_index_dict.get)( - connectivity.indices - ) - connectivity = Connectivity( - connectivity_indices, - connectivity.cf_role, - standard_name=connectivity.standard_name, - long_name=connectivity.long_name, - var_name=connectivity.var_name, - units=connectivity.units, - attributes=connectivity.attributes, - start_index=connectivity.start_index, - location_axis=connectivity.location_axis, - ) - return connectivity + def _connectivity_manager(self): + # Intended to be a 'view' on the original, and Meshes are mutable, so + # must re-index every time it is accessed. + with self.super_mesh._connectivity_man.indexed( + self._calculate_node_indices(), + self._calculate_edge_indices(), + self._calculate_face_indices(), + mesh_id=id(self.super_mesh), + ) as indexed_connectivity_man: + result = copy(indexed_connectivity_man) + return result class _Mesh1DCoordinateManager: @@ -2306,6 +2168,41 @@ def __str__(self): args = [f"{member}" for member, coord in self if coord is not None] return f"{self.__class__.__name__}({', '.join(args)})" + @contextmanager + def indexed( + self, + node_indices: np.typing.ArrayLike, + edge_indices: np.typing.ArrayLike, + face_indices: np.typing.ArrayLike, + mesh_id: int, + ): + # TODO: make members_indexed read-only. + members_original = copy(self._members) + + indices_dict = { + "node": node_indices, + "edge": edge_indices, + "face": face_indices, + } + members_indexed = {} + for key, coord in self._members.items(): + indexing = None + indexed = None + if coord is not None: + indexing = indices_dict[key.split("_")[0]] + if indexing is not None: + indexed = coord[indexing] + members_indexed[key] = indexed + + try: + self._members = MeshIndexSet._IndexedMembers( + members_indexed, + mesh_id=mesh_id, + ) + yield self + finally: + self._members = members_original + def _remove(self, **kwargs): result = {} members = self.filters(**kwargs) @@ -2834,6 +2731,48 @@ def element_filter(instances, loc_arg, loc_name): result_dict = {k: v for k, v in self._members.items() if id(v) in result_ids} return result_dict + @contextmanager + def indexed( + self, + node_indices: np.typing.ArrayLike, + edge_indices: np.typing.ArrayLike, + face_indices: np.typing.ArrayLike, + mesh_id: int, + ): + # TODO: make members_indexed read-only. + members_original = copy(self._members) + + indices_dict = { + "node": node_indices, + "edge": edge_indices, + "face": face_indices, + } + node_index_mapping = { + old_index: new_index for new_index, old_index in enumerate(node_indices) + } + members_indexed = {} + for key, connectivity in self._members.items(): + indexing = None + indexed = None + if connectivity is not None: + indexing = indices_dict[connectivity.location] + if indexing is not None: + new_values = connectivity.indices_by_location()[indexing] + new_values = np.vectorize(node_index_mapping.get)(new_values) + if connectivity.location_axis == 1: + new_values = new_values.T + indexed = connectivity.copy(new_values) + members_indexed[key] = indexed + + try: + self._members = MeshIndexSet._IndexedMembers( + members_indexed, + mesh_id=mesh_id, + ) + yield self + finally: + self._members = members_original + def remove( self, item=None, From 6b2448937acf41952e039e4bb46bca1fb55d8e53 Mon Sep 17 00:00:00 2001 From: Martin Yeo Date: Mon, 12 Aug 2024 12:25:16 +0100 Subject: [PATCH 8/9] Don't need indices_by_location. --- lib/iris/experimental/ugrid/mesh.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index b465f3660d..6c75e65811 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -2053,7 +2053,8 @@ def _calculate_node_indices(self): and c.connected == "node" ) ] - conn_indices = connectivity.indices_by_location()[self.indices] + # Doesn't matter if connectivity is transposed or not in this case. + conn_indices = connectivity.indices[self.indices] node_set = list(set(conn_indices.compressed())) node_set.sort() result = node_set From 63702e52770e117bd70b18b9b314ab473e7efa46 Mon Sep 17 00:00:00 2001 From: Martin Yeo Date: Wed, 14 Aug 2024 16:15:00 +0100 Subject: [PATCH 9/9] Use MeshIndexSet.from_mesh method. --- lib/iris/experimental/ugrid/mesh.py | 186 ++++++++++++++-------------- 1 file changed, 90 insertions(+), 96 deletions(-) diff --git a/lib/iris/experimental/ugrid/mesh.py b/lib/iris/experimental/ugrid/mesh.py index 6c75e65811..342ff5c735 100644 --- a/lib/iris/experimental/ugrid/mesh.py +++ b/lib/iris/experimental/ugrid/mesh.py @@ -19,6 +19,7 @@ from cf_units import Unit from dask import array as da import numpy as np +from numpy.typing import ArrayLike from ... import _lazy_data as _lazy from ...common import CFVariableMixin, metadata_filter, metadata_manager_factory @@ -605,9 +606,9 @@ class Mesh(CFVariableMixin): def __init__( self, - topology_dimension=None, - node_coords_and_axes=None, - connectivities=None, + topology_dimension, + node_coords_and_axes, + connectivities, edge_coords_and_axes=None, face_coords_and_axes=None, standard_name=None, @@ -618,7 +619,6 @@ def __init__( node_dimension=None, edge_dimension=None, face_dimension=None, - _copy_mesh=None, ): """Mesh initialise. @@ -634,94 +634,66 @@ def __init__( # TODO: support volumes. # TODO: support (coord, "z") - copy_mode = ( - _copy_mesh is not None and getattr(_copy_mesh, "cf_role") == "mesh_topology" - ) - - if copy_mode: - self._metadata_manager = _copy_mesh.metadata - self._coord_man = _copy_mesh._coord_manager - self._connectivity_man = _copy_mesh._connectivity_manager - else: - mandatory = [topology_dimension, node_coords_and_axes, connectivities] - if any([m is None for m in mandatory]): - message = ( - "1 or more of mandatory arguments missing: " - "topology_dimension, node_coords_and_axes, connectivities." - ) - raise ValueError(message) - - self._metadata_manager = metadata_manager_factory(MeshMetadata) + self._metadata_manager = metadata_manager_factory(MeshMetadata) - # topology_dimension is read-only, so assign directly to the metadata manager - if topology_dimension not in self.TOPOLOGY_DIMENSIONS: - emsg = ( - f"Expected 'topology_dimension' in range " - f"{self.TOPOLOGY_DIMENSIONS!r}, got " - f"{topology_dimension!r}." - ) - raise ValueError(emsg) - self._metadata_manager.topology_dimension = topology_dimension - - self.node_dimension = node_dimension - self.edge_dimension = edge_dimension - self.face_dimension = face_dimension - - # assign the metadata to the metadata manager - self.standard_name = standard_name - self.long_name = long_name - self.var_name = var_name - self.units = units - self.attributes = attributes - - # based on the topology_dimension, create the appropriate coordinate manager - def normalise(element, axis): - result = str(axis).lower() - if result not in self.AXES: - emsg = ( - f"Invalid axis specified for {element} coordinate " - f"{coord.name()!r}, got {axis!r}." - ) - raise ValueError(emsg) - return f"{element}_{result}" - - if not isinstance(node_coords_and_axes, Iterable): - node_coords_and_axes = [node_coords_and_axes] - - if not isinstance(connectivities, Iterable): - connectivities = [connectivities] - - kwargs = {} - for coord, axis in node_coords_and_axes: - kwargs[normalise("node", axis)] = coord - if edge_coords_and_axes is not None: - for coord, axis in edge_coords_and_axes: - kwargs[normalise("edge", axis)] = coord - if face_coords_and_axes is not None: - for coord, axis in face_coords_and_axes: - kwargs[normalise("face", axis)] = coord - - # check the UGRID minimum requirement for coordinates - if "node_x" not in kwargs: - emsg = ( - "Require a node coordinate that is x-axis like to be " "provided." - ) - raise ValueError(emsg) - if "node_y" not in kwargs: - emsg = ( - "Require a node coordinate that is y-axis like to be " "provided." - ) + # topology_dimension is read-only, so assign directly to the metadata manager + if topology_dimension not in self.TOPOLOGY_DIMENSIONS: + emsg = f"Expected 'topology_dimension' in range {self.TOPOLOGY_DIMENSIONS!r}, got {topology_dimension!r}." + raise ValueError(emsg) + self._metadata_manager.topology_dimension = topology_dimension + + self.node_dimension = node_dimension + self.edge_dimension = edge_dimension + self.face_dimension = face_dimension + + # assign the metadata to the metadata manager + self.standard_name = standard_name + self.long_name = long_name + self.var_name = var_name + self.units = units + self.attributes = attributes + + # based on the topology_dimension, create the appropriate coordinate manager + def normalise(element, axis): + result = str(axis).lower() + if result not in self.AXES: + emsg = f"Invalid axis specified for {element} coordinate {coord.name()!r}, got {axis!r}." raise ValueError(emsg) + return f"{element}_{result}" + + if not isinstance(node_coords_and_axes, Iterable): + node_coords_and_axes = [node_coords_and_axes] + + if not isinstance(connectivities, Iterable): + connectivities = [connectivities] + + kwargs = {} + for coord, axis in node_coords_and_axes: + kwargs[normalise("node", axis)] = coord + if edge_coords_and_axes is not None: + for coord, axis in edge_coords_and_axes: + kwargs[normalise("edge", axis)] = coord + if face_coords_and_axes is not None: + for coord, axis in face_coords_and_axes: + kwargs[normalise("face", axis)] = coord + + # check the UGRID minimum requirement for coordinates + if "node_x" not in kwargs: + emsg = "Require a node coordinate that is x-axis like to be provided." + raise ValueError(emsg) + if "node_y" not in kwargs: + emsg = "Require a node coordinate that is y-axis like to be provided." + raise ValueError(emsg) - if self.topology_dimension == 1: - self._coord_man = _Mesh1DCoordinateManager(**kwargs) - self._connectivity_man = _Mesh1DConnectivityManager(*connectivities) - elif self.topology_dimension == 2: - self._coord_man = _Mesh2DCoordinateManager(**kwargs) - self._connectivity_man = _Mesh2DConnectivityManager(*connectivities) - else: - emsg = f"Unsupported 'topology_dimension', got {topology_dimension!r}." - raise NotImplementedError(emsg) + if self.topology_dimension == 1: + self._coord_man = _Mesh1DCoordinateManager(**kwargs) + self._connectivity_man = _Mesh1DConnectivityManager(*connectivities) + elif self.topology_dimension == 2: + self._coord_man = _Mesh2DCoordinateManager(**kwargs) + self._connectivity_man = _Mesh2DConnectivityManager(*connectivities) + else: + emsg = f"Unsupported 'topology_dimension', got {topology_dimension!r}." + raise NotImplementedError(emsg) @classmethod def from_coords(cls, *coords): @@ -1091,8 +1063,8 @@ def line(text, i_indent=0): def __setstate__(self, state): metadata_manager, coord_manager, connectivity_manager = state self._metadata_manager = metadata_manager - self._coord_manager = coord_manager - self._connectivity_manager = connectivity_manager + self._coord_man = coord_manager + self._connectivity_man = connectivity_manager def _set_dimension_names(self, node, edge, face, reset=False): args = (node, edge, face) @@ -2004,6 +1976,10 @@ def topology_dimension(self): class MeshIndexSet(Mesh): + super_mesh: Mesh = None + location: str = None + indices: ArrayLike = None + class _IndexedMembers(dict): def _readonly(self, *args, **kwargs): message = ( @@ -2026,11 +2002,26 @@ def __init__(self, seq, **kwargs): self.mesh_id = kwargs.pop("mesh_id") super().__init__(seq, **kwargs) - def __init__(self, mesh, location, indices): - super().__init__(_copy_mesh=mesh) - self.super_mesh = mesh - self.location = location - self.indices = indices + @classmethod + def from_mesh(cls, mesh, location, indices): + instance = copy(mesh) + instance.__class__ = cls + instance.super_mesh = mesh + instance.location = location + instance.indices = indices + return instance + + def _validate(self) -> None: + # Instances are expected to be created via the from_mesh() method. + # If __init__() is used instead, the class will not be fully set up. + # Actually disabling __init__() is considered an anti-pattern. + if any(v is None for v in [self.super_mesh, self.location, self.indices]): + message = ( + "MeshIndexSet instance not fully set up. Note that " + "MeshIndexSet is designed to be created via the from_mesh() " + "method, not via __init__()." + ) + raise RuntimeError(message) def __eq__(self, other): # TBD: this is a minimalist implementation and requires to be revisited @@ -2041,6 +2032,7 @@ def __ne__(self, other): return id(self) != id(other) def _calculate_node_indices(self): + self._validate() if self.location == "node": result = self.indices elif self.location in ["edge", "face"]: @@ -2071,6 +2063,7 @@ def _calculate_node_indices(self): return result def _calculate_edge_indices(self): + self._validate() if self.location == "edge": result = self.indices else: @@ -2078,6 +2071,7 @@ def _calculate_edge_indices(self): return result def _calculate_face_indices(self): + self._validate() if self.location == "face": result = self.indices else: