diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 55e6a3cc7f325..9f5f27931d1e8 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -11,13 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import math from collections.abc import Sequence +from typing import TYPE_CHECKING import paddle +from paddle.base.data_feeder import convert_dtype from paddle.distribution import distribution +if TYPE_CHECKING: + from paddle import Tensor + from paddle._typing.dtype_like import _DTypeLiteral + class MultivariateNormal(distribution.Distribution): r"""The Multivariate Normal distribution is a type multivariate continuous distribution defined on the real set, with parameter: `loc` and any one @@ -40,11 +47,11 @@ class MultivariateNormal(distribution.Distribution): Args: loc(int|float|Tensor): The mean of Multivariate Normal distribution. If the input data type is int or float, the data type of `loc` will be convert to a 1-D Tensor the paddle global default dtype. - covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert + covariance_matrix(Tensor|None): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert to be the same as the type of loc. - precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the + precision_matrix(Tensor|None): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the type of loc. - scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be + scale_tril(Tensor|None): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be convert to be the same as the type of loc. Examples: @@ -85,18 +92,24 @@ class MultivariateNormal(distribution.Distribution): 1.55541301) """ + loc: Tensor + covariance_matrix: Tensor | None + precision_matrix: Tensor | None + scale_tril: Tensor | None + dtype: _DTypeLiteral + def __init__( self, - loc, - covariance_matrix=None, - precision_matrix=None, - scale_tril=None, + loc: float | Tensor, + covariance_matrix: Tensor | None = None, + precision_matrix: Tensor | None = None, + scale_tril: Tensor | None = None, ): self.dtype = paddle.get_default_dtype() if isinstance(loc, (float, int)): loc = paddle.to_tensor([loc], dtype=self.dtype) else: - self.dtype = loc.dtype + self.dtype = convert_dtype(loc.dtype) if loc.dim() < 1: loc = loc.reshape((1,)) self.covariance_matrix = None @@ -172,7 +185,7 @@ def __init__( super().__init__(batch_shape, event_shape) @property - def mean(self): + def mean(self) -> Tensor: """Mean of Multivariate Normal distribution. Returns: @@ -181,7 +194,7 @@ def mean(self): return self.loc @property - def variance(self): + def variance(self) -> Tensor: """Variance of Multivariate Normal distribution. Returns: @@ -193,7 +206,7 @@ def variance(self): .expand(self._batch_shape + self._event_shape) ) - def sample(self, shape=()): + def sample(self, shape: Sequence[int] = ()) -> Tensor: """Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``. Args: @@ -205,7 +218,7 @@ def sample(self, shape=()): with paddle.no_grad(): return self.rsample(shape) - def rsample(self, shape=()): + def rsample(self, shape: Sequence[int] = ()) -> Tensor: """Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``. Args: @@ -222,7 +235,7 @@ def rsample(self, shape=()): self._unbroadcasted_scale_tril, eps.unsqueeze(-1) ).squeeze(-1) - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: """Log probability density function. Args: @@ -245,7 +258,7 @@ def log_prob(self, value): - half_log_det ) - def prob(self, value): + def prob(self, value: Tensor) -> Tensor: """Probability density function. Args: @@ -256,7 +269,7 @@ def prob(self, value): """ return paddle.exp(self.log_prob(value)) - def entropy(self): + def entropy(self) -> Tensor: r"""Shannon entropy in nats. The entropy is @@ -286,7 +299,7 @@ def entropy(self): else: return H.expand(self._batch_shape) - def kl_divergence(self, other): + def kl_divergence(self, other: MultivariateNormal) -> Tensor: r"""The KL-divergence between two poisson distributions with the same `batch_shape` and `event_shape`. The probability density function (pdf) is @@ -344,7 +357,7 @@ def kl_divergence(self, other): ) -def precision_to_scale_tril(P): +def precision_to_scale_tril(P: Tensor) -> Tensor: """Convert precision matrix to scale tril matrix Args: @@ -363,7 +376,7 @@ def precision_to_scale_tril(P): return L -def batch_mahalanobis(bL, bx): +def batch_mahalanobis(bL: Tensor, bx: Tensor) -> Tensor: r""" Computes the squared Mahalanobis distance of the Multivariate Normal distribution with cholesky decomposition of the covariance matrix. Accepts batches for both bL and bx.