From 034d843e8cbbe518bd6c3a04493f6550a6287beb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 6 May 2024 15:22:35 -0700 Subject: [PATCH] jax.numpy: better docs for matmul-like functions --- jax/_src/numpy/lax_numpy.py | 328 +++++++++++++++++++++++++++++++++--- jax/_src/numpy/linalg.py | 22 +-- 2 files changed, 315 insertions(+), 35 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4d006f52e83f..5cf307953f1c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3811,20 +3811,71 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, ### Tensor contraction operations - -_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """ -preferred_element_type : dtype, optional - If specified, accumulate results and return a result of the given data type. - If not specified, the accumulation dtype is determined from the type promotion - rules of the input array dtypes. -""" - -@util.implements(np.dot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the dot product of two arrays. + + JAX implementation of :func:`numpy.dot`. + + This differs from :func:`jax.numpy.matmul` in two respects: + + - if either ``a`` or ``b`` is a scalar, the result of ``dot`` is equivalent to + :func:`jax.numpy.multiply`, while the result of ``matmul`` is an error. + - if ``a`` and ``b`` have more than 2 dimensions, the batch indices are + stacked rather than broadcast. + + Args: + a: first input array, of shape ``(..., N)``. + b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. + In the multi-dimensional case, leading dimensions must be broadcast-compatible + with the leading dimensions of ``a``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the dot product of the inputs, with batch dimensions of + ``a`` and ``b`` stacked rather than broadcast. + + See also: + - :func:`jax.numpy.matmul`: broadcasted batched matmul. + - :func:`jax.lax.dot_general`: general batched matrix multiplication. + + Examples: + For scalar inputs, ``dot`` computes the element-wise product: + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.dot(x, 2) + Array([2, 4, 6], dtype=int32) + + For vector or matrix inputs, ``dot`` computes the vector or matrix product: + + >>> M = jnp.array([[2, 3, 4], + ... [5, 6, 7], + ... [8, 9, 0]]) + >>> jnp.dot(M, x) + Array([20, 38, 26], dtype=int32) + >>> jnp.dot(M, M) + Array([[ 51, 60, 29], + [ 96, 114, 62], + [ 61, 78, 95]], dtype=int32) + + For higher-dimensional matrix products, batch dimensions are stacked, whereas + in :func:`~jax.numpy.matmul` they are broadcast. For example: + + >>> a = jnp.zeros((3, 2, 4)) + >>> b = jnp.zeros((3, 4, 1)) + >>> jnp.dot(a, b).shape + (3, 2, 3, 1) + >>> jnp.matmul(a, b).shape + (3, 2, 1) + """ util.check_arraylike("dot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "dot") a, b = asarray(a), asarray(b) @@ -3852,14 +3903,64 @@ def dot(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -@util.implements(np.matmul, module='numpy', lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: - """Matrix Multiply.""" + """Perform a matrix multiplication. + + JAX implementation of :func:`numpy.matmul`. + + Args: + a: first input array, of shape ``(..., N)``. + b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. + In the multi-dimensional case, leading dimensions must be broadcast-compatible + with the leading dimensions of ``a``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the matrix product of the inputs. Shape is ``a.shape[:-1]`` + if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading + dimensions of ``a`` and ``b`` are broadcast together. + + See Also: + - :func:`jax.numpy.linalg.vecdot`: batched vector product. + - :func:`jax.numpy.linalg.tensordot`: batched tensor product. + - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. + + Examples: + Vector dot products: + + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.matmul(a, b) + Array(32, dtype=int32) + + Matrix dot product: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[1, 2], + ... [3, 4], + ... [5, 6]]) + >>> jnp.matmul(a, b) + Array([[22, 28], + [49, 64]], dtype=int32) + + For convenience, in all cases you can do the same computation using + the ``@`` operator: + + >>> a @ b + Array([[22, 28], + [49, 64]], dtype=int32) + """ util.check_arraylike("matmul", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "matmul") a, b = asarray(a), asarray(b) @@ -3925,14 +4026,47 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -@util.implements(np.vdot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, ) -> Array: + """Perform a conjugate multiplication of two 1D vectors. + + JAX implementation of :func:`numpy.vdot`. + + Args: + a: first input array, if not 1D it will be flattened. + b: second input array, if not 1D it will be flattened. Must have ``a.size == b.size``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + Scalar array (shape ``()``) containing the conjugate vector product of the inputs. + + See Also: + - :func:`jax.numpy.vecdot`: batched vector product. + - :func:`jax.numpy.matmul`: general matrix multiplication. + - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. + + Examples: + >>> x = jnp.array([1j, 2j, 3j]) + >>> y = jnp.array([1., 2., 3.]) + >>> jnp.vdot(x, y) + Array(0.-14.j, dtype=complex64) + + Note the difference between this and :func:`~jax.numpy.dot`, which does not + conjugate the first input when complex: + + >>> jnp.dot(x, y) + Array(0.+14.j, dtype=complex64) + """ util.check_arraylike("vdot", a, b) if issubdtype(_dtype(a), complexfloating): a = ufuncs.conj(a) @@ -3940,14 +4074,51 @@ def vdot( preferred_element_type=preferred_element_type) -@util.implements( - getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION, - # TODO(phawkins): numpy.vecdot doesn't have a __module__ attribute. - module="numpy") def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + """Perform a conjugate multiplication of two batched vectors. + + JAX implementation of :func:`numpy.vecdot`. + + Args: + a: left-hand side array. + b: right-hand side array. Size of ``b[axis]`` must match size of ``a[axis]``, + and remaining dimensions must be broadcast-compatible. + axis: axis along which to compute the dot product (default: -1) + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the conjugate dot product of ``a`` and ``b`` along ``axis``. + The non-contracted dimensions are broadcast together. + + See Also: + - :func:`jax.numpy.vdot`: flattened vector product. + - :func:`jax.numpy.matmul`: general matrix multiplication. + - :func:`jax.lax.dot_general`: general N-dimensional batched dot product. + + Examples: + Vector conjugate-dot product of two 1D arrays: + + >>> a = jnp.array([1j, 2j, 3j]) + >>> b = jnp.array([4., 5., 6.]) + >>> jnp.linalg.vecdot(a, b) + Array(0.-32.j, dtype=complex64) + + Batched vector dot product of two 2D arrays: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[2, 3, 4]]) + >>> jnp.linalg.vecdot(a, b, axis=-1) + Array([20, 47], dtype=int32) + """ util.check_arraylike("jnp.vecdot", x1, x2) x1_arr, x2_arr = asarray(x1), asarray(x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: @@ -3958,12 +4129,81 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) -@util.implements(np.tensordot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the tensor dot product of two N-dimensional arrays. + + JAX implementation of :func:`numpy.linalg.tensordot`. + + Args: + a: N-dimensional array + b: M-dimensional array + axes: integer or tuple of sequences of integers. If an integer `k`, then + sum over the last `k` axes of ``a`` and the first `k` axes of ``b``, + in order. If a tuple, then ``axes[0]`` specifies the axes of ``a`` and + ``axes[1]`` specifies the axes of ``b``. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array containing the tensor dot product of the inputs + + See also: + - :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions. + - :func:`jax.lax.dot_general`: XLA API for more general tensor contractions. + + Examples: + >>> x1 = jnp.arange(24.).reshape(2, 3, 4) + >>> x2 = jnp.ones((3, 4, 5)) + >>> jnp.tensordot(x1, x2) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Equivalent result when specifying the axes as explicit sequences: + + >>> jnp.tensordot(x1, x2, axes=([1, 2], [0, 1])) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Equivalent result via :func:`~jax.numpy.einsum`: + + >>> jnp.einsum('ijk,jkm->im', x1, x2) + Array([[ 66., 66., 66., 66., 66.], + [210., 210., 210., 210., 210.]], dtype=float32) + + Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix + multiplication: + + >>> x1 = jnp.array([[1, 2], + ... [3, 4]]) + >>> x2 = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.linalg.tensordot(x1, x2, axes=1) + Array([[ 9, 12, 15], + [19, 26, 33]], dtype=int32) + >>> x1 @ x2 + Array([[ 9, 12, 15], + [19, 26, 33]], dtype=int32) + + Setting ``axes=0`` for one-dimensional inputs is equivalent to + :func:`~jax.numpy.outer`: + + >>> x1 = jnp.array([1, 2]) + >>> x2 = jnp.array([1, 2, 3]) + >>> jnp.linalg.tensordot(x1, x2, axes=0) + Array([[1, 2, 3], + [2, 4, 6]], dtype=int32) + >>> jnp.outer(x1, x2) + Array([[1, 2, 3], + [2, 4, 6]], dtype=int32) + """ util.check_arraylike("tensordot", a, b) dtypes.check_user_dtype_supported(preferred_element_type, "tensordot") a, b = asarray(a), asarray(b) @@ -4247,13 +4487,53 @@ def filter_singleton_dims(operand, names, other_shape, other_names): return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type) -@util.implements(np.inner, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DType | None = None, ) -> Array: + """Compute the inner product of two arrays. + + JAX implementation of :func:`numpy.inner`. + + Unlike :func:`jax.numpy.matmul` or :func:`jax.numpy.dot`, this always performs + a contraction along the last dimension of each input. + + Args: + a: array of shape ``(..., N)`` + b: array of shape ``(..., N)`` + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two + such values indicating precision of ``a`` and ``b``. + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + + Returns: + array of shape ``(*a.shape[:-1], *b.shape[:-1])`` containing the batched vector + product of the inputs. + + See also: + - :func:`jax.numpy.vecdot`: conjugate multiplication along a specified axis. + - :func:`jax.numpy.tensordot`: general tensor multiplication. + - :func:`jax.numpy.matmul`: general batched matrix & vector multiplication. + + Examples: + For 1D inputs, this implements standard (non-conjugate) vector multiplication: + + >>> a = jnp.array([1j, 3j, 4j]) + >>> b = jnp.array([4., 2., 5.]) + >>> jnp.inner(a, b) + Array(0.+30.j, dtype=complex64) + + For multi-dimensional inputs, batch dimensions are stacked rather than broadcast: + + >>> a = jnp.ones((2, 3)) + >>> b = jnp.ones((5, 3)) + >>> jnp.inner(a, b).shape + (2, 5) + """ util.check_arraylike("inner", a, b) if ndim(a) == 0 or ndim(b) == 0: a = asarray(a, dtype=preferred_element_type) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e8ff77242d9f..4019190b97ad 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -1611,18 +1611,19 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: - """Compute the (batched) vector dot product of two arrays. + """Compute the (batched) vector conjugate dot product of two arrays. JAX implementation of :func:`numpy.linalg.vecdot`. Args: x1: left-hand side array. - x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``. + x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``, + and remaining dimensions must be broadcast-compatible. axis: axis along which to compute the dot product (default: -1) Returns: - array containing the dot product of ``x1`` and ``x2`` along ``axis``. The - non-contracted dimensions are broadcast together. + array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``. + The non-contracted dimensions are broadcast together. See also: - :func:`jax.numpy.vecdot`: similar API in the ``jax.numpy`` namespace. @@ -1637,7 +1638,6 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: >>> jnp.linalg.vecdot(x1, x2) Array(32, dtype=int32) - Batched vector dot product of two 2D arrays: >>> x1 = jnp.array([[1, 2, 3], @@ -1657,7 +1657,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: Args: x1: first input array, of shape ``(..., N)``. - x2: second input array. Must have shape ``(N,)`` or ``(..., M, N)``. + x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of ``x1``. @@ -1718,14 +1718,14 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, array containing the tensor dot product of the inputs See also: - :func:`jax.numpy.tensordot`: equivalent API in the :mod:`jax.numpy` namespace. - :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions. - :func:`jax.lax.dot_general`: XLA API for more general tensor contractions. + - :func:`jax.numpy.tensordot`: equivalent API in the :mod:`jax.numpy` namespace. + - :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions. + - :func:`jax.lax.dot_general`: XLA API for more general tensor contractions. Examples: >>> x1 = jnp.arange(24.).reshape(2, 3, 4) >>> x2 = jnp.ones((3, 4, 5)) - >>> jnp.tensordot(x1, x2) + >>> jnp.linalg.tensordot(x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32) @@ -1756,7 +1756,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, [19, 26, 33]], dtype=int32) Setting ``axes=0`` for one-dimensional inputs is equivalent to - ``jnp.linalg.outer``: + :func:`jax.numpy.linalg.outer`: >>> x1 = jnp.array([1, 2]) >>> x2 = jnp.array([1, 2, 3])