Skip to content

Commit

Permalink
Fix bug whereby partial traces have fewer draws than would be availab…
Browse files Browse the repository at this point in the history
…le (#4318)

* add test for _choose_chains

* fix bug - choose overall maximum

* update release notes

* 📝

* 🎨

* minimise diff

* minimise diff

* Update pymc3/sampling.py
  • Loading branch information
MarcoGorelli authored Dec 12, 2020
1 parent 2a38198 commit 6f15cbb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## PyMC3 3.10.1 (on deck)

### Maintenance
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).

## PyMC3 3.10.0 (7 December 2020)
Expand Down
25 changes: 12 additions & 13 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,14 @@ def _mp_sample(


def _choose_chains(traces, tune):
"""
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.
We get here after a ``KeyboardInterrupt``, and so the different
traces have different lengths. We therefore pick the number of
traces such that (number of traces) * (length of shortest trace)
is maximised.
"""
if tune is None:
tune = 0

Expand All @@ -1518,22 +1526,13 @@ def _choose_chains(traces, tune):
if not sum(lengths):
raise ValueError("Not enough samples to build a trace.")

idxs = np.argsort(lengths)[::-1]
idxs = np.argsort(lengths)
l_sort = np.array(lengths)[idxs]

final_length = l_sort[0]
last_total = 0
for i, length in enumerate(l_sort):
total = (i + 1) * length
if total < last_total:
use_until = i
break
last_total = total
final_length = length
else:
use_until = len(lengths)
use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1])
final_length = l_sort[use_until]

return [traces[idx] for idx in idxs[:use_until]], final_length + tune
return [traces[idx] for idx in idxs[use_until:]], final_length + tune


def stop_tuning(step):
Expand Down
28 changes: 28 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import pymc3 as pm

from pymc3.backends.ndarray import NDArray
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
from pymc3.tests.helpers import SeededTest
from pymc3.tests.models import simple_init
Expand Down Expand Up @@ -299,6 +300,33 @@ def test_partial_trace_sample():
trace = pm.sample(trace=[a])


@pytest.mark.parametrize(
"n_points, tune, expected_length, expected_n_traces",
[
((5, 2, 2), 0, 2, 3),
((6, 1, 1), 1, 6, 1),
],
)
def test_choose_chains(n_points, tune, expected_length, expected_n_traces):
with pm.Model() as model:
a = pm.Normal("a", mu=0, sigma=1)
trace_0 = NDArray(model)
trace_1 = NDArray(model)
trace_2 = NDArray(model)
trace_0.setup(n_points[0], 1)
trace_1.setup(n_points[1], 1)
trace_2.setup(n_points[2], 1)
for _ in range(n_points[0]):
trace_0.record({"a": 0})
for _ in range(n_points[1]):
trace_1.record({"a": 0})
for _ in range(n_points[2]):
trace_2.record({"a": 0})
traces, length = pm.sampling._choose_chains([trace_0, trace_1, trace_2], tune=tune)
assert length == expected_length
assert expected_n_traces == len(traces)


@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
class TestNamedSampling(SeededTest):
def test_shared_named(self):
Expand Down

0 comments on commit 6f15cbb

Please sign in to comment.