|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | import dpctl |
6 | | -import numpy as np |
| 6 | +import dpnp as np |
7 | 7 | import pytest |
8 | 8 |
|
9 | 9 | import numba_dpex as dpex |
@@ -38,8 +38,11 @@ def fdtype(request): |
38 | 38 |
|
39 | 39 | @pytest.fixture(params=list_of_i_dtypes + list_of_f_dtypes) |
40 | 40 | def input_arrays(request): |
41 | | - a = np.array([0], request.param) |
42 | | - return a, request.param |
| 41 | + def _inpute_arrays(filter_str): |
| 42 | + a = np.array([0], request.param, device=filter_str) |
| 43 | + return a, request.param |
| 44 | + |
| 45 | + return _inpute_arrays |
43 | 46 |
|
44 | 47 |
|
45 | 48 | list_of_op = [ |
@@ -72,11 +75,9 @@ def f(a): |
72 | 75 | @pytest.mark.parametrize("filter_str", filter_strings) |
73 | 76 | @skip_no_atomic_support |
74 | 77 | def test_kernel_atomic_simple(filter_str, input_arrays, kernel_result_pair): |
75 | | - a, dtype = input_arrays |
| 78 | + a, dtype = input_arrays(filter_str) |
76 | 79 | kernel, expected = kernel_result_pair |
77 | | - device = dpctl.SyclDevice(filter_str) |
78 | | - with dpctl.device_context(device): |
79 | | - kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a) |
| 80 | + kernel[dpex.Range(global_size)](a) |
80 | 81 | assert a[0] == expected |
81 | 82 |
|
82 | 83 |
|
@@ -114,15 +115,11 @@ def f(a): |
114 | 115 | @pytest.mark.parametrize("filter_str", filter_strings) |
115 | 116 | @skip_no_atomic_support |
116 | 117 | def test_kernel_atomic_local(filter_str, input_arrays, return_list_of_op): |
117 | | - a, dtype = input_arrays |
| 118 | + a, dtype = input_arrays(filter_str) |
118 | 119 | op_type, expected = return_list_of_op |
119 | 120 | f = get_func_local(op_type, dtype) |
120 | 121 | kernel = dpex.kernel(f) |
121 | | - device = dpctl.SyclDevice(filter_str) |
122 | | - with dpctl.device_context(device): |
123 | | - gs = (N,) |
124 | | - ls = (N,) |
125 | | - kernel[gs, ls](a) |
| 122 | + kernel[dpex.Range(N), dpex.Range(N)](a) |
126 | 123 | assert a[0] == expected |
127 | 124 |
|
128 | 125 |
|
@@ -161,10 +158,8 @@ def test_kernel_atomic_multi_dim( |
161 | 158 | op_type, expected = return_list_of_op |
162 | 159 | dim = return_list_of_dim |
163 | 160 | kernel = get_kernel_multi_dim(op_type, len(dim)) |
164 | | - a = np.zeros(dim, return_dtype) |
165 | | - device = dpctl.SyclDevice(filter_str) |
166 | | - with dpctl.device_context(device): |
167 | | - kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a) |
| 161 | + a = np.zeros(dim, dtype=return_dtype, device=filter_str) |
| 162 | + kernel[dpex.Range(global_size)](a) |
168 | 163 | assert a[0] == expected |
169 | 164 |
|
170 | 165 |
|
|
0 commit comments