Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance common model utils with classes for implementing double buffering #602

Merged
merged 9 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions model/common/src/icon4py/model/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,26 @@

from __future__ import annotations

from ._common import Swapping, chainable
from . import gt4py_field_allocation
from ._common import (
DoubleBuffering,
NextStepPair,
Pair,
PreviousStepPair,
chainable,
named_property,
)


__all__ = [
# Classes
"DoubleBuffering",
"NextStepPair",
"Pair",
"PreviousStepPair",
"named_property",
# Functions
"chainable",
"Swapping",
# Modules
"gt4py_field_allocation",
]
308 changes: 243 additions & 65 deletions model/common/src/icon4py/model/common/utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,110 +9,288 @@

from __future__ import annotations

import abc
import copy
import functools
from collections.abc import Callable
from typing import Concatenate, Generic, ParamSpec, TypeVar
from typing import (
ClassVar,
Concatenate,
Final,
Generator,
Generic,
Literal,
ParamSpec,
Protocol,
TypeVar,
runtime_checkable,
)


__all__ = [
"chainable",
"Swapping",
]
@runtime_checkable
class DoubleBuffering(Protocol):
@abc.abstractmethod
def swap_buffers(self) -> None:
...


P = ParamSpec("P")
T = TypeVar("T")
C = TypeVar("C")


def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]:
class named_property(property, Generic[C, T]):
"""
Make an instance method return the actual instance so it can used in a chain of calls.
A simple extension of the built-in `property` descriptor storing
the name of the attribute it is assigned to.

Typically used for simple fluent interfaces.
The name is stored in the `name` attribute of the property instance.

Examples:
>>> class A:
... @chainable
... def set_value(self, value: int) -> None:
... self.value = value
... @named_property
... def value(self) -> int:
... return self._value
...
... @chainable
... def increment(self, value: int) -> None:
... self.value += value
...
...
... a = A()
... a.set_value(1).increment(2)
... a.value
3
... @value.setter
... def value(self, value: int) -> None:
... self._value = value
>>> a = A()
>>> a.value = 1
>>> print(A.value.name)
value
"""

@functools.wraps(method_fn)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T:
method_fn(self, *args, **kwargs)
return self
name: str | None = None

return wrapper
def __set_name__(self, owner: C, name: str) -> None:
"""
Set the name of the attribute the property is assigned to.

Note that if the same descriptor is assigned to multiple attributes
or if it is reused in other classes, only the first assigned name
is stored.
"""
if self.name is None:
self.name = name

def getter(self: named_property[C, T], fget: Callable[[C], T]) -> named_property[C, T]:
result = super().getter(fget)
result.name = getattr(self, "name", None)
return result

def setter(self: named_property[C, T], fset: Callable[[C, T], None]) -> named_property[C, T]:
result = super().setter(fset)
result.name = getattr(self, "name", None)
return result

class Swapping(Generic[T]):
def deleter(self: named_property[C, T], fdel: Callable[[C], None]) -> named_property[C, T]:
result = super().deleter(fdel)
result.name = getattr(self, "name", None)
return result

def __copy__(self) -> named_property[C, T]:
result = type(self)(self.fget, self.fset, self.fdel, self.__doc__)
result.name = self.name
return result


class Pair(Generic[T]):
"""
Generic double container for swapping between two values.
Generic class representing a pair of values.

This is useful for double buffering in numerical algorithms.
The name of the pair attributes can be customized by defining new
descriptors in the subclasses.

Examples:
>>> a = Swapping(current=1, other=2)
Swapping(1, 2)
See the examples below.

>>> a.swap()
Swapping(current=2, other=1)
Examples:
>>> class MyPair(Pair[T]):
... a: T = Pair.first
... b: T = Pair.frozen_second
>>> pair = MyPair(1, 2)
>>> print(pair)
MyPair(a=1, b=2)

>>> a.current = 3
... a
Swapping(current=3, other=1)
>>> pair.swap()
MyPair(a=2, b=1)

>>> a != ~a
>>> pair == MyPair(2, 1)
True

>>> a == ~~a
True
>>> pair.a = -1
>>> pair[0]
-1

>>> a.current == (~a).other
True
>>> pair[0] = 100
>>> pair.a
100

>>> pair.b = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute

>>> pair[1] = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute

>>> b = ~a
... a.swap()
... a == b
True
"""

__slots__ = ("current", "_other", "__weakref__")
_FIRST_ACCESSOR_ID: Final = "FIRST"
_SECOND_ACCESSOR_ID: Final = "SECOND"

current: T
_other: T
__first_attr_name: ClassVar[str] = "first"
__second_attr_name: ClassVar[str] = "second"

