Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic find_distributed_partition (non-set) #529

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3364a4f
working pass 1
matthiasdiener Jul 25, 2024
ef7ea0b
cleanups
matthiasdiener Jul 25, 2024
817b255
enable determinism test
matthiasdiener Jul 25, 2024
f3f3c7d
eliminate _OrderedSets
matthiasdiener Jul 25, 2024
8bf2daf
misc improvements
matthiasdiener Jul 25, 2024
5d906b5
revert change to SubsetDependencyMapper
matthiasdiener Jul 25, 2024
142c8e6
some mypy fixes
matthiasdiener Jul 25, 2024
4e0e174
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Aug 12, 2024
bd70620
ruff
matthiasdiener Aug 12, 2024
076a76e
replace orderedsets with unique tuples in DirectPredecessorsGetter
matthiasdiener Aug 13, 2024
ea1462c
mypy fixes
matthiasdiener Aug 14, 2024
168ef53
remove unnecesary cast
matthiasdiener Aug 14, 2024
d711989
adjust comment
matthiasdiener Aug 14, 2024
c976c23
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Sep 4, 2024
fde8f77
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Sep 12, 2024
7e03b35
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Sep 27, 2024
5847800
performance fix
matthiasdiener Sep 27, 2024
7dd83bb
switch to dicts
matthiasdiener Sep 27, 2024
1ea962c
more dict usage
matthiasdiener Sep 27, 2024
77f1bbd
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Oct 11, 2024
e8b5806
fix materialized_arrays perf
matthiasdiener Oct 14, 2024
4a9f9f5
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Oct 24, 2024
12e9449
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Nov 5, 2024
7050bc8
Merge branch 'main' into deterministic-fdp-nonset
inducer Nov 14, 2024
94658af
Fix imports
inducer Nov 14, 2024
535539e
Merge branch 'main' into deterministic-fdp-nonset
inducer Nov 14, 2024
3cac91f
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Nov 18, 2024
8929bed
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Nov 19, 2024
3a30e88
ruff
matthiasdiener Nov 19, 2024
0124792
Merge branch 'main' into deterministic-fdp-nonset
inducer Dec 4, 2024
25ce81e
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Dec 4, 2024
3530b53
use operator.or_ in reduction
matthiasdiener Dec 4, 2024
fbcbcef
use a FakeOrderedFrozenSet type
matthiasdiener Dec 4, 2024
bc87f3a
extend FakeOrderedSet typing
matthiasdiener Dec 4, 2024
fa1f6f1
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Dec 10, 2024
4424a2e
ruff
matthiasdiener Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, Any
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar

from immutabledict import immutabledict

from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper
Expand Down Expand Up @@ -74,8 +77,14 @@
"""


T = TypeVar("T")

FakeOrderedFrozenSet: TypeAlias = immutabledict[T, None]
FakeOrderedSet: TypeAlias = dict[T, None]

# {{{ NUserCollector


class NUserCollector(Mapper[None, []]):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
Expand Down Expand Up @@ -328,37 +337,37 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]):

We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T = TypeVar("T")

class FakeOrderedFrozenSet(immutabledict[T, None]):
    pass

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to get this to work with mypy, but what do you think of fbcbcef?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works for me. I was under the impression that type aliases could not be generic, but apparently I'm wrong? I tried for a bit to back up my assumption, but I wasn't able to. (I also wasn't able to back up the opposite.) Definitive info would be most welcome! 🙂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return frozenset({dim for dim in shape if isinstance(dim, Array)})
def _get_preds_from_shape(self, shape: ShapeType) -> FakeOrderedFrozenSet[Array]:
return immutabledict.fromkeys(dim for dim in shape if isinstance(dim, Array))

def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]:
return (frozenset(expr.bindings.values())
def map_index_lambda(self, expr: IndexLambda) -> FakeOrderedFrozenSet[Array]:
return (immutabledict.fromkeys(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
def map_stack(self, expr: Stack) -> FakeOrderedFrozenSet[Array]:
return (immutabledict.fromkeys(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
def map_concatenate(self, expr: Concatenate) -> FakeOrderedFrozenSet[Array]:
return (immutabledict.fromkeys(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]:
return (frozenset(expr.args)
def map_einsum(self, expr: Einsum) -> FakeOrderedFrozenSet[Array]:
return (immutabledict.fromkeys(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]:
def map_loopy_call_result(self, expr: NamedArray) -> FakeOrderedFrozenSet[Array]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
return (frozenset(ary
return (immutabledict.fromkeys(ary
for ary in expr._container.bindings.values()
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
def _map_index_base(self, expr: IndexBase) -> FakeOrderedFrozenSet[Array]:
return (immutabledict.fromkeys([expr.array])
| immutabledict.fromkeys(idx for idx in expr.indices
if isinstance(idx, Array))
| self._get_preds_from_shape(expr.shape))

Expand All @@ -367,34 +376,35 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> frozenset[ArrayOrNames]:
return frozenset([expr.array])
) -> FakeOrderedFrozenSet[ArrayOrNames]:
return immutabledict.fromkeys([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]:
def _map_input_base(self, expr: InputArgumentBase) -> FakeOrderedFrozenSet[Array]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]:
def map_distributed_recv(self,
expr: DistributedRecv) -> FakeOrderedFrozenSet[Array]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> frozenset[ArrayOrNames]:
return frozenset([expr.passthrough_data])
) -> FakeOrderedFrozenSet[ArrayOrNames]:
return immutabledict.fromkeys([expr.passthrough_data])

def map_call(self, expr: Call) -> frozenset[ArrayOrNames]:
return frozenset(expr.bindings.values())
def map_call(self, expr: Call) -> FakeOrderedFrozenSet[ArrayOrNames]:
return immutabledict.fromkeys(expr.bindings.values())

def map_named_call_result(
self, expr: NamedCallResult) -> frozenset[ArrayOrNames]:
return frozenset([expr._container])
self, expr: NamedCallResult) -> FakeOrderedFrozenSet[ArrayOrNames]:
return immutabledict.fromkeys([expr._container])


# }}}
Expand Down
Loading
Loading