77import hashlib
88import os
99
10- import numba_dpex as dpex
10+ from numba .core import types
11+
12+ import numba_dpex .experimental as dpex
1113from numba_dpex import float32 , usm_ndarray
1214from numba_dpex .core import config
1315from numba_dpex .core .descriptor import dpex_kernel_target
16+ from numba_dpex .core .types .kernel_api .index_space_ids import ItemType
1417
1518f32arrty = usm_ndarray (ndim = 1 , dtype = float32 , layout = "C" )
19+ itemty = ItemType (ndim = 1 )
1620
1721
18- def _get_kernel_llvm (fn , sig , debug = False ):
19- kernel = dpex . core . kernel_interface . spirv_kernel . SpirvKernel (
20- fn , fn . __name__
21- )
22- kernel . compile (
23- args = sig ,
24- target_ctx = dpex_kernel_target . target_context ,
25- typing_ctx = dpex_kernel_target . typing_context ,
26- debug = debug ,
27- compile_flags = None ,
22+ def _get_kernel_llvm ():
23+ def data_parallel_sum ( item , var_a , var_b , var_c ):
24+ i = item . get_id ( 0 )
25+ var_c [ i ] = var_a [ i ] + var_b [ i ]
26+
27+ sig = ( itemty , f32arrty , f32arrty , f32arrty )
28+
29+ disp = dpex . kernel ( sig )( data_parallel_sum )
30+ kcres = disp . get_compile_result (
31+ types . void ( itemty , f32arrty , f32arrty , f32arrty )
2832 )
29- return kernel .module_name , kernel .llvm_module
33+ llvm_module_str = kcres .library .get_llvm_str ()
34+ name = kcres .fndesc .llvm_func_name
35+ if len (name ) > 200 :
36+ sha256 = hashlib .sha256 (name .encode ("utf-8" )).hexdigest ()
37+ name = name [:150 ] + "_" + sha256
38+
39+ dump_file_name = name + ".ll"
40+
41+ return dump_file_name , llvm_module_str
3042
3143
3244def test_dump_file_on_dump_kernel_llvm_flag_on ():
@@ -37,50 +49,20 @@ def test_dump_file_on_dump_kernel_llvm_flag_on():
3749 and compare with llvm source stored in SprivKernel.
3850 """
3951
40- def data_parallel_sum (var_a , var_b , var_c ):
41- i = dpex .get_global_id (0 )
42- var_c [i ] = var_a [i ] + var_b [i ]
43-
44- sig = (f32arrty , f32arrty , f32arrty )
45-
4652 config .DUMP_KERNEL_LLVM = True
47-
48- llvm_module_name , llvm_module_str = _get_kernel_llvm (data_parallel_sum , sig )
49-
50- dump_file_name = (
51- "llvm_kernel_"
52- + hashlib .sha256 (llvm_module_name .encode ()).hexdigest ()
53- + ".ll"
54- )
55-
53+ dump_file_name , llvm_dumped_str = _get_kernel_llvm ()
5654 with open (dump_file_name , "r" ) as f :
57- llvm_dump = f .read ()
58-
59- assert llvm_module_str == llvm_dump
60-
55+ ondisk_llvm_dump_str = f .read ()
56+ assert llvm_dumped_str == ondisk_llvm_dump_str
6157 os .remove (dump_file_name )
58+ config .DUMP_KERNEL_LLVM = False
6259
6360
6461def test_no_dump_file_on_dump_kernel_llvm_flag_off ():
6562 """
6663 Test functionality of DUMP_KERNEL_LLVM config variable.
6764 Check llvm source is not dumped in .ll file in current directory.
6865 """
69-
70- def data_parallel_sum (var_a , var_b , var_c ):
71- i = dpex .get_global_id (0 )
72- var_c [i ] = var_a [i ] + var_b [i ]
73-
74- sig = (f32arrty , f32arrty , f32arrty )
75-
7666 config .DUMP_KERNEL_LLVM = False
77-
78- llvm_module_name , llvm_module_str = _get_kernel_llvm (data_parallel_sum , sig )
79-
80- dump_file_name = (
81- "llvm_kernel_"
82- + hashlib .sha256 (llvm_module_name .encode ()).hexdigest ()
83- + ".ll"
84- )
85-
67+ dump_file_name , llvm_dumped_str = _get_kernel_llvm ()
8668 assert not os .path .isfile (dump_file_name )
0 commit comments