Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some UserWarnings #6407

Merged
merged 4 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,12 @@ def sample_stats_to_xarray(self):
continue
if self.warmup_trace:
data_warmup[name] = np.array(
self.warmup_trace.get_sampler_stats(stat, combine=False)
self.warmup_trace.get_sampler_stats(stat, combine=False, squeeze=False)
)
if self.posterior_trace:
data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
data[name] = np.array(
self.posterior_trace.get_sampler_stats(stat, combine=False, squeeze=False)
)

return (
dict_to_dataset(
Expand Down
23 changes: 21 additions & 2 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from copy import copy
from typing import Dict, List, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -210,3 +209,23 @@ def ignore_logprob(rv: TensorVariable) -> TensorVariable:
return rv
new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
return new_node.outputs[node.outputs.index(rv)]


def reconsider_logprob(rv: TensorVariable) -> TensorVariable:
"""Return a duplicated variable that is considered when creating logprob graphs

This undoes the effect of `ignore_logprob`.

If a variable was not ignored, it is returned directly.
"""
prefix = "Unmeasurable"
node = rv.owner
op_type = type(node.op)
if not op_type.__name__.startswith(prefix):
return rv

new_node = node.clone()
original_op_type = new_node.op.original_op_type
new_node.op = copy(new_node.op)
new_node.op.__class__ = original_op_type
return new_node.outputs[node.outputs.index(rv)]
19 changes: 15 additions & 4 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytensor
import pytensor.tensor as at

from pytensor.graph.basic import Node
from pytensor.graph.basic import Node, ancestors
from pytensor.graph.replace import clone_replace
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
Expand All @@ -33,7 +33,7 @@
_moment,
moment,
)
from pymc.distributions.logprob import ignore_logprob, logp
from pymc.distributions.logprob import ignore_logprob, logp, reconsider_logprob
from pymc.distributions.multivariate import MvNormal, MvStudentT
from pymc.distributions.shape_utils import (
_change_dist_size,
Expand Down Expand Up @@ -106,6 +106,15 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVari
"init_dist and innovation_dist must have the same support dimensionality"
)

# We need to check this, because we clone the variables when we ignore their logprob next
if init_dist in ancestors([innovation_dist]) or innovation_dist in ancestors([init_dist]):
raise ValueError("init_dist and innovation_dist must be completely independent")

# PyMC should not be concerned that these variables don't have values, as they will be
# accounted for in the logp of RandomWalk
init_dist = ignore_logprob(init_dist)
innovation_dist = ignore_logprob(innovation_dist)

steps = cls.get_steps(
innovation_dist=innovation_dist,
steps=steps,
Expand Down Expand Up @@ -225,12 +234,14 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):


@_logprob.register(RandomWalkRV)
def random_walk_logp(op, values, *inputs, **kwargs):
def random_walk_logp(op, values, init_dist, innovation_dist, steps, **kwargs):
# Although we can derive the logprob of random walks, it does not collapse
# what we consider the core dimension of steps. We do it manually here.
(value,) = values
# Recreate RV and obtain inner graph
rv_node = op.make_node(*inputs)
rv_node = op.make_node(
reconsider_logprob(init_dist), reconsider_logprob(innovation_dist), steps
)
rv = clone_replace(
op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, rv_node.inputs)}
)[op.default_output]
Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _get_measurable_outputs_RandomVariable(op, node):


def noop_measurable_outputs_fn(*args, **kwargs):
return None
return []


def assign_custom_measurable_outputs(
Expand Down Expand Up @@ -220,6 +220,7 @@ def assign_custom_measurable_outputs(

new_op_dict = op_type.__dict__.copy()
new_op_dict["id_obj"] = (new_node.op, measurable_outputs_fn)
new_op_dict.setdefault("original_op_type", op_type)

new_op_type = type(
f"{type_prefix}{op_type.__name__}", (op_type, UnmeasurableVariable), new_op_dict
Expand Down
18 changes: 17 additions & 1 deletion pymc/tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
)
from pymc.exceptions import ImputationWarning

# Turn all warnings into errors for this module
pytestmark = pytest.mark.filterwarnings("error")


@pytest.fixture(scope="module")
def eight_schools_params():
Expand Down Expand Up @@ -635,7 +638,9 @@ def test_include_transformed(self):
pm.Uniform("p", 0, 1)

# First check that the default is to exclude the transformed variables
sample_kwargs = dict(tune=5, draws=7, chains=2, cores=1)
sample_kwargs = dict(
tune=5, draws=7, chains=2, cores=1, compute_convergence_checks=False
)
inference_data = pm.sample(**sample_kwargs, step=pm.Metropolis())
assert "p_interval__" not in inference_data.posterior

Expand All @@ -647,6 +652,17 @@ def test_include_transformed(self):
)
assert "p_interval__" in inference_data.posterior

@pytest.mark.parametrize("chains", (1, 2))
def test_single_chain(self, chains):
# Test that no UserWarning is raised when sampling with NUTS defaults

# When this test was added, a `UserWarning: More chains (500) than draws (1)` used to be issued
# when sampling with a single chain
warnings.simplefilter("error")
with pm.Model():
pm.Normal("x")
pm.sample(chains=chains, return_inferencedata=True)


class TestPyMCWarmupHandling:
@pytest.mark.parametrize("save_warmup", [False, True])
Expand Down
41 changes: 29 additions & 12 deletions pymc/tests/distributions/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ignore_logprob,
logcdf,
logp,
reconsider_logprob,
)
from pymc.logprob.abstract import get_measurable_outputs
from pymc.model import Model, Potential
Expand Down Expand Up @@ -315,7 +316,7 @@ def test_unexpected_rvs():
model.logp()


