Skip to content

Commit

Permalink
Now respecting typing.no_type_check decorator.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 13, 2024
1 parent 619f22e commit 6ae28a4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
6 changes: 5 additions & 1 deletion jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,11 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore

@ft.wraps(fn)
def wrapped_fn(*args, **kwargs):
if config.jaxtyping_disable:
if (
config.jaxtyping_disable
or getattr(fn, "__no_type_check__", False)
or getattr(wrapped_fn, "__no_type_check__", False)
):
return fn(*args, **kwargs)

# Raise bind-time errors before we do any shape analysis. (I.e. skip
Expand Down
9 changes: 9 additions & 0 deletions test/import_hook_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import dataclasses
from typing import no_type_check

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -244,6 +245,14 @@ def isinstance_test(x):
isinstance_test(jnp.array(1))


@no_type_check
def f(_: Float32[jnp.ndarray, "foo bar"]):
pass


f("not an array")


# Record that we've finished our checks successfully

jaxtyping._test_import_hook_counter += 1
16 changes: 16 additions & 0 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from typing import no_type_check

import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -192,3 +193,18 @@ def f(x: Float[Array, "foo bar"]):
"The current values for each jaxtyping axis annotation are as follows."
"\nfoo=3\nbar=4\n"
)


def test_no_type_check(typecheck):
@jaxtyped(typechecker=typecheck)
@no_type_check
def f(x: Float[Array, "foo bar"]):
pass

@no_type_check
@jaxtyped(typechecker=typecheck)
def g(x: Float[Array, "foo bar"]):
pass

f("not an array")
g("not an array")

0 comments on commit 6ae28a4

Please sign in to comment.