Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dpnp.vecdot and dpnp.linalg.vecdot #2112

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions doc/reference/binary.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Bit-wise operations
=================
===================

.. https://numpy.org/doc/stable/reference/routines.bitwise.html

Expand All @@ -22,7 +22,6 @@ Element-wise bit operations
dpnp.bitwise_right_shift
dpnp.bitwise_count


Bit packing
-----------

Expand All @@ -33,7 +32,6 @@ Bit packing
dpnp.packbits
dpnp.unpackbits


Output formatting
-----------------

Expand Down
2 changes: 2 additions & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Other matrix operations
-----------------------
.. autosummary::
:toctree: generated/
:nosignatures:

dpnp.diagonal
dpnp.linalg.diagonal (Array API compatible)
Expand All @@ -96,5 +97,6 @@ Exceptions
----------
.. autosummary::
:toctree: generated/
:nosignatures:

dpnp.linalg.linAlgError
144 changes: 131 additions & 13 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
dpnp_dot,
dpnp_kron,
dpnp_matmul,
dpnp_vecdot,
)

__all__ = [
Expand All @@ -60,6 +61,7 @@
"outer",
"tensordot",
"vdot",
"vecdot",
]


Expand Down Expand Up @@ -145,11 +147,11 @@ def dot(a, b, out=None):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b, out=out)

# numpy.dot does not allow casting even if it is safe
# casting="no" is used in the following
if a_ndim == 1 and b_ndim == 1:
return dpnp_dot(a, b, out=out)
return dpnp_dot(a, b, out=out, casting="no")

# NumPy does not allow casting even if it is safe
# casting="no" is used in the following
if a_ndim == 2 and b_ndim == 2:
return dpnp.matmul(a, b, out=out, casting="no")

Expand Down Expand Up @@ -728,7 +730,6 @@ def matmul(
dtype=None,
subok=True,
signature=None,
extobj=None,
axes=None,
axis=None,
):
Expand All @@ -752,18 +753,19 @@ def matmul(
Type to use in computing the matrix product. By default, the returned
array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
Default: ``None``.
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
Controls what kind of data casting may occur.
Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: ``"K"``.
axes : {list of tuples}, optional
axes : {None, list of tuples}, optional
A list of tuples with indices of axes the matrix product should operate
on. For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base
elements are 2d matrices and these are taken to be stored in the two
last axes of each argument. The corresponding axes keyword would be
[(-2, -1), (-2, -1), (-2, -1)].
``[(-2, -1), (-2, -1), (-2, -1)]``.
Default: ``None``.

Returns
Expand All @@ -774,8 +776,8 @@ def matmul(

Limitations
-----------
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
only supported with their default value.
Keyword arguments `subok`, `signature`, and `axis` are only supported with
their default values.
Otherwise ``NotImplementedError`` exception will be raised.

See Also
Expand Down Expand Up @@ -834,18 +836,14 @@ def matmul(

"""

if subok is False:
if not subok:
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
)
if signature is not None:
raise NotImplementedError(
"signature keyword argument is only supported by its default value."
)
if extobj is not None:
raise NotImplementedError(
"extobj keyword argument is only supported by its default value."
)
if axis is not None:
raise NotImplementedError(
"axis keyword argument is only supported by its default value."
Expand Down Expand Up @@ -1135,6 +1133,9 @@ def vdot(a, b):
--------
:obj:`dpnp.dot` : Returns the dot product.
:obj:`dpnp.matmul` : Returns the matrix product.
:obj:`dpnp.vecdot` : Vector dot product of two arrays.
:obj:`dpnp.linalg.vecdot` : Array API compatible version of
:obj:`dpnp.vecdot`.

Examples
--------
Expand Down Expand Up @@ -1178,3 +1179,120 @@ def vdot(a, b):

# dot product of flatten arrays
return dpnp_dot(dpnp.ravel(a), dpnp.ravel(b), out=None, conjugate=True)


def vecdot(
x1,
x2,
/,
out=None,
*,
casting="same_kind",
order="K",
dtype=None,
subok=True,
signature=None,
axes=None,
axis=None,
):
r"""
Computes the vector dot product.

Let :math:`\mathbf{a}` be a vector in `x1` and :math:`\mathbf{b}` be
a corresponding vector in `x2`. The dot product is defined as:

.. math::
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i

where the sum is over the last dimension (unless `axis` is specified) and
where :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i`
is complex and the identity otherwise.

For full documentation refer to :obj:`numpy.vecdot`.

Parameters
----------
x1 : {dpnp.ndarray, usm_ndarray}
First input array.
x2 : {dpnp.ndarray, usm_ndarray}
Second input array.
out : {None, dpnp.ndarray, usm_ndarray}, optional
A location into which the result is stored. If provided, it must have
a shape that the broadcasted shape of `x1` and `x2` with the last axis
removed. If not provided or ``None``, a freshly-allocated array is
used.
Default: ``None``.
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
Controls what kind of data casting may occur.
Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: ``"K"``.
dtype : {None, dtype}, optional
Type to use in computing the vector dot product. By default, the
returned array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
Default: ``None``.
axes : {None, list of tuples}, optional
A list of tuples with indices of axes the matrix product should operate
on. For instance, for the signature of ``(i),(i)->()``, the base
elements are vectors and these are taken to be stored in the last axes
of each argument. The corresponding axes keyword would be
``[(-1,), (-1), ()]``.
Default: ``None``.
axis : {None, int}, optional
Axis over which to compute the dot product. This is a short-cut for
passing in axes with entries of ``(axis,)`` for each
single-core-dimension argument and ``()`` for all others. For instance,
for a signature ``(i),(i)->()``, it is equivalent to passing in
``axes=[(axis,), (axis,), ()]``.
Default: ``None``.

Returns
-------
out : dpnp.ndarray
The vector dot product of the inputs.
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.

Limitations
-----------
Keyword arguments `subok`, and `signature` are only supported with their
default values. Otherwise ``NotImplementedError`` exception will be raised.

See Also
--------
:obj:`dpnp.linalg.vecdot` : Array API compatible version.
:obj:`dpnp.vdot` : Complex-conjugating dot product.
:obj:`dpnp.einsum` : Einstein summation convention.

Examples
--------
Get the projected size along a given normal for an array of vectors.

>>> import dpnp as np
>>> v = np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]])
>>> n = np.array([0., 0.6, 0.8])
>>> np.vecdot(v, n)
array([ 3., 8., 10.])

"""

if not subok:
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
)
if signature is not None:
raise NotImplementedError(
"signature keyword argument is only supported by its default value."
)

return dpnp_vecdot(
x1,
x2,
out=out,
casting=casting,
order=order,
dtype=dtype,
axes=axes,
axis=axis,
)
Loading
Loading