forked from AXERA-TECH/pyaxengine
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
在axengine基础上增加axcl的支持 (AXERA-TECH#13)
* rebase 57a6028 * example适配双backend * 适配旧版python * assert backend in ['ax', 'axcl'] * update readme * 自动判断是否有axclrt lib以及device no(默认0)是否大于等于0,如果是则用axcl,否则ax * device_id>=0但是没有axcl设备时用ax * 自动判断平台 * 调整import * 调整import * update readme * add check for inputs on axcl (AXERA-TECH#14)
- Loading branch information
Showing
11 changed files
with
1,096 additions
and
400 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
# | ||
|
||
from . import _types | ||
from ._capi import E as _lib | ||
from ._ax_capi import E as _lib | ||
|
||
__all__: ["T"] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved. | ||
# | ||
# This source file is the property of Axera Semiconductor Co., Ltd. and | ||
# may not be copied or distributed in any isomorphic form without the prior | ||
# written consent of Axera Semiconductor Co., Ltd. | ||
# | ||
# modified by zylo117 | ||
|
||
import ctypes.util | ||
import platform | ||
|
||
from cffi import FFI | ||
|
||
__all__: ["R", "O"] | ||
|
||
O = FFI() | ||
|
||
# axcl_base.h | ||
O.cdef( | ||
""" | ||
#define AXCL_MAX_DEVICE_COUNT 256 | ||
typedef int32_t axclError; | ||
""" | ||
) | ||
|
||
# axcl_rt_type.h | ||
O.cdef( | ||
""" | ||
typedef struct axclrtDeviceList { | ||
uint32_t num; | ||
int32_t devices[AXCL_MAX_DEVICE_COUNT]; | ||
} axclrtDeviceList; | ||
typedef enum axclrtMemMallocPolicy { | ||
AXCL_MEM_MALLOC_HUGE_FIRST, | ||
AXCL_MEM_MALLOC_HUGE_ONLY, | ||
AXCL_MEM_MALLOC_NORMAL_ONLY | ||
} axclrtMemMallocPolicy; | ||
typedef enum axclrtMemcpyKind { | ||
AXCL_MEMCPY_HOST_TO_HOST, | ||
AXCL_MEMCPY_HOST_TO_DEVICE, //!< host vir -> device phy | ||
AXCL_MEMCPY_DEVICE_TO_HOST, //!< host vir <- device phy | ||
AXCL_MEMCPY_DEVICE_TO_DEVICE, | ||
AXCL_MEMCPY_HOST_PHY_TO_DEVICE, //!< host phy -> device phy | ||
AXCL_MEMCPY_DEVICE_TO_HOST_PHY, //!< host phy <- device phy | ||
} axclrtMemcpyKind; | ||
""" | ||
) | ||
|
||
# axcl_rt_engine_type.h | ||
O.cdef( | ||
""" | ||
#define AXCLRT_ENGINE_MAX_DIM_CNT 32 | ||
typedef void* axclrtEngineIOInfo; | ||
typedef void* axclrtEngineIO; | ||
typedef enum axclrtEngineVNpuKind { | ||
AXCL_VNPU_DISABLE = 0, | ||
AXCL_VNPU_ENABLE = 1, | ||
AXCL_VNPU_BIG_LITTLE = 2, | ||
AXCL_VNPU_LITTLE_BIG = 3, | ||
} axclrtEngineVNpuKind; | ||
typedef struct axclrtEngineIODims { | ||
int32_t dimCount; | ||
int32_t dims[AXCLRT_ENGINE_MAX_DIM_CNT]; | ||
} axclrtEngineIODims; | ||
""" | ||
) | ||
|
||
# ax_model_runner_axcl.cpp | ||
O.cdef( | ||
""" | ||
typedef enum | ||
{ | ||
AX_ENGINE_ABST_DEFAULT = 0, | ||
AX_ENGINE_ABST_CACHED = 1, | ||
} AX_ENGINE_ALLOC_BUFFER_STRATEGY_T; | ||
typedef struct | ||
{ | ||
int nIndex; | ||
int nSize; | ||
void *pBuf; | ||
void *pVirAddr; | ||
const char *Name; | ||
axclrtEngineIODims dims; | ||
} AXCL_IO_BUF_T; | ||
typedef struct | ||
{ | ||
uint32_t nInputSize; | ||
uint32_t nOutputSize; | ||
AXCL_IO_BUF_T *pInputs; | ||
AXCL_IO_BUF_T *pOutputs; | ||
} AXCL_IO_DATA_T; | ||
""" | ||
) | ||
|
||
# ax_model_runner.hpp | ||
O.cdef( | ||
""" | ||
typedef struct | ||
{ | ||
const char * sName; | ||
unsigned int nIdx; | ||
unsigned int vShape[AXCLRT_ENGINE_MAX_DIM_CNT]; | ||
unsigned int vShapeSize; | ||
int nSize; | ||
unsigned long long phyAddr; | ||
void *pVirAddr; | ||
} ax_runner_tensor_t; | ||
""" | ||
) | ||
|
||
# stdlib.h/string.h | ||
O.cdef( | ||
""" | ||
void free (void *__ptr); | ||
void *malloc(size_t size); | ||
void *memset (void *__s, int __c, size_t __n); | ||
void *memcpy (void * __dest, const void * __src, size_t __n); | ||
""" | ||
) | ||
|
||
|
||
|
||
# axcl.h | ||
O.cdef( | ||
""" | ||
axclError axclInit(const char *config); | ||
axclError axclFinalize(); | ||
""" | ||
) | ||
|
||
# axcl_rt.h | ||
O.cdef( | ||
""" | ||
axclError axclrtGetVersion(int32_t *major, int32_t *minor, int32_t *patch); | ||
const char *axclrtGetSocName(); | ||
""" | ||
) | ||
|
||
# axcl_rt_device.h | ||
O.cdef( | ||
""" | ||
axclError axclrtGetDeviceList(axclrtDeviceList *deviceList); | ||
axclError axclrtSetDevice(int32_t deviceId); | ||
""" | ||
) | ||
|
||
# axcl_rt_engine.h | ||
O.cdef( | ||
""" | ||
axclError axclrtEngineInit(axclrtEngineVNpuKind npuKind); | ||
axclError axclrtEngineLoadFromMem(const void *model, uint64_t modelSize, uint64_t *modelId); | ||
axclError axclrtEngineCreateContext(uint64_t modelId, uint64_t *contextId); | ||
axclError axclrtEngineGetVNpuKind(axclrtEngineVNpuKind *npuKind); | ||
const char* axclrtEngineGetModelCompilerVersion(uint64_t modelId); | ||
axclError axclrtEngineGetIOInfo(uint64_t modelId, axclrtEngineIOInfo *ioInfo); | ||
axclError axclrtEngineGetShapeGroupsCount(axclrtEngineIOInfo ioInfo, int32_t *count); | ||
axclError axclrtEngineCreateIO(axclrtEngineIOInfo ioInfo, axclrtEngineIO *io); | ||
uint32_t axclrtEngineGetNumInputs(axclrtEngineIOInfo ioInfo); | ||
uint32_t axclrtEngineGetNumOutputs(axclrtEngineIOInfo ioInfo); | ||
uint64_t axclrtEngineGetInputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index); | ||
axclError axclrtEngineGetInputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims); | ||
const char *axclrtEngineGetInputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index); | ||
axclError axclrtEngineSetInputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size); | ||
uint64_t axclrtEngineGetOutputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index); | ||
axclError axclrtEngineGetOutputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims); | ||
const char *axclrtEngineGetOutputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index); | ||
axclError axclrtEngineSetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size); | ||
axclError axclrtEngineExecute(uint64_t modelId, uint64_t contextId, uint32_t group, axclrtEngineIO io); | ||
axclError axclrtEngineDestroyIO(axclrtEngineIO io); | ||
axclError axclrtEngineUnload(uint64_t modelId); | ||
""" | ||
) | ||
|
||
# axcl_rt_memory.h | ||
O.cdef( | ||
""" | ||
axclError axclrtMalloc(void **devPtr, size_t size, axclrtMemMallocPolicy policy); | ||
axclError axclrtMallocCached(void **devPtr, size_t size, axclrtMemMallocPolicy policy); | ||
axclError axclrtMemcpy(void *dstPtr, const void *srcPtr, size_t count, axclrtMemcpyKind kind); | ||
axclError axclrtFree(void *devPtr); | ||
axclError axclrtMemFlush(void *devPtr, size_t size); | ||
""" | ||
) | ||
|
||
rt_name = "axcl_rt" | ||
rt_path = ctypes.util.find_library(rt_name) | ||
assert ( | ||
rt_path is not None | ||
), f"Failed to find library {rt_name}. Please ensure it is installed and in the library path." | ||
|
||
R = O.dlopen(rt_path) | ||
assert R is not None, f"Failed to load library {rt_path}. Please ensure it is installed and in the library path." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from ._node import NodeArg | ||
from ._types import VNPUType | ||
|
||
import numpy as np | ||
|
||
class BaseInferenceSession: | ||
def __init__(self, *args, **kwargs) -> None: | ||
self._shape_count = 0 | ||
self._inputs = [] | ||
self._outputs = [] | ||
|
||
def __del__(self): | ||
self._final() | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self._final() | ||
|
||
def _init(self, *args, **kwargs): | ||
return | ||
|
||
def _final(self): | ||
return | ||
|
||
def _get_version(self) -> str: | ||
return '' | ||
|
||
def _get_vnpu_type(self) -> VNPUType: | ||
return VNPUType(0) | ||
|
||
def _get_model_tool_version(self) -> str: | ||
return '' | ||
|
||
def _load(self) -> 0: | ||
return 0 | ||
|
||
def _get_shape_count(self) -> int: | ||
return 0 | ||
|
||
def _unload(self): | ||
return | ||
|
||
def get_inputs(self, shape_group=0) -> list[NodeArg]: | ||
if shape_group > self._shape_count: | ||
raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") | ||
selected_info = self._inputs[shape_group] | ||
return selected_info | ||
|
||
def get_outputs(self, shape_group=0) -> list[NodeArg]: | ||
if shape_group > self._shape_count: | ||
raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.") | ||
selected_info = self._outputs[shape_group] | ||
return selected_info | ||
|
||
# copy from onnxruntime | ||
def _validate_input(self, feed_input_names): | ||
missing_input_names = [] | ||
for i in self.get_inputs(): | ||
if i.name not in feed_input_names: | ||
missing_input_names.append(i.name) | ||
if missing_input_names: | ||
raise ValueError( | ||
f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})." | ||
) | ||
|
||
def _validate_output(self, output_names): | ||
if output_names is not None: | ||
for name in output_names: | ||
if name not in [o.name for o in self.get_outputs()]: | ||
raise ValueError(f"Output name '{name}' is not registered.") | ||
|
||
def run(self, output_names, input_feed, run_options=None) -> list[np.ndarray]: | ||
return [] |
Oops, something went wrong.