diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 5269114d609..616b47483a0 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -2099,7 +2099,7 @@ def dpnp_matmul( if numpy.prod(result_shape) == 0: res_shape = result_shape elif x1_shape[-1] == 1: - call_flag = "kron" + call_flag = "multiply" elif x1_is_1D and x2_is_1D: call_flag = "dot" x1 = dpnp.reshape(x1, x1_shape[-1]) @@ -2148,8 +2148,8 @@ def dpnp_matmul( call_flag = "gemm_batch" res_shape = result_shape - if call_flag == "kron": - res = dpnp.kron(x1, x2) + if call_flag == "multiply": + res = dpnp.multiply(x1, x2) res_shape = res.shape elif call_flag == "dot": if out is not None and out.shape != (): diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 935cc9e1f43..6dc5cb01688 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2180,6 +2180,11 @@ def setup_method(self): ((10, 1, 1, 3), (2, 3, 3)), ((10, 1, 1, 3), (10, 2, 3, 3)), ((10, 2, 1, 3), (10, 2, 3, 3)), + ((3, 3, 1), (3, 1, 2)), + ((3, 3, 1), (1, 1, 2)), + ((1, 3, 1), (3, 1, 2)), + ((4, 1, 3, 1), (1, 3, 1, 2)), + ((1, 3, 3, 1), (4, 1, 1, 2)), ], ) def test_matmul(self, order_pair, shape_pair):