Skip to content

Commit

Permalink
module_update_wrapper now looks at __wrapped__, rather than requiring…
Browse files Browse the repository at this point in the history
… an extra argument. It no longer mutates its input.
  • Loading branch information
patrick-kidger committed Sep 11, 2023
1 parent d5de94a commit 30c411b
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 25 deletions.
6 changes: 3 additions & 3 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class _GradWrapper(Module):

@property
def __wrapped__(self):
return self._fun_value_and_grad
return self._fun_value_and_grad.__wrapped__ # pyright: ignore

def __call__(self, /, *args, **kwargs):
value, grad = self._fun_value_and_grad(*args, **kwargs)
Expand Down Expand Up @@ -172,7 +172,7 @@ def filter_value_and_grad(
"as the first argument."
)

return module_update_wrapper(_ValueAndGradWrapper(fun, has_aux, gradkwargs), fun)
return module_update_wrapper(_ValueAndGradWrapper(fun, has_aux, gradkwargs))


@overload
Expand Down Expand Up @@ -262,7 +262,7 @@ def grad_func(x__y):

fun_value_and_grad = filter_value_and_grad(fun, has_aux=has_aux, **gradkwargs)
fun_value_and_grad = cast(_ValueAndGradWrapper, fun_value_and_grad)
return module_update_wrapper(_GradWrapper(fun_value_and_grad, has_aux), fun)
return module_update_wrapper(_GradWrapper(fun_value_and_grad, has_aux))


def _is_none(x):
Expand Down
2 changes: 1 addition & 1 deletion equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,4 @@ def f(x, y): # both args traced if arrays, static if non-arrays
_cached=cached,
_filter_warning=filter_warning,
)
return module_update_wrapper(jit_wrapper, fun)
return module_update_wrapper(jit_wrapper)
2 changes: 1 addition & 1 deletion equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ def filter_make_jaxpr(
`int`, `float`, `complex`) are treated as static inputs; wrap them in JAX/NumPy
arrays if you would like them to be traced.
"""
return module_update_wrapper(_MakeJaxpr(fun), fun)
return module_update_wrapper(_MakeJaxpr(fun))
32 changes: 25 additions & 7 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import types
import weakref
from collections.abc import Callable
from typing import Any, cast, TYPE_CHECKING, TypeVar, Union
from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import dataclass_transform, ParamSpec

import jax.tree_util as jtu
Expand Down Expand Up @@ -131,9 +131,7 @@ def bar(self):
"""
)
else:
_method = module_update_wrapper(
BoundMethod(self.method, instance), self.method
)
_method = module_update_wrapper(BoundMethod(self.method, instance))
return _method


Expand Down Expand Up @@ -476,10 +474,11 @@ def __repr__(self):

# Modifies in-place, just like functools.update_wrapper
def module_update_wrapper(
wrapper: Module, wrapped: Callable[_P, _T]
wrapper: Module, wrapped: Optional[Callable[_P, _T]] = None
) -> Callable[_P, _T]:
"""Like `functools.update_wrapper` (or its better-known cousin, `functools.wraps`),
but can be used on [`equinox.Module`][]s. (Which are normally immutable.)
but acts on [`equinox.Module`][]s, and does not modify its input (it returns the
updated module instead).
!!! Example
Expand All @@ -495,7 +494,7 @@ def __wrapped__(self):
return self.fn
def make_wrapper(fn):
return eqx.module_update_wrapper(Wrapper(fn), fn)
return eqx.module_update_wrapper(Wrapper(fn))
```
For example, [`equinox.filter_jit`][] returns a module representing the JIT'd
Expand All @@ -504,10 +503,29 @@ def make_wrapper(fn):
Note that as in the above example, the wrapper class must supply a `__wrapped__`
property, which redirects to the wrapped object.
**Arguments:**
- `wrapper`: the instance of the wrapper.
- `wrapped`: optional, the callable that is being wrapped. If omitted then
`wrapper.__wrapped__` will be used.
**Returns:**
A copy of `wrapper`, with the attributes `__module__`, `__name__`, `__qualname__`,
`__doc__`, and `__annotations__` copied over from the wrapped function.
"""
cls = wrapper.__class__
if not isinstance(getattr(cls, "__wrapped__", None), property):
raise ValueError("Wrapper module must supply `__wrapped__` as a property.")

if wrapped is None:
wrapped = wrapper.__wrapped__ # pyright: ignore

# Make a clone, to avoid mutating the original input.
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)

initable_cls = _make_initable(cls, wraps=True)
object.__setattr__(wrapper, "__class__", initable_cls)
try:
Expand Down
4 changes: 2 additions & 2 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def evaluate_per_ensemble(model, x):
_axis_size=axis_size,
_vmapkwargs=vmapkwargs,
)
return module_update_wrapper(vmap_wrapper, fun)
return module_update_wrapper(vmap_wrapper)


@compile_cache
Expand Down Expand Up @@ -722,4 +722,4 @@ def g(x, y):
_filter_warning=filter_warning,
_pmapkwargs=pmapkwargs,
)
return module_update_wrapper(pmap_wrapper, fun)
return module_update_wrapper(pmap_wrapper)
2 changes: 1 addition & 1 deletion equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,4 @@ def abstract_fn(__dynamic_fn, *args, **kwargs):
_index_to_fn.append(static_fn)
dynamic_index = jnp.array(dynamic_index)
noinline_fn = _NoInlineWrapper(dynamic_index, abstract_fn, dynamic_fn)
return module_update_wrapper(noinline_fn, abstract_fn)
return module_update_wrapper(noinline_fn)
44 changes: 34 additions & 10 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import equinox as eqx
import equinox.internal as eqxi

from .helpers import shaped_allclose

Expand Down Expand Up @@ -271,17 +272,40 @@ class AnotherModule(MyModule):


def test_wrapper_attributes():
def f(x):
pass

fjit = eqx.filter_jit(f)
# Gets __name__ attribute from module_update_wrapper

@eqx.filter_jit # Flattens and unflattens
def g(k):
k.__name__
def f(x: int) -> str:
"""some doc"""
return "hi"

def h(x: int) -> str:
"""some other doc"""
return "bye"

noinline = lambda x: eqxi.noinline(h, abstract_fn=x)

for wrapper in (
eqx.filter_jit,
eqx.filter_grad,
eqx.filter_value_and_grad,
eqx.filter_vmap,
eqx.filter_pmap,
noinline,
):
f_wrap = wrapper(f)
# Gets __name__ attribute from module_update_wrapper

called = False

@eqx.filter_jit # Flattens and unflattens
def g(k):
nonlocal called
called = True
assert k.__name__ == "f"
assert k.__doc__ == "some doc"
assert k.__qualname__ == "test_wrapper_attributes.<locals>.f"
assert k.__annotations__ == {"x": int, "return": str}

g(fjit)
g(f_wrap)
assert called


# https://github.com/patrick-kidger/equinox/issues/337
Expand Down

0 comments on commit 30c411b

Please sign in to comment.