From 3364a4fda9cc11d46c2ebe8e32fee0b161da27cf Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:19:06 -0500 Subject: [PATCH 01/22] working pass 1 --- pytato/analysis/__init__.py | 50 +++++++++---------- pytato/distributed/partition.py | 86 ++++++++++++++++++--------------- pytato/transform/__init__.py | 5 +- 3 files changed, 74 insertions(+), 67 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 38ed276fe..1a4359e4b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,48 +310,48 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. - .. note:: - We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]: - return frozenset({dim for dim in shape if isinstance(dim, Array)}) + def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[Array]: + return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]: - return (frozenset(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_stack(self, expr: Stack) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[Array]: - return (frozenset(expr.args) + def map_einsum(self, expr: Einsum) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: - from pytato.loopy import LoopyCall, LoopyCallResult + def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + from pytato.loopy import LoopyCallResult, LoopyCall assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (frozenset(ary + return (FrozenOrderedSet(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: - return (frozenset([expr.array]) - | frozenset(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> abc_Set[Array]: + return (FrozenOrderedSet([expr.array]) + | FrozenOrderedSet(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -360,29 +360,29 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[Array]: - return frozenset([expr.array]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]: + def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]: + def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[Array]: - return frozenset([expr.passthrough_data]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.passthrough_data]) - def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]: + def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[Array]: raise NotImplementedError( "DirectPredecessorsGetter does not yet support expressions containing " "functions.") diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5865ec491..7c9b510e7 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,9 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: frozenset[CommunicationOpIdentifier] - ) -> frozenset[CommunicationOpIdentifier]: - return reduce(frozenset.union, args, frozenset()) + self, *args: Tuple[CommunicationOpIdentifier] + ) -> Tuple[CommunicationOpIdentifier]: + from pytools import unique + return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -496,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: - return frozenset() + def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: + return tuple() map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -505,21 +506,21 @@ def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> frozenset[CommunicationOpIdentifier]: + ) -> Tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() + self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() self.local_recv_id_to_recv_node[recv_id] = expr - return frozenset({recv_id}) + return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -557,10 +558,10 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)] + task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].add(task_id) + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -623,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() + self.materialized_arrays: List[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -633,15 +634,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) if isinstance(expr, LoopyCallResult): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.add(subexpr) + self.materialized_arrays.append(subexpr) else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -651,13 +652,13 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, frozenset[_ValueT]], - dict_b: Mapping[_KeyT, frozenset[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]: + dict_a: Mapping[_KeyT, Sequence[_ValueT]], + dict_b: Mapping[_KeyT, Sequence[_ValueT]], + mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, frozenset()) | values + result[key] = result.get(key, tuple()) + values return result # }}} @@ -782,6 +783,8 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ + from pytools import unique + import mpi4py.MPI as MPI from pytato.transform import SubsetDependencyMapper @@ -833,12 +836,13 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] + + part_comm_ids: List[_PartCommIDs] = [] if comm_batches: - recv_ids: frozenset[CommunicationOpIdentifier] = frozenset() + recv_ids: Tuple[CommunicationOpIdentifier] = tuple() for batch in comm_batches: - send_ids = frozenset( - comm_id for comm_id in batch + send_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( @@ -846,19 +850,19 @@ def find_distributed_partition( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = frozenset( - comm_id for comm_id in batch + recv_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=frozenset())) + send_ids=tuple())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=frozenset(), - send_ids=frozenset())) + recv_ids=tuple(), + send_ids=tuple())) nparts = len(part_comm_ids) @@ -876,7 +880,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in comm_ids.send_ids | comm_ids.recv_ids} + for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} # }}} @@ -888,10 +892,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = _OrderedSet( + sent_arrays = tuple( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -899,14 +903,16 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays = ( - materialized_arrays_collector.materialized_arrays - - received_arrays - - sent_arrays) + materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ + - set(received_arrays) \ + - set(sent_arrays) + + from pytools import unique + materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) # "mso" for "materialized/sent/output" - output_arrays = _OrderedSet(outputs._data.values()) - mso_arrays = materialized_arrays | sent_arrays | output_arrays + output_arrays = tuple(outputs._data.values()) + mso_arrays = materialized_arrays + sent_arrays + output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -970,7 +976,7 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = _OrderedSet(stored_ary_to_part_id) + stored_arrays = tuple(unique(stored_ary_to_part_id)) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) @@ -986,13 +992,13 @@ def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: materialized_preds |= get_materialized_predecessors(pred) return materialized_preds - stored_arrays_promoted_to_part_outputs = { + stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + )) # }}} diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b78c24301..56b2a53d6 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,9 +926,10 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - return reduce(lambda acc, arg: acc | (arg & self.universe), + from pytools import unique + return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - frozenset()) + tuple()) # }}} From ef7ea0bb74e1543b317558017f03f7be6352e5c1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:33:06 -0500 Subject: [PATCH 02/22] cleanups --- pytato/analysis/__init__.py | 6 ++++- pytato/distributed/partition.py | 43 ++++++++++++++++++--------------- pytato/transform/__init__.py | 6 +++-- setup.py | 1 + 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1a4359e4b..d7a8e3353 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,14 +310,18 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from collections.abc import Set as abc_Set + from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. + .. note:: We only consider the predecessors of a nodes in a data-flow sense. """ @@ -341,7 +345,7 @@ def map_einsum(self, expr: Einsum) -> abc_Set[Array]: | self._get_preds_from_shape(expr.shape)) def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: - from pytato.loopy import LoopyCallResult, LoopyCall + from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) return (FrozenOrderedSet(ary diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 7c9b510e7..2d4b1c93a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,10 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: Tuple[CommunicationOpIdentifier] - ) -> Tuple[CommunicationOpIdentifier]: + self, *args: tuple[CommunicationOpIdentifier] + ) -> tuple[CommunicationOpIdentifier]: from pytools import unique - return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) + return reduce(lambda x, y: tuple(unique(x+y)), args, ()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -497,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: - return tuple() + def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: + return () map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -506,21 +506,21 @@ def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> Tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() + self.local_comm_ids_to_needed_comm_ids[recv_id] = () self.local_recv_id_to_recv_node[recv_id] = expr return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -558,7 +558,7 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] + task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): task_batches[dep_level].append(task_id) @@ -624,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: List[Array] = [] + self.materialized_arrays: list[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -658,7 +658,7 @@ def _set_dict_union_mpi( assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, tuple()) + values + result[key] = result.get(key, ()) + values return result # }}} @@ -783,10 +783,10 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ - from pytools import unique - import mpi4py.MPI as MPI + from pytools import unique + from pytato.transform import SubsetDependencyMapper local_rank = mpi_communicator.rank @@ -837,9 +837,9 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: List[_PartCommIDs] = [] + part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: Tuple[CommunicationOpIdentifier] = tuple() + recv_ids: tuple[CommunicationOpIdentifier] = () for batch in comm_batches: send_ids = tuple( comm_id for comm_id in unique(batch) @@ -857,12 +857,12 @@ def find_distributed_partition( part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=tuple())) + send_ids=())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=tuple(), - send_ids=tuple())) + recv_ids=(), + send_ids=())) nparts = len(part_comm_ids) @@ -908,7 +908,9 @@ def find_distributed_partition( - set(sent_arrays) from pytools import unique - materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) + materialized_arrays = tuple( + a for a in materialized_arrays_collector.materialized_arrays + if a in materialized_arrays_set) # "mso" for "materialized/sent/output" output_arrays = tuple(outputs._data.values()) @@ -927,7 +929,8 @@ def find_distributed_partition( comm_id_to_part_id[send_id]) if __debug__: - recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset + (received_arrays)) mso_ary_to_last_dep_recv_part_id: dict[Array, int] = { ary: max( diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 56b2a53d6..642f52839 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,10 +926,12 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce + from pytools import unique - return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), + return reduce(lambda acc, arg: + unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - tuple()) + ()) # }}} diff --git a/setup.py b/setup.py index ba0bd1b4d..9fe0df6b1 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "immutabledict", "attrs", "bidict", + "orderedsets", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From 817b255ca54f3232bc01fbd9c11ece34b54b7cac Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:44:33 -0500 Subject: [PATCH 03/22] enable determinism test --- test/test_distributed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index ac7ca1389..1554a024b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -899,13 +899,11 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - (_distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) assert next_tag == 4244 - # FIXME: For the next assertion, find_distributed_partition needs to be - # deterministic too (https://github.com/inducer/pytato/pull/465). - # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 + assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # }}} From f3f3c7df968088a05a2792e189d43a70c14e34c5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:54:43 -0500 Subject: [PATCH 04/22] eliminate _OrderedSets --- pytato/distributed/partition.py | 70 +++------------------------------ 1 file changed, 6 insertions(+), 64 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 2d4b1c93a..6d3adb319 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -62,7 +62,6 @@ THE SOFTWARE. """ -import collections from functools import reduce from typing import ( TYPE_CHECKING, @@ -70,8 +69,6 @@ Any, FrozenSet, Hashable, - Iterable, - Iterator, Mapping, Sequence, TypeVar, @@ -131,61 +128,6 @@ class CommunicationOpIdentifier: _ValueT = TypeVar("_ValueT") -# {{{ crude ordered set - - -class _OrderedSet(collections.abc.MutableSet[_ValueT]): - def __init__(self, items: Iterable[_ValueT] | None = None): - # Could probably also use a valueless dictionary; not sure if it matters - self._items: set[_ValueT] = set() - self._items_ordered: list[_ValueT] = [] - if items is not None: - for item in items: - self.add(item) - - def add(self, item: _ValueT) -> None: - if item not in self._items: - self._items.add(item) - self._items_ordered.append(item) - - def discard(self, item: _ValueT) -> None: - # Not currently needed - raise NotImplementedError - - def __len__(self) -> int: - return len(self._items) - - def __iter__(self) -> Iterator[_ValueT]: - return iter(self._items_ordered) - - def __contains__(self, item: Any) -> bool: - return item in self._items - - def __and__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item in other: - result.add(item) - return result - - # Must be "Any" instead of "_ValueT", otherwise it violates Liskov substitution - # according to mypy. *shrug* - def __or__(self, other: AbstractSet[Any]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet(self._items_ordered) - for item in other: - result.add(item) - return result - - def __sub__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item not in other: - result.add(item) - return result - -# }}} - - # {{{ distributed graph part PartId = Hashable @@ -836,7 +778,6 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] if comm_batches: recv_ids: tuple[CommunicationOpIdentifier] = () @@ -986,14 +927,15 @@ def find_distributed_partition( direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: - materialized_preds: _OrderedSet[Array] = _OrderedSet() + def get_materialized_predecessors(ary: Array) -> tuple[Array]: + materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): if pred in materialized_arrays: - materialized_preds.add(pred) + materialized_preds[pred] = None else: - materialized_preds |= get_materialized_predecessors(pred) - return materialized_preds + for p in get_materialized_predecessors(pred): + materialized_preds[p] = None + return tuple(materialized_preds.keys()) stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred From 8bf2daf69260673b03ed8ef6fd2c03d01483eb04 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:29:07 -0500 Subject: [PATCH 05/22] misc improvements --- pytato/distributed/partition.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 6d3adb319..81a493c9a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -503,7 +503,8 @@ def _schedule_task_batches_counted( task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].append(task_id) + if task_id not in task_batches[dep_level]: + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -566,7 +567,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: list[Array] = [] + self.materialized_arrays: dict[Array, None] = {} def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -576,15 +577,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.append(expr) + self.materialized_arrays[expr] = None if isinstance(expr, LoopyCallResult): - self.materialized_arrays.append(expr) + self.materialized_arrays[expr] = None from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.append(subexpr) + self.materialized_arrays[subexpr] = None else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -596,11 +597,12 @@ def post_visit(self, expr: Any) -> None: def _set_dict_union_mpi( dict_a: Mapping[_KeyT, Sequence[_ValueT]], dict_b: Mapping[_KeyT, Sequence[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: + mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None + from pytools import unique result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, ()) + values + result[key] = tuple(unique(result.get(key, ()) + values)) return result # }}} @@ -833,10 +835,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = tuple( - send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) + sent_arrays = tuple(unique( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())) - received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(unique(lsrdg.local_recv_id_to_recv_node.values())) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -849,13 +851,13 @@ def find_distributed_partition( - set(sent_arrays) from pytools import unique - materialized_arrays = tuple( + materialized_arrays = tuple(unique( a for a in materialized_arrays_collector.materialized_arrays - if a in materialized_arrays_set) + if a in materialized_arrays_set)) # "mso" for "materialized/sent/output" - output_arrays = tuple(outputs._data.values()) - mso_arrays = materialized_arrays + sent_arrays + output_arrays + output_arrays = tuple(unique(outputs._data.values())) + mso_arrays = tuple(unique(materialized_arrays + sent_arrays + output_arrays)) # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -870,8 +872,7 @@ def find_distributed_partition( comm_id_to_part_id[send_id]) if __debug__: - recvd_array_dep_mapper = SubsetDependencyMapper(frozenset - (received_arrays)) + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) mso_ary_to_last_dep_recv_part_id: dict[Array, int] = { ary: max( From 5d906b5edb98425f48bcc3aada14725f0c9223b1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:41:48 -0500 Subject: [PATCH 06/22] revert change to SubsetDependencyMapper --- pytato/transform/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 642f52839..b78c24301 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,12 +926,9 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - - from pytools import unique - return reduce(lambda acc, arg: - unique(tuple(acc) + tuple(set(arg) & self.universe)), + return reduce(lambda acc, arg: acc | (arg & self.universe), args, - ()) + frozenset()) # }}} From 142c8e63cf4ca9d5c570fb4e99a3d86594770bc8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:50:03 -0500 Subject: [PATCH 07/22] some mypy fixes --- pytato/distributed/partition.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 81a493c9a..e8f4b1fb2 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -67,7 +67,6 @@ TYPE_CHECKING, AbstractSet, Any, - FrozenSet, Hashable, Mapping, Sequence, @@ -316,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: frozenset[CommunicationOpIdentifier] - send_ids: frozenset[CommunicationOpIdentifier] + recv_ids: tuple[CommunicationOpIdentifier] + send_ids: tuple[CommunicationOpIdentifier] # {{{ _make_distributed_partition @@ -403,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[FrozenSet[CommunicationOpIdentifier]]): + CombineMapper[tuple[CommunicationOpIdentifier]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - frozenset[CommunicationOpIdentifier]] = {} + tuple[CommunicationOpIdentifier]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -425,7 +424,7 @@ def combine( def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -476,7 +475,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> Sequence[AbstractSet[TaskType]]: + -> Sequence[list[TaskType]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -491,7 +490,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> tuple[Sequence[AbstractSet[TaskType]], int]: + -> tuple[Sequence[list[TaskType]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -773,7 +772,7 @@ def find_distributed_partition( raise comm_batches_or_exc comm_batches = cast( - Sequence[AbstractSet[CommunicationOpIdentifier]], + Sequence[list[CommunicationOpIdentifier]], comm_batches_or_exc) # }}} @@ -928,7 +927,7 @@ def find_distributed_partition( direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> tuple[Array]: + def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): if pred in materialized_arrays: From bd7062071e283e7b063da74ca3f3d15c51986854 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 12 Aug 2024 16:12:04 -0700 Subject: [PATCH 08/22] ruff --- pytato/reductions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index 999ef1af4..b0f2b7fb2 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -178,9 +178,9 @@ def _normalize_reduction_axes( raise ValueError(f"{axis} is out of bounds for array of dimension" f" {len(shape)}.") - new_shape = tuple([axis_len + new_shape = tuple(axis_len for i, axis_len in enumerate(shape) - if i not in reduction_axes]) + if i not in reduction_axes) return new_shape, reduction_axes From 076a76ebe152f8d82c47de31a05f25586b981e4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 13 Aug 2024 13:51:12 -0500 Subject: [PATCH 09/22] replace orderedsets with unique tuples in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 70 +++++++++++++++++-------------------- setup.py | 1 - 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fa8ac31e7..883030a43 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method +from pytools import memoize_method, unique from pytato.array import ( Array, @@ -314,11 +314,6 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -from collections.abc import Set as abc_Set - -from orderedsets import FrozenOrderedSet - - class DirectPredecessorsGetter(Mapper): """ Mapper to get the @@ -327,74 +322,75 @@ class DirectPredecessorsGetter(Mapper): of a node. .. note:: + We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) + def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]: + return tuple(unique(dim for dim in shape if isinstance(dim, Array))) - def map_index_lambda(self, expr: IndexLambda) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.bindings.values()) - | self._get_preds_from_shape(expr.shape)) + def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.bindings.values()) + + self._get_preds_from_shape(expr.shape))) - def map_stack(self, expr: Stack) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.arrays) - | self._get_preds_from_shape(expr.shape)) + def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.arrays) + + self._get_preds_from_shape(expr.shape))) - def map_concatenate(self, expr: Concatenate) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.arrays) - | self._get_preds_from_shape(expr.shape)) + map_concatenate = map_stack - def map_einsum(self, expr: Einsum) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.args) - | self._get_preds_from_shape(expr.shape)) + def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.args) + + self._get_preds_from_shape(expr.shape))) - def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (FrozenOrderedSet(ary + return tuple(unique(tuple(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) - | self._get_preds_from_shape(expr.shape)) + + self._get_preds_from_shape(expr.shape))) - def _map_index_base(self, expr: IndexBase) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet([expr.array]) - | FrozenOrderedSet(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]: + return tuple(unique((expr.array,) # noqa: RUF005 + + tuple(idx for idx in expr.indices if isinstance(idx, Array)) - | self._get_preds_from_shape(expr.shape)) + + self._get_preds_from_shape(expr.shape))) map_basic_index = _map_index_base map_contiguous_advanced_index = _map_index_base map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr.array]) + ) -> tuple[ArrayOrNames]: + return (expr.array,) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr.passthrough_data]) + ) -> tuple[ArrayOrNames]: + return (expr.passthrough_data,) + + def map_call(self, expr: Call) -> tuple[ArrayOrNames]: + return tuple(unique(expr.bindings.values())) - def map_call(self, expr: Call) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_named_call_result( + self, expr: NamedCallResult) -> tuple[ArrayOrNames]: + return (expr._container,) - def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr._container]) # }}} diff --git a/setup.py b/setup.py index 9fe0df6b1..ba0bd1b4d 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,6 @@ "immutabledict", "attrs", "bidict", - "orderedsets", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From ea1462c0bd71ea8bf2a6d80e26e83fdbc255fc57 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 11:37:53 -0500 Subject: [PATCH 10/22] mypy fixes --- pytato/analysis/__init__.py | 18 +++++++++--------- pytato/distributed/partition.py | 26 +++++++++++++------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 883030a43..c568c8f9c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -325,24 +325,24 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]: + def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames, ...]: return tuple(unique(dim for dim in shape if isinstance(dim, Array))) - def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]: + def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.bindings.values()) + self._get_preds_from_shape(expr.shape))) - def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]: + def map_stack(self, expr: Stack) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.arrays) + self._get_preds_from_shape(expr.shape))) map_concatenate = map_stack - def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]: + def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.args) + self._get_preds_from_shape(expr.shape))) - def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: + def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames, ...]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) @@ -351,7 +351,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: if isinstance(ary, Array)) + self._get_preds_from_shape(expr.shape))) - def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]: + def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames, ...]: return tuple(unique((expr.array,) # noqa: RUF005 + tuple(idx for idx in expr.indices if isinstance(idx, Array)) @@ -369,14 +369,14 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames, ...]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames, ...]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, @@ -384,7 +384,7 @@ def map_distributed_send_ref_holder(self, ) -> tuple[ArrayOrNames]: return (expr.passthrough_data,) - def map_call(self, expr: Call) -> tuple[ArrayOrNames]: + def map_call(self, expr: Call) -> tuple[ArrayOrNames, ...]: return tuple(unique(expr.bindings.values())) def map_named_call_result( diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index e8f4b1fb2..68a924c8a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -315,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: tuple[CommunicationOpIdentifier] - send_ids: tuple[CommunicationOpIdentifier] + recv_ids: tuple[CommunicationOpIdentifier, ...] + send_ids: tuple[CommunicationOpIdentifier, ...] # {{{ _make_distributed_partition @@ -402,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[tuple[CommunicationOpIdentifier]]): + CombineMapper[tuple[CommunicationOpIdentifier, ...]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - tuple[CommunicationOpIdentifier]] = {} + tuple[CommunicationOpIdentifier, ...]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -417,14 +417,14 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: tuple[CommunicationOpIdentifier] - ) -> tuple[CommunicationOpIdentifier]: + self, *args: tuple[CommunicationOpIdentifier, ...] + ) -> tuple[CommunicationOpIdentifier, ...]: from pytools import unique return reduce(lambda x, y: tuple(unique(x+y)), args, ()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier, ...]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -438,7 +438,7 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: + def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: return () map_placeholder = _map_input_base @@ -447,7 +447,7 @@ def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier, ...]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: @@ -461,7 +461,7 @@ def map_distributed_recv( return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier, ...]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -594,8 +594,8 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, Sequence[_ValueT]], - dict_b: Mapping[_KeyT, Sequence[_ValueT]], + dict_a: Mapping[_KeyT, tuple[_ValueT, ...]], + dict_b: Mapping[_KeyT, tuple[_ValueT, ...]], mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None from pytools import unique @@ -781,7 +781,7 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: tuple[CommunicationOpIdentifier] = () + recv_ids: tuple[CommunicationOpIdentifier, ...] = () for batch in comm_batches: send_ids = tuple( comm_id for comm_id in unique(batch) From 168ef532057be4e81649acb9353510a28ed62d84 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 11:48:48 -0500 Subject: [PATCH 11/22] remove unnecesary cast --- pytato/distributed/partition.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 68a924c8a..38f26c8d5 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -771,9 +771,7 @@ def find_distributed_partition( if isinstance(comm_batches_or_exc, Exception): raise comm_batches_or_exc - comm_batches = cast( - Sequence[list[CommunicationOpIdentifier]], - comm_batches_or_exc) + comm_batches = comm_batches_or_exc # }}} From d711989c8cc0198cafdbbd94d294f384667b7c25 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 14:43:43 -0500 Subject: [PATCH 12/22] adjust comment --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 38f26c8d5..dea81e925 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -829,7 +829,7 @@ def find_distributed_partition( materialized_arrays_collector = _MaterializedArrayCollector() materialized_arrays_collector(outputs) - # The sets of arrays below must have a deterministic order in order to ensure + # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic sent_arrays = tuple(unique( From 58478004b85604d9ac0b238ab55d601326ef3b94 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 13:01:19 -0500 Subject: [PATCH 13/22] performance fix --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index dea81e925..03691e6ae 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -928,7 +928,7 @@ def find_distributed_partition( def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): - if pred in materialized_arrays: + if pred in materialized_arrays_set: materialized_preds[pred] = None else: for p in get_materialized_predecessors(pred): From 7dd83bb7ab25e701f3a9b5d05477d991d0fbb1fa Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 14:22:21 -0500 Subject: [PATCH 14/22] switch to dicts --- pytato/analysis/__init__.py | 62 ++++++++++++++++---------------- pytato/distributed/partition.py | 63 +++++++++++++++------------------ 2 files changed, 61 insertions(+), 64 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c568c8f9c..e721a7a8a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method, unique +from pytools import memoize_method from pytato.array import ( Array, @@ -325,71 +325,73 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames, ...]: - return tuple(unique(dim for dim in shape if isinstance(dim, Array))) + def _get_preds_from_shape(self, shape: ShapeType) -> dict[Array, None]: + return dict.fromkeys(dim for dim in shape if isinstance(dim, Array)) - def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.bindings.values()) - + self._get_preds_from_shape(expr.shape))) + def map_index_lambda(self, expr: IndexLambda) -> dict[Array, None]: + return (dict.fromkeys(expr.bindings.values()) + | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.arrays) - + self._get_preds_from_shape(expr.shape))) + def map_stack(self, expr: Stack) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) + | self._get_preds_from_shape(expr.shape)) - map_concatenate = map_stack + def map_concatenate(self, expr: Concatenate) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) + | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.args) - + self._get_preds_from_shape(expr.shape))) + def map_einsum(self, expr: Einsum) -> dict[Array, None]: + return (dict.fromkeys(expr.args) + | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames, ...]: + def map_loopy_call_result(self, expr: NamedArray) -> dict[Array, None]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return tuple(unique(tuple(ary + return (dict.fromkeys(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) - + self._get_preds_from_shape(expr.shape))) + | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames, ...]: - return tuple(unique((expr.array,) # noqa: RUF005 - + tuple(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> dict[Array, None]: + return (dict.fromkeys([expr.array]) + | dict.fromkeys(idx for idx in expr.indices if isinstance(idx, Array)) - + self._get_preds_from_shape(expr.shape))) + | self._get_preds_from_shape(expr.shape)) map_basic_index = _map_index_base map_contiguous_advanced_index = _map_index_base map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> tuple[ArrayOrNames]: - return (expr.array,) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames, ...]: + def _map_input_base(self, expr: InputArgumentBase) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames, ...]: + def map_distributed_recv(self, expr: DistributedRecv) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[ArrayOrNames]: - return (expr.passthrough_data,) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.passthrough_data]) - def map_call(self, expr: Call) -> tuple[ArrayOrNames, ...]: - return tuple(unique(expr.bindings.values())) + def map_call(self, expr: Call) -> dict[ArrayOrNames, None]: + return dict.fromkeys(expr.bindings.values()) def map_named_call_result( - self, expr: NamedCallResult) -> tuple[ArrayOrNames]: - return (expr._container,) + self, expr: NamedCallResult) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr._container]) # }}} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 03691e6ae..9e9f47913 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -315,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: tuple[CommunicationOpIdentifier, ...] - send_ids: tuple[CommunicationOpIdentifier, ...] + recv_ids: immutabledict[CommunicationOpIdentifier, None] + send_ids: immutabledict[CommunicationOpIdentifier, None] # {{{ _make_distributed_partition @@ -727,8 +727,7 @@ def find_distributed_partition( assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ import mpi4py.MPI as MPI - - from pytools import unique + from immutabledict import immutabledict from pytato.transform import SubsetDependencyMapper @@ -779,30 +778,31 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: tuple[CommunicationOpIdentifier, ...] = () + recv_ids: immutabledict[CommunicationOpIdentifier, None] = immutabledict() for batch in comm_batches: - send_ids = tuple( - comm_id for comm_id in unique(batch) - if comm_id.src_rank == local_rank) + send_ids: immutabledict[CommunicationOpIdentifier, None] \ + = immutabledict.fromkeys( + comm_id for comm_id in batch + if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = tuple( - comm_id for comm_id in unique(batch) + recv_ids = immutabledict.fromkeys( + comm_id for comm_id in batch if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=())) + send_ids=immutabledict())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=(), - send_ids=())) + recv_ids=immutabledict(), + send_ids=immutabledict())) nparts = len(part_comm_ids) @@ -820,7 +820,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} + for comm_id in comm_ids.send_ids | comm_ids.recv_ids} # }}} @@ -832,10 +832,10 @@ def find_distributed_partition( # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = tuple(unique( - send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())) + sent_arrays = dict.fromkeys( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = tuple(unique(lsrdg.local_recv_id_to_recv_node.values())) + received_arrays = dict.fromkeys(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -843,18 +843,13 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ - - set(received_arrays) \ - - set(sent_arrays) - - from pytools import unique - materialized_arrays = tuple(unique( - a for a in materialized_arrays_collector.materialized_arrays - if a in materialized_arrays_set)) + materialized_arrays = {a: None + for a in materialized_arrays_collector.materialized_arrays + if a not in received_arrays | sent_arrays} # "mso" for "materialized/sent/output" - output_arrays = tuple(unique(outputs._data.values())) - mso_arrays = tuple(unique(materialized_arrays + sent_arrays + output_arrays)) + output_arrays = dict.fromkeys(outputs._data.values()) + mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -918,30 +913,30 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = tuple(unique(stored_ary_to_part_id)) + stored_arrays = dict.fromkeys(stored_ary_to_part_id) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: + def get_materialized_predecessors(ary: Array) -> dict[Array, None]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): - if pred in materialized_arrays_set: + if pred in materialized_arrays: materialized_preds[pred] = None else: for p in get_materialized_predecessors(pred): materialized_preds[p] = None - return tuple(materialized_preds.keys()) + return materialized_preds - stored_arrays_promoted_to_part_outputs = tuple(unique( - stored_pred + stored_arrays_promoted_to_part_outputs = { + stored_pred: None for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - )) + } # }}} From 1ea962cc9460b6407e453dd5bebfc28cc02830ae Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 14:40:51 -0500 Subject: [PATCH 15/22] more dict usage --- pytato/distributed/partition.py | 43 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 9e9f47913..111f07d2e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -402,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[tuple[CommunicationOpIdentifier, ...]]): + CombineMapper[dict[CommunicationOpIdentifier, None]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - tuple[CommunicationOpIdentifier, ...]] = {} + dict[CommunicationOpIdentifier, None]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -417,14 +417,13 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: tuple[CommunicationOpIdentifier, ...] - ) -> tuple[CommunicationOpIdentifier, ...]: - from pytools import unique - return reduce(lambda x, y: tuple(unique(x+y)), args, ()) + self, *args: dict[CommunicationOpIdentifier, None] + ) -> dict[CommunicationOpIdentifier, None]: + return reduce(lambda x, y: x | y, args, {}) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[CommunicationOpIdentifier, ...]: + ) -> dict[CommunicationOpIdentifier, None]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -438,8 +437,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: - return () + def _map_input_base(self, expr: Array) -> dict[CommunicationOpIdentifier, None]: + return {} map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -447,21 +446,21 @@ def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: def map_distributed_recv( self, expr: DistributedRecv - ) -> tuple[CommunicationOpIdentifier, ...]: + ) -> dict[CommunicationOpIdentifier, None]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = () + self.local_comm_ids_to_needed_comm_ids[recv_id] = {} self.local_recv_id_to_recv_node[recv_id] = expr - return (recv_id,) + return {recv_id: None} def map_named_call_result( - self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier, ...]: + self, expr: NamedCallResult) -> dict[CommunicationOpIdentifier, None]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -475,7 +474,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> Sequence[list[TaskType]]: + -> Sequence[dict[TaskType, None]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -490,7 +489,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> tuple[Sequence[list[TaskType]], int]: + -> tuple[Sequence[dict[TaskType, None]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -499,11 +498,11 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] + task_batches: Sequence[dict[TaskType, None]] = [{} for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): if task_id not in task_batches[dep_level]: - task_batches[dep_level].append(task_id) + task_batches[dep_level][task_id] = None return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -594,14 +593,14 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, tuple[_ValueT, ...]], - dict_b: Mapping[_KeyT, tuple[_ValueT, ...]], - mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: + dict_a: Mapping[_KeyT, dict[_ValueT, None]], + dict_b: Mapping[_KeyT, dict[_ValueT, None]], + mpi_data_type: mpi4py.MPI.Datatype | None) \ + -> Mapping[_KeyT, dict[_ValueT, None]]: assert mpi_data_type is None - from pytools import unique result = dict(dict_a) for key, values in dict_b.items(): - result[key] = tuple(unique(result.get(key, ()) + values)) + result[key] = result.get(key, {}) | values return result # }}} From e8b5806732bdbf3f415f53df71a20b109cda10cd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 14 Oct 2024 11:46:22 -0500 Subject: [PATCH 16/22] fix materialized_arrays perf --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 111f07d2e..cf75e6ed7 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -844,7 +844,7 @@ def find_distributed_partition( # processing below. materialized_arrays = {a: None for a in materialized_arrays_collector.materialized_arrays - if a not in received_arrays | sent_arrays} + if a not in received_arrays and a not in sent_arrays} # "mso" for "materialized/sent/output" output_arrays = dict.fromkeys(outputs._data.values()) From 94658af13b967796dd8759acd4a26b1d21b8a20f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 14 Nov 2024 10:28:23 -0600 Subject: [PATCH 17/22] Fix imports --- pytato/distributed/partition.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 31e7252f1..f27a5e10e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -62,22 +62,11 @@ THE SOFTWARE. """ -import collections -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set +from collections.abc import Hashable, Mapping, Sequence, Set from functools import reduce from typing import ( TYPE_CHECKING, Any, - Hashable, - Mapping, - Sequence, - FrozenSet, - Hashable, - Iterable, - Iterator, - Mapping, - Sequence, - Generic, TypeVar, cast, ) From 3a30e88f0003505bce82f32bb772150d823427a9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 19 Nov 2024 17:48:10 -0600 Subject: [PATCH 18/22] ruff --- pytato/distributed/partition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 2924e5213..733b84ad4 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -62,9 +62,8 @@ THE SOFTWARE. """ -import collections import dataclasses -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set +from collections.abc import Hashable, Mapping, Sequence, Set from functools import reduce from typing import ( TYPE_CHECKING, From 3530b53a68c6dbb01aa5588659d8ea0cfa09a5a9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Dec 2024 14:32:14 -0600 Subject: [PATCH 19/22] use operator.or_ in reduction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/distributed/partition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 733b84ad4..3209f62c0 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -414,7 +414,8 @@ def __init__(self, local_rank: int) -> None: def combine( self, *args: dict[CommunicationOpIdentifier, None] ) -> dict[CommunicationOpIdentifier, None]: - return reduce(lambda x, y: x | y, args, {}) + import operator + return reduce(operator.or_, args, {}) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder From fbcbcef7a718690ea6ac118b0d0dd1ac2a5c9fb4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Dec 2024 14:59:19 -0600 Subject: [PATCH 20/22] use a FakeOrderedFrozenSet type --- pytato/analysis/__init__.py | 60 +++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index b5c8cbb0f..fb1f78164 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,7 +27,9 @@ """ from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar + +from immutabledict import immutabledict from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper @@ -73,8 +75,13 @@ """ +T = TypeVar("T") + +FakeOrderedFrozenSet: TypeAlias = immutabledict[T, None] + # {{{ NUserCollector + class NUserCollector(Mapper[None, []]): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of @@ -327,37 +334,37 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> dict[Array, None]: - return dict.fromkeys(dim for dim in shape if isinstance(dim, Array)) + def _get_preds_from_shape(self, shape: ShapeType) -> FakeOrderedFrozenSet[Array]: + return immutabledict.fromkeys(dim for dim in shape if isinstance(dim, Array)) - def map_index_lambda(self, expr: IndexLambda) -> dict[Array, None]: - return (dict.fromkeys(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> FakeOrderedFrozenSet[Array]: + return (immutabledict.fromkeys(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> dict[Array, None]: - return (dict.fromkeys(expr.arrays) + def map_stack(self, expr: Stack) -> FakeOrderedFrozenSet[Array]: + return (immutabledict.fromkeys(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> dict[Array, None]: - return (dict.fromkeys(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> FakeOrderedFrozenSet[Array]: + return (immutabledict.fromkeys(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> dict[Array, None]: - return (dict.fromkeys(expr.args) + def map_einsum(self, expr: Einsum) -> FakeOrderedFrozenSet[Array]: + return (immutabledict.fromkeys(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> dict[Array, None]: + def map_loopy_call_result(self, expr: NamedArray) -> FakeOrderedFrozenSet[Array]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (dict.fromkeys(ary + return (immutabledict.fromkeys(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> dict[Array, None]: - return (dict.fromkeys([expr.array]) - | dict.fromkeys(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> FakeOrderedFrozenSet[Array]: + return (immutabledict.fromkeys([expr.array]) + | immutabledict.fromkeys(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -366,34 +373,35 @@ def _map_index_base(self, expr: IndexBase) -> dict[Array, None]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> dict[ArrayOrNames, None]: - return dict.fromkeys([expr.array]) + ) -> FakeOrderedFrozenSet[ArrayOrNames]: + return immutabledict.fromkeys([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> dict[Array, None]: + def _map_input_base(self, expr: InputArgumentBase) -> FakeOrderedFrozenSet[Array]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> dict[Array, None]: + def map_distributed_recv(self, + expr: DistributedRecv) -> FakeOrderedFrozenSet[Array]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> dict[ArrayOrNames, None]: - return dict.fromkeys([expr.passthrough_data]) + ) -> FakeOrderedFrozenSet[ArrayOrNames]: + return immutabledict.fromkeys([expr.passthrough_data]) - def map_call(self, expr: Call) -> dict[ArrayOrNames, None]: - return dict.fromkeys(expr.bindings.values()) + def map_call(self, expr: Call) -> FakeOrderedFrozenSet[ArrayOrNames]: + return immutabledict.fromkeys(expr.bindings.values()) def map_named_call_result( - self, expr: NamedCallResult) -> dict[ArrayOrNames, None]: - return dict.fromkeys([expr._container]) + self, expr: NamedCallResult) -> FakeOrderedFrozenSet[ArrayOrNames]: + return immutabledict.fromkeys([expr._container]) # }}} From bc87f3a185bcbf71f377317731ab70b95ed5a594 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Dec 2024 16:31:47 -0600 Subject: [PATCH 21/22] extend FakeOrderedSet typing --- pytato/analysis/__init__.py | 1 + pytato/distributed/partition.py | 74 +++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fb1f78164..1602d8f7a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -78,6 +78,7 @@ T = TypeVar("T") FakeOrderedFrozenSet: TypeAlias = immutabledict[T, None] +FakeOrderedSet: TypeAlias = dict[T, None] # {{{ NUserCollector diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 3209f62c0..0f9e68789 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -78,7 +78,11 @@ from pytools import UniqueNameGenerator, memoize_method from pytools.graph import CycleError -from pytato.analysis import DirectPredecessorsGetter +from pytato.analysis import ( + DirectPredecessorsGetter, + FakeOrderedFrozenSet, + FakeOrderedSet, +) from pytato.array import Array, DictOfNamedArrays, Placeholder, make_placeholder from pytato.distributed.nodes import ( CommTagType, @@ -311,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: immutabledict[CommunicationOpIdentifier, None] - send_ids: immutabledict[CommunicationOpIdentifier, None] + recv_ids: FakeOrderedFrozenSet[CommunicationOpIdentifier] + send_ids: FakeOrderedFrozenSet[CommunicationOpIdentifier] # {{{ _make_distributed_partition @@ -397,12 +401,13 @@ def _recv_to_comm_id( comm_tag=recv.comm_tag) -class _LocalSendRecvDepGatherer(CombineMapper[dict[CommunicationOpIdentifier, None]]): +class _LocalSendRecvDepGatherer(CombineMapper[ + FakeOrderedFrozenSet[CommunicationOpIdentifier]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - dict[CommunicationOpIdentifier, None]] = {} + FakeOrderedFrozenSet[CommunicationOpIdentifier]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -412,14 +417,13 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: dict[CommunicationOpIdentifier, None] - ) -> dict[CommunicationOpIdentifier, None]: + self, *args: FakeOrderedFrozenSet[CommunicationOpIdentifier] + ) -> FakeOrderedFrozenSet[CommunicationOpIdentifier]: import operator - return reduce(operator.or_, args, {}) + return reduce(operator.or_, args, immutabledict()) - def map_distributed_send_ref_holder(self, - expr: DistributedSendRefHolder - ) -> dict[CommunicationOpIdentifier, None]: + def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder) \ + -> FakeOrderedFrozenSet[CommunicationOpIdentifier]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -433,8 +437,9 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> dict[CommunicationOpIdentifier, None]: - return {} + def _map_input_base(self, expr: Array) \ + -> FakeOrderedFrozenSet[CommunicationOpIdentifier]: + return immutabledict() map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -442,21 +447,22 @@ def _map_input_base(self, expr: Array) -> dict[CommunicationOpIdentifier, None]: def map_distributed_recv( self, expr: DistributedRecv - ) -> dict[CommunicationOpIdentifier, None]: + ) -> FakeOrderedFrozenSet[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = {} + self.local_comm_ids_to_needed_comm_ids[recv_id] = immutabledict() self.local_recv_id_to_recv_node[recv_id] = expr - return {recv_id: None} + return immutabledict({recv_id: None}) def map_named_call_result( - self, expr: NamedCallResult) -> dict[CommunicationOpIdentifier, None]: + self, expr: NamedCallResult) \ + -> FakeOrderedFrozenSet[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -470,7 +476,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]] - ) -> Sequence[dict[TaskType, None]]: + ) -> Sequence[FakeOrderedSet[TaskType]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -485,7 +491,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ - -> tuple[Sequence[dict[TaskType, None]], int]: + -> tuple[Sequence[FakeOrderedSet[TaskType]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -494,7 +500,7 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[dict[TaskType, None]] = [{} for _ in range(nlevels)] + task_batches: Sequence[FakeOrderedSet[TaskType]] = [{} for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): if task_id not in task_batches[dep_level]: @@ -561,7 +567,7 @@ class _MaterializedArrayCollector(CachedWalkMapper[[]]): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: dict[Array, None] = {} + self.materialized_arrays: FakeOrderedSet[Array] = {} def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -592,7 +598,7 @@ def _set_dict_union_mpi( dict_a: Mapping[_KeyT, dict[_ValueT, None]], dict_b: Mapping[_KeyT, dict[_ValueT, None]], mpi_data_type: mpi4py.MPI.Datatype | None) \ - -> Mapping[_KeyT, dict[_ValueT, None]]: + -> Mapping[_KeyT, FakeOrderedSet[_ValueT]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): @@ -773,9 +779,9 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: immutabledict[CommunicationOpIdentifier, None] = immutabledict() + recv_ids: FakeOrderedFrozenSet[CommunicationOpIdentifier] = immutabledict() for batch in comm_batches: - send_ids: immutabledict[CommunicationOpIdentifier, None] \ + send_ids: FakeOrderedFrozenSet[CommunicationOpIdentifier] \ = immutabledict.fromkeys( comm_id for comm_id in batch if comm_id.src_rank == local_rank) @@ -827,10 +833,11 @@ def find_distributed_partition( # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = dict.fromkeys( + sent_arrays: FakeOrderedFrozenSet[Array] = immutabledict.fromkeys( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = dict.fromkeys(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays: FakeOrderedFrozenSet[Array] = immutabledict.fromkeys( + lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -843,7 +850,8 @@ def find_distributed_partition( if a not in received_arrays and a not in sent_arrays} # "mso" for "materialized/sent/output" - output_arrays = dict.fromkeys(outputs._data.values()) + output_arrays: FakeOrderedFrozenSet[Array] = immutabledict.fromkeys( + outputs._data.values()) mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to @@ -893,7 +901,7 @@ def find_distributed_partition( recvd_ary_to_part_id: dict[Array, int] = { recvd_ary: ( comm_id_to_part_id[ - _recv_to_comm_id(local_rank, recvd_ary)]) + _recv_to_comm_id(local_rank, cast(DistributedRecv, recvd_ary))]) for recvd_ary in received_arrays} # "Materialized" arrays are arrays that are tagged with ImplStored, @@ -908,14 +916,15 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = dict.fromkeys(stored_ary_to_part_id) + stored_arrays: FakeOrderedFrozenSet[Array] = immutabledict.fromkeys( + stored_ary_to_part_id) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> dict[Array, None]: + def get_materialized_predecessors(ary: Array) -> FakeOrderedSet[Array]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): assert isinstance(pred, Array) @@ -926,13 +935,14 @@ def get_materialized_predecessors(ary: Array) -> dict[Array, None]: materialized_preds[p] = None return materialized_preds - stored_arrays_promoted_to_part_outputs = { + stored_arrays_promoted_to_part_outputs: FakeOrderedFrozenSet[Array] \ + = immutabledict({ stored_pred: None for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + }) # }}} From 4424a2e4efbfef5480882a5c55a11b6a908f4fba Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Dec 2024 14:17:18 -0600 Subject: [PATCH 22/22] ruff --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index a66da244f..ede9c7d9f 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -902,7 +902,7 @@ def find_distributed_partition( recvd_ary_to_part_id: dict[Array, int] = { recvd_ary: ( comm_id_to_part_id[ - _recv_to_comm_id(local_rank, cast(DistributedRecv, recvd_ary))]) + _recv_to_comm_id(local_rank, cast("DistributedRecv", recvd_ary))]) for recvd_ary in received_arrays} # "Materialized" arrays are arrays that are tagged with ImplStored,