From 5078079ddedcfba36b8d373ab3c9a0d7cab4dd10 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 9 Jul 2024 11:46:31 +0800 Subject: [PATCH 1/4] fix --- .../distribution/multivariate_normal.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 55e6a3cc7f325..a8ecaedf9c298 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import math from collections.abc import Sequence +from typing import TYPE_CHECKING import paddle from paddle.distribution import distribution +if TYPE_CHECKING: + from paddle import Tensor + class MultivariateNormal(distribution.Distribution): r"""The Multivariate Normal distribution is a type multivariate continuous distribution defined on the real set, with parameter: `loc` and any one @@ -40,11 +45,11 @@ class MultivariateNormal(distribution.Distribution): Args: loc(int|float|Tensor): The mean of Multivariate Normal distribution. If the input data type is int or float, the data type of `loc` will be convert to a 1-D Tensor the paddle global default dtype. - covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert + covariance_matrix(Tensor|None): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert to be the same as the type of loc. - precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the + precision_matrix(Tensor|None): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the type of loc. - scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be + scale_tril(Tensor|None): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be convert to be the same as the type of loc. Examples: @@ -87,10 +92,10 @@ class MultivariateNormal(distribution.Distribution): def __init__( self, - loc, - covariance_matrix=None, - precision_matrix=None, - scale_tril=None, + loc: float | Tensor, + covariance_matrix: Tensor | None = None, + precision_matrix: Tensor | None = None, + scale_tril: Tensor | None = None, ): self.dtype = paddle.get_default_dtype() if isinstance(loc, (float, int)): @@ -172,7 +177,7 @@ def __init__( super().__init__(batch_shape, event_shape) @property - def mean(self): + def mean(self) -> Tensor: """Mean of Multivariate Normal distribution. Returns: @@ -181,7 +186,7 @@ def mean(self): return self.loc @property - def variance(self): + def variance(self) -> Tensor: """Variance of Multivariate Normal distribution. Returns: @@ -193,7 +198,7 @@ def variance(self): .expand(self._batch_shape + self._event_shape) ) - def sample(self, shape=()): + def sample(self, shape: Sequence[int] = ()) -> Tensor: """Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``. Args: @@ -205,7 +210,7 @@ def sample(self, shape=()): with paddle.no_grad(): return self.rsample(shape) - def rsample(self, shape=()): + def rsample(self, shape: Sequence[int] = ()) -> Tensor: """Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``. Args: @@ -222,7 +227,7 @@ def rsample(self, shape=()): self._unbroadcasted_scale_tril, eps.unsqueeze(-1) ).squeeze(-1) - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: """Log probability density function. Args: @@ -245,7 +250,7 @@ def log_prob(self, value): - half_log_det ) - def prob(self, value): + def prob(self, value: Tensor) -> Tensor: """Probability density function. Args: @@ -256,7 +261,7 @@ def prob(self, value): """ return paddle.exp(self.log_prob(value)) - def entropy(self): + def entropy(self) -> Tensor: r"""Shannon entropy in nats. The entropy is @@ -286,7 +291,7 @@ def entropy(self): else: return H.expand(self._batch_shape) - def kl_divergence(self, other): + def kl_divergence(self, other: MultivariateNormal) -> Tensor: r"""The KL-divergence between two poisson distributions with the same `batch_shape` and `event_shape`. The probability density function (pdf) is @@ -344,7 +349,7 @@ def kl_divergence(self, other): ) -def precision_to_scale_tril(P): +def precision_to_scale_tril(P: Tensor) -> Tensor: """Convert precision matrix to scale tril matrix Args: @@ -363,7 +368,7 @@ def precision_to_scale_tril(P): return L -def batch_mahalanobis(bL, bx): +def batch_mahalanobis(bL: Tensor, bx: Tensor) -> Tensor: r""" Computes the squared Mahalanobis distance of the Multivariate Normal distribution with cholesky decomposition of the covariance matrix. Accepts batches for both bL and bx. From ea0fe94a9f3c4f8e2f34aabc079a7feae1152018 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 9 Jul 2024 12:55:46 +0800 Subject: [PATCH 2/4] fix --- python/paddle/distribution/multivariate_normal.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index a8ecaedf9c298..140bae2fd54e4 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -89,6 +89,10 @@ class MultivariateNormal(distribution.Distribution): Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 1.55541301) """ + loc: Tensor + covariance_matrix: Tensor | None + precision_matrix: Tensor | None + scale_tril: Tensor | None def __init__( self, From 10cbead99e1023645befbff17d1e711db1dc7eab Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 9 Jul 2024 13:01:32 +0800 Subject: [PATCH 3/4] fix --- python/paddle/distribution/multivariate_normal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 140bae2fd54e4..d522e6a988d88 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -21,7 +21,7 @@ from paddle.distribution import distribution if TYPE_CHECKING: - from paddle import Tensor + from paddle import Tensor, dtype class MultivariateNormal(distribution.Distribution): @@ -93,6 +93,7 @@ class MultivariateNormal(distribution.Distribution): covariance_matrix: Tensor | None precision_matrix: Tensor | None scale_tril: Tensor | None + dtype: dtype def __init__( self, From 7e9f9ce161dfd62bbe2c2edc1e400c0647724231 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sun, 14 Jul 2024 01:57:04 +0800 Subject: [PATCH 4/4] use `_DTypeLiteral` --- python/paddle/distribution/multivariate_normal.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index d522e6a988d88..9f5f27931d1e8 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -18,10 +18,12 @@ from typing import TYPE_CHECKING import paddle +from paddle.base.data_feeder import convert_dtype from paddle.distribution import distribution if TYPE_CHECKING: - from paddle import Tensor, dtype + from paddle import Tensor + from paddle._typing.dtype_like import _DTypeLiteral class MultivariateNormal(distribution.Distribution): @@ -89,11 +91,12 @@ class MultivariateNormal(distribution.Distribution): Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 1.55541301) """ + loc: Tensor covariance_matrix: Tensor | None precision_matrix: Tensor | None scale_tril: Tensor | None - dtype: dtype + dtype: _DTypeLiteral def __init__( self, @@ -106,7 +109,7 @@ def __init__( if isinstance(loc, (float, int)): loc = paddle.to_tensor([loc], dtype=self.dtype) else: - self.dtype = loc.dtype + self.dtype = convert_dtype(loc.dtype) if loc.dim() < 1: loc = loc.reshape((1,)) self.covariance_matrix = None