Skip to content

Commit

Permalink
Adapt Elemwise iterator for Numba Generators
Browse files Browse the repository at this point in the history
Also drops support for RandomState

Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 23, 2024
1 parent bae694d commit 762bcad
Show file tree
Hide file tree
Showing 14 changed files with 752 additions and 894 deletions.
12 changes: 1 addition & 11 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import numpy as np
from numpy.random import Generator, RandomState
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
)
Expand Down Expand Up @@ -52,15 +52,6 @@ def assert_size_argument_jax_compatible(node):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)


@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state


@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
Expand Down Expand Up @@ -185,7 +176,6 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op):
Expand Down
11 changes: 10 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytensor.link.numba.dispatch.vectorize_codegen import (
_vectorized,
encode_literals,
store_core_outputs,
)
from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.scalar.basic import (
Expand Down Expand Up @@ -483,10 +484,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
)

nin = len(node.inputs)
nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)

input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))

# numba doesn't support nested literals right now...
input_bc_patterns_enc = encode_literals(input_bc_patterns)
Expand All @@ -496,12 +502,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):

def elemwise_wrapper(*inputs):
return _vectorized(
scalar_op_fn,
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
(), # constant_inputs
inputs,
core_output_shapes, # core_shapes
None, # size
)

# Pure python implementation, that will be used in tests
Expand Down
Loading

0 comments on commit 762bcad

Please sign in to comment.