Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 8ccfd36

Browse files
committed
Initial try at calling a c++ function while lowering a function
1 parent 2add24b commit 8ccfd36

File tree

5 files changed

+158
-4
lines changed

5 files changed

+158
-4
lines changed

numba/dppl/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_ordered_arg_access_types(pyfunc, access_types):
4747

4848
return ordered_arg_access_types
4949

50+
5051
class DPPLCompiler(CompilerBase):
5152
""" DPPL Compiler """
5253

numba/dppl/dppl_passes.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from .dppl_lowerer import DPPLLower
2727

28-
from numba.parfors.parfor import PreParforPass as _parfor_PreParforPass
28+
from numba.parfors.parfor import PreParforPass as _parfor_PreParforPass, replace_functions_map
2929
from numba.parfors.parfor import ParforPass as _parfor_ParforPass
3030
from numba.parfors.parfor import Parfor
3131

@@ -119,13 +119,16 @@ def run_pass(self, state):
119119
"""
120120
# Ensure we have an IR and type information.
121121
assert state.func_ir
122+
functions_map = replace_functions_map.copy()
123+
functions_map.pop(('dot', 'numpy'), None)
122124

123125
preparfor_pass = _parfor_PreParforPass(
124126
state.func_ir,
125127
state.type_annotation.typemap,
126128
state.type_annotation.calltypes, state.typingctx,
127129
state.flags.auto_parallel,
128-
state.parfor_diagnostics.replaced_fns
130+
state.parfor_diagnostics.replaced_fns,
131+
replace_functions_map=functions_map
129132
)
130133

131134
preparfor_pass.run()
@@ -216,7 +219,19 @@ def run_pass(self, state):
216219
# be later serialized.
217220
state.library.enable_object_caching()
218221

222+
219223
targetctx = state.targetctx
224+
225+
# This should not happen here, after we have the notion of context in Numba
226+
# we should have specialized dispatcher for dppl context and that dispatcher
227+
# should be a cpu dispatcher that will overload the lowering functions for
228+
# linalg for dppl.cpu_dispatcher and the dppl.gpu_dipatcher should be the
229+
# current target context we have to launch kernels.
230+
# This is broken as this essentially adds the new lowering in a list which
231+
# means it does not get replaced with the new lowering_buitins
232+
from . import experimental_linalg_lowering_overload
233+
targetctx.refresh()
234+
220235
library = state.library
221236
interp = state.func_ir # why is it called this?!
222237
typemap = state.typemap
@@ -273,6 +288,7 @@ def run_pass(self, state):
273288
"""
274289
Back-end: Generate LLVM IR from Numba IR, compile to machine code
275290
"""
291+
276292
lowered = state['cr']
277293
signature = typing.signature(state.return_type, *state.args)
278294

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
from numba.core import types, cgutils
3+
from numba.core.imputils import (lower_builtin)
4+
from numba.core.typing import signature
5+
from numba.np.arrayobj import make_array, _empty_nd_impl, array_copy
6+
from numba.core import itanium_mangler
7+
from llvmlite import ir
8+
import contextlib
9+
10+
from numba import int32, int64, uint32, uint64, float32, float64
11+
12+
13+
@contextlib.contextmanager
14+
def make_contiguous(context, builder, sig, args):
15+
"""
16+
Ensure that all array arguments are contiguous, if necessary by
17+
copying them.
18+
A new (sig, args) tuple is yielded.
19+
"""
20+
newtys = []
21+
newargs = []
22+
copies = []
23+
for ty, val in zip(sig.args, args):
24+
if not isinstance(ty, types.Array) or ty.layout in 'CF':
25+
newty, newval = ty, val
26+
else:
27+
newty = ty.copy(layout='C')
28+
copysig = signature(newty, ty)
29+
newval = array_copy(context, builder, copysig, (val,))
30+
copies.append((newty, newval))
31+
newtys.append(newty)
32+
newargs.append(newval)
33+
yield signature(sig.return_type, *newtys), tuple(newargs)
34+
for ty, val in copies:
35+
context.nrt.decref(builder, ty, val)
36+
37+
def check_c_int(context, builder, n):
38+
"""
39+
Check whether *n* fits in a C `int`.
40+
"""
41+
_maxint = 2**31 - 1
42+
43+
def impl(n):
44+
if n > _maxint:
45+
raise OverflowError("array size too large to fit in C int")
46+
47+
context.compile_internal(builder, impl,
48+
signature(types.none, types.intp), (n,))
49+
50+
51+
ll_char = ir.IntType(8)
52+
ll_char_p = ll_char.as_pointer()
53+
ll_void_p = ll_char_p
54+
ll_intc = ir.IntType(32)
55+
ll_intc_p = ll_intc.as_pointer()
56+
intp_t = cgutils.intp_t
57+
ll_intp_p = intp_t.as_pointer()
58+
59+
def call_experimental_dot(context, builder, conjugate, dtype,
60+
n, a_data, b_data, out_data):
61+
62+
fnty = ir.FunctionType(ir.IntType(32),
63+
[ll_void_p, ll_void_p, ll_void_p, ir.IntType(64)])
64+
65+
#fn = builder.module.get_or_insert_function(fnty, name="inumpy_dot")
66+
#name = itanium_mangler.mangle("inumpy_dot", [int64, dtype])
67+
#print(name)
68+
fn = builder.module.get_or_insert_function(fnty, name="_Z10inumpy_dotIfEiPvS0_S0_m")
69+
70+
res = builder.call(fn, (builder.bitcast(a_data, ll_void_p),
71+
builder.bitcast(b_data, ll_void_p),
72+
builder.bitcast(out_data, ll_void_p),
73+
n))
74+
75+
def dot_2_vv(context, builder, sig, args, conjugate=False):
76+
"""
77+
np.dot(vector, vector)
78+
np.vdot(vector, vector)
79+
"""
80+
import llvmlite.binding as ll
81+
ll.load_library_permanently('libinumpy.so')
82+
83+
aty, bty = sig.args
84+
dtype = sig.return_type
85+
a = make_array(aty)(context, builder, args[0])
86+
b = make_array(bty)(context, builder, args[1])
87+
n, = cgutils.unpack_tuple(builder, a.shape)
88+
89+
def check_args(a, b):
90+
m, = a.shape
91+
n, = b.shape
92+
if m != n:
93+
raise ValueError("incompatible array sizes for np.dot(a, b) "
94+
"(vector * vector)")
95+
96+
context.compile_internal(builder, check_args,
97+
signature(types.none, *sig.args), args)
98+
check_c_int(context, builder, n)
99+
100+
out = cgutils.alloca_once(builder, context.get_value_type(dtype))
101+
call_experimental_dot(context, builder, conjugate, dtype, n, a.data, b.data, out)
102+
return builder.load(out)
103+
104+
105+
@lower_builtin(np.dot, types.Array, types.Array)
106+
def dot_dppl(context, builder, sig, args):
107+
"""
108+
np.dot(a, b)
109+
a @ b
110+
"""
111+
import dppl.ocldrv as driver
112+
device = driver.runtime.get_current_device()
113+
114+
# the device env should come from the context but the current context
115+
# is a cpu context and not a dppl_gpu_context
116+
117+
with make_contiguous(context, builder, sig, args) as (sig, args):
118+
ndims = [x.ndim for x in sig.args[:2]]
119+
if ndims == [2, 2]:
120+
print("gemm")
121+
#return dot_2_mm(context, builder, sig, args)
122+
elif ndims == [2, 1]:
123+
print("gemv")
124+
#return dot_2_mv(context, builder, sig, args)
125+
elif ndims == [1, 2]:
126+
print("gemv")
127+
#return dot_2_vm(context, builder, sig, args)
128+
elif ndims == [1, 1]:
129+
print("dot")
130+
return dot_2_vv(context, builder, sig, args)
131+
else:
132+
assert 0
133+
134+
135+
raise ImportError("scipy 0.16+ is required for linear algebra")

numba/parfors/parfor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,14 +1350,15 @@ class PreParforPass(object):
13501350
implementations of numpy functions if available.
13511351
"""
13521352
def __init__(self, func_ir, typemap, calltypes, typingctx, options,
1353-
swapped={}):
1353+
swapped={}, replace_functions_map=replace_functions_map):
13541354
self.func_ir = func_ir
13551355
self.typemap = typemap
13561356
self.calltypes = calltypes
13571357
self.typingctx = typingctx
13581358
self.options = options
13591359
# diagnostics
13601360
self.swapped = swapped
1361+
self.replace_functions_map = replace_functions_map
13611362
self.stats = {
13621363
'replaced_func': 0,
13631364
'replaced_dtype': 0,
@@ -1394,7 +1395,7 @@ def _replace_parallel_functions(self, blocks):
13941395
def replace_func():
13951396
func_def = get_definition(self.func_ir, expr.func)
13961397
callname = find_callname(self.func_ir, expr)
1397-
repl_func = replace_functions_map.get(callname, None)
1398+
repl_func = self.replace_functions_map.get(callname, None)
13981399
# Handle method on array type
13991400
if (repl_func is None and
14001401
len(callname) == 2 and

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def check_file_at_path(path2file):
213213
else:
214214
omplinkflags = ['-fopenmp']
215215

216+
tbb_root = False
216217
if tbb_root:
217218
print("Using Intel TBB from:", tbb_root)
218219
ext_np_ufunc_tbb_backend = Extension(

0 commit comments

Comments
 (0)