Skip to content

Commit b4f99ef

Browse files
reazulhoqueDeb, Diptorup
authored andcommitted
(numba/dppy) Adds support for local memory allocation inside dppy.kernel, clean ups to barrier, use context manager
* Added local memory allocation and fixed bugs in barrier implementations * Added support for passing signature in dppy.kernel decorator * Added bigger example that deals with local memory and barrier * use barrier and local memory in atomic test * Change local.alloc to local.static_alloc to indicate this is static allocation * Added pass to ensure only constants can be given as shape to static local memory allocation * Change message and fix typo * Include pass * Feature/dppy backend context manager (#85) * Do not use gpu_env instead use the current_env specified in the runtime. * Update all examples to use dppy context manager. * Fix Test suites and update API to use dppy context manager * Uncomment barrier * Fix caching tests
1 parent d18b2a2 commit b4f99ef

32 files changed

+1058
-592
lines changed

compiler.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numba.typing.templates import AbstractTemplate
99
from numba import ctypes_support as ctypes
1010
from types import FunctionType
11+
from inspect import signature
1112

1213
import dppy.core as driver
1314
from . import spirv_generator
@@ -31,6 +32,21 @@ def _raise_invalid_kernel_enqueue_args():
3132
"The local size argument is optional.")
3233
raise ValueError(error_message)
3334

35+
36+
def get_ordered_arg_access_types(pyfunc, access_types):
37+
# Construct a list of access type of each arg according to their position
38+
ordered_arg_access_types = []
39+
sig = signature(pyfunc, follow_wrapped=False)
40+
for idx, arg_name in enumerate(sig.parameters):
41+
if access_types:
42+
for key in access_types:
43+
if arg_name in access_types[key]:
44+
ordered_arg_access_types.append(key)
45+
if len(ordered_arg_access_types) <= idx:
46+
ordered_arg_access_types.append(None)
47+
48+
return ordered_arg_access_types
49+
3450
class DPPyCompiler(CompilerBase):
3551
""" DPPy Compiler """
3652

