Skip to content

Commit

Permalink
feat(annots): support marking classes as singletons
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 11, 2024
1 parent 9fc7aee commit 16ffcf8
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 36 deletions.
93 changes: 64 additions & 29 deletions koerce/annots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Mapping, Sequence
from types import FunctionType, MethodType
from typing import Any, ClassVar, Optional
from weakref import WeakValueDictionary

import cython

Expand Down Expand Up @@ -541,7 +542,7 @@ def varkwargs(pattern=_any, typehint=EMPTY):
return Parameter(kind=_VAR_KEYWORD, pattern=pattern, typehint=typehint)


__create__ = cython.declare(object, type.__call__)
__type_call__ = cython.declare(object, type.__call__)
if cython.compiled:
from cython.cimports.cpython.object import PyObject_GenericSetAttr as __setattr__
else:
Expand All @@ -555,6 +556,7 @@ class AnnotableSpec:
initable = cython.declare(cython.bint, visibility="readonly")
hashable = cython.declare(cython.bint, visibility="readonly")
immutable = cython.declare(cython.bint, visibility="readonly")
singleton = cython.declare(cython.bint, visibility="readonly")
signature = cython.declare(Signature, visibility="readonly")
attributes = cython.declare(dict[str, Attribute], visibility="readonly")
hasattribs = cython.declare(cython.bint, visibility="readonly")
Expand All @@ -564,44 +566,66 @@ def __init__(
initable: bool,
hashable: bool,
immutable: bool,
singleton: bool,
signature: Signature,
attributes: dict[str, Attribute],
):
self.initable = initable
self.hashable = hashable
self.immutable = immutable
self.singleton = singleton
self.signature = signature
self.attributes = attributes
self.hasattribs = bool(attributes)

@cython.cfunc
@cython.inline
def new(self, cls: type, args: tuple[Any, ...], kwargs: dict[str, Any]):
ctx: dict[str, Any] = {}
bound: dict[str, Any]
param: Parameter

if not args and len(kwargs) == self.signature.length:
bound = kwargs
else:
bound = self.signature.bind(args, kwargs)

if self.initable:
# slow initialization calling __init__
for name, param in self.signature.parameters.items():
bound[name] = param.pattern.match(bound[name], ctx)
return __create__(cls, **bound)
if self.singleton or self.initable:
return self.new_slow(cls, bound)
else:
# fast initialization directly setting the arguments
this = cls.__new__(cls)
for name, param in self.signature.parameters.items():
__setattr__(this, name, param.pattern.match(bound[name], ctx))
# TODO(kszucs): test order ot precomputes and attributes calculations
if self.hashable:
self.init_precomputes(this)
if self.hasattribs:
self.init_attributes(this)
return this
return self.new_fast(cls, bound)

@cython.cfunc
@cython.inline
def new_slow(self, cls: type, bound: dict[str, Any]):
# slow initialization calling __init__
ctx: dict[str, Any] = {}
param: Parameter
for name, param in self.signature.parameters.items():
bound[name] = param.pattern.match(bound[name], ctx)

if self.singleton:
key = (cls, *bound.items())
try:
return cls.__instances__[key]
except KeyError:
this = __type_call__(cls, **bound)
cls.__instances__[key] = this
return this

return __type_call__(cls, **bound)

@cython.cfunc
@cython.inline
def new_fast(self, cls: type, bound: dict[str, Any]):
# fast initialization directly setting the arguments
ctx: dict[str, Any] = {}
param: Parameter
this = cls.__new__(cls)
for name, param in self.signature.parameters.items():
__setattr__(this, name, param.pattern.match(bound[name], ctx))
if self.hashable:
self.init_precomputes(this)
if self.hasattribs:
self.init_attributes(this)
return this

