Skip to content

Commit

Permalink
Merge pull request #3841 from rpgoldman/iss3840
Browse files Browse the repository at this point in the history
Fix computation of samples argument in sample_posterior_predictive
Solves #3840
  • Loading branch information
rpgoldman authored Mar 19, 2020
2 parents 363afc8 + 839206b commit 74b7788
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
8 changes: 7 additions & 1 deletion pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,13 @@ def sample_posterior_predictive(
raise IncorrectArgumentsError("Should not specify both keep_size and size argukments")

if samples is None:
samples = sum(len(v) for v in trace._straces.values())
if isinstance(trace, MultiTrace):
samples = sum(len(v) for v in trace._straces.values())
elif isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
# this is a list of points
samples = len(trace)
else:
raise ValueError("Do not know how to compute number of samples for trace argument of type %s"%type(trace))

if samples < len_trace * nchain:
warnings.warn(
Expand Down
29 changes: 29 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from itertools import combinations
from typing import Tuple
import numpy as np

try:
Expand Down Expand Up @@ -693,6 +694,16 @@ def test_exec_nuts_init(method):
assert "a" in start[0] and "b_log__" in start[0]


@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
with pm.Model() as pmodel:
n = pm.Normal('n')
trace = pm.sample()

with pmodel:
d = pm.Deterministic('d', n * 4)
return pmodel, trace

class TestSamplePriorPredictive(SeededTest):
def test_ignores_observed(self):
observed = np.random.normal(10, 1, size=200)
Expand Down Expand Up @@ -851,3 +862,21 @@ def test_bounded_dist(self):
with model:
prior_trace = pm.sample_prior_predictive(5)
assert prior_trace["x"].shape == (5, 3, 1)

class TestSamplePosteriorPredictive:
def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
pp = pm.fast_sample_posterior_predictive(
[trace[15]],
var_names=['d']
)

def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
pp = pm.sample_posterior_predictive(
[trace[15]],
var_names=['d']
)

0 comments on commit 74b7788

Please sign in to comment.