Skip to content

Commit

Permalink
Fix decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
PSU3D0 committed Jul 16, 2024
1 parent b396f1d commit 8409c3c
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 73 deletions.
194 changes: 148 additions & 46 deletions docprompt/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,169 @@
import asyncio
from functools import wraps
from typing import Callable, Tuple, Type
import sys
from functools import update_wrapper, wraps
from typing import Callable, Optional, Set, Tuple, Type

from docprompt.utils.async_utils import to_thread
if sys.version_info >= (3, 9):
to_thread = asyncio.to_thread
else:

def to_thread(func, /, *args, **kwargs):
@wraps(func)
async def wrapper():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# If there's no running event loop, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return await loop.run_in_executor(None, func, *args, **kwargs)

def flexible_methods(*method_groups: Tuple[str, str]):
def decorator(cls: Type):
def get_method(cls: Type, name: str) -> Callable:
return cls.__dict__.get(name)
return wrapper()

def validate_method(name: str, method: Callable, expected_async: bool):
if method is None:
return
is_async = asyncio.iscoroutinefunction(method)
if is_async != expected_async:
return f"Method '{name}' in {cls.__name__} should be {'async' if expected_async else 'sync'}, but it's {'async' if is_async else 'sync'}"
return None

def apply_flexible_methods(cls: Type):
errors = []
def get_closest_attr(cls: Type, attr_name: str) -> Tuple[Type, Optional[Callable], int]:
closest_cls = cls
attr = getattr(cls.__dict__, attr_name, None)
depth = 0

is_abstract = getattr(getattr(cls, "Meta", None), "abstract", False)
if attr and hasattr(attr, "_original"):
attr = None
elif attr:
return (cls, attr, 0)

for group in method_groups:
if len(group) != 2:
errors.append(
f"Invalid method group {group}. Each group must be a tuple of exactly two method names."
)
continue
for idx, base in enumerate(cls.__mro__, start=1):
if not attr and attr_name in base.__dict__:
if not hasattr(base.__dict__[attr_name], "_original"):
closest_cls = base
attr = base.__dict__[attr_name]
depth = idx

sync_name, async_name = group
sync_method = cls.__dict__.get(sync_name)
async_method = cls.__dict__.get(async_name)
if attr:
break

sync_error = validate_method(sync_name, sync_method, False)
if sync_error:
errors.append(sync_error)
return (closest_cls, attr, depth)

async_error = validate_method(async_name, async_method, True)
if async_error:
errors.append(async_error)

if not sync_method and not async_method and not is_abstract:
errors.append(
f"{cls.__name__} must implement at least one of these methods: {sync_name}, {async_name}"
)
def validate_method(cls, name: str, method: Callable, expected_async: bool):
if method is None:
return None
is_async = asyncio.iscoroutinefunction(method)
if is_async != expected_async:
return f"Method '{name}' in {cls.__name__} should be {'async' if expected_async else 'sync'}, but it's {'async' if is_async else 'sync'}"

return None


def apply_dual_methods_to_cls(cls: Type, method_group: Tuple[str, str]):
errors = []

sync_name, async_name = method_group

sync_trace = get_closest_attr(cls, sync_name)
async_trace = get_closest_attr(cls, async_name)

sync_cls, sync_method, sync_depth = sync_trace
async_cls, async_method, async_depth = async_trace

if sync_method:
sync_error = validate_method(cls, sync_name, sync_method, False)
if sync_error:
errors.append(sync_error)

if async_method:
async_error = validate_method(cls, async_name, async_method, True)
if async_error:
errors.append(async_error)

if (
sync_method is None
and async_method is None
and not getattr(getattr(cls, "Meta", None), "abstract", False)
):
return [
f"{cls.__name__} must implement at least one of these methods: {sync_name}, {async_name}"
]

if sync_cls is cls and async_cls is cls and sync_method and async_method:
return errors # Both methods are already in the same class

if async_cls is cls and async_method:

def sync_wrapper(*args, **kwargs):
return asyncio.run(async_method(*args, **kwargs))

update_wrapper(sync_wrapper, async_method)

sync_wrapper._original = async_method

setattr(cls, sync_name, sync_wrapper)
elif sync_cls is cls and sync_method:

if sync_method and not async_method:
async def async_wrapper(*args, **kwargs):
if hasattr(sync_method, "__func__"):
return await to_thread(sync_method, *args, **kwargs)
return await to_thread(sync_method, *args, **kwargs)

@wraps(sync_method)
async def async_wrapper(*args, **kwargs):
return await to_thread(sync_method, *args, **kwargs)
update_wrapper(async_wrapper, sync_method)

setattr(cls, async_name, async_wrapper)
async_wrapper._original = sync_method

elif async_method and not sync_method:
setattr(cls, async_name, async_wrapper)
else:
if async_depth < sync_depth:

@wraps(async_method)
def sync_wrapper(*args, **kwargs):
return asyncio.run(async_method(*args, **kwargs))
def sync_wrapper(*args, **kwargs):
return asyncio.run(async_method(*args, **kwargs))

update_wrapper(sync_wrapper, async_method)

sync_wrapper._original = async_method

setattr(cls, sync_name, sync_wrapper)
else:

async def async_wrapper(*args, **kwargs):
return await to_thread(sync_method, *args, **kwargs)

update_wrapper(async_wrapper, sync_method)

async_wrapper._original = sync_method

setattr(cls, async_name, async_wrapper)

return errors


def get_flexible_method_configs(cls: Type) -> Set[Tuple[str, str]]:
all = set()
for base in cls.__mro__:
all.update(getattr(base, "__flexible_methods__", set()))

return all


def flexible_methods(*method_groups: Tuple[str, str]):
def decorator(cls: Type):
if not hasattr(cls, "__flexible_methods__"):
setattr(cls, "__flexible_methods__", set())

for base in cls.__bases__:
if hasattr(base, "__flexible_methods__"):
cls.__flexible_methods__.update(base.__flexible_methods__)

cls.__flexible_methods__.update(method_groups)

def apply_flexible_methods(cls: Type):
errors = []

for group in get_flexible_method_configs(cls):
if len(group) != 2:
errors.append(
f"Invalid method group {group}. Each group must be a tuple of exactly two method names."
)
continue

setattr(cls, sync_name, sync_wrapper)
errors.extend(apply_dual_methods_to_cls(cls, group))

if errors:
raise TypeError("\n".join(errors))
Expand Down
27 changes: 0 additions & 27 deletions docprompt/utils/async_utils.py

This file was deleted.

105 changes: 105 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,108 @@ def method(self):

assert child2.method() == "child2_sync"
assert run_async(child2.method_async()) == "child2_sync"


def test_abstract_base_class():
from abc import ABC, abstractmethod

@flexible_methods(("abstract_method", "abstract_method_async"))
class AbstractBase(ABC):
@abstractmethod
def abstract_method(self):
pass

class ConcreteSync(AbstractBase):
def abstract_method(self):
return "concrete_sync"

class ConcreteAsync(AbstractBase):
async def abstract_method_async(self):
return "concrete_async"

with pytest.raises(TypeError):
AbstractBase()

sync_instance = ConcreteSync()
assert sync_instance.abstract_method() == "concrete_sync"
assert asyncio.run(sync_instance.abstract_method_async()) == "concrete_sync"

async_instance = ConcreteAsync()
assert async_instance.abstract_method() == "concrete_async"
assert asyncio.run(async_instance.abstract_method_async()) == "concrete_async"


def test_multiple_inheritance():
@flexible_methods(("method1", "method1_async"))
class Base1:
def method1(self):
return "base1"

@flexible_methods(("method2", "method2_async"))
class Base2:
async def method2_async(self):
return "base2"

class Child(Base1, Base2):
async def method1_async(self):
return "child1"

def method2(self):
return "child2"

child = Child()
assert child.method1() == "child1"
assert asyncio.run(child.method1_async()) == "child1"
assert child.method2() == "child2"
assert asyncio.run(child.method2_async()) == "child2"

# Test that Base1 and Base2 methods are not affected
base1 = Base1()
base2 = Base2()
assert base1.method1() == "base1"
assert asyncio.run(base1.method1_async()) == "base1"
assert asyncio.run(base2.method2_async()) == "base2"
assert base2.method2() == "base2"


def test_preserve_signature_and_docstring(run_async):
@flexible_methods(("method", "method_async"))
class PreserveMetadata:
def method(self, arg1: int, arg2: str = "default") -> str:
"""This is a test method."""
return f"{arg1} {arg2}"

instance = PreserveMetadata()
assert instance.method.__doc__ == "This is a test method."
assert instance.method.__annotations__ == {"arg1": int, "arg2": str, "return": str}
assert instance.method_async.__doc__ == "This is a test method."
assert instance.method_async.__annotations__ == {
"arg1": int,
"arg2": str,
"return": str,
}

assert instance.method(1, "test") == "1 test"
assert run_async(instance.method_async(2, "async")) == "2 async"


@pytest.mark.skip(reason="Not implemented yet")
def test_static_methods():
@flexible_methods(
("class_method", "class_method_async"), ("static_method", "static_method_async")
)
class MethodTypes:
@classmethod
def class_method(cls):
return f"class {cls.__name__}"

@staticmethod
def static_method():
return "static"

assert MethodTypes.static_method() == "static"
assert asyncio.run(MethodTypes.static_method_async()) == "static"

instance = MethodTypes()
assert instance.static_method() == "static"
assert asyncio.run(instance.static_method_async()) == "static"

0 comments on commit 8409c3c

Please sign in to comment.