Skip to content

Commit

Permalink
feat: enhance common model utils with classes for implementing double…
Browse files Browse the repository at this point in the history
… buffering (#602)

Refactor and expand the model/common/utils module with simple implementation of utilities for to double-buffering like Pair. Additionally, fix the existing doctest for chainable.
  • Loading branch information
egparedes authored Nov 20, 2024
1 parent b5b0011 commit 5820a8e
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 67 deletions.
20 changes: 18 additions & 2 deletions model/common/src/icon4py/model/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,26 @@

from __future__ import annotations

from ._common import Swapping, chainable
from . import gt4py_field_allocation
from ._common import (
DoubleBuffering,
NextStepPair,
Pair,
PreviousStepPair,
chainable,
named_property,
)


__all__ = [
# Classes
"DoubleBuffering",
"NextStepPair",
"Pair",
"PreviousStepPair",
"named_property",
# Functions
"chainable",
"Swapping",
# Modules
"gt4py_field_allocation",
]
308 changes: 243 additions & 65 deletions model/common/src/icon4py/model/common/utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,110 +9,288 @@

from __future__ import annotations

import abc
import copy
import functools
from collections.abc import Callable
from typing import Concatenate, Generic, ParamSpec, TypeVar
from typing import (
ClassVar,
Concatenate,
Final,
Generator,
Generic,
Literal,
ParamSpec,
Protocol,
TypeVar,
runtime_checkable,
)


__all__ = [
"chainable",
"Swapping",
]
@runtime_checkable
class DoubleBuffering(Protocol):
@abc.abstractmethod
def swap_buffers(self) -> None:
...


P = ParamSpec("P")
T = TypeVar("T")
C = TypeVar("C")


def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]:
class named_property(property, Generic[C, T]):
"""
Make an instance method return the actual instance so it can used in a chain of calls.
A simple extension of the built-in `property` descriptor storing
the name of the attribute it is assigned to.
Typically used for simple fluent interfaces.
The name is stored in the `name` attribute of the property instance.
Examples:
>>> class A:
... @chainable
... def set_value(self, value: int) -> None:
... self.value = value
... @named_property
... def value(self) -> int:
... return self._value
...
... @chainable
... def increment(self, value: int) -> None:
... self.value += value
...
...
... a = A()
... a.set_value(1).increment(2)
... a.value
3
... @value.setter
... def value(self, value: int) -> None:
... self._value = value
>>> a = A()
>>> a.value = 1
>>> print(A.value.name)
value
"""

@functools.wraps(method_fn)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T:
method_fn(self, *args, **kwargs)
return self
name: str | None = None

return wrapper
def __set_name__(self, owner: C, name: str) -> None:
"""
Set the name of the attribute the property is assigned to.
Note that if the same descriptor is assigned to multiple attributes
or if it is reused in other classes, only the first assigned name
is stored.
"""
if self.name is None:
self.name = name

def getter(self: named_property[C, T], fget: Callable[[C], T]) -> named_property[C, T]:
result = super().getter(fget)
result.name = getattr(self, "name", None)
return result

def setter(self: named_property[C, T], fset: Callable[[C, T], None]) -> named_property[C, T]:
result = super().setter(fset)
result.name = getattr(self, "name", None)
return result

class Swapping(Generic[T]):
def deleter(self: named_property[C, T], fdel: Callable[[C], None]) -> named_property[C, T]:
result = super().deleter(fdel)
result.name = getattr(self, "name", None)
return result

def __copy__(self) -> named_property[C, T]:
result = type(self)(self.fget, self.fset, self.fdel, self.__doc__)
result.name = self.name
return result


class Pair(Generic[T]):
"""
Generic double container for swapping between two values.
Generic class representing a pair of values.
This is useful for double buffering in numerical algorithms.
The name of the pair attributes can be customized by defining new
descriptors in the subclasses.
Examples:
>>> a = Swapping(current=1, other=2)
Swapping(1, 2)
See the examples below.
>>> a.swap()
Swapping(current=2, other=1)
Examples:
>>> class MyPair(Pair[T]):
... a: T = Pair.first
... b: T = Pair.frozen_second
>>> pair = MyPair(1, 2)
>>> print(pair)
MyPair(a=1, b=2)
>>> a.current = 3
... a
Swapping(current=3, other=1)
>>> pair.swap()
MyPair(a=2, b=1)
>>> a != ~a
>>> pair == MyPair(2, 1)
True
>>> a == ~~a
True
>>> pair.a = -1
>>> pair[0]
-1
>>> a.current == (~a).other
True
>>> pair[0] = 100
>>> pair.a
100
>>> pair.b = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute
>>> pair[1] = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute
>>> b = ~a
... a.swap()
... a == b
True
"""

