From 8409c3cbcf0cab5d6758f150540b2069e250939d Mon Sep 17 00:00:00 2001 From: Frank Colson Date: Tue, 16 Jul 2024 00:25:31 -0600 Subject: [PATCH] Fix decorator --- docprompt/_decorators.py | 194 +++++++++++++++++++++++++-------- docprompt/utils/async_utils.py | 27 ----- tests/test_decorators.py | 105 ++++++++++++++++++ 3 files changed, 253 insertions(+), 73 deletions(-) delete mode 100644 docprompt/utils/async_utils.py diff --git a/docprompt/_decorators.py b/docprompt/_decorators.py index b5e9815..6921024 100644 --- a/docprompt/_decorators.py +++ b/docprompt/_decorators.py @@ -1,67 +1,169 @@ import asyncio -from functools import wraps -from typing import Callable, Tuple, Type +import sys +from functools import update_wrapper, wraps +from typing import Callable, Optional, Set, Tuple, Type -from docprompt.utils.async_utils import to_thread +if sys.version_info >= (3, 9): + to_thread = asyncio.to_thread +else: + def to_thread(func, /, *args, **kwargs): + @wraps(func) + async def wrapper(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # If there's no running event loop, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return await loop.run_in_executor(None, func, *args, **kwargs) -def flexible_methods(*method_groups: Tuple[str, str]): - def decorator(cls: Type): - def get_method(cls: Type, name: str) -> Callable: - return cls.__dict__.get(name) + return wrapper() - def validate_method(name: str, method: Callable, expected_async: bool): - if method is None: - return - is_async = asyncio.iscoroutinefunction(method) - if is_async != expected_async: - return f"Method '{name}' in {cls.__name__} should be {'async' if expected_async else 'sync'}, but it's {'async' if is_async else 'sync'}" - return None - def apply_flexible_methods(cls: Type): - errors = [] +def get_closest_attr(cls: Type, attr_name: str) -> Tuple[Type, Optional[Callable], int]: + closest_cls = cls + attr = getattr(cls.__dict__, attr_name, None) + depth = 0 - is_abstract = getattr(getattr(cls, "Meta", None), "abstract", False) + if attr and hasattr(attr, "_original"): + attr = None + elif attr: + return (cls, attr, 0) - for group in method_groups: - if len(group) != 2: - errors.append( - f"Invalid method group {group}. Each group must be a tuple of exactly two method names." - ) - continue + for idx, base in enumerate(cls.__mro__, start=1): + if not attr and attr_name in base.__dict__: + if not hasattr(base.__dict__[attr_name], "_original"): + closest_cls = base + attr = base.__dict__[attr_name] + depth = idx - sync_name, async_name = group - sync_method = cls.__dict__.get(sync_name) - async_method = cls.__dict__.get(async_name) + if attr: + break - sync_error = validate_method(sync_name, sync_method, False) - if sync_error: - errors.append(sync_error) + return (closest_cls, attr, depth) - async_error = validate_method(async_name, async_method, True) - if async_error: - errors.append(async_error) - if not sync_method and not async_method and not is_abstract: - errors.append( - f"{cls.__name__} must implement at least one of these methods: {sync_name}, {async_name}" - ) +def validate_method(cls, name: str, method: Callable, expected_async: bool): + if method is None: + return None + is_async = asyncio.iscoroutinefunction(method) + if is_async != expected_async: + return f"Method '{name}' in {cls.__name__} should be {'async' if expected_async else 'sync'}, but it's {'async' if is_async else 'sync'}" + + return None + + +def apply_dual_methods_to_cls(cls: Type, method_group: Tuple[str, str]): + errors = [] + + sync_name, async_name = method_group + + sync_trace = get_closest_attr(cls, sync_name) + async_trace = get_closest_attr(cls, async_name) + + sync_cls, sync_method, sync_depth = sync_trace + async_cls, async_method, async_depth = async_trace + + if sync_method: + sync_error = validate_method(cls, sync_name, sync_method, False) + if sync_error: + errors.append(sync_error) + + if async_method: + async_error = validate_method(cls, async_name, async_method, True) + if async_error: + errors.append(async_error) + + if ( + sync_method is None + and async_method is None + and not getattr(getattr(cls, "Meta", None), "abstract", False) + ): + return [ + f"{cls.__name__} must implement at least one of these methods: {sync_name}, {async_name}" + ] + + if sync_cls is cls and async_cls is cls and sync_method and async_method: + return errors # Both methods are already in the same class + + if async_cls is cls and async_method: + + def sync_wrapper(*args, **kwargs): + return asyncio.run(async_method(*args, **kwargs)) + + update_wrapper(sync_wrapper, async_method) + + sync_wrapper._original = async_method + + setattr(cls, sync_name, sync_wrapper) + elif sync_cls is cls and sync_method: - if sync_method and not async_method: + async def async_wrapper(*args, **kwargs): + if hasattr(sync_method, "__func__"): + return await to_thread(sync_method, *args, **kwargs) + return await to_thread(sync_method, *args, **kwargs) - @wraps(sync_method) - async def async_wrapper(*args, **kwargs): - return await to_thread(sync_method, *args, **kwargs) + update_wrapper(async_wrapper, sync_method) - setattr(cls, async_name, async_wrapper) + async_wrapper._original = sync_method - elif async_method and not sync_method: + setattr(cls, async_name, async_wrapper) + else: + if async_depth < sync_depth: - @wraps(async_method) - def sync_wrapper(*args, **kwargs): - return asyncio.run(async_method(*args, **kwargs)) + def sync_wrapper(*args, **kwargs): + return asyncio.run(async_method(*args, **kwargs)) + + update_wrapper(sync_wrapper, async_method) + + sync_wrapper._original = async_method + + setattr(cls, sync_name, sync_wrapper) + else: + + async def async_wrapper(*args, **kwargs): + return await to_thread(sync_method, *args, **kwargs) + + update_wrapper(async_wrapper, sync_method) + + async_wrapper._original = sync_method + + setattr(cls, async_name, async_wrapper) + + return errors + + +def get_flexible_method_configs(cls: Type) -> Set[Tuple[str, str]]: + all = set() + for base in cls.__mro__: + all.update(getattr(base, "__flexible_methods__", set())) + + return all + + +def flexible_methods(*method_groups: Tuple[str, str]): + def decorator(cls: Type): + if not hasattr(cls, "__flexible_methods__"): + setattr(cls, "__flexible_methods__", set()) + + for base in cls.__bases__: + if hasattr(base, "__flexible_methods__"): + cls.__flexible_methods__.update(base.__flexible_methods__) + + cls.__flexible_methods__.update(method_groups) + + def apply_flexible_methods(cls: Type): + errors = [] + + for group in get_flexible_method_configs(cls): + if len(group) != 2: + errors.append( + f"Invalid method group {group}. Each group must be a tuple of exactly two method names." + ) + continue - setattr(cls, sync_name, sync_wrapper) + errors.extend(apply_dual_methods_to_cls(cls, group)) if errors: raise TypeError("\n".join(errors)) diff --git a/docprompt/utils/async_utils.py b/docprompt/utils/async_utils.py deleted file mode 100644 index 0427189..0000000 --- a/docprompt/utils/async_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import asyncio -import sys -from functools import wraps - - -def get_to_thread(): - if sys.version_info >= (3, 9): - return asyncio.to_thread - else: - - def to_thread(func, /, *args, **kwargs): - @wraps(func) - async def wrapper(): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # If there's no running event loop, create a new one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return await loop.run_in_executor(None, func, *args, **kwargs) - - return wrapper() - - return to_thread - - -to_thread = get_to_thread() diff --git a/tests/test_decorators.py b/tests/test_decorators.py index e1bf947..09eaf4f 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -236,3 +236,108 @@ def method(self): assert child2.method() == "child2_sync" assert run_async(child2.method_async()) == "child2_sync" + + +def test_abstract_base_class(): + from abc import ABC, abstractmethod + + @flexible_methods(("abstract_method", "abstract_method_async")) + class AbstractBase(ABC): + @abstractmethod + def abstract_method(self): + pass + + class ConcreteSync(AbstractBase): + def abstract_method(self): + return "concrete_sync" + + class ConcreteAsync(AbstractBase): + async def abstract_method_async(self): + return "concrete_async" + + with pytest.raises(TypeError): + AbstractBase() + + sync_instance = ConcreteSync() + assert sync_instance.abstract_method() == "concrete_sync" + assert asyncio.run(sync_instance.abstract_method_async()) == "concrete_sync" + + async_instance = ConcreteAsync() + assert async_instance.abstract_method() == "concrete_async" + assert asyncio.run(async_instance.abstract_method_async()) == "concrete_async" + + +def test_multiple_inheritance(): + @flexible_methods(("method1", "method1_async")) + class Base1: + def method1(self): + return "base1" + + @flexible_methods(("method2", "method2_async")) + class Base2: + async def method2_async(self): + return "base2" + + class Child(Base1, Base2): + async def method1_async(self): + return "child1" + + def method2(self): + return "child2" + + child = Child() + assert child.method1() == "child1" + assert asyncio.run(child.method1_async()) == "child1" + assert child.method2() == "child2" + assert asyncio.run(child.method2_async()) == "child2" + + # Test that Base1 and Base2 methods are not affected + base1 = Base1() + base2 = Base2() + assert base1.method1() == "base1" + assert asyncio.run(base1.method1_async()) == "base1" + assert asyncio.run(base2.method2_async()) == "base2" + assert base2.method2() == "base2" + + +def test_preserve_signature_and_docstring(run_async): + @flexible_methods(("method", "method_async")) + class PreserveMetadata: + def method(self, arg1: int, arg2: str = "default") -> str: + """This is a test method.""" + return f"{arg1} {arg2}" + + instance = PreserveMetadata() + assert instance.method.__doc__ == "This is a test method." + assert instance.method.__annotations__ == {"arg1": int, "arg2": str, "return": str} + assert instance.method_async.__doc__ == "This is a test method." + assert instance.method_async.__annotations__ == { + "arg1": int, + "arg2": str, + "return": str, + } + + assert instance.method(1, "test") == "1 test" + assert run_async(instance.method_async(2, "async")) == "2 async" + + +@pytest.mark.skip(reason="Not implemented yet") +def test_static_methods(): + @flexible_methods( + ("class_method", "class_method_async"), ("static_method", "static_method_async") + ) + class MethodTypes: + @classmethod + def class_method(cls): + return f"class {cls.__name__}" + + @staticmethod + def static_method(): + return "static" + + assert MethodTypes.static_method() == "static" + assert asyncio.run(MethodTypes.static_method_async()) == "static" + + instance = MethodTypes() + assert instance.static_method() == "static" + assert asyncio.run(instance.static_method_async()) == "static"