From 7c9a8387459ec3ecb4b1061570e1640f508e9708 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sun, 21 Apr 2024 17:04:28 -0500 Subject: [PATCH 1/9] Add time remaining column to progress bars --- pymc/sampling/parallel.py | 4 +++- pymc/sampling/population.py | 4 +++- pymc/smc/sampling.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index cc6908647ec..9505daef70c 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,7 +27,7 @@ import numpy as np from rich.console import Console -from rich.progress import BarColumn, Progress, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElapsedColumn from rich.theme import Theme from pymc.blocking import DictToArrayBijection @@ -428,6 +428,8 @@ def __init__( BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn() console=Console(theme=progressbar_theme), ) self._show_progress = progressbar diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 1627bb8de77..660b9e60e9d 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from rich.progress import BarColumn, Progress, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElaspedColumn from pymc.backends.base import BaseTrace from pymc.initial_point import PointType @@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), + TextColumn("/"), + TimeElaspedColumn(), ) as self._progress: for c, stepper in enumerate(steppers): # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index d9b76f211ce..2b83cd2cfac 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -25,7 +25,7 @@ import numpy as np from arviz import InferenceData -from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn +from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn import pymc @@ -367,6 +367,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): TextColumn("{task.description}"), SpinnerColumn(), TimeElapsedColumn(), + TextColumn("/"), + TimeRemainingColumn(), TextColumn("{task.fields[status]}"), ) as progress: futures = [] # keep track of the jobs From e5d84ba66c2c61c54da530fd15827d9af2863503 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sun, 21 Apr 2024 17:10:33 -0500 Subject: [PATCH 2/9] Consistent order remaining/elapsed --- pymc/smc/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 2b83cd2cfac..6166d0d22e4 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -366,9 +366,9 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): with Progress( TextColumn("{task.description}"), SpinnerColumn(), - TimeElapsedColumn(), - TextColumn("/"), TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), TextColumn("{task.fields[status]}"), ) as progress: futures = [] # keep track of the jobs From b04ebdb3459bbde2dea526ecdd40ff599319ee11 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 23 Apr 2024 08:45:35 -0500 Subject: [PATCH 3/9] Disable sample_posterior_predictive taskbar when progressbar=False --- pymc/sampling/forward.py | 2 +- pymc/sampling/parallel.py | 2 +- pymc/sampling/population.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index fe0f2085bb0..dc89ebc7035 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -829,7 +829,7 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: task = progress.add_task("Sampling ...", total=samples, visible=progressbar) for idx in np.arange(samples): if nchain > 1: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 9505daef70c..c32ed23ef35 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -429,7 +429,7 @@ def __init__( "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), - TimeElapsedColumn() + TimeElapsedColumn(), console=Console(theme=progressbar_theme), ) self._show_progress = progressbar diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 660b9e60e9d..69cb4793726 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElaspedColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElapsedColumn from pymc.backends.base import BaseTrace from pymc.initial_point import PointType @@ -181,7 +181,7 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), - TimeElaspedColumn(), + TimeElapsedColumn(), ) as self._progress: for c, stepper in enumerate(steppers): # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) From 81c4136310badfc7fbbbe0b60bc73e5fc142bcce Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 23 Apr 2024 08:47:50 -0500 Subject: [PATCH 4/9] Formatting --- pymc/smc/sampling.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 6166d0d22e4..33c8f937c71 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -25,7 +25,13 @@ import numpy as np from arviz import InferenceData -from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.progress import ( + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) import pymc From 0d62304d04250457ea2910f12c2a5ff05584df1f Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 23 Apr 2024 08:54:41 -0500 Subject: [PATCH 5/9] More formatting --- pymc/sampling/parallel.py | 2 +- pymc/sampling/population.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index c32ed23ef35..97057e424aa 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,7 +27,7 @@ import numpy as np from rich.console import Console -from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElapsedColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from pymc.blocking import DictToArrayBijection diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 69cb4793726..339f3d419f6 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElapsedColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn from pymc.backends.base import BaseTrace from pymc.initial_point import PointType From bc1c489cb471c13a9c12b99c67b5fb17c45a0e8e Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 23 Apr 2024 08:59:27 -0500 Subject: [PATCH 6/9] More formatting (why doesnt pre-commit fix this?) --- pymc/sampling/forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index dc89ebc7035..cc1dcd52f6e 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -829,7 +829,9 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: + with Progress( + console=Console(theme=progressbar_theme), disable=not progressbar + ) as progress: task = progress.add_task("Sampling ...", total=samples, visible=progressbar) for idx in np.arange(samples): if nchain > 1: From 6af84cf8275abbdab8e54dbddaffe46525ecec31 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 24 Apr 2024 10:16:45 -0500 Subject: [PATCH 7/9] Disable progress bar when progress=False --- pymc/backends/arviz.py | 2 +- pymc/sampling/parallel.py | 1 + pymc/tuning/starting.py | 1 + pymc/variational/inference.py | 4 +++- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 3af240725fd..5ead7ac1bde 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -659,7 +659,7 @@ def apply_function_over_dataset( out_dict = _DefaultTrace(n_pts) indices = range(n_pts) - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar) for idx in indices: out = fn(posterior_pts[idx]) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 97057e424aa..05b34b4f0d0 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -431,6 +431,7 @@ def __init__( TextColumn("/"), TimeElapsedColumn(), console=Console(theme=progressbar_theme), + disable=not progressbar, ) self._show_progress = progressbar self._divergences = 0 diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 129d6f89730..44d77810f64 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -223,6 +223,7 @@ def __init__( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), console=Console(theme=progressbar_theme), + disable=not progressbar, ) self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 3d9e6fd8eae..3a9a69add72 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -166,7 +166,9 @@ def fit( def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): i = 0 try: - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress( + console=Console(theme=progressbar_theme), disable=not progressbar + ) as progress: task = progress.add_task("Fitting", total=n, visible=progressbar) for i in range(n): step_func() From 0ee2396728683ed8d32e14c1c610c68abbc61297 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 25 Apr 2024 11:46:07 -0500 Subject: [PATCH 8/9] Set refresh flag in progress bar updates --- pymc/sampling/mcmc.py | 4 ++-- pymc/sampling/parallel.py | 1 + pymc/sampling/population.py | 2 +- pymc/smc/sampling.py | 4 +++- pymc/tuning/starting.py | 2 +- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7f750090f81..96190e5ff03 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1041,8 +1041,8 @@ def _sample( for it, diverging in enumerate(sampling_gen): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 - progress.update(task, advance=1) - progress.update(task, advance=1, completed=True) + progress.update(task, refresh=True, advance=1) + progress.update(task, refresh=True, advance=1, completed=True) except KeyboardInterrupt: pass diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 05b34b4f0d0..c2f9791de5c 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -468,6 +468,7 @@ def __iter__(self): self._divergences += 1 progress.update( task, + refresh=True, completed=self._completed_draws, total=self._total_draws, description=self._desc.format(self), diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 339f3d419f6..2b0aad2b32a 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -104,7 +104,7 @@ def _sample_population( task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar) for _ in sampling: - progress.update(task, advance=1) + progress.update(task, advance=1, refresh=True) return diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 33c8f937c71..db4044a4fe9 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -411,6 +411,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): stage = update_data["stage"] beta = update_data["beta"] # update the progress bar for this task: - progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id) + progress.update( + status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True + ) return tuple(cloudpickle.loads(r.result()) for r in futures) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 44d77810f64..09b787c506d 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -178,7 +178,7 @@ def find_MAP( if isinstance(e, StopIteration): pm._log.info(e) finally: - cost_func.progress.update(cost_func.task, completed=cost_func.n_eval) + cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True) print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) From 1964c8ab96ee31a8e25a3fa8df0c19330d5ff9c6 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 26 Apr 2024 07:51:10 -0500 Subject: [PATCH 9/9] Typo --- pymc/backends/arviz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 5ead7ac1bde..40ab7d3ae0e 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -660,7 +660,7 @@ def apply_function_over_dataset( indices = range(n_pts) with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: - task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar) + task = progress.add_task("Computing ...", total=n_pts, visible=progressbar) for idx in indices: out = fn(posterior_pts[idx]) fn.f.trust_input = True # If we arrive here the dtypes are valid