Skip to content

Commit

Permalink
Detach step methods from numpy global random state
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 7, 2024
1 parent b810389 commit 4ea1406
Show file tree
Hide file tree
Showing 18 changed files with 379 additions and 165 deletions.
4 changes: 2 additions & 2 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ def logdiffexp_numpy(a, b):
invlogit = sigmoid


def logbern(log_p):
def logbern(log_p, rng=None):
if np.isnan(log_p):
raise FloatingPointError("log_p can't be nan.")
return np.log(np.random.uniform()) < log_p
return np.log((rng or np.random).uniform()) < log_p


def logit(p):
Expand Down
71 changes: 42 additions & 29 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
_get_seeds_per_chain,
default_progress_theme,
drop_warning_stat,
get_random_generator,
get_untransformed_name,
is_transformed_name,
)
Expand Down Expand Up @@ -489,10 +490,15 @@ def sample(
cores : int
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
system, but at most 4.
random_seed : int, array-like of int, RandomState or Generator, optional
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
is passed, each entry will be used to seed each chain. A ValueError will be
raised if the length does not match the number of chains.
random_seed : int, array-like of int, or Generator, optional
Random seed(s) used by the sampling steps. Each step will create its own
:py:class:`~numpy.random.Generator` object to make its random draws in a way that is
indepedent from all other steppers and all other chains. If a list, tuple or array of ints
is passed, each entry will be used to seed the creation of ``Generator`` objects.
A ``ValueError`` will be raised if the length does not match the number of chains.
A ``TypeError`` will be raised if a :py:class:`~numpy.random.RandomState` object is passed.
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
easy spawning of new independent random streams that are needed by the step methods.
progressbar : bool, optional default=True
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
Expand Down Expand Up @@ -686,7 +692,8 @@ def joined_blas_limiter():

if random_seed == -1:
random_seed = None
random_seed_list = _get_seeds_per_chain(random_seed, chains)
rngs = get_random_generator(random_seed).spawn(chains)
random_seed_list = [rng.integers(2**30) for rng in rngs]

if not discard_tuned_samples and not return_inferencedata:
warnings.warn(
Expand Down Expand Up @@ -834,11 +841,11 @@ def joined_blas_limiter():
if parallel:
# For parallel sampling we can pass the list of random seeds directly, as
# global seeding will only be called inside each process
sample_args["random_seed"] = random_seed_list
sample_args["rngs"] = rngs
else:
# We pass None if the original random seed was None. The single core sampler
# methods will only set a global seed when it is not None.
sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list
sample_args["rngs"] = rngs

t_start = time.time()
if parallel:
Expand Down Expand Up @@ -989,7 +996,7 @@ def _sample_many(
chains: int,
traces: Sequence[IBaseTrace],
start: Sequence[PointType],
random_seed: Sequence[RandomSeed] | None,
rngs: Sequence[np.random.Generator],
step: Step,
callback: SamplingIteratorCallback | None = None,
**kwargs,
Expand All @@ -1004,8 +1011,8 @@ def _sample_many(
Total number of chains to sample.
start: list
Starting points for each chain
random_seed: list of random seeds, optional
A list of seeds, one for each chain
rngs: list of random Generators
A list of :py:class:`~numpy.random.Generator` objects, one for each chain
step: function
Step function
"""
Expand All @@ -1016,7 +1023,7 @@ def _sample_many(
start=start[i],
step=step,
trace=traces[i],
random_seed=None if random_seed is None else random_seed[i],
rng=rngs[i],
callback=callback,
**kwargs,
)
Expand All @@ -1027,7 +1034,7 @@ def _sample(
*,
chain: int,
progressbar: bool,
random_seed: RandomSeed,
rng: np.random.Generator,
start: PointType,
draws: int,
step: Step,
Expand Down Expand Up @@ -1075,7 +1082,7 @@ def _sample(
chain=chain,
tune=tune,
model=model,
random_seed=random_seed,
rng=rng,
callback=callback,
)
_pbar_data = {"chain": chain, "divergences": 0}
Expand Down Expand Up @@ -1114,8 +1121,8 @@ def _iter_sample(
trace: IBaseTrace,
chain: int = 0,
tune: int = 0,
rng: np.random.Generator,
model: Model | None = None,
random_seed: RandomSeed = None,
callback: SamplingIteratorCallback | None = None,
) -> Iterator[bool]:
"""Generator for sampling one chain. (Used in singleprocess sampling.)
Expand Down Expand Up @@ -1149,8 +1156,7 @@ def _iter_sample(
if draws < 1:
raise ValueError("Argument `draws` must be greater than 0.")

if random_seed is not None:
np.random.seed(random_seed)
step.set_rng(rng)

point = start

Expand Down Expand Up @@ -1193,7 +1199,7 @@ def _mp_sample(
step,
chains: int,
cores: int,
random_seed: Sequence[RandomSeed],
rngs: Sequence[np.random.Generator],
start: Sequence[PointType],
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
Expand All @@ -1218,8 +1224,8 @@ def _mp_sample(
The number of chains to sample.
cores : int
The number of chains to run in parallel.
random_seed : list of random seeds
Random seeds for each chain.
rngs: list of random Generators
A list of :py:class:`~numpy.random.Generator` objects, one for each chain
start : list
Starting points for each chain.
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
Expand Down Expand Up @@ -1247,7 +1253,7 @@ def _mp_sample(
tune=tune,
chains=chains,
cores=cores,
seeds=random_seed,
rngs=rngs,
start_points=start,
step_method=step,
progressbar=progressbar,
Expand Down Expand Up @@ -1446,12 +1452,12 @@ def init_nuts(
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
elif init == "jitter+adapt_diag":
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10, rng=random_seed_list[0])
elif init == "jitter+adapt_diag_grad":
mean = np.mean(apoints_data, axis=0)
var = np.ones_like(mean)
Expand All @@ -1468,6 +1474,7 @@ def init_nuts(
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
rng=random_seed_list[0],
)
elif init == "advi+adapt_diag":
approx = pm.fit(
Expand All @@ -1488,7 +1495,9 @@ def init_nuts(
mean = approx.mean.get_value()
weight = 50
n = len(cov)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight)
potential = quadpotential.QuadPotentialDiagAdapt(
n, mean, cov, weight, rng=random_seed_list[0]
)
elif init == "advi":
approx = pm.fit(
random_seed=random_seed_list[0],
Expand All @@ -1504,7 +1513,7 @@ def init_nuts(
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
elif init == "advi_map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
approx = pm.MeanField(model=model, start=start)
Expand All @@ -1521,28 +1530,32 @@ def init_nuts(
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
potential = quadpotential.QuadPotentialDiag(cov, rng=random_seed_list[0])
elif init == "map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
cov = -pm.find_hessian(point=start, negate_output=False)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
potential = quadpotential.QuadPotentialFull(cov, rng=random_seed_list[0])
elif init == "adapt_full":
mean = np.mean(apoints_data * chains, axis=0)
initial_point = initial_points[0]
initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
cov = np.eye(initial_point_model_size)
potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
potential = quadpotential.QuadPotentialFullAdapt(
initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
)
elif init == "jitter+adapt_full":
mean = np.mean(apoints_data, axis=0)
initial_point = initial_points[0]
initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars)
cov = np.eye(initial_point_model_size)
potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10)
potential = quadpotential.QuadPotentialFullAdapt(
initial_point_model_size, mean, cov, 10, rng=random_seed_list[0]
)
else:
raise ValueError(f"Unknown initializer: {init}.")

step = pm.NUTS(potential=potential, model=model, **kwargs)
step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)

# Filter deterministics from initial_points
value_var_names = [var.name for var in model.value_vars]
Expand Down
28 changes: 16 additions & 12 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import CustomProgress, RandomSeed, default_progress_theme
from pymc.util import CustomProgress, default_progress_theme

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,15 +93,18 @@ def __init__(
shared_point,
draws: int,
tune: int,
seed,
rng: np.random.Generator,
seed_seq: np.random.SeedSequence,
blas_cores,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
rng = np.random.Generator(type(rng.bit_generator)(seed_seq))
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
self._shared_point = shared_point
self._seed = seed
self._at_seed = seed + 1
self._rng = rng
self._draws = draws
self._tune = tune
self._blas_cores = blas_cores
Expand Down Expand Up @@ -159,7 +162,7 @@ def _recv_msg(self):
return self._msg_pipe.recv()

def _start_loop(self):
np.random.seed(self._seed)
self._step_method.set_rng(self._rng)

draw = 0
tuning = True
Expand Down Expand Up @@ -210,7 +213,7 @@ def __init__(
step_method,
step_method_pickled,
chain: int,
seed,
rng: np.random.Generator,
start: dict[str, np.ndarray],
blas_cores,
mp_ctx,
Expand Down Expand Up @@ -260,7 +263,8 @@ def __init__(
self._shared_point,
draws,
tune,
seed,
rng,
rng.bit_generator.seed_seq,
blas_cores,
),
)
Expand Down Expand Up @@ -379,16 +383,16 @@ def __init__(
tune: int,
chains: int,
cores: int,
seeds: Sequence["RandomSeed"],
rngs: Sequence[np.random.Generator],
start_points: Sequence[dict[str, np.ndarray]],
step_method,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
blas_cores: int | None = None,
mp_ctx=None,
):
if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError(f"Number of seeds and start_points must be {chains}.")
if any(len(arg) != chains for arg in [rngs, start_points]):
raise ValueError(f"Number of rngs and start_points must be {chains}.")

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
Expand Down Expand Up @@ -416,12 +420,12 @@ def __init__(
step_method,
step_method_pickled,
chain,
seed,
rng,
start,
blas_cores,
mp_ctx,
)
for chain, seed, start in zip(range(chains), seeds, start_points)
for chain, rng, start in zip(range(chains), rngs, start_points)
]

self._inactive = self._samplers.copy()
Expand Down
Loading

0 comments on commit 4ea1406

Please sign in to comment.