diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index a0f7f0ca..8dfa47fe 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,6 +7,8 @@ This library adheres to `Semantic Versioning 2.0 FunctionType | str: return "__module__ attribute is not set" elif f.__code__.co_filename == "": return "cannot instrument functions defined in a REPL" + elif hasattr(f, "__wrapped__"): + return ( + "@typechecked only supports instrumenting functions wrapped with " + "@classmethod, @staticmethod or @property" + ) target_path = [item for item in f.__qualname__.split(".") if item != ""] module_source = inspect.getsource(sys.modules[f.__module__]) @@ -162,29 +167,15 @@ def typechecked(target: T_CallableOrType | None = None) -> Any: return target # Find either the first Python wrapper or the actual function - func: FunctionType wrapper_class: type[classmethod[Any]] | type[staticmethod[Any]] | None = None if isinstance(target, (classmethod, staticmethod)): wrapper_class = target.__class__ target = target.__func__ - if hasattr(target, "__wrapped__"): - warn( - f"Cannot instrument {function_name(target)} -- @typechecked only supports " - f"instrumenting functions wrapped with @classmethod, @staticmethod or " - f"@property", - InstrumentationWarning, - ) - return target - elif isfunction(target): - func = target - else: - raise TypeError("target is not a function or a supported wrapper") - - retval = instrument(func) + retval = instrument(target) if isinstance(retval, str): warn( - f"{retval} -- not typechecking {function_name(func)}", + f"{retval} -- not typechecking {function_name(target)}", InstrumentationWarning, ) return target diff --git a/tests/test_typechecked.py b/tests/test_typechecked.py index fc1c82e3..6862f923 100644 --- a/tests/test_typechecked.py +++ b/tests/test_typechecked.py @@ -1,5 +1,6 @@ import asyncio import sys +from contextlib import contextmanager from textwrap import dedent from typing import ( Any, @@ -580,3 +581,14 @@ def foo() -> Internal: return Internal() assert isinstance(foo(), Internal) + + +def test_existing_method_decorator(): + @typechecked + class Foo: + @contextmanager + def method(self, x: int) -> None: + yield x + 1 + + with Foo().method(6) as value: + assert value == 7