Skip to content

Commit

Permalink
[Typing][B-19] Add type annotations for `python/paddle/distribution/l…
Browse files Browse the repository at this point in the history
…kj_cholesky.py` (PaddlePaddle#65785)
  • Loading branch information
ooooo-create authored and lixcli committed Jul 22, 2024
1 parent b2754d2 commit a4ca1a5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/paddle/distribution/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def sample(self, shape: Sequence[int] = ()) -> Tensor:
shape (Sequence[int], optional): Sample shape.
Returns:
Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
"""
shape = shape if isinstance(shape, tuple) else tuple(shape)
return paddle.squeeze(self._dirichlet.sample(shape)[..., 0], axis=-1)
Expand Down
37 changes: 27 additions & 10 deletions python/paddle/distribution/lkj_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
import operator
from collections.abc import Sequence
from functools import reduce
from typing import TYPE_CHECKING, Literal

import paddle
from paddle.base.data_feeder import check_type, convert_dtype
Expand All @@ -24,6 +26,11 @@
from paddle.distribution.beta import Beta
from paddle.framework import in_dynamic_mode

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing import DTypeLike


__all__ = ["LKJCholesky"]


Expand Down Expand Up @@ -95,7 +102,7 @@ def vec_to_tril_matrix(
return matrix


def tril_matrix_to_vec(mat: paddle.Tensor, diag: int = 0) -> paddle.Tensor:
def tril_matrix_to_vec(mat: Tensor, diag: int = 0) -> Tensor:
r"""
Convert a `D x D` matrix or a batch of matrices into a (batched) vector
which comprises of lower triangular elements from the matrix in row order.
Expand Down Expand Up @@ -139,7 +146,17 @@ class LKJCholesky(distribution.Distribution):
[3, 3]
"""

def __init__(self, dim=2, concentration=1.0, sample_method="onion"):
concentration: Tensor
dtype: DTypeLike
dim: int
sample_method: Literal["onion", "cvine"]

def __init__(
self,
dim: int = 2,
concentration: float = 1.0,
sample_method: Literal["onion", "cvine"] = "onion",
) -> None:
if not in_dynamic_mode():
check_type(
dim,
Expand Down Expand Up @@ -209,14 +226,14 @@ def __init__(self, dim=2, concentration=1.0, sample_method="onion"):
raise ValueError("`method` should be one of 'cvine' or 'onion'.")
super().__init__(batch_shape, event_shape)

def _onion(self, sample_shape):
def _onion(self, sample_shape: Sequence[int]) -> Tensor:
"""Generate a sample using the "onion" method.
Args:
sample_shape (tuple): The shape of the samples to be generated.
Returns:
w (paddle.Tensor): The Cholesky factor of the sampled correlation matrix.
w (Tensor): The Cholesky factor of the sampled correlation matrix.
"""
# Sample y from the Beta distribution
y = self._beta.sample(sample_shape).unsqueeze(-1)
Expand Down Expand Up @@ -249,14 +266,14 @@ def _onion(self, sample_shape):
w += paddle.diag_embed(diag_elems)
return w

def _cvine(self, sample_shape):
def _cvine(self, sample_shape: Sequence[int]) -> Tensor:
"""Generate a sample using the "cvine" method.
Args:
sample_shape (tuple): The shape of the samples to be generated.
Returns:
r (paddle.Tensor): The Cholesky factor of the sampled correlation matrix.
r (Tensor): The Cholesky factor of the sampled correlation matrix.
"""

# Sample beta and calculate partial correlations
Expand Down Expand Up @@ -308,7 +325,7 @@ def _cvine(self, sample_shape):
r = r.reshape((flatten_shape // last_dim, self.dim, self.dim))
return r

def sample(self, sample_shape=()):
def sample(self, sample_shape: Sequence[int] = ()) -> Tensor:
"""Generate a sample using the specified sampling method."""
if not isinstance(sample_shape, Sequence):
raise TypeError('sample shape must be Sequence object.')
Expand All @@ -334,14 +351,14 @@ def sample(self, sample_shape=()):

return res.reshape(output_shape)

def log_prob(self, value):
def log_prob(self, value: Tensor) -> Tensor:
r"""Compute the log probability density of the given Cholesky factor under the LKJ distribution.
Args:
value (paddle.Tensor): The Cholesky factor of the correlation matrix for which the log probability density is to be computed.
value (Tensor): The Cholesky factor of the correlation matrix for which the log probability density is to be computed.
Returns:
log_prob (paddle.Tensor): The log probability density of the given Cholesky factor under the LKJ distribution.
log_prob (Tensor): The log probability density of the given Cholesky factor under the LKJ distribution.
"""
# 1.Compute the order vector.
diag_elems = paddle.diagonal(value, offset=0, axis1=-1, axis2=-2)[
Expand Down

0 comments on commit a4ca1a5

Please sign in to comment.