From 762bcad56811cb03a5c22fb0e05983c87f5c70bb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 22:13:35 +0200 Subject: [PATCH] Adapt Elemwise iterator for Numba Generators Also drops support for RandomState Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Co-authored-by: Adrian Seyboldt --- pytensor/link/jax/dispatch/random.py | 12 +- pytensor/link/numba/dispatch/elemwise.py | 11 +- pytensor/link/numba/dispatch/random.py | 549 ++++++++---------- .../link/numba/dispatch/vectorize_codegen.py | 470 ++++++++++----- pytensor/tensor/random/__init__.py | 2 +- pytensor/tensor/random/basic.py | 80 +-- pytensor/tensor/random/op.py | 17 +- pytensor/tensor/random/rewriting/jax.py | 4 +- pytensor/tensor/random/type.py | 91 --- pytensor/tensor/random/utils.py | 10 +- pytensor/tensor/random/var.py | 18 +- scripts/mypy-failing.txt | 1 - tests/link/numba/test_basic.py | 6 +- tests/link/numba/test_random.py | 375 ++++++------ 14 files changed, 752 insertions(+), 894 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index e7c1a68efb..f175d16126 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -2,7 +2,7 @@ import jax import numpy as np -from numpy.random import Generator, RandomState +from numpy.random import Generator from numpy.random.bit_generator import ( # type: ignore[attr-defined] _coerce_to_uint32_array, ) @@ -52,15 +52,6 @@ def assert_size_argument_jax_compatible(node): raise NotImplementedError(SIZE_NOT_COMPATIBLE) -@jax_typify.register(RandomState) -def jax_typify_RandomState(state, **kwargs): - state = state.get_state(legacy=False) - state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] - # XXX: Is this a reasonable approach? - state["jax_state"] = state["state"]["key"][0:2] - return state - - @jax_typify.register(Generator) def jax_typify_Generator(rng, **kwargs): state = rng.__getstate__() @@ -185,7 +176,6 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn -@jax_sample_fn.register(ptr.RandIntRV) @jax_sample_fn.register(ptr.IntegersRV) @jax_sample_fn.register(ptr.UniformRV) def jax_sample_fn_uniform(op): diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 273850b97d..4cfbb2a1f5 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -23,6 +23,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( _vectorized, encode_literals, + store_core_outputs, ) from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( @@ -483,10 +484,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs ) + nin = len(node.inputs) + nout = len(node.outputs) + core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) + input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs]) output_dtypes = tuple(out.type.dtype for out in node.outputs) inplace_pattern = tuple(op.inplace_pattern.items()) + core_output_shapes = tuple(() for _ in range(nout)) # numba doesn't support nested literals right now... input_bc_patterns_enc = encode_literals(input_bc_patterns) @@ -496,12 +502,15 @@ def numba_funcify_Elemwise(op, node, **kwargs): def elemwise_wrapper(*inputs): return _vectorized( - scalar_op_fn, + core_op_fn, input_bc_patterns_enc, output_bc_patterns_enc, output_dtypes_enc, inplace_pattern_enc, + (), # constant_inputs inputs, + core_output_shapes, # core_shapes + None, # size ) # Pure python implementation, that will be used in tests diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index a063cf21b7..1ea52e8eff 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,369 +1,288 @@ from collections.abc import Callable -from textwrap import dedent, indent -from typing import Any +from copy import copy +from functools import singledispatch +from textwrap import dedent +import numba import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -from numba import _helperlib, types -from numba.core import cgutils -from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox -from numpy.random import RandomState +from numba import types +from numba.core.extending import overload import pytensor.tensor.random.basic as ptr -from pytensor.graph.basic import Apply +from pytensor.graph import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify +from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify +from pytensor.link.numba.dispatch.vectorize_codegen import ( + _vectorized, + encode_literals, + store_core_outputs, +) from pytensor.link.utils import ( compile_function_src, - get_name_for_object, - unique_name_generator, ) -from pytensor.tensor.basic import get_vector_length -from pytensor.tensor.random.type import RandomStateType - - -class RandomStateNumbaType(types.Type): - def __init__(self): - super().__init__(name="RandomState") - - -random_state_numba_type = RandomStateNumbaType() - - -@typeof_impl.register(RandomState) -def typeof_index(val, c): - return random_state_numba_type - - -@register_model(RandomStateNumbaType) -class RandomStateNumbaModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - # TODO: We can add support for boxing and unboxing - # the attributes that describe a RandomState so that - # they can be accessed inside njit functions, if required. - ("state_key", types.Array(types.uint32, 1, "C")), - ] - models.StructModel.__init__(self, dmm, fe_type, members) - - -@unbox(RandomStateNumbaType) -def unbox_random_state(typ, obj, c): - """Convert a `RandomState` object to a native `RandomStateNumbaModel` structure. - - Note that this will create a 'fake' structure which will just get the - `RandomState` objects accepted in Numba functions but the actual information - of the Numba's random state is stored internally and can be accessed - anytime using ``numba._helperlib.rnd_get_np_state_ptr()``. - """ - interval = cgutils.create_struct_proxy(typ)(c.context, c.builder) - is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) - return NativeValue(interval._getvalue(), is_error=is_error) - - -@box(RandomStateNumbaType) -def box_random_state(typ, val, c): - """Convert a native `RandomStateNumbaModel` structure to an `RandomState` object - using Numba's internal state array. - - Note that `RandomStateNumbaModel` is just a placeholder structure with no - inherent information about Numba internal random state, all that information - is instead retrieved from Numba using ``_helperlib.rnd_get_state()`` and a new - `RandomState` is constructed using the Numba's current internal state. +from pytensor.tensor.random.op import RandomVariable + + +@overload(copy) +def copy_NumPyRandomGenerator(rng): + def impl(rng): + # TODO: Open issue on Numba? + with numba.objmode(new_rng=types.npy_rng): + new_rng = copy(rng) + + return new_rng + + return impl + + +@singledispatch +def numba_core_rv_funcify(op: Op, node: Apply) -> Callable: + """Return the core function for a random variable operation.""" + raise NotImplementedError() + + +@numba_core_rv_funcify.register(ptr.UniformRV) +@numba_core_rv_funcify.register(ptr.TriangularRV) +@numba_core_rv_funcify.register(ptr.BetaRV) +@numba_core_rv_funcify.register(ptr.NormalRV) +@numba_core_rv_funcify.register(ptr.LogNormalRV) +@numba_core_rv_funcify.register(ptr.GammaRV) +@numba_core_rv_funcify.register(ptr.ParetoRV) +@numba_core_rv_funcify.register(ptr.ExponentialRV) +@numba_core_rv_funcify.register(ptr.WeibullRV) +@numba_core_rv_funcify.register(ptr.LogisticRV) +@numba_core_rv_funcify.register(ptr.VonMisesRV) +@numba_core_rv_funcify.register(ptr.PoissonRV) +@numba_core_rv_funcify.register(ptr.GeometricRV) +# @numba_core_rv_funcify.register(ptr.HyperGeometricRV) # Not implemented in numba +@numba_core_rv_funcify.register(ptr.WaldRV) +@numba_core_rv_funcify.register(ptr.LaplaceRV) +@numba_core_rv_funcify.register(ptr.BinomialRV) +@numba_core_rv_funcify.register(ptr.NegativeBinomialRV) +@numba_core_rv_funcify.register(ptr.MultinomialRV) +@numba_core_rv_funcify.register(ptr.DirichletRV) +@numba_core_rv_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported +@numba_core_rv_funcify.register(ptr.PermutationRV) +def numba_core_rv_default(op, node): + """Create a default RV core numba function. + + @njit + def random(rng, i0, i1, ..., in): + return rng.name(i0, i1, ..., in) """ - pos, state_list = _helperlib.rnd_get_state(_helperlib.rnd_get_np_state_ptr()) - rng = RandomState() - rng.set_state(("MT19937", state_list, pos)) - class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(rng)) - return class_obj - + name = op.name -@numba_typify.register(RandomState) -def numba_typify_RandomState(state, **kwargs): - # The numba_typify in this case is just an passthrough function - # that synchronizes Numba's internal random state with the current - # RandomState object - ints, index = state.get_state()[1:3] - ptr = _helperlib.rnd_get_np_state_ptr() - _helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints])) - return state + inputs = [f"i{i}" for i in range(len(op.ndims_params))] + input_signature = ",".join(inputs) + func_src = dedent(f""" + def {name}(rng, {input_signature}): + return rng.{name}({input_signature}) + """) -def make_numba_random_fn(node, np_random_func): - """Create Numba implementations for existing Numba-supported ``np.random`` functions. + func = compile_function_src(func_src, name, {**globals()}) + return numba_basic.numba_njit(func) - The functions generated here add parameter broadcasting and the ``size`` - argument to the Numba-supported scalar ``np.random`` functions. - """ - if not isinstance(node.inputs[0].type, RandomStateType): - raise TypeError("Numba does not support NumPy `Generator`s") - - tuple_size = int(get_vector_length(node.inputs[1])) - size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) - - # Make a broadcast-capable version of the Numba supported scalar sampling - # function - bcast_fn_name = f"pytensor_random_{get_name_for_object(np_random_func)}" - - sized_fn_name = "sized_random_variable" - - unique_names = unique_name_generator( - [ - bcast_fn_name, - sized_fn_name, - "np", - "np_random_func", - "numba_vectorize", - "to_fixed_tuple", - "tuple_size", - "size_dims", - "rng", - "size", - "dtype", - ], - suffix_sep="_", - ) - - bcast_fn_input_names = ", ".join( - [unique_names(i, force_unique=True) for i in node.inputs[3:]] - ) - bcast_fn_global_env = { - "np_random_func": np_random_func, - "numba_vectorize": numba_basic.numba_vectorize, - } - - bcast_fn_src = f""" -@numba_vectorize -def {bcast_fn_name}({bcast_fn_input_names}): - return np_random_func({bcast_fn_input_names}) - """ - bcast_fn = compile_function_src( - bcast_fn_src, bcast_fn_name, {**globals(), **bcast_fn_global_env} - ) - - random_fn_input_names = ", ".join( - ["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]] - ) - # Now, create a Numba JITable function that implements the `size` parameter +@numba_core_rv_funcify.register(ptr.BernoulliRV) +def numba_core_BernoulliRV(op, node): out_dtype = node.outputs[1].type.numpy_dtype - random_fn_global_env = { - bcast_fn_name: bcast_fn, - "out_dtype": out_dtype, - } - if tuple_size > 0: - random_fn_body = dedent( - f""" - size = to_fixed_tuple(size, tuple_size) - - data = np.empty(size, dtype=out_dtype) - for i in np.ndindex(size[:size_dims]): - data[i] = {bcast_fn_name}({bcast_fn_input_names}) - - """ + @numba_basic.numba_njit() + def random(rng, p): + return ( + direct_cast(0, out_dtype) + if p < rng.uniform() + else direct_cast(1, out_dtype) ) - random_fn_global_env.update( - { - "np": np, - "to_fixed_tuple": numba_ndarray.to_fixed_tuple, - "tuple_size": tuple_size, - "size_dims": size_dims, - } - ) - else: - random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})""" - - sized_fn_src = dedent( - f""" -def {sized_fn_name}({random_fn_input_names}): -{indent(random_fn_body, " " * 4)} - return (rng, data) - """ - ) - random_fn = compile_function_src( - sized_fn_src, sized_fn_name, {**globals(), **random_fn_global_env} - ) - random_fn = numba_basic.numba_njit(random_fn) - - return random_fn - -@numba_funcify.register(ptr.UniformRV) -@numba_funcify.register(ptr.TriangularRV) -@numba_funcify.register(ptr.BetaRV) -@numba_funcify.register(ptr.NormalRV) -@numba_funcify.register(ptr.LogNormalRV) -@numba_funcify.register(ptr.GammaRV) -@numba_funcify.register(ptr.ParetoRV) -@numba_funcify.register(ptr.GumbelRV) -@numba_funcify.register(ptr.ExponentialRV) -@numba_funcify.register(ptr.WeibullRV) -@numba_funcify.register(ptr.LogisticRV) -@numba_funcify.register(ptr.VonMisesRV) -@numba_funcify.register(ptr.PoissonRV) -@numba_funcify.register(ptr.GeometricRV) -@numba_funcify.register(ptr.HyperGeometricRV) -@numba_funcify.register(ptr.WaldRV) -@numba_funcify.register(ptr.LaplaceRV) -@numba_funcify.register(ptr.BinomialRV) -@numba_funcify.register(ptr.MultinomialRV) -@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported -@numba_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported -@numba_funcify.register(ptr.PermutationRV) -def numba_funcify_RandomVariable(op, node, **kwargs): - name = op.name - np_random_func = getattr(np.random, name) - - return make_numba_random_fn(node, np_random_func) + return random -def create_numba_random_fn( - op: Op, - node: Apply, - scalar_fn: Callable[[str], str], - global_env: dict[str, Any] | None = None, -) -> Callable: - """Create a vectorized function from a callable that generates the ``str`` function body. - - TODO: This could/should be generalized for other simple function - construction cases that need unique-ified symbol names. - """ - np_random_fn_name = f"pytensor_random_{get_name_for_object(op.name)}" - - if global_env: - np_global_env = global_env.copy() - else: - np_global_env = {} +@numba_core_rv_funcify.register(ptr.HalfNormalRV) +def numba_core_HalfNormalRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, loc, scale): + return loc + scale * np.abs(rng.standard_normal()) - np_global_env["np"] = np - np_global_env["numba_vectorize"] = numba_basic.numba_vectorize + return random_fn - unique_names = unique_name_generator( - [np_random_fn_name, *np_global_env.keys(), "rng", "size", "dtype"], - suffix_sep="_", - ) - np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]] - np_input_names = ", ".join(np_names) - np_random_fn_src = f""" -@numba_vectorize -def {np_random_fn_name}({np_input_names}): -{scalar_fn(*np_names)} - """ - np_random_fn = compile_function_src( - np_random_fn_src, np_random_fn_name, {**globals(), **np_global_env} - ) +@numba_core_rv_funcify.register(ptr.CauchyRV) +def numba_core_CauchyRV(op, node): + @numba_basic.numba_njit + def random(rng, loc, scale): + return (loc + rng.standard_cauchy()) / scale - return make_numba_random_fn(node, np_random_fn) + return random -@numba_funcify.register(ptr.NegBinomialRV) -def numba_funcify_NegBinomialRV(op, node, **kwargs): - return make_numba_random_fn(node, np.random.negative_binomial) +@numba_core_rv_funcify.register(ptr.CategoricalRV) +def core_CategoricalRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, p, out): + unif_sample = rng.uniform(0, 1) + # TODO: Check if LLVM can lift constant cumsum(p) out of the loop + out[...] = np.searchsorted(np.cumsum(p), unif_sample) + random_fn.handles_out = True + return random_fn -@numba_funcify.register(ptr.CauchyRV) -def numba_funcify_CauchyRV(op, node, **kwargs): - def body_fn(loc, scale): - return f" return ({loc} + np.random.standard_cauchy()) / {scale}" - return create_numba_random_fn(op, node, body_fn) +@numba_core_rv_funcify.register(ptr.MvNormalRV) +def core_MvNormalRV(op, node): + @numba_basic.numba_njit + def random_fn(rng, mean, cov, out): + chol = np.linalg.cholesky(cov) + stdnorm = rng.normal(size=cov.shape[-1]) + # np.dot(chol, stdnorm, out=out) + # out[...] += mean + out[...] = mean + np.dot(chol, stdnorm) + + random_fn.handles_out = True + return random_fn -@numba_funcify.register(ptr.HalfNormalRV) -def numba_funcify_HalfNormalRV(op, node, **kwargs): - def body_fn(a, b): - return f" return {a} + {b} * abs(np.random.normal(0, 1))" +@numba_core_rv_funcify.register(ptr.GumbelRV) +def core_GumbelRV(op, node): + """Code adapted from Numpy Implementation - return create_numba_random_fn(op, node, body_fn) + https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L502-L511 + """ + @numba_basic.numba_njit + def random_fn(rng, loc, scale): + U = 1.0 - rng.random() + if U < 1.0: + return loc - scale * np.log(-np.log(U)) + else: + return random_fn(rng, loc, scale) -@numba_funcify.register(ptr.BernoulliRV) -def numba_funcify_BernoulliRV(op, node, **kwargs): - out_dtype = node.outputs[1].type.numpy_dtype + return random_fn - def body_fn(a): - return f""" - if {a} < np.random.uniform(0, 1): - return direct_cast(0, out_dtype) - else: - return direct_cast(1, out_dtype) - """ - - return create_numba_random_fn( - op, - node, - body_fn, - {"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast}, - ) +@numba_core_rv_funcify.register(ptr.VonMisesRV) +def core_VonMisesRV(op, node): + """Code adapted from Numpy Implementation -@numba_funcify.register(ptr.CategoricalRV) -def numba_funcify_CategoricalRV(op, node, **kwargs): - out_dtype = node.outputs[1].type.numpy_dtype - size_len = int(get_vector_length(node.inputs[1])) - p_ndim = node.inputs[-1].ndim + https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L855-L925 + """ @numba_basic.numba_njit - def categorical_rv(rng, size, dtype, p): - if not size_len: - size_tpl = p.shape[:-1] - else: - size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) - p = np.broadcast_to(p, size_tpl + p.shape[-1:]) - - # Workaround https://github.com/numba/numba/issues/8975 - if not size_len and p_ndim == 1: - unif_samples = np.asarray(np.random.uniform(0, 1)) + def random_fn(rng, mu, kappa): + if np.isnan(kappa): + return np.nan + if kappa < 1e-8: + # Use a uniform for very small values of kappa + return np.pi * (2 * rng.random() - 1) else: - unif_samples = np.random.uniform(0, 1, size_tpl) + # with double precision rho is zero until 1.4e-8 + if kappa < 1e-5: + # second order taylor expansion around kappa = 0 + # precise until relatively large kappas as second order is 0 + s = 1.0 / kappa + kappa + else: + if kappa <= 1e6: + # Path for 1e-5 <= kappa <= 1e6 + r = 1 + np.sqrt(1 + 4 * kappa * kappa) + rho = (r - np.sqrt(2 * r)) / (2 * kappa) + s = (1 + rho * rho) / (2 * rho) + else: + # Fallback to wrapped normal distribution for kappa > 1e6 + result = mu + np.sqrt(1.0 / kappa) * rng.standard_normal() + # Ensure result is within bounds + if result < -np.pi: + result += 2 * np.pi + if result > np.pi: + result -= 2 * np.pi + return result + + while True: + U = rng.random() + Z = np.cos(np.pi * U) + W = (1 + s * Z) / (s + Z) + Y = kappa * (s - W) + V = rng.random() + # V == 0.0 is ok here since Y >= 0 always leads + # to accept, while Y < 0 always rejects + if (Y * (2 - Y) - V >= 0) or (np.log(Y / V) + 1 - Y >= 0): + break + + U = rng.random() + + result = np.arccos(W) + if U < 0.5: + result = -result + result += mu + neg = result < 0 + mod = np.abs(result) + mod = np.mod(mod + np.pi, 2 * np.pi) - np.pi + if neg: + mod *= -1 + + return mod - res = np.empty(size_tpl, dtype=out_dtype) - for idx in np.ndindex(*size_tpl): - res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx]) + return random_fn - return (rng, res) - return categorical_rv +@numba_funcify.register(ptr.RandomVariable) +def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs): + _, size, _, *args = node.inputs + # None sizes are represented as empty tuple for the time being + # https://github.com/pymc-devs/pytensor/issues/568 + [size_len] = size.type.shape + size_is_None = size_len == 0 + inplace = op.inplace -@numba_funcify.register(ptr.DirichletRV) -def numba_funcify_DirichletRV(op, node, **kwargs): - out_dtype = node.outputs[1].type.numpy_dtype - alphas_ndim = node.inputs[3].type.ndim - neg_ind_shape_len = -alphas_ndim + 1 - size_len = int(get_vector_length(node.inputs[1])) - - if alphas_ndim > 1: - - @numba_basic.numba_njit - def dirichlet_rv(rng, size, dtype, alphas): - if size_len > 0: - size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) - if ( - 0 < alphas.ndim - 1 <= len(size_tpl) - and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1] - ): - raise ValueError("Parameters shape and size do not match.") - samples_shape = size_tpl + alphas.shape[-1:] - else: - samples_shape = alphas.shape + # TODO: Add core_shape to node.inputs + if op.ndim_supp > 0: + raise NotImplementedError("Multivariate RandomVariable not implemented yet") - res = np.empty(samples_shape, dtype=out_dtype) - alphas_bcast = np.broadcast_to(alphas, samples_shape) + core_op_fn = numba_core_rv_funcify(op, node) + if not getattr(core_op_fn, "handles_out", False): + core_op_fn = store_core_outputs(core_op_fn, nin=len(node.inputs) - 2, nout=1) - for index in np.ndindex(*samples_shape[:-1]): - res[index] = np.random.dirichlet(alphas_bcast[index]) + batch_ndim = op.batch_ndim(node) - return (rng, res) + # numba doesn't support nested literals right now... + input_bc_patterns = encode_literals( + tuple( + input_var.type.broadcastable[:batch_ndim] for input_var in node.inputs[3:] + ) + ) + output_bc_patterns = encode_literals( + (node.outputs[1].type.broadcastable[:batch_ndim],) + ) + output_dtypes = encode_literals((node.default_output().type.dtype,)) + inplace_pattern = encode_literals(()) + + def random_wrapper(rng, size, dtype, *inputs): + if not inplace: + rng = copy(rng) + + draws = _vectorized( + core_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + (rng,), + inputs, + ((),), # TODO: correct core_shapes + None + if size_is_None + else numba_ndarray.to_fixed_tuple(size, size_len), # size + ) + return rng, draws - else: + def random(rng, size, dtype, *inputs): + pass - @numba_basic.numba_njit - def dirichlet_rv(rng, size, dtype, alphas): - size = numba_ndarray.to_fixed_tuple(size, size_len) - return (rng, np.random.dirichlet(alphas, size)) + @overload(random) + def ov_random(rng, size, dtype, *inputs): + return random_wrapper - return dirichlet_rv + return random diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index cd1cd2c298..25c6ea13c9 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -2,8 +2,9 @@ import base64 import pickle -from collections.abc import Sequence -from typing import Any +from collections.abc import Callable, Sequence +from textwrap import indent +from typing import Any, cast import numba import numpy as np @@ -11,13 +12,54 @@ from numba import TypingError, types from numba.core import cgutils from numba.core.base import BaseContext +from numba.core.types.misc import NoneType from numba.np import arrayobj +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.utils import compile_function_src + def encode_literals(literals: Sequence) -> str: return base64.encodebytes(pickle.dumps(literals)).decode() +def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: + """Create a Numba function that wraps a core function and stores its vectorized outputs. + + @njit + def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): + to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) + o0[...] = to0 + o1[...] = to1 + ... + on[...] = ton + + """ + inputs = [f"i{i}" for i in range(nin)] + outputs = [f"o{i}" for i in range(nout)] + inner_outputs = [f"t{output}" for output in outputs] + + inp_signature = ", ".join(inputs) + out_signature = ", ".join(outputs) + inner_out_signature = ", ".join(inner_outputs) + store_outputs = "\n".join( + [ + f"{output}[...] = {inner_output}" + for output, inner_output in zip(outputs, inner_outputs) + ] + ) + func_src = f""" +def store_core_outputs({inp_signature}, {out_signature}): + {inner_out_signature} = core_op_fn({inp_signature}) +{indent(store_outputs, " " * 4)} +""" + global_env = {"core_op_fn": core_op_fn} + func = compile_function_src( + func_src, "store_core_outputs", {**globals(), **global_env} + ) + return cast(Callable, numba_basic.numba_njit(func)) + + _jit_options = { "fastmath": { "arcp", # Allow Reciprocal @@ -37,7 +79,10 @@ def _vectorized( output_bc_patterns, output_dtypes, inplace_pattern, - inputs, + constant_inputs_types, + input_types, + output_core_shape_types, + size_type, ): arg_types = [ scalar_func, @@ -45,7 +90,10 @@ def _vectorized( output_bc_patterns, output_dtypes, inplace_pattern, - inputs, + constant_inputs_types, + input_types, + output_core_shape_types, + size_type, ] if not isinstance(input_bc_patterns, types.Literal): @@ -68,34 +116,82 @@ def _vectorized( inplace_pattern = inplace_pattern.literal_value inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) - n_outputs = len(output_bc_patterns) + batch_ndim = len(input_bc_patterns[0]) + nin = len(constant_inputs_types) + len(input_types) + nout = len(output_bc_patterns) + + if nin == 0: + raise TypingError("Empty argument list to vectorized op.") + + if nout == 0: + raise TypingError("Empty list of outputs for vectorized op.") - if not len(inputs) > 0: - raise TypingError("Empty argument list to elemwise op.") + if not all(isinstance(input, types.Array) for input in input_types): + raise TypingError("Vectorized inputs must be arrays.") - if not n_outputs > 0: - raise TypingError("Empty list of outputs for elemwise op.") + if not all( + len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns + ): + raise TypingError( + "Vectorized broadcastable patterns must have the same length." + ) + + core_input_types = [] + for input_type, bc_pattern in zip(input_types, input_bc_patterns): + core_ndim = input_type.ndim - len(bc_pattern) + # TODO: Reconsider this + if core_ndim == 0: + core_input_type = input_type.dtype + else: + core_input_type = types.Array( + dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + ) + core_input_types.append(core_input_type) - if not all(isinstance(input, types.Array) for input in inputs): - raise TypingError("Inputs to elemwise must be arrays.") - ndim = inputs[0].ndim + core_out_types = [ + types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C") + for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + ] - if not all(input.ndim == ndim for input in inputs): - raise TypingError("Inputs to elemwise must have the same rank.") + out_types = [ + types.Array( + numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C" + ) + for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + ] - if not all(len(pattern) == ndim for pattern in output_bc_patterns): - raise TypingError("Invalid output broadcasting pattern.") + for output_idx, input_idx in inplace_pattern: + output_type = input_types[input_idx] + core_out_types[output_idx] = types.Array( + dtype=output_type.dtype, + ndim=output_type.ndim - batch_ndim, + layout=input_type.layout, + ) + out_types[output_idx] = output_type - scalar_signature = typingctx.resolve_function_type( - scalar_func, [in_type.dtype for in_type in inputs], {} + core_signature = typingctx.resolve_function_type( + scalar_func, + [ + *constant_inputs_types, + *core_input_types, + *core_out_types, + ], + {}, ) + ret_type = types.Tuple(out_types) + + if len(output_dtypes) == 1: + ret_type = ret_type.types[0] + sig = ret_type(*arg_types) + # So we can access the constant values in codegen... input_bc_patterns_val = input_bc_patterns output_bc_patterns_val = output_bc_patterns output_dtypes_val = output_dtypes inplace_pattern_val = inplace_pattern - input_types = inputs + input_types = input_types + size_is_none = isinstance(size_type, NoneType) def codegen( ctx, @@ -103,8 +199,16 @@ def codegen( sig, args, ): - [_, _, _, _, _, inputs] = args + [_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args + + constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) inputs = cgutils.unpack_tuple(builder, inputs) + output_core_shapes = [ + cgutils.unpack_tuple(builder, shape) + for shape in cgutils.unpack_tuple(builder, output_core_shapes) + ] + size = None if size_is_none else cgutils.unpack_tuple(builder, size) + inputs = [ arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs) @@ -116,6 +220,7 @@ def codegen( builder, in_shapes, input_bc_patterns_val, + size, ) outputs, output_types = make_outputs( @@ -127,6 +232,7 @@ def codegen( inplace_pattern_val, inputs, input_types, + output_core_shapes, ) make_loop_call( @@ -134,8 +240,9 @@ def codegen( ctx, builder, scalar_func, - scalar_signature, + core_signature, iter_shape, + constant_inputs, inputs, outputs, input_bc_patterns_val, @@ -160,69 +267,94 @@ def codegen( builder, sig.return_type, [out._getvalue() for out in outputs] ) - ret_types = [ - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") - for dtype in output_dtypes - ] - - for output_idx, input_idx in inplace_pattern: - ret_types[output_idx] = input_types[input_idx] - - ret_type = types.Tuple(ret_types) - - if len(output_dtypes) == 1: - ret_type = ret_type.types[0] - sig = ret_type(*arg_types) - return sig, codegen def compute_itershape( ctx: BaseContext, builder: ir.IRBuilder, - in_shapes: tuple[ir.Instruction, ...], + in_shapes: list[list[ir.Instruction]], broadcast_pattern: tuple[tuple[bool, ...], ...], + size: list[ir.Instruction] | None, ): one = ir.IntType(64)(1) - ndim = len(in_shapes[0]) - shape = [None] * ndim - for i in range(ndim): - for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): - length = in_shape[i] - if bc[i]: - with builder.if_then( - builder.icmp_unsigned("!=", length, one), likely=False - ): - msg = ( - f"Input {j} to elemwise is expected to have shape 1 in axis {i}" - ) - ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) - elif shape[i] is not None: - with builder.if_then( - builder.icmp_unsigned("!=", length, shape[i]), likely=False - ): - with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( - then, - otherwise, + batch_ndim = len(broadcast_pattern[0]) + shape = [None] * batch_ndim + if size is not None: + shape = size + for i in range(batch_ndim): + for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + length = in_shape[i] + if bc[i]: + with builder.if_then( + builder.icmp_unsigned("!=", length, one), likely=False + ): + msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}" + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + else: + with builder.if_then( + builder.icmp_unsigned("!=", length, shape[i]), likely=False + ): + with builder.if_else( + builder.icmp_unsigned("==", length, one) + ) as ( + then, + otherwise, + ): + with then: + msg = ( + f"Incompatible vectorized shapes for input {j} and axis {i}. " + f"Input {j} has shape 1, but is not statically " + "known to have shape 1, and thus not broadcastable." + ) + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + with otherwise: + msg = f"Vectorized input {j} has an incompatible shape in axis {i}." + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + else: + # Size is implied by the broadcast pattern + for i in range(batch_ndim): + for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + length = in_shape[i] + if bc[i]: + with builder.if_then( + builder.icmp_unsigned("!=", length, one), likely=False + ): + msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}" + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + elif shape[i] is not None: + with builder.if_then( + builder.icmp_unsigned("!=", length, shape[i]), likely=False ): - with then: - msg = ( - f"Incompatible shapes for input {j} and axis {i} of " - f"elemwise. Input {j} has shape 1, but is not statically " - "known to have shape 1, and thus not broadcastable." - ) - ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) - with otherwise: - msg = ( - f"Input {j} to elemwise has an incompatible " - f"shape in axis {i}." - ) - ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) - else: - shape[i] = length - for i in range(ndim): - if shape[i] is None: - shape[i] = one + with builder.if_else( + builder.icmp_unsigned("==", length, one) + ) as ( + then, + otherwise, + ): + with then: + msg = ( + f"Incompatible vectorized shapes for input {j} and axis {i}. " + f"Input {j} has shape 1, but is not statically " + "known to have shape 1, and thus not broadcastable." + ) + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + with otherwise: + msg = f"Vectorized input {j} has an incompatible shape in axis {i}." + ctx.call_conv.return_user_exc( + builder, ValueError, (msg,) + ) + else: + shape[i] = length + for i in range(batch_ndim): + if shape[i] is None: + shape[i] = one return shape @@ -235,27 +367,32 @@ def make_outputs( inplace: tuple[tuple[int, int], ...], inputs: tuple[Any, ...], input_types: tuple[Any, ...], -): - arrays = [] - ar_types: list[types.Array] = [] + output_core_shapes: tuple, +) -> tuple[list[ir.Value], list[types.Array]]: + output_arrays = [] + output_arry_types = [] one = ir.IntType(64)(1) inplace_dict = dict(inplace) - for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): + for i, (core_shape, bc, dtype) in enumerate( + zip(output_core_shapes, out_bc, dtypes) + ): if i in inplace_dict: - arrays.append(inputs[inplace_dict[i]]) - ar_types.append(input_types[inplace_dict[i]]) + output_arrays.append(inputs[inplace_dict[i]]) + output_arry_types.append(input_types[inplace_dict[i]]) # We need to incref once we return the inplace objects continue dtype = numba.from_dtype(np.dtype(dtype)) - arrtype = types.Array(dtype, len(iter_shape), "C") - ar_types.append(arrtype) + output_ndim = len(iter_shape) + len(core_shape) + arrtype = types.Array(dtype, output_ndim, "C") + output_arry_types.append(arrtype) # This is actually an internal numba function, I guess we could # call `numba.nd.unsafe.ndarray` instead? - shape = [ + batch_shape = [ length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) ] + shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) - arrays.append(array) + output_arrays.append(array) # If there is no inplace operation, we know that all output arrays # don't alias. Informing llvm can make it easier to vectorize. @@ -263,7 +400,7 @@ def make_outputs( # The first argument is the output pointer arg = builder.function.args[0] arg.add_attribute("noalias") - return arrays, ar_types + return output_arrays, output_arry_types def make_loop_call( @@ -273,6 +410,7 @@ def make_loop_call( scalar_func: Any, scalar_signature: types.FunctionType, iter_shape: tuple[ir.Instruction, ...], + constant_inputs: tuple[ir.Instruction, ...], inputs: tuple[ir.Instruction, ...], outputs: tuple[ir.Instruction, ...], input_bc: tuple[tuple[bool, ...], ...], @@ -281,18 +419,8 @@ def make_loop_call( output_types: tuple[Any, ...], ): safe = (False, False) - n_outputs = len(outputs) - - # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) - # Extract shape and stride information from the array. - # For later use in the loop body to do the indexing - def extract_array(aryty, obj): - shape = cgutils.unpack_tuple(builder, obj.shape) - strides = cgutils.unpack_tuple(builder, obj.strides) - data = obj.data - layout = aryty.layout - return (data, shape, strides, layout) + n_outputs = len(outputs) # TODO I think this is better than the noalias attribute # for the input, but self_ref isn't supported in a released @@ -304,12 +432,6 @@ def extract_array(aryty, obj): # input_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope]) - inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs)) - - outputs = tuple( - extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs) - ) - zero = ir.Constant(ir.IntType(64), 0) # Setup loops and initialize accumulators for outputs @@ -336,69 +458,105 @@ def extract_array(aryty, obj): # Load values from input arrays input_vals = [] - for array_info, bc in zip(inputs, input_bc): - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] - ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe) - val = builder.load(ptr) - # val.set_metadata("alias.scope", input_scope_set) - # val.set_metadata("noalias", output_scope_set) + for input, input_type, bc in zip(inputs, input_types, input_bc): + core_ndim = input_type.ndim - len(bc) + + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + zero + ] * core_ndim + ptr = cgutils.get_item_pointer2( + context, + builder, + input.data, + cgutils.unpack_tuple(builder, input.shape), + cgutils.unpack_tuple(builder, input.strides), + input_type.layout, + idxs_bc, + *safe, + ) + if core_ndim == 0: + # Retrive scalar item at index + val = builder.load(ptr) + # val.set_metadata("alias.scope", input_scope_set) + # val.set_metadata("noalias", output_scope_set) + else: + # Retrieve array item at index + # This is a streamlined version of Numba's `GUArrayArg.load` + # TODO check layout arg! + core_arry_type = types.Array( + dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout + ) + core_array = context.make_array(core_arry_type)(context, builder) + core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:] + core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:] + itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype)) + context.populate_array( + core_array, + # TODO whey do we need to bitcast? + data=builder.bitcast(ptr, core_array.data.type), + shape=cgutils.pack_array(builder, core_shape), + strides=cgutils.pack_array(builder, core_strides), + itemsize=context.get_constant(types.intp, itemsize), + # TODO what is meminfo about? + meminfo=None, + ) + val = core_array._getvalue() + input_vals.append(val) + # Create output slices to pass to inner func + output_slices = [] + for output, output_type, bc in zip(outputs, output_types, output_bc): + core_ndim = output_type.ndim - len(bc) + size_type = output.shape.type.element # type: ignore + output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore + output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore + + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + zero + ] * core_ndim + ptr = cgutils.get_item_pointer2( + context, + builder, + output.data, # type:ignore + output_shape, + output_strides, + output_type.layout, + idxs_bc, + *safe, + ) + + # Retrieve array item at index + # This is a streamlined version of Numba's `GUArrayArg.load` + core_arry_type = types.Array( + dtype=output_type.dtype, ndim=core_ndim, layout=output_type.layout + ) + core_array = context.make_array(core_arry_type)(context, builder) + core_shape = output_shape[-core_ndim:] if core_ndim > 0 else [] + core_strides = output_strides[-core_ndim:] if core_ndim > 0 else [] + itemsize = context.get_abi_sizeof(context.get_data_type(output_type.dtype)) + context.populate_array( + core_array, + # TODO whey do we need to bitcast? + data=builder.bitcast(ptr, core_array.data.type), + shape=cgutils.pack_array(builder, core_shape, ty=size_type), + strides=cgutils.pack_array(builder, core_strides, ty=size_type), + itemsize=context.get_constant(types.intp, itemsize), + # TODO what is meminfo about? + meminfo=None, + ) + val = core_array._getvalue() + output_slices.append(val) + inner_codegen = context.get_function(scalar_func, scalar_signature) if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple): input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)] - output_values = inner_codegen(builder, input_vals) - if isinstance(scalar_signature.return_type, types.Tuple | types.UniTuple): - output_values = cgutils.unpack_tuple(builder, output_values) - func_output_types = scalar_signature.return_type.types - else: - output_values = [output_values] - func_output_types = [scalar_signature.return_type] - - # Update output value or accumulators respectively - for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)): - if accu is not None: - load = builder.load(accu) - # load.set_metadata("alias.scope", output_scope_set) - # load.set_metadata("noalias", input_scope_set) - new_value = builder.fadd(load, value) - builder.store(new_value, accu) - # TODO belongs to noalias scope - # store.set_metadata("alias.scope", output_scope_set) - # store.set_metadata("noalias", input_scope_set) - else: - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])] - ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) - # store = builder.store(value, ptr) - value = context.cast( - builder, value, func_output_types[i], output_types[i].dtype - ) - arrayobj.store_item(context, builder, output_types[i], value, ptr) - # store.set_metadata("alias.scope", output_scope_set) - # store.set_metadata("noalias", input_scope_set) + inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices]) - # Close the loops and write accumulator values to the output arrays + # Close the loops for depth, loop in enumerate(loop_stack[::-1]): - for output, (accu, accu_depth) in enumerate(output_accumulator): - if accu_depth == depth: - idxs_bc = [ - zero if bc else idx for idx, bc in zip(idxs, output_bc[output]) - ] - ptr = cgutils.get_item_pointer2( - context, builder, *outputs[output], idxs_bc - ) - load = builder.load(accu) - # load.set_metadata("alias.scope", output_scope_set) - # load.set_metadata("noalias", input_scope_set) - # store = builder.store(load, ptr) - load = context.cast( - builder, load, func_output_types[output], output_types[output].dtype - ) - arrayobj.store_item(context, builder, output_types[output], load, ptr) - # store.set_metadata("alias.scope", output_scope_set) - # store.set_metadata("noalias", input_scope_set) loop.__exit__(None, None, None) return diff --git a/pytensor/tensor/random/__init__.py b/pytensor/tensor/random/__init__.py index a1cd42f789..78994fd40c 100644 --- a/pytensor/tensor/random/__init__.py +++ b/pytensor/tensor/random/__init__.py @@ -2,5 +2,5 @@ import pytensor.tensor.random.rewriting import pytensor.tensor.random.utils from pytensor.tensor.random.basic import * -from pytensor.tensor.random.op import RandomState, default_rng +from pytensor.tensor.random.op import default_rng from pytensor.tensor.random.utils import RandomStream diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index a42a18dfed..4a174b5f2a 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -7,15 +7,10 @@ import pytensor from pytensor.tensor.basic import arange, as_tensor_variable from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType from pytensor.tensor.random.utils import ( broadcast_params, supp_shape_from_ref_param_shape, ) -from pytensor.tensor.random.var import ( - RandomGeneratorSharedVariable, - RandomStateSharedVariable, -) try: @@ -629,7 +624,7 @@ def __call__( @classmethod def rng_fn_scipy( cls, - rng: np.random.Generator | np.random.RandomState, + rng: np.random.Generator, loc: np.ndarray | float, scale: np.ndarray | float, size: list[int] | int | None, @@ -1606,7 +1601,7 @@ def __call__(self, n, p, size=None, **kwargs): binomial = BinomialRV() -class NegBinomialRV(ScipyRandomVariable): +class NegativeBinomialRV(RandomVariable): r"""A negative binomial discrete random variable. The probability mass function for `nbinom` for the number :math:`k` of draws @@ -1619,7 +1614,7 @@ class NegBinomialRV(ScipyRandomVariable): """ - name = "nbinom" + name = "negative_binomial" ndim_supp = 0 ndims_params = [0, 0] dtype = "int64" @@ -1647,13 +1642,8 @@ def __call__(self, n, p, size=None, **kwargs): """ return super().__call__(n, p, size=size, **kwargs) - @classmethod - def rng_fn_scipy(cls, rng, n, p, size): - return stats.nbinom.rvs(n, p, size=size, random_state=rng) - -nbinom = NegBinomialRV() -negative_binomial = NegBinomialRV() +negative_binomial = nbinom = NegativeBinomialRV() class BetaBinomialRV(ScipyRandomVariable): @@ -1914,59 +1904,6 @@ def rng_fn(cls, rng, p, size): categorical = CategoricalRV() -class RandIntRV(RandomVariable): - r"""A discrete uniform random variable. - - Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s. - - """ - - name = "randint" - ndim_supp = 0 - ndims_params = [0, 0] - dtype = "int64" - _print_name = ("randint", "\\operatorname{randint}") - - def __call__(self, low, high=None, size=None, **kwargs): - r"""Draw samples from a discrete uniform distribution. - - Signature - --------- - - `() -> ()` - - Parameters - ---------- - low - Lower boundary of the output interval. All values generated will - be greater than or equal to `low`, unless `high=None`, in which case - all values generated are greater than or equal to `0` and - smaller than `low` (exclusive). - high - Upper boundary of the output interval. All values generated - will be smaller than `high` (exclusive). - size - Sample shape. If the given size is `(m, n, k)`, then `m * n * k` - independent, identically distributed samples are - returned. Default is `None`, in which case a single - sample is returned. - - """ - if high is None: - low, high = 0, low - return super().__call__(low, high, size=size, **kwargs) - - def make_node(self, rng, *args, **kwargs): - if not isinstance( - getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable - ): - raise TypeError("`randint` is only available for `RandomStateType`s") - return super().make_node(rng, *args, **kwargs) - - -randint = RandIntRV() - - class IntegersRV(RandomVariable): r"""A discrete uniform random variable. @@ -2007,14 +1944,6 @@ def __call__(self, low, high=None, size=None, **kwargs): low, high = 0, low return super().__call__(low, high, size=size, **kwargs) - def make_node(self, rng, *args, **kwargs): - if not isinstance( - getattr(rng, "type", None), - RandomGeneratorType | RandomGeneratorSharedVariable, - ): - raise TypeError("`integers` is only available for `RandomGeneratorType`s") - return super().make_node(rng, *args, **kwargs) - integers = IntegersRV() @@ -2153,7 +2082,6 @@ def permutation(x, **kwargs): "permutation", "choice", "integers", - "randint", "categorical", "multinomial", "betabinom", diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index f4c4403bc7..4ae558aacd 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -18,7 +18,7 @@ get_vector_length, infer_static_shape, ) -from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType +from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import ( compute_batch_shape, explicit_expand_dims, @@ -254,9 +254,8 @@ def make_node(self, rng, size, dtype, *dist_params): Parameters ---------- - rng: RandomGeneratorType or RandomStateType - Existing PyTensor `Generator` or `RandomState` object to be used. Creates a - new one, if `None`. + rng: RandomGeneratorType + Existing PyTensor `Generator` object to be used. Creates a new one, if `None`. size: int or Sequence NumPy-like size parameter. dtype: str @@ -284,7 +283,7 @@ def make_node(self, rng, size, dtype, *dist_params): rng = pytensor.shared(np.random.default_rng()) elif not isinstance(rng.type, RandomType): raise TypeError( - "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" + "The type of rng should be an instance of RandomGeneratorType " ) shape = self._infer_shape(size, dist_params) @@ -374,14 +373,6 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed) -class RandomStateConstructor(AbstractRNGConstructor): - random_type = RandomStateType() - random_constructor = "RandomState" - - -RandomState = RandomStateConstructor() - - class DefaultGeneratorMakerOp(AbstractRNGConstructor): random_type = RandomGeneratorType() random_constructor = "default_rng" diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index c4f3a8207a..d5c4e5ad8c 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -12,7 +12,7 @@ HalfNormalRV, InvGammaRV, LogNormalRV, - NegBinomialRV, + NegativeBinomialRV, WaldRV, _gamma, beta, @@ -88,7 +88,7 @@ def geometric_from_uniform(fgraph, node): return [next_rng, cast(g, dtype=node.default_output().dtype)] -@node_rewriter([NegBinomialRV]) +@node_rewriter([NegativeBinomialRV]) def negative_binomial_from_gamma_poisson(fgraph, node): rng, *other_inputs, n, p = node.inputs next_rng, g = _gamma.make_node(rng, *other_inputs, n, (1 - p) / p).outputs diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 527d3f3d6b..7f2a156271 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -31,97 +31,6 @@ def may_share_memory(a: T, b: T): return a._bit_generator is b._bit_generator # type: ignore[attr-defined] -class RandomStateType(RandomType[np.random.RandomState]): - r"""A Type wrapper for `numpy.random.RandomState`. - - The reason this exists (and `Generic` doesn't suffice) is that - `RandomState` objects that would appear to be equal do not compare equal - with the ``==`` operator. - - This `Type` also works with a ``dict`` derived from - `RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter` - is explicitly set to ``True``. - - """ - - def __repr__(self): - return "RandomStateType" - - def filter(self, data, strict: bool = False, allow_downcast=None): - """ - XXX: This doesn't convert `data` to the same type of underlying RNG type - as `self`. It really only checks that `data` is of the appropriate type - to be a valid `RandomStateType`. - - In other words, it serves as a `Type.is_valid_value` implementation, - but, because the default `Type.is_valid_value` depends on - `Type.filter`, we need to have it here to avoid surprising circular - dependencies in sub-classes. - """ - if isinstance(data, np.random.RandomState): - return data - - if not strict and isinstance(data, dict): - gen_keys = ["bit_generator", "gauss", "has_gauss", "state"] - state_keys = ["key", "pos"] - - for key in gen_keys: - if key not in data: - raise TypeError() - - for key in state_keys: - if key not in data["state"]: - raise TypeError() - - state_key = data["state"]["key"] - if state_key.shape == (624,) and state_key.dtype == np.uint32: - # TODO: Add an option to convert to a `RandomState` instance? - return data - - raise TypeError() - - @staticmethod - def values_eq(a, b): - sa = a if isinstance(a, dict) else a.get_state(legacy=False) - sb = b if isinstance(b, dict) else b.get_state(legacy=False) - - def _eq(sa, sb): - for key in sa: - if isinstance(sa[key], dict): - if not _eq(sa[key], sb[key]): - return False - elif isinstance(sa[key], np.ndarray): - if not np.array_equal(sa[key], sb[key]): - return False - else: - if sa[key] != sb[key]: - return False - - return True - - return _eq(sa, sb) - - def __eq__(self, other): - return type(self) == type(other) - - def __hash__(self): - return hash(type(self)) - - -# Register `RandomStateType`'s C code for `ViewOp`. -pytensor.compile.register_view_op_c_code( - RandomStateType, - """ - Py_XDECREF(%(oname)s); - %(oname)s = %(iname)s; - Py_XINCREF(%(oname)s); - """, - 1, -) - -random_state_type = RandomStateType() - - class RandomGeneratorType(RandomType[np.random.Generator]): r"""A Type wrapper for `numpy.random.Generator`. diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index dfaf2fc2e7..704a8b30d8 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -211,9 +211,7 @@ def __init__( self, seed: int | None = None, namespace: ModuleType | None = None, - rng_ctor: Literal[ - np.random.RandomState, np.random.Generator - ] = np.random.default_rng, + rng_ctor: Literal[np.random.Generator] = np.random.default_rng, ): if namespace is None: from pytensor.tensor.random import basic # pylint: disable=import-self @@ -225,12 +223,6 @@ def __init__( self.default_instance_seed = seed self.state_updates = [] self.gen_seedgen = np.random.SeedSequence(seed) - - if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState): - # The legacy state does not accept `SeedSequence`s directly - def rng_ctor(seed): - return np.random.RandomState(np.random.MT19937(seed)) - self.rng_ctor = rng_ctor def __getattr__(self, obj): diff --git a/pytensor/tensor/random/var.py b/pytensor/tensor/random/var.py index c03b3046ab..58cc858f2d 100644 --- a/pytensor/tensor/random/var.py +++ b/pytensor/tensor/random/var.py @@ -3,12 +3,7 @@ import numpy as np from pytensor.compile.sharedvalue import SharedVariable, shared_constructor -from pytensor.tensor.random.type import random_generator_type, random_state_type - - -class RandomStateSharedVariable(SharedVariable): - def __str__(self): - return self.name or f"RandomStateSharedVariable({self.container!r})" +from pytensor.tensor.random.type import random_generator_type class RandomGeneratorSharedVariable(SharedVariable): @@ -23,11 +18,12 @@ def randomgen_constructor( ): r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`.""" if isinstance(value, np.random.RandomState): - rng_sv_type = RandomStateSharedVariable - rng_type = random_state_type - elif isinstance(value, np.random.Generator): - rng_sv_type = RandomGeneratorSharedVariable - rng_type = random_generator_type + raise TypeError( + "`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead." + ) + + rng_sv_type = RandomGeneratorSharedVariable + rng_type = random_generator_type if not borrow: value = copy.deepcopy(value) diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 52fa8dc502..d73c19752b 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py pytensor/ifelse.py pytensor/link/basic.py pytensor/link/numba/dispatch/elemwise.py -pytensor/link/numba/dispatch/random.py pytensor/link/numba/dispatch/scan.py pytensor/printing.py pytensor/raise_op.py diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 5830983518..9f773e5119 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -229,6 +229,7 @@ def compare_numba_and_py( numba_mode=numba_mode, py_mode=py_mode, updates=None, + eval_obj_mode: bool = True, ) -> tuple[Callable, Any]: """Function to compare python graph output and Numba compiled output for testing equality @@ -247,6 +248,8 @@ def compare_numba_and_py( provided uses `np.testing.assert_allclose`. updates Updates to be passed to `pytensor.function`. + eval_obj_mode : bool, default True + Whether to do an isolated call in object mode. Used for test coverage Returns ------- @@ -283,7 +286,8 @@ def assert_fn(x, y): numba_res = pytensor_numba_fn(*inputs) # Get some coverage - eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) + if eval_obj_mode: + eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) if len(fn_outputs) > 1: for j, p in zip(numba_res, py_res): diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index a9b00b4290..e67a1834b9 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -22,34 +22,139 @@ rng = np.random.default_rng(42849) +@pytest.mark.parametrize("mu_shape", [(), (3,), (5, 1)]) +@pytest.mark.parametrize("sigma_shape", [(), (1,), (5, 3)]) +@pytest.mark.parametrize("size_type", (None, "constant", "mutable")) +def test_random_size(mu_shape, sigma_shape, size_type): + test_value_rng = np.random.default_rng(637) + mu = test_value_rng.normal(size=mu_shape) + sigma = np.exp(test_value_rng.normal(size=sigma_shape)) + + # For testing + rng = np.random.default_rng(123) + pt_rng = shared(rng) + if size_type is None: + size = None + pt_size = None + elif size_type == "constant": + size = (5, 3) + pt_size = pt.as_tensor(size, dtype="int64") + else: + size = (5, 3) + pt_size = shared(np.array(size, dtype="int64"), shape=(2,)) + + next_rng, x = pt.random.normal(mu, sigma, rng=pt_rng, size=pt_size).owner.outputs + fn = function([], x, updates={pt_rng: next_rng}, mode="NUMBA") + + res1 = fn() + np.testing.assert_allclose( + res1, + rng.normal(mu, sigma, size=size), + ) + + res2 = fn() + np.testing.assert_allclose( + res2, + rng.normal(mu, sigma, size=size), + ) + + pt_rng.set_value(np.random.default_rng(123)) + res3 = fn() + np.testing.assert_array_equal(res1, res3) + + if size_type == "mutable" and len(mu_shape) < 2 and len(sigma_shape) < 2: + pt_size.set_value(np.array((6, 3), dtype="int64")) + res4 = fn() + assert res4.shape == (6, 3) + + +def test_rng_copy(): + rng = shared(np.random.default_rng(123)) + x = pt.random.normal(rng=rng) + + fn = function([], x, mode="NUMBA") + np.testing.assert_array_equal(fn(), fn()) + + rng.type.values_eq(rng.get_value(), np.random.default_rng(123)) + + +def test_rng_non_default_update(): + rng = shared(np.random.default_rng(1)) + rng_new = shared(np.random.default_rng(2)) + + x = pt.random.normal(size=10, rng=rng) + fn = function([], x, updates={rng: rng_new}, mode=numba_mode) + + ref = np.random.default_rng(1).normal(size=10) + np.testing.assert_allclose(fn(), ref) + + ref = np.random.default_rng(2).normal(size=10) + np.testing.assert_allclose(fn(), ref) + np.testing.assert_allclose(fn(), ref) + + +def test_categorical_rv(): + """This is also a smoke test for a vector input scalar output RV""" + p = np.array( + [ + [ + [1.0, 0, 0, 0], + [0.0, 1.0, 0, 0], + [0.0, 0, 1.0, 0], + ], + [ + [0, 0, 0, 1.0], + [0, 0, 0, 1.0], + [0, 0, 0, 1.0], + ], + ] + ) + x = pt.random.categorical(p=p, size=None) + updates = {x.owner.inputs[0]: x.owner.outputs[0]} + fn = function([], x, updates=updates, mode="NUMBA") + res = fn() + assert np.all(np.argmax(p, axis=-1) == res) + + # Batch size + x = pt.random.categorical(p=p, size=(3, *p.shape[:-1])) + fn = function([], x, updates=updates, mode="NUMBA") + new_res = fn() + assert new_res.shape == (3, *res.shape) + for new_res_row in new_res: + assert np.all(new_res_row == res) + + +def test_multivariate_normal(): + """This is also a smoke test for a multivariate RV""" + rng = np.random.default_rng(123) + + x = pt.random.multivariate_normal( + mean=np.zeros((3, 2)), + cov=np.eye(2), + rng=shared(rng), + ) + + fn = function([], x, mode="NUMBA") + np.testing.assert_array_equal( + fn(), + rng.multivariate_normal(np.zeros(2), np.eye(2), size=(3,)), + ) + + @pytest.mark.parametrize( "rv_op, dist_args, size", [ ( - ptr.normal, + ptr.uniform, [ - set_test_value( - pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), set_test_value( pt.dscalar(), np.array(1.0, dtype=np.float64), ), - ], - pt.as_tensor([3, 2]), - ), - ( - ptr.uniform, - [ set_test_value( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), ), - set_test_value( - pt.dscalar(), - np.array(1.0, dtype=np.float64), - ), ], pt.as_tensor([3, 2]), ), @@ -144,7 +249,7 @@ ], pt.as_tensor([3, 2]), ), - ( + pytest.param( ptr.hypergeometric, [ set_test_value( @@ -161,6 +266,7 @@ ), ], pt.as_tensor([3, 2]), + marks=pytest.mark.xfail, # Not implemented ), ( ptr.wald, @@ -252,57 +358,6 @@ ], None, ), - ( - ptr.randint, - [ - set_test_value( - pt.lscalar(), - np.array(0, dtype=np.int64), - ), - set_test_value( - pt.lscalar(), - np.array(5, dtype=np.int64), - ), - ], - pt.as_tensor([3, 2]), - ), - pytest.param( - ptr.multivariate_normal, - [ - set_test_value( - pt.dmatrix(), - np.array([[1, 2], [3, 4]], dtype=np.float64), - ), - set_test_value( - pt.tensor(dtype="float64", shape=(1, None, None)), - np.eye(2)[None, ...], - ), - ], - pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [4, 3, 2])), - marks=pytest.mark.xfail(reason="Not implemented"), - ), - ], - ids=str, -) -def test_aligned_RandomVariable(rv_op, dist_args, size): - """Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers.""" - rng = shared(np.random.RandomState(29402)) - g = rv_op(*dist_args, size=size, rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) - - -@pytest.mark.parametrize( - "rv_op, dist_args, base_size, cdf_name, params_conv", - [ ( ptr.beta, [ @@ -316,8 +371,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ), ], (2,), - "beta", - lambda *args: args, ), ( ptr._gamma, @@ -332,43 +385,37 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ), ], (2,), - "gamma", - lambda a, b: (a, 0.0, b), ), ( - ptr.cauchy, + ptr.chisquare, [ set_test_value( pt.dvector(), np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - pt.dscalar(), - np.array(1.0, dtype=np.float64), - ), + ) ], (2,), - "cauchy", - lambda *args: args, ), ( - ptr.chisquare, + ptr.negative_binomial, [ set_test_value( - pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ) + pt.lvector(), + np.array([100, 200], dtype=np.int64), + ), + set_test_value( + pt.dscalar(), + np.array(0.09, dtype=np.float64), + ), ], (2,), - "chi2", - lambda *args: args, ), ( - ptr.gumbel, + ptr.vonmises, [ set_test_value( pt.dvector(), - np.array([1.0, 2.0], dtype=np.float64), + np.array([-0.5, 0.5], dtype=np.float64), ), set_test_value( pt.dscalar(), @@ -376,31 +423,52 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ), ], (2,), - "gumbel_r", - lambda *args: args, ), + ], + ids=str, +) +def test_aligned_RandomVariable(rv_op, dist_args, size): + """Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers.""" + rng = shared(np.random.default_rng(29402)) + g = rv_op(*dist_args, size=size, rng=rng) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, SharedVariable | Constant) + ], + eval_obj_mode=False, # No python impl + ) + + +@pytest.mark.parametrize( + "rv_op, dist_args, base_size, cdf_name, params_conv", + [ ( - ptr.negative_binomial, + ptr.cauchy, [ set_test_value( - pt.lvector(), - np.array([100, 200], dtype=np.int64), + pt.dvector(), + np.array([1.0, 2.0], dtype=np.float64), ), set_test_value( pt.dscalar(), - np.array(0.09, dtype=np.float64), + np.array(1.0, dtype=np.float64), ), ], (2,), - "nbinom", + "cauchy", lambda *args: args, ), - pytest.param( - ptr.vonmises, + ( + ptr.gumbel, [ set_test_value( pt.dvector(), - np.array([-0.5, 0.5], dtype=np.float64), + np.array([1.0, 2.0], dtype=np.float64), ), set_test_value( pt.dscalar(), @@ -408,20 +476,14 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ), ], (2,), - "vonmises_line", - lambda mu, kappa: (kappa, mu), - marks=pytest.mark.xfail( - reason=( - "Numba's parameterization of `vonmises` does not match NumPy's." - "See https://github.com/numba/numba/issues/7886" - ) - ), + "gumbel_r", + lambda *args: args, ), ], ) def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): """Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers.""" - rng = shared(np.random.RandomState(29402)) + rng = shared(np.random.default_rng(29402)) g = rv_op(*dist_args, size=(2000, *base_size), rng=rng) g_fn = function(dist_args, g, mode=numba_mode) samples = g_fn( @@ -442,78 +504,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ assert test_res.pvalue > 0.1 -@pytest.mark.parametrize( - "dist_args, size, cm", - [ - pytest.param( - [ - set_test_value( - pt.dvector(), - np.array([100000, 1, 1], dtype=np.float64), - ), - ], - None, - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - pt.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - ], - (10, 3), - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - pt.dmatrix(), - np.array( - [[100000, 1, 1]], - dtype=np.float64, - ), - ), - ], - (5, 4, 3), - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - pt.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - ], - (10, 4), - pytest.raises( - ValueError, match="objects cannot be broadcast to a single shape" - ), - ), - ], -) -def test_CategoricalRV(dist_args, size, cm): - rng = shared(np.random.RandomState(29402)) - g = ptr.categorical(*dist_args, size=size, rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - ) - - @pytest.mark.parametrize( "a, size, cm", [ @@ -550,7 +540,7 @@ def test_CategoricalRV(dist_args, size, cm): ], ) def test_DirichletRV(a, size, cm): - rng = shared(np.random.RandomState(29402)) + rng = shared(np.random.default_rng(29402)) g = ptr.dirichlet(a, size=size, rng=rng) g_fn = function([a], g, mode=numba_mode) @@ -568,30 +558,3 @@ def test_DirichletRV(a, size, cm): exp_res = a_val / a_val.sum(-1) res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1))) assert np.allclose(res, exp_res, atol=1e-4) - - -def test_RandomState_updates(): - rng = shared(np.random.RandomState(1)) - rng_new = shared(np.random.RandomState(2)) - - x = pt.random.normal(size=10, rng=rng) - res = function([], x, updates={rng: rng_new}, mode=numba_mode)() - - ref = np.random.RandomState(2).normal(size=10) - assert np.allclose(res, ref) - - -def test_random_Generator(): - rng = shared(np.random.default_rng(29402)) - g = ptr.normal(rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - with pytest.raises(TypeError): - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, SharedVariable | Constant) - ], - )