Skip to content

Commit

Permalink
Extract RV creation and resize_shape determination code
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Jan 8, 2022
1 parent 6f78b53 commit ab5dd7e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 38 deletions.
83 changes: 46 additions & 37 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Iterable, Optional, Sequence
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union

import aesara
import numpy as np

from aeppl.logprob import _logcdf, _logprob
from aesara import tensor as at
from aesara.graph.basic import Variable
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
Expand All @@ -36,6 +38,8 @@
Dims,
Shape,
Size,
StrongShape,
WeakDims,
convert_dims,
convert_shape,
convert_size,
Expand Down Expand Up @@ -133,6 +137,37 @@ def fn(*args, **kwargs):
return fn


def _make_rv_and_resize_shape(
*,
cls,
dims: Optional[Dims],
model,
observed,
args,
**kwargs,
) -> Tuple[Variable, Optional[WeakDims], Optional[Union[np.ndarray, Variable]], StrongShape]:
"""Creates the RV and processes dims or observed to determine a resize shape."""
# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
dims = convert_dims(dims)
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
if dims is not None:
if dims_can_resize:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif Ellipsis in dims:
# Replace ... with None entries to match the actual dimensionality.
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
elif observed is not None:
resize_shape, observed = resize_from_observed(observed, ndim_actual)
return rv_out, dims, observed, resize_shape


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

Expand Down Expand Up @@ -213,24 +248,11 @@ def __new__(
if rng is None:
rng = model.next_rng()

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rng=rng, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not tracked at the Aesara level.
dims = convert_dims(dims)
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
if dims is not None:
if dims_can_resize:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif Ellipsis in dims:
# Replace ... with None entries to match the actual dimensionality.
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
elif observed is not None:
resize_shape, observed = resize_from_observed(observed, ndim_actual)
# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs
)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
Expand Down Expand Up @@ -459,24 +481,11 @@ def __new__(
elif not isinstance(rngs, (list, tuple)):
rngs = [rngs]

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rngs=rngs, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
dims = convert_dims(dims)
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
if dims is not None:
if dims_can_resize:
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif Ellipsis in dims:
# Replace ... with None entries to match the actual dimensionality.
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
elif observed is not None:
resize_shape, observed = resize_from_observed(observed, ndim_actual)
# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs
)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]


def convert_dims(dims: Dims) -> Optional[WeakDims]:
def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
"""Process a user-provided dims variable into None or a valid dims tuple."""
if dims is None:
return None
Expand Down

0 comments on commit ab5dd7e

Please sign in to comment.