88
99from functools import cached_property
1010
11+ from llvmlite import ir as llvmir
12+ from numba .core import types
1113from numba .core .descriptors import TargetDescriptor
1214from numba .core .target_extension import GPU , target_registry
15+ from numba .core .types .scalars import IntEnumClass
1316
1417from numba_dpex .core .descriptor import DpexTargetOptions
1518from numba_dpex .core .targets .kernel_target import (
1821)
1922from numba_dpex .experimental .models import exp_dmm
2023
24+ from .flag_enum import FlagEnum
25+ from .literal_intenum_type import IntEnumLiteral
26+
2127
2228# pylint: disable=R0903
2329class SyclDeviceExp (GPU ):
@@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
3945 are stable enough to be migrated to DpexKernelTypingContext.
4046 """
4147
48+ def resolve_value_type (self , val ):
49+ """
50+ Return the numba type of a Python value that is being used
51+ as a runtime constant.
52+ ValueError is raised for unsupported types.
53+ """
54+
55+ ty = super ().resolve_value_type (val )
56+
57+ if isinstance (ty , IntEnumClass ) and issubclass (val , FlagEnum ):
58+ ty = IntEnumLiteral (val )
59+
60+ return ty
61+
62+ def resolve_getattr (self , typ , attr ):
63+ """
64+ Resolve getting the attribute *attr* (a string) on the Numba type.
65+ The attribute's type is returned, or None if resolution failed.
66+ """
67+ ty = None
68+
69+ if isinstance (typ , IntEnumLiteral ):
70+ try :
71+ attrval = getattr (typ .literal_value , attr ).value
72+ ty = types .IntegerLiteral (attrval )
73+ except ValueError :
74+ pass
75+ else :
76+ ty = super ().resolve_getattr (typ , attr )
77+ return ty
78+
4279
4380# pylint: disable=W0223
4481# FIXME: Remove the pylint disablement once we add an override for
@@ -56,6 +93,22 @@ def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
5693 super ().__init__ (typingctx , target )
5794 self .data_model_manager = exp_dmm
5895
96+ def get_getattr (self , typ , attr ):
97+ """
98+ Overrides the get_getattr function to provide an implementation for
99+ getattr call on an IntegerEnumLiteral type.
100+ """
101+
102+ if isinstance (typ , IntEnumLiteral ):
103+ # pylint: disable=W0613
104+ def enum_literal_getattr_imp (context , builder , typ , val , attr ):
105+ enum_attr_value = getattr (typ .literal_value , attr ).value
106+ return llvmir .Constant (llvmir .IntType (64 ), enum_attr_value )
107+
108+ return enum_literal_getattr_imp
109+
110+ return super ().get_getattr (typ , attr )
111+
59112
60113class DpexExpKernelTarget (TargetDescriptor ):
61114 """
0 commit comments