diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d123b99..e16ba054 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -67,7 +67,7 @@ def eye( def full( shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, xp: Namespace, *, dtype: Optional[DType] = None, @@ -80,7 +80,7 @@ def full( def full_like( x: Array, /, - fill_value: complex, + fill_value: bool | int | float | complex, *, xp: Namespace, dtype: Optional[DType] = None, diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index ebc7ccd9..423fd10a 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -68,7 +68,10 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e737cebd..e6eff359 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -136,7 +136,10 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 6536d9a8..1d084b2b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -77,7 +77,10 @@ def _supports_buffer_protocol(obj): # rather than trying to combine everything into one function in common/ def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 87d32d85..982500b0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: +def result_type( + *arrays_and_dtypes: Array | DType | bool | int | float | complex +) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -550,10 +552,16 @@ def count_nonzero( return result -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where( + condition: Array, + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + /, +) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) + # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, @@ -622,7 +630,7 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, *, dtype: Optional[DType] = None, device: Optional[Device] = None, diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 7b59a670..1ff7319d 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -85,7 +85,7 @@ def vector_norm( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float, float] = 2, + ord: Union[int, float] = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None