diff --git a/third_party/xpu/backend/driver.py b/third_party/xpu/backend/driver.py index 3ad923e9e3..236b4f0adb 100644 --- a/third_party/xpu/backend/driver.py +++ b/third_party/xpu/backend/driver.py @@ -194,13 +194,13 @@ def _generate_src(): bool update(sycl::queue sycl_queue) { // Get l0-context auto sycl_context = sycl_queue.get_context(); - ze_context_handle_t hCtxt = get_native(sycl_context); + ze_context_handle_t hCtxt = get_native(sycl_context); // Get l0-device std::vector sycl_devices = sycl_context.get_devices(); - ze_device_handle_t hDev = get_native(sycl_devices[0]); + ze_device_handle_t hDev = get_native(sycl_devices[0]); // Get l0-queue bool immediate_cmd_list = false; - std::variant queue_var = get_native(sycl_queue); + std::variant queue_var = get_native(sycl_queue); auto l0_queue = std::get_if(&queue_var); if (l0_queue == nullptr) { auto imm_cmd_list = std::get_if(&queue_var); @@ -218,15 +218,18 @@ def _generate_src(): context = sycl_queue_map[sycl_queue].context; uint32_t deviceCount = std::min(sycl_devices.size(), devices.size()); for (uint32_t i = 0; i < deviceCount; ++i) { - devices[i] = sycl::get_native(sycl_devices[i]); + devices[i] = sycl::get_native(sycl_devices[i]); } return true; } static PyObject* initContext(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { @@ -251,8 +254,11 @@ def _generate_src(): } static PyObject* initDevices(PyObject* self, PyObject *args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); @@ -264,7 +270,7 @@ def _generate_src(): // Retrieve devices uint32_t deviceCount = sycl_devices.size(); for (uint32_t i = 0; i < deviceCount; ++i) { - devices.push_back(sycl::get_native(sycl_devices[i])); + devices.push_back(sycl::get_native(sycl_devices[i])); } // npy_intp dims[1]; @@ -280,8 +286,11 @@ def _generate_src(): } static PyObject* getL0ImmCommandList(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); @@ -291,8 +300,11 @@ def _generate_src(): return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].cmd_list)); } static PyObject* getL0Queue(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { @@ -301,8 +313,11 @@ def _generate_src(): return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].queue)); } static PyObject* getL0DevPtr(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { @@ -311,8 +326,11 @@ def _generate_src(): return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].device)); } static PyObject* getL0CtxtPtr(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { @@ -321,8 +339,11 @@ def _generate_src(): return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].context)); } static PyObject* isUsingICL(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) + PyObject *cap; + void* queue = NULL; + if(!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap)))) return NULL; sycl::queue* sycl_queue = static_cast(queue); if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) {