Skip to content

Commit

Permalink
feat: cache init signatures (#454)
Browse files Browse the repository at this point in the history
* feat: cache init signatures

* chore: `black .`

* chore: update comments

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Nov 28, 2024
1 parent 00f6605 commit 39f0497
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 60 deletions.
105 changes: 60 additions & 45 deletions a_sync/a_sync/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import functools
import inspect
from contextlib import suppress
from logging import DEBUG, getLogger
from libc.stdint cimport uintptr_t

from a_sync import exceptions
from a_sync._typing import *
Expand Down Expand Up @@ -57,10 +58,43 @@ class ASyncGenericBase(ASyncABC):
seamless usage in different contexts without changing the underlying implementation.
"""

@classmethod # type: ignore [misc]
def __a_sync_default_mode__(cls) -> bint: # type: ignore [override]
cdef object flag
cdef bint flag_value
if not c_logger.isEnabledFor(DEBUG):
# we can optimize this if we dont need to log `flag` and the return value
try:
flag = _get_a_sync_flag_name_from_signature(cls, False)
flag_value = _a_sync_flag_default_value_from_signature(cls)
except exceptions.NoFlagsFound:
flag = _get_a_sync_flag_name_from_class_def(cls)
flag_value = _get_a_sync_flag_value_from_class_def(cls, flag)
return validate_and_negate_if_necessary(flag, flag_value)

# we need an extra var so we can log it
cdef bint sync

try:
flag = _get_a_sync_flag_name_from_signature(cls, True)
flag_value = _a_sync_flag_default_value_from_signature(cls)
except exceptions.NoFlagsFound:
flag = _get_a_sync_flag_name_from_class_def(cls)
flag_value = _get_a_sync_flag_value_from_class_def(cls, flag)

sync = validate_and_negate_if_necessary(flag, flag_value)
c_logger._log(
DEBUG,
"`%s.%s` indicates default mode is %ssynchronous",
(cls, flag, "a" if sync is False else ""),
)
return sync

def __init__(self):
if type(self) is ASyncGenericBase:
raise NotImplementedError(
f"You should not create instances of `ASyncGenericBase` directly, you should subclass `ASyncGenericBase` instead."
"You should not create instances of `ASyncGenericBase` directly, "
"you should subclass `ASyncGenericBase` instead."
)
ASyncABC.__init__(self)

Expand All @@ -71,7 +105,7 @@ class ASyncGenericBase(ASyncABC):
if debug_logs := c_logger.isEnabledFor(DEBUG):
c_logger._log(DEBUG, "checking a_sync flag for %s", (self, ))
try:
flag = _get_a_sync_flag_name_from_signature(type(self))
flag = _get_a_sync_flag_name_from_signature(type(self), debug_logs)
except exceptions.ASyncFlagException:
# We can't get the flag name from the __init__ signature,
# but maybe the implementation sets the flag somewhere else.
Expand Down Expand Up @@ -101,38 +135,6 @@ class ASyncGenericBase(ASyncABC):
c_logger.debug("`%s.%s` is currently %s", self, flag, flag_value)
return validate_flag_value(flag, flag_value)

@classmethod # type: ignore [misc]
def __a_sync_default_mode__(cls) -> bint: # type: ignore [override]
cdef object flag
cdef bint flag_value
if not c_logger.isEnabledFor(DEBUG):
# we can optimize this if we dont need to log `flag` and the return value
try:
flag = _get_a_sync_flag_name_from_signature(cls)
flag_value = _a_sync_flag_default_value_from_signature(cls)
except exceptions.NoFlagsFound:
flag = _get_a_sync_flag_name_from_class_def(cls)
flag_value = _get_a_sync_flag_value_from_class_def(cls, flag)
return validate_and_negate_if_necessary(flag, flag_value)

# we need an extra var so we can log it
cdef bint sync

try:
flag = _get_a_sync_flag_name_from_signature(cls)
flag_value = _a_sync_flag_default_value_from_signature(cls)
except exceptions.NoFlagsFound:
flag = _get_a_sync_flag_name_from_class_def(cls)
flag_value = _get_a_sync_flag_value_from_class_def(cls, flag)

sync = validate_and_negate_if_necessary(flag, flag_value)
c_logger._log(
DEBUG,
"`%s.%s` indicates default mode is %ssynchronous",
(cls, flag, "a" if sync is False else ""),
)
return sync



cdef str _get_a_sync_flag_name_from_class_def(object cls):
Expand All @@ -149,15 +151,15 @@ cdef str _get_a_sync_flag_name_from_class_def(object cls):


cdef bint _a_sync_flag_default_value_from_signature(object cls):
cdef object signature = inspect.signature(cls.__init__)
cdef object signature = _get_init_signature(cls)
if not c_logger.isEnabledFor(DEBUG):
# we can optimize this much better
return signature.parameters[_get_a_sync_flag_name_from_signature(cls)].default
return signature.parameters[_get_a_sync_flag_name_from_signature(cls, False)].default

c_logger._log(
DEBUG, "checking `__init__` signature for default %s a_sync flag value", (cls, )
)
cdef str flag = _get_a_sync_flag_name_from_signature(cls)
cdef str flag = _get_a_sync_flag_name_from_signature(cls, True)
cdef object flag_value = signature.parameters[flag].default
if flag_value is inspect._empty: # type: ignore [attr-defined]
raise NotImplementedError(
Expand All @@ -167,20 +169,21 @@ cdef bint _a_sync_flag_default_value_from_signature(object cls):
return flag_value


cdef str _get_a_sync_flag_name_from_signature(object cls):
cdef str _get_a_sync_flag_name_from_signature(object cls, bint debug_logs):
if cls.__name__ == "ASyncGenericBase":
c_logger.debug(
"There are no flags defined on the base class, this is expected. Skipping."
)
return None
if debug_logs:
c_logger._log(
DEBUG, "There are no flags defined on the base class, this is expected. Skipping.", ()
)
return ""

# if we fail this one there's no need to check again
if not c_logger.isEnabledFor(DEBUG):
if not debug_logs:
# we can also skip assigning params to a var
return _parse_flag_name_from_list(cls, inspect.signature(cls.__init__).parameters)
return _parse_flag_name_from_list(cls, _get_init_signature(cls).parameters)

c_logger._log(DEBUG, "Searching for flags defined on %s.__init__", (cls, ))
cdef object parameters = inspect.signature(cls.__init__).parameters
cdef object parameters = _get_init_signature(cls).parameters
c_logger._log(DEBUG, "parameters: %s", (parameters, ))
return _parse_flag_name_from_list(cls, parameters)

Expand All @@ -207,3 +210,15 @@ cdef inline bint _get_a_sync_flag_value_from_class_def(object cls, str flag):
if flag in spec.__dict__:
return spec.__dict__[flag]
raise exceptions.FlagNotDefined(cls, flag)


cdef dict[uintptr_t, object] _init_signature_cache = {}


cdef _get_init_signature(object cls):
cdef uintptr_t cls_init_id = id(cls.__init__)
signature = _init_signature_cache.get(cls_init_id)
if signature is None:
signature = inspect.signature(cls.__init__)
_init_signature_cache[cls_init_id] = signature
return signature
18 changes: 14 additions & 4 deletions a_sync/a_sync/property.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ class ASyncCachedPropertyDescriptorSyncDefault(cached_property[I, T]):
This class is used for cached properties that are primarily intended to be
accessed synchronously but can also be used asynchronously if needed.
Note:
You should never create these yourself. They are automatically generated by ez-a-sync internally.
"""

default: Literal["sync"]
Expand Down Expand Up @@ -514,6 +517,9 @@ class ASyncCachedPropertyDescriptorAsyncDefault(cached_property[I, T]):
This class is used for cached properties that are primarily intended to be
accessed asynchronously but can also be used synchronously if needed.
Note:
You should never create these yourself. They are automatically generated by ez-a-sync internally.
"""

default: Literal["async"]
Expand Down Expand Up @@ -642,8 +648,10 @@ def a_sync_cached_property( # type: ignore [misc]
class HiddenMethod(ASyncBoundMethodAsyncDefault[I, Tuple[()], T]):
"""Represents a hidden method for asynchronous properties.
This class is used internally to manage hidden methods associated with
asynchronous properties.
This class is used internally to manage hidden getter methods associated with a/sync properties.
Note:
You should never create these yourself. They are automatically generated by ez-a-sync internally.
"""

def __init__(
Expand Down Expand Up @@ -689,8 +697,10 @@ class HiddenMethod(ASyncBoundMethodAsyncDefault[I, Tuple[()], T]):
class HiddenMethodDescriptor(ASyncMethodDescriptorAsyncDefault[I, Tuple[()], T]):
"""Descriptor for hidden methods associated with asynchronous properties.
This class is used internally to manage hidden methods associated with
asynchronous properties.
This class is used internally to manage hidden getter methods associated with a/sync properties.
Note:
You should never create these yourself. They are automatically generated by ez-a-sync internally.
"""

def __init__(
Expand Down
8 changes: 4 additions & 4 deletions a_sync/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ async def _debug_daemon(self, fut: asyncio.Future, fn, *args, **kwargs) -> None:

class AsyncProcessPoolExecutor(_AsyncExecutorMixin, concurrent.futures.ProcessPoolExecutor):
"""
A :class:`concurrent.futures.ProcessPoolExecutor' subclass providing asynchronous run and submit methods that support kwargs,
with support for synchronous mode
A :class:`concurrent.futures.ProcessPoolExecutor' subclass providing asynchronous
run and submit methods that support kwargs, with support for synchronous mode
Examples:
>>> executor = AsyncProcessPoolExecutor(max_workers=4)
Expand Down Expand Up @@ -231,8 +231,8 @@ def __init__(

class AsyncThreadPoolExecutor(_AsyncExecutorMixin, concurrent.futures.ThreadPoolExecutor):
"""
A :class:`concurrent.futures.ThreadPoolExecutor' subclass providing asynchronous run and submit methods that support kwargs,
with support for synchronous mode
A :class:`concurrent.futures.ThreadPoolExecutor' subclass providing asynchronous
run and submit methods that support kwargs, with support for synchronous mode
Examples:
>>> executor = AsyncThreadPoolExecutor(max_workers=10, thread_name_prefix="MyThread")
Expand Down
11 changes: 4 additions & 7 deletions a_sync/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,6 @@ def clear(self, cancel: bool = False) -> None:
"""# TODO write docs for this"""
if cancel and self._init_loader and not self._init_loader.done():
logger.debug("cancelling %s", self._init_loader)
# temporary, remove later
try:
raise Exception
except Exception as e:
logger.exception(e)
self._init_loader.cancel()
if keys := tuple(self.keys()):
logger.debug("popping remaining %s tasks", self)
Expand All @@ -515,9 +510,11 @@ def clear(self, cancel: bool = False) -> None:
def _init_loader(self) -> Optional["asyncio.Task[None]"]:
if self.__init_loader_coro:
logger.debug("starting %s init loader", self)
name = f"{type(self).__name__} init loader loading {self.__iterables__} for {self}"
try:
task = create_task(coro=self.__init_loader_coro, name=name)
task = create_task(
coro=self.__init_loader_coro,
name=f"{type(self).__name__} init loader loading {self.__iterables__} for {self}",
)
except RuntimeError as e:
raise _NoRunningLoop if str(e) == "no running event loop" else e
task.add_done_callback(self.__cleanup)
Expand Down

0 comments on commit 39f0497

Please sign in to comment.