Skip to content

Commit

Permalink
Now supports using Any and TypeVars as the array type.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jul 2, 2024
1 parent 6236167 commit 91cfc9b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
16 changes: 13 additions & 3 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import types
import typing
from dataclasses import dataclass
from typing import Any, Literal, NoReturn, Optional, Union
from typing import Any, Literal, NoReturn, Optional, TypeVar, Union


# Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means
Expand Down Expand Up @@ -173,8 +173,12 @@ def __instancecheck__(cls, obj: Any) -> bool:
def __instancecheck_str__(cls, obj: Any) -> str:
if cls._skip_instancecheck:
return ""
if not isinstance(obj, cls.array_type):
return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501
if cls.array_type is Any:
if not (hasattr(obj, "shape") and hasattr(obj, "dtype")):
return "this value does not have both `shape` and `dtype` attributes."
else:
if not isinstance(obj, cls.array_type):
return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501
if get_treeflatten_memo():
return ""

Expand Down Expand Up @@ -622,6 +626,12 @@ def __getitem__(cls, item: tuple[Any, str]):
)
array_type, dim_str = item
dim_str = dim_str.strip()
if isinstance(array_type, TypeVar):
bound = array_type.__bound__
if bound is None:
array_type = Any
else:
array_type = bound
del item
if typing.get_origin(array_type) in _union_types:
out = [
Expand Down
66 changes: 65 additions & 1 deletion test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import dataclasses as dc
import sys
from typing import get_args, get_origin, Union
from typing import Any, get_args, get_origin, TypeVar, Union

import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -766,3 +766,67 @@ def h(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray3, "4"]):

with pytest.raises(ParamError):
g(MyArray1(), MyArray3())


@pytest.mark.parametrize(
"array_type", [Any, TypeVar("T"), TypeVar("T", bound=ArrayLike)]
)
def test_any(array_type, jaxtyp, typecheck):
class DuckArray1:
@property
def shape(self):
return 3, 4

@property
def dtype(self):
return np.array([], dtype=np.float32).dtype

class DuckArray2:
@property
def shape(self):
return 3, 4, 5

@property
def dtype(self):
return np.array([], dtype=np.float32).dtype

class DuckArray3:
@property
def shape(self):
return 3, 4

@property
def dtype(self):
return np.array([], dtype=np.int32).dtype

@jaxtyp(typecheck)
def f(x: Float[array_type, "foo bar"]):
del x

f(np.arange(12.0).reshape(3, 4))
f(jnp.arange(12.0).reshape(3, 4))
if isinstance(array_type, TypeVar) and array_type.__bound__ is ArrayLike:
with pytest.raises(ParamError):
f(DuckArray1())
else:
f(DuckArray1())

# Wrong shape
with pytest.raises(ParamError):
f(np.arange(12.0).reshape(3, 2, 2))
with pytest.raises(ParamError):
f(jnp.arange(12.0).reshape(3, 2, 2))
with pytest.raises(ParamError):
f(DuckArray2())

# Wrong dtype
with pytest.raises(ParamError):
f(np.arange(12).reshape(3, 4))
with pytest.raises(ParamError):
f(jnp.arange(12).reshape(3, 4))
with pytest.raises(ParamError):
f(DuckArray3())

# Not an array
with pytest.raises(ParamError):
f(1)

0 comments on commit 91cfc9b

Please sign in to comment.