diff --git a/pytato/distributed/__init__.py b/pytato/distributed/__init__.py index 04b368afb..4354b2f0f 100644 --- a/pytato/distributed/__init__.py +++ b/pytato/distributed/__init__.py @@ -22,7 +22,8 @@ .. class:: CommTagType - A type representing a communication tag. + A type representing a communication tag. Communication tags must be + hashable and totally ordered (and hence comparable). .. class:: ShapeType diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 9e3bde8d0..41ae3273c 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, FrozenSet, Any +from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -40,6 +40,9 @@ import mpi4py.MPI +T = TypeVar("T") + + # {{{ construct tag numbering def number_distributed_tags( @@ -59,6 +62,10 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. + + .. note:: + + This function requires that symbolic tags are comparable. """ tags = frozenset({ recv.comm_tag @@ -73,8 +80,8 @@ def number_distributed_tags( from mpi4py import MPI def set_union( - set_a: FrozenSet[Any], set_b: FrozenSet[Any], - mpi_data_type: MPI.Datatype) -> FrozenSet[str]: + set_a: FrozenSet[T], set_b: FrozenSet[T], + mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: assert mpi_data_type is None assert isinstance(set_a, frozenset) assert isinstance(set_b, frozenset) @@ -99,7 +106,7 @@ def set_union( next_tag = base_tag assert isinstance(all_tags, frozenset) - for sym_tag in all_tags: + for sym_tag in sorted(all_tags): sym_tag_to_int_tag[sym_tag] = next_tag next_tag += 1