From 0009ff321240de0d8df883ad681f3c78a16a32dd Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 14 Dec 2022 08:12:19 +0100 Subject: [PATCH] Remove global RandomStream --- pymc/data.py | 10 ++++--- pymc/distributions/simulator.py | 2 +- pymc/pytensorf.py | 46 --------------------------------- pymc/sampling/parallel.py | 2 -- pymc/tests/conftest.py | 3 --- pymc/tests/helpers.py | 8 +----- pymc/tests/test_data.py | 2 +- pymc/variational/inference.py | 8 ------ 8 files changed, 10 insertions(+), 71 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 35a99607f4d..85d86f087bc 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -29,7 +29,6 @@ from pytensor.raise_op import Assert from pytensor.scalar import Cast from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.random import RandomStream from pytensor.tensor.random.basic import IntegersRV from pytensor.tensor.subtensor import AdvancedSubtensor from pytensor.tensor.type import TensorType @@ -132,6 +131,12 @@ def __hash__(self): class MinibatchIndexRV(IntegersRV): _print_name = ("minibatch_index", r"\operatorname{minibatch\_index}") + # Work-around for https://github.com/pymc-devs/pytensor/issues/97 + def make_node(self, rng, *args, **kwargs): + if rng is None: + rng = pytensor.shared(np.random.default_rng()) + return super().make_node(rng, *args, **kwargs) + minibatch_index = MinibatchIndexRV() @@ -184,10 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: >>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10) """ - rng = RandomStream() tensor, *tensors = tuple(map(at.as_tensor, (variable, *variables))) upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)]) - slc = rng.gen(minibatch_index, 0, upper, size=batch_size) + slc = minibatch_index(0, upper, size=batch_size) for i, v in enumerate((tensor, *tensors)): if not valid_for_minibatch(v): raise ValueError( diff --git a/pymc/distributions/simulator.py b/pymc/distributions/simulator.py index a15bddce0fd..faeb21a0456 100644 --- a/pymc/distributions/simulator.py +++ b/pymc/distributions/simulator.py @@ -76,7 +76,7 @@ class Simulator(Distribution): ---------- fn : callable Python random simulator function. Should expect the following signature - ``(rng, arg1, arg2, ... argn, size)``, where rng is a ``numpy.random.RandomStream()`` + ``(rng, arg1, arg2, ... argn, size)``, where rng is a ``numpy.random.Generator`` and ``size`` defines the size of the desired sample. *unnamed_params : list of TensorVariable Parameters used by the Simulator random function. Each parameter can be passed diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 211c4a99a68..cbe7288d086 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -50,7 +50,6 @@ from pytensor.scalar.basic import Cast from pytensor.tensor.basic import _as_tensor_variable from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.random import RandomStream from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.var import ( RandomGeneratorSharedVariable, @@ -84,8 +83,6 @@ "join_nonshared_inputs", "make_shared_replacements", "generator", - "set_at_rng", - "at_rng", "convert_observed_data", "compile_pymc", "constant_fold", @@ -891,49 +888,6 @@ def generator(gen, default=None): return GeneratorOp(gen, default)() -_at_rng = RandomStream() - - -def at_rng(random_seed=None): - """ - Get the package-level random number generator or new with specified seed. - - Parameters - ---------- - random_seed: int - If not None - returns *new* pytensor random generator without replacing package global one - - Returns - ------- - `pytensor.tensor.random.utils.RandomStream` instance - `pytensor.tensor.random.utils.RandomStream` - instance passed to the most recent call of `set_at_rng` - """ - if random_seed is None: - return _at_rng - else: - ret = RandomStream(random_seed) - return ret - - -def set_at_rng(new_rng): - """ - Set the package-level random number generator. - - Parameters - ---------- - new_rng: `pytensor.tensor.random.utils.RandomStream` instance - The random number generator to use. - """ - # pylint: disable=global-statement - global _at_rng - # pylint: enable=global-statement - if isinstance(new_rng, int): - new_rng = RandomStream(new_rng) - _at_rng = new_rng - - def floatX_array(x): return floatX(np.array(x)) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 16ff777f1d4..49f586aa72c 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -28,7 +28,6 @@ from fastprogress.fastprogress import progress_bar -from pymc import pytensorf from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.util import RandomSeed @@ -155,7 +154,6 @@ def _recv_msg(self): def _start_loop(self): np.random.seed(self._seed) - pytensorf.set_at_rng(self._at_seed) draw = 0 tuning = True diff --git a/pymc/tests/conftest.py b/pymc/tests/conftest.py index 6bc987af21a..cff8b37b405 100644 --- a/pymc/tests/conftest.py +++ b/pymc/tests/conftest.py @@ -16,8 +16,6 @@ import pytensor import pytest -import pymc as pm - @pytest.fixture(scope="function", autouse=True) def pytensor_config(): @@ -47,4 +45,3 @@ def strict_float32(): def seeded_test(): # TODO: use this instead of SeededTest np.random.seed(42) - pm.set_at_rng(42) diff --git a/pymc/tests/helpers.py b/pymc/tests/helpers.py index ab0d7c7a4bc..663827f836c 100644 --- a/pymc/tests/helpers.py +++ b/pymc/tests/helpers.py @@ -26,12 +26,11 @@ from pytensor.gradient import verify_grad as at_verify_grad from pytensor.graph import ancestors from pytensor.graph.rewriting.basic import in2out -from pytensor.tensor.random import RandomStream from pytensor.tensor.random.op import RandomVariable import pymc as pm -from pymc.pytensorf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng +from pymc.pytensorf import local_check_parameter_to_ninf_switch from pymc.tests.checks import close_to from pymc.tests.models import mv_simple, mv_simple_coarse @@ -46,11 +45,6 @@ def setup_class(cls): def setup_method(self): nr.seed(self.random_seed) - self.old_at_rng = at_rng() - set_at_rng(RandomStream(self.random_seed)) - - def teardown_method(self): - set_at_rng(self.old_at_rng) def get_random_state(self, reset=False): if self.random_state is None or reset: diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index 538e1e1baa7..58b17609c68 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -553,7 +553,7 @@ def test_pickling(self, datagen): def test_gen_cloning_with_shape_change(self, datagen): gen = pm.generator(datagen) - gen_r = pm.at_rng().normal(size=gen.shape).T + gen_r = at.random.normal(size=gen.shape).T X = gen.dot(gen_r) res, _ = pytensor.scan(lambda x: x.sum(), X, n_steps=X.shape[0]) assert res.eval().shape == (50,) diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 90bee1623c0..514d05228ab 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -433,8 +433,6 @@ class ADVI(KLqp): model: :class:`pymc.Model` PyMC model for inference random_seed: None or int - leave None to use package global RandomStream or other - valid value to create instance specific one start: `dict[str, np.ndarray]` or `StartDict` starting point for inference start_sigma: `dict[str, np.ndarray]` @@ -466,8 +464,6 @@ class FullRankADVI(KLqp): model: :class:`pymc.Model` PyMC model for inference random_seed: None or int - leave None to use package global RandomStream or other - valid value to create instance specific one start: `dict[str, np.ndarray]` or `StartDict` starting point for inference @@ -539,8 +535,6 @@ class SVGD(ImplicitGradient): start: `dict[str, np.ndarray]` or `StartDict` initial point for inference random_seed: None or int - leave None to use package global RandomStream or other - valid value to create instance specific one kwargs: other keyword arguments passed to estimator References @@ -685,8 +679,6 @@ def fit( model: :class:`Model` PyMC model for inference random_seed: None or int - leave None to use package global RandomStream or other - valid value to create instance specific one inf_kwargs: dict additional kwargs passed to :class:`Inference` start: `dict[str, np.ndarray]` or `StartDict`