@@ -62,7 +78,8 @@ def compile_with_dppy(pyfunc, return_type, args, debug):
6278
# Do not compile (generate native code), just lower (to LLVM)
6379
flags.set('no_compile')
6480
flags.set('no_cpython_wrapper')
65-
flags.set('nrt')
81+
#flags.set('nrt')
82+
6683
# Run compilation pipeline
6784
if isinstance(pyfunc, FunctionType):
6885
cres = compiler.compile_extra(typingctx=typingctx,
@@ -95,6 +112,9 @@ def compile_with_dppy(pyfunc, return_type, args, debug):
95112
def compile_kernel(device, pyfunc, args, access_types, debug=False):
96113
if DEBUG:
97114
print("compile_kernel", args)
115+
if not device:
116+
device = driver.runtime.get_current_device()
117+
98118
cres = compile_with_dppy(pyfunc, None, args, debug=debug)
99119
func = cres.library.get_function(cres.fndesc.llvm_func_name)
100120
kernel = cres.target_context.prepare_ocl_kernel(func, cres.signature.args)
@@ -241,7 +261,8 @@ def _ensure_valid_work_group_size(val, work_item_grid):
241261
val = [val]
242262

243263
if len(val) != len(work_item_grid):
244-
error_message = ("Unsupported number of work item dimensions ")
264+
error_message = ("Unsupported number of work item dimensions, " +
265+
"dimensions of global and local work items has to be the same ")
245266
raise ValueError(error_message)
246267

247268
return list(val)
@@ -283,19 +304,19 @@ def forall(self, nelem, local_size=64, queue=None):
283304
def __getitem__(self, args):
284305
"""Mimick CUDA python's square-bracket notation for configuration.
285306
This assumes the argument to be:
286-
`device_env, global size, local size`
307+
`global size, local size`
287308
"""
288309
ls = None
289310
nargs = len(args)
290311
# Check if the kernel enquing arguments are sane
291-
if nargs < 2 or nargs > 3:
312+
if nargs < 1 or nargs > 2:
292313
_raise_invalid_kernel_enqueue_args
293314

294-
device_env = args[0]
295-
gs = _ensure_valid_work_item_grid(args[1], device_env)
315+
device_env = driver.runtime.get_current_device()
316+
gs = _ensure_valid_work_item_grid(args[0], device_env)
296317
# If the optional local size argument is provided
297-
if nargs == 3:
298-
ls = _ensure_valid_work_group_size(args[2], gs)
318+
if nargs == 2 and args[1] != []:
319+
ls = _ensure_valid_work_group_size(args[1], gs)
299320

300321
return self.configure(device_env, gs, ls)
301322

@@ -400,7 +421,7 @@ def _unpack_argument(self, ty, val, device_env, retr, kernelargs,
400421
dArr = device_env.copy_array_to_device(val)
401422
elif self.valid_access_types[access_type] == _NUMBA_PVC_WRITE_ONLY:
402423
# write_only case, we do not copy the host data
403-
dArr = driver.DeviceArray(device_env.get_env_ptr(), val)
424+
dArr = device_env.create_device_array(val)
404425

405426
assert (dArr != None), "Problem in allocating device buffer"
406427
device_arrs[-1] = dArr
@@ -466,7 +487,7 @@ def __call__(self, *args, **kwargs):
466487
assert not kwargs, "Keyword Arguments are not supported"
467488
if self.device_env is None:
468489
try:
469-
self.device_env = driver.runtime.get_gpu_device()
490+
self.device_env = driver.runtime.get_current_device()
470491
except:
471492
_raise_no_device_found_error()
472493

decorators.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import print_function, absolute_import, division
22
from numba import sigutils, types
3-
from .compiler import (compile_kernel, JitDPPyKernel,
4-
compile_dppy_func_template, compile_dppy_func)
5-
from inspect import signature
3+
from .compiler import (compile_kernel, JitDPPyKernel, compile_dppy_func_template,
4+
compile_dppy_func, get_ordered_arg_access_types)
65

76

87
def kernel(signature=None, access_types=None, debug=False):
@@ -17,34 +16,25 @@ def kernel(signature=None, access_types=None, debug=False):
1716
func = signature
1817
return autojit(debug=False, access_types=access_types)(func)
1918
else:
20-
return _kernel_jit(signature, debug)
19+
return _kernel_jit(signature, debug, access_types)
2120

2221

2322
def autojit(debug=False, access_types=None):
2423
def _kernel_autojit(pyfunc):
25-
# Construct a list of access type of each arg according to their position
26-
ordered_arg_access_types = []
27-
sig = signature(pyfunc, follow_wrapped=False)
28-
for idx, arg_name in enumerate(sig.parameters):
29-
if access_types:
30-
for key in access_types:
31-
if arg_name in access_types[key]:
32-
ordered_arg_access_types.append(key)
33-
if len(ordered_arg_access_types) <= idx:
34-
ordered_arg_access_types.append(None)
35-
24+
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
3625
return JitDPPyKernel(pyfunc, ordered_arg_access_types)
3726
return _kernel_autojit
3827

3928

40-
def _kernel_jit(signature, debug):
29+
def _kernel_jit(signature, debug, access_types):
4130
argtypes, restype = sigutils.normalize_signature(signature)
4231
if restype is not None and restype != types.void:
4332
msg = ("DPPy kernel must have void return type but got {restype}")
4433
raise TypeError(msg.format(restype=restype))
4534

4635
def _wrapped(pyfunc):
47-
return compile_kernel(pyfunc, argtypes, debug)
36+
ordered_arg_access_types = get_ordered_arg_access_types(pyfunc, access_types)
37+
return compile_kernel(None, pyfunc, argtypes, ordered_arg_access_types, debug)
4838

4939
return _wrapped
5040

device_init.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
get_num_groups,
1212
barrier,
1313
mem_fence,
14-
# shared,
1514
sub_group_barrier,
1615
atomic,
16+
local,
17+
CLK_LOCAL_MEM_FENCE,
18+
CLK_GLOBAL_MEM_FENCE,
1719
)
1820

21+
DEFAULT_LOCAL_SIZE = []
22+
1923
from . import initialize
2024

2125
from .decorators import kernel, func, autojit

dppy_host_fn_call_gen.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ def __init__(self, lowerer, cres, num_inputs):
1515
self.context = self.lowerer.context
1616
self.builder = self.lowerer.builder
1717

18-
self.gpu_device = driver.runtime.get_gpu_device()
19-
self.gpu_device_env = self.gpu_device.get_env_ptr()
20-
self.gpu_device_int = int(driver.ffi.cast("uintptr_t",
21-
self.gpu_device_env))
18+
self.current_device = driver.runtime.get_current_device()
19+
self.current_device_env = self.current_device.get_env_ptr()
20+
self.current_device_int = int(driver.ffi.cast("uintptr_t",
21+
self.current_device_env))
2222

2323
self.kernel_t_obj = cres.kernel._kernel_t_obj[0]
2424
self.kernel_int = int(driver.ffi.cast("uintptr_t",
@@ -65,8 +65,9 @@ def _init_llvm_types_and_constants(self):
6565
self.void_ptr_t = self.context.get_value_type(types.voidptr)
6666
self.void_ptr_ptr_t = lc.Type.pointer(self.void_ptr_t)
6767
self.sizeof_void_ptr = self.context.get_abi_sizeof(self.intp_t)
68-
self.gpu_device_int_const = self.context.get_constant(
69-
types.uintp, self.gpu_device_int)
68+
self.current_device_int_const = self.context.get_constant(
69+
types.uintp,
70+
self.current_device_int)
7071

7172
def _declare_functions(self):
7273
create_dppy_kernel_arg_fnty = lc.Type.function(
@@ -173,7 +174,7 @@ def process_kernel_arg(self, var, llvm_arg, arg_type, gu_sig, val_type, index, m
173174
size=self.one, name=buffer_name)
174175

175176
# env, buffer_size, buffer_ptr
176-
args = [self.builder.inttoptr(self.gpu_device_int_const, self.void_ptr_t),
177+
args = [self.builder.inttoptr(self.current_device_int_const, self.void_ptr_t),
177178
self.builder.load(total_size),
178179
buffer_ptr]
179180
self.builder.call(self.create_dppy_rw_mem_buffer, args)
@@ -186,7 +187,7 @@ def process_kernel_arg(self, var, llvm_arg, arg_type, gu_sig, val_type, index, m
186187

187188
# We really need to detect when an array needs to be copied over
188189
if index < self.num_inputs:
189-
args = [self.builder.inttoptr(self.gpu_device_int_const, self.void_ptr_t),
190+
args = [self.builder.inttoptr(self.current_device_int_const, self.void_ptr_t),
190191
self.builder.load(buffer_ptr),
191192
self.one,
192193
self.zero,
@@ -263,7 +264,7 @@ def enqueue_kernel_and_read_back(self, loop_ranges):
263264
self.builder.store(stop,
264265
self.builder.gep(dim_stops, [self.context.get_constant(types.uintp, i)]))
265266

266-
args = [self.builder.inttoptr(self.gpu_device_int_const, self.void_ptr_t),
267+
args = [self.builder.inttoptr(self.current_device_int_const, self.void_ptr_t),
267268
self.builder.inttoptr(self.context.get_constant(types.uintp, self.kernel_int), self.void_ptr_t),
268269
self.context.get_constant(types.uintp, self.total_kernel_args),
269270
self.kernel_arg_array,
@@ -276,7 +277,7 @@ def enqueue_kernel_and_read_back(self, loop_ranges):
276277
# read buffers back to host
277278
for read_buf in self.read_bufs_after_enqueue:
278279
buffer_ptr, array_size_member, data_member = read_buf
279-
args = [self.builder.inttoptr(self.gpu_device_int_const, self.void_ptr_t),
280+
args = [self.builder.inttoptr(self.current_device_int_const, self.void_ptr_t),
280281
self.builder.load(buffer_ptr),
281282
self.one,
282283
self.zero,

dppy_lowerer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def print_arg_with_addrspaces(args):
618618
# FIXME : We should not always use gpu device, instead select the default
619619
# device as configured in dppy.
620620
kernel_func = numba.dppy.compiler.compile_kernel_parfor(
621-
driver.runtime.get_gpu_device(),
621+
driver.runtime.get_current_device(),
622622
gufunc_ir,
623623
gufunc_param_types,
624624
param_types_addrspaces)

dppy_passbuilder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
InlineOverloads)
1818

1919
from .dppy_passes import (
20+
DPPyConstantSizeStaticLocalMemoryPass,
2021
DPPyPreParforPass,
2122
DPPyParforPass,
2223
SpirvFriendlyLowering,
@@ -40,6 +41,11 @@ def default_numba_nopython_pipeline(state, pm):
4041
pm.add_pass(IRProcessing, "processing IR")
4142
pm.add_pass(WithLifting, "Handle with contexts")
4243

44+
# Add pass to ensure when users are allocating static
45+
# constant memory the size is a constant and can not
46+
# come from a closure variable
47+
pm.add_pass(DPPyConstantSizeStaticLocalMemoryPass, "dppy constant size for static local memory")
48+
4349
# pre typing
4450
if not state.flags.no_rewrites:
4551
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")

dppy_passes.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from contextlib import contextmanager
33
import warnings
44

5+
from numba import ir
56
import weakref
67
from collections import namedtuple, deque
78
import operator
@@ -32,7 +33,81 @@
3233
from numba.parfor import PreParforPass as _parfor_PreParforPass
3334
from numba.parfor import ParforPass as _parfor_ParforPass
3435
from numba.parfor import Parfor
35-
#from numba.npyufunc.dufunc import DUFunc
36+
37+
38+
@register_pass(mutates_CFG=True, analysis_only=False)
39+
class DPPyConstantSizeStaticLocalMemoryPass(FunctionPass):
40+
41+
_name = "dppy_constant_size_static_local_memory_pass"
42+
43+
def __init__(self):
44+
FunctionPass.__init__(self)
45+
46+
def run_pass(self, state):
47+
"""
48+
Preprocessing for data-parallel computations.
49+
"""
50+
# Ensure we have an IR and type information.
51+
assert state.func_ir
52+
func_ir = state.func_ir
53+
54+
_DEBUG = False
55+
56+
if _DEBUG:
57+
print('Checks if size of OpenCL local address space alloca is a compile-time constant.'.center(80, '-'))
58+
print(func_ir.dump())
59+
60+
work_list = list(func_ir.blocks.items())
61+
while work_list:
62+
label, block = work_list.pop()
63+
for i, instr in enumerate(block.body):
64+
if isinstance(instr, ir.Assign):
65+
expr = instr.value
66+
if isinstance(expr, ir.Expr):
67+
if expr.op == 'call':
68+
call_node = block.find_variable_assignment(expr.func.name).value
69+
if isinstance(call_node, ir.Expr) and call_node.attr == "static_alloc":
70+
arg = None
71+
# at first look in keyword arguments to get the shape, which has to be
72+
# constant
73+
if expr.kws:
74+
for _arg in expr.kws:
75+
if _arg[0] == "shape":
76+
arg = _arg[1]
77+
78+
if not arg:
79+
arg = expr.args[0]
80+
81+
error = False
82+
# arg can be one constant or a tuple of constant items
83+
arg_type = func_ir.get_definition(arg.name)
84+
if isinstance(arg_type, ir.Expr):
85+
# we have a tuple
86+
for item in arg_type.items:
87+
if not isinstance(func_ir.get_definition(item.name), ir.Const):
88+
error = True
89+
break
90+
91+
else:
92+
if not isinstance(func_ir.get_definition(arg.name), ir.Const):
93+
error = True
94+
break
95+
96+
if error:
97+
warnings.warn_explicit("The size of the Local memory has to be constant",
98+
errors.NumbaError,
99+
state.func_id.filename,
100+
state.func_id.firstlineno)
101+
raise
102+
103+
104+
105+
if config.DEBUG or config.DUMP_IR:
106+
name = state.func_ir.func_id.func_qualname
107+
print(("IR DUMP: %s" % name).center(80, "-"))
108+
state.func_ir.dump()
109+
110+
return True
36111

37112

38113
@register_pass(mutates_CFG=True, analysis_only=False)

0 commit comments

Comments
 (0)