Skip to content

Commit b5fe8e8

Browse files
author
Diptorup Deb
committed
Various fixes.
1 parent 5b176d0 commit b5fe8e8

File tree

6 files changed

+30
-117
lines changed

6 files changed

+30
-117
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from numba.core import datamodel, types
66
from numba.core.datamodel.models import ArrayModel as DpnpNdArrayModel
7-
from numba.core.datamodel.models import OpaqueModel, PrimitiveModel, StructModel
7+
from numba.core.datamodel.models import PrimitiveModel, StructModel
88
from numba.core.extending import register_model
99

1010
from numba_dpex.utils import address_space
1111

12-
from ..types import Array, DpnpNdArray, SyclQueueType, USMNdArray
12+
from ..types import Array, DpctlSyclQueue, DpnpNdArray, USMNdArray
1313

1414

1515
class GenericPointerModel(PrimitiveModel):
@@ -57,13 +57,10 @@ def __init__(self, dmm, fe_type):
5757
class SyclQueueModel(StructModel):
5858
def __init__(self, dmm, fe_type):
5959
members = [
60-
("parent", types.CPointer),
61-
("queue_ref", types.PyObject),
62-
("context", types.PyObject),
63-
("device", types.PyObject),
60+
("parent", types.CPointer(types.voidptr)),
61+
("queue_ref", types.CPointer(types.voidptr)),
6462
]
65-
# super(StructModel, self).__init__(dmm, fe_type, members)
66-
StructModel.__init__(self, dmm, fe_type, members)
63+
super(SyclQueueModel, self).__init__(dmm, fe_type, members)
6764

6865

6966
def _init_data_model_manager():
@@ -95,10 +92,6 @@ def _init_data_model_manager():
9592
register_model(DpnpNdArray)(DpnpNdArrayModel)
9693
dpex_data_model_manager.register(DpnpNdArray, DpnpNdArrayModel)
9794

98-
# Register the DpctlSyclQueue type with Numba's OpaqueModel
99-
# register_model(DpctlSyclQueue)(OpaqueModel)
100-
# dpex_data_model_manager.register(DpctlSyclQueue, OpaqueModel)
101-
102-
# Register the DpctlSyclQueue type with Numba's OpaqueModel
103-
register_model(SyclQueueType)(SyclQueueModel)
104-
dpex_data_model_manager.register(SyclQueueType, SyclQueueModel)
95+
# Register the DpctlSyclQueue type
96+
register_model(DpctlSyclQueue)(SyclQueueModel)
97+
dpex_data_model_manager.register(DpctlSyclQueue, SyclQueueModel)

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
6767
int ndim,
6868
int writeable,
6969
PyArray_Descr *descr);
70-
static struct PySyclQueueObject *to_py_syclqobject(PyObject *obj);
7170
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
7271
queuestruct_t *queue_struct);
7372

@@ -649,21 +648,6 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
649648
return pyusmarrayobj;
650649
}
651650

652-
static struct PySyclQueueObject *to_py_syclqobject(PyObject *obj)
653-
{
654-
if (!obj)
655-
return NULL;
656-
if (!PyObject_TypeCheck(obj, &PySyclQueueType))
657-
return NULL;
658-
659-
struct PySyclQueueObject *pysyclqobj = (struct PySyclQueueObject *)(obj);
660-
// struct Py_SyclQueueObject py_syclqobj = pysyclqobj->__pyx_base;
661-
662-
// return &py_syclqobj;
663-
664-
return pysyclqobj;
665-
}
666-
667651
/*!
668652
* @brief Returns the product of the elements in an array of a given
669653
* length.
@@ -809,38 +793,33 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
809793
{
810794

811795
struct PySyclQueueObject *queue_obj = NULL;
812-
// DPCTLSyclQueueRef queue_ref = NULL;
796+
DPCTLSyclQueueRef queue_ref = NULL;
813797
PyGILState_STATE gstate;
814798

815799
// Increment the ref count on obj to prevent CPython from garbage
816800
// collecting the array.
817801
Py_IncRef(obj);
818802

803+
// We are unconditionally casting obj to a struct PySyclQueueObject*. If
804+
// the obj is not a struct PySyclQueueObject* then the SyclQueue_GetQueueRef
805+
// will error out.
806+
queue_obj = (struct PySyclQueueObject *)obj;
807+
819808
DPEXRT_DEBUG(
820809
nrt_debug_print("DPEXRT-DEBUG: In DPEXRT_sycl_queue_from_python.\n"));
821810

822-
// Check if the PyObject obj has an _array_obj attribute that is of
823-
// dpctl.tensor.usm_ndarray type.
824-
if (!(queue_obj = to_py_syclqobject(obj))) {
811+
if (!(queue_ref = SyclQueue_GetQueueRef(queue_obj))) {
825812
DPEXRT_DEBUG(nrt_debug_print(
826-
"DPEXRT-ERROR: to_py_syclqobject() check failed %d\n", __FILE__,
827-
__LINE__));
813+
"DPEXRT-ERROR: SyclQueue_GetQueueRef returned NULL at "
814+
"%s, line %d.\n",
815+
__FILE__, __LINE__));
828816
goto error;
829817
}
830818

831-
// if (!(queue_ref = SyclQueue_GetQueueRef(queue_obj))) {
832-
// DPEXRT_DEBUG(nrt_debug_print(
833-
// "DPEXRT-ERROR: SyclQueue_GetQueueRef returned NULL at "
834-
// "%s, line %d.\n",
835-
// __FILE__, __LINE__));
836-
// goto error;
837-
// }
838-
839819
queue_struct->parent = obj;
840-
// queue_struct->queue_ref = queue_ref;
841-
queue_struct->queue_ref = (PyObject *)queue_obj->__pyx_base._queue_ref;
842-
queue_struct->cotext = (PyObject *)queue_obj->__pyx_base._context;
843-
queue_struct->device = (PyObject *)queue_obj->__pyx_base._device;
820+
queue_struct->queue_ref = queue_ref;
821+
822+
return 0;
844823

845824
error:
846825
// If the check failed then decrement the refcount and return an error

numba_dpex/core/runtime/_queuestruct.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
* for the ArrayTemplate class).
77
*/
88

