From 968e499205ecb7d93dcaa3778ca7ee3ab910ca3e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 27 Mar 2023 19:43:42 -0600 Subject: [PATCH 1/6] Add implementation for torch.linalg.vecdot for integer dtypes --- array_api_compat/torch/linalg.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index c803228a..8f95fe9e 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -22,6 +22,27 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch_linalg.cross(x1, x2, dim=axis) -__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot'] +def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: + from ._aliases import isdtype + + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + + # torch.linalg.vecdot doesn't support integer dtypes + if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = torch.broadcast_tensors(x1, x2) + x1_ = torch.moveaxis(x1_, axis, -1) + x2_ = torch.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + return torch.linalg.vecdot(x1, x2, axis=axis) + +__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot'] del linalg_all From dfc12755e19040dc00935007cdbff8de74509f82 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 27 Mar 2023 19:50:00 -0600 Subject: [PATCH 2/6] Pass kwargs through to torch.linalg.vecdot --- array_api_compat/torch/linalg.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 8f95fe9e..535cef4d 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -22,13 +22,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # torch.linalg.vecdot doesn't support integer dtypes if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): + if kwargs: + raise RuntimeError("vecdot kwargs not supported for integral dtypes") ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) @@ -41,7 +43,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] - return torch.linalg.vecdot(x1, x2, axis=axis) + return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) __all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot'] From d6bf189a46a2ce3d2fd0349dc13540ce0a25d03b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 27 Mar 2023 19:50:50 -0600 Subject: [PATCH 3/6] Add wrapper for torch.linalg.solve that does correct type promotion --- array_api_compat/torch/linalg.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 535cef4d..afef7504 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -45,6 +45,11 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot'] +def solve(x1: array, x2: array, /, **kwargs) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.linalg.solve(x1, x2, **kwargs) + +__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', + 'vecdot', 'solve'] del linalg_all From e6aa969ec511c6ea6c30b636baa1d0de512990bb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 28 Mar 2023 18:39:03 -0600 Subject: [PATCH 4/6] The torch log1p test can apparently fail on CI too --- torch-xfails.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index d6cf1670..ad5d25bc 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -57,8 +57,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__im array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] -# Mac-only bug (overflow near float max) -# array_api_tests/test_operators_and_elementwise_functions.py::test_log1p +# overflow near float max +array_api_tests/test_operators_and_elementwise_functions.py::test_log1p # torch doesn't handle shifting by more than the bit size correctly # https://github.com/pytorch/pytorch/issues/70904 From 32fc8e77ff9c403c2a463737a3bd24077a2aaa30 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 28 Mar 2023 19:28:36 -0600 Subject: [PATCH 5/6] Add missing numpy 1.21 xfail --- numpy-1-21-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index d9b76859..505cf7d2 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -118,6 +118,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] From 05204b3b0c6815e8b71b7717b598539edc2db889 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 28 Mar 2023 19:29:49 -0600 Subject: [PATCH 6/6] Fix the naming of the numpy jobs on CI --- .github/workflows/array-api-tests-numpy-1-21.yml | 2 +- .github/workflows/array-api-tests-numpy.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml index 5c64f16a..2d81c3cd 100644 --- a/.github/workflows/array-api-tests-numpy-1-21.yml +++ b/.github/workflows/array-api-tests-numpy-1-21.yml @@ -3,7 +3,7 @@ name: Array API Tests (NumPy 1.21) on: [push, pull_request] jobs: - array-api-tests-numpy: + array-api-tests-numpy-1-21: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy diff --git a/.github/workflows/array-api-tests-numpy.yml b/.github/workflows/array-api-tests-numpy.yml index b97786df..36984345 100644 --- a/.github/workflows/array-api-tests-numpy.yml +++ b/.github/workflows/array-api-tests-numpy.yml @@ -3,7 +3,7 @@ name: Array API Tests (NumPy Latest) on: [push, pull_request] jobs: - array-api-tests-numpy-1-21: + array-api-tests-numpy-latest: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy