Skip to content

Commit

Permalink
Add support for RandomVariable with Generators in Numba backend and d…
Browse files Browse the repository at this point in the history
…rop support for RandomState

Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
  • Loading branch information
3 people committed May 24, 2024
1 parent 4e787be commit 58899d1
Show file tree
Hide file tree
Showing 17 changed files with 650 additions and 662 deletions.
8 changes: 7 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature


def infer_shape(outs, inputs, input_shapes):
Expand All @@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually

# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim
Expand Down Expand Up @@ -307,6 +310,7 @@ def __init__(
connection_pattern: list[list[bool]] | None = None,
strict: bool = False,
name: str | None = None,
destroy_map: dict[int, tuple[int, ...]] | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -464,6 +468,7 @@ def __init__(
if name is not None:
assert isinstance(name, str), "name must be None or string object"
self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}

def __eq__(self, other):
# TODO: recognize a copy
Expand Down Expand Up @@ -862,6 +867,7 @@ def make_node(self, *inputs):
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
destroy_map=self.destroy_map,
**self.kwargs,
)
new_inputs = (
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def jax_funcify_ARange(op, node, **kwargs):
constant_args.append(arg.value)
else:
# TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)!
raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)
constant_args.append(None)
# raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)

constant_start, constant_stop, constant_step = constant_args

Expand Down
6 changes: 6 additions & 0 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from numba.extending import box, overload

from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -434,6 +435,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)

# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) == 1:
Expand Down
Loading

0 comments on commit 58899d1

Please sign in to comment.