Skip to content

Commit

Permalink
[Typing][B-18] Add type annotations for `python/paddle/distribution/l…
Browse files Browse the repository at this point in the history
…aplace.py` (PaddlePaddle#65784)
  • Loading branch information
ooooo-create authored and lixcli committed Jul 22, 2024
1 parent c2a4880 commit f56e2d3
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions python/paddle/distribution/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import numbers
from typing import TYPE_CHECKING, Sequence

import numpy as np

import paddle
from paddle.base import framework
from paddle.distribution import distribution

if TYPE_CHECKING:
from paddle import Tensor


class Laplace(distribution.Distribution):
r"""
Expand Down Expand Up @@ -52,8 +57,10 @@ class Laplace(distribution.Distribution):
1.31554604)
"""
loc: Tensor
scale: Tensor

def __init__(self, loc, scale):
def __init__(self, loc: float | Tensor, scale: float | Tensor) -> None:
if not isinstance(
loc, (numbers.Real, framework.Variable, paddle.pir.Value)
):
Expand Down Expand Up @@ -84,7 +91,7 @@ def __init__(self, loc, scale):
super().__init__(self.loc.shape)

@property
def mean(self):
def mean(self) -> Tensor:
"""Mean of distribution.
Returns:
Expand All @@ -93,7 +100,7 @@ def mean(self):
return self.loc

@property
def stddev(self):
def stddev(self) -> Tensor:
r"""Standard deviation.
The stddev is
Expand All @@ -111,7 +118,7 @@ def stddev(self):
return (2**0.5) * self.scale

@property
def variance(self):
def variance(self) -> Tensor:
r"""Variance of distribution.
The variance is
Expand All @@ -128,7 +135,9 @@ def variance(self):
"""
return self.stddev.pow(2)

def _validate_value(self, value):
def _validate_value(
self, value: float | Tensor
) -> tuple[Tensor, Tensor, Tensor]:
"""Argument dimension check for distribution methods such as `log_prob`,
`cdf` and `icdf`.
Expand All @@ -155,7 +164,7 @@ def _validate_value(self, value):

return loc, scale, value

def log_prob(self, value):
def log_prob(self, value: float | Tensor) -> Tensor:
r"""Log probability density/mass function.
The log_prob is
Expand Down Expand Up @@ -191,7 +200,7 @@ def log_prob(self, value):

return log_scale - paddle.abs(value - loc) / scale

def entropy(self):
def entropy(self) -> Tensor:
r"""Entropy of Laplace distribution.
The entropy is:
Expand All @@ -218,7 +227,7 @@ def entropy(self):
"""
return 1 + paddle.log(2 * self.scale)

def cdf(self, value):
def cdf(self, value: float | Tensor) -> Tensor:
r"""Cumulative distribution function.
The cdf is
Expand Down Expand Up @@ -257,7 +266,7 @@ def cdf(self, value):

return 0.5 - iterm

def icdf(self, value):
def icdf(self, value: float | Tensor) -> Tensor:
r"""Inverse Cumulative distribution function.
The icdf is
Expand Down Expand Up @@ -291,7 +300,7 @@ def icdf(self, value):

return loc - scale * (term).sign() * paddle.log1p(-2 * term.abs())

def sample(self, shape=()):
def sample(self, shape: Sequence[int] = ()) -> Tensor:
r"""Generate samples of the specified shape.
Args:
Expand All @@ -314,7 +323,7 @@ def sample(self, shape=()):
with paddle.no_grad():
return self.rsample(shape)

def rsample(self, shape):
def rsample(self, shape: Sequence[int]) -> Tensor:
r"""Reparameterized sample.
Args:
Expand Down Expand Up @@ -346,7 +355,7 @@ def rsample(self, shape):
-uniform.abs()
)

def _get_eps(self):
def _get_eps(self) -> float:
"""
Get the eps of certain data type.
Expand All @@ -366,7 +375,7 @@ def _get_eps(self):

return eps

def kl_divergence(self, other):
def kl_divergence(self, other: Laplace) -> Tensor:
r"""Calculate the KL divergence KL(self || other) with two Laplace instances.
The kl_divergence between two Laplace distribution is
Expand Down

0 comments on commit f56e2d3

Please sign in to comment.