Skip to content

Commit

Permalink
Update level_zero
Browse files Browse the repository at this point in the history
  • Loading branch information
quintinwang5 committed Jan 16, 2024
1 parent 6a378c7 commit 6979d7e
Showing 1 changed file with 40 additions and 19 deletions.
59 changes: 40 additions & 19 deletions third_party/xpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ def _generate_src():
bool update(sycl::queue sycl_queue) {
// Get l0-context
auto sycl_context = sycl_queue.get_context();
ze_context_handle_t hCtxt = get_native<sycl::backend::level_zero>(sycl_context);
ze_context_handle_t hCtxt = get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
// Get l0-device
std::vector<sycl::device> sycl_devices = sycl_context.get_devices();
ze_device_handle_t hDev = get_native<sycl::backend::level_zero>(sycl_devices[0]);
ze_device_handle_t hDev = get_native<sycl::backend::ext_oneapi_level_zero>(sycl_devices[0]);
// Get l0-queue
bool immediate_cmd_list = false;
std::variant<ze_command_queue_handle_t, ze_command_list_handle_t> queue_var = get_native<sycl::backend::level_zero>(sycl_queue);
std::variant<ze_command_queue_handle_t, ze_command_list_handle_t> queue_var = get_native<sycl::backend::ext_oneapi_level_zero>(sycl_queue);
auto l0_queue = std::get_if<ze_command_queue_handle_t>(&queue_var);
if (l0_queue == nullptr) {
auto imm_cmd_list = std::get_if<ze_command_list_handle_t>(&queue_var);
Expand All @@ -218,15 +218,18 @@ def _generate_src():
context = sycl_queue_map[sycl_queue].context;
uint32_t deviceCount = std::min(sycl_devices.size(), devices.size());
for (uint32_t i = 0; i < deviceCount; ++i) {
devices[i] = sycl::get_native<sycl::backend::level_zero>(sycl_devices[i]);
devices[i] = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_devices[i]);
}
return true;
}
static PyObject* initContext(PyObject* self, PyObject* args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) {
Expand All @@ -251,8 +254,11 @@ def _generate_src():
}
static PyObject* initDevices(PyObject* self, PyObject *args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
Expand All @@ -264,7 +270,7 @@ def _generate_src():
// Retrieve devices
uint32_t deviceCount = sycl_devices.size();
for (uint32_t i = 0; i < deviceCount; ++i) {
devices.push_back(sycl::get_native<sycl::backend::level_zero>(sycl_devices[i]));
devices.push_back(sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_devices[i]));
}
// npy_intp dims[1];
Expand All @@ -280,8 +286,11 @@ def _generate_src():
}
static PyObject* getL0ImmCommandList(PyObject* self, PyObject* args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
Expand All @@ -291,8 +300,11 @@ def _generate_src():
return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].cmd_list));
}
static PyObject* getL0Queue(PyObject* self, PyObject* args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) {
Expand All @@ -301,8 +313,11 @@ def _generate_src():
return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].queue));
}
static PyObject* getL0DevPtr(PyObject* self, PyObject* args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) {
Expand All @@ -311,8 +326,11 @@ def _generate_src():
return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].device));
}
static PyObject* getL0CtxtPtr(PyObject* self, PyObject* args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) {
Expand All @@ -321,8 +339,11 @@ def _generate_src():
return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].context));
}
static PyObject* isUsingICL(PyObject* self, PyObject* args) {
void* queue;
if(!PyArg_ParseTuple(args, "K", &queue))
PyObject *cap;
void* queue = NULL;
if(!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if(!(queue = PyCapsule_GetPointer(cap, PyCapsule_GetName(cap))))
return NULL;
sycl::queue* sycl_queue = static_cast<sycl::queue*>(queue);
if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) {
Expand Down

0 comments on commit 6979d7e

Please sign in to comment.