Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex number support to linalg.qr #548

Merged
merged 2 commits into from
Dec 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions spec/API_specification/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,42 @@ def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
"""

def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tuple[array, array]:
"""
Returns the qr decomposition x = QR of a full column rank matrix (or a stack of matrices), where ``Q`` is an orthonormal matrix (or a stack of matrices) and ``R`` is an upper-triangular matrix (or a stack of matrices).
r"""
Returns the QR decomposition of a full column rank matrix (or a stack of matrices).

If ``x`` is real-valued, let :math:`\mathbb{K}` be the set of real numbers :math:`\mathbb{R}`, and, if ``x`` is complex-valued, let :math:`\mathbb{K}` be the set of complex numbers :math:`\mathbb{C}`.

The **complete QR decomposition** of a matrix :math:`x \in\ \mathbb{K}^{n \times n}` is defined as

.. math::
x = QR

where :math:`Q \in\ \mathbb{K}^{m \times m}` is orthogonal when ``x`` is real-valued and unitary when ``x`` is complex-valued and where :math:`R \in\ \mathbb{K}^{m \times n}` is an upper triangular matrix with real diagonal (even when ``x`` is complex-valued).

When :math:`m \gt n` (tall matrix), as :math:`R` is upper triangular, the last :math:`m - n` rows are zero. In this case, the last :math:`m - n` columns of :math:`Q` can be dropped to form the **reduced QR decomposition**.

.. math::
x = QR

where :math:`Q \in\ \mathbb{K}^{m \times n}` and :math:`R \in\ \mathbb{K}^{n \times n}`.

The reduced QR decomposition equals with the complete QR decomposition when :math:`n \qeq m` (wide matrix).

When ``x`` is a stack of matrices, the function must compute the QR decomposition for each matrix in the stack.

.. note::
Whether an array library explicitly checks whether an input array is a full column rank matrix (or a stack of full column rank matrices) is implementation-defined.

.. warning::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a note instead of a warning, in order to look less scary. That's a minor thing though, and perhaps a question of taste. So I'll leave it alone for now, just noting it in case someone revisits this PR in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed PyTorch in using a warning, instead of a note.

The elements in the diagonal of :math:`R` are not necessarily positive. Accordingly, the returned QR decomposition is only unique up to the sign of the diagonal of :math:`R`, and different libraries or inputs on different devices may produce different valid decompositions.

.. warning::
The QR decomposition is only well-defined if the first ``k = min(m,n)`` columns of every matrix in ``x`` are linearly independent.

Parameters
----------
x: array
input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices of rank ``N``. Should have a real-valued floating-point data type.
input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices of rank ``N``. Should have a floating-point data type.
mode: Literal['reduced', 'complete']
decomposition mode. Should be one of the following modes:

Expand All @@ -341,7 +367,7 @@ def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tupl
- first element must have the field name ``Q`` and must be an array whose shape depends on the value of ``mode`` and contain matrices with orthonormal columns. If ``mode`` is ``'complete'``, the array must have shape ``(..., M, M)``. If ``mode`` is ``'reduced'``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same size as those of the input array ``x``.
- second element must have the field name ``R`` and must be an array whose shape depends on the value of ``mode`` and contain upper-triangular matrices. If ``mode`` is ``'complete'``, the array must have shape ``(..., M, N)``. If ``mode`` is ``'reduced'``, the array must have shape ``(..., K, N)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same size as those of the input ``x``.

Each returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
Each returned array must have a floating-point data type determined by :ref:`type-promotion`.
"""

def slogdet(x: array, /) -> Tuple[array, array]:
Expand Down