|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
| 5 | +from dpctl import SyclQueue |
5 | 6 | from dpctl.tensor import usm_ndarray |
6 | 7 | from dpnp import ndarray |
| 8 | +from numba.core import types |
7 | 9 | from numba.extending import typeof_impl |
8 | 10 | from numba.np import numpy_support |
9 | 11 |
|
10 | | -from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray |
11 | | -from numba_dpex.core.types.usm_ndarray_type import USMNdArray |
12 | 12 | from numba_dpex.utils import address_space |
13 | 13 |
|
| 14 | +from ..types.dpctl_types import sycl_queue_ty |
| 15 | +from ..types.dpnp_ndarray_type import DpnpNdArray |
| 16 | +from ..types.usm_ndarray_type import USMNdArray |
| 17 | + |
14 | 18 |
|
15 | 19 | def _typeof_helper(val, array_class_type): |
16 | 20 | """Creates a Numba type of the specified ``array_class_type`` for ``val``.""" |
@@ -90,3 +94,17 @@ def typeof_dpnp_ndarray(val, c): |
90 | 94 | Returns: The Numba type corresponding to dpnp.ndarray |
91 | 95 | """ |
92 | 96 | return _typeof_helper(val, DpnpNdArray) |
| 97 | + |
| 98 | + |
| 99 | +@typeof_impl.register(SyclQueue) |
| 100 | +def typeof_dpctl_sycl_queue(val, c): |
| 101 | + """Registers the type inference implementation function for a |
| 102 | + dpctl.SyclQueue PyObject. |
| 103 | +
|
| 104 | + Args: |
| 105 | + val : An instance of dpctl.SyclQueue. |
| 106 | + c : Unused argument used to be consistent with Numba API. |
| 107 | +
|
| 108 | + Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclQueue instance. |
| 109 | + """ |
| 110 | + return sycl_queue_ty |
0 commit comments