Skip to content

Commit c823e04

Browse files
author
Diptorup Deb
committed
UPdated unit tests
1 parent 7d894a5 commit c823e04

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import dpctl
65
import dpnp
7-
from numba.core import types
86

97
import numba_dpex.experimental as exp_dpex
10-
from numba_dpex import DpctlSyclQueue, DpnpNdArray, Range, int64
8+
from numba_dpex import Range
119
from numba_dpex.experimental.flag_enum import FlagEnum
1210

1311

@@ -36,24 +34,3 @@ def test_compilation_of_flag_enum():
3634
assert a[1] == MockFlags.FLAG2
3735
for idx in range(2, 9):
3836
assert a[idx] == 1
39-
40-
41-
def test_compilation_as_literal_constant():
42-
@exp_dpex.device_func
43-
def bitwise_or_flags(flag1, flag2):
44-
return flag1 | flag2
45-
46-
def pass_flags_to_func(a):
47-
f1 = MockFlags.FLAG1
48-
f2 = MockFlags.FLAG2
49-
a[0] = bitwise_or_flags(f1, f2)
50-
51-
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
52-
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
53-
kernel_sig = types.void(i64arr_ty)
54-
55-
disp = exp_dpex.kernel(pass_flags_to_func)
56-
disp.compile(kernel_sig)
57-
kcres = disp.overloads[kernel_sig.args]
58-
llvm_ir_mod = kcres.library._final_module
59-
print(llvm_ir_mod)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import re
6+
7+
import dpctl
8+
from numba.core import types
9+
10+
import numba_dpex.experimental as exp_dpex
11+
from numba_dpex import DpctlSyclQueue, DpnpNdArray, int64
12+
from numba_dpex.experimental.flag_enum import FlagEnum
13+
14+
15+
class MockFlags(FlagEnum):
16+
FLAG1 = 1
17+
FLAG2 = 2
18+
19+
20+
def test_compilation_as_literal_constant():
21+
@exp_dpex.device_func
22+
def bitwise_or_flags(flag1, flag2):
23+
return flag1 | flag2
24+
25+
def pass_flags_to_func(a):
26+
f1 = MockFlags.FLAG1
27+
f2 = MockFlags.FLAG2
28+
a[0] = bitwise_or_flags(f1, f2)
29+
30+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
31+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
32+
kernel_sig = types.void(i64arr_ty)
33+
34+
disp = exp_dpex.kernel(pass_flags_to_func)
35+
disp.compile(kernel_sig)
36+
kcres = disp.overloads[kernel_sig.args]
37+
llvm_ir_mod = kcres.library._final_module.__str__()
38+
39+
pattern = re.compile(
40+
r"call spir_func i32 @\_Z.*bitwise\_or"
41+
r"\_flags.*\(i64\* nonnull %.*, i64 1, i64 2\)"
42+
)
43+
44+
assert re.search(pattern, llvm_ir_mod) is not None

0 commit comments

Comments
 (0)