Skip to content

Commit 7d894a5

Browse files
author
Diptorup Deb
committed
WIP unit tests...
1 parent 1ade40d commit 7d894a5

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

numba_dpex/experimental/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from numba.core.imputils import Registry
1010

11-
from .decorators import kernel
11+
from .decorators import device_func, kernel
1212
from .kernel_dispatcher import KernelDispatcher
1313
from .launcher import call_kernel, call_kernel_async
1414
from .literal_intenum_type import IntEnumLiteral
@@ -28,6 +28,7 @@ def dpex_dispatcher_const(context):
2828

2929

3030
__all__ = [
31+
"device_func",
3132
"kernel",
3233
"call_kernel",
3334
"call_kernel_async",

numba_dpex/experimental/target.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext):
8989
they are stable enough to be migrated to DpexKernelTargetContext.
9090
"""
9191

92+
allow_dynamic_globals = True
93+
9294
def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME):
9395
super().__init__(typingctx, target)
9496
self.data_model_manager = exp_dmm

numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import dpctl
56
import dpnp
7+
from numba.core import types
68

79
import numba_dpex.experimental as exp_dpex
8-
from numba_dpex import Range
10+
from numba_dpex import DpctlSyclQueue, DpnpNdArray, Range, int64
911
from numba_dpex.experimental.flag_enum import FlagEnum
1012

1113

1214
class MockFlags(FlagEnum):
13-
FLAG1 = 100
14-
FLAG2 = 200
15+
FLAG1 = 1
16+
FLAG2 = 2
1517

1618

1719
@exp_dpex.kernel(
@@ -34,3 +36,24 @@ def test_compilation_of_flag_enum():
3436
assert a[1] == MockFlags.FLAG2
3537
for idx in range(2, 9):
3638
assert a[idx] == 1
39+
40+
41+
def test_compilation_as_literal_constant():
42+
@exp_dpex.device_func
43+
def bitwise_or_flags(flag1, flag2):
44+
return flag1 | flag2
45+
46+
def pass_flags_to_func(a):
47+
f1 = MockFlags.FLAG1
48+
f2 = MockFlags.FLAG2
49+
a[0] = bitwise_or_flags(f1, f2)
50+
51+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
52+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
53+
kernel_sig = types.void(i64arr_ty)
54+
55+
disp = exp_dpex.kernel(pass_flags_to_func)
56+
disp.compile(kernel_sig)
57+
kcres = disp.overloads[kernel_sig.args]
58+
llvm_ir_mod = kcres.library._final_module
59+
print(llvm_ir_mod)

0 commit comments

Comments
 (0)