diff --git a/.gitignore b/.gitignore index 4ea251c4..6be4482a 100644 --- a/.gitignore +++ b/.gitignore @@ -44,4 +44,4 @@ pytestdebug.log # Codespaces pythonenv* -env/ \ No newline at end of file +env/ diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 25af7644..1c76343d 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -33,13 +33,21 @@ def coords(): return dict(test=range(3), simplex=range(4)) -@pytest.fixture -def user_param_cfg(): - return ("t",), dict( - a="d", - b=dict(transform=transforms.log, dims=("test",)), - c=dict(transform=transforms.simplex, dims=("simplex",)), - ) +@pytest.fixture( + params=[ + [ + ("t",), + dict( + a="d", + b=dict(transform=transforms.log, dims=("test",)), + c=dict(transform=transforms.simplex, dims=("simplex",)), + ), + ], + [("t",), dict()], + ] +) +def user_param_cfg(request): + return request.param @pytest.fixture diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 204de306..b5a0e0de 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -94,6 +94,7 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: def _mean_chol(flat_array: np.ndarray): mean = flat_array.mean(0) cov = np.cov(flat_array, rowvar=False) + cov = np.atleast_2d(cov) chol = np.linalg.cholesky(cov) return mean, chol