Skip to content

Commit

Permalink
working pass 1
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jul 25, 2024
1 parent 550920a commit 3364a4f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 67 deletions.
50 changes: 25 additions & 25 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,48 +310,48 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

from orderedsets import FrozenOrderedSet

class DirectPredecessorsGetter(Mapper):
"""
Mapper to get the
`direct predecessors
<https://en.wikipedia.org/wiki/Glossary_of_graph_theory#direct_predecessor>`__
of a node.
.. note::
We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})
def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[Array]:
return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)])

def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]:
return (frozenset(expr.bindings.values())
def map_index_lambda(self, expr: IndexLambda) -> abc_Set[Array]:
return (FrozenOrderedSet(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> frozenset[Array]:
return (frozenset(expr.arrays)
def map_stack(self, expr: Stack) -> abc_Set[Array]:
return (FrozenOrderedSet(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> frozenset[Array]:
return (frozenset(expr.arrays)
def map_concatenate(self, expr: Concatenate) -> abc_Set[Array]:
return (FrozenOrderedSet(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> frozenset[Array]:
return (frozenset(expr.args)
def map_einsum(self, expr: Einsum) -> abc_Set[Array]:
return (FrozenOrderedSet(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]:
from pytato.loopy import LoopyCall, LoopyCallResult
def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]:
from pytato.loopy import LoopyCallResult, LoopyCall
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
return (frozenset(ary
return (FrozenOrderedSet(ary
for ary in expr._container.bindings.values()
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
def _map_index_base(self, expr: IndexBase) -> abc_Set[Array]:
return (FrozenOrderedSet([expr.array])
| FrozenOrderedSet(idx for idx in expr.indices
if isinstance(idx, Array))
| self._get_preds_from_shape(expr.shape))

Expand All @@ -360,29 +360,29 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> frozenset[Array]:
return frozenset([expr.array])
) -> abc_Set[Array]:
return FrozenOrderedSet([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]:
def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[Array]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]:
def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[Array]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> frozenset[Array]:
return frozenset([expr.passthrough_data])
) -> abc_Set[Array]:
return FrozenOrderedSet([expr.passthrough_data])

def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]:
def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[Array]:
raise NotImplementedError(
"DirectPredecessorsGetter does not yet support expressions containing "
"functions.")
Expand Down
86 changes: 46 additions & 40 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,10 @@ def __init__(self, local_rank: int) -> None:
self.local_rank = local_rank

def combine(
self, *args: frozenset[CommunicationOpIdentifier]
) -> frozenset[CommunicationOpIdentifier]:
return reduce(frozenset.union, args, frozenset())
self, *args: Tuple[CommunicationOpIdentifier]
) -> Tuple[CommunicationOpIdentifier]:
from pytools import unique
return reduce(lambda x, y: tuple(unique(x+y)), args, tuple())

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
Expand All @@ -496,30 +497,30 @@ def map_distributed_send_ref_holder(self,

return self.rec(expr.passthrough_data)

def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]:
return frozenset()
def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]:
return tuple()

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(
self, expr: DistributedRecv
) -> frozenset[CommunicationOpIdentifier]:
) -> Tuple[CommunicationOpIdentifier]:
recv_id = _recv_to_comm_id(self.local_rank, expr)

if recv_id in self.local_recv_id_to_recv_node:
from pytato.distributed.verify import DuplicateRecvError
raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'")

self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset()
self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple()

self.local_recv_id_to_recv_node[recv_id] = expr

return frozenset({recv_id})
return (recv_id,)

def map_named_call_result(
self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]:
self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]:
raise NotImplementedError(
"LocalSendRecvDepGatherer does not support functions.")

Expand Down Expand Up @@ -557,10 +558,10 @@ def _schedule_task_batches_counted(
task_to_dep_level, visits_in_depend = \
_calculate_dependency_levels(task_ids_to_needed_task_ids)
nlevels = 1 + max(task_to_dep_level.values(), default=-1)
task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)]
task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)]

for task_id, dep_level in task_to_dep_level.items():
task_batches[dep_level].add(task_id)
task_batches[dep_level].append(task_id)

return task_batches, visits_in_depend + len(task_to_dep_level.keys())

Expand Down Expand Up @@ -623,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper):
"""
def __init__(self) -> None:
super().__init__()
self.materialized_arrays: _OrderedSet[Array] = _OrderedSet()
self.materialized_arrays: List[Array] = []

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)
Expand All @@ -633,15 +634,15 @@ def post_visit(self, expr: Any) -> None:
from pytato.tags import ImplStored

if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)):
self.materialized_arrays.add(expr)
self.materialized_arrays.append(expr)

if isinstance(expr, LoopyCallResult):
self.materialized_arrays.add(expr)
self.materialized_arrays.append(expr)
from pytato.loopy import LoopyCall
assert isinstance(expr._container, LoopyCall)
for _, subexpr in sorted(expr._container.bindings.items()):
if isinstance(subexpr, Array):
self.materialized_arrays.add(subexpr)
self.materialized_arrays.append(subexpr)
else:
assert isinstance(subexpr, SCALAR_CLASSES)

