diff --git a/cubed/array_api/linalg.py b/cubed/array_api/linalg.py index 91b25394..8ca57877 100644 --- a/cubed/array_api/linalg.py +++ b/cubed/array_api/linalg.py @@ -3,6 +3,7 @@ from cubed.array_api.array_object import Array # These functions are in both the main and linalg namespaces +from cubed.array_api.data_type_functions import result_type from cubed.array_api.linear_algebra_functions import ( # noqa: F401 matmul, matrix_transpose, @@ -15,7 +16,9 @@ def outer(x1, x2, /): - return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype) + return blockwise( + nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=result_type(x1, x2) + ) class QRResult(NamedTuple):