From 51dc358d739403069e36ad935fae82eb2b0fe381 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 05:17:17 +0100 Subject: [PATCH 1/7] Refactor common model utils --- .../icon4py/model/common/utils/__init__.py | 20 +- .../src/icon4py/model/common/utils/_common.py | 307 ++++++++++++++---- 2 files changed, 260 insertions(+), 67 deletions(-) diff --git a/model/common/src/icon4py/model/common/utils/__init__.py b/model/common/src/icon4py/model/common/utils/__init__.py index af9faae75e..39101e02fc 100644 --- a/model/common/src/icon4py/model/common/utils/__init__.py +++ b/model/common/src/icon4py/model/common/utils/__init__.py @@ -8,10 +8,26 @@ from __future__ import annotations -from ._common import Swapping, chainable +from . import gt4py_field_allocation +from ._common import ( + DoubleBuffering, + NextStepPair, + Pair, + PreviousStepPair, + chainable, +) __all__ = [ + # Classes + "DoubleBuffering", + "DoubleBufferingDataClassTrait", + "NextStepPair", + "Pair", + "PreviousStepPair", + "namedproperty", + # Functions "chainable", - "Swapping", + # Modules + "gt4py_field_allocation", ] diff --git a/model/common/src/icon4py/model/common/utils/_common.py b/model/common/src/icon4py/model/common/utils/_common.py index 361c5b3fd7..5f1d31f4d1 100644 --- a/model/common/src/icon4py/model/common/utils/_common.py +++ b/model/common/src/icon4py/model/common/utils/_common.py @@ -9,110 +9,287 @@ from __future__ import annotations +import abc +import copy import functools from collections.abc import Callable -from typing import Concatenate, Generic, ParamSpec, TypeVar +from typing import ( + ClassVar, + Concatenate, + Final, + Generator, + Generic, + Literal, + ParamSpec, + Protocol, + TypeVar, + runtime_checkable, +) -__all__ = [ - "chainable", - "Swapping", -] +@runtime_checkable +class DoubleBuffering(Protocol): + @abc.abstractmethod + def swap_buffers(self) -> None: + ... + -P = ParamSpec("P") T = TypeVar("T") +C = TypeVar("C") -def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]: +class named_property(property, Generic[C, T]): """ - Make an instance method return the actual instance so it can used in a chain of calls. + A simple extension of the built-in `property` descriptor storing + the name of the attribute it is assigned to. - Typically used for simple fluent interfaces. + The name is stored in the `name` attribute of the property instance. Examples: >>> class A: - ... @chainable - ... def set_value(self, value: int) -> None: - ... self.value = value - ... - ... @chainable - ... def increment(self, value: int) -> None: - ... self.value += value - ... + ... @named_property + ... def value(self) -> int: + ... return self._value ... - ... a = A() - ... a.set_value(1).increment(2) - ... a.value - 3 + ... @value.setter + ... def value(self, value: int) -> None: + ... self._value = value + >>> a = A() + >>> a.value = 1 + >>> print(A.value.name) + value """ - @functools.wraps(method_fn) - def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T: - method_fn(self, *args, **kwargs) - return self + name: str | None = None - return wrapper + def __set_name__(self, owner: C, name: str) -> None: + """ + Set the name of the attribute the property is assigned to. + + Note that if the same descriptor is assigned to multiple attributes + or if it is reused in other classes, only the first assigned name + is stored. + """ + if self.name is None: + self.name = name + + def getter(self: named_property[C, T], fget: Callable[[C], T]) -> named_property[C, T]: + result = super().getter(fget) + result.name = getattr(self, "name", None) + return result + def setter(self: named_property[C, T], fset: Callable[[C, T], None]) -> named_property[C, T]: + result = super().setter(fset) + result.name = getattr(self, "name", None) + return result -class Swapping(Generic[T]): + def deleter(self: named_property[C, T], fdel: Callable[[C], None]) -> named_property[C, T]: + result = super().deleter(fdel) + result.name = getattr(self, "name", None) + return result + + def __copy__(self) -> named_property[C, T]: + result = type(self)(self.fget, self.fset, self.fdel, self.__doc__) + result.name = self.name + return result + + +class Pair(Generic[T]): """ - Generic double container for swapping between two values. + A generic class representing a pair of values. - This is useful for double buffering in numerical algorithms. + The name of the pair attributes can be customized by defining new + descriptors in the subclasses. - Examples: - >>> a = Swapping(current=1, other=2) - Swapping(1, 2) + See the examples below. - >>> a.swap() - Swapping(current=2, other=1) + Examples: + >>> class MyPair(Pair[T]): + ... a: T = Pair.first + ... b: T = Pair.frozen_second + >>> pair = MyPair(1, 2) + >>> print(pair) + MyPair(a=1, b=2) - >>> a.current = 3 - ... a - Swapping(current=3, other=1) + >>> pair.swap() + MyPair(a=2, b=1) - >>> a != ~a + >>> pair == MyPair(2, 1) True - >>> a == ~~a - True + >>> pair.a = -1 + >>> pair[0] + -1 - >>> a.current == (~a).other - True + >>> pair[0] = 100 + >>> pair.a + 100 + + >>> pair.b = 3 + Traceback (most recent call last): + ... + AttributeError: can't set attribute + + >>> pair[1] = 3 + Traceback (most recent call last): + ... + AttributeError: can't set attribute - >>> b = ~a - ... a.swap() - ... a == b - True """ - __slots__ = ("current", "_other", "__weakref__") + _FIRST_ACCESSOR_ID: Final = "FIRST" + _SECOND_ACCESSOR_ID: Final = "SECOND" - current: T - _other: T + __first_attr_name: ClassVar[str] = "first" + __second_attr_name: ClassVar[str] = "second" - @property - def other(self) -> T: - return self._other + def __init_subclass__(cls) -> None: + for key, value in {**cls.__dict__}.items(): + if (attr_id := getattr(value, "_pair_accessor_id_", None)) is not None: + assert isinstance(value, named_property) + if key != value.name: + # When the original descriptor from the `Pair` class has been + # directly assigned to another class member with a different name + # (likely in a subclass) instead of creating a proper copy, it is + # copied and initialized here with the right name. + descriptor_copy = copy.copy(value) + descriptor_copy.name = key + setattr(cls, key, descriptor_copy) + if attr_id == Pair._FIRST_ACCESSOR_ID: + cls.__first_attr_name = key + elif attr_id == Pair._SECOND_ACCESSOR_ID: + cls.__second_attr_name = key + else: + raise TypeError(f"Invalid '{key}' pair accessor descriptor: {value}") - def __init__(self, current: T, other: T) -> None: - self.current = current - self._other = other + __first: T + __second: T - def __repr__(self) -> str: - return f"{self.__class__.__name__}(current={self.current!r}, other={self._other!r})" + def __init__(self, first: T, second: T, /) -> None: + self.__first = first + self.__second = second + + @named_property + def first(self) -> T: + """Property descriptor for the first element of the pair.""" + return self.__first + + @first.setter + def first(self, value: T) -> None: + self.__first = value + + @named_property + def second(self) -> T: + """Property descriptor for the second element of the pair.""" + return self.__second + + @second.setter + def second(self, value: T) -> None: + self.__second = value + + @named_property + def frozen_first(self) -> T: + """Read-only property descriptor for the first element of the pair (mainly for subclassing).""" + return self.__first + + @named_property + def frozen_second(self) -> T: + """Read-only property descriptor for the second element of the pair (mainly for subclassing).""" + return self.__second + + first._pair_accessor_id_ = frozen_first._pair_accessor_id_ = _FIRST_ACCESSOR_ID + second._pair_accessor_id_ = frozen_second._pair_accessor_id_ = _SECOND_ACCESSOR_ID def __eq__(self, other: object) -> bool: - return ( - isinstance(other, Swapping) - and self.current == other.current - and self._other == other._other + return type(self) is type(other) and ( + self.__first == other.__first and self.__second == other.__second ) # `__hash__` is implicitly set to None when `__eq__` is redefined, so instances are not hashable. - def swap(self) -> None: - self.current, self._other = self._other, self.current + def __getitem__(self, index: Literal[0, 1]) -> T: + match index: + case 0: + return self.__first + case 1: + return self.__second + case _: + raise IndexError(f"Pair index out of range: {index}") + + def __setitem__(self, index: Literal[0, 1], value: T) -> None: + match index: + case 0: + attr_name = self.__first_attr_name + case 1: + attr_name = self.__second_attr_name + case _: + raise IndexError(f"Pair index out of range: {index}") - def __invert__(self) -> Swapping[T]: - return type(self)(current=self._other, other=self.current) + setattr(self, attr_name, value) + + def __iter__(self) -> Generator[T, None, None]: + yield self.__first + yield self.__second + + def __repr__(self) -> str: + first_name = type(self).__first_attr_name + second_name = type(self).__second_attr_name + return f"{self.__class__.__name__}({first_name}={self.__first!r}, {second_name}={self.__second!r})" + + def swap(self: Pair[T]) -> Pair[T]: + """ + Swap the values of the first and second attributes of the instance. + + Returns: + The instance with swapped values (for fluent interfaces). + """ + self.__first, self.__second = self.__second, self.__first + return self + + swap_buffers = swap + + +class NextStepPair(Pair[T]): + current: T = Pair.first + next: T = Pair.frozen_second + + +class PreviousStepPair(Pair[T]): + current: T = Pair.first + previous: T = Pair.frozen_second + + +P = ParamSpec("P") +T = TypeVar("T") + + +def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]: + """ + Make an instance method return the actual instance so it can used in a chain of calls. + + Typically used for simple fluent interfaces. + + Examples: + >>> class A: + ... @chainable + ... def set_value(self, value: int) -> None: + ... self.value = value + ... + ... @chainable + ... def increment(self, value: int) -> None: + ... self.value += value + + >>> a = A() + >>> a.set_value(1).increment(2) # doctest:+ELLIPSIS + <__main__.A object at ...> + + >>> a.value + 3 + """ + + @functools.wraps(method_fn) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T: + method_fn(self, *args, **kwargs) + return self + + return wrapper From 94a6e847943c629944479c2fe43babdc4e09a7db Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 05:20:52 +0100 Subject: [PATCH 2/7] Fix typo --- model/common/src/icon4py/model/common/utils/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/utils/_common.py b/model/common/src/icon4py/model/common/utils/_common.py index 5f1d31f4d1..30d5aecc4e 100644 --- a/model/common/src/icon4py/model/common/utils/_common.py +++ b/model/common/src/icon4py/model/common/utils/_common.py @@ -265,7 +265,7 @@ class PreviousStepPair(Pair[T]): def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]: """ - Make an instance method return the actual instance so it can used in a chain of calls. + Make an instance method return the actual instance so it can be used in a chain of calls. Typically used for simple fluent interfaces. From a82cac5af0f6b46d7d6c3f95b9e206f4d7da9420 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 05:24:35 +0100 Subject: [PATCH 3/7] Fix typo --- model/common/src/icon4py/model/common/utils/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/utils/_common.py b/model/common/src/icon4py/model/common/utils/_common.py index 30d5aecc4e..91fd1fefb9 100644 --- a/model/common/src/icon4py/model/common/utils/_common.py +++ b/model/common/src/icon4py/model/common/utils/_common.py @@ -96,7 +96,7 @@ def __copy__(self) -> named_property[C, T]: class Pair(Generic[T]): """ - A generic class representing a pair of values. + Generic class representing a pair of values. The name of the pair attributes can be customized by defining new descriptors in the subclasses. From aa164bf1cd89b8f478608efde17d5edd5e932f09 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 05:28:28 +0100 Subject: [PATCH 4/7] Fix typo --- model/common/src/icon4py/model/common/utils/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/utils/_common.py b/model/common/src/icon4py/model/common/utils/_common.py index 91fd1fefb9..d5ac5009e2 100644 --- a/model/common/src/icon4py/model/common/utils/_common.py +++ b/model/common/src/icon4py/model/common/utils/_common.py @@ -148,7 +148,7 @@ def __init_subclass__(cls) -> None: if (attr_id := getattr(value, "_pair_accessor_id_", None)) is not None: assert isinstance(value, named_property) if key != value.name: - # When the original descriptor from the `Pair` class has been + # If one of the original descriptors from the `Pair` class has been # directly assigned to another class member with a different name # (likely in a subclass) instead of creating a proper copy, it is # copied and initialized here with the right name. From ce32fd94f0c61cb03836e89b96a0792ba7437e97 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 05:32:01 +0100 Subject: [PATCH 5/7] Add comment --- model/common/src/icon4py/model/common/utils/_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/utils/_common.py b/model/common/src/icon4py/model/common/utils/_common.py index d5ac5009e2..b6e49de4f7 100644 --- a/model/common/src/icon4py/model/common/utils/_common.py +++ b/model/common/src/icon4py/model/common/utils/_common.py @@ -217,7 +217,8 @@ def __getitem__(self, index: Literal[0, 1]) -> T: raise IndexError(f"Pair index out of range: {index}") def __setitem__(self, index: Literal[0, 1], value: T) -> None: - match index: + # Go through the attribute descriptors to respect the read-only indication + `match index: case 0: attr_name = self.__first_attr_name case 1: From ce27722a3d615a7c724ee0051888b0e70b2fd797 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 05:39:56 +0100 Subject: [PATCH 6/7] Fix typo --- model/common/src/icon4py/model/common/utils/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/utils/_common.py b/model/common/src/icon4py/model/common/utils/_common.py index b6e49de4f7..13d2dd7bbf 100644 --- a/model/common/src/icon4py/model/common/utils/_common.py +++ b/model/common/src/icon4py/model/common/utils/_common.py @@ -218,7 +218,7 @@ def __getitem__(self, index: Literal[0, 1]) -> T: def __setitem__(self, index: Literal[0, 1], value: T) -> None: # Go through the attribute descriptors to respect the read-only indication - `match index: + match index: case 0: attr_name = self.__first_attr_name case 1: From 5d6ad43a12ec27ada45762b507f9467cb2d57a5f Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Wed, 20 Nov 2024 15:38:16 +0100 Subject: [PATCH 7/7] Fix `__init__.py` in common.utils --- model/common/src/icon4py/model/common/utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/utils/__init__.py b/model/common/src/icon4py/model/common/utils/__init__.py index 39101e02fc..315b415edd 100644 --- a/model/common/src/icon4py/model/common/utils/__init__.py +++ b/model/common/src/icon4py/model/common/utils/__init__.py @@ -15,17 +15,17 @@ Pair, PreviousStepPair, chainable, + named_property, ) __all__ = [ # Classes "DoubleBuffering", - "DoubleBufferingDataClassTrait", "NextStepPair", "Pair", "PreviousStepPair", - "namedproperty", + "named_property", # Functions "chainable", # Modules