diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 807a0928f3a..301ff657e87 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,6 +42,7 @@ jobs: pymc/tests/distributions/test_logprob.py pymc/tests/test_aesaraf.py pymc/tests/test_math.py + pymc/tests/backends/test_base.py pymc/tests/backends/test_ndarray.py pymc/tests/step_methods/hmc/test_hmc.py pymc/tests/test_func_utils.py @@ -60,6 +61,7 @@ jobs: pymc/tests/distributions/test_simulator.py pymc/tests/distributions/test_truncated.py pymc/tests/sampling/test_forward.py + pymc/tests/sampling/test_population.py pymc/tests/stats/test_convergence.py - | diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index fbc9f971d31..d13a6dc836b 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -60,7 +60,34 @@ Saved backends can be loaded using `arviz.from_netcdf` """ +from copy import copy +from typing import Dict, List, Optional + from pymc.backends.arviz import predictions_to_inference_data, to_inference_data +from pymc.backends.base import BaseTrace from pymc.backends.ndarray import NDArray, point_list_to_multitrace __all__ = ["to_inference_data", "predictions_to_inference_data"] + + +def _init_trace( + *, + expected_length: int, + chain_number: int, + stats_dtypes: List[Dict[str, type]], + trace: Optional[BaseTrace], + model, +) -> BaseTrace: + """Initializes a trace backend for a chain.""" + strace: BaseTrace + if trace is None: + strace = NDArray(model=model) + elif isinstance(trace, BaseTrace): + if len(trace) > 0: + raise ValueError("Continuation of traces is no longer supported.") + strace = copy(trace) + else: + raise NotImplementedError(f"Unsupported `trace`: {trace}") + + strace.setup(expected_length, chain_number, stats_dtypes) + return strace diff --git a/pymc/backends/base.py b/pymc/backends/base.py index ef6ee1cd642..812dbc2465c 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -21,6 +21,7 @@ import warnings from abc import ABC +from typing import List, Sequence, Tuple, cast import aesara.tensor as at import numpy as np @@ -561,3 +562,30 @@ def _squeeze_cat(results, combine, squeeze): if squeeze and len(results) == 1: results = results[0] return results + + +def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]: + """ + Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized. + + We get here after a ``KeyboardInterrupt``, and so the different + traces have different lengths. We therefore pick the number of + traces such that (number of traces) * (length of shortest trace) + is maximised. + """ + if not traces: + raise ValueError("No traces to slice.") + + lengths = [max(0, len(trace) - tune) for trace in traces] + if not sum(lengths): + raise ValueError("Not enough samples to build a trace.") + + idxs = np.argsort(lengths) + l_sort = np.array(lengths)[idxs] + + use_until = cast(int, np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1])) + final_length = l_sort[use_until] + + take_idx = cast(Sequence[int], idxs[use_until:]) + sliced_traces = [traces[idx] for idx in take_idx] + return sliced_traces, final_length + tune diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4a497ff4681..ac04f0f4654 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -22,10 +22,9 @@ from collections import defaultdict from copy import copy -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast +from typing import Iterator, List, Optional, Sequence, Tuple, Union import aesara.gradient as tg -import cloudpickle import numpy as np from arviz import InferenceData @@ -34,8 +33,8 @@ import pymc as pm -from pymc.backends.base import BaseTrace, MultiTrace -from pymc.backends.ndarray import NDArray +from pymc.backends import _init_trace +from pymc.backends.base import BaseTrace, MultiTrace, _choose_chains from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import ( @@ -46,7 +45,8 @@ ) from pymc.model import Model, modelcontext from pymc.sampling.parallel import Draw, _cpu_count -from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks +from pymc.sampling.population import _sample_population +from pymc.stats.convergence import log_warning_stats, run_convergence_checks from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential @@ -226,7 +226,7 @@ def sample( init: str = "auto", n_init: int = 200_000, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, - trace: Optional[Union[BaseTrace, List[str]]] = None, + trace: Optional[BaseTrace] = None, chains: Optional[int] = None, cores: Optional[int] = None, tune: int = 1000, @@ -266,9 +266,9 @@ def sample( Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`. The keys should be names of transformed random variables. Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. - trace : backend or list - This should be a backend instance, or a list of variables to track. - If None or a list of variables, the NDArray backend is used. + trace : backend, optional + A backend instance or None. + If None, the NDArray backend is used. chains : int The number of chains to sample. Running independent chains is important for some convergence statistics and can also reveal multiple modes in the posterior. If ``None``, @@ -401,6 +401,11 @@ def sample( kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept") else: kwargs = {"nuts": {"target_accept": kwargs.pop("target_accept")}} + if isinstance(trace, list): + raise DeprecationWarning( + "We have removed support for partial traces because it simplified things." + " Please open an issue if & why this is a problem for you." + ) model = modelcontext(model) if not model.free_RVs: @@ -710,64 +715,6 @@ def _sample_many( return MultiTrace(traces) -def _sample_population( - draws: int, - chains: int, - start: Sequence[PointType], - random_seed: RandomSeed, - step, - tune: int, - model, - progressbar: bool = True, - parallelize: bool = False, - **kwargs, -) -> MultiTrace: - """Performs sampling of a population of chains using the ``PopulationStepper``. - - Parameters - ---------- - draws : int - The number of samples to draw - chains : int - The total number of chains in the population - start : list - Start points for each chain - random_seed : single random seed, optional - step : function - Step function (should be or contain a population step method) - tune : int - Number of iterations to tune. - model : Model (optional if in ``with`` context) - progressbar : bool - Show progress bars? (defaults to True) - parallelize : bool - Setting for multiprocess parallelization - - Returns - ------- - trace : MultiTrace - Contains samples of all chains - """ - sampling = _prepare_iter_population( - draws, - step, - start, - parallelize, - tune=tune, - model=model, - random_seed=random_seed, - progressbar=progressbar, - ) - - if progressbar: - sampling = progress_bar(sampling, total=draws, display=progressbar) - - latest_traces = None - for it, traces in enumerate(sampling): - latest_traces = traces - return MultiTrace(latest_traces) - - def _sample( *, chain: int, @@ -776,7 +723,7 @@ def _sample( start: PointType, draws: int, step=None, - trace: Optional[Union[BaseTrace, List[str]]] = None, + trace: Optional[BaseTrace] = None, tune: int, model: Optional[Model] = None, callback=None, @@ -801,9 +748,9 @@ def _sample( The number of samples to draw step : function Step function - trace : backend or list - This should be a backend instance, or a list of variables to track. - If None or a list of variables, the NDArray backend is used. + trace : backend, optional + A backend instance or None. + If None, the NDArray backend is used. tune : int Number of iterations to tune. model : Model (optional if in ``with`` context) @@ -902,7 +849,7 @@ def _iter_sample( draws: int, step, start: PointType, - trace: Optional[Union[BaseTrace, List[str]]] = None, + trace: Optional[BaseTrace] = None, chain: int = 0, tune: int = 0, model=None, @@ -920,9 +867,9 @@ def _iter_sample( start : dict Starting point in parameter space (or partial point). Must contain numeric (transformed) initial values for all (transformed) free variables. - trace : backend or list - This should be a backend instance, or a list of variables to track. - If None or a list of variables, the NDArray backend is used. + trace : backend, optional + A backend instance or None. + If None, the NDArray backend is used. chain : int, optional Chain number used to store sample in backend. tune : int, optional @@ -955,7 +902,7 @@ def _iter_sample( strace: BaseTrace = _init_trace( expected_length=draws + tune, - step=step, + stats_dtypes=step.stats_dtypes, chain_number=chain, trace=trace, model=model, @@ -972,7 +919,7 @@ def _iter_sample( if i == 0 and hasattr(step, "iter_count"): step.iter_count = 0 if i == tune: - step = stop_tuning(step) + step.stop_tuning() if step.generates_stats: point, stats = step.step(point) strace.record(point, stats) @@ -980,7 +927,7 @@ def _iter_sample( diverging = i > tune and stats and stats[0].get("diverging") else: point = step.step(point) - strace.record(point) + strace.record(point, []) if callback is not None: callback( trace=strace, @@ -998,359 +945,6 @@ def _iter_sample( strace.close() -class PopulationStepper: - """Wraps population of step methods to step them in parallel with single or multiprocessing.""" - - def __init__(self, steppers, parallelize: bool, progressbar: bool = True): - """Use multiprocessing to parallelize chains. - - Falls back to sequential evaluation if multiprocessing fails. - - In the multiprocessing mode of operation, a new process is started for each - chain/stepper and Pipes are used to communicate with the main process. - - Parameters - ---------- - steppers : list - A collection of independent step methods, one for each chain. - parallelize : bool - Indicates if parallelization via multiprocessing is desired. - progressbar : bool - Should we display a progress bar showing relative progress? - """ - self.nchains = len(steppers) - self.is_parallelized = False - self._primary_ends = [] - self._processes = [] - self._steppers = steppers - if parallelize: - try: - # configure a child process for each stepper - _log.info( - "Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`." - ) - import multiprocessing - - for c, stepper in ( - enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) - ): - secondary_end, primary_end = multiprocessing.Pipe() - stepper_dumps = cloudpickle.dumps(stepper, protocol=4) - process = multiprocessing.Process( - target=self.__class__._run_secondary, - args=(c, stepper_dumps, secondary_end), - name=f"ChainWalker{c}", - ) - # we want the child process to exit if the parent is terminated - process.daemon = True - # Starting the process might fail and takes time. - # By doing it in the constructor, the sampling progress bar - # will not be confused by the process start. - process.start() - self._primary_ends.append(primary_end) - self._processes.append(process) - self.is_parallelized = True - except Exception: - _log.info( - "Population parallelization failed. " - "Falling back to sequential stepping of chains." - ) - _log.debug("Error was: ", exc_info=True) - else: - _log.info( - "Chains are not parallelized. You can enable this by passing " - "`pm.sample(cores=n)`, where n > 1." - ) - return super().__init__() - - def __enter__(self): - """Do nothing: processes are already started in ``__init__``.""" - return - - def __exit__(self, exc_type, exc_val, exc_tb): - if len(self._processes) > 0: - try: - for primary_end in self._primary_ends: - primary_end.send(None) - for process in self._processes: - process.join(timeout=3) - except Exception: - _log.warning("Termination failed.") - return - - @staticmethod - def _run_secondary(c, stepper_dumps, secondary_end): - """This method is started on a separate process to perform stepping of a chain. - - Parameters - ---------- - c : int - number of this chain - stepper : BlockedStep - a step method such as CompoundStep - secondary_end : multiprocessing.connection.PipeConnection - This is our connection to the main process - """ - # re-seed each child process to make them unique - np.random.seed(None) - try: - stepper = cloudpickle.loads(stepper_dumps) - # the stepper is not necessarily a PopulationArraySharedStep itself, - # but rather a CompoundStep. PopulationArrayStepShared.population - # has to be updated, therefore we identify the substeppers first. - population_steppers = [] - for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]: - if isinstance(sm, PopulationArrayStepShared): - population_steppers.append(sm) - while True: - incoming = secondary_end.recv() - # receiving a None is the signal to exit - if incoming is None: - break - tune_stop, population = incoming - if tune_stop: - stop_tuning(stepper) - # forward the population to the PopulationArrayStepShared objects - # This is necessary because due to the process fork, the population - # object is no longer shared between the steppers. - for popstep in population_steppers: - popstep.population = population - update = stepper.step(population[c]) - secondary_end.send(update) - except Exception: - _log.exception(f"ChainWalker{c}") - return - - def step(self, tune_stop: bool, population): - """Step the entire population of chains. - - Parameters - ---------- - tune_stop : bool - Indicates if the condition (i == tune) is fulfilled - population : list - Current Points of all chains - - Returns - ------- - update : list - List of (Point, stats) tuples for all chains - """ - updates = [None] * self.nchains - if self.is_parallelized: - for c in range(self.nchains): - self._primary_ends[c].send((tune_stop, population)) - # Blockingly get the step outcomes - for c in range(self.nchains): - updates[c] = self._primary_ends[c].recv() - else: - for c in range(self.nchains): - if tune_stop: - self._steppers[c] = stop_tuning(self._steppers[c]) - updates[c] = self._steppers[c].step(population[c]) - return updates - - -def _prepare_iter_population( - draws: int, - step, - start: Sequence[PointType], - parallelize: bool, - tune: int, - model=None, - random_seed: RandomSeed = None, - progressbar=True, -) -> Iterator[Sequence[BaseTrace]]: - """Prepare a PopulationStepper and traces for population sampling. - - Parameters - ---------- - draws : int - The number of samples to draw - step : function - Step function (should be or contain a population step method) - start : list - Start points for each chain - parallelize : bool - Setting for multiprocess parallelization - tune : int - Number of iterations to tune. - model : Model (optional if in ``with`` context) - random_seed : single random seed, optional - progressbar : bool - ``progressbar`` argument for the ``PopulationStepper``, (defaults to True) - - Returns - ------- - _iter_population : generator - Yields traces of all chains at the same time - """ - nchains = len(start) - model = modelcontext(model) - draws = int(draws) - - if draws < 1: - raise ValueError("Argument `draws` should be above 0.") - - if random_seed is not None: - np.random.seed(random_seed) - - # The initialization of traces, samplers and points must happen in the right order: - # 1. population of points is created - # 2. steppers are initialized and linked to the points object - # 3. traces are initialized - # 4. a PopulationStepper is configured for parallelized stepping - - # 1. create a population (points) that tracks each chain - # it is updated as the chains are advanced - population = [start[c] for c in range(nchains)] - - # 2. Set up the steppers - steppers: List[Step] = [] - for c in range(nchains): - # need indepenent samplers for each chain - # it is important to copy the actual steppers (but not the delta_logp) - if isinstance(step, CompoundStep): - chainstep = CompoundStep([copy(m) for m in step.methods]) - else: - chainstep = copy(step) - # link population samplers to the shared population state - for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: - if isinstance(sm, PopulationArrayStepShared): - sm.link_population(population, c) - steppers.append(chainstep) - - # 3. Initialize a BaseTrace for each chain - traces: List[BaseTrace] = [ - _init_trace( - expected_length=draws + tune, - step=steppers[c], - chain_number=c, - trace=None, - model=model, - ) - for c in range(nchains) - ] - - # 4. configure the PopulationStepper (expensive call) - popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar) - - # Because the preparations above are expensive, the actual iterator is - # in another method. This way the progbar will not be disturbed. - return _iter_population(draws, tune, popstep, steppers, traces, population) - - -def _iter_population( - draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points -) -> Iterator[Sequence[BaseTrace]]: - """Iterate a ``PopulationStepper``. - - Parameters - ---------- - draws : int - number of draws per chain - tune : int - number of tuning steps - popstep : PopulationStepper - the helper object for (parallelized) stepping of chains - steppers : list - The step methods for each chain - traces : list - Traces for each chain - points : list - population of chain states - - Yields - ------ - traces : list - List of trace objects of the individual chains - """ - try: - with popstep: - # iterate draws of all chains - for i in range(draws): - # this call steps all chains and returns a list of (point, stats) - # the `popstep` may interact with subprocesses internally - updates = popstep.step(i == tune, points) - - # apply the update to the points and record to the traces - for c, strace in enumerate(traces): - if steppers[c].generates_stats: - points[c], stats = updates[c] - strace.record(points[c], stats) - log_warning_stats(stats) - else: - points[c] = updates[c] - strace.record(points[c]) - # yield the state of all chains in parallel - yield traces - except KeyboardInterrupt: - for c, strace in enumerate(traces): - strace.close() - if hasattr(steppers[c], "report"): - steppers[c].report._finalize(strace) - raise - except BaseException: - for c, strace in enumerate(traces): - strace.close() - raise - else: - for c, strace in enumerate(traces): - strace.close() - if hasattr(steppers[c], "report"): - steppers[c].report._finalize(strace) - - -def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> BaseTrace: - """Selects or creates a NDArray trace backend for a particular chain. - - Parameters - ---------- - trace : BaseTrace, list, or None - This should be a BaseTrace, or list of variables to track. - If None or a list of variables, the NDArray backend is used. - **kwds : - keyword arguments to forward to the backend creation - - Returns - ------- - trace : BaseTrace - The incoming, or a brand new trace object. - """ - if isinstance(trace, BaseTrace) and len(trace) > 0: - raise ValueError("Continuation of traces is no longer supported.") - if isinstance(trace, MultiTrace): - raise ValueError("Starting from existing MultiTrace objects is no longer supported.") - - if isinstance(trace, BaseTrace): - return trace - if trace is None: - return NDArray(**kwds) - - return NDArray(vars=trace, **kwds) - - -def _init_trace( - *, - expected_length: int, - step: Step, - chain_number: int, - trace: Optional[Union[BaseTrace, List[str]]], - model, -) -> BaseTrace: - """Extracted helper function to create trace backends for each chain.""" - if trace is not None: - strace = _choose_backend(copy(trace), model=model) - else: - strace = _choose_backend(None, model=model) - - if step.generates_stats: - strace.setup(expected_length, chain_number, step.stats_dtypes) - else: - strace.setup(expected_length, chain_number) - return strace - - def _mp_sample( draws: int, tune: int, @@ -1360,7 +954,7 @@ def _mp_sample( random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, - trace: Optional[Union[BaseTrace, List[str]]] = None, + trace: Optional[BaseTrace] = None, model=None, callback=None, discard_tuned_samples: bool = True, @@ -1388,9 +982,9 @@ def _mp_sample( Dicts must contain numeric (transformed) initial values for all (transformed) free variables. progressbar : bool Whether or not to display a progress bar in the command line. - trace : BaseTrace, list, or None - This should be a backend instance, or a list of variables to track - If None or a list of variables, the NDArray backend is used. + trace : BaseTrace, optional + A backend instance, or None. + If None, the NDArray backend is used. model : Model (optional if in ``with`` context) callback : Callable A function which gets called for every sample from the trace of a chain. The function is @@ -1412,7 +1006,7 @@ def _mp_sample( traces = [ _init_trace( expected_length=draws + tune, - step=step, + stats_dtypes=step.stats_dtypes, chain_number=chain_number, trace=trace, model=model, @@ -1464,55 +1058,6 @@ def _mp_sample( strace.close() -def log_warning_stats(stats: Sequence[Dict[str, Any]]): - """Logs 'warning' stats if present.""" - if stats is None: - return - - for sts in stats: - warn = sts.get("warning", None) - if warn is None: - continue - if isinstance(warn, SamplerWarning): - log_warning(warn) - else: - _log.warning(warn) - return - - -def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]: - """ - Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized. - - We get here after a ``KeyboardInterrupt``, and so the different - traces have different lengths. We therefore pick the number of - traces such that (number of traces) * (length of shortest trace) - is maximised. - """ - if not traces: - raise ValueError("No traces to slice.") - - lengths = [max(0, len(trace) - tune) for trace in traces] - if not sum(lengths): - raise ValueError("Not enough samples to build a trace.") - - idxs = np.argsort(lengths) - l_sort = np.array(lengths)[idxs] - - use_until = cast(int, np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1])) - final_length = l_sort[use_until] - - take_idx = cast(Sequence[int], idxs[use_until:]) - sliced_traces = [traces[idx] for idx in take_idx] - return sliced_traces, final_length + tune - - -def stop_tuning(step): - """Stop tuning the current step method.""" - step.stop_tuning() - return step - - def _init_jitter( model: Model, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py new file mode 100644 index 00000000000..96ed3cf6e25 --- /dev/null +++ b/pymc/sampling/population.py @@ -0,0 +1,404 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module specializes on running MCMCs with population step methods.""" + +import logging + +from copy import copy +from typing import Iterator, List, Sequence, Union + +import cloudpickle +import numpy as np + +from fastprogress.fastprogress import progress_bar +from typing_extensions import TypeAlias + +from pymc.backends import _init_trace +from pymc.backends.base import BaseTrace, MultiTrace +from pymc.initial_point import PointType +from pymc.model import modelcontext +from pymc.stats.convergence import log_warning_stats +from pymc.step_methods import CompoundStep +from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared +from pymc.util import RandomSeed + +__all__ = () + + +Step: TypeAlias = Union[BlockedStep, CompoundStep] + + +_log = logging.getLogger("pymc") + + +def _sample_population( + draws: int, + chains: int, + start: Sequence[PointType], + random_seed: RandomSeed, + step, + tune: int, + model, + progressbar: bool = True, + parallelize: bool = False, + **kwargs, +) -> MultiTrace: + """Performs sampling of a population of chains using the ``PopulationStepper``. + + Parameters + ---------- + draws : int + The number of samples to draw + chains : int + The total number of chains in the population + start : list + Start points for each chain + random_seed : single random seed, optional + step : function + Step function (should be or contain a population step method) + tune : int + Number of iterations to tune. + model : Model (optional if in ``with`` context) + progressbar : bool + Show progress bars? (defaults to True) + parallelize : bool + Setting for multiprocess parallelization + + Returns + ------- + trace : MultiTrace + Contains samples of all chains + """ + sampling = _prepare_iter_population( + draws, + step, + start, + parallelize, + tune=tune, + model=model, + random_seed=random_seed, + progressbar=progressbar, + ) + + if progressbar: + sampling = progress_bar(sampling, total=draws, display=progressbar) + + latest_traces = None + for it, traces in enumerate(sampling): + latest_traces = traces + return MultiTrace(latest_traces) + + +class PopulationStepper: + """Wraps population of step methods to step them in parallel with single or multiprocessing.""" + + def __init__(self, steppers, parallelize: bool, progressbar: bool = True): + """Use multiprocessing to parallelize chains. + + Falls back to sequential evaluation if multiprocessing fails. + + In the multiprocessing mode of operation, a new process is started for each + chain/stepper and Pipes are used to communicate with the main process. + + Parameters + ---------- + steppers : list + A collection of independent step methods, one for each chain. + parallelize : bool + Indicates if parallelization via multiprocessing is desired. + progressbar : bool + Should we display a progress bar showing relative progress? + """ + self.nchains = len(steppers) + self.is_parallelized = False + self._primary_ends = [] + self._processes = [] + self._steppers = steppers + if parallelize: + try: + # configure a child process for each stepper + _log.info( + "Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`." + ) + import multiprocessing + + for c, stepper in ( + enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) + ): + secondary_end, primary_end = multiprocessing.Pipe() + stepper_dumps = cloudpickle.dumps(stepper, protocol=4) + process = multiprocessing.Process( + target=self.__class__._run_secondary, + args=(c, stepper_dumps, secondary_end), + name=f"ChainWalker{c}", + ) + # we want the child process to exit if the parent is terminated + process.daemon = True + # Starting the process might fail and takes time. + # By doing it in the constructor, the sampling progress bar + # will not be confused by the process start. + process.start() + self._primary_ends.append(primary_end) + self._processes.append(process) + self.is_parallelized = True + except Exception: + _log.info( + "Population parallelization failed. " + "Falling back to sequential stepping of chains." + ) + _log.debug("Error was: ", exc_info=True) + else: + _log.info( + "Chains are not parallelized. You can enable this by passing " + "`pm.sample(cores=n)`, where n > 1." + ) + return super().__init__() + + def __enter__(self): + """Do nothing: processes are already started in ``__init__``.""" + return + + def __exit__(self, exc_type, exc_val, exc_tb): + if len(self._processes) > 0: + try: + for primary_end in self._primary_ends: + primary_end.send(None) + for process in self._processes: + process.join(timeout=3) + except Exception: + _log.warning("Termination failed.") + return + + @staticmethod + def _run_secondary(c, stepper_dumps, secondary_end): + """This method is started on a separate process to perform stepping of a chain. + + Parameters + ---------- + c : int + number of this chain + stepper : BlockedStep + a step method such as CompoundStep + secondary_end : multiprocessing.connection.PipeConnection + This is our connection to the main process + """ + # re-seed each child process to make them unique + np.random.seed(None) + try: + stepper = cloudpickle.loads(stepper_dumps) + # the stepper is not necessarily a PopulationArraySharedStep itself, + # but rather a CompoundStep. PopulationArrayStepShared.population + # has to be updated, therefore we identify the substeppers first. + population_steppers = [] + for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]: + if isinstance(sm, PopulationArrayStepShared): + population_steppers.append(sm) + while True: + incoming = secondary_end.recv() + # receiving a None is the signal to exit + if incoming is None: + break + tune_stop, population = incoming + if tune_stop: + stepper.stop_tuning() + # forward the population to the PopulationArrayStepShared objects + # This is necessary because due to the process fork, the population + # object is no longer shared between the steppers. + for popstep in population_steppers: + popstep.population = population + update = stepper.step(population[c]) + secondary_end.send(update) + except Exception: + _log.exception(f"ChainWalker{c}") + return + + def step(self, tune_stop: bool, population): + """Step the entire population of chains. + + Parameters + ---------- + tune_stop : bool + Indicates if the condition (i == tune) is fulfilled + population : list + Current Points of all chains + + Returns + ------- + update : list + List of (Point, stats) tuples for all chains + """ + updates = [None] * self.nchains + if self.is_parallelized: + for c in range(self.nchains): + self._primary_ends[c].send((tune_stop, population)) + # Blockingly get the step outcomes + for c in range(self.nchains): + updates[c] = self._primary_ends[c].recv() + else: + for c in range(self.nchains): + if tune_stop: + self._steppers[c].stop_tuning() + updates[c] = self._steppers[c].step(population[c]) + return updates + + +def _prepare_iter_population( + draws: int, + step, + start: Sequence[PointType], + parallelize: bool, + tune: int, + model=None, + random_seed: RandomSeed = None, + progressbar=True, +) -> Iterator[Sequence[BaseTrace]]: + """Prepare a PopulationStepper and traces for population sampling. + + Parameters + ---------- + draws : int + The number of samples to draw + step : function + Step function (should be or contain a population step method) + start : list + Start points for each chain + parallelize : bool + Setting for multiprocess parallelization + tune : int + Number of iterations to tune. + model : Model (optional if in ``with`` context) + random_seed : single random seed, optional + progressbar : bool + ``progressbar`` argument for the ``PopulationStepper``, (defaults to True) + + Returns + ------- + _iter_population : generator + Yields traces of all chains at the same time + """ + nchains = len(start) + model = modelcontext(model) + draws = int(draws) + + if draws < 1: + raise ValueError("Argument `draws` should be above 0.") + + if random_seed is not None: + np.random.seed(random_seed) + + # The initialization of traces, samplers and points must happen in the right order: + # 1. population of points is created + # 2. steppers are initialized and linked to the points object + # 3. traces are initialized + # 4. a PopulationStepper is configured for parallelized stepping + + # 1. create a population (points) that tracks each chain + # it is updated as the chains are advanced + population = [start[c] for c in range(nchains)] + + # 2. Set up the steppers + steppers: List[Step] = [] + for c in range(nchains): + # need indepenent samplers for each chain + # it is important to copy the actual steppers (but not the delta_logp) + if isinstance(step, CompoundStep): + chainstep = CompoundStep([copy(m) for m in step.methods]) + else: + chainstep = copy(step) + # link population samplers to the shared population state + for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: + if isinstance(sm, PopulationArrayStepShared): + sm.link_population(population, c) + steppers.append(chainstep) + + # 3. Initialize a BaseTrace for each chain + traces: List[BaseTrace] = [ + _init_trace( + expected_length=draws + tune, + stats_dtypes=steppers[c].stats_dtypes, + chain_number=c, + trace=None, + model=model, + ) + for c in range(nchains) + ] + + # 4. configure the PopulationStepper (expensive call) + popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar) + + # Because the preparations above are expensive, the actual iterator is + # in another method. This way the progbar will not be disturbed. + return _iter_population(draws, tune, popstep, steppers, traces, population) + + +def _iter_population( + draws: int, tune: int, popstep: PopulationStepper, steppers, traces: Sequence[BaseTrace], points +) -> Iterator[Sequence[BaseTrace]]: + """Iterate a ``PopulationStepper``. + + Parameters + ---------- + draws : int + number of draws per chain + tune : int + number of tuning steps + popstep : PopulationStepper + the helper object for (parallelized) stepping of chains + steppers : list + The step methods for each chain + traces : list + Traces for each chain + points : list + population of chain states + + Yields + ------ + traces : list + List of trace objects of the individual chains + """ + try: + with popstep: + # iterate draws of all chains + for i in range(draws): + # this call steps all chains and returns a list of (point, stats) + # the `popstep` may interact with subprocesses internally + updates = popstep.step(i == tune, points) + + # apply the update to the points and record to the traces + for c, strace in enumerate(traces): + if steppers[c].generates_stats: + points[c], stats = updates[c] + strace.record(points[c], stats) + log_warning_stats(stats) + else: + points[c] = updates[c] + strace.record(points[c]) + # yield the state of all chains in parallel + yield traces + except KeyboardInterrupt: + for c, strace in enumerate(traces): + strace.close() + if hasattr(steppers[c], "report"): + steppers[c].report._finalize(strace) + raise + except BaseException: + for c, strace in enumerate(traces): + strace.close() + raise + else: + for c, strace in enumerate(traces): + strace.close() + if hasattr(steppers[c], "report"): + steppers[c].report._finalize(strace) diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index e39beff573f..7d585e9248e 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -2,7 +2,7 @@ import enum import logging -from typing import Any, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence import arviz @@ -164,3 +164,19 @@ def log_warning(warn: SamplerWarning): def log_warnings(warnings: Sequence[SamplerWarning]): for warn in warnings: log_warning(warn) + + +def log_warning_stats(stats: Sequence[Dict[str, Any]]): + """Logs 'warning' stats if present.""" + if stats is None: + return + + for sts in stats: + warn = sts.get("warning", None) + if warn is None: + continue + if isinstance(warn, SamplerWarning): + log_warning(warn) + else: + logger.warning(warn) + return diff --git a/pymc/tests/backends/test_base.py b/pymc/tests/backends/test_base.py new file mode 100644 index 00000000000..0e8cb95a027 --- /dev/null +++ b/pymc/tests/backends/test_base.py @@ -0,0 +1,55 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import pymc as pm + +from pymc.backends import _init_trace +from pymc.backends.base import _choose_chains + + +@pytest.mark.parametrize( + "n_points, tune, expected_length, expected_n_traces", + [ + ((5, 2, 2), 0, 2, 3), + ((6, 1, 1), 1, 6, 1), + ], +) +def test_choose_chains(n_points, tune, expected_length, expected_n_traces): + trace_0 = np.arange(n_points[0]) + trace_1 = np.arange(n_points[1]) + trace_2 = np.arange(n_points[2]) + traces, length = _choose_chains([trace_0, trace_1, trace_2], tune=tune) + assert length == expected_length + assert expected_n_traces == len(traces) + + +class TestInitTrace: + def test_init_trace_continuation_unsupported(self): + with pm.Model() as pmodel: + A = pm.Normal("A") + B = pm.Uniform("B") + strace = pm.backends.ndarray.NDArray(vars=[A, B]) + strace.setup(10, 0) + strace.record({"A": 2, "B_interval__": 0.1}) + assert len(strace) == 1 + with pytest.raises(ValueError, match="Continuation of traces"): + _init_trace( + expected_length=20, + stats_dtypes=pm.Metropolis().stats_dtypes, + chain_number=0, + trace=strace, + model=pmodel, + ) diff --git a/pymc/tests/sampling/test_mcmc.py b/pymc/tests/sampling/test_mcmc.py index 540850bd237..ba8eb287177 100644 --- a/pymc/tests/sampling/test_mcmc.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -31,7 +31,6 @@ import pymc as pm -from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.distributions import transforms from pymc.exceptions import SamplingError @@ -484,40 +483,12 @@ def test_empty_model(): error.match("any free variables") -def test_partial_trace_sample(): +def test_partial_trace_unsupported(): with pm.Model() as model: a = pm.Normal("a", mu=0, sigma=1) b = pm.Normal("b", mu=0, sigma=1) - idata = pm.sample(trace=[a]) - assert "a" in idata.posterior - assert "b" not in idata.posterior - - -@pytest.mark.parametrize( - "n_points, tune, expected_length, expected_n_traces", - [ - ((5, 2, 2), 0, 2, 3), - ((6, 1, 1), 1, 6, 1), - ], -) -def test_choose_chains(n_points, tune, expected_length, expected_n_traces): - with pm.Model() as model: - a = pm.Normal("a", mu=0, sigma=1) - trace_0 = NDArray(model) - trace_1 = NDArray(model) - trace_2 = NDArray(model) - trace_0.setup(n_points[0], 1) - trace_1.setup(n_points[1], 1) - trace_2.setup(n_points[2], 1) - for _ in range(n_points[0]): - trace_0.record({"a": 0}) - for _ in range(n_points[1]): - trace_1.record({"a": 0}) - for _ in range(n_points[2]): - trace_2.record({"a": 0}) - traces, length = pm.sampling.mcmc._choose_chains([trace_0, trace_1, trace_2], tune=tune) - assert length == expected_length - assert expected_n_traces == len(traces) + with pytest.raises(DeprecationWarning, match="removed support"): + pm.sample(trace=[a]) @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32") @@ -573,33 +544,6 @@ def test_constant_named(self): assert np.isclose(res, 0.0) -class TestChooseBackend: - def test_choose_backend_none(self): - with mock.patch("pymc.sampling.mcmc.NDArray") as nd: - pm.sampling.mcmc._choose_backend(None) - assert nd.called - - def test_choose_backend_list_of_variables(self): - with mock.patch("pymc.sampling.mcmc.NDArray") as nd: - pm.sampling.mcmc._choose_backend(["var1", "var2"]) - nd.assert_called_with(vars=["var1", "var2"]) - - def test_errors_and_warnings(self): - with pm.Model(): - A = pm.Normal("A") - B = pm.Uniform("B") - strace = pm.backends.ndarray.NDArray(vars=[A, B]) - strace.setup(10, 0) - - with pytest.raises(ValueError, match="from existing MultiTrace"): - pm.sampling.mcmc._choose_backend(trace=MultiTrace([strace])) - - strace.record({"A": 2, "B_interval__": 0.1}) - assert len(strace) == 1 - with pytest.raises(ValueError, match="Continuation of traces"): - pm.sampling.mcmc._choose_backend(trace=strace) - - def check_exec_nuts_init(method): with pm.Model() as model: pm.Normal("a", mu=0, sigma=1, size=2) @@ -698,29 +642,6 @@ def test_step_args(): npt.assert_allclose(idata1.sample_stats.scaling, 0) -def test_log_warning_stats(caplog): - s1 = dict(warning="Temperature too low!") - s2 = dict(warning="Temperature too high!") - stats = [s1, s2] - - with caplog.at_level(logging.WARNING): - pm.sampling.mcmc.log_warning_stats(stats) - - # We have a list of stats dicts, because there might be several samplers involved. - assert "too low" in caplog.records[0].message - assert "too high" in caplog.records[1].message - - -def test_log_warning_stats_knows_SamplerWarning(caplog): - """Checks that SamplerWarning "warning" stats get special treatment.""" - stats = [dict(warning=SamplerWarning(WarningType.BAD_ENERGY, "Not that interesting", "debug"))] - - with caplog.at_level(logging.DEBUG, logger="pymc"): - pm.sampling.mcmc.log_warning_stats(stats) - - assert "Not that interesting" in caplog.records[0].message - - class ApolypticMetropolis(pm.Metropolis): """A stepper that warns in every iteration.""" diff --git a/pymc/tests/sampling/test_population.py b/pymc/tests/sampling/test_population.py new file mode 100644 index 00000000000..58689858569 --- /dev/null +++ b/pymc/tests/sampling/test_population.py @@ -0,0 +1,91 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import pymc as pm + +from pymc.step_methods.metropolis import DEMetropolis + + +class TestPopulationSamplers: + + steppers = [DEMetropolis] + + def test_checks_population_size(self): + """Test that population samplers check the population size.""" + with pm.Model() as model: + n = pm.Normal("n", mu=0, sigma=1) + for stepper in TestPopulationSamplers.steppers: + step = stepper() + with pytest.raises(ValueError, match="requires at least 3 chains"): + pm.sample(draws=10, tune=10, chains=1, cores=1, step=step) + # don't parallelize to make test faster + pm.sample( + draws=10, + tune=10, + chains=4, + cores=1, + step=step, + compute_convergence_checks=False, + ) + + def test_demcmc_warning_on_small_populations(self): + """Test that a warning is raised when n_chains <= n_dims""" + with pm.Model() as model: + pm.Normal("n", mu=0, sigma=1, size=(2, 3)) + with pytest.warns(UserWarning, match="more chains than dimensions"): + pm.sample( + draws=5, + tune=5, + chains=6, + step=DEMetropolis(), + # make tests faster by not parallelizing; disable convergence warning + cores=1, + compute_convergence_checks=False, + ) + + def test_nonparallelized_chains_are_random(self): + with pm.Model() as model: + x = pm.Normal("x", 0, 1) + for stepper in TestPopulationSamplers.steppers: + step = stepper() + idata = pm.sample( + chains=4, + cores=1, + draws=20, + tune=0, + step=DEMetropolis(), + compute_convergence_checks=False, + ) + samples = idata.posterior["x"].values[:, 5] + + assert len(set(samples)) == 4, f"Parallelized {stepper} chains are identical." + + def test_parallelized_chains_are_random(self): + with pm.Model() as model: + x = pm.Normal("x", 0, 1) + for stepper in TestPopulationSamplers.steppers: + step = stepper() + idata = pm.sample( + chains=4, + cores=4, + draws=20, + tune=0, + step=DEMetropolis(), + compute_convergence_checks=False, + ) + samples = idata.posterior["x"].values[:, 5] + + assert len(set(samples)) == 4, f"Parallelized {stepper} chains are identical." diff --git a/pymc/tests/stats/test_convergence.py b/pymc/tests/stats/test_convergence.py index 796731953a3..d3241d6e01e 100644 --- a/pymc/tests/stats/test_convergence.py +++ b/pymc/tests/stats/test_convergence.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import arviz import numpy as np @@ -27,3 +29,31 @@ def test_warn_divergences(): warns = convergence.warn_divergences(idata) assert len(warns) == 1 assert "2 divergences after tuning" in warns[0].message + + +def test_log_warning_stats(caplog): + s1 = dict(warning="Temperature too low!") + s2 = dict(warning="Temperature too high!") + stats = [s1, s2] + + with caplog.at_level(logging.WARNING): + convergence.log_warning_stats(stats) + + # We have a list of stats dicts, because there might be several samplers involved. + assert "too low" in caplog.records[0].message + assert "too high" in caplog.records[1].message + + +def test_log_warning_stats_knows_SamplerWarning(caplog): + """Checks that SamplerWarning "warning" stats get special treatment.""" + warn = convergence.SamplerWarning( + convergence.WarningType.BAD_ENERGY, + "Not that interesting", + "debug", + ) + stats = [dict(warning=warn)] + + with caplog.at_level(logging.DEBUG, logger="pymc"): + convergence.log_warning_stats(stats) + + assert "Not that interesting" in caplog.records[0].message diff --git a/pymc/tests/step_methods/test_metropolis.py b/pymc/tests/step_methods/test_metropolis.py index 119ca5c9a67..d4ed99f62a9 100644 --- a/pymc/tests/step_methods/test_metropolis.py +++ b/pymc/tests/step_methods/test_metropolis.py @@ -131,40 +131,7 @@ def test_multinomial_no_elemwise_update(self): assert not step.elemwise_update -class TestPopulationSamplers: - - steppers = [DEMetropolis] - - def test_checks_population_size(self): - """Test that population samplers check the population size.""" - with pm.Model() as model: - n = pm.Normal("n", mu=0, sigma=1) - for stepper in TestPopulationSamplers.steppers: - step = stepper() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - with pytest.raises(ValueError): - pm.sample(draws=10, tune=10, chains=1, cores=1, step=step) - # don't parallelize to make test faster - pm.sample(draws=10, tune=10, chains=4, cores=1, step=step) - - def test_demcmc_warning_on_small_populations(self): - """Test that a warning is raised when n_chains <= n_dims""" - with pm.Model() as model: - pm.Normal("n", mu=0, sigma=1, size=(2, 3)) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - with pytest.warns(UserWarning) as record: - pm.sample( - draws=5, - tune=5, - chains=6, - step=DEMetropolis(), - # make tests faster by not parallelizing; disable convergence warning - cores=1, - compute_convergence_checks=False, - ) - +class TestDEMetropolis: def test_demcmc_tune_parameter(self): """Tests that validity of the tune setting is checked""" with pm.Model() as model: @@ -182,30 +149,6 @@ def test_demcmc_tune_parameter(self): with pytest.raises(ValueError): DEMetropolis(tune="foo") - def test_nonparallelized_chains_are_random(self): - with pm.Model() as model: - x = pm.Normal("x", 0, 1) - for stepper in TestPopulationSamplers.steppers: - step = stepper() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - idata = pm.sample(chains=4, cores=1, draws=20, tune=0, step=DEMetropolis()) - samples = idata.posterior["x"].values[:, 5] - - assert len(set(samples)) == 4, f"Parallelized {stepper} chains are identical." - - def test_parallelized_chains_are_random(self): - with pm.Model() as model: - x = pm.Normal("x", 0, 1) - for stepper in TestPopulationSamplers.steppers: - step = stepper() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - idata = pm.sample(chains=4, cores=4, draws=20, tune=0, step=DEMetropolis()) - samples = idata.posterior["x"].values[:, 5] - - assert len(set(samples)) == 4, f"Parallelized {stepper} chains are identical." - class TestDEMetropolisZ: def test_tuning_lambda_sequential(self): diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 310f26fbdc9..5bd538c7932 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -55,6 +55,7 @@ pymc/sampling/forward.py pymc/sampling/mcmc.py pymc/sampling/parallel.py +pymc/sampling/population.py pymc/smc/__init__.py pymc/smc/sampling.py pymc/smc/kernels.py