9-
#include "numpy/npy_common.h"
109
#include <Python.h>
1110

1211
typedef struct
1312
{
1413
PyObject *parent;
15-
PyObject *queue_ref;
16-
PyObject *cotext;
17-
PyObject *device;
14+
void *queue_ref;
1815
} queuestruct_t;
1916

2017
#endif /* NUMBA_DPEX_QUEUESTRUCT_H_ */

numba_dpex/core/types/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from .array_type import Array
6-
from .dpctl_types import SyclQueueType
6+
from .dpctl_types import DpctlSyclQueue
77
from .dpnp_ndarray_type import DpnpNdArray
88
from .numba_types_short_names import (
99
b1,
@@ -32,7 +32,7 @@
3232

3333
__all__ = [
3434
"Array",
35-
"SyclQueueType",
35+
"DpctlSyclQueue",
3636
"DpnpNdArray",
3737
"USMNdArray",
3838
"none",

numba_dpex/core/types/dpctl_types.py

Lines changed: 4 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,13 @@
55
from dpctl import SyclQueue
66
from numba import types
77
from numba.core import cgutils
8-
from numba.extending import (
9-
NativeValue,
10-
as_numba_type,
11-
box,
12-
type_callable,
13-
typeof_impl,
14-
unbox,
15-
)
8+
from numba.extending import NativeValue, box, unbox
169

1710
from numba_dpex.core.exceptions import UnreachableError
1811
from numba_dpex.core.runtime import context as dpxrtc
1912

2013

21-
class SyclQueueType(types.Type):
14+
class DpctlSyclQueue(types.Type):
2215
"""A Numba type to represent a dpctl.SyclQueue PyObject.
2316
2417
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
@@ -27,56 +20,10 @@ class SyclQueueType(types.Type):
2720
"""
2821

2922
def __init__(self):
30-
super(SyclQueueType, self).__init__(name="SyclQueue")
23+
super(DpctlSyclQueue, self).__init__(name="SyclQueue")
3124

3225

33-
# sycl_queue_type = SyclQueueType()
34-
35-
36-
# @typeof_impl.register(SyclQueue)
37-
# def typeof_index(val, c):
38-
# return sycl_queue_type
39-
40-
41-
# as_numba_type.register(SyclQueue, sycl_queue_type)
42-
43-
44-
@type_callable(SyclQueue)
45-
def type_sycl_queue(context):
46-
def typer(args):
47-
if isinstance(args, types.Tuple):
48-
if len(args) > 0:
49-
if (
50-
isinstance(args[0], types.PyObject)
51-
and isinstance(args[1], types.StringLiteral)
52-
and isinstance(args[2], types.PyObject)
53-
):
54-
return SyclQueueType()
55-
else:
56-
return SyclQueueType()
57-
elif isinstance(args, types.NoneType):
58-
return SyclQueueType()
59-
else:
60-
raise ValueError("Couldn't do type inference for 'SycleQueue'.")
61-
62-
return typer
63-
64-
65-
# @lower_builtin(SyclQueue, types.PyObject, types.StringLiteral, types.PyObject)
66-
# def impl_interval(context, builder, sig, args):
67-
# typ = sig.return_type
68-
# if len(args) > 0:
69-
# ctx, dev, property = args
70-
# sycl_queue = cgutils.create_struct_proxy(typ)(context, builder)
71-
# sycl_queue.ctx = ctx
72-
# sycl_queue.dev = dev
73-
# sycl_queue.property = property
74-
# else:
75-
# sycl_queue = cgutils.create_struct_proxy(typ)(context, builder)
76-
# return sycl_queue._getvalue()
77-
78-
79-
@unbox(SyclQueueType)
26+
@unbox(DpctlSyclQueue)
8027
def unbox_sycl_queue(typ, obj, c):
8128
"""
8229
Convert a SyclQueue object to a native structure.
@@ -89,7 +36,6 @@ def unbox_sycl_queue(typ, obj, c):
8936
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
9037
else:
9138
raise UnreachableError
92-
9339
is_error = cgutils.is_not_null(c.builder, errcode)
9440
# Handle error
9541
with c.builder.if_then(is_error, likely=False):

numba_dpex/core/typing/typeof.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from numba_dpex.utils import address_space
1313

14-
from ..types.dpctl_types import SyclQueueType
14+
from ..types.dpctl_types import DpctlSyclQueue
1515
from ..types.dpnp_ndarray_type import DpnpNdArray
1616
from ..types.usm_ndarray_type import USMNdArray
1717

@@ -107,6 +107,4 @@ def typeof_dpctl_sycl_queue(val, c):
107107
108108
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclQueue instance.
109109
"""
110-
# return sycl_queue_type
111-
# return _typeof_helper(val, SyclQueueType)
112-
return SyclQueueType()
110+
return DpctlSyclQueue()

0 commit comments

Comments
 (0)