Skip to content

Commit

Permalink
Merge pull request #741 from stan-dev/feat/738-pathfinder-threads
Browse files Browse the repository at this point in the history
Add a num_threads helper argument to pathfinder()
  • Loading branch information
WardBrian authored Mar 25, 2024
2 parents 742b409 + 71d22e0 commit de2e73c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 1 deletion.
1 change: 1 addition & 0 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ def validate(self) -> None:
if not (
isinstance(self.method_args, SamplerArgs)
and self.method_args.num_chains > 1
or isinstance(self.method_args, PathfinderArgs)
):
if not os.path.exists(self.inits):
raise ValueError('no such file {}'.format(self.inits))
Expand Down
16 changes: 16 additions & 0 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,7 @@ def pathfinder(
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
num_threads: Optional[int] = None,
) -> CmdStanPathfinder:
"""
Run CmdStan's Pathfinder variational inference algorithm.
Expand Down Expand Up @@ -1689,6 +1690,10 @@ def pathfinder(
:param timeout: Duration at which Pathfinder times
out in seconds. Defaults to None.
:param num_threads: Number of threads to request for parallel execution.
A number other than ``1`` requires the model to have been compiled
with STAN_THREADS=True.
:return: A :class:`CmdStanPathfinder` object
References
Expand All @@ -1715,6 +1720,17 @@ def pathfinder(
"available for CmdStan versions 2.34 and later"
)

if num_threads is not None:
if (
num_threads != 1
and exe_info.get('STAN_THREADS', '').lower() != 'true'
):
raise ValueError(
"Model must be compiled with 'STAN_THREADS=true' to use"
" 'num_threads' argument"
)
os.environ['STAN_NUM_THREADS'] = str(num_threads)

if num_paths == 1:
if num_single_draws is None:
num_single_draws = draws
Expand Down
39 changes: 39 additions & 0 deletions test/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Tests for the Pathfinder method.
"""

import contextlib
from io import StringIO
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -129,6 +131,26 @@ def test_pathfinder_init_sampling():
assert fit.draws().shape == (1000, 4, 9)


def test_inits_for_pathfinder():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')
bern_model.pathfinder(
jdata, inits=[{"theta": 0.1}, {"theta": 0.9}], num_paths=2
)

# second path is initialized too large!
with contextlib.redirect_stdout(StringIO()) as captured:
bern_model.pathfinder(
jdata,
inits=[{"theta": 0.1}, {"theta": 1.1}],
num_paths=2,
show_console=True,
)

assert "Bounded variable is 1.1" in captured.getvalue()


def test_pathfinder_no_psis():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
Expand All @@ -152,3 +174,20 @@ def test_pathfinder_no_lp_calc():
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
assert n_lp_nan > 3000 # but most are not


def test_pathfinder_threads():
stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')

bern_model.pathfinder(data=jdata, num_threads=1)

with pytest.raises(ValueError, match="STAN_THREADS"):
bern_model.pathfinder(data=jdata, num_threads=4)

bern_model = cmdstanpy.CmdStanModel(
stan_file=stan, cpp_options={'STAN_THREADS': True}, force_compile=True
)
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
assert pathfinder.draws().shape == (1000, 3)
4 changes: 3 additions & 1 deletion test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)
def test_bernoulli_good(stanfile: str):
stan = os.path.join(DATAFILES_PATH, stanfile)
bern_model = CmdStanModel(stan_file=stan)
bern_model = CmdStanModel(stan_file=stan, force_compile=True)

jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = bern_model.sample(
Expand All @@ -74,6 +74,8 @@ def test_bernoulli_good(stanfile: str):

for i in range(bern_fit.runset.chains):
csv_file = bern_fit.runset.csv_files[i]
# NB: This will fail if STAN_THREADS is enabled
# due to sampling only producing 1 stdout file in that case
stdout_file = bern_fit.runset.stdout_files[i]
assert os.path.exists(csv_file)
assert os.path.exists(stdout_file)
Expand Down

0 comments on commit de2e73c

Please sign in to comment.