def test_ignore_logprob_basic():
def test_ignore_reconsider_logprob_basic():
x = Normal.dist()
(measurable_x_out,) = get_measurable_outputs(x.owner.op, x.owner)
assert measurable_x_out is x.owner.outputs[1]
Expand All @@ -325,21 +326,37 @@ def test_ignore_logprob_basic():
assert isinstance(new_x.owner.op, Normal)
assert type(new_x.owner.op).__name__ == "UnmeasurableNormalRV"
# Confirm that it does not have measurable output
assert get_measurable_outputs(new_x.owner.op, new_x.owner) is None
assert get_measurable_outputs(new_x.owner.op, new_x.owner) == []

# Test that it will not clone a variable that is already unmeasurable
new_new_x = ignore_logprob(new_x)
assert new_new_x is new_x


def test_ignore_logprob_model():
# logp that does not depend on input
def logp(value, x):
return value
assert ignore_logprob(new_x) is new_x

orig_x = reconsider_logprob(new_x)
assert orig_x is not new_x
assert isinstance(orig_x.owner.op, Normal)
assert type(orig_x.owner.op).__name__ == "NormalRV"
# Confirm that it has measurable outputs again
assert get_measurable_outputs(orig_x.owner.op, orig_x.owner) == [orig_x.owner.outputs[1]]

# Test that will not clone a variable that is already measurable
assert reconsider_logprob(x) is x
assert reconsider_logprob(orig_x) is orig_x


def test_ignore_reconsider_logprob_model():
def custom_logp(value, x):
# custom_logp is just the logp of x at value
x = reconsider_logprob(x)
return _joint_logp(
[x],
rvs_to_values={x: value},
rvs_to_transforms={},
rvs_to_total_sizes={},
)

with Model() as m:
x = Normal.dist()
y = CustomDist("y", x, logp=logp)
y = CustomDist("y", x, logp=custom_logp)
with pytest.warns(
UserWarning,
match="Found a random variable that was neither among the observations "
Expand All @@ -355,7 +372,7 @@ def logp(value, x):
# The above warning should go away with ignore_logprob.
with Model() as m:
x = ignore_logprob(Normal.dist())
y = CustomDist("y", x, logp=logp)
y = CustomDist("y", x, logp=custom_logp)
with warnings.catch_warnings():
warnings.simplefilter("error")
assert _joint_logp(
Expand Down
38 changes: 30 additions & 8 deletions pymc/tests/distributions/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
from pymc.tests.distributions.util import assert_moment_is_expected
from pymc.tests.helpers import select_by_precision

# Turn all warnings into errors for this module
# Ignoring NumPy deprecation warning tracked in https://github.com/pymc-devs/pytensor/issues/146
pytestmark = pytest.mark.filterwarnings("error", "ignore: NumPy will stop allowing conversion")


class TestRandomWalk:
def test_dists_types(self):
Expand Down Expand Up @@ -92,6 +96,14 @@ def test_dists_not_registered_check(self):
):
RandomWalk("rw", init_dist=init_dist, innovation_dist=innovation, steps=5)

def test_dists_independent_check(self):
init_dist = Normal.dist()
innovation_dist = Normal.dist(init_dist)
with pytest.raises(
ValueError, match="init_dist and innovation_dist must be completely independent"
):
RandomWalk.dist(init_dist=init_dist, innovation_dist=innovation_dist)

@pytest.mark.parametrize(
"init_dist, innovation_dist, steps, size, shape, "
"init_dist_size, innovation_dist_size, rw_shape",
Expand Down Expand Up @@ -423,15 +435,18 @@ def test_mvgaussian_with_chol_cov_rv(self, param):
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
# pylint: enable=unpacking-non-sequence
if param == "chol":
mv = MvGaussianRandomWalk("mv", mu, chol=chol, shape=(10, 7, 3))
elif param == "cov":
mv = MvGaussianRandomWalk("mv", mu, cov=pm.math.dot(chol, chol.T), shape=(10, 7, 3))
else:
raise ValueError
with pytest.warns(UserWarning, match="Initial distribution not specified"):
if param == "chol":
mv = MvGaussianRandomWalk("mv", mu, chol=chol, shape=(10, 7, 3))
elif param == "cov":
mv = MvGaussianRandomWalk(
"mv", mu, cov=pm.math.dot(chol, chol.T), shape=(10, 7, 3)
)
else:
raise ValueError
assert draw(mv, draws=5).shape == (5, 10, 7, 3)

@pytest.mark.parametrize("param", ["cov", "chol", "tau"])
@pytest.mark.parametrize("param", ["scale", "chol", "tau"])
def test_mvstudentt(self, param):
x = MvStudentTRandomWalk.dist(
nu=4,
Expand Down Expand Up @@ -853,7 +868,13 @@ def sde_fn(x, k, d, s):
with Model() as t0:
init_dist = pm.Normal.dist(0, 10, shape=(batch_size,))
y = EulerMaruyama(
"y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs
"y",
dt=0.02,
sde_fn=sde_fn,
sde_pars=sde_pars,
init_dist=init_dist,
initval="prior",
**kwargs,
)

y_eval = draw(y, draws=2, random_seed=numpy_rng)
Expand All @@ -873,6 +894,7 @@ def sde_fn(x, k, d, s):
sde_fn=sde_fn,
sde_pars=sde_pars_slice,
init_dist=init_dist,
initval="prior",
**kwargs,
)

Expand Down