Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Oct 28, 2024
1 parent dd22aae commit 0f53876
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 44 deletions.
18 changes: 10 additions & 8 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def dot(a, b, out=None):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b, out=out)

# numpy.dot does not allow casting even if it is safe
# casting="no" is used in the following
if a_ndim == 1 and b_ndim == 1:
return dpnp_dot(a, b, out=out)
return dpnp_dot(a, b, out=out, casting="no")

# NumPy does not allow casting even if it is safe
# casting="no" is used in the following
if a_ndim == 2 and b_ndim == 2:
return dpnp.matmul(a, b, out=out, casting="no")

Expand Down Expand Up @@ -753,6 +753,7 @@ def matmul(
Type to use in computing the matrix product. By default, the returned
array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
Default: ``None``.
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
Controls what kind of data casting may occur.
Default: ``"same_kind"``.
Expand Down Expand Up @@ -1203,7 +1204,7 @@ def vecdot(
.. math::
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
where the sum is over the last dimension (unless axis is specified) and
where the sum is over the last dimension (unless `axis` is specified) and
where :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i`
is complex and the identity otherwise.
Expand All @@ -1221,16 +1222,17 @@ def vecdot(
removed. If not provided or ``None``, a freshly-allocated array is
used.
Default: ``None``.
dtype : {None, dtype}, optional
Type to use in computing the vector dot product. By default, the
returned array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
Controls what kind of data casting may occur.
Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: ``"K"``.
dtype : {None, dtype}, optional
Type to use in computing the vector dot product. By default, the
returned array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
Default: ``None``.
axes : {None, list of tuples}, optional
A list of tuples with indices of axes the matrix product should operate
on. For instance, for the signature of ``(i),(i)->()``, the base
Expand Down
65 changes: 30 additions & 35 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _define_contig_flag(x):

def _define_dim_flags(x, axis):
"""
Define useful flags for the main calculation in dpnp_matmul.
Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
except for one of them), for instance, if x.shape = (1, 1, 1, 2),
then x_is_1D = True
Expand Down Expand Up @@ -220,7 +220,7 @@ def _define_dim_flags(x, axis):
return x_is_2D, x_is_1D, x_base_is_1D


def _get_result_shape(x1, x2, out, func, np_flag):
def _get_result_shape(x1, x2, out, _get_result_shape_fn, np_flag):
"""
Three task are completed in this function:
- Get the shape of the result array.
Expand All @@ -239,15 +239,7 @@ def _get_result_shape(x1, x2, out, func, np_flag):
"The second input array does not have enough dimensions (has 0, but requires at least 1)"
)

if func == "matmul":
x1, x2, result_shape = _get_result_shape_matmul(
x1, x2, x1_ndim, x2_ndim
)
else: # func == "vecdot"
assert func == "vecdot"
x1, x2, result_shape = _get_result_shape_vecdot(
x1, x2, x1_ndim, x2_ndim
)
x1, x2, result_shape = _get_result_shape_fn(x1, x2, x1_ndim, x2_ndim)

if out is not None:
out_shape = out.shape
Expand Down Expand Up @@ -474,7 +466,7 @@ def _shape_error(shape1, shape2, func, err_msg):
elif func == "vecdot":
signature = "(n?,),(n?,)->()"
else:
# applicable when err_msg == 3
# applicable when err_msg == 2
assert func is None

if err_msg == 0:
Expand Down Expand Up @@ -655,7 +647,7 @@ def dpnp_cross(a, b, cp):
return cp


def dpnp_dot(a, b, /, out=None, *, conjugate=False):
def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
"""
Return the dot product of two arrays.
Expand Down Expand Up @@ -717,8 +709,7 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
if dot_dtype != res_dtype:
result = result.astype(res_dtype, copy=False)

# numpy.dot does not allow casting even if it is safe
return dpnp.get_result_array(result, out, casting="no")
return dpnp.get_result_array(result, out, casting=casting)


def dpnp_kron(a, b, a_ndim, b_ndim):
Expand Down Expand Up @@ -773,8 +764,10 @@ def dpnp_matmul(
order = "F"
else:
order = "C"

if order in "kK":
elif order in "kK":
# For order="K", we return order="C" to align with NumPy behavior
# It is different than logic used in dpnp_vecdot because NumPy
# behaves differently for matmul and vecdot
order = "C"

x1_ndim = x1.ndim
Expand Down Expand Up @@ -806,7 +799,7 @@ def dpnp_matmul(
)

x1, x2, result_shape = _get_result_shape(
x1, x2, out, "matmul", NumPy_special_behavior
x1, x2, out, _get_result_shape_matmul, NumPy_special_behavior
)

# Determine the appropriate data types
Expand Down Expand Up @@ -1000,6 +993,9 @@ def dpnp_vecdot(
_validate_out_array(out, exec_q)

if order in "aAkK":
# This logic is also used for order="K" to align with NumPy behavior.
# It is different than logic used in dpnp_matmul because NumPy
# behaves differently for matmul and vecdot
if x1.flags.fnc and x2.flags.fnc:
order = "F"
else:
Expand Down Expand Up @@ -1035,7 +1031,7 @@ def dpnp_vecdot(
)

x1, x2, result_shape = _get_result_shape(
x1, x2, out, "vecdot", NumPy_special_behavior
x1, x2, out, _get_result_shape_vecdot, NumPy_special_behavior
)

# Determine the appropriate data types
Expand All @@ -1047,21 +1043,7 @@ def dpnp_vecdot(
_, x2_is_1D, _ = _define_dim_flags(x2, axis=-1)

if x1.size == 0 or x2.size == 0:
order = "C" if order in "kK" else order
result = _create_result_array(
x1,
x2,
out,
shape=result_shape,
dtype=res_dtype,
usm_type=res_usm_type,
sycl_queue=exec_q,
order=order,
)
if numpy.prod(result_shape) == 0:
return result
result.fill(0)
return result
call_flag = "trivial"
elif x1_is_1D and x2_is_1D:
call_flag = "dot"
# arrays are inehrently 1D, make them 1D
Expand All @@ -1072,7 +1054,20 @@ def dpnp_vecdot(
call_flag = "vecdot"

# dispatch to proper function call
if call_flag == "dot":
if call_flag == "trivial":
result = _create_result_array(
x1,
x2,
out,
shape=result_shape,
dtype=res_dtype,
usm_type=res_usm_type,
sycl_queue=exec_q,
order=order,
)
if numpy.prod(result_shape) != 0:
result.fill(0)
elif call_flag == "dot":
if out is not None and out.shape != ():
result = dpnp_dot(x1, x2, out=None, conjugate=True)
else:
Expand Down
51 changes: 50 additions & 1 deletion tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,27 @@ def test_order(self, order1, order2, order, shape):
assert result.flags.f_contiguous == expected.flags.f_contiguous
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("order", ["C", "F", "K", "A"])
@pytest.mark.parametrize(
"shape", [(2, 4, 0), (4, 0, 5)], ids=["(2, 4, 0)", "(4, 0, 5)"]
)
def test_order_trivial(self, order, shape):
# input is both c-contiguous and f-contiguous
a = numpy.ones(shape)
a_dp = dpnp.asarray(a)

result = dpnp.vecdot(a_dp, a_dp, order=order)
expected = numpy.vecdot(a, a, order=order)
if shape == (2, 4, 0) and order == "A":
# NumPy does not behave correctly for this case, for order="A",
# if input is both c- and f-contiguous, output is c-contiguous
assert result.flags.c_contiguous
assert not result.flags.f_contiguous
else:
assert result.flags.c_contiguous == expected.flags.c_contiguous
assert result.flags.f_contiguous == expected.flags.f_contiguous
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize(
"order1, order2, out_order",
[
Expand All @@ -1538,7 +1559,7 @@ def test_order(self, order1, order2, order, shape):
("F", "F", "C"),
],
)
def test_out(self, order1, order2, out_order):
def test_out_order(self, order1, order2, out_order):
a1 = numpy.arange(20).reshape(5, 4, order=order1)
a2 = numpy.arange(20).reshape(5, 4, order=order2)

Expand All @@ -1555,6 +1576,34 @@ def test_out(self, order1, order2, out_order):
assert result.flags.f_contiguous == expected.flags.f_contiguous
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("dtype1", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("dtype2", get_all_dtypes(no_none=True))
@pytest.mark.parametrize(
"shape_pair",
[
((4,), ()),
((1, 1, 4), (1, 1)),
((6, 7, 4, 3), (6, 7, 4)),
((2, 0), (2,)), # zero-size inputs, 1D output
((3, 0, 4), (3, 0)), # zero-size output
],
)
def test_out_dtype(self, dtype1, dtype2, shape_pair):
shape1, shape2 = shape_pair
a = numpy.ones(shape1, dtype=dtype1)
b = dpnp.asarray(a)

out_np = numpy.empty(shape2, dtype=dtype2)
out_dp = dpnp.asarray(out_np)

if dpnp.can_cast(dtype1, dtype2, casting="same_kind"):
result = dpnp.vecdot(b, b, out=out_dp)
expected = numpy.vecdot(a, a, out=out_np)
assert_dtype_allclose(result, expected)
else:
with pytest.raises(TypeError):
dpnp.vecdot(b, b, out=out_dp)

@pytest.mark.parametrize(
"out_shape",
[
Expand Down

0 comments on commit 0f53876

Please sign in to comment.