Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime type checking via typeguard causes TypeError due to array's having type DeviceArray. #33

Open
jaymody opened this issue Sep 25, 2022 · 3 comments
Labels
question User queries

Comments

@jaymody
Copy link

jaymody commented Sep 25, 2022

I'm trying to use jaxtyping with runtime type checking via typeguard as described here. Here's my code:

import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from typeguard import typechecked as typechecker


@jaxtyped
@typechecker
def foo(
    x: Float[Array, "n"],
    y: Float[Array, "n"],
) -> Float[Array, "n"]:
    return x + y

print(foo(jnp.arange(10), jnp.arange(10)))

However when I run the above script, I get the following error:

Traceback (most recent call last):
  File "/Users/jay/playground/myscript.py", line 14, in <module>
    print(foo(jnp.arange(10), jnp.arange(10)))
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/jaxtyping/decorator.py", line 41, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "x" must be jaxtyping.Float[ndarray, 'n']; got jaxlib.xla_extension.DeviceArray instead

Steps to reproduce my python environment (Note: I'm running this on an M1 Macbook Pro with macOS Monterey 12.2 (21D49)):

$ python -V
Python 3.9.10

$ python -m venv .venv

$ source .venv/bin/activate

$ python -m pip install --upgrade pip

$ python -m pip install "jax[cpu]==0.3.17" "jaxtyping==0.2.5"
@jaymody
Copy link
Author

jaymody commented Sep 25, 2022

Ah, so I'm realizing it's because jnp.arange by default returns an array of type int. If I change it to print(foo(jnp.arange(10)*1.0, jnp.arange(10)*1.0)) I no longer get an error. Wondering if the error message can be more descriptive, or if this quirk is documented somewhere? Error message is a bit misleading.

@patrick-kidger
Copy link
Owner

Right; something similar came up in #6. Indeed it would be great if the error message could include more information, but it's the typechecker that's raising the error (in this case typeguard) -- not jaxtyping. (All jaxtyping does is provide the types themselves.)

FWIW my usual approach to debugging this it to rerun with the debugger, so that I can check what types were passed myself. This can be done with either of:

python -m pdb -c continue your_script.py
ipython your_script.py --pdb

(I'd like a better solution to this too.)

@jaymody
Copy link
Author

jaymody commented Sep 25, 2022

Yeah, that's the workaround I'm using as well to check the shapes and types if an error comes up. Maybe it's worth documenting this in API.md? I missed #6 in my search for a solution (which is on me tbh), but might be useful for the next person that will inevitably come across this without thoroughly checking the issues on github.

@patrick-kidger patrick-kidger added the question User queries label Aug 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants