Skip to content

Commit

Permalink
Replace fastprogress progress bars with rich (pymc-devs#7233)
Browse files Browse the repository at this point in the history
* Replace fastprogress with rich

* Bugfixes for ADVI progress bars

* Bugfixes for MAP progress bars

* Fixed final update to progress bar

* SMC progress bar working

* Fixes to MAP progress bar

* Customize progress bar theme

* Added progressbar_theme argument

* Moved default progressbar theme to util

* Convert compute_log_density to use Progress instead of track

* Getting rid of mypy complaint
  • Loading branch information
fonnesbeck authored and mkusnetsov committed Oct 26, 2024
1 parent ac1be86 commit a30b4a4
Show file tree
Hide file tree
Showing 17 changed files with 372 additions and 299 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- numpy>=1.15.0
- pandas>=0.24.0
Expand All @@ -28,6 +27,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- rich>=13.7.1
- sphinx-copybutton
- sphinx-design
- sphinx-notfound-page
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ dependencies:
- arviz>=0.13.0
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.19,<2.20
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for docs build
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
# Jaxlib version must not be greater than jax version!
- blackjax>=1.0.0
Expand All @@ -24,6 +23,7 @@ dependencies:
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- jax
- libblas=*=*mkl
Expand All @@ -20,6 +19,7 @@ dependencies:
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for dev, testing and docs build
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- libpython
- mkl-service>=2.3.0
Expand All @@ -20,6 +19,7 @@ dependencies:
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
48 changes: 26 additions & 22 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import xarray

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from pytensor import tensor as pt
from pytensor.graph.basic import (
Apply,
Expand All @@ -46,6 +45,9 @@
RandomStateSharedVariable,
)
from pytensor.tensor.sharedvar import SharedVariable
from rich.console import Console
from rich.progress import Progress
from rich.theme import Theme
from typing_extensions import TypeAlias

import pymc as pm
Expand All @@ -59,6 +61,7 @@
RandomState,
_get_seeds_per_chain,
dataset_to_point_list,
default_progress_theme,
get_default_varnames,
point_wrapper,
)
Expand All @@ -70,7 +73,6 @@
"sample_posterior_predictive",
)


ArrayLike: TypeAlias = Union[np.ndarray, list[float]]
PointList: TypeAlias = list[PointType]

Expand Down Expand Up @@ -442,6 +444,7 @@ def sample_posterior_predictive(
sample_dims: Optional[list[str]] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
return_inferencedata: bool = True,
extend_inferencedata: bool = False,
predictions: bool = False,
Expand Down Expand Up @@ -796,10 +799,6 @@ def sample_posterior_predictive(
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)

indices = np.arange(samples)
if progressbar:
indices = progress_bar(indices, total=samples, display=progressbar)

vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))

if not vars_to_sample:
Expand Down Expand Up @@ -834,25 +833,30 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
for idx in indices:
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
chain_idx = chain_idx % nchain
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
# ... or a PointList
with Progress(console=Console(theme=progressbar_theme)) as progress:
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
for idx in np.arange(samples):
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
chain_idx = chain_idx % nchain
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
# ... or a PointList
else:
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
# there's only a single chain, but the index might hit it multiple times if
# the number of indices is greater than the length of the trace.
else:
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
# there's only a single chain, but the index might hit it multiple times if
# the number of indices is greater than the length of the trace.
else:
param = _trace[idx % len_trace]
param = _trace[idx % len_trace]

values = sampler_fn(**param)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)

values = sampler_fn(**param)
progress.advance(task)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
except KeyboardInterrupt:
pass

Expand Down
39 changes: 25 additions & 14 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from fastprogress.fastprogress import progress_bar
from pytensor.graph.basic import Variable
from rich.console import Console
from rich.progress import Progress
from rich.theme import Theme
from typing_extensions import Protocol, TypeAlias

import pymc as pm
Expand Down Expand Up @@ -65,6 +67,7 @@
RandomSeed,
RandomState,
_get_seeds_per_chain,
default_progress_theme,
drop_warning_stat,
get_untransformed_name,
is_transformed_name,
Expand Down Expand Up @@ -377,6 +380,7 @@ def sample(
cores: Optional[int] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
step=None,
var_names: Optional[Sequence[str]] = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -406,6 +410,7 @@ def sample(
cores: Optional[int] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
step=None,
var_names: Optional[Sequence[str]] = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -435,6 +440,7 @@ def sample(
cores: Optional[int] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
step=None,
var_names: Optional[Sequence[str]] = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -761,6 +767,7 @@ def sample(
"tune": tune,
"var_names": var_names,
"progressbar": progressbar,
"progressbar_theme": progressbar_theme,
"model": model,
"cores": cores,
"callback": callback,
Expand Down Expand Up @@ -983,6 +990,7 @@ def _sample(
trace: IBaseTrace,
tune: int,
model: Optional[Model] = None,
progressbar_theme: Optional[Theme] = default_progress_theme,
callback=None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -1010,6 +1018,8 @@ def _sample(
tune : int
Number of iterations to tune.
model : Model (optional if in ``with`` context)
progressbar_theme : Theme
Optional custom theme for the progress bar.
"""
skip_first = kwargs.get("skip_first", 0)

Expand All @@ -1026,19 +1036,16 @@ def _sample(
)
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
if progressbar:
sampling = progress_bar(sampling_gen, total=draws, display=progressbar)
sampling.comment = _desc.format(**_pbar_data)
else:
sampling = sampling_gen
try:
for it, diverging in enumerate(sampling):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
if progressbar:
sampling.comment = _desc.format(**_pbar_data)
except KeyboardInterrupt:
pass
with Progress(console=Console(theme=progressbar_theme)) as progress:
try:
task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar)
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, advance=1)
progress.update(task, advance=1, completed=True)
except KeyboardInterrupt:
pass


def _iter_sample(
Expand Down Expand Up @@ -1131,6 +1138,7 @@ def _mp_sample(
random_seed: Sequence[RandomSeed],
start: Sequence[PointType],
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
traces: Sequence[IBaseTrace],
model: Optional[Model] = None,
callback: Optional[SamplingIteratorCallback] = None,
Expand Down Expand Up @@ -1158,6 +1166,8 @@ def _mp_sample(
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
progressbar : bool
Whether or not to display a progress bar in the command line.
progressbar_theme : Theme
Optional custom theme for the progress bar.
traces
Recording backends for each chain.
model : Model (optional if in ``with`` context)
Expand All @@ -1182,6 +1192,7 @@ def _mp_sample(
start_points=start,
step_method=step,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
mp_ctx=mp_ctx,
)
try:
Expand Down
Loading

0 comments on commit a30b4a4

Please sign in to comment.