Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-22] Add type annotations for python/paddle/distribution/multivariate_normal.py #65847

Merged
merged 5 commits into from
Jul 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions python/paddle/distribution/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
from collections.abc import Sequence
from typing import TYPE_CHECKING

import paddle
from paddle.base.data_feeder import convert_dtype
from paddle.distribution import distribution

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing.dtype_like import _DTypeLiteral


class MultivariateNormal(distribution.Distribution):
r"""The Multivariate Normal distribution is a type multivariate continuous distribution defined on the real set, with parameter: `loc` and any one
Expand All @@ -40,11 +47,11 @@ class MultivariateNormal(distribution.Distribution):
Args:
loc(int|float|Tensor): The mean of Multivariate Normal distribution. If the input data type is int or float, the data type of `loc` will be
convert to a 1-D Tensor the paddle global default dtype.
covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert
covariance_matrix(Tensor|None): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert
to be the same as the type of loc.
precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the
precision_matrix(Tensor|None): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the
type of loc.
scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be
scale_tril(Tensor|None): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be
convert to be the same as the type of loc.

Examples:
Expand Down Expand Up @@ -85,18 +92,24 @@ class MultivariateNormal(distribution.Distribution):
1.55541301)
"""

loc: Tensor
covariance_matrix: Tensor | None
precision_matrix: Tensor | None
scale_tril: Tensor | None
dtype: _DTypeLiteral

def __init__(
self,
loc,
covariance_matrix=None,
precision_matrix=None,
scale_tril=None,
loc: float | Tensor,
covariance_matrix: Tensor | None = None,
precision_matrix: Tensor | None = None,
scale_tril: Tensor | None = None,
enkilee marked this conversation as resolved.
Show resolved Hide resolved
):
self.dtype = paddle.get_default_dtype()
if isinstance(loc, (float, int)):
loc = paddle.to_tensor([loc], dtype=self.dtype)
else:
self.dtype = loc.dtype
self.dtype = convert_dtype(loc.dtype)
if loc.dim() < 1:
loc = loc.reshape((1,))
self.covariance_matrix = None
Expand Down Expand Up @@ -172,7 +185,7 @@ def __init__(
super().__init__(batch_shape, event_shape)

@property
def mean(self):
def mean(self) -> Tensor:
"""Mean of Multivariate Normal distribution.

Returns:
Expand All @@ -181,7 +194,7 @@ def mean(self):
return self.loc

@property
def variance(self):
def variance(self) -> Tensor:
"""Variance of Multivariate Normal distribution.

Returns:
Expand All @@ -193,7 +206,7 @@ def variance(self):
.expand(self._batch_shape + self._event_shape)
)

def sample(self, shape=()):
def sample(self, shape: Sequence[int] = ()) -> Tensor:
"""Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``.

Args:
Expand All @@ -205,7 +218,7 @@ def sample(self, shape=()):
with paddle.no_grad():
return self.rsample(shape)

def rsample(self, shape=()):
def rsample(self, shape: Sequence[int] = ()) -> Tensor:
"""Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``.

Args:
Expand All @@ -222,7 +235,7 @@ def rsample(self, shape=()):
self._unbroadcasted_scale_tril, eps.unsqueeze(-1)
).squeeze(-1)

def log_prob(self, value):
def log_prob(self, value: Tensor) -> Tensor:
"""Log probability density function.

Args:
Expand All @@ -245,7 +258,7 @@ def log_prob(self, value):
- half_log_det
)

def prob(self, value):
def prob(self, value: Tensor) -> Tensor:
"""Probability density function.

Args:
Expand All @@ -256,7 +269,7 @@ def prob(self, value):
"""
return paddle.exp(self.log_prob(value))

def entropy(self):
def entropy(self) -> Tensor:
r"""Shannon entropy in nats.

The entropy is
Expand Down Expand Up @@ -286,7 +299,7 @@ def entropy(self):
else:
return H.expand(self._batch_shape)

def kl_divergence(self, other):
def kl_divergence(self, other: MultivariateNormal) -> Tensor:
r"""The KL-divergence between two poisson distributions with the same `batch_shape` and `event_shape`.

The probability density function (pdf) is
Expand Down Expand Up @@ -344,7 +357,7 @@ def kl_divergence(self, other):
)


def precision_to_scale_tril(P):
def precision_to_scale_tril(P: Tensor) -> Tensor:
"""Convert precision matrix to scale tril matrix

Args:
Expand All @@ -363,7 +376,7 @@ def precision_to_scale_tril(P):
return L


def batch_mahalanobis(bL, bx):
def batch_mahalanobis(bL: Tensor, bx: Tensor) -> Tensor:
r"""
Computes the squared Mahalanobis distance of the Multivariate Normal distribution with cholesky decomposition of the covariance matrix.
Accepts batches for both bL and bx.
Expand Down