Skip to content

Commit

Permalink
Require that happens_after is not mutable
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Sep 8, 2024
1 parent 0f78426 commit b9c5eae
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
28 changes: 19 additions & 9 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
THE SOFTWARE.
"""

from collections.abc import Mapping as MappingABC, Set as abc_Set
from collections.abc import (
Mapping as MappingABC,
MutableMapping as MutableMappingABC,
Set as abc_Set,
)
from dataclasses import dataclass
from functools import cached_property
from sys import intern
Expand Down Expand Up @@ -283,6 +287,7 @@ def __init__(self,
*,
depends_on: Union[FrozenSet[str], str, None] = None,
) -> None:
from immutabledict import immutabledict

if predicates is None:
predicates = frozenset()
Expand Down Expand Up @@ -315,27 +320,27 @@ def __init__(self,
"actually specifying happens_after/depends_on")

if happens_after is None:
happens_after = {}
happens_after = immutabledict()
elif isinstance(happens_after, str):
warn("Passing a string for happens_after/depends_on is deprecated and "
"will stop working in 2025. Instead, pass a full-fledged "
"happens_after data structure.", DeprecationWarning, stacklevel=2)

happens_after = {
happens_after = immutabledict({
after_id.strip(): HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after.split(",")
if after_id.strip()}
if after_id.strip()})
elif isinstance(happens_after, frozenset):
happens_after = {
happens_after = immutabledict({
after_id: HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after}
for after_id in happens_after})
elif isinstance(happens_after, MappingABC):
if isinstance(happens_after, dict):
happens_after = happens_after
happens_after = immutabledict(happens_after)
else:
raise TypeError("'happens_after' has unexpected type: "
f"{type(happens_after)}")
Expand Down Expand Up @@ -389,6 +394,9 @@ def __init__(self,
assert isinstance(happens_after, MappingABC) or happens_after is None
assert isinstance(groups, abc_Set)
assert isinstance(conflicts_with_groups, abc_Set)
if isinstance(happens_after, MappingABC):
# Verify that happens_after is hashable.
assert not isinstance(happens_after, MutableMappingABC)

ImmutableRecord.__init__(self,
id=id,
Expand Down Expand Up @@ -573,13 +581,15 @@ def update_persistent_hash(self, key_hash, key_builder):
def __setstate__(self, val):
super().__setstate__(val)

from immutabledict import immutabledict

from loopy.tools import intern_frozenset_of_ids

if self.id is not None: # pylint:disable=access-member-before-definition
self.id = intern(self.id)
self.happens_after = {
self.happens_after = immutabledict({
intern(after_id): ha
for after_id, ha in self.happens_after.items()}
for after_id, ha in self.happens_after.items()})
self.groups = intern_frozenset_of_ids(self.groups)
self.conflicts_with_groups = (
intern_frozenset_of_ids(self.conflicts_with_groups))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"Mako",
"pyrsistent",
"immutables",
"immutabledict",

# for Self, TypeAlias
"typing-extensions>=4; python_version<'3.12'",
Expand Down

0 comments on commit b9c5eae

Please sign in to comment.