Skip to content

Commit 22d8cde

Browse files
author
Diptorup Deb
committed
Port test_dump_kernel_llvm to experimental
1 parent 94c7106 commit 22d8cde

File tree

1 file changed

+29
-47
lines changed

1 file changed

+29
-47
lines changed

numba_dpex/tests/kernel_tests/test_dump_kernel_llvm.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,38 @@
77
import hashlib
88
import os
99

10-
import numba_dpex as dpex
10+
from numba.core import types
11+
12+
import numba_dpex.experimental as dpex
1113
from numba_dpex import float32, usm_ndarray
1214
from numba_dpex.core import config
1315
from numba_dpex.core.descriptor import dpex_kernel_target
16+
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
1417

1518
f32arrty = usm_ndarray(ndim=1, dtype=float32, layout="C")
19+
itemty = ItemType(ndim=1)
1620

1721

18-
def _get_kernel_llvm(fn, sig, debug=False):
19-
kernel = dpex.core.kernel_interface.spirv_kernel.SpirvKernel(
20-
fn, fn.__name__
21-
)
22-
kernel.compile(
23-
args=sig,
24-
target_ctx=dpex_kernel_target.target_context,
25-
typing_ctx=dpex_kernel_target.typing_context,
26-
debug=debug,
27-
compile_flags=None,
22+
def _get_kernel_llvm():
23+
def data_parallel_sum(item, var_a, var_b, var_c):
24+
i = item.get_id(0)
25+
var_c[i] = var_a[i] + var_b[i]
26+
27+
sig = (itemty, f32arrty, f32arrty, f32arrty)
28+
29+
disp = dpex.kernel(sig)(data_parallel_sum)
30+
kcres = disp.get_compile_result(
31+
types.void(itemty, f32arrty, f32arrty, f32arrty)
2832
)
29-
return kernel.module_name, kernel.llvm_module
33+
llvm_module_str = kcres.library.get_llvm_str()
34+
name = kcres.fndesc.llvm_func_name
35+
if len(name) > 200:
36+
sha256 = hashlib.sha256(name.encode("utf-8")).hexdigest()
37+
name = name[:150] + "_" + sha256
38+
39+
dump_file_name = name + ".ll"
40+
41+
return dump_file_name, llvm_module_str
3042

3143

3244
def test_dump_file_on_dump_kernel_llvm_flag_on():
@@ -37,50 +49,20 @@ def test_dump_file_on_dump_kernel_llvm_flag_on():
3749
and compare with llvm source stored in SprivKernel.
3850
"""
3951

40-
def data_parallel_sum(var_a, var_b, var_c):
41-
i = dpex.get_global_id(0)
42-
var_c[i] = var_a[i] + var_b[i]
43-
44-
sig = (f32arrty, f32arrty, f32arrty)
45-
4652
config.DUMP_KERNEL_LLVM = True
47-
48-
llvm_module_name, llvm_module_str = _get_kernel_llvm(data_parallel_sum, sig)
49-
50-
dump_file_name = (
51-
"llvm_kernel_"
52-
+ hashlib.sha256(llvm_module_name.encode()).hexdigest()
53-
+ ".ll"
54-
)
55-
53+
dump_file_name, llvm_dumped_str = _get_kernel_llvm()
5654
with open(dump_file_name, "r") as f:
57-
llvm_dump = f.read()
58-
59-
assert llvm_module_str == llvm_dump
60-
55+
ondisk_llvm_dump_str = f.read()
56+
assert llvm_dumped_str == ondisk_llvm_dump_str
6157
os.remove(dump_file_name)
58+
config.DUMP_KERNEL_LLVM = False
6259

6360

6461
def test_no_dump_file_on_dump_kernel_llvm_flag_off():
6562
"""
6663
Test functionality of DUMP_KERNEL_LLVM config variable.
6764
Check llvm source is not dumped in .ll file in current directory.
6865
"""
69-
70-
def data_parallel_sum(var_a, var_b, var_c):
71-
i = dpex.get_global_id(0)
72-
var_c[i] = var_a[i] + var_b[i]
73-
74-
sig = (f32arrty, f32arrty, f32arrty)
75-
7666
config.DUMP_KERNEL_LLVM = False
77-
78-
llvm_module_name, llvm_module_str = _get_kernel_llvm(data_parallel_sum, sig)
79-
80-
dump_file_name = (
81-
"llvm_kernel_"
82-
+ hashlib.sha256(llvm_module_name.encode()).hexdigest()
83-
+ ".ll"
84-
)
85-
67+
dump_file_name, llvm_dumped_str = _get_kernel_llvm()
8668
assert not os.path.isfile(dump_file_name)

0 commit comments

Comments
 (0)