From b705a8d946266023ab15708e82a3a82d47d63da3 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 11 Oct 2023 10:49:19 +0800 Subject: [PATCH 01/29] add api and test --- python/paddle/distribution/__init__.py | 4 + .../distribution/continuous_bernoulli.py | 401 ++++++++++++++ .../distribution/multivariate_normal.py | 487 ++++++++++++++++++ .../test_distribution_continuous_bernoulli.py | 320 ++++++++++++ ...istribution_continuous_bernoulli_static.py | 335 ++++++++++++ .../test_distribution_multivariate_normal.py | 251 +++++++++ ...distribution_multivariate_normal_static.py | 282 ++++++++++ 7 files changed, 2080 insertions(+) create mode 100644 python/paddle/distribution/continuous_bernoulli.py create mode 100644 python/paddle/distribution/multivariate_normal.py create mode 100644 test/distribution/test_distribution_continuous_bernoulli.py create mode 100644 test/distribution/test_distribution_continuous_bernoulli_static.py create mode 100644 test/distribution/test_distribution_multivariate_normal.py create mode 100644 test/distribution/test_distribution_multivariate_normal_static.py diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 68f4820da994d..6ae25d2012d6d 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -17,6 +17,7 @@ from paddle.distribution.beta import Beta from paddle.distribution.categorical import Categorical from paddle.distribution.cauchy import Cauchy +from paddle.distribution.continuous_bernoulli import ContinuousBernoulli from paddle.distribution.dirichlet import Dirichlet from paddle.distribution.distribution import Distribution from paddle.distribution.gumbel import Gumbel @@ -25,6 +26,7 @@ from paddle.distribution.kl import kl_divergence, register_kl from paddle.distribution.lognormal import LogNormal from paddle.distribution.multinomial import Multinomial +from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.distribution.normal import Normal from paddle.distribution.transform import * # noqa: F403 from paddle.distribution.transformed_distribution import TransformedDistribution @@ -37,10 +39,12 @@ 'Beta', 'Categorical', 'Cauchy', + 'ContinuousBernoulli', 'Dirichlet', 'Distribution', 'ExponentialFamily', 'Multinomial', + 'MultivariateNormal', 'Normal', 'Uniform', 'kl_divergence', diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py new file mode 100644 index 0000000000000..2b1053b736924 --- /dev/null +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -0,0 +1,401 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable + +import numpy as np + +import paddle +from paddle.distribution import distribution + + +class ContinuousBernoulli(distribution.Distribution): + r"""The Continuous Bernoulli distribution with probability parameter: `probability`. + + Mathematical details + + The probability density function (pdf) is + + .. math:: + + p(x;\lambda) = C(\lambda)\lambda^x (1-\lambda)^{1-x} + + In the above equation: + + * :math:`probability = \lambda`: is the probability. + * :math: `C(\lambda) = + \left\{ + \begin{aligned} + &2 & \text{ if $\lambda = \frac{1}{2}$} \\ + &\frac{2\tanh^{-1}(1-2\lambda)}{1 - 2\lambda} & \text{ otherwise} + \end{aligned} + \right.` + * :math:`x`: is continuous between 0 and 1 + + Args: + probability(int|float|np.ndarray|Tensor): The probability of Continuous Bernoulli distribution, which characterize the shape of the pdf. + The data type of `probability` will be convert to float32. + eps(float): Specify the bandwith of the unstable calculation region near 0.5 + + Examples: + .. code-block:: python + + >>> import paddle + >>> from paddle.distribution import ContinuousBernoulli + >>> rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5])) + >>> print(rv.sample([2])) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.09428147, 0.81438422], + [0.24624705, 0.93354583]]) + >>> print(rv.mean) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.38801414, 0.50000000]) + >>> print(rv.entropy()) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.07641461, 0. ]) + >>> rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8])) + >>> rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5])) + >>> print(rv1.kl_divergence(rv2)) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.20103613, 0.07641447]) + """ + + def __init__(self, probability, eps=1e-4): + self.eps = paddle.to_tensor(eps) + self.dtype = 'float32' + self.probability = self._to_tensor(probability) + eps_prob = paddle.finfo(self.probability.dtype).eps + self.probability = paddle.clip( + self.probability, min=eps_prob, max=1 - eps_prob + ) + + if not self._check_constraint(self.probability): + raise ValueError( + 'Every element of input parameter `rate` should be nonnegative.' + ) + if self.probability.shape == []: + batch_shape = (1,) + else: + batch_shape = self.probability.shape + super().__init__(batch_shape) + + def _to_tensor(self, probability): + """Convert the input parameters into tensors with dtype = float32 + + Returns: + Tensor: converted probability. + """ + # convert type + if isinstance(probability, (float, int)): + probability = paddle.to_tensor([probability], dtype=self.dtype) + if isinstance(probability, np.ndarray): + probability = paddle.to_tensor(probability) + probability = paddle.cast(probability, dtype=self.dtype) + return probability + + def _check_constraint(self, value): + """Check the constraint for input parameters + + Args: + value (Tensor) + + Returns: + bool: pass or not. + """ + return (value >= 0).all() and (value <= 1).all() + + def _cut_support_region(self): + """Generate stable support region indicator (prob < 0.5 - self.eps && prob >= 0.5 + self.eps ) + + Returns: + Tensor: the element of the returned indicator tensor corresponding to stable region is True, and False otherwise + """ + return paddle.logical_or( + paddle.less_equal(self.probability, 0.5 - self.eps), + paddle.greater_than(self.probability, 0.5 + self.eps), + ) + + def _cut_probs(self): + """Cut the probability parameter with stable support region + + Returns: + Tensor: the element of the returned probability tensor corresponding to unstable region is set to be (0.5 - self.eps), and unchanged otherwise + """ + return paddle.where( + self._cut_support_region(), + self.probability, + (0.5 - self.eps) * paddle.ones_like(self.probability), + ) + + def _tanh_inverse(self, value): + """Calculate the tanh inverse of value + + Args: + value (Tensor) + + Returns: + Tensor: tanh inverse of value + """ + return 0.5 * (paddle.log1p(value) - paddle.log1p(-value)) + + def _log_constant(self): + """Calculate the logarithm of the constant factor :math:`C(lambda)` in the pdf of the Continuous Bernoulli distribution + + Returns: + Tensor: logarithm of the constant factor + """ + cut_probs = self._cut_probs() + cut_probs_below_half = paddle.where( + paddle.less_equal(cut_probs, paddle.to_tensor(0.5)), + cut_probs, + paddle.zeros_like(cut_probs), + ) + cut_probs_above_half = paddle.where( + paddle.greater_equal(cut_probs, paddle.to_tensor(0.5)), + cut_probs, + paddle.ones_like(cut_probs), + ) + log_constant_propose = paddle.log( + 2.0 * paddle.abs(self._tanh_inverse(1.0 - 2.0 * cut_probs)) + ) - paddle.where( + paddle.less_equal(cut_probs, paddle.to_tensor(0.5)), + paddle.log1p(-2.0 * cut_probs_below_half), + paddle.log(2.0 * cut_probs_above_half - 1.0), + ) + x = paddle.square(self.probability - 0.5) + taylor_expansion = ( + paddle.log(paddle.to_tensor(2.0)) + + (4.0 / 3.0 + 104.0 / 45.0 * x) * x + ) + return paddle.where( + self._cut_support_region(), log_constant_propose, taylor_expansion + ) + + @property + def mean(self): + """Mean of Continuous Bernoulli distribuion. + + Returns: + Tensor: mean value. + """ + cut_probs = self._cut_probs() + tmp = paddle.divide(cut_probs, 2.0 * cut_probs - 1.0) + propose = tmp + paddle.divide( + paddle.to_tensor(1.0, dtype=self.dtype), + 2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs), + ) + x = self.probability - 0.5 + taylor_expansion = ( + 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * paddle.square(x)) * x + ) + return paddle.where( + self._cut_support_region(), propose, taylor_expansion + ) + + @property + def variance(self): + """Variance of Continuous Bernoulli distribution. + + Returns: + Tensor: variance value. + """ + cut_probs = self._cut_probs() + tmp = paddle.divide( + paddle.square(cut_probs) - cut_probs, + paddle.square(1.0 - 2.0 * cut_probs), + ) + propose = tmp + paddle.divide( + paddle.to_tensor(1.0, dtype=self.dtype), + paddle.square(2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)), + ) + x = paddle.square(self.probability - 0.5) + taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x + return paddle.where( + self._cut_support_region(), propose, taylor_expansion + ) + + def sample(self, shape=()): + """Generate Continuous Bernoulli samples of the specified shape. + + Args: + shape (Sequence[int], optional): Prepended shape of the generated samples. + + Returns: + Tensor, A tensor with prepended dimensions shape. The data type is float32. + """ + with paddle.no_grad(): + return self.rsample(shape) + + def rsample(self, shape=()): + """Generate Continuous Bernoulli samples of the specified shape. + + Args: + shape (Sequence[int], optional): Prepended shape of the generated samples. + + Returns: + Tensor, A tensor with prepended dimensions shape. The data type is float32. + """ + if not isinstance(shape, Iterable): + raise TypeError('sample shape must be Iterable object.') + shape = tuple(shape) + batch_shape = tuple(self.batch_shape) + output_shape = tuple(shape + batch_shape) + u = paddle.uniform(shape=output_shape, dtype=self.dtype, min=0, max=1) + return self.icdf(u) + + def log_prob(self, value): + """Log probability density function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: log probability. The data type is same with :attr:`value` . + """ + value = paddle.cast(value, dtype=self.dtype) + if not self._check_constraint(value): + raise ValueError( + 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.' + ) + eps = paddle.finfo(self.probability.dtype).eps + cross_entropy = paddle.nan_to_num( + value * paddle.log(self.probability) + + (1.0 - value) * paddle.log(1 - self.probability), + neginf=-eps, + ) + return self._log_constant() + cross_entropy + + def prob(self, value): + """Probability density function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability. The data type is same with :attr:`value` . + """ + return paddle.exp(self.log_prob(value)) + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + \mathcal{H}(X) = - \int_{x \in \Omega} p(x) \log{p(x)} dx + + In the above equation: + + * :math:\Omega: is the support of the distribution. + + Returns: + Tensor, Shannon entropy of Continuous Bernoulli distribution. The data type is float32. + """ + log_p = paddle.log(self.probability) + log_1_minus_p = paddle.log1p(-self.probability) + + return ( + -self._log_constant() + + self.mean * (log_1_minus_p - log_p) + - log_1_minus_p + ) + + def cdf(self, value): + """Cumulative distribution function + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: quantile of :attr:`value`. The data type is same with :attr:`value` . + """ + value = paddle.cast(value, dtype=self.dtype) + if not self._check_constraint(value): + raise ValueError( + 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.' + ) + cut_probs = self._cut_probs() + cdfs = ( + paddle.pow(cut_probs, value) + * paddle.pow(1.0 - cut_probs, 1.0 - value) + + cut_probs + - 1.0 + ) / (2.0 * cut_probs - 1.0) + unbounded_cdfs = paddle.where(self._cut_support_region(), cdfs, value) + return paddle.where( + paddle.less_equal(value, paddle.to_tensor(0.0)), + paddle.zeros_like(value), + paddle.where( + paddle.greater_equal(value, paddle.to_tensor(1.0)), + paddle.ones_like(value), + unbounded_cdfs, + ), + ) + + def icdf(self, value): + """Inverse cumulative distribution function + + Args: + value (Tensor): The input tensor, meaning the quantile. + + Returns: + Tensor: p-value of the quantile. The data type is same with :attr:`value` . + """ + value = paddle.cast(value, dtype=self.dtype) + if not self._check_constraint(value): + raise ValueError( + 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.' + ) + cut_probs = self._cut_probs() + return paddle.where( + self._cut_support_region(), + ( + paddle.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - paddle.log1p(-cut_probs) + ) + / (paddle.log(cut_probs) - paddle.log1p(-cut_probs)), + value, + ) + + def kl_divergence(self, other): + r"""The KL-divergence between two Continuous Bernoulli distributions. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence(\lambda_1, \lambda_2) = \int_x p_1(x) \log{\frac{p_1(x)}{p_2(x)}} dx + + Args: + other (ContinuousBernoulli): instance of Continuous Bernoulli. + + Returns: + Tensor, kl-divergence between two Continuous Bernoulli distributions. The data type is float32. + + """ + + if self.batch_shape != other.batch_shape: + raise ValueError( + "KL divergence of two Continuous Bernoulli distributions should share the same `batch_shape`." + ) + part1 = -self.entropy() + log_q = paddle.log(other.probability) + log_1_minus_q = paddle.log1p(-other.probability) + part2 = -( + other._log_constant() + + self.mean * (log_q - log_1_minus_q) + + log_1_minus_q + ) + return part1 + part2 diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py new file mode 100644 index 0000000000000..89125f64121d1 --- /dev/null +++ b/python/paddle/distribution/multivariate_normal.py @@ -0,0 +1,487 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable + +import numpy as np + +import paddle +from paddle.distribution import distribution + + +class MultivariateNormal(distribution.Distribution): + r"""The Multivariate Normal distribution with parameter: `loc` and any one of the following parameters: `covariance_matrix`, `precision_matrix`, `scale_tril`. + + Mathematical details + + The probability density function (pdf) is + + .. math:: + + p(X ;\mu, \Sigma) = \frac{1}{\sqrt{(2\pi)^k |\Sigma|}} \exp(-\frac{1}{2}(X - \mu)^{\intercal} \Sigma^{-1} (X - \mu)) + + In the above equation: + + * :math:`loc = \mu`: is the mean. + * :math:`covariance_matrix = \Sigma`: is the covariance matrix. + + Args: + loc(int|float|np.ndarray|Tensor): The mean of Multivariate Normal distribution. The data type of `loc` will be convert to float32. + covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert to float32. + precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to float32. + scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be convert to float32. + + Examples: + .. code-block:: python + + >>> import paddle + >>> from paddle.distribution import MultivariateNormal + >>> rv = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + >>> print(rv.sample([3, 2])) + Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[0.68554986, 3.85142398], + [1.88336682, 5.43841648]], + + [[5.32492065, 7.23725986], + [3.42192221, 4.83934879]], + + [[3.36775684, 4.46108866], + [4.58927441, 4.32255936]]]) + >>> print(rv.mean) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [2., 5.]) + >>> print(rv.entropy()) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 3.38718319) + >>> rv1 = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + >>> rv2 = MultivariateNormal(loc=paddle.to_tensor([-1.,3.]), covariance_matrix=paddle.to_tensor([[3.,2.],[2.,3.]])) + >>> print(rv1.kl_divergence(rv2)) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 1.55541301) + """ + + def __init__( + self, + loc, + covariance_matrix=None, + precision_matrix=None, + scale_tril=None, + ): + self.dtype = 'float32' + if isinstance(loc, (float, int)): + loc = paddle.to_tensor([loc], dtype=self.dtype) + elif isinstance(loc, np.ndarray): + loc = paddle.to_tensor(loc, dtype=self.dtype) + if loc.dim() < 1: + loc = loc.reshape((1,)) + loc = paddle.cast(loc, dtype=self.dtype) + self.covariance_matrix = None + self.precision_matrix = None + self.scale_tril = None + if (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) != 1: + raise ValueError( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) + + if scale_tril is not None: + if scale_tril.dim() < 2: + raise ValueError( + "scale_tril matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + scale_tril = paddle.cast(scale_tril, dtype=self.dtype) + batch_shape = paddle.broadcast_shape( + scale_tril.shape[:-2], loc.shape[:-1] + ) + self.scale_tril = scale_tril.expand( + batch_shape + [scale_tril.shape[-2], scale_tril.shape[-1]] + ) + elif covariance_matrix is not None: + if covariance_matrix.dim() < 2: + raise ValueError( + "covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + covariance_matrix = paddle.cast(covariance_matrix, dtype=self.dtype) + batch_shape = paddle.broadcast_shape( + covariance_matrix.shape[:-2], loc.shape[:-1] + ) + self.covariance_matrix = covariance_matrix.expand( + batch_shape + + [covariance_matrix.shape[-2], covariance_matrix.shape[-1]] + ) + else: + if precision_matrix.dim() < 2: + raise ValueError( + "precision_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + precision_matrix = paddle.cast(precision_matrix, dtype=self.dtype) + batch_shape = paddle.broadcast_shape( + precision_matrix.shape[:-2], loc.shape[:-1] + ) + self.precision_matrix = precision_matrix.expand( + batch_shape + + [precision_matrix.shape[-2], precision_matrix.shape[-1]] + ) + self._check_constriants() + self.loc = loc.expand( + batch_shape + + [ + -1, + ] + ) + event_shape = self.loc.shape[-1:] + + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + elif covariance_matrix is not None: + self._unbroadcasted_scale_tril = paddle.linalg.cholesky( + covariance_matrix + ) + else: + self._unbroadcasted_scale_tril = precision_to_scale_tril( + precision_matrix + ) + + super().__init__(batch_shape, event_shape) + + def _check_lower_triangular(self, value): + """Check whether the input is a lower triangular matrix + + Args: + value (Tensor): input matrix + + Return: + Tensor: indicator for lower triangular matrix + """ + tril = value.tril() + is_lower_triangular = paddle.cast( + (tril == value).reshape( + value.shape[:-2] + + [ + -1, + ] + ), + dtype=self.dtype, + ).min(-1, keepdim=True)[0] + is_positive_diag = paddle.cast( + (value.diagonal(axis1=-2, axis2=-1) > 0), dtype=self.dtype + ).min(-1, keepdim=True)[0] + return is_lower_triangular and is_positive_diag + + def _check_positive_definite(self, value): + """Check whether the input is a positive definite matrix + + Args: + value (Tensor): input matrix + + Return: + Tensor: indicator for positive definite matrix + """ + is_square = paddle.full( + shape=value.shape[:-2], + fill_value=(value.shape[-2] == value.shape[-1]), + dtype="bool", + ).all() + if not is_square: + raise ValueError( + "covariance_matrix or precision_matrix must be a sqaure matrix" + ) + new_perm = list(range(len(value.shape))) + new_perm[-1], new_perm[-2] = new_perm[-2], new_perm[-1] + is_symmetric = paddle.isclose( + value, value.transpose(new_perm), atol=1e-6 + ).all() + if not is_symmetric: + raise ValueError( + "covariance_matrix or precision_matrix must be a symmetric matrix" + ) + is_postive_definite = ( + paddle.cast(paddle.linalg.eigvalsh(value), dtype="float32") > 0 + ).all() + return is_postive_definite + + def _check_constriants(self): + """Check whether the matrix satisfy corresponding constriant + + Return: + Tensor: indicator for the pass of constriants check + """ + if self.scale_tril is not None: + check = self._check_lower_triangular(self.scale_tril) + if not check: + raise ValueError( + "scale_tril matrix must be a lower triangular matrix with positive diagonals" + ) + elif self.covariance_matrix is not None: + is_postive_definite = self._check_positive_definite( + self.covariance_matrix + ) + if not is_postive_definite: + raise ValueError( + "covariance_matrix must be a symmetric positive definite matrix" + ) + else: + is_postive_definite = self._check_positive_definite( + self.precision_matrix + ) + if not is_postive_definite: + raise ValueError( + "precision_matrix must be a symmetric positive definite matrix" + ) + + @property + def mean(self): + """Mean of Multivariate Normal distribuion. + + Returns: + Tensor: mean value. + """ + return self.loc + + @property + def variance(self): + """Variance of Multivariate Normal distribution. + + Returns: + Tensor: variance value. + """ + return ( + paddle.square(self._unbroadcasted_scale_tril) + .sum(-1) + .expand(self._batch_shape + self._event_shape) + ) + + def sample(self, shape=()): + """Generate Multivariate Normal samples of the specified shape. + + Args: + shape (Sequence[int], optional): Prepended shape of the generated samples. + + Returns: + Tensor, A tensor with prepended dimensions shape. The data type is float32. + """ + with paddle.no_grad(): + return self.rsample(shape) + + def rsample(self, shape=()): + if not isinstance(shape, Iterable): + raise TypeError('sample shape must be Iterable object.') + output_shape = self._extend_shape(shape) + eps = paddle.normal(shape=output_shape) + return self.loc + paddle.matmul( + self._unbroadcasted_scale_tril, eps.unsqueeze(-1) + ).squeeze(-1) + + def log_prob(self, value): + """Log probability density function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: log probability. The data type is same with :attr:`value` . + """ + value = paddle.cast(value, dtype=self.dtype) + + diff = value - self.loc + M = batch_mahalanobis(self._unbroadcasted_scale_tril, diff) + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(axis1=-2, axis2=-1) + .log() + .sum(-1) + ) + return ( + -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) + - half_log_det + ) + + def prob(self, value): + """Probability density function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability. The data type is same with :attr:`value` . + """ + return paddle.exp(self.log_prob(value)) + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + \mathcal{H}(X) = - \int_{x \in \Omega} p(x) \log{p(x)} dx + + In the above equation: + + * :math:\Omega: is the support of the distribution. + + Returns: + Tensor, Shannon entropy of Multivariate Normal distribution. The data type is float32. + """ + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(axis1=-2, axis2=-1) + .log() + .sum(-1) + ) + H = ( + 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + + half_log_det + ) + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) + + def kl_divergence(self, other): + r"""The KL-divergence between two poisson distributions. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence(\mu_1, \Sigma_1, \mu_2, \Sigma_2) = \int_x p_1(x) \log{\frac{p_1(x)}{p_2(x)}} dx + + Args: + other (MultivariateNormal): instance of Multivariate Normal. + + Returns: + Tensor, kl-divergence between two Multivariate Normal distributions. The data type is float32. + + """ + if ( + self._batch_shape != other._batch_shape + and self._event_shape != other._event_shape + ): + raise ValueError( + "KL divergence of two Multivariate Normal distributions should share the same `batch_shape` and `event_shape`." + ) + half_log_det_1 = ( + self._unbroadcasted_scale_tril.diagonal(axis1=-2, axis2=-1) + .log() + .sum(-1) + ) + half_log_det_2 = ( + other._unbroadcasted_scale_tril.diagonal(axis1=-2, axis2=-1) + .log() + .sum(-1) + ) + new_perm = list(range(len(self._unbroadcasted_scale_tril.shape))) + new_perm[-1], new_perm[-2] = new_perm[-2], new_perm[-1] + cov_mat_1 = paddle.matmul( + self._unbroadcasted_scale_tril, + self._unbroadcasted_scale_tril.transpose(new_perm), + ) + cov_mat_2 = paddle.matmul( + other._unbroadcasted_scale_tril, + other._unbroadcasted_scale_tril.transpose(new_perm), + ) + expectation = ( + paddle.linalg.solve(cov_mat_2, cov_mat_1) + .diagonal(axis1=-2, axis2=-1) + .sum(-1) + ) + expectation += batch_mahalanobis( + other._unbroadcasted_scale_tril, self.loc - other.loc + ) + return ( + half_log_det_2 + - half_log_det_1 + + 0.5 * (expectation - self._event_shape[0]) + ) + + +def precision_to_scale_tril(P): + """Convert precision matrix to scale tril matrix + + Args: + P (Tensor): input precision matrix + + Returns: + Tensor: scale tril matrix + """ + Lf = paddle.linalg.cholesky(paddle.flip(P, (-2, -1))) + tmp = paddle.flip(Lf, (-2, -1)) + new_perm = list(range(len(tmp.shape))) + new_perm[-2], new_perm[-1] = new_perm[-1], new_perm[-2] + L_inv = paddle.transpose(tmp, new_perm) + Id = paddle.eye(P.shape[-1], dtype=P.dtype) + L = paddle.linalg.triangular_solve(L_inv, Id, upper=False) + return L + + +def batch_mahalanobis(bL, bx): + r""" + Computes the squared Mahalanobis distance of the Multivariate Normal distribution with cholesky decomposition of the covatiance matrix. + Accepts batches for both bL and bx. + + Args: + bL (Tensor): scale trial matrix (batched) + bx (Tensor): difference vector(batched) + + Returns: + Tensor: squared Mahalanobis distance + """ + n = bx.shape[-1] + bx_batch_shape = bx.shape[:-1] + + # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), + # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve + bx_batch_dims = len(bx_batch_shape) + bL_batch_dims = bL.dim() - 2 + outer_batch_dims = bx_batch_dims - bL_batch_dims + old_batch_dims = outer_batch_dims + bL_batch_dims + new_batch_dims = outer_batch_dims + 2 * bL_batch_dims + # Reshape bx with the shape (..., 1, i, j, 1, n) + bx_new_shape = bx.shape[:outer_batch_dims] + for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + bx_new_shape += (sx // sL, sL) + bx_new_shape += (n,) + bx = bx.reshape(bx_new_shape) + # Permute bx to make it have shape (..., 1, j, i, 1, n) + permute_dims = ( + list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims] + ) + bx = bx.transpose(permute_dims) + + flat_L = bL.reshape((-1, n, n)) # shape = b x n x n + flat_x = bx.reshape((-1, flat_L.shape[0], n)) # shape = c x b x n + flat_x_swap = flat_x.transpose((1, 2, 0)) # shape = b x n x c + M_swap = ( + paddle.linalg.triangular_solve(flat_L, flat_x_swap, upper=False) + .pow(2) + .sum(-2) + ) # shape = b x c + M = M_swap.t() # shape = c x b + + # Now we revert the above reshape and permute operators. + permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1) + permute_inv_dims = list(range(outer_batch_dims)) + for i in range(bL_batch_dims): + permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + reshaped_M = permuted_M.transpose( + permute_inv_dims + ) # shape = (..., 1, i, j, 1) + return reshaped_M.reshape(bx_batch_shape) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py new file mode 100644 index 0000000000000..4355e66a5fd56 --- /dev/null +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -0,0 +1,320 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +from distribution import config + +import paddle +from paddle.distribution.continuous_bernoulli import ContinuousBernoulli + + +class ContinuousBernoulli_np: + def __init__(self, probability, eps=1e-4): + self.eps = eps + self.dtype = 'float32' + eps_prob = 1.1920928955078125e-07 + self.probability = np.clip( + probability, a_min=eps_prob, a_max=1 - eps_prob + ) + + def _cut_support_region(self): + return np.logical_or( + np.less_equal(self.probability, 0.5 - self.eps), + np.greater_equal(self.probability, 0.5 + self.eps), + ) + + def _cut_probs(self): + return np.where( + self._cut_support_region(), + self.probability, + (0.5 - self.eps) * np.ones_like(self.probability), + ) + + def _tanh_inverse(self, value): + return 0.5 * (np.log1p(value) - np.log1p(-value)) + + def _log_constant(self): + cut_probs = self._cut_probs() + cut_probs_below_half = np.where( + np.less_equal(cut_probs, 0.5), cut_probs, np.zeros_like(cut_probs) + ) + cut_probs_above_half = np.where( + np.greater_equal(cut_probs, 0.5), cut_probs, np.ones_like(cut_probs) + ) + log_constant_propose = np.log( + 2.0 * np.abs(self._tanh_inverse(1.0 - 2.0 * cut_probs)) + ) - np.where( + np.less_equal(cut_probs, 0.5), + np.log1p(-2.0 * cut_probs_below_half), + np.log(2.0 * cut_probs_above_half - 1.0), + ) + x = np.square(self.probability - 0.5) + taylor_expansion = np.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x + return np.where( + self._cut_support_region(), log_constant_propose, taylor_expansion + ) + + def np_variance(self): + cut_probs = self._cut_probs() + tmp = np.divide( + np.square(cut_probs) - cut_probs, np.square(1.0 - 2.0 * cut_probs) + ) + propose = tmp + np.divide( + 1.0, np.square(2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) + ) + x = np.square(self.probability - 0.5) + taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x + return np.where(self._cut_support_region(), propose, taylor_expansion) + + def np_mean(self): + cut_probs = self._cut_probs() + tmp = cut_probs / (2.0 * cut_probs - 1.0) + propose = tmp + 1.0 / (2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) + x = self.probability - 0.5 + taylor_expansion = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * np.square(x)) * x + return np.where(self._cut_support_region(), propose, taylor_expansion) + + def np_entropy(self): + log_p = np.log(self.probability) + log_1_minus_p = np.log1p(-self.probability) + return ( + -self._log_constant() + + self.np_mean() * (log_1_minus_p - log_p) + - log_1_minus_p + ) + + def np_prob(self, value): + return np.exp(self.np_log_prob(value)) + + def np_log_prob(self, value): + eps = 1e-8 + cross_entropy = np.nan_to_num( + value * np.log(self.probability) + + (1.0 - value) * np.log(1 - self.probability), + neginf=-eps, + ) + return self._log_constant() + cross_entropy + + def np_cdf(self, value): + cut_probs = self._cut_probs() + cdfs = ( + np.power(cut_probs, value) * np.power(1.0 - cut_probs, 1.0 - value) + + cut_probs + - 1.0 + ) / (2.0 * cut_probs - 1.0) + unbounded_cdfs = np.where(self._cut_support_region(), cdfs, value) + return np.where( + np.less_equal(value, 0.0), + np.zeros_like(value), + np.where( + np.greater_equal(value, 1.0), + np.ones_like(value), + unbounded_cdfs, + ), + ) + + def np_icdf(self, value): + cut_probs = self._cut_probs() + return np.where( + self._cut_support_region(), + ( + np.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - np.log1p(-cut_probs) + ) + / (np.log(cut_probs) - np.log1p(-cut_probs)), + value, + ) + + def np_kl_divergence(self, other): + part1 = -self.np_entropy() + log_q = np.log(other.probability) + log_1_minus_q = np.log1p(-other.probability) + part2 = -( + other._log_constant() + + self.np_mean() * (log_q - log_1_minus_q) + + log_1_minus_q + ) + return part1 + part2 + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'probability'), + [ + ('half', np.array(0.5).astype("float32")), + ( + 'one-dim', + parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + ), + ( + 'multi-dim', + parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + ), + ], +) +class TestContinuousBernoulli(unittest.TestCase): + def setUp(self): + self._dist = ContinuousBernoulli( + probability=paddle.to_tensor(self.probability) + ) + self._np_dist = ContinuousBernoulli_np(self.probability) + + def test_mean(self): + mean = self._dist.mean + self.assertEqual(mean.numpy().dtype, self.probability.dtype) + np.testing.assert_allclose( + mean, + self._np_dist.np_mean(), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_variance(self): + var = self._dist.variance + self.assertEqual(var.numpy().dtype, self.probability.dtype) + np.testing.assert_allclose( + var, + self._np_dist.np_variance(), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_entropy(self): + entropy = self._dist.entropy() + self.assertEqual(entropy.numpy().dtype, self.probability.dtype) + np.testing.assert_allclose( + entropy, + self._np_dist.np_entropy(), + rtol=0.10, + atol=0.10, + ) + + def test_sample(self): + sample_shape = () + samples = self._dist.sample(sample_shape) + self.assertEqual(samples.numpy().dtype, self.probability.dtype) + self.assertEqual( + tuple(samples.shape), + sample_shape + self._dist.batch_shape + self._dist.event_shape, + ) + + sample_shape = (5000,) + samples = self._dist.sample(sample_shape) + sample_mean = samples.mean(axis=0) + sample_variance = samples.var(axis=0) + + np.testing.assert_allclose( + sample_mean, self._dist.mean, atol=0, rtol=0.20 + ) + np.testing.assert_allclose( + sample_variance, self._dist.variance, atol=0, rtol=0.20 + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'probability', 'value'), + [ + ( + 'value-same-shape', + parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + ), + ( + 'value-broadcast-shape', + parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + ), + ], +) +class TestContinuousBernoulliProbs(unittest.TestCase): + def setUp(self): + self._dist = ContinuousBernoulli(probability=self.probability) + self._np_dist = ContinuousBernoulli_np(self.probability) + + def test_prob(self): + np.testing.assert_allclose( + self._dist.prob(paddle.to_tensor(self.value)), + self._np_dist.np_prob(self.value), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_log_prob(self): + np.testing.assert_allclose( + self._dist.log_prob(paddle.to_tensor(self.value)), + self._np_dist.np_log_prob(self.value), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_cdf(self): + np.testing.assert_allclose( + self._dist.cdf(paddle.to_tensor(self.value)), + self._np_dist.np_cdf(self.value), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_icdf(self): + np.testing.assert_allclose( + self._dist.icdf(paddle.to_tensor(self.value)), + self._np_dist.np_icdf(self.value), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'p_1', 'p_2'), + [ + ( + 'one-dim', + parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + ), + ( + 'multi-dim', + parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + ), + ], +) +class TestContinuousBernoulliKL(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self._dist1 = ContinuousBernoulli( + probability=paddle.to_tensor(self.p_1) + ) + self._dist2 = ContinuousBernoulli( + probability=paddle.to_tensor(self.p_2) + ) + self._np_dist1 = ContinuousBernoulli_np(self.p_1) + self._np_dist2 = ContinuousBernoulli_np(self.p_2) + + def test_kl_divergence(self): + kl0 = self._dist1.kl_divergence(self._dist2) + kl1 = self._np_dist1.np_kl_divergence(self._np_dist2) + + self.assertEqual(tuple(kl0.shape), self._dist1.batch_shape) + self.assertEqual(tuple(kl1.shape), self._dist1.batch_shape) + np.testing.assert_allclose(kl0, kl1, rtol=0.1, atol=0.1) + + +if __name__ == '__main__': + unittest.main(argv=[''], verbosity=3, exit=False) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py new file mode 100644 index 0000000000000..d221f9b183c10 --- /dev/null +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -0,0 +1,335 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +from distribution import config + +import paddle +from paddle.distribution.continuous_bernoulli import ContinuousBernoulli + + +class ContinuousBernoulli_np: + def __init__(self, probability, eps=1e-4): + self.eps = eps + self.dtype = 'float32' + eps_prob = 1.1920928955078125e-07 + self.probability = np.clip( + probability, a_min=eps_prob, a_max=1 - eps_prob + ) + + def _cut_support_region(self): + return np.logical_or( + np.less_equal(self.probability, 0.5 - self.eps), + np.greater_equal(self.probability, 0.5 + self.eps), + ) + + def _cut_probs(self): + return np.where( + self._cut_support_region(), + self.probability, + (0.5 - self.eps) * np.ones_like(self.probability), + ) + + def _tanh_inverse(self, value): + return 0.5 * (np.log1p(value) - np.log1p(-value)) + + def _log_constant(self): + cut_probs = self._cut_probs() + cut_probs_below_half = np.where( + np.less_equal(cut_probs, 0.5), cut_probs, np.zeros_like(cut_probs) + ) + cut_probs_above_half = np.where( + np.greater_equal(cut_probs, 0.5), cut_probs, np.ones_like(cut_probs) + ) + log_constant_propose = np.log( + 2.0 * np.abs(self._tanh_inverse(1.0 - 2.0 * cut_probs)) + ) - np.where( + np.less_equal(cut_probs, 0.5), + np.log1p(-2.0 * cut_probs_below_half), + np.log(2.0 * cut_probs_above_half - 1.0), + ) + x = np.square(self.probability - 0.5) + taylor_expansion = np.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x + return np.where( + self._cut_support_region(), log_constant_propose, taylor_expansion + ) + + def np_variance(self): + cut_probs = self._cut_probs() + tmp = np.divide( + np.square(cut_probs) - cut_probs, np.square(1.0 - 2.0 * cut_probs) + ) + propose = tmp + np.divide( + 1.0, np.square(2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) + ) + x = np.square(self.probability - 0.5) + taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x + return np.where(self._cut_support_region(), propose, taylor_expansion) + + def np_mean(self): + cut_probs = self._cut_probs() + tmp = cut_probs / (2.0 * cut_probs - 1.0) + propose = tmp + 1.0 / (2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) + x = self.probability - 0.5 + taylor_expansion = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * np.square(x)) * x + return np.where(self._cut_support_region(), propose, taylor_expansion) + + def np_entropy(self): + log_p = np.log(self.probability) + log_1_minus_p = np.log1p(-self.probability) + return ( + -self._log_constant() + + self.np_mean() * (log_1_minus_p - log_p) + - log_1_minus_p + ) + + def np_prob(self, value): + return np.exp(self.np_log_prob(value)) + + def np_log_prob(self, value): + eps = 1e-8 + cross_entropy = np.nan_to_num( + value * np.log(self.probability) + + (1.0 - value) * np.log(1 - self.probability), + neginf=-eps, + ) + return self._log_constant() + cross_entropy + + def np_cdf(self, value): + cut_probs = self._cut_probs() + cdfs = ( + np.power(cut_probs, value) * np.power(1.0 - cut_probs, 1.0 - value) + + cut_probs + - 1.0 + ) / (2.0 * cut_probs - 1.0) + unbounded_cdfs = np.where(self._cut_support_region(), cdfs, value) + return np.where( + np.less_equal(value, 0.0), + np.zeros_like(value), + np.where( + np.greater_equal(value, 1.0), + np.ones_like(value), + unbounded_cdfs, + ), + ) + + def np_icdf(self, value): + cut_probs = self._cut_probs() + return np.where( + self._cut_support_region(), + ( + np.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - np.log1p(-cut_probs) + ) + / (np.log(cut_probs) - np.log1p(-cut_probs)), + value, + ) + + def np_kl_divergence(self, other): + part1 = -self.np_entropy() + log_q = np.log(other.probability) + log_1_minus_q = np.log1p(-other.probability) + part2 = -( + other._log_constant() + + self.np_mean() * (log_q - log_1_minus_q) + + log_1_minus_q + ) + return part1 + part2 + + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'probability'), + [ + ( + 'multi-dim', + parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + ), + ], +) +class TestContinuousBernoulli(unittest.TestCase): + def setUp(self): + self._np_dist = ContinuousBernoulli_np(self.probability) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + probability = paddle.static.data( + 'probability', self.probability.shape, self.probability.dtype + ) + dist = ContinuousBernoulli(probability) + mean = dist.mean + var = dist.variance + entropy = dist.entropy() + mini_samples = dist.sample(shape=()) + large_samples = dist.sample(shape=(1000,)) + fetch_list = [mean, var, entropy, mini_samples, large_samples] + feed = {'probability': self.probability} + + executor.run(startup_program) + [ + self.mean, + self.var, + self.entropy, + self.mini_samples, + self.large_samples, + ] = executor.run(main_program, feed=feed, fetch_list=fetch_list) + + def test_mean(self): + self.assertEqual( + str(self.mean.dtype).split('.')[-1], self.probability.dtype + ) + np.testing.assert_allclose( + self.mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_variance(self): + self.assertEqual( + str(self.var.dtype).split('.')[-1], self.probability.dtype + ) + np.testing.assert_allclose( + self.var, + self._np_variance(), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + def test_entropy(self): + self.assertEqual( + str(self.entropy.dtype).split('.')[-1], self.probability.dtype + ) + np.testing.assert_allclose( + self.entropy, + self._np_entropy(), + rtol=0.10, + atol=0.10, + ) + + def test_sample(self): + self.assertEqual( + str(self.mini_samples.dtype).split('.')[-1], self.probability.dtype + ) + sample_mean = self.large_samples.mean(axis=0) + sample_variance = self.large_samples.var(axis=0) + np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.20) + np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.20) + + def _np_variance(self): + return self._np_dist.np_variance() + + def _np_mean(self): + return self._np_dist.np_mean() + + def _np_entropy(self): + return self._np_dist.np_entropy() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'probability', 'value'), + [ + ( + 'value-broadcast-shape', + parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + ), + ], +) +class TestContinuousBernoulliProbs(unittest.TestCase): + def setUp(self): + self._np_dist = ContinuousBernoulli_np(self.probability) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + probability = paddle.static.data( + 'probability', self.probability.shape, self.probability.dtype + ) + value = paddle.static.data( + 'value', self.value.shape, self.value.dtype + ) + dist = ContinuousBernoulli(probability) + pmf = dist.prob(value) + feed = {'probability': self.probability, 'value': self.value} + fetch_list = [pmf] + + executor.run(startup_program) + [self.pmf] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_prob(self): + np.testing.assert_allclose( + self.pmf, + self._np_dist.np_prob(self.value), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'p_1', 'p_2'), + [ + ( + 'multi-dim', + parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + ), + ], +) +class TestContinuousBernoulliKL(unittest.TestCase): + def setUp(self): + self._np_dist1 = ContinuousBernoulli_np(self.p_1) + self._np_dist2 = ContinuousBernoulli_np(self.p_2) + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + p_1 = paddle.static.data('p_1', self.p_1.shape) + p_2 = paddle.static.data('p_2', self.p_2.shape) + dist1 = ContinuousBernoulli(p_1) + dist2 = ContinuousBernoulli(p_2) + kl_dist1_dist2 = dist1.kl_divergence(dist2) + feed = {'p_1': self.p_1, 'p_2': self.p_2} + fetch_list = [kl_dist1_dist2] + + executor.run(startup_program) + [self.kl_dist1_dist2] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_kl_divergence(self): + kl0 = self.kl_dist1_dist2 + kl1 = self._np_dist1.np_kl_divergence(self._np_dist2) + + self.assertEqual(tuple(kl0.shape), self.p_1.shape) + self.assertEqual(tuple(kl1.shape), self.p_1.shape) + np.testing.assert_allclose(kl0, kl1, rtol=0.1, atol=0.1) + + +if __name__ == '__main__': + unittest.main(argv=[''], verbosity=3, exit=False) diff --git a/test/distribution/test_distribution_multivariate_normal.py b/test/distribution/test_distribution_multivariate_normal.py new file mode 100644 index 0000000000000..40df25e2dc5bf --- /dev/null +++ b/test/distribution/test_distribution_multivariate_normal.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy +from distribution import config + +import paddle +from paddle.distribution.multivariate_normal import MultivariateNormal + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'covariance_matrix'), + [ + ( + 'one-batch', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + ), + ( + 'multi-batch', + parameterize.xrand((2, 3), dtype='float32', min=-2, max=2), + np.array([[6.0, 2.5, 3.0], [2.5, 4.0, 5.0], [3.0, 5.0, 7.0]]), + ), + ], +) +class TestMVN(unittest.TestCase): + def setUp(self): + self._dist = MultivariateNormal( + loc=paddle.to_tensor(self.loc), + covariance_matrix=paddle.to_tensor(self.covariance_matrix), + ) + + def test_mean(self): + mean = self._dist.mean + self.assertEqual(mean.numpy().dtype, self.loc.dtype) + np.testing.assert_allclose( + mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_variance(self): + var = self._dist.variance + self.assertEqual(var.numpy().dtype, self.loc.dtype) + np.testing.assert_allclose( + var, + self._np_variance(), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_entropy(self): + entropy = self._dist.entropy() + self.assertEqual(entropy.numpy().dtype, self.loc.dtype) + np.testing.assert_allclose( + entropy, + self._np_entropy(), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_sample(self): + sample_shape = () + samples = self._dist.sample(sample_shape) + self.assertEqual(samples.numpy().dtype, self.loc.dtype) + self.assertEqual( + tuple(samples.shape), + sample_shape + self._dist.batch_shape + self._dist.event_shape, + ) + + sample_shape = (10000,) + samples = self._dist.sample(sample_shape) + sample_mean = samples.mean(axis=0) + sample_variance = samples.var(axis=0) + + np.testing.assert_allclose( + sample_mean, self._dist.mean, atol=0.05, rtol=0.40 + ) + np.testing.assert_allclose( + sample_variance, self._dist.variance, atol=0.05, rtol=0.40 + ) + + def _np_variance(self): + batch_shape = np.broadcast_shapes( + self.covariance_matrix.shape[:-2], self.loc.shape[:-1] + ) + event_shape = self.loc.shape[-1:] + return np.broadcast_to( + np.diag(self.covariance_matrix), batch_shape + event_shape + ) + + def _np_mean(self): + return self.loc + + def _np_entropy(self): + if len(self.loc.shape) <= 1: + return scipy.stats.multivariate_normal.entropy( + self.loc, self.covariance_matrix + ) + else: + return np.apply_along_axis( + lambda i: scipy.stats.multivariate_normal.entropy( + i, self.covariance_matrix + ), + axis=1, + arr=self.loc, + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'covariance_matrix', 'value'), + [ + ( + 'value-same-shape', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + parameterize.xrand((2,), dtype='float32', min=-5, max=5), + ), + ( + 'value-broadcast-shape', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + parameterize.xrand((3, 2), dtype='float32', min=-5, max=5), + ), + ], +) +class TestMVNProbs(unittest.TestCase): + def setUp(self): + self._dist = MultivariateNormal( + loc=self.loc, + covariance_matrix=paddle.to_tensor(self.covariance_matrix), + ) + + def test_prob(self): + if len(self.value.shape) <= 1: + scipy_pdf = scipy.stats.multivariate_normal.pdf( + self.value, self.loc, self.covariance_matrix + ) + else: + scipy_pdf = np.apply_along_axis( + lambda i: scipy.stats.multivariate_normal.pdf( + i, self.loc, self.covariance_matrix + ), + axis=1, + arr=self.value, + ) + np.testing.assert_allclose( + self._dist.prob(paddle.to_tensor(self.value)), + scipy_pdf, + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_log_prob(self): + if len(self.value.shape) <= 1: + scipy_logpdf = scipy.stats.multivariate_normal.logpdf( + self.value, self.loc, self.covariance_matrix + ) + else: + scipy_logpdf = np.apply_along_axis( + lambda i: scipy.stats.multivariate_normal.logpdf( + i, self.loc, self.covariance_matrix + ), + axis=1, + arr=self.value, + ) + np.testing.assert_allclose( + self._dist.log_prob(paddle.to_tensor(self.value)), + scipy_logpdf, + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'mu_1', 'sig_1', 'mu_2', 'sig_2'), + [ + ( + 'one-batch', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[3.0, 2.0], [2.0, 3.0]]), + ) + ], +) +class TestMVNKL(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self._dist1 = MultivariateNormal( + loc=paddle.to_tensor(self.mu_1), + covariance_matrix=paddle.to_tensor(self.sig_1), + ) + self._dist2 = MultivariateNormal( + loc=paddle.to_tensor(self.mu_2), + covariance_matrix=paddle.to_tensor(self.sig_2), + ) + + def test_kl_divergence(self): + kl0 = self._dist1.kl_divergence(self._dist2) + kl1 = self.kl_divergence(self._dist1, self._dist2) + + self.assertEqual(tuple(kl0.shape), self._dist1.batch_shape) + self.assertEqual(tuple(kl1.shape), self._dist1.batch_shape) + np.testing.assert_allclose( + kl0, + kl1, + rtol=config.RTOL.get(str(self.mu_1.dtype)), + atol=config.ATOL.get(str(self.mu_1.dtype)), + ) + + def kl_divergence(self, dist1, dist2): + t1 = np.array(dist1._unbroadcasted_scale_tril) + t2 = np.array(dist2._unbroadcasted_scale_tril) + half_log_det_1 = np.log(t1.diagonal(axis1=-2, axis2=-1)).sum(-1) + half_log_det_2 = np.log(t2.diagonal(axis1=-2, axis2=-1)).sum(-1) + new_perm = list(range(len(t1.shape))) + new_perm[-1], new_perm[-2] = new_perm[-2], new_perm[-1] + cov_mat_1 = np.matmul(t1, t1.transpose(new_perm)) + cov_mat_2 = np.matmul(t2, t2.transpose(new_perm)) + expectation = ( + np.linalg.solve(cov_mat_2, cov_mat_1) + .diagonal(axis1=-2, axis2=-1) + .sum(-1) + ) + tmp = np.linalg.solve(t2, self.mu_1 - self.mu_2) + expectation += np.matmul(tmp.T, tmp) + return half_log_det_2 - half_log_det_1 + 0.5 * (expectation - 2.0) + + +if __name__ == '__main__': + unittest.main(argv=[''], verbosity=3, exit=False) diff --git a/test/distribution/test_distribution_multivariate_normal_static.py b/test/distribution/test_distribution_multivariate_normal_static.py new file mode 100644 index 0000000000000..230263f214b0f --- /dev/null +++ b/test/distribution/test_distribution_multivariate_normal_static.py @@ -0,0 +1,282 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy +from distribution import config + +import paddle +from paddle.distribution.multivariate_normal import MultivariateNormal + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'covariance_matrix'), + [ + ( + 'one-batch', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + ), + ( + 'multi-batch', + parameterize.xrand((2, 3), dtype='float32', min=-2, max=2), + np.array([[6.0, 2.5, 3.0], [2.5, 4.0, 5.0], [3.0, 5.0, 7.0]]), + ), + ], +) +class TestMVN(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + covariance_matrix = paddle.static.data( + 'covariance_matrix', + self.covariance_matrix.shape, + self.covariance_matrix.dtype, + ) + dist = MultivariateNormal( + loc=loc, covariance_matrix=covariance_matrix + ) + mean = dist.mean + var = dist.variance + entropy = dist.entropy() + mini_samples = dist.sample(shape=()) + large_samples = dist.sample(shape=(10000,)) + fetch_list = [mean, var, entropy, mini_samples, large_samples] + feed = {'loc': self.loc, 'covariance_matrix': self.covariance_matrix} + + executor.run(startup_program) + [ + self.mean, + self.var, + self.entropy, + self.mini_samples, + self.large_samples, + ] = executor.run(main_program, feed=feed, fetch_list=fetch_list) + + def test_mean(self): + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.loc.dtype) + np.testing.assert_allclose( + self.mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_variance(self): + self.assertEqual(str(self.var.dtype).split('.')[-1], self.loc.dtype) + np.testing.assert_allclose( + self.var, + self._np_variance(), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_entropy(self): + self.assertEqual(str(self.entropy.dtype).split('.')[-1], self.loc.dtype) + np.testing.assert_allclose( + self.entropy, + self._np_entropy(), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + def test_sample(self): + self.assertEqual( + str(self.mini_samples.dtype).split('.')[-1], self.loc.dtype + ) + sample_mean = self.large_samples.mean(axis=0) + sample_variance = self.large_samples.var(axis=0) + np.testing.assert_allclose(sample_mean, self.mean, atol=0.05, rtol=0.40) + np.testing.assert_allclose( + sample_variance, self.var, atol=0.05, rtol=0.40 + ) + + def _np_variance(self): + batch_shape = np.broadcast_shapes( + self.covariance_matrix.shape[:-2], self.loc.shape[:-1] + ) + event_shape = self.loc.shape[-1:] + return np.broadcast_to( + np.diag(self.covariance_matrix), batch_shape + event_shape + ) + + def _np_mean(self): + return self.loc + + def _np_entropy(self): + if len(self.loc.shape) <= 1: + return scipy.stats.multivariate_normal.entropy( + self.loc, self.covariance_matrix + ) + else: + return np.apply_along_axis( + lambda i: scipy.stats.multivariate_normal.entropy( + i, self.covariance_matrix + ), + axis=1, + arr=self.loc, + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'covariance_matrix', 'value'), + [ + ( + 'value-same-shape', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + parameterize.xrand((2,), dtype='float32', min=-5, max=5), + ), + ( + 'value-broadcast-shape', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]), + parameterize.xrand((3, 2), dtype='float32', min=-5, max=5), + ), + ], +) +class TestMVNProbs(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + covariance_matrix = paddle.static.data( + 'covariance_matrix', + self.covariance_matrix.shape, + self.covariance_matrix.dtype, + ) + value = paddle.static.data( + 'value', self.value.shape, self.value.dtype + ) + dist = MultivariateNormal( + loc=loc, covariance_matrix=covariance_matrix + ) + pmf = dist.prob(value) + feed = { + 'loc': self.loc, + 'covariance_matrix': self.covariance_matrix, + 'value': self.value, + } + fetch_list = [pmf] + + executor.run(startup_program) + [self.pmf] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_prob(self): + if len(self.value.shape) <= 1: + scipy_pdf = scipy.stats.multivariate_normal.pdf( + self.value, self.loc, self.covariance_matrix + ) + else: + scipy_pdf = np.apply_along_axis( + lambda i: scipy.stats.multivariate_normal.pdf( + i, self.loc, self.covariance_matrix + ), + axis=1, + arr=self.value, + ) + np.testing.assert_allclose( + self.pmf, + scipy_pdf, + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'mu_1', 'sig_1', 'mu_2', 'sig_2'), + [ + ( + 'one-batch', + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[2.0, 1.0], [1.0, 2.0]]).astype('float32'), + parameterize.xrand((2,), dtype='float32', min=-2, max=2), + np.array([[3.0, 2.0], [2.0, 3.0]]).astype('float32'), + ) + ], +) +class TestMVNKL(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + mu_1 = paddle.static.data('mu_1', self.mu_1.shape) + sig_1 = paddle.static.data('sig_1', self.sig_1.shape) + mu_2 = paddle.static.data('mu_2', self.mu_2.shape) + sig_2 = paddle.static.data('sig_2', self.sig_2.shape) + dist1 = MultivariateNormal(loc=mu_1, covariance_matrix=sig_1) + dist2 = MultivariateNormal(loc=mu_2, covariance_matrix=sig_2) + kl_dist1_dist2 = dist1.kl_divergence(dist2) + feed = { + 'mu_1': self.mu_1, + 'sig_1': self.sig_1, + 'mu_2': self.mu_2, + 'sig_2': self.sig_2, + } + fetch_list = [kl_dist1_dist2] + + executor.run(startup_program) + [self.kl_dist1_dist2] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_kl_divergence(self): + kl0 = self.kl_dist1_dist2 + kl1 = self.kl_divergence() + batch_shape = np.broadcast_shapes( + self.sig_1.shape[:-2], self.mu_1.shape[:-1] + ) + self.assertEqual(tuple(kl0.shape), batch_shape) + self.assertEqual(tuple(kl1.shape), batch_shape) + np.testing.assert_allclose(kl0, kl1, rtol=0.1, atol=0.1) + + def kl_divergence(self): + t1 = np.array(np.linalg.cholesky(self.sig_1)) + t2 = np.array(np.linalg.cholesky(self.sig_2)) + half_log_det_1 = np.log(t1.diagonal(axis1=-2, axis2=-1)).sum(-1) + half_log_det_2 = np.log(t2.diagonal(axis1=-2, axis2=-1)).sum(-1) + new_perm = list(range(len(t1.shape))) + new_perm[-1], new_perm[-2] = new_perm[-2], new_perm[-1] + cov_mat_1 = np.matmul(t1, t1.transpose(new_perm)) + cov_mat_2 = np.matmul(t2, t2.transpose(new_perm)) + expectation = ( + np.linalg.solve(cov_mat_2, cov_mat_1) + .diagonal(axis1=-2, axis2=-1) + .sum(-1) + ) + tmp = np.linalg.solve(t2, self.mu_1 - self.mu_2) + expectation += np.matmul(tmp.T, tmp) + return half_log_det_2 - half_log_det_1 + 0.5 * (expectation - 2.0) + + +if __name__ == '__main__': + unittest.main(argv=[''], verbosity=3, exit=False) From 4f3c102b37528c167be4db061de2a884edc62be8 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 24 Oct 2023 17:06:52 +0800 Subject: [PATCH 02/29] add kl-div registrition for cb and mvn --- python/paddle/distribution/kl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index aa880ba37a5b8..27e12a4309c2e 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -19,12 +19,14 @@ from paddle.distribution.beta import Beta from paddle.distribution.categorical import Categorical from paddle.distribution.cauchy import Cauchy +from paddle.distribution.continuous_bernoulli import ContinuousBernoulli from paddle.distribution.dirichlet import Dirichlet from paddle.distribution.distribution import Distribution from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.geometric import Geometric from paddle.distribution.laplace import Laplace from paddle.distribution.lognormal import LogNormal +from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.distribution.normal import Normal from paddle.distribution.uniform import Uniform from paddle.framework import in_dynamic_mode @@ -192,11 +194,21 @@ def _kl_cauchy_cauchy(p, q): return p.kl_divergence(q) +@register_kl(ContinuousBernoulli, ContinuousBernoulli) +def _kl_continuousbernoulli_continuousbernoulli(p, q): + return p.kl_divergence(q) + + @register_kl(Normal, Normal) def _kl_normal_normal(p, q): return p.kl_divergence(q) +@register_kl(MultivariateNormal, MultivariateNormal) +def _kl_mvn_mvn(p, q): + return p.kl_divergence(q) + + @register_kl(Uniform, Uniform) def _kl_uniform_uniform(p, q): return p.kl_divergence(q) From f64796cf3d38a0816630c25bb1e7c798fc07ec0f Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 16 Nov 2023 17:09:44 +0800 Subject: [PATCH 03/29] fix docs annd test --- .../distribution/continuous_bernoulli.py | 16 +++++++++---- .../distribution/multivariate_normal.py | 24 ++++++++++++------- .../test_distribution_continuous_bernoulli.py | 8 +++---- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 2b1053b736924..657acc00336c0 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -54,21 +54,27 @@ class ContinuousBernoulli(distribution.Distribution): >>> import paddle >>> from paddle.distribution import ContinuousBernoulli >>> rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5])) + + >>> # doctest: +SKIP >>> print(rv.sample([2])) Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.09428147, 0.81438422], - [0.24624705, 0.93354583]]) + [[0.09428147, 0.81438422], + [0.24624705, 0.93354583]]) + + >>> # doctest: -SKIP >>> print(rv.mean) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [0.38801414, 0.50000000]) + [0.38801414, 0.50000000]) + >>> print(rv.entropy()) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [-0.07641461, 0. ]) + [-0.07641461, 0. ]) + >>> rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8])) >>> rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5])) >>> print(rv1.kl_divergence(rv2)) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [0.20103613, 0.07641447]) + [0.20103613, 0.07641447]) """ def __init__(self, probability, eps=1e-4): diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 89125f64121d1..989a58cb4b9db 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -49,27 +49,33 @@ class MultivariateNormal(distribution.Distribution): >>> import paddle >>> from paddle.distribution import MultivariateNormal >>> rv = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + + >>> # doctest: +SKIP >>> print(rv.sample([3, 2])) Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[[0.68554986, 3.85142398], - [1.88336682, 5.43841648]], + [[[0.68554986, 3.85142398], + [1.88336682, 5.43841648]], + + [[5.32492065, 7.23725986], + [3.42192221, 4.83934879]], - [[5.32492065, 7.23725986], - [3.42192221, 4.83934879]], + [[3.36775684, 4.46108866], + [4.58927441, 4.32255936]]]) - [[3.36775684, 4.46108866], - [4.58927441, 4.32255936]]]) + >>> # doctest: -SKIP >>> print(rv.mean) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [2., 5.]) + [2., 5.]) + >>> print(rv.entropy()) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 3.38718319) + 3.38718319) + >>> rv1 = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) >>> rv2 = MultivariateNormal(loc=paddle.to_tensor([-1.,3.]), covariance_matrix=paddle.to_tensor([[3.,2.],[2.,3.]])) >>> print(rv1.kl_divergence(rv2)) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 1.55541301) + 1.55541301) """ def __init__( diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index 4355e66a5fd56..62404872a3901 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -189,8 +189,8 @@ def test_variance(self): np.testing.assert_allclose( var, self._np_dist.np_variance(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=0.00, + atol=0.20, ) def test_entropy(self): @@ -199,8 +199,8 @@ def test_entropy(self): np.testing.assert_allclose( entropy, self._np_dist.np_entropy(), - rtol=0.10, - atol=0.10, + rtol=0.00, + atol=0.20, ) def test_sample(self): From 42d279eb0257da91d071a1cae2aa3b390f112c02 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 16 Nov 2023 20:38:19 +0800 Subject: [PATCH 04/29] fix test --- .../test_distribution_continuous_bernoulli_static.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index d221f9b183c10..df67667775a8a 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -210,8 +210,8 @@ def test_variance(self): np.testing.assert_allclose( self.var, self._np_variance(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=0.00, + atol=0.20, ) def test_entropy(self): @@ -221,8 +221,8 @@ def test_entropy(self): np.testing.assert_allclose( self.entropy, self._np_entropy(), - rtol=0.10, - atol=0.10, + rtol=0.00, + atol=0.20, ) def test_sample(self): From ca84f0dc467e835b484ae9c7ded088a516d5f85d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 17 Nov 2023 20:42:40 +0800 Subject: [PATCH 05/29] fix test --- ...istribution_continuous_bernoulli_static.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index df67667775a8a..c6d389547c4f7 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -160,7 +160,7 @@ def np_kl_divergence(self, other): [ ( 'multi-dim', - parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1, 3), min=0.1, max=0.9).astype("float32"), ), ], ) @@ -178,9 +178,8 @@ def setUp(self): mean = dist.mean var = dist.variance entropy = dist.entropy() - mini_samples = dist.sample(shape=()) large_samples = dist.sample(shape=(1000,)) - fetch_list = [mean, var, entropy, mini_samples, large_samples] + fetch_list = [mean, var, entropy, large_samples] feed = {'probability': self.probability} executor.run(startup_program) @@ -188,7 +187,6 @@ def setUp(self): self.mean, self.var, self.entropy, - self.mini_samples, self.large_samples, ] = executor.run(main_program, feed=feed, fetch_list=fetch_list) @@ -210,8 +208,8 @@ def test_variance(self): np.testing.assert_allclose( self.var, self._np_variance(), - rtol=0.00, - atol=0.20, + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), ) def test_entropy(self): @@ -221,14 +219,11 @@ def test_entropy(self): np.testing.assert_allclose( self.entropy, self._np_entropy(), - rtol=0.00, + rtol=0.0, atol=0.20, ) def test_sample(self): - self.assertEqual( - str(self.mini_samples.dtype).split('.')[-1], self.probability.dtype - ) sample_mean = self.large_samples.mean(axis=0) sample_variance = self.large_samples.var(axis=0) np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.20) @@ -251,7 +246,7 @@ def _np_entropy(self): ( 'value-broadcast-shape', parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2, 2), min=0.1, max=0.9).astype("float32"), ), ], ) @@ -294,8 +289,8 @@ def test_prob(self): [ ( 'multi-dim', - parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), ), ], ) @@ -328,7 +323,7 @@ def test_kl_divergence(self): self.assertEqual(tuple(kl0.shape), self.p_1.shape) self.assertEqual(tuple(kl1.shape), self.p_1.shape) - np.testing.assert_allclose(kl0, kl1, rtol=0.1, atol=0.1) + np.testing.assert_allclose(kl0, kl1, rtol=0, atol=0.2) if __name__ == '__main__': From b27eef0dc44e7e4b3367934ce1343a0066103345 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 21 Nov 2023 11:24:48 +0800 Subject: [PATCH 06/29] fix mvn test coverage --- .../distribution/multivariate_normal.py | 2 +- .../test_distribution_multivariate_normal.py | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 989a58cb4b9db..0968fc23cfebf 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -175,7 +175,7 @@ def _check_lower_triangular(self, value): Return: Tensor: indicator for lower triangular matrix """ - tril = value.tril() + tril = paddle.tril(value) is_lower_triangular = paddle.cast( (tril == value).reshape( value.shape[:-2] diff --git a/test/distribution/test_distribution_multivariate_normal.py b/test/distribution/test_distribution_multivariate_normal.py index 40df25e2dc5bf..a4e91f7d8da24 100644 --- a/test/distribution/test_distribution_multivariate_normal.py +++ b/test/distribution/test_distribution_multivariate_normal.py @@ -91,10 +91,10 @@ def test_sample(self): sample_variance = samples.var(axis=0) np.testing.assert_allclose( - sample_mean, self._dist.mean, atol=0.05, rtol=0.40 + sample_mean, self._dist.mean, atol=0.00, rtol=0.40 ) np.testing.assert_allclose( - sample_variance, self._dist.variance, atol=0.05, rtol=0.40 + sample_variance, self._dist.variance, atol=0.00, rtol=0.40 ) def _np_variance(self): @@ -126,7 +126,7 @@ def _np_entropy(self): @parameterize.place(config.DEVICES) @parameterize.parameterize_cls( - (parameterize.TEST_CASE_NAME, 'loc', 'covariance_matrix', 'value'), + (parameterize.TEST_CASE_NAME, 'loc', 'precision_matrix', 'value'), [ ( 'value-same-shape', @@ -146,18 +146,19 @@ class TestMVNProbs(unittest.TestCase): def setUp(self): self._dist = MultivariateNormal( loc=self.loc, - covariance_matrix=paddle.to_tensor(self.covariance_matrix), + precision_matrix=paddle.to_tensor(self.precision_matrix), ) + self.cov = np.linalg.inv(self.precision_matrix) def test_prob(self): if len(self.value.shape) <= 1: scipy_pdf = scipy.stats.multivariate_normal.pdf( - self.value, self.loc, self.covariance_matrix + self.value, self.loc, self.cov ) else: scipy_pdf = np.apply_along_axis( lambda i: scipy.stats.multivariate_normal.pdf( - i, self.loc, self.covariance_matrix + i, self.loc, self.cov ), axis=1, arr=self.value, @@ -172,12 +173,12 @@ def test_prob(self): def test_log_prob(self): if len(self.value.shape) <= 1: scipy_logpdf = scipy.stats.multivariate_normal.logpdf( - self.value, self.loc, self.covariance_matrix + self.value, self.loc, self.cov ) else: scipy_logpdf = np.apply_along_axis( lambda i: scipy.stats.multivariate_normal.logpdf( - i, self.loc, self.covariance_matrix + i, self.loc, self.cov ), axis=1, arr=self.value, @@ -192,14 +193,14 @@ def test_log_prob(self): @parameterize.place(config.DEVICES) @parameterize.parameterize_cls( - (parameterize.TEST_CASE_NAME, 'mu_1', 'sig_1', 'mu_2', 'sig_2'), + (parameterize.TEST_CASE_NAME, 'mu_1', 'tril_1', 'mu_2', 'tril_2'), [ ( 'one-batch', parameterize.xrand((2,), dtype='float32', min=-2, max=2), - np.array([[2.0, 1.0], [1.0, 2.0]]), + np.array([[2.0, 0.0], [1.0, 2.0]]), parameterize.xrand((2,), dtype='float32', min=-2, max=2), - np.array([[3.0, 2.0], [2.0, 3.0]]), + np.array([[3.0, 0.0], [2.0, 3.0]]), ) ], ) @@ -208,11 +209,11 @@ def setUp(self): paddle.disable_static() self._dist1 = MultivariateNormal( loc=paddle.to_tensor(self.mu_1), - covariance_matrix=paddle.to_tensor(self.sig_1), + scale_tril=paddle.to_tensor(self.tril_1), ) self._dist2 = MultivariateNormal( loc=paddle.to_tensor(self.mu_2), - covariance_matrix=paddle.to_tensor(self.sig_2), + scale_tril=paddle.to_tensor(self.tril_2), ) def test_kl_divergence(self): From 28e46d31dec7e12ffdf17a6a2d6ca101f7e3624d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 21 Nov 2023 18:22:00 +0800 Subject: [PATCH 07/29] fix docs --- python/paddle/distribution/continuous_bernoulli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 657acc00336c0..33065a847c8da 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -21,7 +21,7 @@ class ContinuousBernoulli(distribution.Distribution): - r"""The Continuous Bernoulli distribution with probability parameter: `probability`. + r"""The Continuous Bernoulli distribution with parameter: `probability` characterizing the shape of the density function. Mathematical details From e7ce6bee5f323fedd8526dee664a73721b3318c0 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 22 Nov 2023 18:03:42 +0800 Subject: [PATCH 08/29] update docs --- .../distribution/continuous_bernoulli.py | 72 +++++++++++++------ .../distribution/multivariate_normal.py | 26 ++++--- 2 files changed, 69 insertions(+), 29 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 33065a847c8da..ee0cd6db3cf23 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -22,6 +22,9 @@ class ContinuousBernoulli(distribution.Distribution): r"""The Continuous Bernoulli distribution with parameter: `probability` characterizing the shape of the density function. + The Continuous Bernoulli distribution is defined on [0, 1], and it can be viewed as a continuous version of the Bernoulli distribution. + + [1] Loaiza-Ganem, G., & Cunningham, J. P. The continuous Bernoulli: fixing a pervasive error in variational autoencoders. 2019. Mathematical details @@ -33,20 +36,26 @@ class ContinuousBernoulli(distribution.Distribution): In the above equation: - * :math:`probability = \lambda`: is the probability. - * :math: `C(\lambda) = - \left\{ - \begin{aligned} - &2 & \text{ if $\lambda = \frac{1}{2}$} \\ - &\frac{2\tanh^{-1}(1-2\lambda)}{1 - 2\lambda} & \text{ otherwise} - \end{aligned} - \right.` * :math:`x`: is continuous between 0 and 1 + * :math:`probability = \lambda`: is the probability. + * :math:`C(\lambda)`: is the normalizing constant factor + + .. math:: + + C(\lambda) = + \left\{ + \begin{aligned} + &2 & \text{ if $\lambda = \frac{1}{2}$} \\ + &\frac{2\tanh^{-1}(1-2\lambda)}{1 - 2\lambda} & \text{ otherwise} + \end{aligned} + \right. Args: - probability(int|float|np.ndarray|Tensor): The probability of Continuous Bernoulli distribution, which characterize the shape of the pdf. - The data type of `probability` will be convert to float32. - eps(float): Specify the bandwith of the unstable calculation region near 0.5 + probability(int|float|np.ndarray|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], + which characterize the shape of the pdf. The data type of `probability` will be convert to float32. + eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region + would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The + default value is 1e-4. Examples: .. code-block:: python @@ -232,25 +241,25 @@ def variance(self): ) def sample(self, shape=()): - """Generate Continuous Bernoulli samples of the specified shape. + """Generate Continuous Bernoulli samples of the specified shape. The final shape would be ``sample_shape + batch_shape``. Args: shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, A tensor with prepended dimensions shape. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is float32. """ with paddle.no_grad(): return self.rsample(shape) def rsample(self, shape=()): - """Generate Continuous Bernoulli samples of the specified shape. + """Generate Continuous Bernoulli samples of the specified shape. The final shape would be ``sample_shape + batch_shape``. Args: shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, A tensor with prepended dimensions shape. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is float32. """ if not isinstance(shape, Iterable): raise TypeError('sample shape must be Iterable object.') @@ -300,7 +309,7 @@ def entropy(self): .. math:: - \mathcal{H}(X) = - \int_{x \in \Omega} p(x) \log{p(x)} dx + \mathcal{H}(X) = -\log C + \left[ \log (1 - \lambda) -\log \lambda \right] \mathbb{E}(X) - \log(1 - \lambda) In the above equation: @@ -319,7 +328,18 @@ def entropy(self): ) def cdf(self, value): - """Cumulative distribution function + r"""Cumulative distribution function + + .. math:: + + { P(X \le t; \lambda) = + F(t;\lambda) = + \left\{ + \begin{aligned} + &t & \text{ if $\lambda = \frac{1}{2}$} \\ + &\frac{\lambda^t (1 - \lambda)^{1 - t} + \lambda - 1}{2\lambda - 1} & \text{ otherwise} + \end{aligned} + \right. } Args: value (Tensor): The input tensor. @@ -351,13 +371,23 @@ def cdf(self, value): ) def icdf(self, value): - """Inverse cumulative distribution function + r"""Inverse cumulative distribution function + + .. math:: + + { F^{-1}(x;\lambda) = + \left\{ + \begin{aligned} + &x & \text{ if $\lambda = \frac{1}{2}$} \\ + &\frac{\log(1+(\frac{2\lambda - 1}{1 - \lambda})x)}{\log(\frac{\lambda}{1-\lambda})} & \text{ otherwise} + \end{aligned} + \right. } Args: value (Tensor): The input tensor, meaning the quantile. Returns: - Tensor: p-value of the quantile. The data type is same with :attr:`value` . + Tensor: the value of the r.v. corresponding to the quantile. The data type is same with :attr:`value` . """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): @@ -376,13 +406,13 @@ def icdf(self, value): ) def kl_divergence(self, other): - r"""The KL-divergence between two Continuous Bernoulli distributions. + r"""The KL-divergence between two Continuous Bernoulli distributions with the same `batch_shape`. The probability density function (pdf) is .. math:: - KL\_divergence(\lambda_1, \lambda_2) = \int_x p_1(x) \log{\frac{p_1(x)}{p_2(x)}} dx + KL\_divergence(\lambda_1, \lambda_2) = - H - \{\log C_2 + [\log \lambda_2 - \log (1-\lambda_2)] \mathbb{E}_1(X) + \log (1-\lambda_2) \} Args: other (ContinuousBernoulli): instance of Continuous Bernoulli. diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 0968fc23cfebf..c14ac5c2f15db 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -22,7 +22,8 @@ class MultivariateNormal(distribution.Distribution): - r"""The Multivariate Normal distribution with parameter: `loc` and any one of the following parameters: `covariance_matrix`, `precision_matrix`, `scale_tril`. + r"""The Multivariate Normal distribution is a type multivariate continuous distribution defined on the real set, with parameter: `loc` and any one + of the following parameters characterizing the variance: `covariance_matrix`, `precision_matrix`, `scale_tril`. Mathematical details @@ -34,8 +35,9 @@ class MultivariateNormal(distribution.Distribution): In the above equation: - * :math:`loc = \mu`: is the mean. - * :math:`covariance_matrix = \Sigma`: is the covariance matrix. + * :math:`X`: is a k-dim random vector. + * :math:`loc = \mu`: is the k-dim mean vector. + * :math:`covariance_matrix = \Sigma`: is the k-by-k covariance matrix. Args: loc(int|float|np.ndarray|Tensor): The mean of Multivariate Normal distribution. The data type of `loc` will be convert to float32. @@ -274,18 +276,26 @@ def variance(self): ) def sample(self, shape=()): - """Generate Multivariate Normal samples of the specified shape. + """Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``. Args: shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, A tensor with prepended dimensions shape. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is float32. """ with paddle.no_grad(): return self.rsample(shape) def rsample(self, shape=()): + """Generate Multivariate Normal samples of the specified shape. The final shape would be ``sample_shape + batch_shape + event_shape``. + + Args: + shape (Sequence[int], optional): Prepended shape of the generated samples. + + Returns: + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is float32. + """ if not isinstance(shape, Iterable): raise TypeError('sample shape must be Iterable object.') output_shape = self._extend_shape(shape) @@ -335,7 +345,7 @@ def entropy(self): .. math:: - \mathcal{H}(X) = - \int_{x \in \Omega} p(x) \log{p(x)} dx + \mathcal{H}(X) = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{n}{2} In the above equation: @@ -359,13 +369,13 @@ def entropy(self): return H.expand(self._batch_shape) def kl_divergence(self, other): - r"""The KL-divergence between two poisson distributions. + r"""The KL-divergence between two poisson distributions with the same `batch_shape` and `event_shape`. The probability density function (pdf) is .. math:: - KL\_divergence(\mu_1, \Sigma_1, \mu_2, \Sigma_2) = \int_x p_1(x) \log{\frac{p_1(x)}{p_2(x)}} dx + KL\_divergence(\lambda_1, \lambda_2) = \log(\det A_2) - \log(\det A_1) -\frac{n}{2} +\frac{1}{2}[tr [\Sigma_2^{-1} \Sigma_1] + (\mu_1 - \mu_2)^{\intercal} \Sigma_2^{-1} (\mu_1 - \mu_2)] Args: other (MultivariateNormal): instance of Multivariate Normal. From 2c5ce90383da1f7a9f2d00f31660c50c3c2c11d6 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 30 Nov 2023 22:41:21 +0800 Subject: [PATCH 09/29] update cb and mvn --- .../distribution/continuous_bernoulli.py | 80 ++++++++-------- .../distribution/multivariate_normal.py | 94 +++++++++---------- 2 files changed, 84 insertions(+), 90 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index ee0cd6db3cf23..45107731ace5f 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable - -import numpy as np +from collections.abc import Sequence import paddle from paddle.distribution import distribution @@ -51,8 +49,8 @@ class ContinuousBernoulli(distribution.Distribution): \right. Args: - probability(int|float|np.ndarray|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], - which characterize the shape of the pdf. The data type of `probability` will be convert to float32. + probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], + which characterize the shape of the pdf. The data type of `probability` will be convert to the global default dtype. eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The default value is 1e-4. @@ -60,36 +58,38 @@ class ContinuousBernoulli(distribution.Distribution): Examples: .. code-block:: python - >>> import paddle - >>> from paddle.distribution import ContinuousBernoulli - >>> rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5])) - - >>> # doctest: +SKIP - >>> print(rv.sample([2])) - Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.09428147, 0.81438422], - [0.24624705, 0.93354583]]) - - >>> # doctest: -SKIP - >>> print(rv.mean) - Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [0.38801414, 0.50000000]) - - >>> print(rv.entropy()) - Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [-0.07641461, 0. ]) - - >>> rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8])) - >>> rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5])) - >>> print(rv1.kl_divergence(rv2)) - Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [0.20103613, 0.07641447]) + import paddle + from paddle.distribution import ContinuousBernoulli + + # init `probability` with `paddle.Tensor` + rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5])) + + print(rv.sample([2])) + # Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [[0.09428147, 0.81438422], + # [0.24624705, 0.93354583]]) + + print(rv.mean) + # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [0.38801414, 0.50000000]) + + print(rv.entropy()) + # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [-0.07641461, 0. ]) + + rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8])) + rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5])) + print(rv1.kl_divergence(rv2)) + # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [0.20103613, 0.07641447]) """ def __init__(self, probability, eps=1e-4): self.eps = paddle.to_tensor(eps) - self.dtype = 'float32' + self.dtype = paddle.get_default_dtype() self.probability = self._to_tensor(probability) + + # eps_prob is used to clip the input `probability` in the range of [eps_prob, 1-eps_prob] eps_prob = paddle.finfo(self.probability.dtype).eps self.probability = paddle.clip( self.probability, min=eps_prob, max=1 - eps_prob @@ -106,17 +106,15 @@ def __init__(self, probability, eps=1e-4): super().__init__(batch_shape) def _to_tensor(self, probability): - """Convert the input parameters into tensors with dtype = float32 + """Convert the input parameters into tensors with the global default dtype Returns: Tensor: converted probability. """ # convert type if isinstance(probability, (float, int)): - probability = paddle.to_tensor([probability], dtype=self.dtype) - if isinstance(probability, np.ndarray): - probability = paddle.to_tensor(probability) - probability = paddle.cast(probability, dtype=self.dtype) + probability = [probability] + probability = paddle.to_tensor(probability, dtype=self.dtype) return probability def _check_constraint(self, value): @@ -247,7 +245,7 @@ def sample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is the global default dtype. """ with paddle.no_grad(): return self.rsample(shape) @@ -259,10 +257,10 @@ def rsample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is the global default dtype. """ - if not isinstance(shape, Iterable): - raise TypeError('sample shape must be Iterable object.') + if not isinstance(shape, Sequence): + raise TypeError('sample shape must be Sequence object.') shape = tuple(shape) batch_shape = tuple(self.batch_shape) output_shape = tuple(shape + batch_shape) @@ -316,7 +314,7 @@ def entropy(self): * :math:\Omega: is the support of the distribution. Returns: - Tensor, Shannon entropy of Continuous Bernoulli distribution. The data type is float32. + Tensor, Shannon entropy of Continuous Bernoulli distribution. The data type is the global default dtype. """ log_p = paddle.log(self.probability) log_1_minus_p = paddle.log1p(-self.probability) @@ -418,7 +416,7 @@ def kl_divergence(self, other): other (ContinuousBernoulli): instance of Continuous Bernoulli. Returns: - Tensor, kl-divergence between two Continuous Bernoulli distributions. The data type is float32. + Tensor, kl-divergence between two Continuous Bernoulli distributions. The data type is the global default dtype. """ diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index c14ac5c2f15db..e7a9c9465de42 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -13,9 +13,7 @@ # limitations under the License. import math -from collections.abc import Iterable - -import numpy as np +from collections.abc import Sequence import paddle from paddle.distribution import distribution @@ -40,44 +38,44 @@ class MultivariateNormal(distribution.Distribution): * :math:`covariance_matrix = \Sigma`: is the k-by-k covariance matrix. Args: - loc(int|float|np.ndarray|Tensor): The mean of Multivariate Normal distribution. The data type of `loc` will be convert to float32. - covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert to float32. - precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to float32. - scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be convert to float32. + loc(int|float|Tensor): The mean of Multivariate Normal distribution. The data type of `loc` will be convert to the global default dtype. + covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert to the global default dtype. + precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to the global default dtype. + scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be convert to the global default dtype. Examples: .. code-block:: python - >>> import paddle - >>> from paddle.distribution import MultivariateNormal - >>> rv = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) - - >>> # doctest: +SKIP - >>> print(rv.sample([3, 2])) - Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[[0.68554986, 3.85142398], - [1.88336682, 5.43841648]], - - [[5.32492065, 7.23725986], - [3.42192221, 4.83934879]], - - [[3.36775684, 4.46108866], - [4.58927441, 4.32255936]]]) - - >>> # doctest: -SKIP - >>> print(rv.mean) - Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - [2., 5.]) - - >>> print(rv.entropy()) - Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 3.38718319) - - >>> rv1 = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) - >>> rv2 = MultivariateNormal(loc=paddle.to_tensor([-1.,3.]), covariance_matrix=paddle.to_tensor([[3.,2.],[2.,3.]])) - >>> print(rv1.kl_divergence(rv2)) - Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 1.55541301) + import paddle + from paddle.distribution import MultivariateNormal + + # init `loc` and `covariance_matrix` with `paddle.Tensor` + rv = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + + print(rv.sample([3, 2])) + # Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [[[0.68554986, 3.85142398], + # [1.88336682, 5.43841648]], + # + # [[5.32492065, 7.23725986], + # [3.42192221, 4.83934879]], + # + # [[3.36775684, 4.46108866], + # [4.58927441, 4.32255936]]]) + + print(rv.mean) + # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [2., 5.]) + + print(rv.entropy()) + # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + # 3.38718319) + + rv1 = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + rv2 = MultivariateNormal(loc=paddle.to_tensor([-1.,3.]), covariance_matrix=paddle.to_tensor([[3.,2.],[2.,3.]])) + print(rv1.kl_divergence(rv2)) + # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + # 1.55541301) """ def __init__( @@ -87,14 +85,12 @@ def __init__( precision_matrix=None, scale_tril=None, ): - self.dtype = 'float32' + self.dtype = paddle.get_default_dtype() if isinstance(loc, (float, int)): - loc = paddle.to_tensor([loc], dtype=self.dtype) - elif isinstance(loc, np.ndarray): - loc = paddle.to_tensor(loc, dtype=self.dtype) + loc = [loc] + loc = paddle.to_tensor(loc, dtype=self.dtype) if loc.dim() < 1: loc = loc.reshape((1,)) - loc = paddle.cast(loc, dtype=self.dtype) self.covariance_matrix = None self.precision_matrix = None self.scale_tril = None @@ -220,7 +216,7 @@ def _check_positive_definite(self, value): "covariance_matrix or precision_matrix must be a symmetric matrix" ) is_postive_definite = ( - paddle.cast(paddle.linalg.eigvalsh(value), dtype="float32") > 0 + paddle.cast(paddle.linalg.eigvalsh(value), dtype=self.dtype) > 0 ).all() return is_postive_definite @@ -282,7 +278,7 @@ def sample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the global default dtype. """ with paddle.no_grad(): return self.rsample(shape) @@ -294,10 +290,10 @@ def rsample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is float32. + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the global default dtype. """ - if not isinstance(shape, Iterable): - raise TypeError('sample shape must be Iterable object.') + if not isinstance(shape, Sequence): + raise TypeError('sample shape must be Sequence object.') output_shape = self._extend_shape(shape) eps = paddle.normal(shape=output_shape) return self.loc + paddle.matmul( @@ -352,7 +348,7 @@ def entropy(self): * :math:\Omega: is the support of the distribution. Returns: - Tensor, Shannon entropy of Multivariate Normal distribution. The data type is float32. + Tensor, Shannon entropy of Multivariate Normal distribution. The data type is the global default dtype. """ half_log_det = ( self._unbroadcasted_scale_tril.diagonal(axis1=-2, axis2=-1) @@ -381,7 +377,7 @@ def kl_divergence(self, other): other (MultivariateNormal): instance of Multivariate Normal. Returns: - Tensor, kl-divergence between two Multivariate Normal distributions. The data type is float32. + Tensor, kl-divergence between two Multivariate Normal distributions. The data type is the global default dtype. """ if ( From e84391c203460dee560e96aade96bf7ee971bf26 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 1 Dec 2023 11:00:29 +0800 Subject: [PATCH 10/29] fix mvn test --- test/distribution/test_distribution_multivariate_normal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distribution/test_distribution_multivariate_normal.py b/test/distribution/test_distribution_multivariate_normal.py index a4e91f7d8da24..601ec445abefd 100644 --- a/test/distribution/test_distribution_multivariate_normal.py +++ b/test/distribution/test_distribution_multivariate_normal.py @@ -35,7 +35,7 @@ ( 'multi-batch', parameterize.xrand((2, 3), dtype='float32', min=-2, max=2), - np.array([[6.0, 2.5, 3.0], [2.5, 4.0, 5.0], [3.0, 5.0, 7.0]]), + np.array([[4.0, 2.5, 2.0], [2.5, 3.0, 1.2], [2.0, 1.2, 4.0]]), ), ], ) @@ -85,7 +85,7 @@ def test_sample(self): sample_shape + self._dist.batch_shape + self._dist.event_shape, ) - sample_shape = (10000,) + sample_shape = (50000,) samples = self._dist.sample(sample_shape) sample_mean = samples.mean(axis=0) sample_variance = samples.var(axis=0) From 251127ed99e4e491b5e5b2751c1970d35cd300c9 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 1 Dec 2023 13:21:25 +0800 Subject: [PATCH 11/29] fix test --- test/distribution/test_distribution_continuous_bernoulli.py | 2 +- .../test_distribution_continuous_bernoulli_static.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index 62404872a3901..f32018d6b73f6 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -101,7 +101,7 @@ def np_prob(self, value): return np.exp(self.np_log_prob(value)) def np_log_prob(self, value): - eps = 1e-8 + eps = np.finfo('float32').eps cross_entropy = np.nan_to_num( value * np.log(self.probability) + (1.0 - value) * np.log(1 - self.probability), diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index c6d389547c4f7..1f5f3b065d0f2 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -101,7 +101,7 @@ def np_prob(self, value): return np.exp(self.np_log_prob(value)) def np_log_prob(self, value): - eps = 1e-8 + eps = np.finfo('float32').eps cross_entropy = np.nan_to_num( value * np.log(self.probability) + (1.0 - value) * np.log(1 - self.probability), From 8972b7bf56c258128b8329246e7897224292958d Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 1 Dec 2023 22:18:19 +0800 Subject: [PATCH 12/29] fix test --- ...istribution_continuous_bernoulli_static.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index 1f5f3b065d0f2..d50684f001801 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -25,8 +25,8 @@ class ContinuousBernoulli_np: def __init__(self, probability, eps=1e-4): self.eps = eps - self.dtype = 'float32' - eps_prob = 1.1920928955078125e-07 + self.dtype = paddle.get_default_dtype() + eps_prob = np.finfo(self.dtype).eps self.probability = np.clip( probability, a_min=eps_prob, a_max=1 - eps_prob ) @@ -101,7 +101,7 @@ def np_prob(self, value): return np.exp(self.np_log_prob(value)) def np_log_prob(self, value): - eps = np.finfo('float32').eps + eps = np.finfo(self.dtype).eps cross_entropy = np.nan_to_num( value * np.log(self.probability) + (1.0 - value) * np.log(1 - self.probability), @@ -160,7 +160,7 @@ def np_kl_divergence(self, other): [ ( 'multi-dim', - parameterize.xrand((1, 3), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1, 3), min=0.0, max=1.0).astype("float32"), ), ], ) @@ -195,10 +195,7 @@ def test_mean(self): str(self.mean.dtype).split('.')[-1], self.probability.dtype ) np.testing.assert_allclose( - self.mean, - self._np_mean(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + self.mean, self._np_mean(), rtol=0.20, atol=0 ) def test_variance(self): @@ -206,10 +203,7 @@ def test_variance(self): str(self.var.dtype).split('.')[-1], self.probability.dtype ) np.testing.assert_allclose( - self.var, - self._np_variance(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + self.var, self._np_variance(), rtol=0.20, atol=0 ) def test_entropy(self): @@ -217,10 +211,7 @@ def test_entropy(self): str(self.entropy.dtype).split('.')[-1], self.probability.dtype ) np.testing.assert_allclose( - self.entropy, - self._np_entropy(), - rtol=0.0, - atol=0.20, + self.entropy, self._np_entropy(), rtol=0.20, atol=0 ) def test_sample(self): @@ -245,8 +236,8 @@ def _np_entropy(self): [ ( 'value-broadcast-shape', - parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((2, 2), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2, 2), min=0.0, max=1.0).astype("float32"), ), ], ) @@ -289,8 +280,8 @@ def test_prob(self): [ ( 'multi-dim', - parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2,), min=0.0, max=1.0).astype("float32"), ), ], ) From b24b1f9127dbdf1b550de0a2ee16064df931fc5e Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sat, 2 Dec 2023 20:48:53 +0800 Subject: [PATCH 13/29] fix test --- .../distribution/test_distribution_multivariate_normal.py | 8 ++++---- .../test_distribution_multivariate_normal_static.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/distribution/test_distribution_multivariate_normal.py b/test/distribution/test_distribution_multivariate_normal.py index 601ec445abefd..7e24d9ceeee7b 100644 --- a/test/distribution/test_distribution_multivariate_normal.py +++ b/test/distribution/test_distribution_multivariate_normal.py @@ -29,12 +29,12 @@ [ ( 'one-batch', - parameterize.xrand((2,), dtype='float32', min=-2, max=2), + parameterize.xrand((2,), dtype='float32', min=1, max=2), np.array([[2.0, 1.0], [1.0, 2.0]]), ), ( 'multi-batch', - parameterize.xrand((2, 3), dtype='float32', min=-2, max=2), + parameterize.xrand((2, 3), dtype='float32', min=-2, max=-1), np.array([[4.0, 2.5, 2.0], [2.5, 3.0, 1.2], [2.0, 1.2, 4.0]]), ), ], @@ -91,10 +91,10 @@ def test_sample(self): sample_variance = samples.var(axis=0) np.testing.assert_allclose( - sample_mean, self._dist.mean, atol=0.00, rtol=0.40 + sample_mean, self._dist.mean, atol=0.00, rtol=0.20 ) np.testing.assert_allclose( - sample_variance, self._dist.variance, atol=0.00, rtol=0.40 + sample_variance, self._dist.variance, atol=0.00, rtol=0.20 ) def _np_variance(self): diff --git a/test/distribution/test_distribution_multivariate_normal_static.py b/test/distribution/test_distribution_multivariate_normal_static.py index 230263f214b0f..522e88256563a 100644 --- a/test/distribution/test_distribution_multivariate_normal_static.py +++ b/test/distribution/test_distribution_multivariate_normal_static.py @@ -31,12 +31,12 @@ [ ( 'one-batch', - parameterize.xrand((2,), dtype='float32', min=-2, max=2), + parameterize.xrand((2,), dtype='float32', min=1, max=2), np.array([[2.0, 1.0], [1.0, 2.0]]), ), ( 'multi-batch', - parameterize.xrand((2, 3), dtype='float32', min=-2, max=2), + parameterize.xrand((2, 3), dtype='float32', min=-2, max=-1), np.array([[6.0, 2.5, 3.0], [2.5, 4.0, 5.0], [3.0, 5.0, 7.0]]), ), ], @@ -106,9 +106,9 @@ def test_sample(self): ) sample_mean = self.large_samples.mean(axis=0) sample_variance = self.large_samples.var(axis=0) - np.testing.assert_allclose(sample_mean, self.mean, atol=0.05, rtol=0.40) + np.testing.assert_allclose(sample_mean, self.mean, atol=0.00, rtol=0.20) np.testing.assert_allclose( - sample_variance, self.var, atol=0.05, rtol=0.40 + sample_variance, self.var, atol=0.00, rtol=0.20 ) def _np_variance(self): From 559c28edaed24a9861412f573b3705caf187c30c Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sun, 3 Dec 2023 20:36:11 +0800 Subject: [PATCH 14/29] fix test --- .../test_distribution_continuous_bernoulli.py | 4 ++-- ...distribution_continuous_bernoulli_static.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index f32018d6b73f6..e34c835b2939a 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -28,7 +28,7 @@ def __init__(self, probability, eps=1e-4): self.dtype = 'float32' eps_prob = 1.1920928955078125e-07 self.probability = np.clip( - probability, a_min=eps_prob, a_max=1 - eps_prob + probability, a_min=eps_prob, a_max=1.0 - eps_prob ) def _cut_support_region(self): @@ -101,7 +101,7 @@ def np_prob(self, value): return np.exp(self.np_log_prob(value)) def np_log_prob(self, value): - eps = np.finfo('float32').eps + eps = 1e-8 cross_entropy = np.nan_to_num( value * np.log(self.probability) + (1.0 - value) * np.log(1 - self.probability), diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index d50684f001801..7c742485422b0 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -25,10 +25,10 @@ class ContinuousBernoulli_np: def __init__(self, probability, eps=1e-4): self.eps = eps - self.dtype = paddle.get_default_dtype() - eps_prob = np.finfo(self.dtype).eps + self.dtype = 'float32' + eps_prob = 1.1920928955078125e-07 self.probability = np.clip( - probability, a_min=eps_prob, a_max=1 - eps_prob + probability, a_min=eps_prob, a_max=1.0 - eps_prob ) def _cut_support_region(self): @@ -101,7 +101,7 @@ def np_prob(self, value): return np.exp(self.np_log_prob(value)) def np_log_prob(self, value): - eps = np.finfo(self.dtype).eps + eps = 1e-8 cross_entropy = np.nan_to_num( value * np.log(self.probability) + (1.0 - value) * np.log(1 - self.probability), @@ -160,7 +160,7 @@ def np_kl_divergence(self, other): [ ( 'multi-dim', - parameterize.xrand((1, 3), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((1, 3), min=0.1, max=0.9).astype("float32"), ), ], ) @@ -236,8 +236,8 @@ def _np_entropy(self): [ ( 'value-broadcast-shape', - parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), - parameterize.xrand((2, 2), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2, 2), min=0.1, max=0.9).astype("float32"), ), ], ) @@ -280,8 +280,8 @@ def test_prob(self): [ ( 'multi-dim', - parameterize.xrand((2,), min=0.0, max=1.0).astype("float32"), - parameterize.xrand((2,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), ), ], ) From ef8006316e2ffc0c740dc11a42b81baf1893d7ca Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 4 Dec 2023 11:56:49 +0800 Subject: [PATCH 15/29] fix unstable region calculation --- .../distribution/continuous_bernoulli.py | 4 +-- .../test_distribution_continuous_bernoulli.py | 26 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 45107731ace5f..9b5825d00fa8f 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -84,7 +84,7 @@ class ContinuousBernoulli(distribution.Distribution): # [0.20103613, 0.07641447]) """ - def __init__(self, probability, eps=1e-4): + def __init__(self, probability, eps=0.02): self.eps = paddle.to_tensor(eps) self.dtype = paddle.get_default_dtype() self.probability = self._to_tensor(probability) @@ -225,7 +225,7 @@ def variance(self): """ cut_probs = self._cut_probs() tmp = paddle.divide( - paddle.square(cut_probs) - cut_probs, + cut_probs * (cut_probs - 1.0), paddle.square(1.0 - 2.0 * cut_probs), ) propose = tmp + paddle.divide( diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index e34c835b2939a..fa81a7e2d98fa 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -158,18 +158,18 @@ def np_kl_divergence(self, other): ('half', np.array(0.5).astype("float32")), ( 'one-dim', - parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), ), ( 'multi-dim', - parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2, 3), min=0.0, max=1.0).astype("float32"), ), ], ) class TestContinuousBernoulli(unittest.TestCase): def setUp(self): self._dist = ContinuousBernoulli( - probability=paddle.to_tensor(self.probability) + probability=paddle.to_tensor(self.probability), eps=0.02 ) self._np_dist = ContinuousBernoulli_np(self.probability) @@ -179,8 +179,8 @@ def test_mean(self): np.testing.assert_allclose( mean, self._np_dist.np_mean(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=0.00, + atol=0.20, ) def test_variance(self): @@ -231,13 +231,13 @@ def test_sample(self): [ ( 'value-same-shape', - parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((5,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((5,), min=0.0, max=1.0).astype("float32"), ), ( 'value-broadcast-shape', - parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((2, 3), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2, 3), min=0.0, max=1.0).astype("float32"), ), ], ) @@ -285,13 +285,13 @@ def test_icdf(self): [ ( 'one-dim', - parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), ), ( 'multi-dim', - parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((5,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((5,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((5,), min=0.0, max=1.0).astype("float32"), ), ], ) From 7e208a6be197057b319b2ee616a0da8e6782e510 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 4 Dec 2023 14:23:11 +0800 Subject: [PATCH 16/29] fix test --- test/distribution/test_distribution_continuous_bernoulli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index fa81a7e2d98fa..8ee7eaf8e4233 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -23,7 +23,7 @@ class ContinuousBernoulli_np: - def __init__(self, probability, eps=1e-4): + def __init__(self, probability, eps=0.02): self.eps = eps self.dtype = 'float32' eps_prob = 1.1920928955078125e-07 From a4903f2e75fbe2bfb0811ad02b63ca3fb9869830 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 6 Dec 2023 14:08:25 +0800 Subject: [PATCH 17/29] update dtype convertion and tests --- .../distribution/continuous_bernoulli.py | 49 ++++---- .../distribution/multivariate_normal.py | 19 +-- .../test_distribution_continuous_bernoulli.py | 108 ++++++++++++++---- ...istribution_continuous_bernoulli_static.py | 70 +++++++++--- .../test_distribution_multivariate_normal.py | 70 +++++++++++- ...distribution_multivariate_normal_static.py | 14 +-- 6 files changed, 253 insertions(+), 77 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 9b5825d00fa8f..e040ef5d00988 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -50,10 +50,11 @@ class ContinuousBernoulli(distribution.Distribution): Args: probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], - which characterize the shape of the pdf. The data type of `probability` will be convert to the global default dtype. + which characterize the shape of the pdf. If the input data type is int or float, the data type of + `probability` will be convert to a 1-D Tensor the paddle global default dtype. eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region - would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The - default value is 1e-4. + would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The + default value is 1e-4. Examples: .. code-block:: python @@ -85,9 +86,13 @@ class ContinuousBernoulli(distribution.Distribution): """ def __init__(self, probability, eps=0.02): - self.eps = paddle.to_tensor(eps) self.dtype = paddle.get_default_dtype() self.probability = self._to_tensor(probability) + self.eps = paddle.to_tensor(eps, dtype=self.dtype) + if not self._check_constraint(self.probability): + raise ValueError( + 'Every element of input parameter `probability` should be nonnegative.' + ) # eps_prob is used to clip the input `probability` in the range of [eps_prob, 1-eps_prob] eps_prob = paddle.finfo(self.probability.dtype).eps @@ -95,10 +100,6 @@ def __init__(self, probability, eps=0.02): self.probability, min=eps_prob, max=1 - eps_prob ) - if not self._check_constraint(self.probability): - raise ValueError( - 'Every element of input parameter `rate` should be nonnegative.' - ) if self.probability.shape == []: batch_shape = (1,) else: @@ -106,15 +107,16 @@ def __init__(self, probability, eps=0.02): super().__init__(batch_shape) def _to_tensor(self, probability): - """Convert the input parameters into tensors with the global default dtype + """Convert the input parameters into tensors Returns: Tensor: converted probability. """ # convert type if isinstance(probability, (float, int)): - probability = [probability] - probability = paddle.to_tensor(probability, dtype=self.dtype) + probability = paddle.to_tensor([probability], dtype=self.dtype) + else: + self.dtype = probability.dtype return probability def _check_constraint(self, value): @@ -169,26 +171,27 @@ def _log_constant(self): Tensor: logarithm of the constant factor """ cut_probs = self._cut_probs() + half = paddle.to_tensor(0.5, dtype=self.dtype) cut_probs_below_half = paddle.where( - paddle.less_equal(cut_probs, paddle.to_tensor(0.5)), + paddle.less_equal(cut_probs, half), cut_probs, paddle.zeros_like(cut_probs), ) cut_probs_above_half = paddle.where( - paddle.greater_equal(cut_probs, paddle.to_tensor(0.5)), + paddle.greater_equal(cut_probs, half), cut_probs, paddle.ones_like(cut_probs), ) log_constant_propose = paddle.log( 2.0 * paddle.abs(self._tanh_inverse(1.0 - 2.0 * cut_probs)) ) - paddle.where( - paddle.less_equal(cut_probs, paddle.to_tensor(0.5)), + paddle.less_equal(cut_probs, half), paddle.log1p(-2.0 * cut_probs_below_half), paddle.log(2.0 * cut_probs_above_half - 1.0), ) x = paddle.square(self.probability - 0.5) taylor_expansion = ( - paddle.log(paddle.to_tensor(2.0)) + paddle.log(paddle.to_tensor(2.0, dtype=self.dtype)) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x ) return paddle.where( @@ -230,7 +233,7 @@ def variance(self): ) propose = tmp + paddle.divide( paddle.to_tensor(1.0, dtype=self.dtype), - paddle.square(2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)), + paddle.square(paddle.log1p(-cut_probs) - paddle.log(cut_probs)), ) x = paddle.square(self.probability - 0.5) taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x @@ -245,7 +248,7 @@ def sample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is the global default dtype. + Tensor, Sampled data with shape `sample_shape` + `batch_shape`. """ with paddle.no_grad(): return self.rsample(shape) @@ -257,7 +260,7 @@ def rsample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape`. The data type is the global default dtype. + Tensor, Sampled data with shape `sample_shape` + `batch_shape`. """ if not isinstance(shape, Sequence): raise TypeError('sample shape must be Sequence object.') @@ -314,7 +317,7 @@ def entropy(self): * :math:\Omega: is the support of the distribution. Returns: - Tensor, Shannon entropy of Continuous Bernoulli distribution. The data type is the global default dtype. + Tensor, Shannon entropy of Continuous Bernoulli distribution. """ log_p = paddle.log(self.probability) log_1_minus_p = paddle.log1p(-self.probability) @@ -359,10 +362,12 @@ def cdf(self, value): ) / (2.0 * cut_probs - 1.0) unbounded_cdfs = paddle.where(self._cut_support_region(), cdfs, value) return paddle.where( - paddle.less_equal(value, paddle.to_tensor(0.0)), + paddle.less_equal(value, paddle.to_tensor(0.0, dtype=self.dtype)), paddle.zeros_like(value), paddle.where( - paddle.greater_equal(value, paddle.to_tensor(1.0)), + paddle.greater_equal( + value, paddle.to_tensor(1.0, dtype=self.dtype) + ), paddle.ones_like(value), unbounded_cdfs, ), @@ -416,7 +421,7 @@ def kl_divergence(self, other): other (ContinuousBernoulli): instance of Continuous Bernoulli. Returns: - Tensor, kl-divergence between two Continuous Bernoulli distributions. The data type is the global default dtype. + Tensor, kl-divergence between two Continuous Bernoulli distributions. """ diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index e7a9c9465de42..7960efa1052c3 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -38,10 +38,14 @@ class MultivariateNormal(distribution.Distribution): * :math:`covariance_matrix = \Sigma`: is the k-by-k covariance matrix. Args: - loc(int|float|Tensor): The mean of Multivariate Normal distribution. The data type of `loc` will be convert to the global default dtype. - covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert to the global default dtype. - precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to the global default dtype. - scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be convert to the global default dtype. + loc(int|float|Tensor): The mean of Multivariate Normal distribution. If the input data type is int or float, the data type of `loc` will be + convert to a 1-D Tensor the paddle global default dtype. + covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert + to be the same as the type of loc. + precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the + type of loc. + scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be + convert to be the same as the type of loc. Examples: .. code-block:: python @@ -87,8 +91,9 @@ def __init__( ): self.dtype = paddle.get_default_dtype() if isinstance(loc, (float, int)): - loc = [loc] - loc = paddle.to_tensor(loc, dtype=self.dtype) + loc = paddle.to_tensor([loc], dtype=self.dtype) + else: + self.dtype = loc.dtype if loc.dim() < 1: loc = loc.reshape((1,)) self.covariance_matrix = None @@ -295,7 +300,7 @@ def rsample(self, shape=()): if not isinstance(shape, Sequence): raise TypeError('sample shape must be Sequence object.') output_shape = self._extend_shape(shape) - eps = paddle.normal(shape=output_shape) + eps = paddle.cast(paddle.normal(shape=output_shape), dtype=self.dtype) return self.loc + paddle.matmul( self._unbroadcasted_scale_tril, eps.unsqueeze(-1) ).squeeze(-1) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index 8ee7eaf8e4233..d2719be7ecd91 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -17,6 +17,11 @@ import numpy as np import parameterize from distribution import config +from parameterize import ( + TEST_CASE_NAME, + parameterize_cls, + parameterize_func, +) import paddle from paddle.distribution.continuous_bernoulli import ContinuousBernoulli @@ -158,7 +163,7 @@ def np_kl_divergence(self, other): ('half', np.array(0.5).astype("float32")), ( 'one-dim', - parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float64"), ), ( 'multi-dim', @@ -171,7 +176,7 @@ def setUp(self): self._dist = ContinuousBernoulli( probability=paddle.to_tensor(self.probability), eps=0.02 ) - self._np_dist = ContinuousBernoulli_np(self.probability) + self._np_dist = ContinuousBernoulli_np(self.probability, eps=0.02) def test_mean(self): mean = self._dist.mean @@ -179,8 +184,8 @@ def test_mean(self): np.testing.assert_allclose( mean, self._np_dist.np_mean(), - rtol=0.00, - atol=0.20, + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), ) def test_variance(self): @@ -189,8 +194,8 @@ def test_variance(self): np.testing.assert_allclose( var, self._np_dist.np_variance(), - rtol=0.00, - atol=0.20, + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), ) def test_entropy(self): @@ -199,8 +204,8 @@ def test_entropy(self): np.testing.assert_allclose( entropy, self._np_dist.np_entropy(), - rtol=0.00, - atol=0.20, + rtol=0.005, + atol=0.0, ) def test_sample(self): @@ -212,16 +217,22 @@ def test_sample(self): sample_shape + self._dist.batch_shape + self._dist.event_shape, ) - sample_shape = (5000,) + sample_shape = (50000,) samples = self._dist.sample(sample_shape) sample_mean = samples.mean(axis=0) sample_variance = samples.var(axis=0) np.testing.assert_allclose( - sample_mean, self._dist.mean, atol=0, rtol=0.20 + sample_mean, + self._dist.mean, + rtol=0.02, + atol=0.0, ) np.testing.assert_allclose( - sample_variance, self._dist.variance, atol=0, rtol=0.20 + sample_variance, + self._dist.variance, + rtol=0.02, + atol=0.0, ) @@ -236,15 +247,17 @@ def test_sample(self): ), ( 'value-broadcast-shape', - parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), - parameterize.xrand((2, 3), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float64"), + parameterize.xrand((2, 3), min=0.0, max=1.0).astype("float64"), ), ], ) class TestContinuousBernoulliProbs(unittest.TestCase): def setUp(self): - self._dist = ContinuousBernoulli(probability=self.probability) - self._np_dist = ContinuousBernoulli_np(self.probability) + self._dist = ContinuousBernoulli( + probability=paddle.to_tensor(self.probability), eps=0.02 + ) + self._np_dist = ContinuousBernoulli_np(self.probability, eps=0.02) def test_prob(self): np.testing.assert_allclose( @@ -290,8 +303,8 @@ def test_icdf(self): ), ( 'multi-dim', - parameterize.xrand((5,), min=0.0, max=1.0).astype("float32"), - parameterize.xrand((5,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((5,), min=0.0, max=1.0).astype("float64"), + parameterize.xrand((5,), min=0.0, max=1.0).astype("float64"), ), ], ) @@ -299,13 +312,13 @@ class TestContinuousBernoulliKL(unittest.TestCase): def setUp(self): paddle.disable_static() self._dist1 = ContinuousBernoulli( - probability=paddle.to_tensor(self.p_1) + probability=paddle.to_tensor(self.p_1), eps=0.02 ) self._dist2 = ContinuousBernoulli( - probability=paddle.to_tensor(self.p_2) + probability=paddle.to_tensor(self.p_2), eps=0.02 ) - self._np_dist1 = ContinuousBernoulli_np(self.p_1) - self._np_dist2 = ContinuousBernoulli_np(self.p_2) + self._np_dist1 = ContinuousBernoulli_np(self.p_1, eps=0.02) + self._np_dist2 = ContinuousBernoulli_np(self.p_2, eps=0.02) def test_kl_divergence(self): kl0 = self._dist1.kl_divergence(self._dist2) @@ -313,7 +326,58 @@ def test_kl_divergence(self): self.assertEqual(tuple(kl0.shape), self._dist1.batch_shape) self.assertEqual(tuple(kl1.shape), self._dist1.batch_shape) - np.testing.assert_allclose(kl0, kl1, rtol=0.1, atol=0.1) + np.testing.assert_allclose( + kl0, + kl1, + rtol=0.005, + atol=0.0, + ) + + +@parameterize.place(config.DEVICES) +@parameterize_cls([TEST_CASE_NAME], ['ContinuousBernoulliTestError']) +class ContinuousBernoulliTestError(unittest.TestCase): + def setUp(self): + paddle.disable_static(self.place) + + @parameterize_func( + [ + (-0.1, ValueError), + (1.1, ValueError), + ] + ) + def test_bad_init(self, probs, error): + with paddle.base.dygraph.guard(self.place): + self.assertRaises(error, ContinuousBernoulli, probs) + + @parameterize_func( + [ + ( + paddle.to_tensor([0.3, 0.5]), + paddle.to_tensor([-0.1, 1.2]), + ), + ] + ) + def test_bad_log_prob_value(self, probs, value): + with paddle.base.dygraph.guard(self.place): + rv = ContinuousBernoulli(probs) + self.assertRaises(ValueError, rv.cdf, value) + self.assertRaises(ValueError, rv.log_prob, value) + self.assertRaises(ValueError, rv.icdf, value) + + @parameterize_func( + [ + ( + paddle.to_tensor([0.3, 0.5]), + paddle.to_tensor([0.2, 0.8, 0.6]), + ), + ] + ) + def test_bad_kl_div(self, probs1, probs2): + with paddle.base.dygraph.guard(self.place): + rv = ContinuousBernoulli(probs1) + rv_other = ContinuousBernoulli(probs2) + self.assertRaises(ValueError, rv.kl_divergence, rv_other) if __name__ == '__main__': diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index 7c742485422b0..68f4feea5a2db 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -17,13 +17,18 @@ import numpy as np import parameterize from distribution import config +from parameterize import ( + TEST_CASE_NAME, + parameterize_cls, + parameterize_func, +) import paddle from paddle.distribution.continuous_bernoulli import ContinuousBernoulli class ContinuousBernoulli_np: - def __init__(self, probability, eps=1e-4): + def __init__(self, probability, eps=0.02): self.eps = eps self.dtype = 'float32' eps_prob = 1.1920928955078125e-07 @@ -160,7 +165,7 @@ def np_kl_divergence(self, other): [ ( 'multi-dim', - parameterize.xrand((1, 3), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1, 3), min=0.0, max=1.0).astype("float32"), ), ], ) @@ -178,7 +183,7 @@ def setUp(self): mean = dist.mean var = dist.variance entropy = dist.entropy() - large_samples = dist.sample(shape=(1000,)) + large_samples = dist.sample(shape=(50000,)) fetch_list = [mean, var, entropy, large_samples] feed = {'probability': self.probability} @@ -195,7 +200,10 @@ def test_mean(self): str(self.mean.dtype).split('.')[-1], self.probability.dtype ) np.testing.assert_allclose( - self.mean, self._np_mean(), rtol=0.20, atol=0 + self.mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), ) def test_variance(self): @@ -203,7 +211,10 @@ def test_variance(self): str(self.var.dtype).split('.')[-1], self.probability.dtype ) np.testing.assert_allclose( - self.var, self._np_variance(), rtol=0.20, atol=0 + self.var, + self._np_variance(), + rtol=config.RTOL.get(str(self.probability.dtype)), + atol=config.ATOL.get(str(self.probability.dtype)), ) def test_entropy(self): @@ -211,14 +222,14 @@ def test_entropy(self): str(self.entropy.dtype).split('.')[-1], self.probability.dtype ) np.testing.assert_allclose( - self.entropy, self._np_entropy(), rtol=0.20, atol=0 + self.entropy, self._np_entropy(), rtol=0.005, atol=0 ) def test_sample(self): sample_mean = self.large_samples.mean(axis=0) sample_variance = self.large_samples.var(axis=0) - np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.20) - np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.20) + np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.02) + np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.02) def _np_variance(self): return self._np_dist.np_variance() @@ -236,8 +247,8 @@ def _np_entropy(self): [ ( 'value-broadcast-shape', - parameterize.xrand((1,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((2, 2), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2, 2), min=0.0, max=1.0).astype("float64"), ), ], ) @@ -280,8 +291,8 @@ def test_prob(self): [ ( 'multi-dim', - parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), - parameterize.xrand((2,), min=0.1, max=0.9).astype("float32"), + parameterize.xrand((2,), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2,), min=0.0, max=1.0).astype("float32"), ), ], ) @@ -314,7 +325,40 @@ def test_kl_divergence(self): self.assertEqual(tuple(kl0.shape), self.p_1.shape) self.assertEqual(tuple(kl1.shape), self.p_1.shape) - np.testing.assert_allclose(kl0, kl1, rtol=0, atol=0.2) + np.testing.assert_allclose( + kl0, + kl1, + rtol=0.005, + atol=0.0, + ) + + +@parameterize.place(config.DEVICES) +@parameterize_cls([TEST_CASE_NAME], ['ContinuousBernoulliTestError']) +class ContinuousBernoulliTestError(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + + @parameterize_func( + [ + (100,), # int + (100.0,), # float + ] + ) + def test_bad_sample_shape_type(self, shape): + with paddle.static.program_guard(self.program): + rv = ContinuousBernoulli(0.3) + + with self.assertRaises(TypeError): + [_] = self.executor.run( + self.program, feed={}, fetch_list=[rv.sample(shape)] + ) + + with self.assertRaises(TypeError): + [_] = self.executor.run( + self.program, feed={}, fetch_list=[rv.rsample(shape)] + ) if __name__ == '__main__': diff --git a/test/distribution/test_distribution_multivariate_normal.py b/test/distribution/test_distribution_multivariate_normal.py index 7e24d9ceeee7b..e1e5fcb0d0314 100644 --- a/test/distribution/test_distribution_multivariate_normal.py +++ b/test/distribution/test_distribution_multivariate_normal.py @@ -18,6 +18,11 @@ import parameterize import scipy from distribution import config +from parameterize import ( + TEST_CASE_NAME, + parameterize_cls, + parameterize_func, +) import paddle from paddle.distribution.multivariate_normal import MultivariateNormal @@ -34,7 +39,7 @@ ), ( 'multi-batch', - parameterize.xrand((2, 3), dtype='float32', min=-2, max=-1), + parameterize.xrand((2, 3), dtype='float64', min=-2, max=-1), np.array([[4.0, 2.5, 2.0], [2.5, 3.0, 1.2], [2.0, 1.2, 4.0]]), ), ], @@ -90,11 +95,12 @@ def test_sample(self): sample_mean = samples.mean(axis=0) sample_variance = samples.var(axis=0) + # `atol` and `rtol` refer to ``test_distribution_normal`` and ``test_distribution_lognormal`` np.testing.assert_allclose( - sample_mean, self._dist.mean, atol=0.00, rtol=0.20 + sample_mean, self._dist.mean, atol=0.0, rtol=0.1 ) np.testing.assert_allclose( - sample_variance, self._dist.variance, atol=0.00, rtol=0.20 + sample_variance, self._dist.variance, atol=0.0, rtol=0.1 ) def _np_variance(self): @@ -136,16 +142,16 @@ def _np_entropy(self): ), ( 'value-broadcast-shape', - parameterize.xrand((2,), dtype='float32', min=-2, max=2), + parameterize.xrand((2,), dtype='float64', min=-2, max=2), np.array([[2.0, 1.0], [1.0, 2.0]]), - parameterize.xrand((3, 2), dtype='float32', min=-5, max=5), + parameterize.xrand((3, 2), dtype='float64', min=-5, max=5), ), ], ) class TestMVNProbs(unittest.TestCase): def setUp(self): self._dist = MultivariateNormal( - loc=self.loc, + loc=paddle.to_tensor(self.loc), precision_matrix=paddle.to_tensor(self.precision_matrix), ) self.cov = np.linalg.inv(self.precision_matrix) @@ -248,5 +254,57 @@ def kl_divergence(self, dist1, dist2): return half_log_det_2 - half_log_det_1 + 0.5 * (expectation - 2.0) +@parameterize.place(config.DEVICES) +@parameterize_cls([TEST_CASE_NAME], ['MVNTestError']) +class MVNTestError(unittest.TestCase): + def setUp(self): + paddle.disable_static(self.place) + + @parameterize_func( + [ + (5, None, ValueError), # no matrix input + ( + 5, + paddle.to_tensor([2.0, 3.0]), + ValueError, + ), # wrong input matrix dim + ( + 5, + paddle.to_tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]]), + ValueError, + ), # non-sqaure input matrix + ( + 5, + paddle.to_tensor([[2.0, 3.0], [2.0, 3.0]]), + ValueError, + ), # non-symmetric input matrix + ( + 5, + paddle.to_tensor([[-2.0, 3.0], [3.0, -1.0]]), + ValueError, + ), # non-psd input matrix + ] + ) + def test_bad_cov_matrix(self, loc, matrix, error): + # with paddle.base.dygraph.guard(self.place): + self.assertRaises(error, MultivariateNormal, loc, matrix) + + @parameterize_func( + [ + ( + 1.0, + 2.0, + paddle.to_tensor([[3.0, 2.0], [2.0, 3.0]]), + paddle.to_tensor([[2.0, 1.0], [1.0, 2.0]]), + ), + ] + ) + def test_bad_kl_div(self, loc1, loc2, matrix1, matrix2): + # with paddle.base.dygraph.guard(self.place): + rv = MultivariateNormal(loc1, covariance_matrix=matrix1) + rv_other = MultivariateNormal(loc2, covariance_matrix=matrix2) + self.assertRaises(ValueError, rv.kl_divergence, rv_other) + + if __name__ == '__main__': unittest.main(argv=[''], verbosity=3, exit=False) diff --git a/test/distribution/test_distribution_multivariate_normal_static.py b/test/distribution/test_distribution_multivariate_normal_static.py index 522e88256563a..7a9b8702c7608 100644 --- a/test/distribution/test_distribution_multivariate_normal_static.py +++ b/test/distribution/test_distribution_multivariate_normal_static.py @@ -36,7 +36,7 @@ ), ( 'multi-batch', - parameterize.xrand((2, 3), dtype='float32', min=-2, max=-1), + parameterize.xrand((2, 3), dtype='float64', min=-2, max=-1), np.array([[6.0, 2.5, 3.0], [2.5, 4.0, 5.0], [3.0, 5.0, 7.0]]), ), ], @@ -106,10 +106,10 @@ def test_sample(self): ) sample_mean = self.large_samples.mean(axis=0) sample_variance = self.large_samples.var(axis=0) - np.testing.assert_allclose(sample_mean, self.mean, atol=0.00, rtol=0.20) - np.testing.assert_allclose( - sample_variance, self.var, atol=0.00, rtol=0.20 - ) + + # `atol` and `rtol` refer to ``test_distribution_normal`` and ``test_distribution_lognormal`` + np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.1) + np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.1) def _np_variance(self): batch_shape = np.broadcast_shapes( @@ -150,9 +150,9 @@ def _np_entropy(self): ), ( 'value-broadcast-shape', - parameterize.xrand((2,), dtype='float32', min=-2, max=2), + parameterize.xrand((2,), dtype='float64', min=-2, max=2), np.array([[2.0, 1.0], [1.0, 2.0]]), - parameterize.xrand((3, 2), dtype='float32', min=-5, max=5), + parameterize.xrand((3, 2), dtype='float64', min=-5, max=5), ), ], ) From 52e0fb47a93e205719dbd67e924da77cb421663a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 6 Dec 2023 18:59:27 +0800 Subject: [PATCH 18/29] fix test --- python/paddle/distribution/continuous_bernoulli.py | 14 ++++++++++---- .../test_distribution_continuous_bernoulli.py | 13 +++++++------ ...est_distribution_continuous_bernoulli_static.py | 9 +++++---- .../test_distribution_multivariate_normal.py | 12 ++++++------ 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index e040ef5d00988..7e4d88dc35b3a 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -322,10 +322,16 @@ def entropy(self): log_p = paddle.log(self.probability) log_1_minus_p = paddle.log1p(-self.probability) - return ( - -self._log_constant() - + self.mean * (log_1_minus_p - log_p) - - log_1_minus_p + return paddle.where( + paddle.equal( + self.probability, paddle.to_tensor(0.5, dtype=self.dtype) + ), + paddle.full_like(self.probability, 0.0), + ( + -self._log_constant() + + self.mean * (log_1_minus_p - log_p) + - log_1_minus_p + ), ) def cdf(self, value): diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index d2719be7ecd91..b49b9fe6a9a5c 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -94,8 +94,9 @@ def np_mean(self): return np.where(self._cut_support_region(), propose, taylor_expansion) def np_entropy(self): - log_p = np.log(self.probability) - log_1_minus_p = np.log1p(-self.probability) + cut_probs = self._cut_probs() + log_p = np.log(cut_probs) + log_1_minus_p = np.log1p(-cut_probs) return ( -self._log_constant() + self.np_mean() * (log_1_minus_p - log_p) @@ -163,11 +164,11 @@ def np_kl_divergence(self, other): ('half', np.array(0.5).astype("float32")), ( 'one-dim', - parameterize.xrand((1,), min=0.0, max=1.0).astype("float64"), + parameterize.xrand((1,), min=0.0, max=0.498).astype("float64"), ), ( 'multi-dim', - parameterize.xrand((2, 3), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((2, 3), min=0.498, max=1.0).astype("float32"), ), ], ) @@ -194,8 +195,8 @@ def test_variance(self): np.testing.assert_allclose( var, self._np_dist.np_variance(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=0.005, + atol=0.0, ) def test_entropy(self): diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index 68f4feea5a2db..d6ea4ca4d1189 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -94,8 +94,9 @@ def np_mean(self): return np.where(self._cut_support_region(), propose, taylor_expansion) def np_entropy(self): - log_p = np.log(self.probability) - log_1_minus_p = np.log1p(-self.probability) + cut_probs = self._cut_probs() + log_p = np.log(cut_probs) + log_1_minus_p = np.log1p(-cut_probs) return ( -self._log_constant() + self.np_mean() * (log_1_minus_p - log_p) @@ -213,8 +214,8 @@ def test_variance(self): np.testing.assert_allclose( self.var, self._np_variance(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=0.005, + atol=0.0, ) def test_entropy(self): diff --git a/test/distribution/test_distribution_multivariate_normal.py b/test/distribution/test_distribution_multivariate_normal.py index e1e5fcb0d0314..3c6059defe859 100644 --- a/test/distribution/test_distribution_multivariate_normal.py +++ b/test/distribution/test_distribution_multivariate_normal.py @@ -286,8 +286,8 @@ def setUp(self): ] ) def test_bad_cov_matrix(self, loc, matrix, error): - # with paddle.base.dygraph.guard(self.place): - self.assertRaises(error, MultivariateNormal, loc, matrix) + with paddle.base.dygraph.guard(self.place): + self.assertRaises(error, MultivariateNormal, loc, matrix) @parameterize_func( [ @@ -300,10 +300,10 @@ def test_bad_cov_matrix(self, loc, matrix, error): ] ) def test_bad_kl_div(self, loc1, loc2, matrix1, matrix2): - # with paddle.base.dygraph.guard(self.place): - rv = MultivariateNormal(loc1, covariance_matrix=matrix1) - rv_other = MultivariateNormal(loc2, covariance_matrix=matrix2) - self.assertRaises(ValueError, rv.kl_divergence, rv_other) + with paddle.base.dygraph.guard(self.place): + rv = MultivariateNormal(loc1, covariance_matrix=matrix1) + rv_other = MultivariateNormal(loc2, covariance_matrix=matrix2) + self.assertRaises(ValueError, rv.kl_divergence, rv_other) if __name__ == '__main__': From 3a52d072b267d2e05791c6d64a85478ab14afd31 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 7 Dec 2023 12:19:28 +0800 Subject: [PATCH 19/29] fix test --- .../test_distribution_continuous_bernoulli.py | 17 ++++++++++------- ..._distribution_continuous_bernoulli_static.py | 17 ++++++++++------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index b49b9fe6a9a5c..e7927e81e806e 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -94,13 +94,16 @@ def np_mean(self): return np.where(self._cut_support_region(), propose, taylor_expansion) def np_entropy(self): - cut_probs = self._cut_probs() - log_p = np.log(cut_probs) - log_1_minus_p = np.log1p(-cut_probs) - return ( - -self._log_constant() - + self.np_mean() * (log_1_minus_p - log_p) - - log_1_minus_p + log_p = np.log(self.probability) + log_1_minus_p = np.log1p(-self.probability) + return np.where( + np.equal(self.probability, 0.5), + np.full_like(self.probability, 0.0), + ( + -self._log_constant() + + self.np_mean() * (log_1_minus_p - log_p) + - log_1_minus_p + ), ) def np_prob(self, value): diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index d6ea4ca4d1189..d253a10d8f6dd 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -94,13 +94,16 @@ def np_mean(self): return np.where(self._cut_support_region(), propose, taylor_expansion) def np_entropy(self): - cut_probs = self._cut_probs() - log_p = np.log(cut_probs) - log_1_minus_p = np.log1p(-cut_probs) - return ( - -self._log_constant() - + self.np_mean() * (log_1_minus_p - log_p) - - log_1_minus_p + log_p = np.log(self.probability) + log_1_minus_p = np.log1p(-self.probability) + return np.where( + np.equal(self.probability, 0.5), + np.full_like(self.probability, 0.0), + ( + -self._log_constant() + + self.np_mean() * (log_1_minus_p - log_p) + - log_1_minus_p + ), ) def np_prob(self, value): From 04564bc85cd41cd4cdbb4bdc4ccc2e460c120af0 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 7 Dec 2023 17:28:40 +0800 Subject: [PATCH 20/29] fix test --- .../test_distribution_continuous_bernoulli_static.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index d253a10d8f6dd..ab0c9ea034d46 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -169,7 +169,7 @@ def np_kl_divergence(self, other): [ ( 'multi-dim', - parameterize.xrand((1, 3), min=0.0, max=1.0).astype("float32"), + parameterize.xrand((1, 3), min=0.0, max=0.498).astype("float32"), ), ], ) From 4532cd82967eedee1ebbaa2f08f281aa8f1e3d92 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 13 Dec 2023 15:08:59 +0800 Subject: [PATCH 21/29] refine docs --- python/paddle/distribution/continuous_bernoulli.py | 10 +++++----- python/paddle/distribution/multivariate_normal.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 7e4d88dc35b3a..8d91e0ae5f09f 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -54,7 +54,7 @@ class ContinuousBernoulli(distribution.Distribution): `probability` will be convert to a 1-D Tensor the paddle global default dtype. eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The - default value is 1e-4. + default value is 0.02. Examples: .. code-block:: python @@ -277,7 +277,7 @@ def log_prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: log probability. The data type is same with :attr:`value` . + Tensor: log probability. The data type is the same as `self.probability`. """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): @@ -299,7 +299,7 @@ def prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: probability. The data type is same with :attr:`value` . + Tensor: probability. The data type is the same as `self.probability`. """ return paddle.exp(self.log_prob(value)) @@ -352,7 +352,7 @@ def cdf(self, value): value (Tensor): The input tensor. Returns: - Tensor: quantile of :attr:`value`. The data type is same with :attr:`value` . + Tensor: quantile of :attr:`value`. The data type is the same as `self.probability`. """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): @@ -396,7 +396,7 @@ def icdf(self, value): value (Tensor): The input tensor, meaning the quantile. Returns: - Tensor: the value of the r.v. corresponding to the quantile. The data type is same with :attr:`value` . + Tensor: the value of the r.v. corresponding to the quantile. The data type is the same as `self.probability`. """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 7960efa1052c3..be0795c88d4f2 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -283,7 +283,7 @@ def sample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the global default dtype. + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the same as `self.loc`. """ with paddle.no_grad(): return self.rsample(shape) @@ -295,7 +295,7 @@ def rsample(self, shape=()): shape (Sequence[int], optional): Prepended shape of the generated samples. Returns: - Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the global default dtype. + Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the same as `self.loc`. """ if not isinstance(shape, Sequence): raise TypeError('sample shape must be Sequence object.') @@ -312,7 +312,7 @@ def log_prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: log probability. The data type is same with :attr:`value` . + Tensor: log probability. The data type is the same as `self.loc`. """ value = paddle.cast(value, dtype=self.dtype) @@ -335,7 +335,7 @@ def prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: probability. The data type is same with :attr:`value` . + Tensor: probability. The data type is the same as `self.loc`. """ return paddle.exp(self.log_prob(value)) @@ -353,7 +353,7 @@ def entropy(self): * :math:\Omega: is the support of the distribution. Returns: - Tensor, Shannon entropy of Multivariate Normal distribution. The data type is the global default dtype. + Tensor, Shannon entropy of Multivariate Normal distribution. The data type is the same as `self.loc`. """ half_log_det = ( self._unbroadcasted_scale_tril.diagonal(axis1=-2, axis2=-1) @@ -382,7 +382,7 @@ def kl_divergence(self, other): other (MultivariateNormal): instance of Multivariate Normal. Returns: - Tensor, kl-divergence between two Multivariate Normal distributions. The data type is the global default dtype. + Tensor, kl-divergence between two Multivariate Normal distributions. The data type is the same as `self.loc`. """ if ( From 6e5d19a6bf77f9ccb322b2bf54e97fbf7410ee9f Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 13 Dec 2023 20:15:17 +0800 Subject: [PATCH 22/29] update docs --- .../distribution/continuous_bernoulli.py | 63 +++++++++------- .../distribution/multivariate_normal.py | 71 ++++++++++--------- 2 files changed, 75 insertions(+), 59 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 8d91e0ae5f09f..26d05d6bdf107 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -22,7 +22,7 @@ class ContinuousBernoulli(distribution.Distribution): r"""The Continuous Bernoulli distribution with parameter: `probability` characterizing the shape of the density function. The Continuous Bernoulli distribution is defined on [0, 1], and it can be viewed as a continuous version of the Bernoulli distribution. - [1] Loaiza-Ganem, G., & Cunningham, J. P. The continuous Bernoulli: fixing a pervasive error in variational autoencoders. 2019. + `The continuous Bernoulli: fixing a pervasive error in variational autoencoders. `_ Mathematical details @@ -50,39 +50,52 @@ class ContinuousBernoulli(distribution.Distribution): Args: probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], - which characterize the shape of the pdf. If the input data type is int or float, the data type of - `probability` will be convert to a 1-D Tensor the paddle global default dtype. + which characterize the shape of the pdf. If the input data type is int or float, the data type of + `probability` will be convert to a 1-D Tensor the paddle global default dtype. eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region - would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The - default value is 0.02. + would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The + default value is 0.02. Examples: .. code-block:: python - import paddle - from paddle.distribution import ContinuousBernoulli + >>> import paddle + >>> from paddle.distribution import ContinuousBernoulli + >>> paddle.set_device("cpu") + >>> paddle.seed(100) - # init `probability` with `paddle.Tensor` - rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5])) + >>> rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5])) - print(rv.sample([2])) - # Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - # [[0.09428147, 0.81438422], - # [0.24624705, 0.93354583]]) + >>> print(rv.sample([2])) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.38694882, 0.20714243], + [0.00631948, 0.51577556]]) - print(rv.mean) - # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - # [0.38801414, 0.50000000]) + >>> print(rv.mean) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.38801414, 0.50000000]) - print(rv.entropy()) - # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - # [-0.07641461, 0. ]) + >>> print(rv.variance) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.07589778, 0.08333334]) - rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8])) - rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5])) - print(rv1.kl_divergence(rv2)) - # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - # [0.20103613, 0.07641447]) + >>> print(rv.entropy()) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.07641457, 0. ]) + + >>> print(rv.cdf(paddle.to_tensor(0.1))) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.17259926, 0.10000000]) + + >>> print(rv.icdf(paddle.to_tensor(0.1))) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.05623737, 0.10000000]) + + >>> rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8])) + >>> rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5])) + >>> print(rv1.kl_divergence(rv2)) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.20103608, 0.07641447]) """ def __init__(self, probability, eps=0.02): @@ -314,7 +327,7 @@ def entropy(self): In the above equation: - * :math:\Omega: is the support of the distribution. + * :math:`\Omega`: is the support of the distribution. Returns: Tensor, Shannon entropy of Continuous Bernoulli distribution. diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index be0795c88d4f2..bad58bbb0b2be 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -39,47 +39,50 @@ class MultivariateNormal(distribution.Distribution): Args: loc(int|float|Tensor): The mean of Multivariate Normal distribution. If the input data type is int or float, the data type of `loc` will be - convert to a 1-D Tensor the paddle global default dtype. + convert to a 1-D Tensor the paddle global default dtype. covariance_matrix(Tensor): The covariance matrix of Multivariate Normal distribution. The data type of `covariance_matrix` will be convert - to be the same as the type of loc. + to be the same as the type of loc. precision_matrix(Tensor): The inverse of the covariance matrix. The data type of `precision_matrix` will be convert to be the same as the - type of loc. + type of loc. scale_tril(Tensor): The cholesky decomposition (lower triangular matrix) of the covariance matrix. The data type of `scale_tril` will be - convert to be the same as the type of loc. + convert to be the same as the type of loc. Examples: .. code-block:: python - import paddle - from paddle.distribution import MultivariateNormal - - # init `loc` and `covariance_matrix` with `paddle.Tensor` - rv = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) - - print(rv.sample([3, 2])) - # Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - # [[[0.68554986, 3.85142398], - # [1.88336682, 5.43841648]], - # - # [[5.32492065, 7.23725986], - # [3.42192221, 4.83934879]], - # - # [[3.36775684, 4.46108866], - # [4.58927441, 4.32255936]]]) - - print(rv.mean) - # Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, - # [2., 5.]) - - print(rv.entropy()) - # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - # 3.38718319) - - rv1 = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) - rv2 = MultivariateNormal(loc=paddle.to_tensor([-1.,3.]), covariance_matrix=paddle.to_tensor([[3.,2.],[2.,3.]])) - print(rv1.kl_divergence(rv2)) - # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - # 1.55541301) + >>> import paddle + >>> from paddle.distribution import MultivariateNormal + >>> paddle.set_device("cpu") + >>> paddle.seed(100) + + >>> rv = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + + >>> print(rv.sample([3, 2])) + Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[2.36634731, 3.44818163], + [1.57115066, 4.79757214]], + [[0.91755736, 2.81447577], + [0.12842906, 4.38841820]], + [[1.60453653, 5.57910490], + [1.28331566, 2.50838280]]]) + + >>> print(rv.mean) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [2., 5.]) + + >>> print(rv.variance) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [1.99999988, 2. ]) + + >>> print(rv.entropy()) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 3.38718319) + + >>> rv1 = MultivariateNormal(loc=paddle.to_tensor([2.,5.]), covariance_matrix=paddle.to_tensor([[2.,1.],[1.,2.]])) + >>> rv2 = MultivariateNormal(loc=paddle.to_tensor([-1.,3.]), covariance_matrix=paddle.to_tensor([[3.,2.],[2.,3.]])) + >>> print(rv1.kl_divergence(rv2)) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 1.55541301) """ def __init__( From 806c1f67d161d10c48cb5de3552f393018b5466c Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Wed, 13 Dec 2023 20:16:59 +0800 Subject: [PATCH 23/29] update docs --- python/paddle/distribution/multivariate_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index bad58bbb0b2be..0cc9428f12c3d 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -353,7 +353,7 @@ def entropy(self): In the above equation: - * :math:\Omega: is the support of the distribution. + * :math:`\Omega`: is the support of the distribution. Returns: Tensor, Shannon entropy of Multivariate Normal distribution. The data type is the same as `self.loc`. From 1c645a891221535c1640424bb2887b79f2a16938 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 14 Dec 2023 10:13:06 +0800 Subject: [PATCH 24/29] update docs --- python/paddle/distribution/multivariate_normal.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 0cc9428f12c3d..ef726d4d1f041 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -59,12 +59,12 @@ class MultivariateNormal(distribution.Distribution): >>> print(rv.sample([3, 2])) Tensor(shape=[3, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, - [[[2.36634731, 3.44818163], - [1.57115066, 4.79757214]], - [[0.91755736, 2.81447577], - [0.12842906, 4.38841820]], - [[1.60453653, 5.57910490], - [1.28331566, 2.50838280]]]) + [[[-0.00339603, 4.31556797], + [ 2.01385283, 4.63553190]], + [[ 0.10132277, 3.11323833], + [ 2.37435842, 3.56635118]], + [[ 2.89701366, 5.10602522], + [-0.46329355, 3.14768648]]]) >>> print(rv.mean) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, From dec2efb1492def6deb809a91634fd104d9d5f1e1 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 14 Dec 2023 13:04:44 +0800 Subject: [PATCH 25/29] update cb api --- .../distribution/continuous_bernoulli.py | 91 ++++++++-------- .../test_distribution_continuous_bernoulli.py | 101 +++++++++--------- ...istribution_continuous_bernoulli_static.py | 93 ++++++++-------- 3 files changed, 135 insertions(+), 150 deletions(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index 26d05d6bdf107..cb44b340c4c9b 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -19,7 +19,7 @@ class ContinuousBernoulli(distribution.Distribution): - r"""The Continuous Bernoulli distribution with parameter: `probability` characterizing the shape of the density function. + r"""The Continuous Bernoulli distribution with parameter: `probs` characterizing the shape of the density function. The Continuous Bernoulli distribution is defined on [0, 1], and it can be viewed as a continuous version of the Bernoulli distribution. `The continuous Bernoulli: fixing a pervasive error in variational autoencoders. `_ @@ -35,7 +35,7 @@ class ContinuousBernoulli(distribution.Distribution): In the above equation: * :math:`x`: is continuous between 0 and 1 - * :math:`probability = \lambda`: is the probability. + * :math:`probs = \lambda`: is the probability. * :math:`C(\lambda)`: is the normalizing constant factor .. math:: @@ -49,12 +49,11 @@ class ContinuousBernoulli(distribution.Distribution): \right. Args: - probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], + probs(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1], which characterize the shape of the pdf. If the input data type is int or float, the data type of - `probability` will be convert to a 1-D Tensor the paddle global default dtype. - eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region - would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The - default value is 0.02. + `probs` will be convert to a 1-D Tensor the paddle global default dtype. + lims(tuple): Specify the unstable calculation region near 0.5, where the calculation is approximated + by talyor expansion. The default value is (0.499, 0.501). Examples: .. code-block:: python @@ -98,39 +97,37 @@ class ContinuousBernoulli(distribution.Distribution): [0.20103608, 0.07641447]) """ - def __init__(self, probability, eps=0.02): + def __init__(self, probs=None, lims=(0.499, 0.501)): self.dtype = paddle.get_default_dtype() - self.probability = self._to_tensor(probability) - self.eps = paddle.to_tensor(eps, dtype=self.dtype) - if not self._check_constraint(self.probability): + self.probs = self._to_tensor(probs) + self.lims = paddle.to_tensor(lims, dtype=self.dtype) + if not self._check_constraint(self.probs): raise ValueError( - 'Every element of input parameter `probability` should be nonnegative.' + 'Every element of input parameter `probs` should be nonnegative.' ) - # eps_prob is used to clip the input `probability` in the range of [eps_prob, 1-eps_prob] - eps_prob = paddle.finfo(self.probability.dtype).eps - self.probability = paddle.clip( - self.probability, min=eps_prob, max=1 - eps_prob - ) + # eps_prob is used to clip the input `probs` in the range of [eps_prob, 1-eps_prob] + eps_prob = paddle.finfo(self.probs.dtype).eps + self.probs = paddle.clip(self.probs, min=eps_prob, max=1 - eps_prob) - if self.probability.shape == []: + if self.probs.shape == []: batch_shape = (1,) else: - batch_shape = self.probability.shape + batch_shape = self.probs.shape super().__init__(batch_shape) - def _to_tensor(self, probability): + def _to_tensor(self, probs): """Convert the input parameters into tensors Returns: Tensor: converted probability. """ # convert type - if isinstance(probability, (float, int)): - probability = paddle.to_tensor([probability], dtype=self.dtype) + if isinstance(probs, (float, int)): + probs = paddle.to_tensor([probs], dtype=self.dtype) else: - self.dtype = probability.dtype - return probability + self.dtype = probs.dtype + return probs def _check_constraint(self, value): """Check the constraint for input parameters @@ -144,26 +141,26 @@ def _check_constraint(self, value): return (value >= 0).all() and (value <= 1).all() def _cut_support_region(self): - """Generate stable support region indicator (prob < 0.5 - self.eps && prob >= 0.5 + self.eps ) + """Generate stable support region indicator (prob < self.lims[0] && prob >= self.lims[1] ) Returns: Tensor: the element of the returned indicator tensor corresponding to stable region is True, and False otherwise """ return paddle.logical_or( - paddle.less_equal(self.probability, 0.5 - self.eps), - paddle.greater_than(self.probability, 0.5 + self.eps), + paddle.less_equal(self.probs, self.lims[0]), + paddle.greater_than(self.probs, self.lims[1]), ) def _cut_probs(self): """Cut the probability parameter with stable support region Returns: - Tensor: the element of the returned probability tensor corresponding to unstable region is set to be (0.5 - self.eps), and unchanged otherwise + Tensor: the element of the returned probability tensor corresponding to unstable region is set to be self.lims[0], and unchanged otherwise """ return paddle.where( self._cut_support_region(), - self.probability, - (0.5 - self.eps) * paddle.ones_like(self.probability), + self.probs, + self.lims[0] * paddle.ones_like(self.probs), ) def _tanh_inverse(self, value): @@ -202,7 +199,7 @@ def _log_constant(self): paddle.log1p(-2.0 * cut_probs_below_half), paddle.log(2.0 * cut_probs_above_half - 1.0), ) - x = paddle.square(self.probability - 0.5) + x = paddle.square(self.probs - 0.5) taylor_expansion = ( paddle.log(paddle.to_tensor(2.0, dtype=self.dtype)) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x @@ -224,7 +221,7 @@ def mean(self): paddle.to_tensor(1.0, dtype=self.dtype), 2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs), ) - x = self.probability - 0.5 + x = self.probs - 0.5 taylor_expansion = ( 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * paddle.square(x)) * x ) @@ -248,7 +245,7 @@ def variance(self): paddle.to_tensor(1.0, dtype=self.dtype), paddle.square(paddle.log1p(-cut_probs) - paddle.log(cut_probs)), ) - x = paddle.square(self.probability - 0.5) + x = paddle.square(self.probs - 0.5) taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x return paddle.where( self._cut_support_region(), propose, taylor_expansion @@ -290,17 +287,17 @@ def log_prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: log probability. The data type is the same as `self.probability`. + Tensor: log probability. The data type is the same as `self.probs`. """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): raise ValueError( 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.' ) - eps = paddle.finfo(self.probability.dtype).eps + eps = paddle.finfo(self.probs.dtype).eps cross_entropy = paddle.nan_to_num( - value * paddle.log(self.probability) - + (1.0 - value) * paddle.log(1 - self.probability), + value * paddle.log(self.probs) + + (1.0 - value) * paddle.log(1 - self.probs), neginf=-eps, ) return self._log_constant() + cross_entropy @@ -312,7 +309,7 @@ def prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: probability. The data type is the same as `self.probability`. + Tensor: probability. The data type is the same as `self.probs`. """ return paddle.exp(self.log_prob(value)) @@ -332,14 +329,12 @@ def entropy(self): Returns: Tensor, Shannon entropy of Continuous Bernoulli distribution. """ - log_p = paddle.log(self.probability) - log_1_minus_p = paddle.log1p(-self.probability) + log_p = paddle.log(self.probs) + log_1_minus_p = paddle.log1p(-self.probs) return paddle.where( - paddle.equal( - self.probability, paddle.to_tensor(0.5, dtype=self.dtype) - ), - paddle.full_like(self.probability, 0.0), + paddle.equal(self.probs, paddle.to_tensor(0.5, dtype=self.dtype)), + paddle.full_like(self.probs, 0.0), ( -self._log_constant() + self.mean * (log_1_minus_p - log_p) @@ -365,7 +360,7 @@ def cdf(self, value): value (Tensor): The input tensor. Returns: - Tensor: quantile of :attr:`value`. The data type is the same as `self.probability`. + Tensor: quantile of :attr:`value`. The data type is the same as `self.probs`. """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): @@ -409,7 +404,7 @@ def icdf(self, value): value (Tensor): The input tensor, meaning the quantile. Returns: - Tensor: the value of the r.v. corresponding to the quantile. The data type is the same as `self.probability`. + Tensor: the value of the r.v. corresponding to the quantile. The data type is the same as `self.probs`. """ value = paddle.cast(value, dtype=self.dtype) if not self._check_constraint(value): @@ -449,8 +444,8 @@ def kl_divergence(self, other): "KL divergence of two Continuous Bernoulli distributions should share the same `batch_shape`." ) part1 = -self.entropy() - log_q = paddle.log(other.probability) - log_1_minus_q = paddle.log1p(-other.probability) + log_q = paddle.log(other.probs) + log_1_minus_q = paddle.log1p(-other.probs) part2 = -( other._log_constant() + self.mean * (log_q - log_1_minus_q) diff --git a/test/distribution/test_distribution_continuous_bernoulli.py b/test/distribution/test_distribution_continuous_bernoulli.py index e7927e81e806e..8b5c5c9c7ee52 100644 --- a/test/distribution/test_distribution_continuous_bernoulli.py +++ b/test/distribution/test_distribution_continuous_bernoulli.py @@ -28,25 +28,23 @@ class ContinuousBernoulli_np: - def __init__(self, probability, eps=0.02): - self.eps = eps - self.dtype = 'float32' + def __init__(self, probs, lims=(0.48, 0.52)): + self.lims = lims + self.dtype = probs.dtype eps_prob = 1.1920928955078125e-07 - self.probability = np.clip( - probability, a_min=eps_prob, a_max=1.0 - eps_prob - ) + self.probs = np.clip(probs, a_min=eps_prob, a_max=1.0 - eps_prob) def _cut_support_region(self): return np.logical_or( - np.less_equal(self.probability, 0.5 - self.eps), - np.greater_equal(self.probability, 0.5 + self.eps), + np.less_equal(self.probs, self.lims[0]), + np.greater_equal(self.probs, self.lims[1]), ) def _cut_probs(self): return np.where( self._cut_support_region(), - self.probability, - (0.5 - self.eps) * np.ones_like(self.probability), + self.probs, + self.lims[0] * np.ones_like(self.probs), ) def _tanh_inverse(self, value): @@ -67,7 +65,7 @@ def _log_constant(self): np.log1p(-2.0 * cut_probs_below_half), np.log(2.0 * cut_probs_above_half - 1.0), ) - x = np.square(self.probability - 0.5) + x = np.square(self.probs - 0.5) taylor_expansion = np.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x return np.where( self._cut_support_region(), log_constant_propose, taylor_expansion @@ -81,7 +79,7 @@ def np_variance(self): propose = tmp + np.divide( 1.0, np.square(2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) ) - x = np.square(self.probability - 0.5) + x = np.square(self.probs - 0.5) taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x return np.where(self._cut_support_region(), propose, taylor_expansion) @@ -89,16 +87,16 @@ def np_mean(self): cut_probs = self._cut_probs() tmp = cut_probs / (2.0 * cut_probs - 1.0) propose = tmp + 1.0 / (2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) - x = self.probability - 0.5 + x = self.probs - 0.5 taylor_expansion = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * np.square(x)) * x return np.where(self._cut_support_region(), propose, taylor_expansion) def np_entropy(self): - log_p = np.log(self.probability) - log_1_minus_p = np.log1p(-self.probability) + log_p = np.log(self.probs) + log_1_minus_p = np.log1p(-self.probs) return np.where( - np.equal(self.probability, 0.5), - np.full_like(self.probability, 0.0), + np.equal(self.probs, 0.5), + np.full_like(self.probs, 0.0), ( -self._log_constant() + self.np_mean() * (log_1_minus_p - log_p) @@ -112,8 +110,7 @@ def np_prob(self, value): def np_log_prob(self, value): eps = 1e-8 cross_entropy = np.nan_to_num( - value * np.log(self.probability) - + (1.0 - value) * np.log(1 - self.probability), + value * np.log(self.probs) + (1.0 - value) * np.log(1 - self.probs), neginf=-eps, ) return self._log_constant() + cross_entropy @@ -150,8 +147,8 @@ def np_icdf(self, value): def np_kl_divergence(self, other): part1 = -self.np_entropy() - log_q = np.log(other.probability) - log_1_minus_q = np.log1p(-other.probability) + log_q = np.log(other.probs) + log_1_minus_q = np.log1p(-other.probs) part2 = -( other._log_constant() + self.np_mean() * (log_q - log_1_minus_q) @@ -162,60 +159,60 @@ def np_kl_divergence(self, other): @parameterize.place(config.DEVICES) @parameterize.parameterize_cls( - (parameterize.TEST_CASE_NAME, 'probability'), + (parameterize.TEST_CASE_NAME, 'probs'), [ ('half', np.array(0.5).astype("float32")), ( 'one-dim', - parameterize.xrand((1,), min=0.0, max=0.498).astype("float64"), + parameterize.xrand((1,), min=0.0, max=1.0).astype("float64"), ), ( 'multi-dim', - parameterize.xrand((2, 3), min=0.498, max=1.0).astype("float32"), + parameterize.xrand((2, 3), min=0.0, max=1.0).astype("float32"), ), ], ) class TestContinuousBernoulli(unittest.TestCase): def setUp(self): self._dist = ContinuousBernoulli( - probability=paddle.to_tensor(self.probability), eps=0.02 + probs=paddle.to_tensor(self.probs), lims=(0.48, 0.52) ) - self._np_dist = ContinuousBernoulli_np(self.probability, eps=0.02) + self._np_dist = ContinuousBernoulli_np(self.probs, lims=(0.48, 0.52)) def test_mean(self): mean = self._dist.mean - self.assertEqual(mean.numpy().dtype, self.probability.dtype) + self.assertEqual(mean.numpy().dtype, self.probs.dtype) np.testing.assert_allclose( mean, self._np_dist.np_mean(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) def test_variance(self): var = self._dist.variance - self.assertEqual(var.numpy().dtype, self.probability.dtype) + self.assertEqual(var.numpy().dtype, self.probs.dtype) np.testing.assert_allclose( var, self._np_dist.np_variance(), - rtol=0.005, + rtol=0.01, atol=0.0, ) def test_entropy(self): entropy = self._dist.entropy() - self.assertEqual(entropy.numpy().dtype, self.probability.dtype) + self.assertEqual(entropy.numpy().dtype, self.probs.dtype) np.testing.assert_allclose( entropy, self._np_dist.np_entropy(), - rtol=0.005, + rtol=0.01, atol=0.0, ) def test_sample(self): sample_shape = () samples = self._dist.sample(sample_shape) - self.assertEqual(samples.numpy().dtype, self.probability.dtype) + self.assertEqual(samples.numpy().dtype, self.probs.dtype) self.assertEqual( tuple(samples.shape), sample_shape + self._dist.batch_shape + self._dist.event_shape, @@ -229,20 +226,20 @@ def test_sample(self): np.testing.assert_allclose( sample_mean, self._dist.mean, - rtol=0.02, + rtol=0.1, atol=0.0, ) np.testing.assert_allclose( sample_variance, self._dist.variance, - rtol=0.02, + rtol=0.1, atol=0.0, ) @parameterize.place(config.DEVICES) @parameterize.parameterize_cls( - (parameterize.TEST_CASE_NAME, 'probability', 'value'), + (parameterize.TEST_CASE_NAME, 'probs', 'value'), [ ( 'value-same-shape', @@ -259,40 +256,40 @@ def test_sample(self): class TestContinuousBernoulliProbs(unittest.TestCase): def setUp(self): self._dist = ContinuousBernoulli( - probability=paddle.to_tensor(self.probability), eps=0.02 + probs=paddle.to_tensor(self.probs), lims=(0.48, 0.52) ) - self._np_dist = ContinuousBernoulli_np(self.probability, eps=0.02) + self._np_dist = ContinuousBernoulli_np(self.probs, lims=(0.48, 0.52)) def test_prob(self): np.testing.assert_allclose( self._dist.prob(paddle.to_tensor(self.value)), self._np_dist.np_prob(self.value), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) def test_log_prob(self): np.testing.assert_allclose( self._dist.log_prob(paddle.to_tensor(self.value)), self._np_dist.np_log_prob(self.value), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) def test_cdf(self): np.testing.assert_allclose( self._dist.cdf(paddle.to_tensor(self.value)), self._np_dist.np_cdf(self.value), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) def test_icdf(self): np.testing.assert_allclose( self._dist.icdf(paddle.to_tensor(self.value)), self._np_dist.np_icdf(self.value), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) @@ -316,13 +313,13 @@ class TestContinuousBernoulliKL(unittest.TestCase): def setUp(self): paddle.disable_static() self._dist1 = ContinuousBernoulli( - probability=paddle.to_tensor(self.p_1), eps=0.02 + probs=paddle.to_tensor(self.p_1), lims=(0.48, 0.52) ) self._dist2 = ContinuousBernoulli( - probability=paddle.to_tensor(self.p_2), eps=0.02 + probs=paddle.to_tensor(self.p_2), lims=(0.48, 0.52) ) - self._np_dist1 = ContinuousBernoulli_np(self.p_1, eps=0.02) - self._np_dist2 = ContinuousBernoulli_np(self.p_2, eps=0.02) + self._np_dist1 = ContinuousBernoulli_np(self.p_1, lims=(0.48, 0.52)) + self._np_dist2 = ContinuousBernoulli_np(self.p_2, lims=(0.48, 0.52)) def test_kl_divergence(self): kl0 = self._dist1.kl_divergence(self._dist2) @@ -333,7 +330,7 @@ def test_kl_divergence(self): np.testing.assert_allclose( kl0, kl1, - rtol=0.005, + rtol=0.01, atol=0.0, ) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index ab0c9ea034d46..87edbc893b89b 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -28,25 +28,23 @@ class ContinuousBernoulli_np: - def __init__(self, probability, eps=0.02): - self.eps = eps - self.dtype = 'float32' + def __init__(self, probs, lims=(0.48, 0.52)): + self.lims = lims + self.dtype = probs.dtype eps_prob = 1.1920928955078125e-07 - self.probability = np.clip( - probability, a_min=eps_prob, a_max=1.0 - eps_prob - ) + self.probs = np.clip(probs, a_min=eps_prob, a_max=1.0 - eps_prob) def _cut_support_region(self): return np.logical_or( - np.less_equal(self.probability, 0.5 - self.eps), - np.greater_equal(self.probability, 0.5 + self.eps), + np.less_equal(self.probs, self.lims[0]), + np.greater_equal(self.probs, self.lims[1]), ) def _cut_probs(self): return np.where( self._cut_support_region(), - self.probability, - (0.5 - self.eps) * np.ones_like(self.probability), + self.probs, + self.lims[0] * np.ones_like(self.probs), ) def _tanh_inverse(self, value): @@ -67,7 +65,7 @@ def _log_constant(self): np.log1p(-2.0 * cut_probs_below_half), np.log(2.0 * cut_probs_above_half - 1.0), ) - x = np.square(self.probability - 0.5) + x = np.square(self.probs - 0.5) taylor_expansion = np.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x return np.where( self._cut_support_region(), log_constant_propose, taylor_expansion @@ -81,7 +79,7 @@ def np_variance(self): propose = tmp + np.divide( 1.0, np.square(2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) ) - x = np.square(self.probability - 0.5) + x = np.square(self.probs - 0.5) taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x return np.where(self._cut_support_region(), propose, taylor_expansion) @@ -89,16 +87,16 @@ def np_mean(self): cut_probs = self._cut_probs() tmp = cut_probs / (2.0 * cut_probs - 1.0) propose = tmp + 1.0 / (2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs)) - x = self.probability - 0.5 + x = self.probs - 0.5 taylor_expansion = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * np.square(x)) * x return np.where(self._cut_support_region(), propose, taylor_expansion) def np_entropy(self): - log_p = np.log(self.probability) - log_1_minus_p = np.log1p(-self.probability) + log_p = np.log(self.probs) + log_1_minus_p = np.log1p(-self.probs) return np.where( - np.equal(self.probability, 0.5), - np.full_like(self.probability, 0.0), + np.equal(self.probs, 0.5), + np.full_like(self.probs, 0.0), ( -self._log_constant() + self.np_mean() * (log_1_minus_p - log_p) @@ -112,8 +110,7 @@ def np_prob(self, value): def np_log_prob(self, value): eps = 1e-8 cross_entropy = np.nan_to_num( - value * np.log(self.probability) - + (1.0 - value) * np.log(1 - self.probability), + value * np.log(self.probs) + (1.0 - value) * np.log(1 - self.probs), neginf=-eps, ) return self._log_constant() + cross_entropy @@ -150,8 +147,8 @@ def np_icdf(self, value): def np_kl_divergence(self, other): part1 = -self.np_entropy() - log_q = np.log(other.probability) - log_1_minus_q = np.log1p(-other.probability) + log_q = np.log(other.probs) + log_1_minus_q = np.log1p(-other.probs) part2 = -( other._log_constant() + self.np_mean() * (log_q - log_1_minus_q) @@ -165,31 +162,31 @@ def np_kl_divergence(self, other): @parameterize.place(config.DEVICES) @parameterize.parameterize_cls( - (parameterize.TEST_CASE_NAME, 'probability'), + (parameterize.TEST_CASE_NAME, 'probs'), [ ( 'multi-dim', - parameterize.xrand((1, 3), min=0.0, max=0.498).astype("float32"), + parameterize.xrand((1, 3), min=0.0, max=1.0).astype("float32"), ), ], ) class TestContinuousBernoulli(unittest.TestCase): def setUp(self): - self._np_dist = ContinuousBernoulli_np(self.probability) + self._np_dist = ContinuousBernoulli_np(self.probs) startup_program = paddle.static.Program() main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) with paddle.static.program_guard(main_program, startup_program): - probability = paddle.static.data( - 'probability', self.probability.shape, self.probability.dtype + probs = paddle.static.data( + 'probs', self.probs.shape, self.probs.dtype ) - dist = ContinuousBernoulli(probability) + dist = ContinuousBernoulli(probs, lims=(0.48, 0.52)) mean = dist.mean var = dist.variance entropy = dist.entropy() large_samples = dist.sample(shape=(50000,)) fetch_list = [mean, var, entropy, large_samples] - feed = {'probability': self.probability} + feed = {'probs': self.probs} executor.run(startup_program) [ @@ -200,30 +197,26 @@ def setUp(self): ] = executor.run(main_program, feed=feed, fetch_list=fetch_list) def test_mean(self): - self.assertEqual( - str(self.mean.dtype).split('.')[-1], self.probability.dtype - ) + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.probs.dtype) np.testing.assert_allclose( self.mean, self._np_mean(), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) def test_variance(self): - self.assertEqual( - str(self.var.dtype).split('.')[-1], self.probability.dtype - ) + self.assertEqual(str(self.var.dtype).split('.')[-1], self.probs.dtype) np.testing.assert_allclose( self.var, self._np_variance(), - rtol=0.005, - atol=0.0, + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) def test_entropy(self): self.assertEqual( - str(self.entropy.dtype).split('.')[-1], self.probability.dtype + str(self.entropy.dtype).split('.')[-1], self.probs.dtype ) np.testing.assert_allclose( self.entropy, self._np_entropy(), rtol=0.005, atol=0 @@ -247,7 +240,7 @@ def _np_entropy(self): @parameterize.place(config.DEVICES) @parameterize.parameterize_cls( - (parameterize.TEST_CASE_NAME, 'probability', 'value'), + (parameterize.TEST_CASE_NAME, 'probs', 'value'), [ ( 'value-broadcast-shape', @@ -258,21 +251,21 @@ def _np_entropy(self): ) class TestContinuousBernoulliProbs(unittest.TestCase): def setUp(self): - self._np_dist = ContinuousBernoulli_np(self.probability) + self._np_dist = ContinuousBernoulli_np(self.probs) startup_program = paddle.static.Program() main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) with paddle.static.program_guard(main_program, startup_program): - probability = paddle.static.data( - 'probability', self.probability.shape, self.probability.dtype + probs = paddle.static.data( + 'probs', self.probs.shape, self.probs.dtype ) value = paddle.static.data( 'value', self.value.shape, self.value.dtype ) - dist = ContinuousBernoulli(probability) + dist = ContinuousBernoulli(probs, lims=(0.48, 0.52)) pmf = dist.prob(value) - feed = {'probability': self.probability, 'value': self.value} + feed = {'probs': self.probs, 'value': self.value} fetch_list = [pmf] executor.run(startup_program) @@ -284,8 +277,8 @@ def test_prob(self): np.testing.assert_allclose( self.pmf, self._np_dist.np_prob(self.value), - rtol=config.RTOL.get(str(self.probability.dtype)), - atol=config.ATOL.get(str(self.probability.dtype)), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), ) @@ -312,8 +305,8 @@ def setUp(self): with paddle.static.program_guard(main_program, startup_program): p_1 = paddle.static.data('p_1', self.p_1.shape) p_2 = paddle.static.data('p_2', self.p_2.shape) - dist1 = ContinuousBernoulli(p_1) - dist2 = ContinuousBernoulli(p_2) + dist1 = ContinuousBernoulli(p_1, lims=(0.48, 0.52)) + dist2 = ContinuousBernoulli(p_2, lims=(0.48, 0.52)) kl_dist1_dist2 = dist1.kl_divergence(dist2) feed = {'p_1': self.p_1, 'p_2': self.p_2} fetch_list = [kl_dist1_dist2] @@ -332,7 +325,7 @@ def test_kl_divergence(self): np.testing.assert_allclose( kl0, kl1, - rtol=0.005, + rtol=0.01, atol=0.0, ) From 08df9b88f720cc92174c62d9f024ab091db6fa1f Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 14 Dec 2023 14:38:37 +0800 Subject: [PATCH 26/29] increase cb static test timeout --- test/distribution/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distribution/CMakeLists.txt b/test/distribution/CMakeLists.txt index 95739040ef4af..d428e97c7b0ea 100644 --- a/test/distribution/CMakeLists.txt +++ b/test/distribution/CMakeLists.txt @@ -7,3 +7,6 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach() + +set_tests_properties(test_distribution_continuous_bernoulli_static + PROPERTIES TIMEOUT 30) From 5e821c748367cb9cd756c1e667bf2968ccc10854 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 14 Dec 2023 14:52:09 +0800 Subject: [PATCH 27/29] fix test time --- test/distribution/CMakeLists.txt | 3 -- ...istribution_continuous_bernoulli_static.py | 33 ------------------- 2 files changed, 36 deletions(-) diff --git a/test/distribution/CMakeLists.txt b/test/distribution/CMakeLists.txt index d428e97c7b0ea..95739040ef4af 100644 --- a/test/distribution/CMakeLists.txt +++ b/test/distribution/CMakeLists.txt @@ -7,6 +7,3 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach() - -set_tests_properties(test_distribution_continuous_bernoulli_static - PROPERTIES TIMEOUT 30) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index 87edbc893b89b..b4ab81f5da6eb 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -17,11 +17,6 @@ import numpy as np import parameterize from distribution import config -from parameterize import ( - TEST_CASE_NAME, - parameterize_cls, - parameterize_func, -) import paddle from paddle.distribution.continuous_bernoulli import ContinuousBernoulli @@ -330,33 +325,5 @@ def test_kl_divergence(self): ) -@parameterize.place(config.DEVICES) -@parameterize_cls([TEST_CASE_NAME], ['ContinuousBernoulliTestError']) -class ContinuousBernoulliTestError(unittest.TestCase): - def setUp(self): - self.program = paddle.static.Program() - self.executor = paddle.static.Executor(self.place) - - @parameterize_func( - [ - (100,), # int - (100.0,), # float - ] - ) - def test_bad_sample_shape_type(self, shape): - with paddle.static.program_guard(self.program): - rv = ContinuousBernoulli(0.3) - - with self.assertRaises(TypeError): - [_] = self.executor.run( - self.program, feed={}, fetch_list=[rv.sample(shape)] - ) - - with self.assertRaises(TypeError): - [_] = self.executor.run( - self.program, feed={}, fetch_list=[rv.rsample(shape)] - ) - - if __name__ == '__main__': unittest.main(argv=[''], verbosity=3, exit=False) From 750016078b257ebfb08a58f88fdb44e3df7bee33 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 14 Dec 2023 17:58:23 +0800 Subject: [PATCH 28/29] fix test --- ...test_distribution_continuous_bernoulli_static.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/distribution/test_distribution_continuous_bernoulli_static.py b/test/distribution/test_distribution_continuous_bernoulli_static.py index b4ab81f5da6eb..f768e38e8ed14 100644 --- a/test/distribution/test_distribution_continuous_bernoulli_static.py +++ b/test/distribution/test_distribution_continuous_bernoulli_static.py @@ -205,8 +205,8 @@ def test_variance(self): np.testing.assert_allclose( self.var, self._np_variance(), - rtol=config.RTOL.get(str(self.probs.dtype)), - atol=config.ATOL.get(str(self.probs.dtype)), + rtol=0.01, + atol=0.0, ) def test_entropy(self): @@ -214,14 +214,17 @@ def test_entropy(self): str(self.entropy.dtype).split('.')[-1], self.probs.dtype ) np.testing.assert_allclose( - self.entropy, self._np_entropy(), rtol=0.005, atol=0 + self.entropy, + self._np_entropy(), + rtol=0.01, + atol=0.0, ) def test_sample(self): sample_mean = self.large_samples.mean(axis=0) sample_variance = self.large_samples.var(axis=0) - np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.02) - np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.02) + np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.1) + np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.1) def _np_variance(self): return self._np_dist.np_variance() From 5e1816e8825d6ca8f0c237b09e1ecfd74a8a2f57 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 15 Dec 2023 11:55:33 +0800 Subject: [PATCH 29/29] update cb --- python/paddle/distribution/continuous_bernoulli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/continuous_bernoulli.py b/python/paddle/distribution/continuous_bernoulli.py index cb44b340c4c9b..1df7653f0103a 100644 --- a/python/paddle/distribution/continuous_bernoulli.py +++ b/python/paddle/distribution/continuous_bernoulli.py @@ -97,7 +97,7 @@ class ContinuousBernoulli(distribution.Distribution): [0.20103608, 0.07641447]) """ - def __init__(self, probs=None, lims=(0.499, 0.501)): + def __init__(self, probs, lims=(0.499, 0.501)): self.dtype = paddle.get_default_dtype() self.probs = self._to_tensor(probs) self.lims = paddle.to_tensor(lims, dtype=self.dtype)