diff --git a/pymc/parallel_sampling.py b/pymc/parallel_sampling.py index bb03b89ac0d..0312b3e09bb 100644 --- a/pymc/parallel_sampling.py +++ b/pymc/parallel_sampling.py @@ -389,6 +389,7 @@ def terminate_all(processes, patience=2): class ParallelSampler: def __init__( self, + *, draws: int, tune: int, chains: int, @@ -396,7 +397,6 @@ def __init__( seeds: Sequence["RandomSeed"], start_points: Sequence[Dict[str, np.ndarray]], step_method, - start_chain_num: int = 0, progressbar: bool = True, mp_ctx=None, ): @@ -420,7 +420,7 @@ def __init__( tune, step_method, step_method_pickled, - chain + start_chain_num, + chain, seed, start, mp_ctx, @@ -434,7 +434,6 @@ def __init__( self._max_active = cores self._in_context = False - self._start_chain_num = start_chain_num self._progress = None self._divergences = 0 diff --git a/pymc/sampling.py b/pymc/sampling.py index a3f4dcb0fbc..ad26b309b94 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -311,7 +311,6 @@ def sample( n_init: int = 200_000, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, trace: Optional[Union[BaseTrace, List[str]]] = None, - chain_idx: int = 0, chains: Optional[int] = None, cores: Optional[int] = None, tune: int = 1000, @@ -353,9 +352,6 @@ def sample( trace : backend or list This should be a backend instance, or a list of variables to track. If None or a list of variables, the NDArray backend is used. - chain_idx : int - Chain number used to store sample in backend. If ``chains`` is greater than one, chain - numbers will start here. chains : int The number of chains to sample. Running independent chains is important for some convergence statistics and can also reveal multiple modes in the posterior. If ``None``, @@ -569,7 +565,6 @@ def sample( "step": step, "start": initial_points, "trace": trace, - "chain": chain_idx, "chains": chains, "tune": tune, "progressbar": progressbar, @@ -658,7 +653,7 @@ def sample( # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: - stat = mtrace.get_sampler_stats("tune", chains=chain_idx) + stat = mtrace.get_sampler_stats("tune", chains=0) # when CompoundStep is used, the stat is 2 dimensional! if len(stat.shape) == 2: stat = stat[:, 0] @@ -734,7 +729,6 @@ def _check_start_shape(model, start: PointType): def _sample_many( draws: int, - chain: int, chains: int, start: Sequence[PointType], random_seed: Optional[Sequence[RandomSeed]], @@ -748,8 +742,6 @@ def _sample_many( ---------- draws: int The number of samples to draw - chain: int - Number of the first chain in the sequence. chains: int Total number of chains to sample. start: list @@ -768,7 +760,7 @@ def _sample_many( for i in range(chains): trace = _sample( draws=draws, - chain=chain + i, + chain=i, start=start[i], step=step, random_seed=None if random_seed is None else random_seed[i], @@ -791,7 +783,6 @@ def _sample_many( def _sample_population( draws: int, - chain: int, chains: int, start: Sequence[PointType], random_seed: RandomSeed, @@ -808,8 +799,6 @@ def _sample_population( ---------- draws : int The number of samples to draw - chain : int - The number of the first chain in the population chains : int The total number of chains in the population start : list @@ -832,7 +821,6 @@ def _sample_population( """ sampling = _prepare_iter_population( draws, - [chain + c for c in range(chains)], step, start, parallelize, @@ -952,8 +940,7 @@ def iter_sample( This should be a backend instance, or a list of variables to track. If None or a list of variables, the NDArray backend is used. chain : int, optional - Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers - will start here. + Chain number used to store sample in backend. tune : int, optional Number of iterations to tune (defaults to 0). model : Model (optional if in ``with`` context) @@ -1008,8 +995,7 @@ def _iter_sample( This should be a backend instance, or a list of variables to track. If None or a list of variables, the NDArray backend is used. chain : int, optional - Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers - will start here. + Chain number used to store sample in backend. tune : int, optional Number of iterations to tune (defaults to 0). model : Model (optional if in ``with`` context) @@ -1247,7 +1233,6 @@ def step(self, tune_stop: bool, population): def _prepare_iter_population( draws: int, - chains: list, step, start: Sequence[PointType], parallelize: bool, @@ -1262,8 +1247,6 @@ def _prepare_iter_population( ---------- draws : int The number of samples to draw - chains : list - The chain numbers in the population step : function Step function (should be or contain a population step method) start : list @@ -1282,8 +1265,7 @@ def _prepare_iter_population( _iter_population : generator Yields traces of all chains at the same time """ - # chains contains the chain numbers, but for indexing we need indices... - nchains = len(chains) + nchains = len(start) model = modelcontext(model) draws = int(draws) @@ -1327,7 +1309,7 @@ def _prepare_iter_population( trace=None, model=model, ) - for c in chains + for c in range(nchains) ] # 4. configure the PopulationStepper (expensive call) @@ -1457,7 +1439,6 @@ def _mp_sample( step, chains: int, cores: int, - chain: int, random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, @@ -1482,8 +1463,6 @@ def _mp_sample( The number of chains to sample. cores : int The number of chains to run in parallel. - chain : int - Number of the first chain. random_seed : list of random seeds Random seeds for each chain. start : list @@ -1520,26 +1499,25 @@ def _mp_sample( trace=trace, model=model, ) - for chain_number in range(chain, chain + chains) + for chain_number in range(chains) ] sampler = ps.ParallelSampler( - draws, - tune, - chains, - cores, - random_seed, - start, - step, - chain, - progressbar, + draws=draws, + tune=tune, + chains=chains, + cores=cores, + seeds=random_seed, + start_points=start, + step_method=step, + progressbar=progressbar, mp_ctx=mp_ctx, ) try: try: with sampler: for draw in sampler: - strace = traces[draw.chain - chain] + strace = traces[draw.chain] if strace.supports_sampler_stats and draw.stats is not None: strace.record(draw.point, draw.stats) else: @@ -1553,7 +1531,7 @@ def _mp_sample( callback(trace=trace, draw=draw) except ps.ParallelSamplingError as error: - strace = traces[error._chain - chain] + strace = traces[error._chain] strace._add_warnings(error._warnings) for strace in traces: strace.close() @@ -1998,18 +1976,12 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - if isinstance(_trace, MultiTrace): - # trace dict is unordered, but we want to return ppc samples in - # a predictable ordering, so sort the chain indices - chain_idx_mapping = sorted(_trace._straces.keys()) for idx in indices: if nchain > 1: # the trace object will either be a MultiTrace (and have _straces)... if hasattr(_trace, "_straces"): chain_idx, point_idx = np.divmod(idx, len_trace) chain_idx = chain_idx % nchain - # chain indices might not always start at 0, convert to proper index - chain_idx = chain_idx_mapping[chain_idx] param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) # ... or a PointList else: diff --git a/pymc/tests/test_parallel_sampling.py b/pymc/tests/test_parallel_sampling.py index 80cb49c0411..2032e7e0719 100644 --- a/pymc/tests/test_parallel_sampling.py +++ b/pymc/tests/test_parallel_sampling.py @@ -183,7 +183,16 @@ def test_iterator(): step = pm.CompoundStep([step1, step2]) start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))} - sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False) + sampler = ps.ParallelSampler( + draws=10, + tune=10, + chains=3, + cores=2, + seeds=[2, 3, 4], + start_points=[start] * 3, + step_method=step, + progressbar=False, + ) with sampler: for draw in sampler: pass diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 8e7f4c28e84..b4365d0c519 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -505,21 +505,6 @@ def test_partial_trace_sample(): assert "b" not in idata.posterior -def test_chain_idx(): - # see https://github.com/pymc-devs/pymc/issues/4469 - with pm.Model(): - mu = pm.Normal("mu") - x = pm.Normal("x", mu=mu, sigma=1, observed=np.asarray(3)) - # note draws-tune must be >100 AND we need an observed RV for this to properly - # trigger convergence checks, which is one particular case in which this failed - # before - idata = pm.sample(draws=150, tune=10, chain_idx=1) - - ppc = pm.sample_posterior_predictive(idata) - # TODO FIXME: Assert something. - ppc = pm.sample_posterior_predictive(idata, keep_size=True) - - @pytest.mark.parametrize( "n_points, tune, expected_length, expected_n_traces", [