Skip to content

Commit

Permalink
Adds tests for bugs changes in vecdot, tensordot
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian authored and oleksandr-pavlyk committed Mar 27, 2024
1 parent 8f000a0 commit a397caf
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions dpctl/tests/test_usm_ndarray_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,17 @@ def test_tensordot_axes_errors():
dpt.tensordot(m1, m2, axes=-1)


# tests for gh-1570
def test_tensordot_gemm_small_k_m():
get_queue_or_skip()

x1 = dpt.asarray(1, dtype="i2")
x2 = dpt.asarray([0, 1, 0, 0], dtype="i2")

res = dpt.tensordot(x1, x2, axes=0)
assert dpt.all(x2 == res)


@pytest.mark.parametrize("dtype", _numeric_types)
def test_vecdot_1d(dtype):
q = get_queue_or_skip()
Expand Down Expand Up @@ -943,3 +954,29 @@ def test_vecdot_type_promotion(dt1, dt2):
assert r.shape == tuple()
assert r.dtype == mul.dtype
assert dpt.allclose(r, dpt.sum(mul, dtype=mul.dtype))


def test_vecdot_broadcast_o1_buffer():
get_queue_or_skip()

v1 = dpt.arange(10, dtype="i2")
v2 = dpt.ones((5, 10), dtype="i4")

res1 = dpt.vecdot(v1, v2)
assert res1.shape == (5,)

res2 = dpt.vecdot(v2, v1)
assert res2.shape == (5,)


def test_vecdot_contig_small():
get_queue_or_skip()

n = 1
for dt in [dpt.int16, dpt.int32, dpt.complex64]:
v1 = dpt.zeros((10, n), dtype=dt)
v2 = dpt.ones_like(v1, dtype=dt)
v1[-1] = 1
res = dpt.vecdot(v1, v2)
assert dpt.all(res[:-1] == 0)
assert res[-1] == n

0 comments on commit a397caf

Please sign in to comment.