Skip to content

Commit e8c042b

Browse files
Separate interface and dispatch of numba_funcify
1 parent 561be04 commit e8c042b

11 files changed

+169
-121
lines changed

aesara/link/numba/dispatch/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# isort: off
2-
from aesara.link.numba.dispatch.basic import numba_funcify, numba_const_convert
2+
from aesara.link.numba.dispatch.basic import (
3+
numba_funcify,
4+
numba_const_convert,
5+
numba_njit,
6+
)
37

48
# Load dispatch specializations
59
import aesara.link.numba.dispatch.scalar

aesara/link/numba/dispatch/basic.py

+66-28
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
from functools import singledispatch
55
from textwrap import dedent
6-
from typing import Union
6+
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
77

88
import numba
99
import numba.np.unsafe.ndarray as numba_ndarray
@@ -22,6 +22,7 @@
2222
from aesara.compile.ops import DeepCopyOp
2323
from aesara.graph.basic import Apply, NoParams
2424
from aesara.graph.fg import FunctionGraph
25+
from aesara.graph.op import Op
2526
from aesara.graph.type import Type
2627
from aesara.ifelse import IfElse
2728
from aesara.link.utils import (
@@ -48,6 +49,10 @@
4849
from aesara.tensor.type_other import MakeSlice, NoneConst
4950

5051

52+
if TYPE_CHECKING:
53+
from aesara.graph.op import StorageMapType
54+
55+
5156
def numba_njit(*args, **kwargs):
5257

5358
if len(args) > 0 and callable(args[0]):
@@ -335,9 +340,42 @@ def numba_const_convert(data, dtype=None, **kwargs):
335340
return data
336341

337342

343+
def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable:
344+
"""Convert `obj` to a Numba-JITable object."""
345+
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
346+
347+
338348
@singledispatch
339-
def numba_funcify(op, node=None, storage_map=None, **kwargs):
340-
"""Create a Numba compatible function from an Aesara `Op`."""
349+
def _numba_funcify(
350+
obj,
351+
node: Optional[Apply] = None,
352+
storage_map: Optional["StorageMapType"] = None,
353+
**kwargs,
354+
) -> Callable:
355+
r"""Dispatch on Aesara object types to perform Numba conversions.
356+
357+
Arguments
358+
---------
359+
obj
360+
The object used to determine the appropriate conversion function based
361+
on its type. This is generally an `Op` instance, but `FunctionGraph`\s
362+
are also supported.
363+
node
364+
When `obj` is an `Op`, this value should be the corresponding `Apply` node.
365+
storage_map
366+
A storage map with, for example, the constant and `SharedVariable` values
367+
of the graph being converted.
368+
369+
Returns
370+
-------
371+
A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
372+
373+
"""
374+
375+
376+
@_numba_funcify.register(Op)
377+
def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable:
378+
"""Create a Numba compatible function from an Aesara `Op.perform`."""
341379

342380
warnings.warn(
343381
f"Numba will use object mode to run {op}'s perform method",
@@ -388,10 +426,10 @@ def perform(*inputs):
388426
ret = py_perform_return(inputs)
389427
return ret
390428

391-
return perform
429+
return cast(Callable, perform)
392430

393431

394-
@numba_funcify.register(OpFromGraph)
432+
@_numba_funcify.register(OpFromGraph)
395433
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
396434

397435
_ = kwargs.pop("storage_map", None)
@@ -413,7 +451,7 @@ def opfromgraph(*inputs):
413451
return opfromgraph
414452

415453

416-
@numba_funcify.register(FunctionGraph)
454+
@_numba_funcify.register(FunctionGraph)
417455
def numba_funcify_FunctionGraph(
418456
fgraph,
419457
node=None,
@@ -521,9 +559,9 @@ def {fn_name}({", ".join(input_names)}):
521559
return subtensor_def_src
522560

523561

524-
@numba_funcify.register(Subtensor)
525-
@numba_funcify.register(AdvancedSubtensor)
526-
@numba_funcify.register(AdvancedSubtensor1)
562+
@_numba_funcify.register(Subtensor)
563+
@_numba_funcify.register(AdvancedSubtensor)
564+
@_numba_funcify.register(AdvancedSubtensor1)
527565
def numba_funcify_Subtensor(op, node, **kwargs):
528566

529567
subtensor_def_src = create_index_func(
@@ -539,8 +577,8 @@ def numba_funcify_Subtensor(op, node, **kwargs):
539577
return numba_njit(subtensor_fn)
540578

541579

542-
@numba_funcify.register(IncSubtensor)
543-
@numba_funcify.register(AdvancedIncSubtensor)
580+
@_numba_funcify.register(IncSubtensor)
581+
@_numba_funcify.register(AdvancedIncSubtensor)
544582
def numba_funcify_IncSubtensor(op, node, **kwargs):
545583

546584
incsubtensor_def_src = create_index_func(
@@ -556,7 +594,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
556594
return numba_njit(incsubtensor_fn)
557595

558596

559-
@numba_funcify.register(AdvancedIncSubtensor1)
597+
@_numba_funcify.register(AdvancedIncSubtensor1)
560598
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
561599
inplace = op.inplace
562600
set_instead_of_inc = op.set_instead_of_inc
@@ -589,7 +627,7 @@ def advancedincsubtensor1(x, vals, idxs):
589627
return advancedincsubtensor1
590628

591629

592-
@numba_funcify.register(DeepCopyOp)
630+
@_numba_funcify.register(DeepCopyOp)
593631
def numba_funcify_DeepCopyOp(op, node, **kwargs):
594632

595633
# Scalars are apparently returned as actual Python scalar types and not
@@ -611,26 +649,26 @@ def deepcopyop(x):
611649
return deepcopyop
612650

613651

614-
@numba_funcify.register(MakeSlice)
615-
def numba_funcify_MakeSlice(op, **kwargs):
652+
@_numba_funcify.register(MakeSlice)
653+
def numba_funcify_MakeSlice(op, node, **kwargs):
616654
@numba_njit
617655
def makeslice(*x):
618656
return slice(*x)
619657

620658
return makeslice
621659

622660

623-
@numba_funcify.register(Shape)
624-
def numba_funcify_Shape(op, **kwargs):
661+
@_numba_funcify.register(Shape)
662+
def numba_funcify_Shape(op, node, **kwargs):
625663
@numba_njit(inline="always")
626664
def shape(x):
627665
return np.asarray(np.shape(x))
628666

629667
return shape
630668

631669

632-
@numba_funcify.register(Shape_i)
633-
def numba_funcify_Shape_i(op, **kwargs):
670+
@_numba_funcify.register(Shape_i)
671+
def numba_funcify_Shape_i(op, node, **kwargs):
634672
i = op.i
635673

636674
@numba_njit(inline="always")
@@ -660,8 +698,8 @@ def codegen(context, builder, signature, args):
660698
return sig, codegen
661699

662700

663-
@numba_funcify.register(Reshape)
664-
def numba_funcify_Reshape(op, **kwargs):
701+
@_numba_funcify.register(Reshape)
702+
def numba_funcify_Reshape(op, node, **kwargs):
665703
ndim = op.ndim
666704

667705
if ndim == 0:
@@ -683,7 +721,7 @@ def reshape(x, shape):
683721
return reshape
684722

685723

686-
@numba_funcify.register(SpecifyShape)
724+
@_numba_funcify.register(SpecifyShape)
687725
def numba_funcify_SpecifyShape(op, node, **kwargs):
688726
shape_inputs = node.inputs[1:]
689727
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
@@ -730,7 +768,7 @@ def inputs_cast(x):
730768
return inputs_cast
731769

732770

733-
@numba_funcify.register(Dot)
771+
@_numba_funcify.register(Dot)
734772
def numba_funcify_Dot(op, node, **kwargs):
735773
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
736774
# float.
@@ -745,7 +783,7 @@ def dot(x, y):
745783
return dot
746784

747785

748-
@numba_funcify.register(Softplus)
786+
@_numba_funcify.register(Softplus)
749787
def numba_funcify_Softplus(op, node, **kwargs):
750788

751789
x_dtype = np.dtype(node.inputs[0].dtype)
@@ -764,7 +802,7 @@ def softplus(x):
764802
return softplus
765803

766804

767-
@numba_funcify.register(Cholesky)
805+
@_numba_funcify.register(Cholesky)
768806
def numba_funcify_Cholesky(op, node, **kwargs):
769807
lower = op.lower
770808

@@ -800,7 +838,7 @@ def cholesky(a):
800838
return cholesky
801839

802840

803-
@numba_funcify.register(Solve)
841+
@_numba_funcify.register(Solve)
804842
def numba_funcify_Solve(op, node, **kwargs):
805843

806844
assume_a = op.assume_a
@@ -847,7 +885,7 @@ def solve(a, b):
847885
return solve
848886

849887

850-
@numba_funcify.register(BatchedDot)
888+
@_numba_funcify.register(BatchedDot)
851889
def numba_funcify_BatchedDot(op, node, **kwargs):
852890
dtype = node.outputs[0].type.numpy_dtype
853891

@@ -868,7 +906,7 @@ def batched_dot(x, y):
868906
# optimizations are apparently already performed by Numba
869907

870908

871-
@numba_funcify.register(IfElse)
909+
@_numba_funcify.register(IfElse)
872910
def numba_funcify_IfElse(op, **kwargs):
873911
n_outs = op.n_outs
874912

aesara/link/numba/dispatch/elemwise.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from aesara.graph.op import Op
1313
from aesara.link.numba.dispatch import basic as numba_basic
1414
from aesara.link.numba.dispatch.basic import (
15+
_numba_funcify,
1516
create_numba_signature,
1617
create_tuple_creator,
1718
numba_funcify,
@@ -422,7 +423,7 @@ def axis_apply_fn(x):
422423
return axis_apply_fn
423424

424425

425-
@numba_funcify.register(Elemwise)
426+
@_numba_funcify.register(Elemwise)
426427
def numba_funcify_Elemwise(op, node, **kwargs):
427428

428429
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
@@ -474,7 +475,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
474475
return elemwise_fn
475476

476477

477-
@numba_funcify.register(CAReduce)
478+
@_numba_funcify.register(CAReduce)
478479
def numba_funcify_CAReduce(op, node, **kwargs):
479480
axes = op.axis
480481
if axes is None:
@@ -512,7 +513,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
512513
return careduce_fn
513514

514515

515-
@numba_funcify.register(DimShuffle)
516+
@_numba_funcify.register(DimShuffle)
516517
def numba_funcify_DimShuffle(op, **kwargs):
517518
shuffle = tuple(op.shuffle)
518519
transposition = tuple(op.transposition)
@@ -590,7 +591,7 @@ def dimshuffle(x):
590591
return dimshuffle
591592

592593

593-
@numba_funcify.register(Softmax)
594+
@_numba_funcify.register(Softmax)
594595
def numba_funcify_Softmax(op, node, **kwargs):
595596

596597
x_at = node.inputs[0]
@@ -627,7 +628,7 @@ def softmax_py_fn(x):
627628
return softmax
628629

629630

630-
@numba_funcify.register(SoftmaxGrad)
631+
@_numba_funcify.register(SoftmaxGrad)
631632
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
632633

633634
sm_at = node.inputs[1]
@@ -658,7 +659,7 @@ def softmax_grad_py_fn(dy, sm):
658659
return softmax_grad
659660

660661

661-
@numba_funcify.register(LogSoftmax)
662+
@_numba_funcify.register(LogSoftmax)
662663
def numba_funcify_LogSoftmax(op, node, **kwargs):
663664

664665
x_at = node.inputs[0]
@@ -692,7 +693,7 @@ def log_softmax_py_fn(x):
692693
return log_softmax
693694

694695

695-
@numba_funcify.register(MaxAndArgmax)
696+
@_numba_funcify.register(MaxAndArgmax)
696697
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
697698
axis = op.axis
698699
x_at = node.inputs[0]

0 commit comments

Comments
 (0)