Skip to content

Commit

Permalink
Remove RandomState type
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 9, 2024
1 parent 4629abd commit d46f085
Show file tree
Hide file tree
Showing 15 changed files with 76 additions and 415 deletions.
12 changes: 1 addition & 11 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import numpy as np
from numpy.random import Generator, RandomState
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
)
Expand Down Expand Up @@ -52,15 +52,6 @@ def assert_size_argument_jax_compatible(node):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)


@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state


@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
Expand Down Expand Up @@ -184,7 +175,6 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op):
Expand Down
4 changes: 0 additions & 4 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT


Expand Down Expand Up @@ -265,9 +264,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs

[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op
rng_param = rv_op.rng_param(rv_node)
if isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `RandomStateType`s")
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
import pytensor.tensor.random.rewriting
import pytensor.tensor.random.utils
from pytensor.tensor.random.basic import *
from pytensor.tensor.random.op import RandomState, default_rng
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.random.utils import RandomStream
100 changes: 25 additions & 75 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@
import pytensor
from pytensor.tensor.basic import arange, as_tensor_variable, constant
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import (
broadcast_params,
)
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)


try:
Expand Down Expand Up @@ -605,7 +600,7 @@ def __call__(
@classmethod
def rng_fn_scipy(
cls,
rng: np.random.Generator | np.random.RandomState,
rng: np.random.Generator,
loc: np.ndarray | float,
scale: np.ndarray | float,
size: list[int] | int | None,
Expand Down Expand Up @@ -1548,7 +1543,7 @@ def __call__(self, n, p, size=None, **kwargs):
binomial = BinomialRV()


class NegBinomialRV(ScipyRandomVariable):
class NegBinomialRV(RandomVariable):
r"""A negative binomial discrete random variable.
The probability mass function for `nbinom` for the number :math:`k` of draws
Expand Down Expand Up @@ -1588,13 +1583,8 @@ def __call__(self, n, p, size=None, **kwargs):
"""
return super().__call__(n, p, size=size, **kwargs)

@classmethod
def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)


nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
negative_binomial = nbinom = NegBinomialRV()


class BetaBinomialRV(ScipyRandomVariable):
Expand Down Expand Up @@ -1842,58 +1832,6 @@ def rng_fn(cls, rng, p, size):
categorical = CategoricalRV()


class RandIntRV(RandomVariable):
r"""A discrete uniform random variable.
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
"""

name = "randint"
signature = "(),()->()"
dtype = "int64"
_print_name = ("randint", "\\operatorname{randint}")

def __call__(self, low, high=None, size=None, **kwargs):
r"""Draw samples from a discrete uniform distribution.
Signature
---------
`() -> ()`
Parameters
----------
low
Lower boundary of the output interval. All values generated will
be greater than or equal to `low`, unless `high=None`, in which case
all values generated are greater than or equal to `0` and
smaller than `low` (exclusive).
high
Upper boundary of the output interval. All values generated
will be smaller than `high` (exclusive).
size
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
independent, identically distributed samples are
returned. Default is `None`, in which case a single
sample is returned.
"""
if high is None:
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)

def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable
):
raise TypeError("`randint` is only available for `RandomStateType`s")
return super().make_node(rng, *args, **kwargs)


randint = RandIntRV()


class IntegersRV(RandomVariable):
r"""A discrete uniform random variable.
Expand Down Expand Up @@ -1933,14 +1871,6 @@ def __call__(self, low, high=None, size=None, **kwargs):
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)

def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None),
RandomGeneratorType | RandomGeneratorSharedVariable,
):
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
return super().make_node(rng, *args, **kwargs)


integers = IntegersRV()

Expand Down Expand Up @@ -1974,7 +1904,28 @@ def rng_fn(self, *params):
p = None
else:
rng, a, p, replace, size = params
return rng.choice(a, size, replace, p)

batch_ndim = a.ndim - self.ndims_params[0]

if size is not None:
a = np.broadcast_to(a, size + a.shape[-self.ndims_params[0] :])
if p is not None:
p = np.broadcast_to(p, size + p.shape[-1:])
elif p is not None:
a, p = broadcast_params([a, p], self.ndims_params)

if batch_ndim:
# rng.choice does not have a concept of batch dimensionn
batch_shape = a.shape[:batch_ndim]
core_shape = a.shape[batch_ndim:-1]
out = np.empty(batch_shape + core_shape, dtype=a.dtype)
for idx in np.ndindex(batch_shape):
out[idx] = rng.choice(
a[idx], size=None, replace=replace, p=None if p is None else p[idx]
)
return out
else:
return rng.choice(a, size=size, replace=replace, p=p)


def choice(a, size=None, replace=True, p=None, rng=None):
Expand Down Expand Up @@ -2079,7 +2030,6 @@ def permutation(x, **kwargs):
"permutation",
"choice",
"integers",
"randint",
"categorical",
"multinomial",
"betabinom",
Expand Down
17 changes: 4 additions & 13 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
infer_static_shape,
)
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
explicit_expand_dims,
Expand Down Expand Up @@ -326,9 +326,8 @@ def make_node(self, rng, size, *dist_params):
Parameters
----------
rng: RandomGeneratorType or RandomStateType
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a
new one, if `None`.
rng: RandomGeneratorType
Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
size: int or Sequence
NumPy-like size parameter.
dtype: str
Expand Down Expand Up @@ -356,7 +355,7 @@ def make_node(self, rng, size, *dist_params):
rng = pytensor.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
"The type of rng should be an instance of RandomGeneratorType "
)

inferred_shape = self._infer_shape(size, dist_params)
Expand Down Expand Up @@ -435,14 +434,6 @@ def perform(self, node, inputs, output_storage):
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)


class RandomStateConstructor(AbstractRNGConstructor):
random_type = RandomStateType()
random_constructor = "RandomState"


RandomState = RandomStateConstructor()


class DefaultGeneratorMakerOp(AbstractRNGConstructor):
random_type = RandomGeneratorType()
random_constructor = "default_rng"
Expand Down
91 changes: 0 additions & 91 deletions pytensor/tensor/random/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,97 +31,6 @@ def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]


class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""

def __repr__(self):
return "RandomStateType"

def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data

if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]

for key in gen_keys:
if key not in data:
raise TypeError()

for key in state_keys:
if key not in data["state"]:
raise TypeError()

state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
# TODO: Add an option to convert to a `RandomState` instance?
return data

raise TypeError()

@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
sb = b if isinstance(b, dict) else b.get_state(legacy=False)

def _eq(sa, sb):
for key in sa:
if isinstance(sa[key], dict):
if not _eq(sa[key], sb[key]):
return False
elif isinstance(sa[key], np.ndarray):
if not np.array_equal(sa[key], sb[key]):
return False
else:
if sa[key] != sb[key]:
return False

return True

return _eq(sa, sb)

def __eq__(self, other):
return type(self) == type(other)

def __hash__(self):
return hash(type(self))


# Register `RandomStateType`'s C code for `ViewOp`.
pytensor.compile.register_view_op_c_code(
RandomStateType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)

random_state_type = RandomStateType()


class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`.
Expand Down
10 changes: 1 addition & 9 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ def __init__(
self,
seed: int | None = None,
namespace: ModuleType | None = None,
rng_ctor: Literal[
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
):
if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self
Expand All @@ -222,12 +220,6 @@ def __init__(
self.default_instance_seed = seed
self.state_updates = []
self.gen_seedgen = np.random.SeedSequence(seed)

if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))

self.rng_ctor = rng_ctor

def __getattr__(self, obj):
Expand Down
Loading

0 comments on commit d46f085

Please sign in to comment.