diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index b4786320..4b727f1c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import wraps as _wraps +from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from ..common import _aliases @@ -124,25 +124,43 @@ def _fix_promotion(x1, x2, only_scalar=True): def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: - if len(arrays_and_dtypes) == 0: - raise TypeError("At least one array or dtype must be provided") - if len(arrays_and_dtypes) == 1: + num = len(arrays_and_dtypes) + + if num == 0: + raise ValueError("At least one array or dtype must be provided") + + elif num == 1: x = arrays_and_dtypes[0] if isinstance(x, torch.dtype): return x return x.dtype - if len(arrays_and_dtypes) > 2: - return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) - x, y = arrays_and_dtypes - if isinstance(x, _py_scalars) or isinstance(y, _py_scalars): - return torch.result_type(x, y) + if num == 2: + x, y = arrays_and_dtypes + return _result_type(x, y) + + else: + # sort scalars so that they are treated last + scalars, others = [], [] + for x in arrays_and_dtypes: + if isinstance(x, _py_scalars): + scalars.append(x) + else: + others.append(x) + if not others: + raise ValueError("At least one array or dtype must be provided") + + # combine left-to-right + return _reduce(_result_type, others + scalars) - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y - if (xdt, ydt) in _promotion_table: - return _promotion_table[xdt, ydt] +def _result_type(x, y): + if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): + xdt = x.dtype if not isinstance(x, torch.dtype) else x + ydt = y.dtype if not isinstance(y, torch.dtype) else y + + if (xdt, ydt) in _promotion_table: + return _promotion_table[xdt, ydt] # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -151,6 +169,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) + def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype diff --git a/tests/test_all.py b/tests/test_all.py index 10a2a95d..d2e9b768 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,6 +16,7 @@ import pytest +@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 00000000..75b3a136 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,98 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import itertools + +import pytest +import torch + +from array_api_compat import torch as xp + + +class TestResultType: + def test_empty(self): + with pytest.raises(ValueError): + xp.result_type() + + def test_one_arg(self): + for x in [1, 1.0, 1j, '...', None]: + with pytest.raises((ValueError, AttributeError)): + xp.result_type(x) + + for x in [xp.float32, xp.int64, torch.complex64]: + assert xp.result_type(x) == x + + for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: + assert xp.result_type(x) == x.dtype + + def test_two_args(self): + # Only include here things "unspecified" in the spec + + # scalar, tensor or tensor,tensor + for x, y in [ + (1., 1j), + (1j, xp.arange(3)), + (True, xp.asarray(3.)), + (xp.ones(3) == 1, 1j*xp.ones(3)), + ]: + assert xp.result_type(x, y) == torch.result_type(x, y) + + # dtype, scalar + for x, y in [ + (1j, xp.int64), + (True, xp.float64), + ]: + assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) + + # dtype, dtype + for x, y in [ + (xp.bool, xp.complex64) + ]: + xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) + assert xp.result_type(x, y) == torch.result_type(xt, yt) + + def test_multi_arg(self): + torch.set_default_dtype(torch.float32) + + args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] + assert xp.result_type(*args) == torch.float16 + + args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] + assert xp.result_type(*args) == xp.complex64 + + args = [1, 2, 3j, xp.float64, 4, 5, 6] + assert xp.result_type(*args) == xp.complex128 + + args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] + assert xp.result_type(*args) == xp.complex128 + + i64 = xp.ones(1, dtype=xp.int64) + f16 = xp.ones(1, dtype=xp.float16) + for i in itertools.permutations([i64, f16, 1.0, 1.0]): + assert xp.result_type(*i) == xp.float16, f"{i}" + + with pytest.raises(ValueError): + xp.result_type(1, 2, 3, 4) + + + @pytest.mark.parametrize("default_dt", ['float32', 'float64']) + @pytest.mark.parametrize("dtype_a", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + @pytest.mark.parametrize("dtype_b", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + def test_gh_273(self, default_dt, dtype_a, dtype_b): + # Regression test for https://github.com/data-apis/array-api-compat/issues/273 + + try: + prev_default = torch.get_default_dtype() + default_dtype = getattr(torch, default_dt) + torch.set_default_dtype(default_dtype) + + a = xp.asarray([2, 1], dtype=dtype_a) + b = xp.asarray([1, -1], dtype=dtype_b) + dtype_1 = xp.result_type(a, b, 1.0) + dtype_2 = xp.result_type(b, a, 1.0) + assert dtype_1 == dtype_2 + finally: + torch.set_default_dtype(prev_default)