Skip to content

Commit 7fe8d9b

Browse files
author
Diptorup Deb
committed
Add basic support for dpctl.SyclQueue as a Numba type.
- Adds a Numba type to represent dpctl.SyclQueue and infer it as an opaque pointer inside the compiler.
1 parent c30e1bc commit 7fe8d9b

File tree

5 files changed

+96
-4
lines changed

5 files changed

+96
-4
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from numba.core import datamodel, types
66
from numba.core.datamodel.models import ArrayModel as DpnpNdArrayModel
7-
from numba.core.datamodel.models import PrimitiveModel, StructModel
7+
from numba.core.datamodel.models import OpaqueModel, PrimitiveModel, StructModel
88
from numba.core.extending import register_model
99

10-
from numba_dpex.core.types import Array, DpnpNdArray, USMNdArray
1110
from numba_dpex.utils import address_space
1211

12+
from ..types import Array, DpctlSyclQueue, DpnpNdArray, USMNdArray
13+
1314

1415
class GenericPointerModel(PrimitiveModel):
1516
def __init__(self, dmm, fe_type):
@@ -81,3 +82,7 @@ def _init_data_model_manager():
8182
# Register the DpnpNdArray type with the Numba ArrayModel
8283
register_model(DpnpNdArray)(DpnpNdArrayModel)
8384
dpex_data_model_manager.register(DpnpNdArray, DpnpNdArrayModel)
85+
86+
# Register the DpctlSyclQueue type with Numba's OpaqueModel
87+
register_model(DpctlSyclQueue)(OpaqueModel)
88+
dpex_data_model_manager.register(DpctlSyclQueue, OpaqueModel)

numba_dpex/core/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from .array_type import Array
6+
from .dpctl_types import DpctlSyclQueue
67
from .dpnp_ndarray_type import DpnpNdArray
78
from .numba_types_short_names import (
89
b1,
@@ -31,6 +32,7 @@
3132

3233
__all__ = [
3334
"Array",
35+
"DpctlSyclQueue",
3436
"DpnpNdArray",
3537
"USMNdArray",
3638
"none",
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from dpctl import SyclQueue
6+
from numba import types
7+
from numba.extending import NativeValue, box, type_callable, unbox
8+
9+
10+
class DpctlSyclQueue(types.Type):
11+
"""A Numba type to represent a dpctl.SyclQueue PyObject.
12+
13+
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
14+
passing in and using a SyclQueue object as an opaque pointer type inside
15+
Numba.
16+
"""
17+
18+
def __init__(self):
19+
super().__init__(name="DpctlSyclQueueType")
20+
21+
22+
sycl_queue_ty = DpctlSyclQueue()
23+
24+
25+
@type_callable(SyclQueue)
26+
def type_interval(context):
27+
def typer():
28+
return sycl_queue_ty
29+
30+
return typer
31+
32+
33+
@unbox(DpctlSyclQueue)
34+
def unbox_sycl_queue(typ, obj, c):
35+
return NativeValue(obj)
36+
37+
38+
@box(DpctlSyclQueue)
39+
def box_pyobject(typ, val, c):
40+
return val

numba_dpex/core/typing/typeof.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from dpctl import SyclQueue
56
from dpctl.tensor import usm_ndarray
67
from dpnp import ndarray
8+
from numba.core import types
79
from numba.extending import typeof_impl
810
from numba.np import numpy_support
911

10-
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray
11-
from numba_dpex.core.types.usm_ndarray_type import USMNdArray
1212
from numba_dpex.utils import address_space
1313

14+
from ..types.dpctl_types import sycl_queue_ty
15+
from ..types.dpnp_ndarray_type import DpnpNdArray
16+
from ..types.usm_ndarray_type import USMNdArray
17+
1418

1519
def _typeof_helper(val, array_class_type):
1620
"""Creates a Numba type of the specified ``array_class_type`` for ``val``."""
@@ -90,3 +94,17 @@ def typeof_dpnp_ndarray(val, c):
9094
Returns: The Numba type corresponding to dpnp.ndarray
9195
"""
9296
return _typeof_helper(val, DpnpNdArray)
97+
98+
99+
@typeof_impl.register(SyclQueue)
100+
def typeof_dpctl_sycl_queue(val, c):
101+
"""Registers the type inference implementation function for a
102+
dpctl.SyclQueue PyObject.
103+
104+
Args:
105+
val : An instance of dpctl.SyclQueue.
106+
c : Unused argument used to be consistent with Numba API.
107+
108+
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclQueue instance.
109+
"""
110+
return sycl_queue_ty
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Tests for boxing and unboxing of types supported inside dpjit
7+
"""
8+
9+
import dpctl
10+
import pytest
11+
12+
from numba_dpex import dpjit
13+
14+
15+
@pytest.mark.parametrize(
16+
"obj",
17+
[
18+
pytest.param(dpctl.SyclQueue()),
19+
],
20+
)
21+
def test_boxing_unboxing(obj):
22+
@dpjit
23+
def func(a):
24+
return a
25+
26+
o = func(obj)
27+
assert id(o) == id(obj)

0 commit comments

Comments
 (0)