From b9c5eaee3b11c2c265209d7e9bd7c595ddb534e0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 8 Sep 2024 14:16:29 -0500 Subject: [PATCH] Require that happens_after is not mutable --- loopy/kernel/instruction.py | 28 +++++++++++++++++++--------- pyproject.toml | 1 + 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index d564d5e36..83cdcc674 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -20,7 +20,11 @@ THE SOFTWARE. """ -from collections.abc import Mapping as MappingABC, Set as abc_Set +from collections.abc import ( + Mapping as MappingABC, + MutableMapping as MutableMappingABC, + Set as abc_Set, +) from dataclasses import dataclass from functools import cached_property from sys import intern @@ -283,6 +287,7 @@ def __init__(self, *, depends_on: Union[FrozenSet[str], str, None] = None, ) -> None: + from immutabledict import immutabledict if predicates is None: predicates = frozenset() @@ -315,27 +320,27 @@ def __init__(self, "actually specifying happens_after/depends_on") if happens_after is None: - happens_after = {} + happens_after = immutabledict() elif isinstance(happens_after, str): warn("Passing a string for happens_after/depends_on is deprecated and " "will stop working in 2025. Instead, pass a full-fledged " "happens_after data structure.", DeprecationWarning, stacklevel=2) - happens_after = { + happens_after = immutabledict({ after_id.strip(): HappensAfter( variable_name=None, instances_rel=None) for after_id in happens_after.split(",") - if after_id.strip()} + if after_id.strip()}) elif isinstance(happens_after, frozenset): - happens_after = { + happens_after = immutabledict({ after_id: HappensAfter( variable_name=None, instances_rel=None) - for after_id in happens_after} + for after_id in happens_after}) elif isinstance(happens_after, MappingABC): if isinstance(happens_after, dict): - happens_after = happens_after + happens_after = immutabledict(happens_after) else: raise TypeError("'happens_after' has unexpected type: " f"{type(happens_after)}") @@ -389,6 +394,9 @@ def __init__(self, assert isinstance(happens_after, MappingABC) or happens_after is None assert isinstance(groups, abc_Set) assert isinstance(conflicts_with_groups, abc_Set) + if isinstance(happens_after, MappingABC): + # Verify that happens_after is hashable. + assert not isinstance(happens_after, MutableMappingABC) ImmutableRecord.__init__(self, id=id, @@ -573,13 +581,15 @@ def update_persistent_hash(self, key_hash, key_builder): def __setstate__(self, val): super().__setstate__(val) + from immutabledict import immutabledict + from loopy.tools import intern_frozenset_of_ids if self.id is not None: # pylint:disable=access-member-before-definition self.id = intern(self.id) - self.happens_after = { + self.happens_after = immutabledict({ intern(after_id): ha - for after_id, ha in self.happens_after.items()} + for after_id, ha in self.happens_after.items()}) self.groups = intern_frozenset_of_ids(self.groups) self.conflicts_with_groups = ( intern_frozenset_of_ids(self.conflicts_with_groups)) diff --git a/pyproject.toml b/pyproject.toml index 4134ba24d..70672a1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "Mako", "pyrsistent", "immutables", + "immutabledict", # for Self, TypeAlias "typing-extensions>=4; python_version<'3.12'",