Skip to content

Commit

Permalink
Drop support for custom chain numbering start
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and ricardoV94 committed Oct 7, 2022
1 parent 9abf4e0 commit 91dbfd2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 64 deletions.
5 changes: 2 additions & 3 deletions pymc/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,14 @@ def terminate_all(processes, patience=2):
class ParallelSampler:
def __init__(
self,
*,
draws: int,
tune: int,
chains: int,
cores: int,
seeds: Sequence["RandomSeed"],
start_points: Sequence[Dict[str, np.ndarray]],
step_method,
start_chain_num: int = 0,
progressbar: bool = True,
mp_ctx=None,
):
Expand All @@ -420,7 +420,7 @@ def __init__(
tune,
step_method,
step_method_pickled,
chain + start_chain_num,
chain,
seed,
start,
mp_ctx,
Expand All @@ -434,7 +434,6 @@ def __init__(
self._max_active = cores

self._in_context = False
self._start_chain_num = start_chain_num

self._progress = None
self._divergences = 0
Expand Down
62 changes: 17 additions & 45 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def sample(
n_init: int = 200_000,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
chain_idx: int = 0,
chains: Optional[int] = None,
cores: Optional[int] = None,
tune: int = 1000,
Expand Down Expand Up @@ -353,9 +352,6 @@ def sample(
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain_idx : int
Chain number used to store sample in backend. If ``chains`` is greater than one, chain
numbers will start here.
chains : int
The number of chains to sample. Running independent chains is important for some
convergence statistics and can also reveal multiple modes in the posterior. If ``None``,
Expand Down Expand Up @@ -569,7 +565,6 @@ def sample(
"step": step,
"start": initial_points,
"trace": trace,
"chain": chain_idx,
"chains": chains,
"tune": tune,
"progressbar": progressbar,
Expand Down Expand Up @@ -658,7 +653,7 @@ def sample(
# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if "tune" in mtrace.stat_names:
stat = mtrace.get_sampler_stats("tune", chains=chain_idx)
stat = mtrace.get_sampler_stats("tune", chains=0)
# when CompoundStep is used, the stat is 2 dimensional!
if len(stat.shape) == 2:
stat = stat[:, 0]
Expand Down Expand Up @@ -734,7 +729,6 @@ def _check_start_shape(model, start: PointType):

def _sample_many(
draws: int,
chain: int,
chains: int,
start: Sequence[PointType],
random_seed: Optional[Sequence[RandomSeed]],
Expand All @@ -748,8 +742,6 @@ def _sample_many(
----------
draws: int
The number of samples to draw
chain: int
Number of the first chain in the sequence.
chains: int
Total number of chains to sample.
start: list
Expand All @@ -768,7 +760,7 @@ def _sample_many(
for i in range(chains):
trace = _sample(
draws=draws,
chain=chain + i,
chain=i,
start=start[i],
step=step,
random_seed=None if random_seed is None else random_seed[i],
Expand All @@ -791,7 +783,6 @@ def _sample_many(

def _sample_population(
draws: int,
chain: int,
chains: int,
start: Sequence[PointType],
random_seed: RandomSeed,
Expand All @@ -808,8 +799,6 @@ def _sample_population(
----------
draws : int
The number of samples to draw
chain : int
The number of the first chain in the population
chains : int
The total number of chains in the population
start : list
Expand All @@ -832,7 +821,6 @@ def _sample_population(
"""
sampling = _prepare_iter_population(
draws,
[chain + c for c in range(chains)],
step,
start,
parallelize,
Expand Down Expand Up @@ -952,8 +940,7 @@ def iter_sample(
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int, optional
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
will start here.
Chain number used to store sample in backend.
tune : int, optional
Number of iterations to tune (defaults to 0).
model : Model (optional if in ``with`` context)
Expand Down Expand Up @@ -1008,8 +995,7 @@ def _iter_sample(
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int, optional
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
will start here.
Chain number used to store sample in backend.
tune : int, optional
Number of iterations to tune (defaults to 0).
model : Model (optional if in ``with`` context)
Expand Down Expand Up @@ -1247,7 +1233,6 @@ def step(self, tune_stop: bool, population):

def _prepare_iter_population(
draws: int,
chains: list,
step,
start: Sequence[PointType],
parallelize: bool,
Expand All @@ -1262,8 +1247,6 @@ def _prepare_iter_population(
----------
draws : int
The number of samples to draw
chains : list
The chain numbers in the population
step : function
Step function (should be or contain a population step method)
start : list
Expand All @@ -1282,8 +1265,7 @@ def _prepare_iter_population(
_iter_population : generator
Yields traces of all chains at the same time
"""
# chains contains the chain numbers, but for indexing we need indices...
nchains = len(chains)
nchains = len(start)
model = modelcontext(model)
draws = int(draws)

Expand Down Expand Up @@ -1327,7 +1309,7 @@ def _prepare_iter_population(
trace=None,
model=model,
)
for c in chains
for c in range(nchains)
]

# 4. configure the PopulationStepper (expensive call)
Expand Down Expand Up @@ -1457,7 +1439,6 @@ def _mp_sample(
step,
chains: int,
cores: int,
chain: int,
random_seed: Sequence[RandomSeed],
start: Sequence[PointType],
progressbar: bool = True,
Expand All @@ -1482,8 +1463,6 @@ def _mp_sample(
The number of chains to sample.
cores : int
The number of chains to run in parallel.
chain : int
Number of the first chain.
random_seed : list of random seeds
Random seeds for each chain.
start : list
Expand Down Expand Up @@ -1520,26 +1499,25 @@ def _mp_sample(
trace=trace,
model=model,
)
for chain_number in range(chain, chain + chains)
for chain_number in range(chains)
]

sampler = ps.ParallelSampler(
draws,
tune,
chains,
cores,
random_seed,
start,
step,
chain,
progressbar,
draws=draws,
tune=tune,
chains=chains,
cores=cores,
seeds=random_seed,
start_points=start,
step_method=step,
progressbar=progressbar,
mp_ctx=mp_ctx,
)
try:
try:
with sampler:
for draw in sampler:
strace = traces[draw.chain - chain]
strace = traces[draw.chain]
if strace.supports_sampler_stats and draw.stats is not None:
strace.record(draw.point, draw.stats)
else:
Expand All @@ -1553,7 +1531,7 @@ def _mp_sample(
callback(trace=trace, draw=draw)

except ps.ParallelSamplingError as error:
strace = traces[error._chain - chain]
strace = traces[error._chain]
strace._add_warnings(error._warnings)
for strace in traces:
strace.close()
Expand Down Expand Up @@ -1998,18 +1976,12 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
if isinstance(_trace, MultiTrace):
# trace dict is unordered, but we want to return ppc samples in
# a predictable ordering, so sort the chain indices
chain_idx_mapping = sorted(_trace._straces.keys())
for idx in indices:
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
chain_idx = chain_idx % nchain
# chain indices might not always start at 0, convert to proper index
chain_idx = chain_idx_mapping[chain_idx]
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
# ... or a PointList
else:
Expand Down
11 changes: 10 additions & 1 deletion pymc/tests/test_parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,16 @@ def test_iterator():
step = pm.CompoundStep([step1, step2])

start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}
sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False)
sampler = ps.ParallelSampler(
draws=10,
tune=10,
chains=3,
cores=2,
seeds=[2, 3, 4],
start_points=[start] * 3,
step_method=step,
progressbar=False,
)
with sampler:
for draw in sampler:
pass
Expand Down
15 changes: 0 additions & 15 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,21 +505,6 @@ def test_partial_trace_sample():
assert "b" not in idata.posterior


def test_chain_idx():
# see https://github.com/pymc-devs/pymc/issues/4469
with pm.Model():
mu = pm.Normal("mu")
x = pm.Normal("x", mu=mu, sigma=1, observed=np.asarray(3))
# note draws-tune must be >100 AND we need an observed RV for this to properly
# trigger convergence checks, which is one particular case in which this failed
# before
idata = pm.sample(draws=150, tune=10, chain_idx=1)

ppc = pm.sample_posterior_predictive(idata)
# TODO FIXME: Assert something.
ppc = pm.sample_posterior_predictive(idata, keep_size=True)


@pytest.mark.parametrize(
"n_points, tune, expected_length, expected_n_traces",
[
Expand Down

0 comments on commit 91dbfd2

Please sign in to comment.