Expand All @@ -651,13 +652,13 @@ def post_visit(self, expr: Any) -> None:
# {{{ _set_dict_union_mpi

def _set_dict_union_mpi(
dict_a: Mapping[_KeyT, frozenset[_ValueT]],
dict_b: Mapping[_KeyT, frozenset[_ValueT]],
mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]:
dict_a: Mapping[_KeyT, Sequence[_ValueT]],
dict_b: Mapping[_KeyT, Sequence[_ValueT]],
mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]:
assert mpi_data_type is None
result = dict(dict_a)
for key, values in dict_b.items():
result[key] = result.get(key, frozenset()) | values
result[key] = result.get(key, tuple()) + values
return result

# }}}
Expand Down Expand Up @@ -782,6 +783,8 @@ def find_distributed_partition(
- Gather sent arrays into
assigned in :attr:`DistributedGraphPart.name_to_send_nodes`.
"""
from pytools import unique

import mpi4py.MPI as MPI

from pytato.transform import SubsetDependencyMapper
Expand Down Expand Up @@ -833,32 +836,33 @@ def find_distributed_partition(

# {{{ create (local) parts out of batch ids

part_comm_ids: list[_PartCommIDs] = []

part_comm_ids: List[_PartCommIDs] = []
if comm_batches:
recv_ids: frozenset[CommunicationOpIdentifier] = frozenset()
recv_ids: Tuple[CommunicationOpIdentifier] = tuple()
for batch in comm_batches:
send_ids = frozenset(
comm_id for comm_id in batch
send_ids = tuple(
comm_id for comm_id in unique(batch)
if comm_id.src_rank == local_rank)
if recv_ids or send_ids:
part_comm_ids.append(
_PartCommIDs(
recv_ids=recv_ids,
send_ids=send_ids))
# These go into the next part
recv_ids = frozenset(
comm_id for comm_id in batch
recv_ids = tuple(
comm_id for comm_id in unique(batch)
if comm_id.dest_rank == local_rank)
if recv_ids:
part_comm_ids.append(
_PartCommIDs(
recv_ids=recv_ids,
send_ids=frozenset()))
send_ids=tuple()))
else:
part_comm_ids.append(
_PartCommIDs(
recv_ids=frozenset(),
send_ids=frozenset()))
recv_ids=tuple(),
send_ids=tuple()))

nparts = len(part_comm_ids)

Expand All @@ -876,7 +880,7 @@ def find_distributed_partition(
comm_id_to_part_id = {
comm_id: ipart
for ipart, comm_ids in enumerate(part_comm_ids)
for comm_id in comm_ids.send_ids | comm_ids.recv_ids}
for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)}

# }}}

Expand All @@ -888,25 +892,27 @@ def find_distributed_partition(
# The sets of arrays below must have a deterministic order in order to ensure
# that the resulting partition is also deterministic

sent_arrays = _OrderedSet(
sent_arrays = tuple(
send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())

received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values())
received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values())

# While receive nodes may be marked as materialized, we shouldn't be
# including them here because we're using them (along with the send nodes)
# as anchors to place *other* materialized data into the batches.
# We could allow sent *arrays* to be included here because they are distinct
# from send *nodes*, but we choose to exclude them in order to simplify the
# processing below.
materialized_arrays = (
materialized_arrays_collector.materialized_arrays
- received_arrays
- sent_arrays)
materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \
- set(received_arrays) \
- set(sent_arrays)

from pytools import unique
materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set)

# "mso" for "materialized/sent/output"
output_arrays = _OrderedSet(outputs._data.values())
mso_arrays = materialized_arrays | sent_arrays | output_arrays
output_arrays = tuple(outputs._data.values())
mso_arrays = materialized_arrays + sent_arrays + output_arrays

# FIXME: This gathers up materialized_arrays recursively, leading to
# result sizes potentially quadratic in the number of materialized arrays.
Expand Down Expand Up @@ -970,7 +976,7 @@ def find_distributed_partition(
assert all(0 <= part_id < nparts
for part_id in stored_ary_to_part_id.values())

stored_arrays = _OrderedSet(stored_ary_to_part_id)
stored_arrays = tuple(unique(stored_ary_to_part_id))

# {{{ find which stored arrays should become part outputs
# (because they are used in not just their local part, but also others)
Expand All @@ -986,13 +992,13 @@ def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]:
materialized_preds |= get_materialized_predecessors(pred)
return materialized_preds

stored_arrays_promoted_to_part_outputs = {
stored_arrays_promoted_to_part_outputs = tuple(unique(
stored_pred
for stored_ary in stored_arrays
for stored_pred in get_materialized_predecessors(stored_ary)
if (stored_ary_to_part_id[stored_ary]
!= stored_ary_to_part_id[stored_pred])
}
))

# }}}

Expand Down
5 changes: 3 additions & 2 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,10 @@ def __init__(self, universe: frozenset[Array]):

def combine(self, *args: frozenset[Array]) -> frozenset[Array]:
from functools import reduce
return reduce(lambda acc, arg: acc | (arg & self.universe),
from pytools import unique
return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)),
args,
frozenset())
tuple())

# }}}

Expand Down

0 comments on commit 3364a4f

Please sign in to comment.