88
99from functools import cached_property
1010
11+ from llvmlite import ir as llvmir
12+ from numba .core import types
1113from numba .core .descriptors import TargetDescriptor
1214from numba .core .target_extension import GPU , target_registry
15+ from numba .core .types .scalars import IntEnumClass
1316
1417from numba_dpex .core .descriptor import DpexTargetOptions
1518from numba_dpex .core .targets .kernel_target import (
1821)
1922from numba_dpex .experimental .models import exp_dmm
2023
24+ from .flag_enum import FlagEnum
25+ from .literal_intenum_type import IntEnumLiteral
26+
2127
2228# pylint: disable=R0903
2329class SyclDeviceExp (GPU ):
@@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext):
3945 are stable enough to be migrated to DpexKernelTypingContext.
4046 """
4147
48+ def resolve_value_type (self , val ):
49+ """
50+ Return the numba type of a Python value that is being used
51+ as a runtime constant.
52+ ValueError is raised for unsupported types.
53+ """
54+
55+ ty = super ().resolve_value_type (val )
56+
57+ if isinstance (ty , IntEnumClass ) and issubclass (val , FlagEnum ):
58+ ty = IntEnumLiteral (val )
59+
60+ return ty
61+
62+ def resolve_getattr (self , typ , attr ):
63+ """
64+ Resolve getting the attribute *attr* (a string) on the Numba type.
65+ The attribute's type is returned, or None if resolution failed.
66+ """
67+ ty = None
68+
69+ if isinstance (typ , IntEnumLiteral ):
70+ try :
71+ attrval = getattr (typ .literal_value , attr ).value
72+ ty = types .IntegerLiteral (attrval )
73+ except ValueError :
74+ pass
75+ else :
76+ ty = super ().resolve_getattr (typ , attr )
77+ return ty
78+
4279
4380# pylint: disable=W0223
4481# FIXME: Remove the pylint disablement once we add an override for
@@ -52,10 +89,28 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
5289 they are stable enough to be migrated to DpexKernelTargetContext.
5390 """
5491
92+ allow_dynamic_globals = True
93+
5594 def __init__ (self , typingctx , target = DPEX_KERNEL_EXP_TARGET_NAME ):
5695 super ().__init__ (typingctx , target )
5796 self .data_model_manager = exp_dmm
5897
98+ def get_getattr (self , typ , attr ):
99+ """
100+ Overrides the get_getattr function to provide an implementation for
101+ getattr call on an IntegerEnumLiteral type.
102+ """
103+
104+ if isinstance (typ , IntEnumLiteral ):
105+ # pylint: disable=W0613
106+ def enum_literal_getattr_imp (context , builder , typ , val , attr ):
107+ enum_attr_value = getattr (typ .literal_value , attr ).value
108+ return llvmir .Constant (llvmir .IntType (64 ), enum_attr_value )
109+
110+ return enum_literal_getattr_imp
111+
112+ return super ().get_getattr (typ , attr )
113+
59114
60115class DpexExpKernelTarget (TargetDescriptor ):
61116 """
0 commit comments