Skip to content

Commit

Permalink
利用axclrtSetDevice完成context的get/set,从而支持跨线程推理
Browse files Browse the repository at this point in the history
  • Loading branch information
zylo117 committed Dec 26, 2024
1 parent 96998c2 commit a65644d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
13 changes: 13 additions & 0 deletions axengine/_axcl_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
#define AXCL_MAX_DEVICE_COUNT 256
typedef int32_t axclError;
typedef void *axclrtContext;
"""
)

Expand Down Expand Up @@ -149,6 +150,18 @@
"""
axclError axclrtGetDeviceList(axclrtDeviceList *deviceList);
axclError axclrtSetDevice(int32_t deviceId);
axclError axclrtResetDevice(int32_t deviceId);
"""
)

# axcl_rt_context.h
O.cdef(
"""
axclError axclrtCreateContext(axclrtContext *context, int32_t deviceId);
axclError axclrtDestroyContext(axclrtContext context);
axclError axclrtSetCurrentContext(axclrtContext context);
axclError axclrtGetCurrentContext(axclrtContext *context);
axclError axclrtGetDefaultContext(axclrtContext *context, int32_t deviceId);
"""
)

Expand Down
18 changes: 16 additions & 2 deletions axengine/axcl_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(

super(BaseInferenceSession).__init__()

self.device_id = device_id

# load shared library
self._rt_lib = _capi.R
self._rt_ffi = _capi.O
Expand All @@ -34,12 +36,18 @@ def __init__(
print(f"[INFO] SOC Name: {self.soc_name}")

# init axcl
self.axcl_device_id = -1 # axcl_device_id != device_id, device_id is just the index of the list of axcl_device_ids
ret = self._init(device_id)
if 0 != ret:
raise RuntimeError("Failed to initialize axclrt.")
print(f"[INFO] Runtime version: {self._get_version()}")

# handle, context, info, io
self._thread_context = self._rt_ffi.new("axclrtContext *")
ret = self._rt_lib.axclrtGetCurrentContext(self._thread_context)
if ret != 0:
raise RuntimeError("axclrtGetCurrentContext failed")

# model handle, context, info, io
self._handle = self._rt_ffi.new("uint64_t *")
self._context = self._rt_ffi.new("uint64_t *")
self.io_info = self._rt_ffi.new("axclrtEngineIOInfo *")
Expand Down Expand Up @@ -256,7 +264,8 @@ def _init(self, device_id=0, vnpu=VNPUType.DISABLED): # vnpu type, the default
if ret != 0 or lst.num == 0:
raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.")

ret = self._rt_lib.axclrtSetDevice(lst.devices[device_id])
self.axcl_device_id = lst.devices[device_id]
ret = self._rt_lib.axclrtSetDevice(self.axcl_device_id)
if ret != 0 or lst.num == 0:
raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.")

Expand All @@ -269,6 +278,7 @@ def _init(self, device_id=0, vnpu=VNPUType.DISABLED): # vnpu type, the default
def _final(self):
if self._handle[0] is not None:
self._unload()
self._rt_lib.axclrtResetDevice(self.axcl_device_id)
self._rt_lib.axclFinalize()
return

Expand Down Expand Up @@ -331,6 +341,10 @@ def run(self, output_names, input_feed, run_options=None):
self._validate_input(list(input_feed.keys()))
self._validate_output(output_names)

ret = self._rt_lib.axclrtSetCurrentContext(self._thread_context[0])
if ret != 0:
raise RuntimeError("axclrtSetCurrentContext failed")

if None is output_names:
output_names = [o.name for o in self.get_outputs()]

Expand Down

0 comments on commit a65644d

Please sign in to comment.