diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 8045be4b1fa923..f5bf8a13275f01 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -150,7 +150,7 @@ def sample(self, shape: Sequence[int] = ()) -> Tensor: shape (Sequence[int], optional): Sample shape. Returns: - Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. """ shape = shape if isinstance(shape, tuple) else tuple(shape) return paddle.squeeze(self._dirichlet.sample(shape)[..., 0], axis=-1) diff --git a/python/paddle/distribution/lkj_cholesky.py b/python/paddle/distribution/lkj_cholesky.py index fc2cb5c6c33e6f..b2dd807e709b26 100644 --- a/python/paddle/distribution/lkj_cholesky.py +++ b/python/paddle/distribution/lkj_cholesky.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import math import operator from collections.abc import Sequence from functools import reduce +from typing import TYPE_CHECKING, Literal import paddle from paddle.base.data_feeder import check_type, convert_dtype @@ -24,6 +26,11 @@ from paddle.distribution.beta import Beta from paddle.framework import in_dynamic_mode +if TYPE_CHECKING: + from paddle import Tensor + from paddle._typing import DTypeLike + + __all__ = ["LKJCholesky"] @@ -95,7 +102,7 @@ def vec_to_tril_matrix( return matrix -def tril_matrix_to_vec(mat: paddle.Tensor, diag: int = 0) -> paddle.Tensor: +def tril_matrix_to_vec(mat: Tensor, diag: int = 0) -> Tensor: r""" Convert a `D x D` matrix or a batch of matrices into a (batched) vector which comprises of lower triangular elements from the matrix in row order. @@ -139,7 +146,17 @@ class LKJCholesky(distribution.Distribution): [3, 3] """ - def __init__(self, dim=2, concentration=1.0, sample_method="onion"): + concentration: Tensor + dtype: DTypeLike + dim: int + sample_method: Literal["onion", "cvine"] + + def __init__( + self, + dim: int = 2, + concentration: float = 1.0, + sample_method: Literal["onion", "cvine"] = "onion", + ) -> None: if not in_dynamic_mode(): check_type( dim, @@ -209,14 +226,14 @@ def __init__(self, dim=2, concentration=1.0, sample_method="onion"): raise ValueError("`method` should be one of 'cvine' or 'onion'.") super().__init__(batch_shape, event_shape) - def _onion(self, sample_shape): + def _onion(self, sample_shape: Sequence[int]) -> Tensor: """Generate a sample using the "onion" method. Args: sample_shape (tuple): The shape of the samples to be generated. Returns: - w (paddle.Tensor): The Cholesky factor of the sampled correlation matrix. + w (Tensor): The Cholesky factor of the sampled correlation matrix. """ # Sample y from the Beta distribution y = self._beta.sample(sample_shape).unsqueeze(-1) @@ -249,14 +266,14 @@ def _onion(self, sample_shape): w += paddle.diag_embed(diag_elems) return w - def _cvine(self, sample_shape): + def _cvine(self, sample_shape: Sequence[int]) -> Tensor: """Generate a sample using the "cvine" method. Args: sample_shape (tuple): The shape of the samples to be generated. Returns: - r (paddle.Tensor): The Cholesky factor of the sampled correlation matrix. + r (Tensor): The Cholesky factor of the sampled correlation matrix. """ # Sample beta and calculate partial correlations @@ -308,7 +325,7 @@ def _cvine(self, sample_shape): r = r.reshape((flatten_shape // last_dim, self.dim, self.dim)) return r - def sample(self, sample_shape=()): + def sample(self, sample_shape: Sequence[int] = ()) -> Tensor: """Generate a sample using the specified sampling method.""" if not isinstance(sample_shape, Sequence): raise TypeError('sample shape must be Sequence object.') @@ -334,14 +351,14 @@ def sample(self, sample_shape=()): return res.reshape(output_shape) - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: r"""Compute the log probability density of the given Cholesky factor under the LKJ distribution. Args: - value (paddle.Tensor): The Cholesky factor of the correlation matrix for which the log probability density is to be computed. + value (Tensor): The Cholesky factor of the correlation matrix for which the log probability density is to be computed. Returns: - log_prob (paddle.Tensor): The log probability density of the given Cholesky factor under the LKJ distribution. + log_prob (Tensor): The log probability density of the given Cholesky factor under the LKJ distribution. """ # 1.Compute the order vector. diag_elems = paddle.diagonal(value, offset=0, axis1=-1, axis2=-2)[