|
19 | 19 | import numpy as np |
20 | 20 | import pytest |
21 | 21 |
|
| 22 | +import dpctl |
22 | 23 | import dpctl.tensor as dpt |
23 | 24 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported |
24 | 25 |
|
@@ -81,6 +82,26 @@ def test_matmul_simple(dtype): |
81 | 82 | assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) |
82 | 83 |
|
83 | 84 |
|
| 85 | +@pytest.mark.parametrize("dtype", _numeric_types) |
| 86 | +def test_matmul_simple2(dtype): |
| 87 | + q = get_queue_or_skip() |
| 88 | + skip_if_dtype_not_supported(dtype, q) |
| 89 | + dev = q.sycl_device |
| 90 | + if dev.is_cpu: |
| 91 | + cpu_count = dev.max_compute_units |
| 92 | + sub_devs = dev.create_sub_devices(partition=min(4, cpu_count // 2)) |
| 93 | + ctx = dpctl.SyclContext(sub_devs[0]) |
| 94 | + q = dpctl.SyclQueue(ctx, sub_devs[0]) |
| 95 | + |
| 96 | + n, m = 235, 17 |
| 97 | + m1 = dpt.ones((m, n), dtype=dtype, sycl_queue=q) |
| 98 | + m2 = dpt.ones((n, m), dtype=dtype, sycl_queue=q) |
| 99 | + |
| 100 | + for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: |
| 101 | + r = dpt.matmul(m1[:k, :], m2[:, :k]) |
| 102 | + assert dpt.all(r == dpt.full((k, k), n, dtype=dtype, sycl_queue=q)) |
| 103 | + |
| 104 | + |
84 | 105 | @pytest.mark.parametrize("dtype", _numeric_types) |
85 | 106 | def test_matmul_nilpotent1(dtype): |
86 | 107 | q = get_queue_or_skip() |
|
0 commit comments