3
3
from contextlib import contextmanager
4
4
from functools import singledispatch
5
5
from textwrap import dedent
6
- from typing import Union
6
+ from typing import TYPE_CHECKING , Callable , Optional , Union , cast
7
7
8
8
import numba
9
9
import numba .np .unsafe .ndarray as numba_ndarray
22
22
from aesara .compile .ops import DeepCopyOp
23
23
from aesara .graph .basic import Apply , NoParams
24
24
from aesara .graph .fg import FunctionGraph
25
+ from aesara .graph .op import Op
25
26
from aesara .graph .type import Type
26
27
from aesara .ifelse import IfElse
27
28
from aesara .link .utils import (
48
49
from aesara .tensor .type_other import MakeSlice , NoneConst
49
50
50
51
52
+ if TYPE_CHECKING :
53
+ from aesara .graph .op import StorageMapType
54
+
55
+
51
56
def numba_njit (* args , ** kwargs ):
52
57
53
58
if len (args ) > 0 and callable (args [0 ]):
@@ -335,9 +340,42 @@ def numba_const_convert(data, dtype=None, **kwargs):
335
340
return data
336
341
337
342
343
+ def numba_funcify (obj , node = None , storage_map = None , ** kwargs ) -> Callable :
344
+ """Convert `obj` to a Numba-JITable object."""
345
+ return _numba_funcify (obj , node = node , storage_map = storage_map , ** kwargs )
346
+
347
+
338
348
@singledispatch
339
- def numba_funcify (op , node = None , storage_map = None , ** kwargs ):
340
- """Create a Numba compatible function from an Aesara `Op`."""
349
+ def _numba_funcify (
350
+ obj ,
351
+ node : Optional [Apply ] = None ,
352
+ storage_map : Optional ["StorageMapType" ] = None ,
353
+ ** kwargs ,
354
+ ) -> Callable :
355
+ r"""Dispatch on Aesara object types to perform Numba conversions.
356
+
357
+ Arguments
358
+ ---------
359
+ obj
360
+ The object used to determine the appropriate conversion function based
361
+ on its type. This is generally an `Op` instance, but `FunctionGraph`\s
362
+ are also supported.
363
+ node
364
+ When `obj` is an `Op`, this value should be the corresponding `Apply` node.
365
+ storage_map
366
+ A storage map with, for example, the constant and `SharedVariable` values
367
+ of the graph being converted.
368
+
369
+ Returns
370
+ -------
371
+ A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
372
+
373
+ """
374
+
375
+
376
+ @_numba_funcify .register (Op )
377
+ def numba_funcify_perform (op , node , storage_map = None , ** kwargs ) -> Callable :
378
+ """Create a Numba compatible function from an Aesara `Op.perform`."""
341
379
342
380
warnings .warn (
343
381
f"Numba will use object mode to run { op } 's perform method" ,
@@ -388,10 +426,10 @@ def perform(*inputs):
388
426
ret = py_perform_return (inputs )
389
427
return ret
390
428
391
- return perform
429
+ return cast ( Callable , perform )
392
430
393
431
394
- @numba_funcify .register (OpFromGraph )
432
+ @_numba_funcify .register (OpFromGraph )
395
433
def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
396
434
397
435
_ = kwargs .pop ("storage_map" , None )
@@ -413,7 +451,7 @@ def opfromgraph(*inputs):
413
451
return opfromgraph
414
452
415
453
416
- @numba_funcify .register (FunctionGraph )
454
+ @_numba_funcify .register (FunctionGraph )
417
455
def numba_funcify_FunctionGraph (
418
456
fgraph ,
419
457
node = None ,
@@ -521,9 +559,9 @@ def {fn_name}({", ".join(input_names)}):
521
559
return subtensor_def_src
522
560
523
561
524
- @numba_funcify .register (Subtensor )
525
- @numba_funcify .register (AdvancedSubtensor )
526
- @numba_funcify .register (AdvancedSubtensor1 )
562
+ @_numba_funcify .register (Subtensor )
563
+ @_numba_funcify .register (AdvancedSubtensor )
564
+ @_numba_funcify .register (AdvancedSubtensor1 )
527
565
def numba_funcify_Subtensor (op , node , ** kwargs ):
528
566
529
567
subtensor_def_src = create_index_func (
@@ -539,8 +577,8 @@ def numba_funcify_Subtensor(op, node, **kwargs):
539
577
return numba_njit (subtensor_fn )
540
578
541
579
542
- @numba_funcify .register (IncSubtensor )
543
- @numba_funcify .register (AdvancedIncSubtensor )
580
+ @_numba_funcify .register (IncSubtensor )
581
+ @_numba_funcify .register (AdvancedIncSubtensor )
544
582
def numba_funcify_IncSubtensor (op , node , ** kwargs ):
545
583
546
584
incsubtensor_def_src = create_index_func (
@@ -556,7 +594,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
556
594
return numba_njit (incsubtensor_fn )
557
595
558
596
559
- @numba_funcify .register (AdvancedIncSubtensor1 )
597
+ @_numba_funcify .register (AdvancedIncSubtensor1 )
560
598
def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
561
599
inplace = op .inplace
562
600
set_instead_of_inc = op .set_instead_of_inc
@@ -589,7 +627,7 @@ def advancedincsubtensor1(x, vals, idxs):
589
627
return advancedincsubtensor1
590
628
591
629
592
- @numba_funcify .register (DeepCopyOp )
630
+ @_numba_funcify .register (DeepCopyOp )
593
631
def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
594
632
595
633
# Scalars are apparently returned as actual Python scalar types and not
@@ -611,26 +649,26 @@ def deepcopyop(x):
611
649
return deepcopyop
612
650
613
651
614
- @numba_funcify .register (MakeSlice )
615
- def numba_funcify_MakeSlice (op , ** kwargs ):
652
+ @_numba_funcify .register (MakeSlice )
653
+ def numba_funcify_MakeSlice (op , node , ** kwargs ):
616
654
@numba_njit
617
655
def makeslice (* x ):
618
656
return slice (* x )
619
657
620
658
return makeslice
621
659
622
660
623
- @numba_funcify .register (Shape )
624
- def numba_funcify_Shape (op , ** kwargs ):
661
+ @_numba_funcify .register (Shape )
662
+ def numba_funcify_Shape (op , node , ** kwargs ):
625
663
@numba_njit (inline = "always" )
626
664
def shape (x ):
627
665
return np .asarray (np .shape (x ))
628
666
629
667
return shape
630
668
631
669
632
- @numba_funcify .register (Shape_i )
633
- def numba_funcify_Shape_i (op , ** kwargs ):
670
+ @_numba_funcify .register (Shape_i )
671
+ def numba_funcify_Shape_i (op , node , ** kwargs ):
634
672
i = op .i
635
673
636
674
@numba_njit (inline = "always" )
@@ -660,8 +698,8 @@ def codegen(context, builder, signature, args):
660
698
return sig , codegen
661
699
662
700
663
- @numba_funcify .register (Reshape )
664
- def numba_funcify_Reshape (op , ** kwargs ):
701
+ @_numba_funcify .register (Reshape )
702
+ def numba_funcify_Reshape (op , node , ** kwargs ):
665
703
ndim = op .ndim
666
704
667
705
if ndim == 0 :
@@ -683,7 +721,7 @@ def reshape(x, shape):
683
721
return reshape
684
722
685
723
686
- @numba_funcify .register (SpecifyShape )
724
+ @_numba_funcify .register (SpecifyShape )
687
725
def numba_funcify_SpecifyShape (op , node , ** kwargs ):
688
726
shape_inputs = node .inputs [1 :]
689
727
shape_input_names = ["shape_" + str (i ) for i in range (len (shape_inputs ))]
@@ -730,7 +768,7 @@ def inputs_cast(x):
730
768
return inputs_cast
731
769
732
770
733
- @numba_funcify .register (Dot )
771
+ @_numba_funcify .register (Dot )
734
772
def numba_funcify_Dot (op , node , ** kwargs ):
735
773
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
736
774
# float.
@@ -745,7 +783,7 @@ def dot(x, y):
745
783
return dot
746
784
747
785
748
- @numba_funcify .register (Softplus )
786
+ @_numba_funcify .register (Softplus )
749
787
def numba_funcify_Softplus (op , node , ** kwargs ):
750
788
751
789
x_dtype = np .dtype (node .inputs [0 ].dtype )
@@ -764,7 +802,7 @@ def softplus(x):
764
802
return softplus
765
803
766
804
767
- @numba_funcify .register (Cholesky )
805
+ @_numba_funcify .register (Cholesky )
768
806
def numba_funcify_Cholesky (op , node , ** kwargs ):
769
807
lower = op .lower
770
808
@@ -800,7 +838,7 @@ def cholesky(a):
800
838
return cholesky
801
839
802
840
803
- @numba_funcify .register (Solve )
841
+ @_numba_funcify .register (Solve )
804
842
def numba_funcify_Solve (op , node , ** kwargs ):
805
843
806
844
assume_a = op .assume_a
@@ -847,7 +885,7 @@ def solve(a, b):
847
885
return solve
848
886
849
887
850
- @numba_funcify .register (BatchedDot )
888
+ @_numba_funcify .register (BatchedDot )
851
889
def numba_funcify_BatchedDot (op , node , ** kwargs ):
852
890
dtype = node .outputs [0 ].type .numpy_dtype
853
891
@@ -868,7 +906,7 @@ def batched_dot(x, y):
868
906
# optimizations are apparently already performed by Numba
869
907
870
908
871
- @numba_funcify .register (IfElse )
909
+ @_numba_funcify .register (IfElse )
872
910
def numba_funcify_IfElse (op , ** kwargs ):
873
911
n_outs = op .n_outs
874
912
0 commit comments