diff --git a/equinox/_ad.py b/equinox/_ad.py index b18c19d0..6e14e514 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -76,7 +76,7 @@ class _GradWrapper(Module): @property def __wrapped__(self): - return self._fun_value_and_grad + return self._fun_value_and_grad.__wrapped__ # pyright: ignore def __call__(self, /, *args, **kwargs): value, grad = self._fun_value_and_grad(*args, **kwargs) @@ -172,7 +172,7 @@ def filter_value_and_grad( "as the first argument." ) - return module_update_wrapper(_ValueAndGradWrapper(fun, has_aux, gradkwargs), fun) + return module_update_wrapper(_ValueAndGradWrapper(fun, has_aux, gradkwargs)) @overload @@ -262,7 +262,7 @@ def grad_func(x__y): fun_value_and_grad = filter_value_and_grad(fun, has_aux=has_aux, **gradkwargs) fun_value_and_grad = cast(_ValueAndGradWrapper, fun_value_and_grad) - return module_update_wrapper(_GradWrapper(fun_value_and_grad, has_aux), fun) + return module_update_wrapper(_GradWrapper(fun_value_and_grad, has_aux)) def _is_none(x): diff --git a/equinox/_jit.py b/equinox/_jit.py index 997c6f2e..b347285c 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -294,4 +294,4 @@ def f(x, y): # both args traced if arrays, static if non-arrays _cached=cached, _filter_warning=filter_warning, ) - return module_update_wrapper(jit_wrapper, fun) + return module_update_wrapper(jit_wrapper) diff --git a/equinox/_make_jaxpr.py b/equinox/_make_jaxpr.py index ce6f9e55..0cbf2ae3 100644 --- a/equinox/_make_jaxpr.py +++ b/equinox/_make_jaxpr.py @@ -76,4 +76,4 @@ def filter_make_jaxpr( `int`, `float`, `complex`) are treated as static inputs; wrap them in JAX/NumPy arrays if you would like them to be traced. """ - return module_update_wrapper(_MakeJaxpr(fun), fun) + return module_update_wrapper(_MakeJaxpr(fun)) diff --git a/equinox/_module.py b/equinox/_module.py index 5bc01607..234f5435 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -4,7 +4,7 @@ import types import weakref from collections.abc import Callable -from typing import Any, cast, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import dataclass_transform, ParamSpec import jax.tree_util as jtu @@ -131,9 +131,7 @@ def bar(self): """ ) else: - _method = module_update_wrapper( - BoundMethod(self.method, instance), self.method - ) + _method = module_update_wrapper(BoundMethod(self.method, instance)) return _method @@ -476,10 +474,11 @@ def __repr__(self): # Modifies in-place, just like functools.update_wrapper def module_update_wrapper( - wrapper: Module, wrapped: Callable[_P, _T] + wrapper: Module, wrapped: Optional[Callable[_P, _T]] = None ) -> Callable[_P, _T]: """Like `functools.update_wrapper` (or its better-known cousin, `functools.wraps`), - but can be used on [`equinox.Module`][]s. (Which are normally immutable.) + but acts on [`equinox.Module`][]s, and does not modify its input (it returns the + updated module instead). !!! Example @@ -495,7 +494,7 @@ def __wrapped__(self): return self.fn def make_wrapper(fn): - return eqx.module_update_wrapper(Wrapper(fn), fn) + return eqx.module_update_wrapper(Wrapper(fn)) ``` For example, [`equinox.filter_jit`][] returns a module representing the JIT'd @@ -504,10 +503,29 @@ def make_wrapper(fn): Note that as in the above example, the wrapper class must supply a `__wrapped__` property, which redirects to the wrapped object. + + **Arguments:** + + - `wrapper`: the instance of the wrapper. + - `wrapped`: optional, the callable that is being wrapped. If omitted then + `wrapper.__wrapped__` will be used. + + **Returns:** + + A copy of `wrapper`, with the attributes `__module__`, `__name__`, `__qualname__`, + `__doc__`, and `__annotations__` copied over from the wrapped function. """ cls = wrapper.__class__ if not isinstance(getattr(cls, "__wrapped__", None), property): raise ValueError("Wrapper module must supply `__wrapped__` as a property.") + + if wrapped is None: + wrapped = wrapper.__wrapped__ # pyright: ignore + + # Make a clone, to avoid mutating the original input. + leaves, treedef = jtu.tree_flatten(wrapper) + wrapper = jtu.tree_unflatten(treedef, leaves) + initable_cls = _make_initable(cls, wraps=True) object.__setattr__(wrapper, "__class__", initable_cls) try: diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index 9653fb0e..d487ce82 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -406,7 +406,7 @@ def evaluate_per_ensemble(model, x): _axis_size=axis_size, _vmapkwargs=vmapkwargs, ) - return module_update_wrapper(vmap_wrapper, fun) + return module_update_wrapper(vmap_wrapper) @compile_cache @@ -722,4 +722,4 @@ def g(x, y): _filter_warning=filter_warning, _pmapkwargs=pmapkwargs, ) - return module_update_wrapper(pmap_wrapper, fun) + return module_update_wrapper(pmap_wrapper) diff --git a/equinox/internal/_noinline.py b/equinox/internal/_noinline.py index 940b11b1..0514394d 100644 --- a/equinox/internal/_noinline.py +++ b/equinox/internal/_noinline.py @@ -506,4 +506,4 @@ def abstract_fn(__dynamic_fn, *args, **kwargs): _index_to_fn.append(static_fn) dynamic_index = jnp.array(dynamic_index) noinline_fn = _NoInlineWrapper(dynamic_index, abstract_fn, dynamic_fn) - return module_update_wrapper(noinline_fn, abstract_fn) + return module_update_wrapper(noinline_fn) diff --git a/tests/test_module.py b/tests/test_module.py index 71d7a219..8ab5f8a2 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -9,6 +9,7 @@ import pytest import equinox as eqx +import equinox.internal as eqxi from .helpers import shaped_allclose @@ -271,17 +272,40 @@ class AnotherModule(MyModule): def test_wrapper_attributes(): - def f(x): - pass - - fjit = eqx.filter_jit(f) - # Gets __name__ attribute from module_update_wrapper - - @eqx.filter_jit # Flattens and unflattens - def g(k): - k.__name__ + def f(x: int) -> str: + """some doc""" + return "hi" + + def h(x: int) -> str: + """some other doc""" + return "bye" + + noinline = lambda x: eqxi.noinline(h, abstract_fn=x) + + for wrapper in ( + eqx.filter_jit, + eqx.filter_grad, + eqx.filter_value_and_grad, + eqx.filter_vmap, + eqx.filter_pmap, + noinline, + ): + f_wrap = wrapper(f) + # Gets __name__ attribute from module_update_wrapper + + called = False + + @eqx.filter_jit # Flattens and unflattens + def g(k): + nonlocal called + called = True + assert k.__name__ == "f" + assert k.__doc__ == "some doc" + assert k.__qualname__ == "test_wrapper_attributes..f" + assert k.__annotations__ == {"x": int, "return": str} - g(fjit) + g(f_wrap) + assert called # https://github.com/patrick-kidger/equinox/issues/337