Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove global RandomStream #6396

Merged
merged 1 commit into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random import RandomStream
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import TensorType
Expand Down Expand Up @@ -132,6 +131,12 @@ def __hash__(self):
class MinibatchIndexRV(IntegersRV):
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")

# Work-around for https://github.com/pymc-devs/pytensor/issues/97
def make_node(self, rng, *args, **kwargs):
if rng is None:
rng = pytensor.shared(np.random.default_rng())
return super().make_node(rng, *args, **kwargs)
Comment on lines +135 to +138
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed once pymc-devs/pytensor#97 is solved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing



minibatch_index = MinibatchIndexRV()

Expand Down Expand Up @@ -184,10 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
>>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10)
"""

rng = RandomStream()
tensor, *tensors = tuple(map(at.as_tensor, (variable, *variables)))
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
slc = rng.gen(minibatch_index, 0, upper, size=batch_size)
slc = minibatch_index(0, upper, size=batch_size)
for i, v in enumerate((tensor, *tensors)):
if not valid_for_minibatch(v):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Simulator(Distribution):
----------
fn : callable
Python random simulator function. Should expect the following signature
``(rng, arg1, arg2, ... argn, size)``, where rng is a ``numpy.random.RandomStream()``
``(rng, arg1, arg2, ... argn, size)``, where rng is a ``numpy.random.Generator``
and ``size`` defines the size of the desired sample.
*unnamed_params : list of TensorVariable
Parameters used by the Simulator random function. Each parameter can be passed
Expand Down
46 changes: 0 additions & 46 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from pytensor.scalar.basic import Cast
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random import RandomStream
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
Expand Down Expand Up @@ -84,8 +83,6 @@
"join_nonshared_inputs",
"make_shared_replacements",
"generator",
"set_at_rng",
"at_rng",
"convert_observed_data",
"compile_pymc",
"constant_fold",
Expand Down Expand Up @@ -891,49 +888,6 @@ def generator(gen, default=None):
return GeneratorOp(gen, default)()


_at_rng = RandomStream()


def at_rng(random_seed=None):
"""
Get the package-level random number generator or new with specified seed.

Parameters
----------
random_seed: int
If not None
returns *new* pytensor random generator without replacing package global one

Returns
-------
`pytensor.tensor.random.utils.RandomStream` instance
`pytensor.tensor.random.utils.RandomStream`
instance passed to the most recent call of `set_at_rng`
"""
if random_seed is None:
return _at_rng
else:
ret = RandomStream(random_seed)
return ret


def set_at_rng(new_rng):
"""
Set the package-level random number generator.

Parameters
----------
new_rng: `pytensor.tensor.random.utils.RandomStream` instance
The random number generator to use.
"""
# pylint: disable=global-statement
global _at_rng
# pylint: enable=global-statement
if isinstance(new_rng, int):
new_rng = RandomStream(new_rng)
_at_rng = new_rng


def floatX_array(x):
return floatX(np.array(x))

Expand Down
2 changes: 0 additions & 2 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from fastprogress.fastprogress import progress_bar

from pymc import pytensorf
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import RandomSeed
Expand Down Expand Up @@ -155,7 +154,6 @@ def _recv_msg(self):

def _start_loop(self):
np.random.seed(self._seed)
pytensorf.set_at_rng(self._at_seed)

draw = 0
tuning = True
Expand Down
3 changes: 0 additions & 3 deletions pymc/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import pytensor
import pytest

import pymc as pm


@pytest.fixture(scope="function", autouse=True)
def pytensor_config():
Expand Down Expand Up @@ -47,4 +45,3 @@ def strict_float32():
def seeded_test():
# TODO: use this instead of SeededTest
np.random.seed(42)
pm.set_at_rng(42)
8 changes: 1 addition & 7 deletions pymc/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
from pytensor.gradient import verify_grad as at_verify_grad
from pytensor.graph import ancestors
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.random import RandomStream
from pytensor.tensor.random.op import RandomVariable

import pymc as pm

from pymc.pytensorf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng
from pymc.pytensorf import local_check_parameter_to_ninf_switch
from pymc.tests.checks import close_to
from pymc.tests.models import mv_simple, mv_simple_coarse

Expand All @@ -46,11 +45,6 @@ def setup_class(cls):

def setup_method(self):
nr.seed(self.random_seed)
self.old_at_rng = at_rng()
set_at_rng(RandomStream(self.random_seed))

def teardown_method(self):
set_at_rng(self.old_at_rng)

def get_random_state(self, reset=False):
if self.random_state is None or reset:
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def test_pickling(self, datagen):

def test_gen_cloning_with_shape_change(self, datagen):
gen = pm.generator(datagen)
gen_r = pm.at_rng().normal(size=gen.shape).T
gen_r = at.random.normal(size=gen.shape).T
X = gen.dot(gen_r)
res, _ = pytensor.scan(lambda x: x.sum(), X, n_steps=X.shape[0])
assert res.eval().shape == (50,)
Expand Down
8 changes: 0 additions & 8 deletions pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,6 @@ class ADVI(KLqp):
model: :class:`pymc.Model`
PyMC model for inference
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
start: `dict[str, np.ndarray]` or `StartDict`
starting point for inference
start_sigma: `dict[str, np.ndarray]`
Expand Down Expand Up @@ -466,8 +464,6 @@ class FullRankADVI(KLqp):
model: :class:`pymc.Model`
PyMC model for inference
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
start: `dict[str, np.ndarray]` or `StartDict`
starting point for inference

Expand Down Expand Up @@ -539,8 +535,6 @@ class SVGD(ImplicitGradient):
start: `dict[str, np.ndarray]` or `StartDict`
initial point for inference
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
kwargs: other keyword arguments passed to estimator

References
Expand Down Expand Up @@ -685,8 +679,6 @@ def fit(
model: :class:`Model`
PyMC model for inference
random_seed: None or int
leave None to use package global RandomStream or other
valid value to create instance specific one
inf_kwargs: dict
additional kwargs passed to :class:`Inference`
start: `dict[str, np.ndarray]` or `StartDict`
Expand Down