Skip to content

Commit 84766b2

Browse files
committed
Adding proper typing for dpctl.SyclQueue
1 parent cd66cfd commit 84766b2

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

numba_dpex/core/types/dpctl_types.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44

55
from dpctl import SyclQueue
66
from numba import types
7-
from numba.extending import NativeValue, box, type_callable, unbox
7+
from numba.extending import (
8+
NativeValue,
9+
as_numba_type,
10+
box,
11+
type_callable,
12+
typeof_impl,
13+
unbox,
14+
)
815

916

10-
class DpctlSyclQueue(types.Type):
17+
class SyclQueueType(types.Type):
1118
"""A Numba type to represent a dpctl.SyclQueue PyObject.
1219
1320
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
@@ -16,25 +23,33 @@ class DpctlSyclQueue(types.Type):
1623
"""
1724

1825
def __init__(self):
19-
super().__init__(name="DpctlSyclQueueType")
26+
super().__init__(name="SyclQueue")
2027

2128

22-
sycl_queue_ty = DpctlSyclQueue()
29+
sycl_queue_type = SyclQueueType()
30+
31+
32+
@typeof_impl.register(SyclQueue)
33+
def typeof_index(val, c):
34+
return sycl_queue_type
35+
36+
37+
as_numba_type.register(SyclQueue, sycl_queue_type)
2338

2439

2540
@type_callable(SyclQueue)
2641
def type_interval(context):
2742
def typer():
28-
return sycl_queue_ty
43+
return sycl_queue_type
2944

3045
return typer
3146

3247

33-
@unbox(DpctlSyclQueue)
48+
@unbox(SyclQueue)
3449
def unbox_sycl_queue(typ, obj, c):
3550
return NativeValue(obj)
3651

3752

38-
@box(DpctlSyclQueue)
53+
@box(SyclQueue)
3954
def box_pyobject(typ, val, c):
4055
return val

0 commit comments

Comments
 (0)