Skip to content

Commit ec3e962

Browse files
author
Diptorup Deb
committed
Unit test checking if FlagEnum values are lowered as constants in LLVM IR.
1 parent 9ce9927 commit ec3e962

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import re
6+
7+
import dpctl
8+
from numba.core import types
9+
10+
import numba_dpex.experimental as exp_dpex
11+
from numba_dpex import DpctlSyclQueue, DpnpNdArray, int64
12+
from numba_dpex.experimental.flag_enum import FlagEnum
13+
14+
15+
class MockFlags(FlagEnum):
16+
FLAG1 = 1
17+
FLAG2 = 2
18+
19+
20+
def test_compilation_as_literal_constant():
21+
"""Tests if FlagEnum objects are treaded as scalar constants inside
22+
numba-dpex generated code.
23+
24+
The test case compiles the kernel `pass_flags_to_func` that includes a
25+
call to the device_func `bitwise_or_flags`. The `bitwise_or_flags` function
26+
is passed two FlagEnum arguments. The test case evaluates the generated
27+
LLVM IR for `pass_flags_to_func` to see if the call to `bitwise_or_flags`
28+
has the scalar arguments `i64 1` and `i64 2`.
29+
"""
30+
31+
@exp_dpex.device_func
32+
def bitwise_or_flags(flag1, flag2):
33+
return flag1 | flag2
34+
35+
def pass_flags_to_func(a):
36+
f1 = MockFlags.FLAG1
37+
f2 = MockFlags.FLAG2
38+
a[0] = bitwise_or_flags(f1, f2)
39+
40+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
41+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
42+
kernel_sig = types.void(i64arr_ty)
43+
44+
disp = exp_dpex.kernel(pass_flags_to_func)
45+
disp.compile(kernel_sig)
46+
kcres = disp.overloads[kernel_sig.args]
47+
llvm_ir_mod = kcres.library._final_module.__str__()
48+
49+
pattern = re.compile(
50+
r"call spir_func i32 @\_Z.*bitwise\_or"
51+
r"\_flags.*\(i64\* nonnull %.*, i64 1, i64 2\)"
52+
)
53+
54+
assert re.search(pattern, llvm_ir_mod) is not None

0 commit comments

Comments
 (0)