Skip to content

Commit

Permalink
Remove two reference cycles and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ojw28 authored and patrick-kidger committed Oct 22, 2024
1 parent c5d01ba commit d22a0a8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
9 changes: 8 additions & 1 deletion jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import itertools as it
import sys
import warnings
import weakref
from collections.abc import Callable
from typing import (
Any,
Expand Down Expand Up @@ -519,13 +520,16 @@ def wrapped_fn_impl(args, kwargs, bound, memos):

return out

wrapped_fn_holder = [] # Avoids introducing a reference cycle.

@ft.wraps(fn)
def wrapped_fn(*args, **kwargs):
__tracebackhide__ = True

if (
config.jaxtyping_disable
or getattr(fn, "__no_type_check__", False)
or getattr(wrapped_fn, "__no_type_check__", False)
or getattr(wrapped_fn_holder[0](), "__no_type_check__", False)
):
return fn(*args, **kwargs)

Expand All @@ -542,6 +546,8 @@ def wrapped_fn(*args, **kwargs):
finally:
pop_shape_memo()

wrapped_fn_holder.append(weakref.ref(wrapped_fn))

return wrapped_fn


Expand Down Expand Up @@ -723,6 +729,7 @@ def _make_fn_with_signature(
fnstr = f"def {name}({argstr}){retstr}:\n {outstr}"
exec(fnstr, scope)
fn = scope[name]
del scope[name] # Avoids introducing a reference cycle.
fn.__module__ = module
fn.__qualname__ = qualname
assert fn is not None
Expand Down
35 changes: 35 additions & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import contextlib
import gc
from collections.abc import Callable
from typing import Any

import equinox as eqx
import typeguard

Expand Down Expand Up @@ -48,3 +53,33 @@
@eqx.filter_jit
def make_mlp(key):
return eqx.nn.MLP(2, 2, 2, 2, key=key)


@contextlib.contextmanager
def assert_no_garbage(
allowed_garbage_predicate: Callable[[Any], bool] = lambda _: False,
):
try:
gc.disable()
gc.collect()
# It's unclear why, but a second GC is necessary to fully collect
# existing garbage.
gc.collect()
gc.garbage.clear()

yield

# Do a GC collection, saving collected objects in gc.garbage.
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()

disallowed_garbage = [
obj for obj in gc.garbage if not allowed_garbage_predicate(obj)
]
assert not disallowed_garbage
finally:
# Reset the GC back to normal.
gc.set_debug(0)
gc.garbage.clear()
gc.collect()
gc.enable()
30 changes: 29 additions & 1 deletion test/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import abc
import dataclasses
from typing import no_type_check

import jax.numpy as jnp
import jax.random as jr
import pytest
import typeguard

from jaxtyping import Array, Float, jaxtyped, print_bindings

from .helpers import ParamError, ReturnError
from .helpers import assert_no_garbage, ParamError, ReturnError


class M(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -208,3 +210,29 @@ def g(x: Float[Array, "foo bar"]):

f("not an array")
g("not an array")


def test_no_garbage(typecheck):
if typecheck is typeguard.typechecked:
# Currently fails due to reference cycles in typeguard.
pytest.skip()

with assert_no_garbage():

@jaxtyped(typechecker=typecheck)
@dataclasses.dataclass
class _Obj:
x: int

_Obj(x=5)


def test_no_garbage_identity_typecheck():
with assert_no_garbage():

@jaxtyped(typechecker=lambda x: x)
@dataclasses.dataclass
class _Obj:
x: int

_Obj(x=5)

0 comments on commit d22a0a8

Please sign in to comment.