Skip to content

Commit

Permalink
Added Numba cache, vectorize_target, and fastmath config options
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 authored and brandonwillard committed Jan 18, 2022
1 parent be918f5 commit 240827c
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 146 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ __pycache__
*.snm
*.toc
*.vrb
*.nbc
*.nbi
.noseids
*.DS_Store
*.bak
Expand Down
22 changes: 22 additions & 0 deletions aesara/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,27 @@ def add_scan_configvars():
)


def add_numba_configvars():
config.add(
"numba__vectorize_target",
("Default target for numba.vectorize."),
EnumStr("cpu", ["parallel", "cuda"], mutable=True),
in_c_key=False,
)
config.add(
"numba__fastmath",
("If True, use Numba's fastmath mode."),
BoolParam(True),
in_c_key=False,
)
config.add(
"numba__cache",
("If True, use Numba's file based caching."),
BoolParam(True),
in_c_key=False,
)


def _get_default_gpuarray__cache_path():
return os.path.join(config.compiledir, "gpuarray_kernels")

Expand Down Expand Up @@ -1683,6 +1704,7 @@ def add_caching_dir_configvars():
add_metaopt_configvars()
add_vm_configvars()
add_deprecated_configvars()
add_numba_configvars()

# TODO: `gcc_version_str` is used by other modules.. Should it become an immutable config var?
try:
Expand Down
74 changes: 46 additions & 28 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numba.core.errors import TypingError
from numba.extending import box

from aesara import config
from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
Expand Down Expand Up @@ -40,6 +41,21 @@
from aesara.tensor.type_other import MakeSlice


def numba_njit(*args, **kwargs):

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0])

return numba.njit(*args, cache=config.numba__cache, **kwargs)


def numba_vectorize(*args, **kwargs):
if len(args) > 0 and callable(args[0]):
return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0])

return numba.vectorize(*args, cache=config.numba__cache, **kwargs)


def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
) -> numba.types.Type:
Expand Down Expand Up @@ -222,19 +238,19 @@ def create_tuple_creator(f, n):
"""
assert n > 0

f = numba.njit(f)
f = numba_njit(f)

@numba.njit
@numba_njit
def creator(args):
return (f(0, *args),)

for i in range(1, n):

@numba.njit
@numba_njit
def creator(args, creator=creator, i=i):
return creator(args) + (f(i, *args),)

return numba.njit(lambda *args: creator(args))
return numba_njit(lambda *args: creator(args))


def create_tuple_string(x):
Expand Down Expand Up @@ -268,7 +284,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
else:
ret_sig = get_numba_type(node.outputs[0].type)

@numba.njit
@numba_njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
outputs = [[None] for i in range(n_outputs)]
Expand Down Expand Up @@ -402,9 +418,11 @@ def numba_funcify_Subtensor(op, node, **kwargs):

global_env = {"np": np, "objmode": numba.objmode}

subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env)
subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
)

return numba.njit(subtensor_fn)
return numba_njit(subtensor_fn)


@numba_funcify.register(IncSubtensor)
Expand All @@ -419,10 +437,10 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
global_env = {"np": np, "objmode": numba.objmode}

incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", global_env
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
)

return numba.njit(incsubtensor_fn)
return numba_njit(incsubtensor_fn)


@numba_funcify.register(DeepCopyOp)
Expand All @@ -434,13 +452,13 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
# The type can also be RandomType with no ndims
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit(inline="always")
@numba_njit(inline="always")
def deepcopyop(x):
return x

else:

@numba.njit(inline="always")
@numba_njit(inline="always")
def deepcopyop(x):
return x.copy()

Expand All @@ -449,7 +467,7 @@ def deepcopyop(x):

@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba.njit
@numba_njit
def makeslice(*x):
return slice(*x)

Expand All @@ -458,7 +476,7 @@ def makeslice(*x):

@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba.njit(inline="always")
@numba_njit(inline="always")
def shape(x):
return np.asarray(np.shape(x))

Expand All @@ -469,7 +487,7 @@ def shape(x):
def numba_funcify_Shape_i(op, **kwargs):
i = op.i

@numba.njit(inline="always")
@numba_njit(inline="always")
def shape_i(x):
return np.shape(x)[i]

Expand Down Expand Up @@ -502,13 +520,13 @@ def numba_funcify_Reshape(op, **kwargs):

if ndim == 0:

@numba.njit(inline="always")
@numba_njit(inline="always")
def reshape(x, shape):
return x.item()

else:

@numba.njit(inline="always")
@numba_njit(inline="always")
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
Expand All @@ -521,7 +539,7 @@ def reshape(x, shape):

@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs):
@numba.njit
@numba_njit
def specifyshape(x, shape):
assert np.array_equal(x.shape, shape)
return x
Expand All @@ -536,15 +554,15 @@ def int_to_float_fn(inputs, out_dtype):

args_dtype = np.dtype(f"f{out_dtype.itemsize}")

@numba.njit(inline="always")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

else:
args_dtype_sz = max([_arg.type.numpy_dtype.itemsize for _arg in inputs])
args_dtype = np.dtype(f"f{args_dtype_sz}")

@numba.njit(inline="always")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

Expand All @@ -559,7 +577,7 @@ def numba_funcify_Dot(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba.njit(inline="always")
@numba_njit(inline="always")
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)

Expand All @@ -571,7 +589,7 @@ def numba_funcify_Softplus(op, node, **kwargs):

x_dtype = np.dtype(node.inputs[0].dtype)

@numba.njit
@numba_njit
def softplus(x):
if x < -37.0:
return direct_cast(np.exp(x), x_dtype)
Expand All @@ -595,7 +613,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):

inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba.njit(inline="always")
@numba_njit(inline="always")
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)

Expand All @@ -612,7 +630,7 @@ def cholesky(a):

ret_sig = get_numba_type(node.outputs[0].type)

@numba.njit
@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
Expand Down Expand Up @@ -641,7 +659,7 @@ def numba_funcify_Solve(op, node, **kwargs):

ret_sig = get_numba_type(node.outputs[0].type)

@numba.njit
@numba_njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(
Expand All @@ -656,7 +674,7 @@ def solve(a, b):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba.njit(inline="always")
@numba_njit(inline="always")
def solve(a, b):
return np.linalg.solve(
inputs_cast(a),
Expand All @@ -672,7 +690,7 @@ def solve(a, b):
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype

@numba.njit
@numba_njit
def batched_dot(x, y):
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
Expand All @@ -695,7 +713,7 @@ def numba_funcify_IfElse(op, **kwargs):

if n_outs > 1:

@numba.njit
@numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
Expand All @@ -706,7 +724,7 @@ def ifelse(cond, *args):

else:

@numba.njit
@numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
Expand Down
Loading

0 comments on commit 240827c

Please sign in to comment.