Skip to content

Commit

Permalink
Refactor RandomState/Generator reseeding logic in initial_point into …
Browse files Browse the repository at this point in the history
…aesaraf
  • Loading branch information
ricardoV94 committed May 19, 2022
1 parent 00d4372 commit 367afef
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
28 changes: 28 additions & 0 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,34 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
)


def find_rng_nodes(variables: Iterable[Optional[TensorVariable]]):
"""Return RNG variables in a graph"""
return [
node
for node in graph_inputs(variables)
if isinstance(
node,
(
at.random.var.RandomStateSharedVariable,
at.random.var.RandomGeneratorSharedVariable,
),
)
]


def reseed_rngs(rngs: Iterable[SharedVariable], seed: Optional[int]) -> None:
"""Create a new set of RandomState/Generator for each rng based on a seed"""
bit_generators = [
np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
]
for rng, bit_generator in zip(rngs, bit_generators):
if isinstance(rng, at.random.var.RandomStateSharedVariable):
new_rng = np.random.RandomState(bit_generator)
else:
new_rng = np.random.Generator(bit_generator)
rng.set_value(new_rng, borrow=True)


def compile_pymc(
inputs, outputs, mode=None, **kwargs
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
Expand Down
28 changes: 3 additions & 25 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import aesara.tensor as at
import numpy as np

from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import compile_pymc
from pymc.aesaraf import compile_pymc, find_rng_nodes, reseed_rngs
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name

StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]]
Expand Down Expand Up @@ -150,19 +150,6 @@ def make_initial_point_fn(
If `True` the returned variables will correspond to transformed initial values.
"""

def find_rng_nodes(variables):
return [
node
for node in graph_inputs(variables)
if isinstance(
node,
(
at.random.var.RandomStateSharedVariable,
at.random.var.RandomGeneratorSharedVariable,
),
)
]

sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
**model.initial_values,
Expand Down Expand Up @@ -208,16 +195,7 @@ def make_seeded_function(func):

@functools.wraps(func)
def inner(seed, *args, **kwargs):
seeds = [
np.random.PCG64(sub_seed)
for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
]
for rng, seed in zip(rngs, seeds):
if isinstance(rng, at.random.var.RandomStateSharedVariable):
new_rng = np.random.RandomState(seed)
else:
new_rng = np.random.Generator(seed)
rng.set_value(new_rng, True)
reseed_rngs(rngs, seed)
values = func(*args, **kwargs)
return dict(zip(varnames, values))

Expand Down

0 comments on commit 367afef

Please sign in to comment.