Skip to content

Commit

Permalink
Distinguish between "not set" and None initvals
Browse files Browse the repository at this point in the history
This changes the initval default on Distribution.__new__ and Distribution.dist to UNSET.
It allows for implementing distribution-specific initial values similar to how it was done in pymc3 <4.

Related to pymc-devs#4771.
  • Loading branch information
michaelosthege committed Jul 18, 2021
1 parent b1e43b0 commit 0ccd7be
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 11 deletions.
8 changes: 4 additions & 4 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ class Flat(Continuous):
rv_op = flat

@classmethod
def dist(cls, *, size=None, initval=None, **kwargs):
if initval is None:
def dist(cls, *, size=None, initval=UNSET, **kwargs):
if initval is UNSET or initval is None:
initval = np.full(size, floatX(0.0))
res = super().dist([], size=size, **kwargs)
res.tag.test_value = initval
Expand Down Expand Up @@ -425,8 +425,8 @@ class HalfFlat(PositiveContinuous):
rv_op = halfflat

@classmethod
def dist(cls, *, size=None, initval=None, **kwargs):
if initval is None:
def dist(cls, *, size=None, initval=UNSET, **kwargs):
if initval is UNSET or initval is None:
initval = np.full(size, floatX(1.0))
res = super().dist([], size=size, **kwargs)
res.tag.test_value = initval
Expand Down
33 changes: 27 additions & 6 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __new__(
*args,
rng=None,
dims: Optional[Dims] = None,
initval=None,
initval=UNSET,
observed=None,
total_size=None,
transform=UNSET,
Expand All @@ -149,6 +149,9 @@ def __new__(
initval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
If set to `None`, an initial value will be drawn randomly.
With a value of `UNSET`, or not passing `initval` the `cls` distribution
may provide a default initial value at its own discretion.
observed : optional
Observed data to be passed when registering the random variable in the model.
See ``Model.register_rv``.
Expand Down Expand Up @@ -202,9 +205,14 @@ def __new__(
)
dims = convert_dims(dims)

# Create the RV without specifying initval, because the initval may have a shape
# Create the RV without a user-provided initval, because it may have a shape
# that only matches after replicating with a size implied by dims (see below).
rv_out = cls.dist(*args, rng=rng, initval=None, **kwargs)
# Only the following initvals are forwarded, because they don't create shape problems:
# → None: No explicit init value. The model will default to a random draw.
# → UNSET: The RV Op may pass a default (e.g. the mode), otherwise no explicit init value.
if initval is None or initval is UNSET:
kwargs["initval"] = initval
rv_out = cls.dist(*args, rng=rng, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

Expand All @@ -219,9 +227,15 @@ def __new__(
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv_var=rv_out, new_size=resize_shape, expand=True)

if initval is not None:
if initval is not UNSET and initval is not None:
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = initval
elif initval is None and hasattr(rv_out.tag, "test_value"):
warnings.warn(
f"An initval=None was specified, but the {cls.__name__} distribution assigned a test value anyways."
f" This unexpected behaviour and should be investigated!",
RuntimeWarning,
)

rv_out = model.register_rv(
rv_out, name, observed, total_size, dims=dims, transform=transform
Expand All @@ -242,7 +256,7 @@ def dist(
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
initval=None,
initval=UNSET,
**kwargs,
) -> RandomVariable:
"""Creates a RandomVariable corresponding to the `cls` distribution.
Expand Down Expand Up @@ -293,6 +307,13 @@ def dist(
# Create the RV with a `size` right away.
# This is not necessarily the final result.
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
if initval is None and hasattr(rv_out.tag, "test_value"):
warnings.warn(
f"An initval=None was specified, but the {cls.rv_op.__class__.__name__} Op assigned a test value anyways."
f" This unexpected behaviour and should be investigated!",
RuntimeWarning,
)

rv_out = maybe_resize(
rv_out,
cls.rv_op,
Expand Down Expand Up @@ -322,7 +343,7 @@ def dist(
rv_out.update = (rng, new_rng)
rng.default_update = new_rng

if initval is not None:
if initval is not UNSET and initval is not None:
rv_out.tag.test_value = initval

return rv_out
Expand Down
170 changes: 169 additions & 1 deletion pymc3/tests/test_initvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import aesara
import aesara.tensor as at
import aesara.tensor.random.basic as atr
import numpy as np
import pytest

import pymc3 as pm

from pymc3.util import UNSET, select_initval


Expand All @@ -23,7 +28,7 @@ def test_util_select_initval():
assert select_initval(None, default=3) is None
with pytest.warns(None):
# No warning, because defaults are often symbolic (e.g. the mode of a distribution).
assert select_initval(UNSET, default=at.scalar()) == UNSET
assert select_initval(UNSET, default=at.scalar()) is UNSET

# The default is preferred if the candidate is UNSET or invalid
assert select_initval(UNSET, default=3) == 3
Expand All @@ -36,3 +41,166 @@ def test_util_select_initval():
with pytest.warns(UserWarning, match="incompatible `initial` value"):
assert select_initval(at.scalar(), default=at.scalar()) == None
pass


class NormalWithoutInitval(pm.Distribution):
"""A distribution that does not specify a default initial value."""

rv_op = atr.normal

@classmethod
def dist(cls, mu=0, sigma=None, **kwargs):
mu = at.as_tensor(pm.floatX(mu))
sigma = at.as_tensor(pm.floatX(sigma))
return super().dist([mu, sigma], **kwargs)


class UniformWithInitval(pm.distributions.continuous.BoundedContinuous):
"""
A distribution that defaults the initial value.
"""

rv_op = atr.uniform
bound_args_indices = (0, 1) # Lower, Upper

@classmethod
def dist(cls, lower=0, upper=1, initval=UNSET, **kwargs):
initval = select_initval(initval, (upper + lower) / 2)
lower = at.as_tensor_variable(pm.floatX(lower))
upper = at.as_tensor_variable(pm.floatX(upper))
return super().dist([lower, upper], initval=initval, **kwargs)


class AlwaysInitvalRV(atr.NormalRV):
"""
A hypothetical RV Op that always assigns a test_value.
This behavior would render initval=None ineffective, potentially
resulting in the loss of control over the init value selection.
"""

def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
output = super().__call__(loc, scale, size=size, **kwargs)
output.tag.test_value = 0
return output


class AlwaysInitvalDist_ByOp(pm.Distribution):
"""A PyMC3 distribution corresponding to the AlwaysInitvalRV Op."""

rv_op = AlwaysInitvalRV()

@classmethod
def dist(cls, mu=0, sigma=None, **kwargs):
mu = at.as_tensor(mu)
sigma = at.as_tensor(sigma)
return super().dist([mu, sigma], **kwargs)


class AlwaysInitvalDist_ByDist(pm.Distribution):
"""A distribution that illegally overwrites any initial value setting."""

rv_op = atr.normal

@classmethod
def dist(cls, mu=0, sigma=None, **kwargs):
mu = at.as_tensor(mu)
sigma = at.as_tensor(sigma)
kwargs["initval"] = 1
return super().dist([mu, sigma], **kwargs)


def transform_fwd(rv, expected_untransformed):
return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval()


class TestInitvalAssignment:
def test_dist_initval_behaviors(self):
"""
No test values are set on the RV unless specified by either the user or the RV Op.
"""
rv = NormalWithoutInitval.dist(1, 2)
assert not hasattr(rv.tag, "test_value")

rv = NormalWithoutInitval.dist(1, 2, initval=None)
assert not hasattr(rv.tag, "test_value")

rv = NormalWithoutInitval.dist(1, 2, initval=0.5)
assert rv.tag.test_value == 0.5

# A distribution may provide a default initval in its .dist() implementation:
rv = UniformWithInitval.dist(1, 2)
assert rv.tag.test_value == 1.5
pass

def test_new_initval_behaviors(self):
"""
No test values are set on the RV unless specified by either the user or the RV Op.
But initial values are always determined and managed by the Model object.
"""
with pm.Model() as pmodel:
rv1 = NormalWithoutInitval("default to random draw", 1, 2)
rv2 = NormalWithoutInitval("default to random draw the second", 1, 2)
assert pmodel.initial_values[rv1.tag.value_var] != 1
assert pmodel.initial_values[rv2.tag.value_var] != 1
assert (
pmodel.initial_values[rv1.tag.value_var] != pmodel.initial_values[rv2.tag.value_var]
)
# Randomly drawn initvals are not attached to the rv:
assert not hasattr(rv1.tag, "test_value")
assert not hasattr(rv2.tag, "test_value")

rv = NormalWithoutInitval("user provided", 1, 2, initval=-0.2)
assert pmodel.initial_values[rv.tag.value_var] == np.array(
-0.2, dtype=aesara.config.floatX
)
assert rv.tag.test_value == np.array(-0.2, dtype=aesara.config.floatX)

rv = UniformWithInitval("RVOp default", 1.5, 2)
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 1.75)
assert rv.tag.test_value == np.array(1.75, dtype=aesara.config.floatX)

rv = UniformWithInitval("user can override RVOp default", 1.5, 2, initval=1.8)
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 1.8)
assert rv.tag.test_value == np.array(1.8, dtype=aesara.config.floatX)

rv = UniformWithInitval("user can revert to random draw", 1.5, 2, initval=None)
assert pmodel.initial_values[rv.tag.value_var] != transform_fwd(rv, 1.75)
assert not hasattr(rv.tag, "test_value")
pass

def test_unexpected_test_value_assignment_warning(self):
with pytest.warns(RuntimeWarning, match="the AlwaysInitvalRV Op assigned a test value"):
rv = AlwaysInitvalDist_ByOp.dist(1, 2, initval=None)
assert rv.tag.test_value == 0

# Note that `AlwaysInitvalDist_ByDist.dist` overrides the `initval` passed to it,
# but there's no way to warn about that when `AlwaysInitvalDist_ByDist.dist` is called directly.

with pm.Model() as pmodel:
with pytest.warns(RuntimeWarning, match="the AlwaysInitvalRV Op assigned a test value"):
rv = AlwaysInitvalDist_ByOp("noncompliant Op", 1, 2, initval=None)
# In violation of initval=None, the initial value is now fixed:
assert pmodel.initial_values[rv.tag.value_var] == 0
assert rv.tag.test_value == 0

with pytest.warns(
RuntimeWarning,
match="the AlwaysInitvalDist_ByDist distribution assigned a test value",
):
rv = AlwaysInitvalDist_ByDist("noncompliant Dist", 1, 2, initval=None)
# In violation of initval=None, the initial value is now fixed:
assert pmodel.initial_values[rv.tag.value_var] == 1
assert rv.tag.test_value == 1
pass

def test_test_value_deprecation_warning(self):
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
rv = pm.Exponential.dist(lam=1, testval=0.5)
assert rv.tag.test_value == 0.5

with pm.Model() as pmodel:
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
rv = pm.Uniform("u", 0, 1, testval=0.75)
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75)
assert rv.tag.test_value == np.array(0.75, dtype=aesara.config.floatX)
pass

0 comments on commit 0ccd7be

Please sign in to comment.