diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 3395bc5026..7d7ba15a50 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -783,6 +783,17 @@ def test_tensordot_axes_errors(): dpt.tensordot(m1, m2, axes=-1) +# tests for gh-1570 +def test_tensordot_gemm_small_k_m(): + get_queue_or_skip() + + x1 = dpt.asarray(1, dtype="i2") + x2 = dpt.asarray([0, 1, 0, 0], dtype="i2") + + res = dpt.tensordot(x1, x2, axes=0) + assert dpt.all(x2 == res) + + @pytest.mark.parametrize("dtype", _numeric_types) def test_vecdot_1d(dtype): q = get_queue_or_skip() @@ -943,3 +954,29 @@ def test_vecdot_type_promotion(dt1, dt2): assert r.shape == tuple() assert r.dtype == mul.dtype assert dpt.allclose(r, dpt.sum(mul, dtype=mul.dtype)) + + +def test_vecdot_broadcast_o1_buffer(): + get_queue_or_skip() + + v1 = dpt.arange(10, dtype="i2") + v2 = dpt.ones((5, 10), dtype="i4") + + res1 = dpt.vecdot(v1, v2) + assert res1.shape == (5,) + + res2 = dpt.vecdot(v2, v1) + assert res2.shape == (5,) + + +def test_vecdot_contig_small(): + get_queue_or_skip() + + n = 1 + for dt in [dpt.int16, dpt.int32, dpt.complex64]: + v1 = dpt.zeros((10, n), dtype=dt) + v2 = dpt.ones_like(v1, dtype=dt) + v1[-1] = 1 + res = dpt.vecdot(v1, v2) + assert dpt.all(res[:-1] == 0) + assert res[-1] == n