diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 6dd41493914..accb18981a5 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -1824,6 +1824,7 @@ def dpnp_solve(a, b): return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) if a.ndim > 2: + is_cpu_device = exec_q.sycl_device.has_aspect_cpu reshape = False orig_shape_b = b_shape if a.ndim > 3: @@ -1850,22 +1851,27 @@ def dpnp_solve(a, b): for i in range(batch_size): # oneMKL LAPACK assumes fortran-like array as input, so # allocate a memory with 'F' order for dpnp array of coefficient matrix - # and multiple dependent variables array coeff_vecs[i] = dpnp.empty_like( a[i], order="F", dtype=res_type, usm_type=res_usm_type ) - val_vecs[i] = dpnp.empty_like( - b[i], order="F", dtype=res_type, usm_type=res_usm_type - ) # use DPCTL tensor function to fill the coefficient matrix array - # and the array of multiple dependent variables with content - # from the input arrays + # with content from the input array a_ht_copy_ev[i], a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=a_usm_arr[i], dst=coeff_vecs[i].get_array(), sycl_queue=a.sycl_queue, ) + + # oneMKL LAPACK assumes fortran-like array as input, so + # allocate a memory with 'F' order for dpnp array of multiple + # dependent variables array + val_vecs[i] = dpnp.empty_like( + b[i], order="F", dtype=res_type, usm_type=res_usm_type + ) + + # use DPCTL tensor function to fill the array of multiple dependent + # variables with content from the input arrays b_ht_copy_ev[i], b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=b_usm_arr[i], dst=val_vecs[i].get_array(), @@ -1882,6 +1888,15 @@ def dpnp_solve(a, b): depends=[a_copy_ev, b_copy_ev], ) + # TODO: Remove this w/a when MKLD-17201 is solved. + # Waiting for a host task executing an OneMKL LAPACK gesv call + # on CPU causes deadlock due to serialization of all host tasks + # in the queue. + # We need to wait for each host tasks before calling _gesv to avoid deadlock. + if is_cpu_device: + ht_lapack_ev[i].wait() + b_ht_copy_ev[i].wait() + for i in range(batch_size): ht_lapack_ev[i].wait() b_ht_copy_ev[i].wait() diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index f95c0413053..b74e927f66d 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -899,9 +899,6 @@ def test_eigenvalue(func, shape, usm_type): ) def test_solve(matrix, vector, usm_type_matrix, usm_type_vector): x = dp.array(matrix, usm_type=usm_type_matrix) - if x.ndim > 2 and x.device.sycl_device.is_cpu: - pytest.skip("SAT-6842: reported hanging in public CI") - y = dp.array(vector, usm_type=usm_type_vector) z = dp.linalg.solve(x, y) diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index 182e9ca15f0..7e7c2377f0b 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -47,7 +47,6 @@ def check_x(self, a_shape, b_shape, xp, dtype): testing.assert_array_equal(b_copy, b) return result - @pytest.mark.skipif(is_cpu_device(), reason="SAT-6842") def test_solve(self): self.check_x((4, 4), (4,)) self.check_x((5, 5), (5, 2))