From 7c3fdbf24add80ce12cf5254c9b73ca0b3cab8cc Mon Sep 17 00:00:00 2001 From: kc611 Date: Thu, 28 Jan 2021 20:10:27 +0530 Subject: [PATCH 1/2] Added default testvalue support for theano.shared --- RELEASE-NOTES.md | 1 + pymc3/distributions/distribution.py | 16 ++++++++-------- pymc3/tests/test_data_container.py | 22 ++++++++++++++++++++++ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 65cde7cc6ed..5c83edc7cb1 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,7 @@ ### Maintenance - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). +- `ScalarSharedVariable` can now be used as an input to other RVs directly.(see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)) ## PyMC3 3.11.0 (21 January 2021) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 8178ae0d228..f692d7af132 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -148,17 +148,17 @@ def default(self): def get_test_val(self, val, defaults): if val is None: for v in defaults: - if hasattr(self, v) and np.all(np.isfinite(self.getattr_value(v))): - return self.getattr_value(v) - else: - return self.getattr_value(val) - - if val is None: + if hasattr(self, v): + attr_val = self.getattr_value(v) + if np.all(np.isfinite(attr_val)): + return attr_val raise AttributeError( - "%s has no finite default value to use, " + "%s has no finite default value to use " "checked: %s. Pass testval argument or " "adjust so value is finite." % (self, str(defaults)) ) + else: + return self.getattr_value(val) def getattr_value(self, val): if isinstance(val, string_types): @@ -167,7 +167,7 @@ def getattr_value(self, val): if isinstance(val, tt.TensorVariable): return val.tag.test_value - if isinstance(val, tt.sharedvar.TensorSharedVariable): + if isinstance(val, tt.sharedvar.SharedVariable): return val.get_value() if isinstance(val, theano_constant): diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index d3eaf2fb7ff..966ce47cd6a 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -16,6 +16,8 @@ import pandas as pd import pytest +from theano import shared + import pymc3 as pm from pymc3.tests.helpers import SeededTest @@ -156,6 +158,26 @@ def test_shared_data_as_rv_input(self): np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1) np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1) + def test_shared_scalar_as_rv_input(self): + # See https://github.com/pymc-devs/pymc3/issues/3139 + with pm.Model() as m: + shared_var = shared(5.0) + v = pm.Normal("v", mu=shared_var, shape=1) + + np.testing.assert_allclose( + v.logp({"v": [5.0]}), + -0.91893853, + rtol=1e-5, + ) + + shared_var.set_value(10.0) + + np.testing.assert_allclose( + v.logp({"v": [10.0]}), + -0.91893853, + rtol=1e-5, + ) + def test_creation_of_data_outside_model_context(self): with pytest.raises((IndexError, TypeError)) as error: pm.Data("data", [1.1, 2.2, 3.3]) From 1fb71a4a7437de5dc2876200762c4f4106e826e6 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Sun, 31 Jan 2021 10:47:02 +0100 Subject: [PATCH 2/2] Tiny reformatting --- RELEASE-NOTES.md | 2 +- pymc3/distributions/distribution.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 5c83edc7cb1..67d1c752ae4 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,7 +8,7 @@ ### Maintenance - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). -- `ScalarSharedVariable` can now be used as an input to other RVs directly.(see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)) +- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)). ## PyMC3 3.11.0 (21 January 2021) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index f692d7af132..c24a9d9df6e 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -153,7 +153,7 @@ def get_test_val(self, val, defaults): if np.all(np.isfinite(attr_val)): return attr_val raise AttributeError( - "%s has no finite default value to use " + "%s has no finite default value to use, " "checked: %s. Pass testval argument or " "adjust so value is finite." % (self, str(defaults)) )