Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 20bb2ea

Browse files
committed
MAde sure parallel flag is only activated when provided or when target is dppy
1 parent f6ee24b commit 20bb2ea

File tree

3 files changed

+353
-2
lines changed

3 files changed

+353
-2
lines changed

numba/dppy/dppy_lowerer.py

Lines changed: 346 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from numba.dppy.target import SPIR_GENERIC_ADDRSPACE
4242

43+
multi_tile = False
4344

4445
def replace_var_with_array_in_block(vars, block, typemap, calltypes):
4546
new_block = []
@@ -381,7 +382,7 @@ def _create_gufunc_for_parfor_body(
381382
parfor_dim = len(parfor.loop_nests)
382383
loop_indices = [l.index_variable.name for l in parfor.loop_nests]
383384

384-
use_sched = False
385+
use_sched = True if (not target=='spirv' or multi_tile) else False
385386

386387
# Get all the parfor params.
387388
parfor_params = parfor.params
@@ -1047,7 +1048,7 @@ def _lower_parfor_dppy(lowerer, parfor):
10471048

10481049
# get the shape signature
10491050
get_shape_classes = parfor.get_shape_classes
1050-
use_sched = False
1051+
use_sched = True if (not target=='spirv' or multi_tile) else False
10511052
if use_sched:
10521053
func_args = ['sched'] + func_args
10531054
num_reductions = len(parfor_redvars)
@@ -1338,6 +1339,349 @@ def bump_alpha(c, class_map):
13381339
return (gu_sin, gu_sout)
13391340

13401341

1342+
def call_parallel_gufunc(lowerer, cres, gu_signature, outer_sig, expr_args, expr_arg_types,
1343+
loop_ranges, redvars, reddict, redarrdict, init_block, index_var_typ, races):
1344+
'''
1345+
Adds the call to the gufunc function from the main function.
1346+
'''
1347+
context = lowerer.context
1348+
builder = lowerer.builder
1349+
1350+
from numba.npyufunc.parallel import (build_gufunc_wrapper,
1351+
get_thread_count,
1352+
_launch_threads)
1353+
1354+
if config.DEBUG_ARRAY_OPT:
1355+
print("make_parallel_loop")
1356+
print("args = ", expr_args)
1357+
print("outer_sig = ", outer_sig.args, outer_sig.return_type,
1358+
outer_sig.recvr, outer_sig.pysig)
1359+
print("loop_ranges = ", loop_ranges)
1360+
print("expr_args", expr_args)
1361+
print("expr_arg_types", expr_arg_types)
1362+
print("gu_signature", gu_signature)
1363+
print("cres", cres, type(cres))
1364+
print("cres.library", cres.library, type(cres.library))
1365+
print("cres.fndesc", cres.fndesc, type(cres.fndesc))
1366+
1367+
# Build the wrapper for GUFunc
1368+
args, return_type = sigutils.normalize_signature(outer_sig)
1369+
llvm_func = cres.library.get_function(cres.fndesc.llvm_func_name)
1370+
1371+
if config.DEBUG_ARRAY_OPT:
1372+
print("llvm_func", llvm_func, type(llvm_func))
1373+
sin, sout = gu_signature
1374+
1375+
if config.DEBUG_ARRAY_OPT:
1376+
print("sin", sin)
1377+
print("sout", sout)
1378+
1379+
# These are necessary for build_gufunc_wrapper to find external symbols
1380+
_launch_threads()
1381+
1382+
info = build_gufunc_wrapper(llvm_func, cres, sin, sout,
1383+
cache=False, is_parfors=True)
1384+
wrapper_name = info.name
1385+
cres.library._ensure_finalized()
1386+
1387+
if config.DEBUG_ARRAY_OPT:
1388+
print("parallel function = ", wrapper_name, cres)
1389+
1390+
# loadvars for loop_ranges
1391+
def load_range(v):
1392+
if isinstance(v, ir.Var):
1393+
return lowerer.loadvar(v.name)
1394+
else:
1395+
return context.get_constant(types.uintp, v)
1396+
1397+
num_dim = len(loop_ranges)
1398+
for i in range(num_dim):
1399+
start, stop, step = loop_ranges[i]
1400+
start = load_range(start)
1401+
stop = load_range(stop)
1402+
assert(step == 1) # We do not support loop steps other than 1
1403+
step = load_range(step)
1404+
loop_ranges[i] = (start, stop, step)
1405+
1406+
if config.DEBUG_ARRAY_OPT:
1407+
print("call_parallel_gufunc loop_ranges[{}] = ".format(i), start,
1408+
stop, step)
1409+
cgutils.printf(builder, "loop range[{}]: %d %d (%d)\n".format(i),
1410+
start, stop, step)
1411+
1412+
# Commonly used LLVM types and constants
1413+
byte_t = lc.Type.int(8)
1414+
byte_ptr_t = lc.Type.pointer(byte_t)
1415+
byte_ptr_ptr_t = lc.Type.pointer(byte_ptr_t)
1416+
intp_t = context.get_value_type(types.intp)
1417+
uintp_t = context.get_value_type(types.uintp)
1418+
intp_ptr_t = lc.Type.pointer(intp_t)
1419+
uintp_ptr_t = lc.Type.pointer(uintp_t)
1420+
zero = context.get_constant(types.uintp, 0)
1421+
one = context.get_constant(types.uintp, 1)
1422+
one_type = one.type
1423+
sizeof_intp = context.get_abi_sizeof(intp_t)
1424+
1425+
# Prepare sched, first pop it out of expr_args, outer_sig, and gu_signature
1426+
expr_args.pop(0)
1427+
sched_sig = sin.pop(0)
1428+
1429+
if config.DEBUG_ARRAY_OPT:
1430+
print("Parfor has potentially negative start", index_var_typ.signed)
1431+
1432+
if index_var_typ.signed:
1433+
sched_type = intp_t
1434+
sched_ptr_type = intp_ptr_t
1435+
else:
1436+
sched_type = uintp_t
1437+
sched_ptr_type = uintp_ptr_t
1438+
1439+
# Call do_scheduling with appropriate arguments
1440+
dim_starts = cgutils.alloca_once(
1441+
builder, sched_type, size=context.get_constant(
1442+
types.uintp, num_dim), name="dims")
1443+
dim_stops = cgutils.alloca_once(
1444+
builder, sched_type, size=context.get_constant(
1445+
types.uintp, num_dim), name="dims")
1446+
for i in range(num_dim):
1447+
start, stop, step = loop_ranges[i]
1448+
if start.type != one_type:
1449+
start = builder.sext(start, one_type)
1450+
if stop.type != one_type:
1451+
stop = builder.sext(stop, one_type)
1452+
if step.type != one_type:
1453+
step = builder.sext(step, one_type)
1454+
# substract 1 because do-scheduling takes inclusive ranges
1455+
stop = builder.sub(stop, one)
1456+
builder.store(
1457+
start, builder.gep(
1458+
dim_starts, [
1459+
context.get_constant(
1460+
types.uintp, i)]))
1461+
builder.store(stop, builder.gep(dim_stops,
1462+
[context.get_constant(types.uintp, i)]))
1463+
1464+
sched_size = get_thread_count() * num_dim * 2
1465+
sched = cgutils.alloca_once(
1466+
builder, sched_type, size=context.get_constant(
1467+
types.uintp, sched_size), name="sched")
1468+
debug_flag = 1 if config.DEBUG_ARRAY_OPT else 0
1469+
scheduling_fnty = lc.Type.function(
1470+
intp_ptr_t, [uintp_t, sched_ptr_type, sched_ptr_type, uintp_t, sched_ptr_type, intp_t])
1471+
if index_var_typ.signed:
1472+
do_scheduling = builder.module.get_or_insert_function(scheduling_fnty,
1473+
name="do_scheduling_signed")
1474+
else:
1475+
do_scheduling = builder.module.get_or_insert_function(scheduling_fnty,
1476+
name="do_scheduling_unsigned")
1477+
1478+
builder.call(
1479+
do_scheduling, [
1480+
context.get_constant(
1481+
types.uintp, num_dim), dim_starts, dim_stops, context.get_constant(
1482+
types.uintp, get_thread_count()), sched, context.get_constant(
1483+
types.intp, debug_flag)])
1484+
1485+
# Get the LLVM vars for the Numba IR reduction array vars.
1486+
redarrs = [lowerer.loadvar(redarrdict[x].name) for x in redvars]
1487+
1488+
nredvars = len(redvars)
1489+
ninouts = len(expr_args) - nredvars
1490+
1491+
if config.DEBUG_ARRAY_OPT:
1492+
for i in range(get_thread_count()):
1493+
cgutils.printf(builder, "sched[" + str(i) + "] = ")
1494+
for j in range(num_dim * 2):
1495+
cgutils.printf(
1496+
builder, "%d ", builder.load(
1497+
builder.gep(
1498+
sched, [
1499+
context.get_constant(
1500+
types.intp, i * num_dim * 2 + j)])))
1501+
cgutils.printf(builder, "\n")
1502+
1503+
# ----------------------------------------------------------------------------
1504+
# Prepare arguments: args, shapes, steps, data
1505+
all_args = [lowerer.loadvar(x) for x in expr_args[:ninouts]] + redarrs
1506+
num_args = len(all_args)
1507+
num_inps = len(sin) + 1
1508+
args = cgutils.alloca_once(
1509+
builder,
1510+
byte_ptr_t,
1511+
size=context.get_constant(
1512+
types.intp,
1513+
1 + num_args),
1514+
name="pargs")
1515+
array_strides = []
1516+
# sched goes first
1517+
builder.store(builder.bitcast(sched, byte_ptr_t), args)
1518+
array_strides.append(context.get_constant(types.intp, sizeof_intp))
1519+
red_shapes = {}
1520+
rv_to_arg_dict = {}
1521+
# followed by other arguments
1522+
for i in range(num_args):
1523+
arg = all_args[i]
1524+
var = expr_args[i]
1525+
aty = expr_arg_types[i]
1526+
dst = builder.gep(args, [context.get_constant(types.intp, i + 1)])
1527+
if i >= ninouts: # reduction variables
1528+
ary = context.make_array(aty)(context, builder, arg)
1529+
strides = cgutils.unpack_tuple(builder, ary.strides, aty.ndim)
1530+
ary_shapes = cgutils.unpack_tuple(builder, ary.shape, aty.ndim)
1531+
# Start from 1 because we skip the first dimension of length num_threads just like sched.
1532+
for j in range(1, len(strides)):
1533+
array_strides.append(strides[j])
1534+
red_shapes[i] = ary_shapes[1:]
1535+
builder.store(builder.bitcast(ary.data, byte_ptr_t), dst)
1536+
elif isinstance(aty, types.ArrayCompatible):
1537+
if var in races:
1538+
typ = context.get_data_type(
1539+
aty.dtype) if aty.dtype != types.boolean else lc.Type.int(1)
1540+
1541+
rv_arg = cgutils.alloca_once(builder, typ)
1542+
builder.store(arg, rv_arg)
1543+
builder.store(builder.bitcast(rv_arg, byte_ptr_t), dst)
1544+
rv_to_arg_dict[var] = (arg, rv_arg)
1545+
1546+
array_strides.append(context.get_constant(types.intp, context.get_abi_sizeof(typ)))
1547+
else:
1548+
ary = context.make_array(aty)(context, builder, arg)
1549+
strides = cgutils.unpack_tuple(builder, ary.strides, aty.ndim)
1550+
for j in range(len(strides)):
1551+
array_strides.append(strides[j])
1552+
builder.store(builder.bitcast(ary.data, byte_ptr_t), dst)
1553+
else:
1554+
if i < num_inps:
1555+
# Scalar input, need to store the value in an array of size 1
1556+
typ = context.get_data_type(
1557+
aty) if aty != types.boolean else lc.Type.int(1)
1558+
ptr = cgutils.alloca_once(builder, typ)
1559+
builder.store(arg, ptr)
1560+
else:
1561+
# Scalar output, must allocate
1562+
typ = context.get_data_type(
1563+
aty) if aty != types.boolean else lc.Type.int(1)
1564+
ptr = cgutils.alloca_once(builder, typ)
1565+
builder.store(builder.bitcast(ptr, byte_ptr_t), dst)
1566+
1567+
# ----------------------------------------------------------------------------
1568+
# Next, we prepare the individual dimension info recorded in gu_signature
1569+
sig_dim_dict = {}
1570+
occurances = []
1571+
occurances = [sched_sig[0]]
1572+
sig_dim_dict[sched_sig[0]] = context.get_constant(types.intp, 2 * num_dim)
1573+
assert len(expr_args) == len(all_args)
1574+
assert len(expr_args) == len(expr_arg_types)
1575+
assert len(expr_args) == len(sin + sout)
1576+
assert len(expr_args) == len(outer_sig.args[1:])
1577+
for var, arg, aty, gu_sig in zip(expr_args, all_args,
1578+
expr_arg_types, sin + sout):
1579+
if isinstance(aty, types.npytypes.Array):
1580+
i = aty.ndim - len(gu_sig)
1581+
else:
1582+
i = 0
1583+
if config.DEBUG_ARRAY_OPT:
1584+
print("var =", var, "gu_sig =", gu_sig, "type =", aty, "i =", i)
1585+
1586+
for dim_sym in gu_sig:
1587+
if config.DEBUG_ARRAY_OPT:
1588+
print("var = ", var, " type = ", aty)
1589+
if var in races:
1590+
sig_dim_dict[dim_sym] = context.get_constant(types.intp, 1)
1591+
else:
1592+
ary = context.make_array(aty)(context, builder, arg)
1593+
shapes = cgutils.unpack_tuple(builder, ary.shape, aty.ndim)
1594+
sig_dim_dict[dim_sym] = shapes[i]
1595+
1596+
if not (dim_sym in occurances):
1597+
if config.DEBUG_ARRAY_OPT:
1598+
print("dim_sym = ", dim_sym, ", i = ", i)
1599+
cgutils.printf(builder, dim_sym + " = %d\n", sig_dim_dict[dim_sym])
1600+
occurances.append(dim_sym)
1601+
i = i + 1
1602+
1603+
# ----------------------------------------------------------------------------
1604+
# Prepare shapes, which is a single number (outer loop size), followed by
1605+
# the size of individual shape variables.
1606+
nshapes = len(sig_dim_dict) + 1
1607+
shapes = cgutils.alloca_once(builder, intp_t, size=nshapes, name="pshape")
1608+
# For now, outer loop size is the same as number of threads
1609+
builder.store(context.get_constant(types.intp, get_thread_count()), shapes)
1610+
# Individual shape variables go next
1611+
i = 1
1612+
for dim_sym in occurances:
1613+
if config.DEBUG_ARRAY_OPT:
1614+
cgutils.printf(builder, dim_sym + " = %d\n", sig_dim_dict[dim_sym])
1615+
builder.store(
1616+
sig_dim_dict[dim_sym], builder.gep(
1617+
shapes, [
1618+
context.get_constant(
1619+
types.intp, i)]))
1620+
i = i + 1
1621+
1622+
# ----------------------------------------------------------------------------
1623+
# Prepare steps for each argument. Note that all steps are counted in
1624+
# bytes.
1625+
num_steps = num_args + 1 + len(array_strides)
1626+
steps = cgutils.alloca_once(
1627+
builder, intp_t, size=context.get_constant(
1628+
types.intp, num_steps), name="psteps")
1629+
# First goes the step size for sched, which is 2 * num_dim
1630+
builder.store(context.get_constant(types.intp, 2 * num_dim * sizeof_intp),
1631+
steps)
1632+
# The steps for all others are 0, except for reduction results.
1633+
for i in range(num_args):
1634+
if i >= ninouts: # steps for reduction vars are abi_sizeof(typ)
1635+
j = i - ninouts
1636+
# Get the base dtype of the reduction array.
1637+
redtyp = lowerer.fndesc.typemap[redvars[j]]
1638+
red_stride = None
1639+
if isinstance(redtyp, types.npytypes.Array):
1640+
redtyp = redtyp.dtype
1641+
red_stride = red_shapes[i]
1642+
typ = context.get_value_type(redtyp)
1643+
sizeof = context.get_abi_sizeof(typ)
1644+
# Set stepsize to the size of that dtype.
1645+
stepsize = context.get_constant(types.intp, sizeof)
1646+
if red_stride is not None:
1647+
for rs in red_stride:
1648+
stepsize = builder.mul(stepsize, rs)
1649+
else:
1650+
# steps are strides
1651+
stepsize = zero
1652+
dst = builder.gep(steps, [context.get_constant(types.intp, 1 + i)])
1653+
builder.store(stepsize, dst)
1654+
for j in range(len(array_strides)):
1655+
dst = builder.gep(
1656+
steps, [
1657+
context.get_constant(
1658+
types.intp, 1 + num_args + j)])
1659+
builder.store(array_strides[j], dst)
1660+
1661+
# ----------------------------------------------------------------------------
1662+
# prepare data
1663+
data = cgutils.get_null_value(byte_ptr_t)
1664+
1665+
fnty = lc.Type.function(lc.Type.void(), [byte_ptr_ptr_t, intp_ptr_t,
1666+
intp_ptr_t, byte_ptr_t])
1667+
1668+
fn = builder.module.get_or_insert_function(fnty, name=wrapper_name)
1669+
context.active_code_library.add_linking_library(info.library)
1670+
1671+
if config.DEBUG_ARRAY_OPT:
1672+
cgutils.printf(builder, "before calling kernel %p\n", fn)
1673+
builder.call(fn, [args, shapes, steps, data])
1674+
if config.DEBUG_ARRAY_OPT:
1675+
cgutils.printf(builder, "after calling kernel %p\n", fn)
1676+
1677+
for k, v in rv_to_arg_dict.items():
1678+
arg, rv_arg = v
1679+
only_elem_ptr = builder.gep(rv_arg, [context.get_constant(types.intp, 0)])
1680+
builder.store(builder.load(only_elem_ptr), lowerer.getvar(k))
1681+
1682+
context.active_code_library.add_linking_library(cres.library)
1683+
1684+
13411685

13421686

13431687
# Keep all the dppy kernels and programs created alive indefinitely.

numba/dppy/dppy_passes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,15 @@ def run_pass(self, state):
5757
"""
5858
# Ensure we have an IR and type information.
5959
assert state.func_ir
60+
'''
6061
state.flags.auto_parallel.stencil = True
6162
state.flags.auto_parallel.setitem = True
6263
state.flags.auto_parallel.numpy = True
6364
state.flags.auto_parallel.reduction = True
6465
state.flags.auto_parallel.prange = True
6566
state.flags.auto_parallel.fusion = True
67+
'''
68+
print(state.flags.auto_parallel.numpy)
6669

6770
preparfor_pass = _parfor_PreParforPass(
6871
state.func_ir,

0 commit comments

Comments
 (0)