diff --git a/numba_dpex/dpctl_iface/_helpers.py b/numba_dpex/dpctl_iface/_helpers.py index 2814eb1aff..be8354b459 100644 --- a/numba_dpex/dpctl_iface/_helpers.py +++ b/numba_dpex/dpctl_iface/_helpers.py @@ -4,6 +4,8 @@ from numba.core import types +from numba_dpex import dpctl_sem_version + def numba_type_to_dpctl_typenum(context, ty): """ @@ -11,32 +13,52 @@ def numba_type_to_dpctl_typenum(context, ty): ``DPCTLKernelArgType``. """ - val = None - if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): - # DPCTL_LONG_LONG - val = context.get_constant(types.int32, 9) - elif ty == types.uint32: - # DPCTL_UNSIGNED_LONG_LONG - val = context.get_constant(types.int32, 10) - elif ty == types.boolean: - # DPCTL_UNSIGNED_INT - val = context.get_constant(types.int32, 5) - elif ty == types.int64: - # DPCTL_LONG_LONG - val = context.get_constant(types.int32, 9) - elif ty == types.uint64: - # DPCTL_SIZE_T - val = context.get_constant(types.int32, 11) - elif ty == types.float32: - # DPCTL_FLOAT - val = context.get_constant(types.int32, 12) - elif ty == types.float64: - # DPCTL_DOUBLE - val = context.get_constant(types.int32, 13) - elif ty == types.voidptr or isinstance(ty, types.CPointer): - # DPCTL_VOID_PTR - val = context.get_constant(types.int32, 15) - else: - raise NotImplementedError + if dpctl_sem_version >= (0, 17, 0): + # FIXME change to imports from a dpctl enum/class rather than + # hard coding these numbers. - return val + if ty == types.boolean: + return context.get_constant(types.int32, 1) + elif ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): + return context.get_constant(types.int32, 4) + elif ty == types.uint32: + return context.get_constant(types.int32, 5) + elif ty == types.int64: + return context.get_constant(types.int32, 6) + elif ty == types.uint64: + return context.get_constant(types.int32, 7) + elif ty == types.float32: + return context.get_constant(types.int32, 8) + elif ty == types.float64: + return context.get_constant(types.int32, 9) + elif ty == types.voidptr or isinstance(ty, types.CPointer): + return context.get_constant(types.int32, 10) + else: + raise NotImplementedError + else: + if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral): + # DPCTL_LONG_LONG + return context.get_constant(types.int32, 9) + elif ty == types.uint32: + # DPCTL_UNSIGNED_LONG_LONG + return context.get_constant(types.int32, 10) + elif ty == types.boolean: + # DPCTL_UNSIGNED_INT + return context.get_constant(types.int32, 5) + elif ty == types.int64: + # DPCTL_LONG_LONG + return context.get_constant(types.int32, 9) + elif ty == types.uint64: + # DPCTL_SIZE_T + return context.get_constant(types.int32, 11) + elif ty == types.float32: + # DPCTL_FLOAT + return context.get_constant(types.int32, 12) + elif ty == types.float64: + # DPCTL_DOUBLE + return context.get_constant(types.int32, 13) + elif ty == types.voidptr or isinstance(ty, types.CPointer): + # DPCTL_VOID_PTR + return context.get_constant(types.int32, 15) + else: + raise NotImplementedError