@cython.cfunc
@cython.inline
Expand All @@ -627,8 +651,7 @@ def init_precomputes(self, this) -> cython.void:
class AbstractMeta(type):
"""Base metaclass for many of the ibis core classes.
Enforce the subclasses to define a `__slots__` attribute and provide a
`__create__` classmethod to change the instantiation behavior of the class.
Enforce the subclasses to define a `__slots__` attribute.
Support abstract methods without extending `abc.ABCMeta`. While it provides
a reduced feature set compared to `abc.ABCMeta` (no way to register virtual
Expand All @@ -639,8 +662,8 @@ class AbstractMeta(type):
__slots__ = ()

def __new__(metacls, clsname, bases, dct, **kwargs):
# # enforce slot definitions
# dct.setdefault("__slots__", ())
# enforce slot definitions
dct.setdefault("__slots__", ())

# construct the class object
cls = super().__new__(metacls, clsname, bases, dct, **kwargs)
Expand All @@ -663,6 +686,10 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
return cls


class Abstract(metaclass=AbstractMeta):
"""Base class for many of the ibis core classes, see `AbstractMeta`."""


class AnnotableMeta(AbstractMeta):
def __new__(
metacls,
Expand All @@ -672,6 +699,7 @@ def __new__(
initable=None,
hashable=None,
immutable=None,
singleton=False,
allow_coercion=True,
**kwargs,
):
Expand All @@ -682,6 +710,7 @@ def __new__(
is_initable: cython.bint
is_hashable: cython.bint = hashable is True
is_immutable: cython.bint = immutable is True
is_singleton: cython.bint = singleton is True
if initable is None:
is_initable = "__init__" in dct or "__new__" in dct
else:
Expand Down Expand Up @@ -713,6 +742,8 @@ def __new__(
traits.append(Hashable)
if immutable:
traits.append(Immutable)
if singleton:
traits.append(Singleton)

# collect type annotations and convert them to patterns
slots: list[str] = list(dct.pop("__slots__", []))
Expand Down Expand Up @@ -757,6 +788,7 @@ def __new__(
spec = AnnotableSpec(
initable=is_initable,
hashable=is_hashable,
singleton=is_singleton,
immutable=is_immutable,
signature=signature,
attributes=attributes,
Expand All @@ -778,9 +810,14 @@ def __call__(cls, *args, **kwargs):
return spec.new(cython.cast(type, cls), args, kwargs)


class Immutable:
__slots__ = ()
class Singleton(Abstract):
"""Cache instances of the class based on instantiation arguments."""

__instances__: Mapping[Any, Self] = WeakValueDictionary()
__slots__ = ("__weakref__",)


class Immutable(Abstract):
def __copy__(self):
return self

Expand All @@ -794,7 +831,7 @@ def __setattr__(self, name: str, _: Any) -> None:
)


class Hashable:
class Hashable(Abstract):
__slots__ = ("__args__", "__precomputed_hash__")

def __hash__(self) -> int:
Expand All @@ -809,13 +846,11 @@ def __eq__(self, other) -> bool:
)


class Annotable(metaclass=AnnotableMeta, initable=False):
class Annotable(Abstract, metaclass=AnnotableMeta, initable=False):
__argnames__: ClassVar[tuple[str, ...]]
__match_args__: ClassVar[tuple[str, ...]]
__signature__: ClassVar[Signature]

__slots__ = ("__weakref__",)

def __init__(self, **kwargs):
spec: AnnotableSpec = self.__spec__
for name, value in kwargs.items():
Expand Down
85 changes: 78 additions & 7 deletions koerce/tests/test_annots.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import pytest
from typing_extensions import Self

from koerce._internal import (
from koerce import (
EMPTY,
Abstract,
AbstractMeta,
Annotable,
AnnotableMeta,
Expand Down Expand Up @@ -1269,7 +1270,7 @@ class AnnImm(Annotable, immutable=True):
lower = optional(is_int, default=0)
upper = optional(is_int, default=None)

assert AnnImm.__mro__ == (AnnImm, Immutable, Annotable, object)
assert AnnImm.__mro__ == (AnnImm, Immutable, Annotable, Abstract, object)

obj = AnnImm(3, lower=0, upper=4)
with pytest.raises(AttributeError):
Expand Down Expand Up @@ -1851,6 +1852,7 @@ def test_hashable():
Hashable,
Immutable,
Annotable,
Abstract,
object,
)

Expand Down Expand Up @@ -1878,10 +1880,6 @@ def test_hashable():
# hashable
assert {obj: 1}.get(obj) == 1

# weakrefable
ref = weakref.ref(obj)
assert ref() == obj

# serializable
assert pickle.loads(pickle.dumps(obj)) == obj

Expand Down Expand Up @@ -1954,7 +1952,7 @@ class Example(Annotable):


def test_abstract_meta():
class Foo(metaclass=AbstractMeta):
class Foo(Abstract):
@abstractmethod
def foo(self): ...

Expand Down Expand Up @@ -2154,3 +2152,76 @@ class User(Annotable):
assert User.__spec__.initable is False
assert User.__spec__.immutable is False
assert User.__spec__.hashable is False


def test_arg_and_hash_precomputed_before_attributes():
class Frozen(Annotable, immutable=True, hashable=True):
arg: int

@attribute
def a(self):
assert self.__args__ == (1,)
assert isinstance(self.__precomputed_hash__, int)
return "ok"

assert Frozen(1).a == "ok"


class OneAndOnly(Annotable, singleton=True):
__instances__ = weakref.WeakValueDictionary()


class DataType(Annotable, singleton=True):
__instances__ = weakref.WeakValueDictionary()
nullable: bool = True


def test_singleton_basics():
one = OneAndOnly()
only = OneAndOnly()
assert one is only

assert len(OneAndOnly.__instances__) == 1
key = (OneAndOnly,)
assert OneAndOnly.__instances__[key] is one


def test_singleton_lifetime() -> None:
one = OneAndOnly()
assert len(OneAndOnly.__instances__) == 1

del one
assert len(OneAndOnly.__instances__) == 0


def test_singleton_with_argument() -> None:
dt1 = DataType(nullable=True)
dt2 = DataType(nullable=False)
dt3 = DataType(nullable=True)

assert dt1 is dt3
assert dt1 is not dt2
assert len(DataType.__instances__) == 2

del dt3
assert len(DataType.__instances__) == 2
del dt1
assert len(DataType.__instances__) == 1
del dt2
assert len(DataType.__instances__) == 0


def test_singleton_looked_after_validation() -> None:
class Single(Annotable, singleton=True):
value: As[int]

# arguments looked up after validation
obj1 = Single("1")
obj2 = Single(2)
assert Single("1") is obj1
assert Single(1) is obj1
assert Single(1.0) is obj1
assert Single(2) is obj2
assert Single("2") is obj2
assert obj2 is not obj1
assert Single("3") is Single(3.0)
3 changes: 3 additions & 0 deletions koerce/tests/test_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def test_msgspec(benchmark):


def test_annotated(benchmark):
assert KUser.__spec__.initable is False
assert KUser.__spec__.singleton is False

r2 = benchmark.pedantic(
KUser,
args=(),
Expand Down
13 changes: 13 additions & 0 deletions koerce/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,16 @@ def rewind(self):
def checkpoint(self):
"""Create a checkpoint of the current iterator state."""
self._iterator, self._checkpoint = itertools.tee(self._iterator)


# def format_typehint(typ: Any) -> str:
# if isinstance(typ, type):
# return typ.__name__
# elif isinstance(typ, TypeVar):
# if typ.__bound__ is None:
# return str(typ)
# else:
# return format_typehint(typ.__bound__)
# else:
# # remove the module name from the typehint, including generics
# return re.sub(r"(\w+\.)+", "", str(typ))

0 comments on commit 16ffcf8

Please sign in to comment.