Skip to content

Commit

Permalink
Remove RandomState type in remaining backends
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 24, 2024
1 parent 58899d1 commit c597d89
Show file tree
Hide file tree
Showing 15 changed files with 50 additions and 392 deletions.
12 changes: 1 addition & 11 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import numpy as np
from numpy.random import Generator, RandomState
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
)
Expand Down Expand Up @@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)


@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state


@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
Expand Down Expand Up @@ -214,7 +205,6 @@ def sample_fn(rng, size, dtype, p):
return sample_fn


@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op, node):
Expand Down
4 changes: 0 additions & 4 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature

Expand Down Expand Up @@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs

[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op
rng_param = rv_op.rng_param(rv_node)
if isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `RandomStateType`s")
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
import pytensor.tensor.random.rewriting
import pytensor.tensor.random.utils
from pytensor.tensor.random.basic import *
from pytensor.tensor.random.op import RandomState, default_rng
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.random.utils import RandomStream
68 changes: 1 addition & 67 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,10 @@
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
)
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)


try:
Expand Down Expand Up @@ -645,7 +640,7 @@ def __call__(
@classmethod
def rng_fn_scipy(
cls,
rng: np.random.Generator | np.random.RandomState,
rng: np.random.Generator,
loc: np.ndarray | float,
scale: np.ndarray | float,
size: list[int] | int | None,
Expand Down Expand Up @@ -1880,58 +1875,6 @@ def rng_fn(cls, rng, p, size):
categorical = CategoricalRV()


class RandIntRV(RandomVariable):
r"""A discrete uniform random variable.
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
"""

name = "randint"
signature = "(),()->()"
dtype = "int64"
_print_name = ("randint", "\\operatorname{randint}")

def __call__(self, low, high=None, size=None, **kwargs):
r"""Draw samples from a discrete uniform distribution.
Signature
---------
`() -> ()`
Parameters
----------
low
Lower boundary of the output interval. All values generated will
be greater than or equal to `low`, unless `high=None`, in which case
all values generated are greater than or equal to `0` and
smaller than `low` (exclusive).
high
Upper boundary of the output interval. All values generated
will be smaller than `high` (exclusive).
size
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
independent, identically distributed samples are
returned. Default is `None`, in which case a single
sample is returned.
"""
if high is None:
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)

def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable
):
raise TypeError("`randint` is only available for `RandomStateType`s")
return super().make_node(rng, *args, **kwargs)


randint = RandIntRV()


class IntegersRV(RandomVariable):
r"""A discrete uniform random variable.
Expand Down Expand Up @@ -1971,14 +1914,6 @@ def __call__(self, low, high=None, size=None, **kwargs):
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)

def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None),
RandomGeneratorType | RandomGeneratorSharedVariable,
):
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
return super().make_node(rng, *args, **kwargs)


integers = IntegersRV()

Expand Down Expand Up @@ -2201,7 +2136,6 @@ def permutation(x, **kwargs):
"permutation",
"choice",
"integers",
"randint",
"categorical",
"multinomial",
"betabinom",
Expand Down
17 changes: 4 additions & 13 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
infer_static_shape,
)
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
explicit_expand_dims,
Expand Down Expand Up @@ -324,9 +324,8 @@ def make_node(self, rng, size, *dist_params):
Parameters
----------
rng: RandomGeneratorType or RandomStateType
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a
new one, if `None`.
rng: RandomGeneratorType
Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
size: int or Sequence
NumPy-like size parameter.
dtype: str
Expand Down Expand Up @@ -354,7 +353,7 @@ def make_node(self, rng, size, *dist_params):
rng = pytensor.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
"The type of rng should be an instance of RandomGeneratorType "
)

inferred_shape = self._infer_shape(size, dist_params)
Expand Down Expand Up @@ -436,14 +435,6 @@ def perform(self, node, inputs, output_storage):
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)


class RandomStateConstructor(AbstractRNGConstructor):
random_type = RandomStateType()
random_constructor = "RandomState"


RandomState = RandomStateConstructor()


class DefaultGeneratorMakerOp(AbstractRNGConstructor):
random_type = RandomGeneratorType()
random_constructor = "default_rng"
Expand Down
91 changes: 0 additions & 91 deletions pytensor/tensor/random/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,97 +31,6 @@ def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]


class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""

def __repr__(self):
return "RandomStateType"

def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data

if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]

for key in gen_keys:
if key not in data:
raise TypeError()

for key in state_keys:
if key not in data["state"]:
raise TypeError()

state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
# TODO: Add an option to convert to a `RandomState` instance?
return data

raise TypeError()

@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
sb = b if isinstance(b, dict) else b.get_state(legacy=False)

def _eq(sa, sb):
for key in sa:
if isinstance(sa[key], dict):
if not _eq(sa[key], sb[key]):
return False
elif isinstance(sa[key], np.ndarray):
if not np.array_equal(sa[key], sb[key]):
return False
else:
if sa[key] != sb[key]:
return False

return True

return _eq(sa, sb)

def __eq__(self, other):
return type(self) == type(other)

def __hash__(self):
return hash(type(self))


# Register `RandomStateType`'s C code for `ViewOp`.
pytensor.compile.register_view_op_c_code(
RandomStateType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)

random_state_type = RandomStateType()


class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`.
Expand Down
10 changes: 1 addition & 9 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ def __init__(
self,
seed: int | None = None,
namespace: ModuleType | None = None,
rng_ctor: Literal[
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
):
if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self
Expand All @@ -223,12 +221,6 @@ def __init__(
self.default_instance_seed = seed
self.state_updates = []
self.gen_seedgen = np.random.SeedSequence(seed)

if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))

self.rng_ctor = rng_ctor

def __getattr__(self, obj):
Expand Down
20 changes: 8 additions & 12 deletions pytensor/tensor/random/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
import numpy as np

from pytensor.compile.sharedvalue import SharedVariable, shared_constructor
from pytensor.tensor.random.type import random_generator_type, random_state_type


class RandomStateSharedVariable(SharedVariable):
def __str__(self):
return self.name or f"RandomStateSharedVariable({self.container!r})"
from pytensor.tensor.random.type import random_generator_type


class RandomGeneratorSharedVariable(SharedVariable):
def __str__(self):
return self.name or f"RandomGeneratorSharedVariable({self.container!r})"
return self.name or f"RNG({self.container!r})"


@shared_constructor.register(np.random.RandomState)
Expand All @@ -23,11 +18,12 @@ def randomgen_constructor(
):
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable
rng_type = random_state_type
elif isinstance(value, np.random.Generator):
rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type
raise TypeError(
"`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead."
)

rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type

if not borrow:
value = copy.deepcopy(value)
Expand Down
Loading

0 comments on commit c597d89

Please sign in to comment.