diff --git a/python/paddle/distribution/constraint.py b/python/paddle/distribution/constraint.py index a339d47c9d164..e59163dfd3e91 100644 --- a/python/paddle/distribution/constraint.py +++ b/python/paddle/distribution/constraint.py @@ -34,7 +34,7 @@ def __call__(self, value: Tensor) -> Tensor: class Range(Constraint): - def __init__(self, lower: Tensor, upper: Tensor) -> None: + def __init__(self, lower: float | Tensor, upper: float | Tensor) -> None: self._lower = lower self._upper = upper super().__init__() diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index 9bfa53a89fe65..1fbd71af5ee75 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -11,10 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import enum import math import typing +from typing import ( + TYPE_CHECKING, + Any, + Sequence, + overload, +) import paddle import paddle.nn.functional as F @@ -25,6 +32,10 @@ variable, ) +if TYPE_CHECKING: + from paddle import Tensor + from paddle.distribution import Distribution, TransformedDistribution + __all__ = [ 'Transform', 'AbsTransform', @@ -115,9 +126,10 @@ class Transform: * _inverse_shape """ + _type = Type.INJECTION - def __init__(self): + def __init__(self) -> None: super().__init__() @classmethod @@ -129,7 +141,19 @@ def _is_injective(cls): """ return Type.is_injective(cls._type) - def __call__(self, input): + @overload + def __call__(self, input: Tensor) -> Tensor: + ... + + @overload + def __call__(self, input: Distribution) -> TransformedDistribution: + ... + + @overload + def __call__(self, input: Transform) -> ChainTransform: + ... + + def __call__(self, input) -> Any: """Make this instance as a callable object. The return value is depending on the input type. @@ -154,7 +178,7 @@ def __call__(self, input): return ChainTransform([self, input]) return self.forward(input) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Forward transformation with mapping :math:`y = f(x)`. Useful for turning one random outcome into another. @@ -179,7 +203,7 @@ def forward(self, x): ) return self._forward(x) - def inverse(self, y): + def inverse(self, y: Tensor) -> Tensor: """Inverse transformation :math:`x = f^{-1}(y)`. It's useful for "reversing" a transformation to compute one probability in terms of another. @@ -202,7 +226,7 @@ def inverse(self, y): ) return self._inverse(y) - def forward_log_det_jacobian(self, x): + def forward_log_det_jacobian(self, x: Tensor) -> Tensor: """The log of the absolute value of the determinant of the matrix of all first-order partial derivatives of the inverse function. @@ -235,7 +259,7 @@ def forward_log_det_jacobian(self, x): return self._call_forward_log_det_jacobian(x) - def inverse_log_det_jacobian(self, y): + def inverse_log_det_jacobian(self, y: Tensor) -> Tensor: """Compute :math:`log|det J_{f^{-1}}(y)|`. Note that ``forward_log_det_jacobian`` is the negative of this function, evaluated at :math:`f^{-1}(y)`. @@ -258,7 +282,7 @@ def inverse_log_det_jacobian(self, y): ) return self._call_inverse_log_det_jacobian(y) - def forward_shape(self, shape): + def forward_shape(self, shape: Sequence[int]) -> Sequence[int]: """Infer the shape of forward transformation. Args: @@ -273,7 +297,7 @@ def forward_shape(self, shape): ) return self._forward_shape(shape) - def inverse_shape(self, shape): + def inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: """Infer the shape of inverse transformation. Args: @@ -289,28 +313,28 @@ def inverse_shape(self, shape): return self._inverse_shape(shape) @property - def _domain(self): + def _domain(self) -> variable.Variable: """The domain of this transformation""" return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Variable: """The codomain of this transformation""" return variable.real - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: """Inner method for public API ``forward``, subclass should overwrite this method for supporting forward transformation. """ raise NotImplementedError('Forward not implemented') - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: """Inner method of public API ``inverse``, subclass should overwrite this method for supporting inverse transformation. """ raise NotImplementedError('Inverse not implemented') - def _call_forward_log_det_jacobian(self, x): + def _call_forward_log_det_jacobian(self, x: Tensor) -> Tensor: """Inner method called by ``forward_log_det_jacobian``.""" if hasattr(self, '_forward_log_det_jacobian'): return self._forward_log_det_jacobian(x) @@ -321,7 +345,7 @@ def _call_forward_log_det_jacobian(self, x): 'is implemented. One of them is required.' ) - def _call_inverse_log_det_jacobian(self, y): + def _call_inverse_log_det_jacobian(self, y: Tensor) -> Tensor: """Inner method called by ``inverse_log_det_jacobian``""" if hasattr(self, '_inverse_log_det_jacobian'): return self._inverse_log_det_jacobian(y) @@ -332,14 +356,14 @@ def _call_inverse_log_det_jacobian(self, y): 'is implemented. One of them is required' ) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: """Inner method called by ``forward_shape``, which is used to infer the forward shape. Subclass should overwrite this method for supporting ``forward_shape``. """ return shape - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: """Inner method called by ``inverse_shape``, which is used to infer the inverse shape. Subclass should overwrite this method for supporting ``inverse_shape``. @@ -398,24 +422,25 @@ class AbsTransform(Transform): 0.)) """ + _type = Type.SURJECTION - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.abs() - def _inverse(self, y): + def _inverse(self, y: Tensor) -> tuple[Tensor, Tensor]: return -y, y - def _inverse_log_det_jacobian(self, y): + def _inverse_log_det_jacobian(self, y: Tensor) -> tuple[Tensor, Tensor]: zero = paddle.zeros([], dtype=y.dtype) return zero, zero @property - def _domain(self): + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Positive: return variable.positive @@ -446,9 +471,10 @@ class AffineTransform(Transform): Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.) """ + _type = Type.BIJECTION - def __init__(self, loc, scale): + def __init__(self, loc: Tensor, scale: Tensor) -> None: if not isinstance( loc, (paddle.base.framework.Variable, paddle.pir.Value) ): @@ -464,23 +490,23 @@ def __init__(self, loc, scale): super().__init__() @property - def loc(self): + def loc(self) -> Tensor: return self._loc @property - def scale(self): + def scale(self) -> Tensor: return self._scale - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return self._loc + self._scale * x - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return (y - self._loc) / self._scale - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return paddle.abs(self._scale).log() - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple( paddle.broadcast_shape( paddle.broadcast_shape(shape, self._loc.shape), @@ -488,7 +514,7 @@ def _forward_shape(self, shape): ) ) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple( paddle.broadcast_shape( paddle.broadcast_shape(shape, self._loc.shape), @@ -497,11 +523,11 @@ def _inverse_shape(self, shape): ) @property - def _domain(self): + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Real: return variable.real @@ -539,7 +565,7 @@ class ChainTransform(Transform): [ 0., -1., -2., -3.]) """ - def __init__(self, transforms): + def __init__(self, transforms: Sequence[Transform]) -> None: if not isinstance(transforms, typing.Sequence): raise TypeError( f"Expected type of 'transforms' is Sequence, but got {type(transforms)}" @@ -552,20 +578,20 @@ def __init__(self, transforms): self.transforms = transforms super().__init__() - def _is_injective(self): + def _is_injective(self) -> bool: return all(t._is_injective() for t in self.transforms) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: for transform in self.transforms: x = transform.forward(x) return x - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: for transform in reversed(self.transforms): y = transform.inverse(y) return y - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> float: value = 0.0 event_rank = self._domain.event_rank for t in self.transforms: @@ -576,22 +602,22 @@ def _forward_log_det_jacobian(self, x): event_rank += t._codomain.event_rank - t._domain.event_rank return value - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: for transform in self.transforms: shape = transform.forward_shape(shape) return shape - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: for transform in self.transforms: shape = transform.inverse_shape(shape) return shape - def _sum_rightmost(self, value, n): + def _sum_rightmost(self, value: Tensor, n: int) -> Tensor: """sum value along rightmost n dim""" return value.sum(list(range(-n, 0))) if n > 0 else value @property - def _domain(self): + def _domain(self) -> variable.Independent: domain = self.transforms[0]._domain # Compute the lower bound of input dimensions for chain transform. @@ -619,7 +645,7 @@ def _domain(self): return variable.Independent(domain, event_rank - domain.event_rank) @property - def _codomain(self): + def _codomain(self) -> variable.Independent: codomain = self.transforms[-1]._codomain event_rank = self.transforms[0]._domain.event_rank @@ -656,26 +682,27 @@ class ExpTransform(Transform): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [ 0. , -0.69314718, -1.09861231]) """ + _type = Type.BIJECTION - def __init__(self): + def __init__(self) -> None: super().__init__() @property - def _domain(self): + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Positive: return variable.positive - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.exp() - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.log() - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return x @@ -723,7 +750,7 @@ class IndependentTransform(Transform): [6. , 15.]) """ - def __init__(self, base, reinterpreted_batch_rank): + def __init__(self, base: Transform, reinterpreted_batch_rank: int) -> None: if not isinstance(base, Transform): raise TypeError( f"Expected 'base' is Transform type, but get {type(base)}" @@ -737,38 +764,38 @@ def __init__(self, base, reinterpreted_batch_rank): self._reinterpreted_batch_rank = reinterpreted_batch_rank super().__init__() - def _is_injective(self): + def _is_injective(self) -> bool: return self._base._is_injective() - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: if x.dim() < self._domain.event_rank: raise ValueError("Input dimensions is less than event dimensions.") return self._base.forward(x) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: if y.dim() < self._codomain.event_rank: raise ValueError("Input dimensions is less than event dimensions.") return self._base.inverse(y) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return self._base.forward_log_det_jacobian(x).sum( list(range(-self._reinterpreted_batch_rank, 0)) ) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: return self._base.forward_shape(shape) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return self._base.inverse_shape(shape) @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent( self._base._domain, self._reinterpreted_batch_rank ) @property - def _codomain(self): + def _codomain(self) -> variable.Independent: return variable.Independent( self._base._codomain, self._reinterpreted_batch_rank ) @@ -800,9 +827,10 @@ class PowerTransform(Transform): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [0.69314718, 1.38629436]) """ + _type = Type.BIJECTION - def __init__(self, power): + def __init__(self, power: Tensor) -> None: if not isinstance( power, (paddle.base.framework.Variable, paddle.pir.Value) ): @@ -813,30 +841,30 @@ def __init__(self, power): super().__init__() @property - def power(self): + def power(self) -> Tensor: return self._power @property - def _domain(self): + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Positive: return variable.positive - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.pow(self._power) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.pow(1 / self._power) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return (self._power * x.pow(self._power - 1)).abs().log() - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple(paddle.broadcast_shape(shape, self._power.shape)) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple(paddle.broadcast_shape(shape, self._power.shape)) @@ -873,9 +901,12 @@ class ReshapeTransform(Transform): Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, [0.]) """ + _type = Type.BIJECTION - def __init__(self, in_event_shape, out_event_shape): + def __init__( + self, in_event_shape: Sequence[int], out_event_shape: Sequence[int] + ) -> None: if not isinstance(in_event_shape, typing.Sequence) or not isinstance( out_event_shape, typing.Sequence ): @@ -901,34 +932,34 @@ def __init__(self, in_event_shape, out_event_shape): super().__init__() @property - def in_event_shape(self): + def in_event_shape(self) -> tuple[Sequence[int]]: return self._in_event_shape @property - def out_event_shape(self): + def out_event_shape(self) -> tuple[Sequence[int]]: return self._out_event_shape @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent(variable.real, len(self._in_event_shape)) @property - def _codomain(self): + def _codomain(self) -> variable.Independent: return variable.Independent(variable.real, len(self._out_event_shape)) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.reshape( tuple(x.shape)[: x.dim() - len(self._in_event_shape)] + self._out_event_shape ) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.reshape( tuple(y.shape)[: y.dim() - len(self._out_event_shape)] + self._in_event_shape ) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < len(self._in_event_shape): raise ValueError( f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" @@ -943,7 +974,7 @@ def _forward_shape(self, shape): tuple(shape[: -len(self._in_event_shape)]) + self._out_event_shape ) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < len(self._out_event_shape): raise ValueError( f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" @@ -958,7 +989,7 @@ def _inverse_shape(self, shape): tuple(shape[: -len(self._out_event_shape)]) + self._in_event_shape ) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: shape = x.shape[: x.dim() - len(self._in_event_shape)] return paddle.zeros(shape, dtype=x.dtype) @@ -989,20 +1020,20 @@ class SigmoidTransform(Transform): """ @property - def _domain(self): + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 0, constraint.Range(0.0, 1.0)) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return F.sigmoid(x) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.log() - (-y).log1p() - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return -F.softplus(-x) - F.softplus(x) @@ -1030,31 +1061,32 @@ class SoftmaxTransform(Transform): [[-1.09861231, -1.09861231, -1.09861231], [-1.09861231, -1.09861231, -1.09861231]]) """ + _type = Type.OTHER @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent(variable.real, 1) @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 1, constraint.simplex) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: x = (x - x.max(-1, keepdim=True)[0]).exp() return x / x.sum(-1, keepdim=True) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.log() - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < 1: raise ValueError( f"Expected length of shape is grater than 1, but got {len(shape)}" ) return shape - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < 1: raise ValueError( f"Expected length of shape is grater than 1, but got {len(shape)}" @@ -1103,7 +1135,7 @@ class StackTransform(Transform): [3. , 1.79175949]]) """ - def __init__(self, transforms, axis=0): + def __init__(self, transforms: Sequence[Transform], axis: int = 0): if not transforms or not isinstance(transforms, typing.Sequence): raise TypeError( f"Expected 'transforms' is Sequence[Transform], but got {type(transforms)}." @@ -1118,18 +1150,18 @@ def __init__(self, transforms, axis=0): self._transforms = transforms self._axis = axis - def _is_injective(self): + def _is_injective(self) -> bool: return all(t._is_injective() for t in self._transforms) @property - def transforms(self): + def transforms(self) -> Sequence[Transform]: return self._transforms @property - def axis(self): + def axis(self) -> int: return self._axis - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: self._check_size(x) return paddle.stack( [ @@ -1139,7 +1171,7 @@ def _forward(self, x): self._axis, ) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: self._check_size(y) return paddle.stack( [ @@ -1149,7 +1181,7 @@ def _inverse(self, y): self._axis, ) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: self._check_size(x) return paddle.stack( [ @@ -1159,7 +1191,7 @@ def _forward_log_det_jacobian(self, x): self._axis, ) - def _check_size(self, v): + def _check_size(self, v: Tensor) -> None: if not (-v.dim() <= self._axis < v.dim()): raise ValueError( f'Input dimensions {v.dim()} should be grater than stack ' @@ -1172,11 +1204,11 @@ def _check_size(self, v): ) @property - def _domain(self): + def _domain(self) -> variable.Stack: return variable.Stack([t._domain for t in self._transforms], self._axis) @property - def _codomain(self): + def _codomain(self) -> variable.Stack: return variable.Stack( [t._codomain for t in self._transforms], self._axis ) @@ -1208,7 +1240,7 @@ class StickBreakingTransform(Transform): _type = Type.BIJECTION - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) z = F.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) @@ -1216,35 +1248,35 @@ def _forward(self, x): z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1 ) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: y_crop = y[..., :-1] offset = y.shape[-1] - paddle.ones([y_crop.shape[-1]]).cumsum(-1) sf = 1 - y_crop.cumsum(-1) x = y_crop.log() - sf.log() + offset.log() return x - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: y = self.forward(x) offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) x = x - offset.log() return (-x + F.log_sigmoid(x) + y[..., :-1].log()).sum(-1) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if not shape: raise ValueError(f"Expected 'shape' is not empty, but got {shape}") - return shape[:-1] + (shape[-1] + 1,) + return (*shape[:-1], shape[-1] + 1) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if not shape: raise ValueError(f"Expected 'shape' is not empty, but got {shape}") - return shape[:-1] + (shape[-1] - 1,) + return (*shape[:-1], shape[-1] - 1) @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent(variable.real, 1) @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 1, constraint.simplex) @@ -1280,23 +1312,24 @@ class TanhTransform(Transform): [6.61441946 , 8.61399269 , 10.61451530]]) >>> # doctest: -SKIP """ + _type = Type.BIJECTION @property - def _domain(self): + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 0, constraint.Range(-1.0, 1.0)) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.tanh() - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.atanh() - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: """We implicitly rely on _forward_log_det_jacobian rather than explicitly implement ``_inverse_log_det_jacobian`` since directly using ``-tf.math.log1p(-tf.square(y))`` has lower numerical precision. diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index 595b0580f63e5..974593d8a862e 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -11,11 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import typing +from typing import TYPE_CHECKING, Sequence from paddle.distribution import distribution, independent, transform +if TYPE_CHECKING: + from paddle import Tensor + from paddle.distribution.distribution import Distribution + from paddle.distribution.transform import Transform + class TransformedDistribution(distribution.Distribution): r""" @@ -48,8 +55,12 @@ class TransformedDistribution(distribution.Distribution): -1.64333570) >>> # doctest: -SKIP """ + base: Distribution + transforms: Sequence[Transform] - def __init__(self, base, transforms): + def __init__( + self, base: Distribution, transforms: Sequence[Transform] + ) -> None: if not isinstance(base, distribution.Distribution): raise TypeError( f"Expected type of 'base' is Distribution, but got {type(base)}." @@ -92,7 +103,7 @@ def __init__(self, base, transforms): ], ) - def sample(self, shape=()): + def sample(self, shape: Sequence[int] = ()) -> Tensor: """Sample from ``TransformedDistribution``. Args: @@ -106,7 +117,7 @@ def sample(self, shape=()): x = t.forward(x) return x - def rsample(self, shape=()): + def rsample(self, shape: Sequence[int] = ()) -> Tensor: """Reparameterized sample from ``TransformedDistribution``. Args: @@ -120,7 +131,7 @@ def rsample(self, shape=()): x = t.forward(x) return x - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: """The log probability evaluated at value. Args: @@ -145,5 +156,5 @@ def log_prob(self, value): return log_prob -def _sum_rightmost(value, n): +def _sum_rightmost(value: Tensor, n: int) -> Tensor: return value.sum(list(range(-n, 0))) if n > 0 else value