Skip to content

Commit 5d3a92c

Browse files
authored
Merge pull request #35 from asmeurer/torch-linalg2
More fixes for torch linalg extension
2 parents b32a5b3 + 05204b3 commit 5d3a92c

File tree

5 files changed

+34
-5
lines changed

5 files changed

+34
-5
lines changed

.github/workflows/array-api-tests-numpy-1-21.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Array API Tests (NumPy 1.21)
33
on: [push, pull_request]
44

55
jobs:
6-
array-api-tests-numpy:
6+
array-api-tests-numpy-1-21:
77
uses: ./.github/workflows/array-api-tests.yml
88
with:
99
package-name: numpy

.github/workflows/array-api-tests-numpy.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Array API Tests (NumPy Latest)
33
on: [push, pull_request]
44

55
jobs:
6-
array-api-tests-numpy-1-21:
6+
array-api-tests-numpy-latest:
77
uses: ./.github/workflows/array-api-tests.yml
88
with:
99
package-name: numpy

array_api_compat/torch/linalg.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,34 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2222
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
2323
return torch_linalg.cross(x1, x2, dim=axis)
2424

25-
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
25+
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
26+
from ._aliases import isdtype
27+
28+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
29+
30+
# torch.linalg.vecdot doesn't support integer dtypes
31+
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
32+
if kwargs:
33+
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
34+
ndim = max(x1.ndim, x2.ndim)
35+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
36+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
37+
if x1_shape[axis] != x2_shape[axis]:
38+
raise ValueError("x1 and x2 must have the same size along the given axis")
39+
40+
x1_, x2_ = torch.broadcast_tensors(x1, x2)
41+
x1_ = torch.moveaxis(x1_, axis, -1)
42+
x2_ = torch.moveaxis(x2_, axis, -1)
43+
44+
res = x1_[..., None, :] @ x2_[..., None]
45+
return res[..., 0, 0]
46+
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
47+
48+
def solve(x1: array, x2: array, /, **kwargs) -> array:
49+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
50+
return torch.linalg.solve(x1, x2, **kwargs)
51+
52+
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
53+
'vecdot', 'solve']
2654

2755
del linalg_all

numpy-1-21-xfails.txt

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_
118118
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
119119
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
120120
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
121+
array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
121122
array_api_tests/test_operators_and_elementwise_functions.py::test_logaddexp
122123
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
123124
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]

torch-xfails.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__im
5757
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
5858

5959

60-
# Mac-only bug (overflow near float max)
61-
# array_api_tests/test_operators_and_elementwise_functions.py::test_log1p
60+
# overflow near float max
61+
array_api_tests/test_operators_and_elementwise_functions.py::test_log1p
6262

6363
# torch doesn't handle shifting by more than the bit size correctly
6464
# https://github.com/pytorch/pytorch/issues/70904

0 commit comments

Comments
 (0)