From a30b4a4a9cb2bc74a0dc278bc83fec819fd729af Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 3 Apr 2024 11:15:35 -0700 Subject: [PATCH] Replace fastprogress progress bars with rich (#7233) * Replace fastprogress with rich * Bugfixes for ADVI progress bars * Bugfixes for MAP progress bars * Fixed final update to progress bar * SMC progress bar working * Fixes to MAP progress bar * Customize progress bar theme * Added progressbar_theme argument * Moved default progressbar theme to util * Convert compute_log_density to use Progress instead of track * Getting rid of mypy complaint --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-jax.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc/sampling/forward.py | 48 +++--- pymc/sampling/mcmc.py | 39 +++-- pymc/sampling/parallel.py | 93 +++++++----- pymc/sampling/population.py | 66 +++++--- pymc/smc/sampling.py | 117 +++++++-------- pymc/stats/log_density.py | 21 +-- pymc/tuning/starting.py | 71 +++++---- pymc/util.py | 8 + pymc/variational/inference.py | 192 +++++++++++++----------- requirements-dev.txt | 2 +- requirements.txt | 2 +- 17 files changed, 372 insertions(+), 299 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index a2cf7c25d3a..a3f4a41a8c5 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - numpy>=1.15.0 - pandas>=0.24.0 @@ -28,6 +27,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 +- rich>=13.7.1 - sphinx-copybutton - sphinx-design - sphinx-notfound-page diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 86227038372..d50328df26c 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -8,12 +8,12 @@ dependencies: - arviz>=0.13.0 - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - numpy>=1.15.0 - pandas>=0.24.0 - pip - pytensor>=2.19,<2.20 - python-graphviz +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for docs build diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 0986f43046e..34379048015 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 # Jaxlib version must not be greater than jax version! - blackjax>=1.0.0 @@ -24,6 +23,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 8272cca2396..6c0c2a0b61f 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - jax - libblas=*=*mkl @@ -20,6 +19,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 25fdeb419ce..91df7bfbac5 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - numpy>=1.15.0 - pandas>=0.24.0 @@ -17,6 +16,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for dev, testing and docs build diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 900e3e227e6..aaa958e985c 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - libpython - mkl-service>=2.3.0 @@ -20,6 +19,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 803d6f8ed20..1a8116ecd27 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -30,7 +30,6 @@ import xarray from arviz import InferenceData -from fastprogress.fastprogress import progress_bar from pytensor import tensor as pt from pytensor.graph.basic import ( Apply, @@ -46,6 +45,9 @@ RandomStateSharedVariable, ) from pytensor.tensor.sharedvar import SharedVariable +from rich.console import Console +from rich.progress import Progress +from rich.theme import Theme from typing_extensions import TypeAlias import pymc as pm @@ -59,6 +61,7 @@ RandomState, _get_seeds_per_chain, dataset_to_point_list, + default_progress_theme, get_default_varnames, point_wrapper, ) @@ -70,7 +73,6 @@ "sample_posterior_predictive", ) - ArrayLike: TypeAlias = Union[np.ndarray, list[float]] PointList: TypeAlias = list[PointType] @@ -442,6 +444,7 @@ def sample_posterior_predictive( sample_dims: Optional[list[str]] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = default_progress_theme, return_inferencedata: bool = True, extend_inferencedata: bool = False, predictions: bool = False, @@ -796,10 +799,6 @@ def sample_posterior_predictive( else: vars_ = model.observed_RVs + observed_dependent_deterministics(model) - indices = np.arange(samples) - if progressbar: - indices = progress_bar(indices, total=samples, display=progressbar) - vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) if not vars_to_sample: @@ -834,25 +833,30 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - for idx in indices: - if nchain > 1: - # the trace object will either be a MultiTrace (and have _straces)... - if hasattr(_trace, "_straces"): - chain_idx, point_idx = np.divmod(idx, len_trace) - chain_idx = chain_idx % nchain - param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) - # ... or a PointList + with Progress(console=Console(theme=progressbar_theme)) as progress: + task = progress.add_task("Sampling ...", total=samples, visible=progressbar) + for idx in np.arange(samples): + if nchain > 1: + # the trace object will either be a MultiTrace (and have _straces)... + if hasattr(_trace, "_straces"): + chain_idx, point_idx = np.divmod(idx, len_trace) + chain_idx = chain_idx % nchain + param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) + # ... or a PointList + else: + param = cast(PointList, _trace)[idx % (len_trace * nchain)] + # there's only a single chain, but the index might hit it multiple times if + # the number of indices is greater than the length of the trace. else: - param = cast(PointList, _trace)[idx % (len_trace * nchain)] - # there's only a single chain, but the index might hit it multiple times if - # the number of indices is greater than the length of the trace. - else: - param = _trace[idx % len_trace] + param = _trace[idx % len_trace] + + values = sampler_fn(**param) + + for k, v in zip(vars_, values): + ppc_trace_t.insert(k.name, v, idx) - values = sampler_fn(**param) + progress.advance(task) - for k, v in zip(vars_, values): - ppc_trace_t.insert(k.name, v, idx) except KeyboardInterrupt: pass diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index dd97f78c884..1241b10b865 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -34,8 +34,10 @@ from arviz import InferenceData, dict_to_dataset from arviz.data.base import make_attrs -from fastprogress.fastprogress import progress_bar from pytensor.graph.basic import Variable +from rich.console import Console +from rich.progress import Progress +from rich.theme import Theme from typing_extensions import Protocol, TypeAlias import pymc as pm @@ -65,6 +67,7 @@ RandomSeed, RandomState, _get_seeds_per_chain, + default_progress_theme, drop_warning_stat, get_untransformed_name, is_transformed_name, @@ -377,6 +380,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = default_progress_theme, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -406,6 +410,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = default_progress_theme, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -435,6 +440,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = default_progress_theme, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -761,6 +767,7 @@ def sample( "tune": tune, "var_names": var_names, "progressbar": progressbar, + "progressbar_theme": progressbar_theme, "model": model, "cores": cores, "callback": callback, @@ -983,6 +990,7 @@ def _sample( trace: IBaseTrace, tune: int, model: Optional[Model] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, callback=None, **kwargs, ) -> None: @@ -1010,6 +1018,8 @@ def _sample( tune : int Number of iterations to tune. model : Model (optional if in ``with`` context) + progressbar_theme : Theme + Optional custom theme for the progress bar. """ skip_first = kwargs.get("skip_first", 0) @@ -1026,19 +1036,16 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - if progressbar: - sampling = progress_bar(sampling_gen, total=draws, display=progressbar) - sampling.comment = _desc.format(**_pbar_data) - else: - sampling = sampling_gen - try: - for it, diverging in enumerate(sampling): - if it >= skip_first and diverging: - _pbar_data["divergences"] += 1 - if progressbar: - sampling.comment = _desc.format(**_pbar_data) - except KeyboardInterrupt: - pass + with Progress(console=Console(theme=progressbar_theme)) as progress: + try: + task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar) + for it, diverging in enumerate(sampling_gen): + if it >= skip_first and diverging: + _pbar_data["divergences"] += 1 + progress.update(task, advance=1) + progress.update(task, advance=1, completed=True) + except KeyboardInterrupt: + pass def _iter_sample( @@ -1131,6 +1138,7 @@ def _mp_sample( random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, + progressbar_theme: Optional[Theme] = default_progress_theme, traces: Sequence[IBaseTrace], model: Optional[Model] = None, callback: Optional[SamplingIteratorCallback] = None, @@ -1158,6 +1166,8 @@ def _mp_sample( Dicts must contain numeric (transformed) initial values for all (transformed) free variables. progressbar : bool Whether or not to display a progress bar in the command line. + progressbar_theme : Theme + Optional custom theme for the progress bar. traces Recording backends for each chain. model : Model (optional if in ``with`` context) @@ -1182,6 +1192,7 @@ def _mp_sample( start_points=start, step_method=step, progressbar=progressbar, + progressbar_theme=progressbar_theme, mp_ctx=mp_ctx, ) try: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 430c361cac1..29ccc1a0d0e 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -22,15 +22,18 @@ from collections import namedtuple from collections.abc import Sequence +from typing import Optional import cloudpickle import numpy as np -from fastprogress.fastprogress import progress_bar +from rich.console import Console +from rich.progress import BarColumn, Progress, TimeRemainingColumn +from rich.theme import Theme from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import RandomSeed +from pymc.util import RandomSeed, default_progress_theme logger = logging.getLogger(__name__) @@ -375,6 +378,7 @@ def __init__( start_points: Sequence[dict[str, np.ndarray]], step_method, progressbar: bool = True, + progressbar_theme: Optional[Theme] = default_progress_theme, mp_ctx=None, ): if any(len(arg) != chains for arg in [seeds, start_points]): @@ -420,14 +424,19 @@ def __init__( self._in_context = False - self._progress = None + self._progress = Progress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + console=Console(theme=progressbar_theme), + ) + self._show_progress = progressbar self._divergences = 0 - self._total_draws = 0 + self._completed_draws = 0 + self._total_draws = chains * (draws + tune) self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" self._chains = chains - if progressbar: - self._progress = progress_bar(range(chains * (draws + tune)), display=progressbar) - self._progress.comment = self._desc.format(self) def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -441,37 +450,45 @@ def __iter__(self): raise ValueError("Use ParallelSampler as context manager.") self._make_active() - if self._active and self._progress: - self._progress.update(self._total_draws) - - while self._active: - draw = ProcessAdapter.recv_draw(self._active) - proc, is_last, draw, tuning, stats = draw - self._total_draws += 1 - if not tuning and stats and stats[0].get("diverging"): - self._divergences += 1 - if self._progress: - self._progress.comment = self._desc.format(self) - if self._progress: - self._progress.update(self._total_draws) - - if is_last: - proc.join() - self._active.remove(proc) - self._finished.append(proc) - self._make_active() - - # We could also yield proc.shared_point_view directly, - # and only call proc.write_next() after the yield returns. - # This seems to be faster overally though, as the worker - # loses less time waiting. - point = {name: val.copy() for name, val in proc.shared_point_view.items()} - - # Already called for new proc in _make_active - if not is_last: - proc.write_next() - - yield Draw(proc.chain, is_last, draw, tuning, stats, point) + with self._progress as progress: + task = progress.add_task( + self._desc.format(self), + completed=self._completed_draws, + total=self._total_draws, + visible=self._show_progress, + ) + + while self._active: + draw = ProcessAdapter.recv_draw(self._active) + proc, is_last, draw, tuning, stats = draw + self._completed_draws += 1 + if not tuning and stats and stats[0].get("diverging"): + self._divergences += 1 + progress.update( + task, + completed=self._completed_draws, + total=self._total_draws, + description=self._desc.format(self), + ) + + if is_last: + proc.join() + self._active.remove(proc) + self._finished.append(proc) + self._make_active() + progress.update(task, description=self._desc.format(self), refresh=True) + + # We could also yield proc.shared_point_view directly, + # and only call proc.write_next() after the yield returns. + # This seems to be faster overally though, as the worker + # loses less time waiting. + point = {name: val.copy() for name, val in proc.shared_point_view.items()} + + # Already called for new proc in _make_active + if not is_last: + proc.write_next() + + yield Draw(proc.chain, is_last, draw, tuning, stats, point) def __enter__(self): self._in_context = True diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index c38b90599b2..2a0db2ecfa8 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from fastprogress.fastprogress import progress_bar +from rich.progress import BarColumn, Progress, TimeRemainingColumn from typing_extensions import TypeAlias from pymc.backends.base import BaseTrace @@ -101,11 +101,12 @@ def _sample_population( progressbar=progressbar, ) - if progressbar: - sampling = progress_bar(sampling, total=draws, display=progressbar) + with Progress() as progress: + task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar) + + for _ in sampling: + progress.update(task, advance=1) - for i in sampling: - pass return @@ -166,6 +167,7 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): self._primary_ends = [] self._processes = [] self._steppers = steppers + self._progress = None if parallelize: try: # configure a child process for each stepper @@ -174,25 +176,34 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): ) import multiprocessing - for c, stepper in ( - enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) - ): - secondary_end, primary_end = multiprocessing.Pipe() - stepper_dumps = cloudpickle.dumps(stepper, protocol=4) - process = multiprocessing.Process( - target=self.__class__._run_secondary, - args=(c, stepper_dumps, secondary_end), - name=f"ChainWalker{c}", - ) - # we want the child process to exit if the parent is terminated - process.daemon = True - # Starting the process might fail and takes time. - # By doing it in the constructor, the sampling progress bar - # will not be confused by the process start. - process.start() - self._primary_ends.append(primary_end) - self._processes.append(process) - self.is_parallelized = True + with Progress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + ) as self._progress: + for c, stepper in enumerate(steppers): + # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) + # ): + task = self._progress.add_task( + description=f"Chain {c}", visible=progressbar + ) + secondary_end, primary_end = multiprocessing.Pipe() + stepper_dumps = cloudpickle.dumps(stepper, protocol=4) + process = multiprocessing.Process( + target=self.__class__._run_secondary, + args=(c, stepper_dumps, secondary_end, task, self._progress), + name=f"ChainWalker{c}", + ) + # we want the child process to exit if the parent is terminated + process.daemon = True + # Starting the process might fail and takes time. + # By doing it in the constructor, the sampling progress bar + # will not be confused by the process start. + process.start() + self._primary_ends.append(primary_end) + self._processes.append(process) + self.is_parallelized = True except Exception: _log.info( "Population parallelization failed. " @@ -222,7 +233,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return @staticmethod - def _run_secondary(c, stepper_dumps, secondary_end): + def _run_secondary(c, stepper_dumps, secondary_end, task, progress): """The method is started on a separate process to perform stepping of a chain. Parameters @@ -233,6 +244,10 @@ def _run_secondary(c, stepper_dumps, secondary_end): a step method such as CompoundStep secondary_end : multiprocessing.connection.PipeConnection This is our connection to the main process + task : progress.Task + The progress task for this chain + progress : progress.Progress + The progress bar """ # re-seed each child process to make them unique np.random.seed(None) @@ -259,6 +274,7 @@ def _run_secondary(c, stepper_dumps, secondary_end): for popstep in population_steppers: popstep.population = population update = stepper.step(population[c]) + progress.advance(task) secondary_end.send(update) except Exception: _log.exception(f"ChainWalker{c}") diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 2ea3800acec..e5129e8fcee 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -13,19 +13,19 @@ # limitations under the License. import logging -import multiprocessing as mp +import multiprocessing import time import warnings from collections import defaultdict -from itertools import repeat +from concurrent.futures import ProcessPoolExecutor from typing import Any, Optional, Union import cloudpickle import numpy as np from arviz import InferenceData -from fastprogress.fastprogress import force_console_behavior, progress_bar +from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn import pymc @@ -209,14 +209,8 @@ def sample_smc( t1 = time.time() - if cores > 1: - results = run_chains_parallel( - chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores - ) - else: - results = run_chains_sequential( - chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs - ) + results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores) + ( traces, sample_stats, @@ -310,7 +304,8 @@ def _sample_smc_int( model, random_seed, chain, - progressbar=None, + progress_dict, + task_id, **kernel_kwargs, ): """Run one SMC instance.""" @@ -337,10 +332,6 @@ def _sample_smc_int( **kernel_kwargs, ) - if progressbar: - progressbar.comment = f"{getattr(progressbar, 'base_comment', '')} Stage: 0 Beta: 0" - progressbar.update_bar(getattr(progressbar, "offset", 0) + 0) - smc._initialize_kernel() smc.setup_kernel() @@ -349,11 +340,7 @@ def _sample_smc_int( while smc.beta < 1: smc.update_beta_and_weights() - if progressbar: - progressbar.comment = ( - f"{getattr(progressbar, 'base_comment', '')} Stage: {stage} Beta: {smc.beta:.3f}" - ) - progressbar.update_bar(getattr(progressbar, "offset", 0) + int(smc.beta * 100)) + progress_dict[task_id] = {"stage": stage, "beta": smc.beta} smc.resample() smc.tune() @@ -375,47 +362,47 @@ def _sample_smc_int( return results -def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): - # fastprogress HTML progress bar does not support multiprocessing - _, progress_bar = force_console_behavior() - pbar = progress_bar((), total=100, display=progressbar) - pbar.update(0) - pbars = [pbar] + [None] * (chains - 1) - - pool = mp.Pool(cores) - - # "manually" (de)serialize params before/after multiprocessing - params = tuple(cloudpickle.dumps(p) for p in params) - kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - results = _starmap_with_kwargs( - pool, - to_run, - [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)], - repeat(kernel_kwargs), - ) - results = tuple(cloudpickle.loads(r) for r in results) - pool.close() - pool.join() - return results - - -def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): - results = [] - pbar = progress_bar((), total=100 * chains, display=progressbar) - pbar.update(0) - for chain in range(chains): - pbar.offset = 100 * chain - pbar.base_comment = f"Chain: {chain + 1}/{chains}" - results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) - return results - - -def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): - # Helper function to allow kwargs with Pool.starmap - # Copied from https://stackoverflow.com/a/53173433/13311693 - args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) - return pool.starmap(_apply_args_and_kwargs, args_for_starmap) - - -def _apply_args_and_kwargs(fn, args, kwargs): - return fn(*args, **kwargs) +def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): + with Progress( + TextColumn("{task.description}"), + SpinnerColumn(), + TimeElapsedColumn(), + TextColumn("{task.fields[status]}"), + ) as progress: + futures = [] # keep track of the jobs + with multiprocessing.Manager() as manager: + # this is the key - we share some state between our + # main process and our worker functions + _progress = manager.dict() + + # "manually" (de)serialize params before/after multiprocessing + params = tuple(cloudpickle.dumps(p) for p in params) + kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} + + with ProcessPoolExecutor(max_workers=cores) as executor: + for c in range(chains): # iterate over the jobs we need to run + # set visible false so we don't have a lot of bars all at once: + task_id = progress.add_task( + f"Chain {c}", status="Stage: 0 Beta: 0", visible=progressbar + ) + futures.append( + executor.submit( + _sample_smc_int, + *params, + random_seed[c], + c, + _progress, + task_id, + **kernel_kwargs, + ) + ) + + # monitor the progress: + while sum([future.done() for future in futures]) < len(futures): + for task_id, update_data in _progress.items(): + stage = update_data["stage"] + beta = update_data["beta"] + # update the progress bar for this task: + progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id) + + return tuple(cloudpickle.loads(r.result()) for r in futures) diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index e72e2445799..daf172342f4 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -15,14 +15,15 @@ from typing import Optional, cast from arviz import InferenceData, dict_to_dataset -from fastprogress import progress_bar +from rich.console import Console +from rich.progress import Progress import pymc from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata from pymc.model import Model, modelcontext from pymc.pytensorf import PointFunc -from pymc.util import dataset_to_point_list +from pymc.util import dataset_to_point_list, default_progress_theme __all__ = ("compute_log_likelihood", "compute_log_prior") @@ -169,14 +170,14 @@ def compute_log_density( n_pts = len(posterior_pts) logdens_dict = _DefaultTrace(n_pts) - indices = range(n_pts) - if progressbar: - indices = progress_bar(indices, total=n_pts, display=progressbar) - - for idx in indices: - logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) - for rv_name, rv_logdens in zip(var_names, logdenss_pts): - logdens_dict.insert(rv_name, rv_logdens, idx) + + with Progress(console=Console(theme=default_progress_theme)) as progress: + task = progress.add_task("Computing log density...", total=n_pts, visible=progressbar) + for idx in range(n_pts): + logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) + for rv_name, rv_logdens in zip(var_names, logdenss_pts): + logdens_dict.insert(rv_name, rv_logdens, idx) + progress.update(task, advance=1) logdens_trace = logdens_dict.trace_dict for key, array in logdens_trace.items(): diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 8db885057b8..90f56d19bee 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -27,9 +27,10 @@ import numpy as np import pytensor.gradient as tg -from fastprogress.fastprogress import ProgressBar, progress_bar from numpy import isfinite from pytensor import Variable +from rich.console import Console +from rich.progress import Progress, TextColumn from scipy.optimize import minimize import pymc as pm @@ -37,7 +38,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext -from pymc.util import get_default_varnames, get_value_vars_from_user_vars +from pymc.util import default_progress_theme, get_default_varnames, get_value_vars_from_user_vars from pymc.vartypes import discrete_types, typefilter __all__ = ["find_MAP"] @@ -50,6 +51,7 @@ def find_MAP( return_raw=False, include_transformed=True, progressbar=True, + progressbar_theme=default_progress_theme, maxeval=5000, model=None, *args, @@ -82,6 +84,8 @@ def find_MAP( to the constrained values progressbar: bool, optional defaults to True Whether to display a progress bar in the command line. + progressbar_theme: Theme, optional + Custom theme for the progress bar. maxeval: int, optional, defaults to 5000 The maximum number of times the posterior distribution is evaluated. model: Model (optional if in `with` context) @@ -159,26 +163,23 @@ def find_MAP( method = "Powell" if compute_gradient and method != "Powell": - cost_func = CostFuncWrapper(maxeval, progressbar, logp_func, dlogp_func) + cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func, dlogp_func) else: - cost_func = CostFuncWrapper(maxeval, progressbar, logp_func) + cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func) compute_gradient = False - try: - opt_result = minimize( - cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs - ) - mx0 = opt_result["x"] # r -> opt_result - except (KeyboardInterrupt, StopIteration) as e: - mx0, opt_result = cost_func.previous_x, None - if isinstance(e, StopIteration): - pm._log.info(e) - finally: - last_v = cost_func.n_eval - if progressbar: - assert isinstance(cost_func.progress, ProgressBar) - cost_func.progress.total = last_v - cost_func.progress.update(last_v) + with cost_func.progress: + try: + opt_result = minimize( + cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs + ) + mx0 = opt_result["x"] # r -> opt_result + except (KeyboardInterrupt, StopIteration) as e: + mx0, opt_result = cost_func.previous_x, None + if isinstance(e, StopIteration): + pm._log.info(e) + finally: + cost_func.progress.update(cost_func.task, completed=cost_func.n_eval) print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) @@ -199,7 +200,14 @@ def allfinite(x): class CostFuncWrapper: - def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=None): + def __init__( + self, + maxeval=5000, + progressbar=True, + progressbar_theme=default_progress_theme, + logp_func=None, + dlogp_func=None, + ): self.n_eval = 0 self.maxeval = maxeval self.logp_func = logp_func @@ -212,11 +220,12 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}" self.previous_x = None self.progressbar = progressbar - if progressbar: - self.progress = progress_bar(range(maxeval), total=maxeval, display=progressbar) - self.progress.update(0) - else: - self.progress = range(maxeval) + self.progress = Progress( + *Progress.get_default_columns(), + TextColumn("{task.fields[loss]}"), + console=Console(theme=progressbar_theme), + ) + self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") def __call__(self, x): neg_value = np.float64(self.logp_func(pm.floatX(x))) @@ -232,16 +241,14 @@ def __call__(self, x): grad = None if self.n_eval % 10 == 0: - self.update_progress_desc(neg_value, grad) + self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad)) if self.n_eval > self.maxeval: - self.update_progress_desc(neg_value, grad) + self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad)) raise StopIteration self.n_eval += 1 - if self.progressbar: - assert isinstance(self.progress, ProgressBar) - self.progress.update_bar(self.n_eval) + self.progress.update(self.task, completed=self.n_eval) if self.use_gradient: return value, grad @@ -251,7 +258,7 @@ def __call__(self, x): def update_progress_desc(self, neg_value: float, grad: np.float64 = None) -> None: if self.progressbar: if grad is None: - self.progress.comment = self.desc.format(neg_value) + return self.desc.format(neg_value) else: norm_grad = np.linalg.norm(grad) - self.progress.comment = self.desc.format(neg_value, norm_grad) + return self.desc.format(neg_value, norm_grad) diff --git a/pymc/util.py b/pymc/util.py index 8388a8ed49b..b72e17e0ae5 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -27,11 +27,19 @@ from pytensor import Variable from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad +from rich.theme import Theme from pymc.exceptions import BlockModelAccessError VarName = NewType("VarName", str) +default_progress_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 6ee5815d145..3d9e6fd8eae 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -18,10 +18,12 @@ import numpy as np -from fastprogress.fastprogress import progress_bar +from rich.console import Console +from rich.progress import Progress, TextColumn, track import pymc as pm +from pymc.util import default_progress_theme from pymc.variational import test_functions from pymc.variational.approximations import Empirical, FullRank, MeanField from pymc.variational.operators import KL, KSD @@ -83,15 +85,22 @@ def run_profiling(self, n=1000, score=None, **kwargs): fn_kwargs = kwargs.pop("fn_kwargs", dict()) fn_kwargs["profile"] = True step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs) - progress = progress_bar(range(n)) try: - for _ in progress: + for _ in track(range(n)): step_func() except KeyboardInterrupt: pass return step_func.profile - def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): + def fit( + self, + n=10000, + score=None, + callbacks=None, + progressbar=True, + progressbar_theme=default_progress_theme, + **kwargs, + ): """Perform Operator Variational Inference Parameters @@ -104,6 +113,8 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): calls provided functions after each iteration step progressbar : bool whether to show progressbar or not + progressbar_theme : Theme + Custom theme for the progress bar Other Parameters ---------------- @@ -136,14 +147,15 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): callbacks = [] score = self._maybe_score(score) step_func = self.objective.step_function(score=score, **kwargs) - if progressbar: - progress = progress_bar(range(n), display=progressbar) - else: - progress = range(n) + if score: - state = self._iterate_with_loss(0, n, step_func, progress, callbacks) + state = self._iterate_with_loss( + 0, n, step_func, progressbar, progressbar_theme, callbacks + ) else: - state = self._iterate_without_loss(0, n, step_func, progress, callbacks) + state = self._iterate_without_loss( + 0, n, step_func, progressbar, progressbar_theme, callbacks + ) # hack to allow pm.fit() access to loss hist self.approx.hist = self.hist @@ -151,43 +163,46 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): return self.approx - def _iterate_without_loss(self, s, _, step_func, progress, callbacks): + def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): i = 0 try: - for i in progress: - step_func() - current_param = self.approx.params[0].get_value() - if np.isnan(current_param).any(): - name_slc = [] - tmp_hold = list(range(current_param.size)) - for varname, slice_info in self.approx.groups[0].ordering.items(): - slclen = len(tmp_hold[slice_info[1]]) - for j in range(slclen): - name_slc.append((varname, j)) - index = np.where(np.isnan(current_param))[0] - errmsg = ["NaN occurred in optimization. "] - suggest_solution = ( - "Try tracking this parameter: " - "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" - ) - try: - for ii in index: - errmsg.append( - "The current approximation of RV `{}`.ravel()[{}]" - " is NaN.".format(*name_slc[ii]) - ) - errmsg.append(suggest_solution) - except IndexError: - pass - raise FloatingPointError("\n".join(errmsg)) - for callback in callbacks: - callback(self.approx, None, i + s + 1) + with Progress(console=Console(theme=progressbar_theme)) as progress: + task = progress.add_task("Fitting", total=n, visible=progressbar) + for i in range(n): + step_func() + progress.update(task, advance=1) + current_param = self.approx.params[0].get_value() + if np.isnan(current_param).any(): + name_slc = [] + tmp_hold = list(range(current_param.size)) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) + for j in range(slclen): + name_slc.append((varname, j)) + index = np.where(np.isnan(current_param))[0] + errmsg = ["NaN occurred in optimization. "] + suggest_solution = ( + "Try tracking this parameter: " + "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" + ) + try: + for ii in index: + errmsg.append( + "The current approximation of RV `{}`.ravel()[{}]" + " is NaN.".format(*name_slc[ii]) + ) + errmsg.append(suggest_solution) + except IndexError: + pass + raise FloatingPointError("\n".join(errmsg)) + for callback in callbacks: + callback(self.approx, None, i + s + 1) except (KeyboardInterrupt, StopIteration) as e: if isinstance(e, StopIteration): logger.info(str(e)) return State(i + s, step=step_func, callbacks=callbacks, score=False) - def _iterate_with_loss(self, s, n, step_func, progress, callbacks): + def _iterate_with_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): def _infmean(input_array): """Return the mean of the finite values of the array""" input_array = input_array[np.isfinite(input_array)].astype("float64") @@ -200,44 +215,49 @@ def _infmean(input_array): scores[:] = np.nan i = 0 try: - for i in progress: - e = step_func() - if np.isnan(e): - scores = scores[:i] - self.hist = np.concatenate([self.hist, scores]) - current_param = self.approx.params[0].get_value() - name_slc = [] - tmp_hold = list(range(current_param.size)) - for varname, slice_info in self.approx.groups[0].ordering.items(): - slclen = len(tmp_hold[slice_info[1]]) - for j in range(slclen): - name_slc.append((varname, j)) - index = np.where(np.isnan(current_param))[0] - errmsg = ["NaN occurred in optimization. "] - suggest_solution = ( - "Try tracking this parameter: " - "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" - ) - try: - for ii in index: - errmsg.append( - "The current approximation of RV `{}`.ravel()[{}]" - " is NaN.".format(*name_slc[ii]) - ) - errmsg.append(suggest_solution) - except IndexError: - pass - raise FloatingPointError("\n".join(errmsg)) - scores[i] = e - if i % 10 == 0: - avg_loss = _infmean(scores[max(0, i - 1000) : i + 1]) - if hasattr(progress, "comment"): - progress.comment = f"Average Loss = {avg_loss:,.5g}" - avg_loss = scores[max(0, i - 1000) : i + 1].mean() - if hasattr(progress, "comment"): - progress.comment = f"Average Loss = {avg_loss:,.5g}" - for callback in callbacks: - callback(self.approx, scores[: i + 1], i + s + 1) + with Progress( + *Progress.get_default_columns(), + TextColumn("{task.fields[loss]}"), + console=Console(theme=progressbar_theme), + ) as progress: + task = progress.add_task("Fitting:", total=n, visible=progressbar, loss="") + for i in range(n): + e = step_func() + progress.update(task, advance=1) + if np.isnan(e): + scores = scores[:i] + self.hist = np.concatenate([self.hist, scores]) + current_param = self.approx.params[0].get_value() + name_slc = [] + tmp_hold = list(range(current_param.size)) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) + for j in range(slclen): + name_slc.append((varname, j)) + index = np.where(np.isnan(current_param))[0] + errmsg = ["NaN occurred in optimization. "] + suggest_solution = ( + "Try tracking this parameter: " + "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" + ) + try: + for ii in index: + errmsg.append( + "The current approximation of RV `{}`.ravel()[{}]" + " is NaN.".format(*name_slc[ii]) + ) + errmsg.append(suggest_solution) + except IndexError: + pass + raise FloatingPointError("\n".join(errmsg)) + scores[i] = e + if i % 10 == 0: + avg_loss = _infmean(scores[max(0, i - 1000) : i + 1]) + progress.update(task, loss=f"Average Loss = {avg_loss:,.5g}") + avg_loss = scores[max(0, i - 1000) : i + 1].mean() + progress.update(task, loss=f"Average Loss = {avg_loss:,.5g}") + for callback in callbacks: + callback(self.approx, scores[: i + 1], i + s + 1) except (KeyboardInterrupt, StopIteration) as e: # pragma: no cover # do not print log on the same line scores = scores[:i] @@ -261,19 +281,17 @@ def _infmean(input_array): self.hist = np.concatenate([self.hist, scores]) return State(i + s, step=step_func, callbacks=callbacks, score=True) - def refine(self, n, progressbar=True): + def refine(self, n, progressbar=True, progressbar_theme=default_progress_theme): """Refine the solution using the last compiled step function""" if self.state is None: raise TypeError("Need to call `.fit` first") i, step, callbacks, score = self.state - if progressbar: - progress = progress_bar(range(n), display=progressbar) - else: - progress = range(n) # This is a guess at what progress_bar(n) does. if score: - state = self._iterate_with_loss(i, n, step, progress, callbacks) + state = self._iterate_with_loss(i, n, step, progressbar, progressbar_theme, callbacks) else: - state = self._iterate_without_loss(i, n, step, progress, callbacks) + state = self._iterate_without_loss( + i, n, step, progressbar, progressbar_theme, callbacks + ) self.state = state @@ -630,6 +648,7 @@ def fit( score=None, callbacks=None, progressbar=True, + progressbar_theme=default_progress_theme, obj_n_mc=500, **kwargs, ): @@ -638,6 +657,7 @@ def fit( score=score, callbacks=callbacks, progressbar=progressbar, + progressbar_theme=progressbar_theme, obj_n_mc=obj_n_mc, **kwargs, ) @@ -688,6 +708,8 @@ def fit( calls provided functions after each iteration step progressbar: bool whether to show progressbar or not + progressbar_theme: Theme + Custom theme for the progress bar obj_n_mc: `int` Number of monte carlo samples used for approximation of objective gradients tf_n_mc: `int` diff --git a/requirements-dev.txt b/requirements-dev.txt index ddcf9ded9b9..56077f3a6b1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,7 +4,6 @@ arviz>=0.13.0 cachetools>=4.2.1 cloudpickle -fastprogress>=0.2.0 git+https://github.com/pymc-devs/pymc-sphinx-theme h5py>=2.7 ipython>=7.16 @@ -21,6 +20,7 @@ pre-commit>=2.8.0 pytensor>=2.19,<2.20 pytest-cov>=2.5 pytest>=3.0 +rich>=13.7.1 scipy>=1.4.1 sphinx-copybutton sphinx-design diff --git a/requirements.txt b/requirements.txt index 0bc21049c15..370dcbd41e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ arviz>=0.13.0 cachetools>=4.2.1 cloudpickle -fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0 pytensor>=2.19,<2.20 +rich>=13.7.1 scipy>=1.4.1 typing-extensions>=3.7.4