__slots__ = ("current", "_other", "__weakref__")
_FIRST_ACCESSOR_ID: Final = "FIRST"
_SECOND_ACCESSOR_ID: Final = "SECOND"

current: T
_other: T
__first_attr_name: ClassVar[str] = "first"
__second_attr_name: ClassVar[str] = "second"

@property
def other(self) -> T:
return self._other
def __init_subclass__(cls) -> None:
for key, value in {**cls.__dict__}.items():
if (attr_id := getattr(value, "_pair_accessor_id_", None)) is not None:
assert isinstance(value, named_property)
if key != value.name:
# If one of the original descriptors from the `Pair` class has been
# directly assigned to another class member with a different name
# (likely in a subclass) instead of creating a proper copy, it is
# copied and initialized here with the right name.
descriptor_copy = copy.copy(value)
descriptor_copy.name = key
setattr(cls, key, descriptor_copy)
if attr_id == Pair._FIRST_ACCESSOR_ID:
cls.__first_attr_name = key
elif attr_id == Pair._SECOND_ACCESSOR_ID:
cls.__second_attr_name = key
else:
raise TypeError(f"Invalid '{key}' pair accessor descriptor: {value}")

def __init__(self, current: T, other: T) -> None:
self.current = current
self._other = other
__first: T
__second: T

def __repr__(self) -> str:
return f"{self.__class__.__name__}(current={self.current!r}, other={self._other!r})"
def __init__(self, first: T, second: T, /) -> None:
self.__first = first
self.__second = second

@named_property
def first(self) -> T:
"""Property descriptor for the first element of the pair."""
return self.__first

@first.setter
def first(self, value: T) -> None:
self.__first = value

@named_property
def second(self) -> T:
"""Property descriptor for the second element of the pair."""
return self.__second

@second.setter
def second(self, value: T) -> None:
self.__second = value

@named_property
def frozen_first(self) -> T:
"""Read-only property descriptor for the first element of the pair (mainly for subclassing)."""
return self.__first

@named_property
def frozen_second(self) -> T:
"""Read-only property descriptor for the second element of the pair (mainly for subclassing)."""
return self.__second

first._pair_accessor_id_ = frozen_first._pair_accessor_id_ = _FIRST_ACCESSOR_ID
second._pair_accessor_id_ = frozen_second._pair_accessor_id_ = _SECOND_ACCESSOR_ID

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Swapping)
and self.current == other.current
and self._other == other._other
return type(self) is type(other) and (
self.__first == other.__first and self.__second == other.__second
)

# `__hash__` is implicitly set to None when `__eq__` is redefined, so instances are not hashable.

def swap(self) -> None:
self.current, self._other = self._other, self.current
def __getitem__(self, index: Literal[0, 1]) -> T:
match index:
case 0:
return self.__first
case 1:
return self.__second
case _:
raise IndexError(f"Pair index out of range: {index}")

def __setitem__(self, index: Literal[0, 1], value: T) -> None:
# Go through the attribute descriptors to respect the read-only indication
match index:
case 0:
attr_name = self.__first_attr_name
case 1:
attr_name = self.__second_attr_name
case _:
raise IndexError(f"Pair index out of range: {index}")

def __invert__(self) -> Swapping[T]:
return type(self)(current=self._other, other=self.current)
setattr(self, attr_name, value)

def __iter__(self) -> Generator[T, None, None]:
yield self.__first
yield self.__second

def __repr__(self) -> str:
first_name = type(self).__first_attr_name
second_name = type(self).__second_attr_name
return f"{self.__class__.__name__}({first_name}={self.__first!r}, {second_name}={self.__second!r})"

def swap(self: Pair[T]) -> Pair[T]:
"""
Swap the values of the first and second attributes of the instance.
Returns:
The instance with swapped values (for fluent interfaces).
"""
self.__first, self.__second = self.__second, self.__first
return self

swap_buffers = swap


class NextStepPair(Pair[T]):
current: T = Pair.first
next: T = Pair.frozen_second


class PreviousStepPair(Pair[T]):
current: T = Pair.first
previous: T = Pair.frozen_second


P = ParamSpec("P")
T = TypeVar("T")


def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]:
"""
Make an instance method return the actual instance so it can be used in a chain of calls.
Typically used for simple fluent interfaces.
Examples:
>>> class A:
... @chainable
... def set_value(self, value: int) -> None:
... self.value = value
...
... @chainable
... def increment(self, value: int) -> None:
... self.value += value
>>> a = A()
>>> a.set_value(1).increment(2) # doctest:+ELLIPSIS
<__main__.A object at ...>
>>> a.value
3
"""

@functools.wraps(method_fn)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T:
method_fn(self, *args, **kwargs)
return self

return wrapper

0 comments on commit 5820a8e

Please sign in to comment.