diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index dcd0115..8cb3dc0 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -25,7 +25,16 @@ import types import typing from dataclasses import dataclass -from typing import Any, Literal, NoReturn, Optional, TypeVar, Union +from typing import ( + Any, + get_args, + get_origin, + Literal, + NoReturn, + Optional, + TypeVar, + Union, +) # Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means @@ -358,7 +367,7 @@ class for `Float32[Array, "foo"]`. _not_made = object() -_union_types = [typing.Union] +_union_types = [Union] if sys.version_info >= (3, 10): _union_types.append(types.UnionType) @@ -517,6 +526,9 @@ def _make_array_cached(array_type, dim_str, dtypes, name): # Allow Python built-in numeric types. # TODO: do something more generic than this? Should we _make all types # that have `shape` and `dtype` attributes or something? + array_origin = get_origin(array_type) + if array_origin is not None: + array_type = array_origin if array_type is bool: if _check_scalar("bool", dtypes, dims): return array_type @@ -547,7 +559,7 @@ def _make_array_cached(array_type, dim_str, dtypes, name): return array_type else: return _not_made - if issubclass(array_type, AbstractArray): + if array_type is not Any and issubclass(array_type, AbstractArray): if dtypes is _any_dtype: dtypes = array_type.dtypes elif array_type.dtypes is not _any_dtype: @@ -588,11 +600,15 @@ def _make_array(*args, **kwargs): if type(out) is tuple: array_type, name, dtypes, dims, index_variadic, dim_str = out - metaclass = _make_metaclass(type(array_type)) + metaclass = ( + _make_metaclass(type) + if array_type is Any + else _make_metaclass(type(array_type)) + ) out = metaclass( name, - (array_type, AbstractArray), + (AbstractArray,) if array_type is Any else (array_type, AbstractArray), dict( array_type=array_type, dtypes=dtypes, @@ -629,14 +645,18 @@ def __getitem__(cls, item: tuple[Any, str]): if isinstance(array_type, TypeVar): bound = array_type.__bound__ if bound is None: - array_type = Any + constraints = array_type.__constraints__ + if constraints == (): + array_type = Any + else: + array_type = Union[constraints] else: array_type = bound del item - if typing.get_origin(array_type) in _union_types: + if get_origin(array_type) in _union_types: out = [ _make_array(x, dim_str, cls.dtypes, cls.__name__) - for x in typing.get_args(array_type) + for x in get_args(array_type) ] out = tuple(x for x in out if x is not _not_made) if len(out) == 0: diff --git a/pyproject.toml b/pyproject.toml index 746b48f..db44eac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxtyping" -version = "0.2.32" +version = "0.2.33" description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." readme = "README.md" requires-python ="~=3.9"