diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 86a1c1e7fbd..953b0c0b1c6 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -232,9 +232,9 @@ def __new__( # `dims` are only available with this API, because `.dist()` can be used # without a modelcontext and dims are not tracked at the Aesara level. if dims is not None: - ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model) + resize_shape, dims = resize_from_dims(dims, ndim_actual, model) elif observed is not None: - ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual) + resize_shape, observed = resize_from_observed(observed, ndim_actual) if resize_shape: # A batch size was specified through `dims`, or implied by `observed`. @@ -482,9 +482,9 @@ def __new__( # # `dims` are only available with this API, because `.dist()` can be used # # without a modelcontext and dims are not tracked at the Aesara level. if dims is not None: - ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model) + resize_shape, dims = resize_from_dims(dims, ndim_actual, model) elif observed is not None: - ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual) + resize_shape, observed = resize_from_observed(observed, ndim_actual) if resize_shape: # A batch size was specified through `dims`, or implied by `observed`. diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 28117c3353b..ec73c118a2e 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -487,9 +487,7 @@ def convert_size(size: Size) -> Optional[StrongSize]: return size -def resize_from_dims( - dims: WeakDims, ndim_implied: int, model -) -> Tuple[int, StrongSize, StrongDims]: +def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSize, StrongDims]: """Determines a potential resize shape from a `dims` tuple. Parameters @@ -503,10 +501,10 @@ def resize_from_dims( Returns ------- - ndim_resize : int - Number of dimensions that should be added through resizing. resize_shape : array-like - The shape of the new dimensions. + Shape of new dimensions that should be prepended. + dims : tuple of (str or None) + Names or None for all dimensions after resizing. """ if Ellipsis in dims: # Auto-complete the dims tuple to the full length. @@ -525,12 +523,12 @@ def resize_from_dims( # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize]) - return ndim_resize, resize_shape, dims + return resize_shape, dims def resize_from_observed( observed, ndim_implied: int -) -> Tuple[int, StrongSize, Union[np.ndarray, Variable]]: +) -> Tuple[StrongSize, Union[np.ndarray, Variable]]: """Determines a potential resize shape from observations. Parameters @@ -542,10 +540,8 @@ def resize_from_observed( Returns ------- - ndim_resize : int - Number of dimensions that should be added through resizing. resize_shape : array-like - The shape of the new dimensions. + Shape of new dimensions that should be prepended. observed : scalar, array-like Observations as numpy array or `Variable`. """ @@ -553,7 +549,7 @@ def resize_from_observed( observed = pandas_to_array(observed) ndim_resize = observed.ndim - ndim_implied resize_shape = tuple(observed.shape[d] for d in range(ndim_resize)) - return ndim_resize, resize_shape, observed + return resize_shape, observed def find_size(shape=None, size=None, ndim_supp=None):