@property
def other(self) -> T:
return self._other
def __init_subclass__(cls) -> None:
for key, value in {**cls.__dict__}.items():
if (attr_id := getattr(value, "_pair_accessor_id_", None)) is not None:
assert isinstance(value, named_property)
if key != value.name:
# If one of the original descriptors from the `Pair` class has been
# directly assigned to another class member with a different name
# (likely in a subclass) instead of creating a proper copy, it is
# copied and initialized here with the right name.
descriptor_copy = copy.copy(value)
descriptor_copy.name = key
setattr(cls, key, descriptor_copy)
if attr_id == Pair._FIRST_ACCESSOR_ID:
cls.__first_attr_name = key
elif attr_id == Pair._SECOND_ACCESSOR_ID:
cls.__second_attr_name = key
else:
raise TypeError(f"Invalid '{key}' pair accessor descriptor: {value}")

def __init__(self, current: T, other: T) -> None:
self.current = current
self._other = other
__first: T
__second: T

def __repr__(self) -> str:
return f"{self.__class__.__name__}(current={self.current!r}, other={self._other!r})"
def __init__(self, first: T, second: T, /) -> None:
self.__first = first
self.__second = second

@named_property
def first(self) -> T:
"""Property descriptor for the first element of the pair."""
return self.__first

@first.setter
def first(self, value: T) -> None:
self.__first = value

@named_property
def second(self) -> T:
"""Property descriptor for the second element of the pair."""
return self.__second

@second.setter
def second(self, value: T) -> None:
self.__second = value

@named_property
def frozen_first(self) -> T:
"""Read-only property descriptor for the first element of the pair (mainly for subclassing)."""
return self.__first

@named_property
def frozen_second(self) -> T:
"""Read-only property descriptor for the second element of the pair (mainly for subclassing)."""
return self.__second

first._pair_accessor_id_ = frozen_first._pair_accessor_id_ = _FIRST_ACCESSOR_ID
second._pair_accessor_id_ = frozen_second._pair_accessor_id_ = _SECOND_ACCESSOR_ID

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Swapping)
and self.current == other.current
and self._other == other._other
return type(self) is type(other) and (
self.__first == other.__first and self.__second == other.__second
)

# `__hash__` is implicitly set to None when `__eq__` is redefined, so instances are not hashable.

def swap(self) -> None:
self.current, self._other = self._other, self.current
def __getitem__(self, index: Literal[0, 1]) -> T:
match index:
case 0:
return self.__first
case 1:
return self.__second
case _:
raise IndexError(f"Pair index out of range: {index}")

def __setitem__(self, index: Literal[0, 1], value: T) -> None:
# Go through the attribute descriptors to respect the read-only indication
match index:
case 0:
attr_name = self.__first_attr_name
case 1:
attr_name = self.__second_attr_name
case _:
raise IndexError(f"Pair index out of range: {index}")

def __invert__(self) -> Swapping[T]:
return type(self)(current=self._other, other=self.current)
setattr(self, attr_name, value)

def __iter__(self) -> Generator[T, None, None]:
yield self.__first
yield self.__second

def __repr__(self) -> str:
first_name = type(self).__first_attr_name
second_name = type(self).__second_attr_name
return f"{self.__class__.__name__}({first_name}={self.__first!r}, {second_name}={self.__second!r})"

def swap(self: Pair[T]) -> Pair[T]:
"""
Swap the values of the first and second attributes of the instance.

Returns:
The instance with swapped values (for fluent interfaces).
"""
self.__first, self.__second = self.__second, self.__first
return self

swap_buffers = swap


class NextStepPair(Pair[T]):
current: T = Pair.first
halungge marked this conversation as resolved.
Show resolved Hide resolved
next: T = Pair.frozen_second


class PreviousStepPair(Pair[T]):
current: T = Pair.first
previous: T = Pair.frozen_second

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So where for example in the SolveNonhydro would you use those: in that if pattern where the do no swapping I assume you would use any of those. But does it need both? Or you are adding them only as an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NextStepPair will be used for the double buffering of the prognostic state and PreviousStepPair is just an example and thus it could be deleted.


P = ParamSpec("P")
T = TypeVar("T")


def chainable(method_fn: Callable[Concatenate[T, P], None]) -> Callable[Concatenate[T, P], T]:
"""
Make an instance method return the actual instance so it can be used in a chain of calls.

Typically used for simple fluent interfaces.

Examples:
>>> class A:
... @chainable
... def set_value(self, value: int) -> None:
... self.value = value
...
... @chainable
... def increment(self, value: int) -> None:
... self.value += value

>>> a = A()
>>> a.set_value(1).increment(2) # doctest:+ELLIPSIS
<__main__.A object at ...>

>>> a.value
3
"""

@functools.wraps(method_fn)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T:
method_fn(self, *args, **kwargs)
return self

return wrapper
Loading