Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-25] Add type annotations for python/paddle/distribution/student_t.py #65853

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions python/paddle/distribution/student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
from collections.abc import Sequence
from typing import TYPE_CHECKING

import paddle
from paddle.base.data_feeder import check_type, convert_dtype
from paddle.base.framework import Variable
from paddle.distribution import Gamma, distribution
from paddle.framework import in_dynamic_mode

if TYPE_CHECKING:
from paddle import Tensor, dtype


class StudentT(distribution.Distribution):
r"""
Expand Down Expand Up @@ -49,7 +54,7 @@ class StudentT(distribution.Distribution):
1-D Tensor with paddle global default dtype. Supported dtype: float32, float64.
scale (float|Tensor): The scale of the distribution, which should be non-negative. If the input data type is float, the data type
of `scale` will be converted to a 1-D Tensor with paddle global default dtype. Supported dtype: float32, float64.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
name(str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -83,8 +88,19 @@ class StudentT(distribution.Distribution):
[1.52126312, 2.32064891])

"""

def __init__(self, df, loc, scale, name=None):
df: Tensor
loc: Tensor
scale: Tensor
name: str
dtype: dtype

def __init__(
self,
df: float | Tensor,
loc: float | Tensor,
scale: float | Tensor,
name: str | None = None,
) -> None:
if not in_dynamic_mode():
check_type(
df,
Expand Down Expand Up @@ -148,7 +164,7 @@ def __init__(self, df, loc, scale, name=None):
super().__init__(batch_shape)
self._chi2 = Gamma(0.5 * self.df, paddle.full_like(self.df, 0.5))

def _check_nonnegative(self, value):
def _check_nonnegative(self, value: Tensor) -> bool:
"""Check the non-negative constraint for input parameters

Args:
Expand All @@ -160,7 +176,7 @@ def _check_nonnegative(self, value):
return (value >= 0.0).all()

@property
def mean(self):
def mean(self) -> Tensor:
"""Mean of StudentT distribution.

Returns:
Expand All @@ -173,7 +189,7 @@ def mean(self):
)

@property
def variance(self):
def variance(self) -> Tensor:
"""Variance of StudentT distribution.

Returns:
Expand All @@ -192,7 +208,7 @@ def variance(self):
)
return var

def sample(self, shape=()):
def sample(self, shape: Sequence[int] = ()) -> Tensor:
"""Generate StudentT samples of the specified shape. The final shape would be ``shape+batch_shape`` .

Args:
Expand All @@ -210,7 +226,7 @@ def sample(self, shape=()):
x = z * paddle.rsqrt(chi2 / self.df)
return self.loc + self.scale * x

def entropy(self):
def entropy(self) -> Tensor:
r"""Shannon entropy in nats.

The entropy is
Expand Down Expand Up @@ -245,7 +261,7 @@ def entropy(self):
+ lbeta
)

def log_prob(self, value):
def log_prob(self, value: Tensor) -> Tensor:
"""Log probability density function.

Args:
Expand All @@ -265,7 +281,7 @@ def log_prob(self, value):
)
return -0.5 * (self.df + 1.0) * paddle.log1p(y**2.0 / self.df) - Z

def prob(self, value):
def prob(self, value: Tensor) -> Tensor:
"""Probability density function.

Args:
Expand Down