From 8d992190daad4f0ad57a998b29e4edc9524849d1 Mon Sep 17 00:00:00 2001 From: dasen Date: Thu, 10 Nov 2022 10:01:34 +0800 Subject: [PATCH 01/18] dropout2d; test=document_fix --- python/paddle/distribution/kl.py | 4 + .../distribution/multivariate_normal.py | 113 ++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 python/paddle/distribution/multivariate_normal.py diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 6dae2f64fb733..28dc98469e972 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -23,6 +23,7 @@ from paddle.distribution.normal import Normal from paddle.distribution.uniform import Uniform from paddle.distribution.laplace import Laplace +from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.fluid.framework import _non_static_mode, in_dygraph_mode __all__ = ["register_kl", "kl_divergence"] @@ -163,6 +164,9 @@ def _kl_categorical_categorical(p, q): def _kl_normal_normal(p, q): return p.kl_divergence(q) +@register_kl(MultivariateNormal,MultivariateNormal) +def _kl_multnormal_multnormal(p, q): + return p.kl_divergence(q) @register_kl(Uniform, Uniform) def _kl_uniform_uniform(p, q): diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py new file mode 100644 index 0000000000000..188453d6fe001 --- /dev/null +++ b/python/paddle/distribution/multivariate_normal.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.distribution import distribution + +class MultivariateNormal(distribution.Distribution): + r""" + (MultivariateNormal Introduce) + + Args: + + Examples: + + """ + + def __init__(self): + pass + + @property + def mean(self): + """mean of multivariate_normal distribuion. + + Returns: + Tensor: mean value. + """ + pass + + @property + def variance(self): + """variance of multivariate_normal distribution. + + Returns: + Tensor: variance value. + """ + pass + + @property + def stddev(self): + """standard deviation of multivariate_normal distribution. + + Returns: + Tensor: variance value. + """ + pass + + def prob(self, value): + """probability mass function evaluated at value. + + Args: + value (Tensor): value to be evaluated. + + Returns: + Tensor: probability of value. + """ + pass + + def log_prob(self, value): + """probability mass function evaluated of logarithm at value + + Args: + value (Tensor): value to be evaluated. + + Returns: + Tensor: probability of value. + """ + pass + + def entropy(self): + """entropy of multivariate_normal distribution + + Returns: + Tensor: entropy value + """ + pass + + def sample(self, shape=()): + """draw sample data from multivariate_normal distribution + + Args: + shape (tuple, optional): [description]. Defaults to (). + """ + pass + + def rsample(self, shape=()): + """draw sample data from multivariate_normal distribution + + Args: + shape (tuple, optional): [description]. Defaults to (). + """ + pass + + def kl_divergence(self, other): + """calculate the KL divergence KL(self || other) with two MultivariateNormal instances. + + Args: + other (MultivariateNormal): An instance of MultivariateNormal. + + Returns: + Tensor: The kl-divergence between two multivariate_normal distributions. + """ + pass From 23e7ebdba085f5ef428a241ba5e40c16788de559 Mon Sep 17 00:00:00 2001 From: dasen Date: Thu, 10 Nov 2022 10:19:02 +0800 Subject: [PATCH 02/18] init multivariate_normal api --- python/paddle/distribution/multivariate_normal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 188453d6fe001..5b37903b32cf0 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -15,6 +15,7 @@ from paddle.distribution import distribution + class MultivariateNormal(distribution.Distribution): r""" (MultivariateNormal Introduce) From be5d83f7a7b586e0ece6d9c5b579cb7d36394e30 Mon Sep 17 00:00:00 2001 From: dasen Date: Mon, 14 Nov 2022 15:46:18 +0800 Subject: [PATCH 03/18] resolve conflicts --- python/paddle/distribution/kl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 28dc98469e972..1b4ab1f2e5a84 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -24,7 +24,7 @@ from paddle.distribution.uniform import Uniform from paddle.distribution.laplace import Laplace from paddle.distribution.multivariate_normal import MultivariateNormal -from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.framework import _non_static_mode __all__ = ["register_kl", "kl_divergence"] From 5a90ff00bf074006c6f44cdd209e9a3f153a99c2 Mon Sep 17 00:00:00 2001 From: dasen Date: Mon, 14 Nov 2022 17:07:22 +0800 Subject: [PATCH 04/18] rollback kl.py --- python/paddle/distribution/kl.py | 101 +++++++++++++++++-------------- 1 file changed, 56 insertions(+), 45 deletions(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 22a5c54528d8f..6146d6a51afc2 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -35,59 +35,45 @@ def kl_divergence(p, q): r""" Kullback-Leibler divergence between distribution p and q. - .. math:: - KL(p||q) = \int p(x)log\frac{p(x)}{q(x)} \mathrm{d}x - Args: p (Distribution): ``Distribution`` object. Inherits from the Distribution Base class. q (Distribution): ``Distribution`` object. Inherits from the Distribution Base class. - Returns: Tensor, Batchwise KL-divergence between distribution p and q. - Examples: - .. code-block:: python - import paddle - p = paddle.distribution.Beta(alpha=0.5, beta=0.5) q = paddle.distribution.Beta(alpha=0.3, beta=0.7) - print(paddle.distribution.kl_divergence(p, q)) # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [0.21193528]) - """ return _dispatch(type(p), type(q))(p, q) def register_kl(cls_p, cls_q): """Decorator for register a KL divergence implemention function. - The ``kl_divergence(p, q)`` function will search concrete implemention functions registered by ``register_kl``, according to multi-dispatch pattern. If an implemention function is found, it will return the result, otherwise, it will raise ``NotImplementError`` exception. Users can register implemention funciton by the decorator. - Args: cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``. cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``. - Examples: .. code-block:: python - import paddle - @paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta) def kl_beta_beta(): pass # insert implementation here """ - if (not issubclass(cls_p, Distribution) - or not issubclass(cls_q, Distribution)): + if not issubclass(cls_p, Distribution) or not issubclass( + cls_q, Distribution + ): raise TypeError('cls_p and cls_q must be subclass of Distribution') def decorator(f): @@ -98,11 +84,14 @@ def decorator(f): def _dispatch(cls_p, cls_q): - """Multiple dispatch into concrete implement function""" + """Multiple dispatch into concrete implement function.""" # find all matched super class pair of p and q - matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE - if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)] + matchs = [ + (super_p, super_q) + for super_p, super_q in _REGISTER_TABLE + if issubclass(cls_p, super_p) and issubclass(cls_q, super_q) + ] if not matchs: raise NotImplementedError @@ -111,16 +100,20 @@ def _dispatch(cls_p, cls_q): if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]: warnings.warn( - 'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'. - format(cls_p.__name__, cls_q.__name__, left_p.__name__, - right_q.__name__), RuntimeWarning) + 'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format( + cls_p.__name__, + cls_q.__name__, + left_p.__name__, + right_q.__name__, + ), + RuntimeWarning, + ) return _REGISTER_TABLE[left_p, left_q] @functools.total_ordering -class _Compare(object): - +class _Compare: def __init__(self, *classes): self.classes = classes @@ -138,22 +131,33 @@ def __le__(self, other): @register_kl(Beta, Beta) def _kl_beta_beta(p, q): - return ((q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma()) - - (p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma()) + - ((p.alpha - q.alpha) * p.alpha.digamma()) + - ((p.beta - q.beta) * p.beta.digamma()) + - (((q.alpha + q.beta) - (p.alpha + p.beta)) * - (p.alpha + p.beta).digamma())) + return ( + (q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma()) + - (p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma()) + + ((p.alpha - q.alpha) * p.alpha.digamma()) + + ((p.beta - q.beta) * p.beta.digamma()) + + ( + ((q.alpha + q.beta) - (p.alpha + p.beta)) + * (p.alpha + p.beta).digamma() + ) + ) @register_kl(Dirichlet, Dirichlet) def _kl_dirichlet_dirichlet(p, q): return ( - (p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma()) - - ((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)) + - (((p.concentration - q.concentration) * - (p.concentration.digamma() - - p.concentration.sum(-1).digamma().unsqueeze(-1))).sum(-1))) + (p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma()) + - ((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)) + + ( + ( + (p.concentration - q.concentration) + * ( + p.concentration.digamma() + - p.concentration.sum(-1).digamma().unsqueeze(-1) + ) + ).sum(-1) + ) + ) @register_kl(Categorical, Categorical) @@ -178,8 +182,7 @@ def _kl_laplace_laplace(p, q): @register_kl(ExponentialFamily, ExponentialFamily) def _kl_expfamily_expfamily(p, q): - """Compute kl-divergence using `Bregman divergences `_ - """ + """Compute kl-divergence using `Bregman divergences `_""" if not type(p) == type(q): raise NotImplementedError @@ -195,24 +198,32 @@ def _kl_expfamily_expfamily(p, q): try: if _non_static_mode(): - p_grads = paddle.grad(p_log_norm, - p_natural_params, - create_graph=True) + p_grads = paddle.grad( + p_log_norm, p_natural_params, create_graph=True + ) else: p_grads = paddle.static.gradients(p_log_norm, p_natural_params) except RuntimeError as e: raise TypeError( - "Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q})." - .format(cls_p=type(p).__name__, cls_q=type(q).__name__)) from e + "Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q}).".format( + cls_p=type(p).__name__, cls_q=type(q).__name__ + ) + ) from e kl = q._log_normalizer(*q_natural_params) - p_log_norm - for p_param, q_param, p_grad in zip(p_natural_params, q_natural_params, - p_grads): + for p_param, q_param, p_grad in zip( + p_natural_params, q_natural_params, p_grads + ): term = (q_param - p_param) * p_grad kl -= _sum_rightmost(term, len(q.event_shape)) return kl +@register_kl(LogNormal, LogNormal) +def _kl_lognormal_lognormal(p, q): + return p._base.kl_divergence(q._base) + + def _sum_rightmost(value, n): return value.sum(list(range(-n, 0))) if n > 0 else value From 29b05784a39fd202a85ce0feabff02593189da70 Mon Sep 17 00:00:00 2001 From: dasen Date: Fri, 25 Nov 2022 10:45:10 +0800 Subject: [PATCH 05/18] update: init / mean / variance / stddev --- .../distribution/multivariate_normal.py | 74 ++++++++++++++++--- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 5b37903b32cf0..a68b0fb69efb3 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import paddle +import math from paddle.distribution import distribution @@ -26,8 +27,34 @@ class MultivariateNormal(distribution.Distribution): """ - def __init__(self): - pass + def __init__(self, loc, covariance_matrix=None): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + if (covariance_matrix is not None) != 1: + raise ValueError("Exactly covariance_matrix may be specified.") + + if covariance_matrix is not None: + if covariance_matrix.dim() < 2: + raise ValueError("covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions") + if(covariance_matrix.shape[:-2] == [] or loc.shape[:-1] == []): + batch_shape = [] + else: + batch_shape = paddle.broadcast_shape(covariance_matrix.shape[:-2], loc.shape[:-1]) + self.covariance_matrix = covariance_matrix.expand(batch_shape + [-1, -1]) + self.loc = loc.expand(batch_shape + [-1]) + + event_shape = self.loc.shape[-1:] + super(MultivariateNormal, self).__init__(batch_shape, event_shape) + + if covariance_matrix is not None: + self._unbroadcasted_scale_tril = paddle.linalg.cholesky(covariance_matrix) + + def covariance_matrix(self): + res1 = paddle.matmul(self._unbroadcasted_scale_tril, + self._unbroadcasted_scale_tril.T) + + return res1.expand(res1, self._batch_shape + self._event_shape + self._event_shape) @property def mean(self): @@ -36,7 +63,7 @@ def mean(self): Returns: Tensor: mean value. """ - pass + return self.loc @property def variance(self): @@ -45,7 +72,8 @@ def variance(self): Returns: Tensor: variance value. """ - pass + matrix_decompos = paddle.linalg.cholesky(self.covariance_matrix).pow(2).sum(-1) + return paddle.broadcast_to(matrix_decompos, self._batch_shape + self._event_shape) @property def stddev(self): @@ -54,7 +82,7 @@ def stddev(self): Returns: Tensor: variance value. """ - pass + return paddle.sqrt(self.variance) def prob(self, value): """probability mass function evaluated at value. @@ -65,7 +93,12 @@ def prob(self, value): Returns: Tensor: probability of value. """ - pass + x = paddle.pow(2 * math.pi, -value.shape.pop(1) * 0.5) * paddle.pow(paddle.linalg.det(self.covariance_matrix), + -0.5) + y = paddle.exp( + -0.5 * paddle.t(value - self.loc) * paddle.inverse(self.covariance_matrix) * (value - self.loc)) + + return x * y def log_prob(self, value): """probability mass function evaluated of logarithm at value @@ -76,7 +109,7 @@ def log_prob(self, value): Returns: Tensor: probability of value. """ - pass + return paddle.log(self.prob(value)) def entropy(self): """entropy of multivariate_normal distribution @@ -84,7 +117,8 @@ def entropy(self): Returns: Tensor: entropy value """ - pass + sigma = paddle.linalg.det(self.covariance_matrix) + return 0.5 * paddle.log(paddle.pow(2 * math.pi * math.e, self.loc.dim()) * sigma) def sample(self, shape=()): """draw sample data from multivariate_normal distribution @@ -92,7 +126,8 @@ def sample(self, shape=()): Args: shape (tuple, optional): [description]. Defaults to (). """ - pass + with paddle.no_grad: + self.rsample(shape) def rsample(self, shape=()): """draw sample data from multivariate_normal distribution @@ -100,7 +135,12 @@ def rsample(self, shape=()): Args: shape (tuple, optional): [description]. Defaults to (). """ - pass + shape = self._extend_shape(shape) + eps = paddle.standard_normal(shape, dtype=None, name=None) + unbroadcasted_scale_tril = paddle.linalg.cholesky(self.covariance_matrix) + + return self.loc + self._batch_mv(unbroadcasted_scale_tril, eps) + def kl_divergence(self, other): """calculate the KL divergence KL(self || other) with two MultivariateNormal instances. @@ -111,4 +151,14 @@ def kl_divergence(self, other): Returns: Tensor: The kl-divergence between two multivariate_normal distributions. """ - pass + sector_1 = paddle.t(self.loc - other.loc) * paddle.inverse(other.covariance_matrix) * (self.loc - other.loc) + sector_2 = paddle.log(paddle.linalg.det(paddle.inverse(other.covariance_matrix) * self.covariance_matrix)) + sector_3 = paddle.trace(paddle.inverse(other.covariance_matrix) * self.covariance_matrix) + n = self.loc.shape.pop(1) + return 0.5 * (sector_1 - sector_2 + sector_3 - n) + + + def _batch_mv(self,bmat, bvec): + bvec_unsqueeze = paddle.unsqueeze(bvec, 1) + bvec = paddle.squeeze(bvec_unsqueeze) + return paddle.matmul(bmat, bvec) From 9ffdadbdafdf64fe79b5b70033ad11eb158ca89e Mon Sep 17 00:00:00 2001 From: dasen Date: Sun, 4 Dec 2022 16:41:24 +0800 Subject: [PATCH 06/18] fix: deal conflict --- python/paddle/distribution/kl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 6146d6a51afc2..51e95f00fd33e 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -24,7 +24,6 @@ from paddle.distribution.lognormal import LogNormal from paddle.distribution.uniform import Uniform from paddle.distribution.laplace import Laplace -from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.fluid.framework import _non_static_mode __all__ = ["register_kl", "kl_divergence"] From 653793d88eef5e2f11a4fccd68154efc9598c782 Mon Sep 17 00:00:00 2001 From: dasen Date: Sun, 18 Dec 2022 15:12:59 +0800 Subject: [PATCH 07/18] update prob / log_prob --- .../distribution/multivariate_normal.py | 76 +++++++++++++++---- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index a68b0fb69efb3..3f8cb465788da 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -73,7 +73,7 @@ def variance(self): Tensor: variance value. """ matrix_decompos = paddle.linalg.cholesky(self.covariance_matrix).pow(2).sum(-1) - return paddle.broadcast_to(matrix_decompos, self._batch_shape + self._event_shape) + return paddle.expand(matrix_decompos, self._batch_shape + self._event_shape) @property def stddev(self): @@ -93,23 +93,17 @@ def prob(self, value): Returns: Tensor: probability of value. """ - x = paddle.pow(2 * math.pi, -value.shape.pop(1) * 0.5) * paddle.pow(paddle.linalg.det(self.covariance_matrix), - -0.5) - y = paddle.exp( - -0.5 * paddle.t(value - self.loc) * paddle.inverse(self.covariance_matrix) * (value - self.loc)) - - return x * y + return paddle.exp(self.log_prob(value)) def log_prob(self, value): - """probability mass function evaluated of logarithm at value + # if self._validate_args: + # self._validate_sample(value) + diff = value - self.loc + M = self._batch_mahalanobis(self._unbroadcasted_scale_tril, diff) - Args: - value (Tensor): value to be evaluated. + half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril,axis1=-2, axis2=-1).log().sum(-1) - Returns: - Tensor: probability of value. - """ - return paddle.log(self.prob(value)) + return -0.5 * (self.event_shape[0] * math.log(2 * math.pi) + M) - half_log_det def entropy(self): """entropy of multivariate_normal distribution @@ -117,8 +111,15 @@ def entropy(self): Returns: Tensor: entropy value """ - sigma = paddle.linalg.det(self.covariance_matrix) - return 0.5 * paddle.log(paddle.pow(2 * math.pi * math.e, self.loc.dim()) * sigma) + # sigma = paddle.linalg.det(self.covariance_matrix) + # return 0.5 * paddle.log(paddle.pow(paddle.to_tensor([2 * math.pi * math.e],dtype=paddle.float32), self.loc.dim()) * sigma) + + half_log_det = self._unbroadcasted_scale_tril.diagonal(axois=-2, dim2=-1).log().sum(-1) + H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) def sample(self, shape=()): """draw sample data from multivariate_normal distribution @@ -162,3 +163,46 @@ def _batch_mv(self,bmat, bvec): bvec_unsqueeze = paddle.unsqueeze(bvec, 1) bvec = paddle.squeeze(bvec_unsqueeze) return paddle.matmul(bmat, bvec) + + def _batch_mahalanobis(self, bL, bx): + n = bx.shape[-1] + bx_batch_shape = bx.shape[:-1] + bx_batch_dims = len(bx_batch_shape) + bL_batch_dims = bL.ndim - 2 + + outer_batch_dims = bx_batch_dims - bL_batch_dims + old_batch_dims = outer_batch_dims + bL_batch_dims + new_batch_dims = outer_batch_dims + 2 * bL_batch_dims + bx_new_shape = bx.shape[:outer_batch_dims] + + for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + bx_new_shape += (sx // sL, sL) + + bx_new_shape += (n,) + bx = paddle.reshape(bx, bx_new_shape) + + permute_dims = (list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims]) + + bx = paddle.transpose(bx, perm=permute_dims) + # shape = [b, n, n] + flat_L = paddle.reshape(bL, [1, n, n]) + # shape = [c, b, n] + flat_x = paddle.reshape(bx, [n, flat_L.shape[0], n]) + + # shape = [b, n, c] + flat_x_swap = paddle.transpose(flat_x, perm=[1, 2, 0]) + # shape = [b, c] + M_swap = paddle.linalg.triangular_solve(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + M = M_swap.t() + # shape = [..., 1, j, i, 1] + permuted_M = paddle.reshape(M, bx.shape[:-1]) + permute_inv_dims = list(range(outer_batch_dims)) + + for i in range(bL_batch_dims): + permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + # shape = [..., 1, i, j, 1] + reshaped_M = paddle.transpose(permuted_M, perm=permute_inv_dims) + return paddle.reshape(reshaped_M, bx_batch_shape) From 89168bf0589dd8ac9c0efa65ee4e8c29b15154b6 Mon Sep 17 00:00:00 2001 From: dasen Date: Sun, 18 Dec 2022 17:50:22 +0800 Subject: [PATCH 08/18] update entrop / sample / rsample --- .../distribution/multivariate_normal.py | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 3f8cb465788da..641cbe714738d 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -114,27 +114,27 @@ def entropy(self): # sigma = paddle.linalg.det(self.covariance_matrix) # return 0.5 * paddle.log(paddle.pow(paddle.to_tensor([2 * math.pi * math.e],dtype=paddle.float32), self.loc.dim()) * sigma) - half_log_det = self._unbroadcasted_scale_tril.diagonal(axois=-2, dim2=-1).log().sum(-1) + half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril,axis1=-2, axis2=-1).log().sum(-1) H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det if len(self._batch_shape) == 0: return H else: return H.expand(self._batch_shape) - def sample(self, shape=()): + def sample(self, shape=[]): """draw sample data from multivariate_normal distribution Args: - shape (tuple, optional): [description]. Defaults to (). + shape (list, optional): [description]. Defaults to []. """ - with paddle.no_grad: - self.rsample(shape) + with paddle.no_grad(): + return self.rsample(shape) - def rsample(self, shape=()): + def rsample(self, shape=[]): """draw sample data from multivariate_normal distribution Args: - shape (tuple, optional): [description]. Defaults to (). + shape (list, optional): [description]. Defaults to []. """ shape = self._extend_shape(shape) eps = paddle.standard_normal(shape, dtype=None, name=None) @@ -187,15 +187,17 @@ def _batch_mahalanobis(self, bL, bx): [new_batch_dims]) bx = paddle.transpose(bx, perm=permute_dims) - # shape = [b, n, n] + # shape = [a, n, n] flat_L = paddle.reshape(bL, [1, n, n]) - # shape = [c, b, n] + # shape = [b, a, n] flat_x = paddle.reshape(bx, [n, flat_L.shape[0], n]) - # shape = [b, n, c] + # shape = [a, n, b] flat_x_swap = paddle.transpose(flat_x, perm=[1, 2, 0]) - # shape = [b, c] + + # shape = [a, b] M_swap = paddle.linalg.triangular_solve(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + # shape = [b, a] M = M_swap.t() # shape = [..., 1, j, i, 1] permuted_M = paddle.reshape(M, bx.shape[:-1]) @@ -203,6 +205,18 @@ def _batch_mahalanobis(self, bL, bx): for i in range(bL_batch_dims): permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + # shape = [..., 1, i, j, 1] reshaped_M = paddle.transpose(permuted_M, perm=permute_inv_dims) return paddle.reshape(reshaped_M, bx_batch_shape) + + def _extend_shape(self, sample_shape): + """compute shape of the sample + + Args: + sample_shape (Tensor): sample shape + + Returns: + Tensor: generated sample data shape + """ + return sample_shape + list(self.batch_shape) + list(self.event_shape) From 17213008651fd7fdd47a829f86c22831d2704359 Mon Sep 17 00:00:00 2001 From: dasen Date: Wed, 21 Dec 2022 21:43:37 +0800 Subject: [PATCH 09/18] update kl_divergence --- .../distribution/multivariate_normal.py | 85 +++++++++++-------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 641cbe714738d..266d257f2b975 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -111,9 +111,6 @@ def entropy(self): Returns: Tensor: entropy value """ - # sigma = paddle.linalg.det(self.covariance_matrix) - # return 0.5 * paddle.log(paddle.pow(paddle.to_tensor([2 * math.pi * math.e],dtype=paddle.float32), self.loc.dim()) * sigma) - half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril,axis1=-2, axis2=-1).log().sum(-1) H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det if len(self._batch_shape) == 0: @@ -144,27 +141,47 @@ def rsample(self, shape=[]): def kl_divergence(self, other): - """calculate the KL divergence KL(self || other) with two MultivariateNormal instances. - - Args: - other (MultivariateNormal): An instance of MultivariateNormal. + """ - Returns: - Tensor: The kl-divergence between two multivariate_normal distributions. """ - sector_1 = paddle.t(self.loc - other.loc) * paddle.inverse(other.covariance_matrix) * (self.loc - other.loc) - sector_2 = paddle.log(paddle.linalg.det(paddle.inverse(other.covariance_matrix) * self.covariance_matrix)) - sector_3 = paddle.trace(paddle.inverse(other.covariance_matrix) * self.covariance_matrix) - n = self.loc.shape.pop(1) - return 0.5 * (sector_1 - sector_2 + sector_3 - n) + if self.event_shape != other.event_shape: + raise ValueError("KL-divergence between two Multivariate Normals with\ + different event shapes cannot be computed") + + half_term1 = (paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) - + paddle.diagonal(other._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1)) + combined_batch_shape = [] + n = self.event_shape[0] + self_scale_tril = self._unbroadcasted_scale_tril.expand(combined_batch_shape + [n, n]) + other_scale_tril = other._unbroadcasted_scale_tril.expand(combined_batch_shape + [n, n]) + term2 = self._batch_trace_XXT(paddle.linalg.triangular_solve(self_scale_tril, other_scale_tril, upper=False)) + term3 = self._batch_mahalanobis(self._unbroadcasted_scale_tril, (self.loc - other.loc)) + return half_term1 + 0.5 * (term2 + term3 - n) + + def _batch_trace_XXT(self, bmat): + """ + """ + n = bmat.shape[-1] + m = bmat.shape[-2] + flat_trace = paddle.reshape(bmat, [1, m * n]).pow(2).sum(-1) + if( bmat.shape[:-2] == []): + return flat_trace + else: + return paddle.reshape(flat_trace, bmat.shape[:-2]) def _batch_mv(self,bmat, bvec): + """ + + """ bvec_unsqueeze = paddle.unsqueeze(bvec, 1) bvec = paddle.squeeze(bvec_unsqueeze) return paddle.matmul(bmat, bvec) def _batch_mahalanobis(self, bL, bx): + """ + + """ n = bx.shape[-1] bx_batch_shape = bx.shape[:-1] bx_batch_dims = len(bx_batch_shape) @@ -177,8 +194,7 @@ def _batch_mahalanobis(self, bL, bx): for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): bx_new_shape += (sx // sL, sL) - - bx_new_shape += (n,) + bx_new_shape += [n] bx = paddle.reshape(bx, bx_new_shape) permute_dims = (list(range(outer_batch_dims)) + @@ -187,28 +203,29 @@ def _batch_mahalanobis(self, bL, bx): [new_batch_dims]) bx = paddle.transpose(bx, perm=permute_dims) - # shape = [a, n, n] - flat_L = paddle.reshape(bL, [1, n, n]) - # shape = [b, a, n] - flat_x = paddle.reshape(bx, [n, flat_L.shape[0], n]) - - # shape = [a, n, b] + # shape = [b, n, n] + flat_L = paddle.reshape(bL, [-1, n, n]) + # shape = [c, b, n] + flat_x = paddle.reshape(bx, [-1, flat_L.shape[0], n]) + # shape = [b, n, c] flat_x_swap = paddle.transpose(flat_x, perm=[1, 2, 0]) - - # shape = [a, b] + # shape = [b, c] M_swap = paddle.linalg.triangular_solve(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) - # shape = [b, a] + # shape = [c, b] M = M_swap.t() - # shape = [..., 1, j, i, 1] - permuted_M = paddle.reshape(M, bx.shape[:-1]) - permute_inv_dims = list(range(outer_batch_dims)) - for i in range(bL_batch_dims): - permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] - - # shape = [..., 1, i, j, 1] - reshaped_M = paddle.transpose(permuted_M, perm=permute_inv_dims) - return paddle.reshape(reshaped_M, bx_batch_shape) + if bx.shape[:-1] == []: + return M.sum() + else: + # shape = [..., 1, j, i, 1] + permuted_M = paddle.reshape(M, bx.shape[:-1]) + permute_inv_dims = list(range(outer_batch_dims)) + + for i in range(bL_batch_dims): + permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + # shape = [..., 1, i, j, 1] + reshaped_M = paddle.transpose(permuted_M, perm=permute_inv_dims) + return paddle.reshape(reshaped_M, bx_batch_shape) def _extend_shape(self, sample_shape): """compute shape of the sample From 535e377067049c3d6ba55fc9eac0166dc1376142 Mon Sep 17 00:00:00 2001 From: dasen Date: Wed, 21 Dec 2022 21:46:19 +0800 Subject: [PATCH 10/18] regist kl --- python/paddle/distribution/kl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 57f50b6a79b68..47779df79f88f 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -24,6 +24,7 @@ from paddle.distribution.lognormal import LogNormal from paddle.distribution.normal import Normal from paddle.distribution.uniform import Uniform +from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.fluid.framework import _non_static_mode __all__ = ["register_kl", "kl_divergence"] @@ -178,6 +179,9 @@ def _kl_uniform_uniform(p, q): def _kl_laplace_laplace(p, q): return p.kl_divergence(q) +@register_kl(MultivariateNormal,MultivariateNormal) +def _kl_multnormal_multnormal(p,q): + return p.kl_divergence(q) @register_kl(ExponentialFamily, ExponentialFamily) def _kl_expfamily_expfamily(p, q): From 1d4a8a121c84cefefc61c878fc81db6bade974b1 Mon Sep 17 00:00:00 2001 From: dasen Date: Sat, 31 Dec 2022 09:02:52 +0800 Subject: [PATCH 11/18] complete en api docs --- .../distribution/multivariate_normal.py | 314 ++++++++++++++---- 1 file changed, 244 insertions(+), 70 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 266d257f2b975..5fa46ca07b0d7 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -19,15 +19,57 @@ class MultivariateNormal(distribution.Distribution): r""" - (MultivariateNormal Introduce) - - Args: - - Examples: - - """ - - def __init__(self, loc, covariance_matrix=None): + MultivariateNormal distribution parameterized by :attr:`loc` and :attr:`covariance_matrix`. + + The probability mass function (PMF) for multivariate_normal is + + .. math:: + + f_\boldsymbol{X}(x_1,...,x_k) = \frac{exp(-\frac{1}{2}$\mathbf{(\boldsymbol{x - \mu})}^\top$\boldsymbol{\Sigma}^{-1}(\boldsymbol{x - \mu}))}{\sqrt{(2\pi)^k\left| \boldsymbol{\Sigma} \right|}} + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`covariance\_matrix = \Sigma`: is the multivariate normal distribution covariance matrix is established when the covariance matrix is a positive semi-definite matrix. + + Args: + loc(tensor): MultivariateNormal distribution location parameter. The data type is Tensor. + covariance\_matrix(tensor): MultivariateNormal distribution covariance matrix parameter. The data type is Tensor, and the parameter must be a positive semi-definite matrix. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution.multivariate_normal import MultivariateNormal + # MultivariateNormal distributed with loc=torch.tensor([0,1],dtype=torch.float32), covariance_matrix=torch.tensor([[2,1],[1,2]],dtype=torch.float32) + dist = MultivariateNormal(torch.tensor([0,1],dtype=torch.float32),torch.tensor([[2,1],[1,2]],dtype=torch.float32)) + dist.sample([2,2]) + #Tensor(shape=[2, 2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[[-1.24544513, 0.24218500], + # [-0.26033771, 0.36445701]], + + # [[ 0.41002670, 1.30887973], + # [-0.39297765, 1.32064724]]]) + value = paddle.to_tensor([[2,1],[1,2]],dtype=paddle.float32) + dist.prob(value) + #Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,[0.02422146, 0.06584076]) + dist.log_prob(value) + #Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,[-3.72051620, -2.72051620]) + dist.entropy() + #Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,[3.38718319]) + dist.rsample([2,2]) + #Tensor(shape=[2, 2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[[-2.64208245, 2.58928585], + # [-2.26590896, 2.81269646]], + + # [[ 1.51346231, 1.07011509], + # [ 2.11932302, 0.55175352]]]) + dist_kl = MultivariateNormal(paddle.to_tensor([1,2],dtype=paddle.float32),paddle.to_tensor([[4,2],[2,4]],dtype=paddle.float32)) + dist.kl_divergence(dist_kl) + #Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,[0.64018595]) + """ + + def __init__(self, loc, covariance_matrix): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) != 1: @@ -37,81 +79,155 @@ def __init__(self, loc, covariance_matrix=None): if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - if(covariance_matrix.shape[:-2] == [] or loc.shape[:-1] == []): + if (covariance_matrix.shape[:-2] == [] or loc.shape[:-1] == []): batch_shape = [] else: batch_shape = paddle.broadcast_shape(covariance_matrix.shape[:-2], loc.shape[:-1]) self.covariance_matrix = covariance_matrix.expand(batch_shape + [-1, -1]) - self.loc = loc.expand(batch_shape + [-1]) + self.loc = loc.expand(batch_shape + [-1]) event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape) if covariance_matrix is not None: self._unbroadcasted_scale_tril = paddle.linalg.cholesky(covariance_matrix) - def covariance_matrix(self): - res1 = paddle.matmul(self._unbroadcasted_scale_tril, - self._unbroadcasted_scale_tril.T) - - return res1.expand(res1, self._batch_shape + self._event_shape + self._event_shape) - @property def mean(self): - """mean of multivariate_normal distribuion. + r"""Mean of distribution + + The mean is + + .. math:: + + mean = \mu + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. Returns: Tensor: mean value. + """ return self.loc @property def variance(self): - """variance of multivariate_normal distribution. + r"""Variance of distribution. + + The variance is + + .. math:: + + variance = \boldsymbol{\sigma^2} + + In the above equation: + + * :math:`scale = \sigma`: is scale vector obtained after matrix decomposition of multivariate normal distribution covariance matrix. Returns: - Tensor: variance value. + Tensor: The variance value. + """ matrix_decompos = paddle.linalg.cholesky(self.covariance_matrix).pow(2).sum(-1) return paddle.expand(matrix_decompos, self._batch_shape + self._event_shape) @property def stddev(self): - """standard deviation of multivariate_normal distribution. + r"""Standard deviation of distribution + + The standard deviation is + + .. math:: + + stddev = \boldsymbol{\sigma} + + In the above equation: + * :math:`scale = \sigma`: is scale vector obtained after matrix decomposition of multivariate normal distribution covariance matrix. Returns: - Tensor: variance value. + Tensor: std value """ return paddle.sqrt(self.variance) def prob(self, value): - """probability mass function evaluated at value. + r"""Probability density/mass function + + The probability density is + + .. math:: + + prob(value) = \frac{exp(-\frac{1}{2}$\mathbf{(\boldsymbol{value - \mu})}^\top$\boldsymbol{\Sigma}^{-1}(\boldsymbol{value- \mu}))}{\sqrt{(2\pi)^k\left| \boldsymbol{\Sigma} \right|}} + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`covariance\_matrix = \Sigma`: is the multivariate normal distribution covariance matrix is established when the covariance matrix is a positive semi-definite matrix. Args: - value (Tensor): value to be evaluated. + value (Tensor): The input tensor. Returns: - Tensor: probability of value. + Tensor: probability.The data type is same with value. """ + if not isinstance(value, type(self.loc)): + raise TypeError( + f"Expected type of value is {type(loc)}, but got {type(value)}" + ) + return paddle.exp(self.log_prob(value)) def log_prob(self, value): - # if self._validate_args: - # self._validate_sample(value) + r"""Log probability density/mass function. + + The log probability density is + + .. math:: + + log\_prob(value) = log(\frac{exp(-\frac{1}{2}$\mathbf{(\boldsymbol{value - \mu})}^\top$\boldsymbol{\Sigma}^{-1}(\boldsymbol{value- \mu}))}{\sqrt{(2\pi)^k\left| \boldsymbol{\Sigma} \right|}}) + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`covariance\_matrix = \Sigma`: is the multivariate normal distribution covariance matrix is established when the covariance matrix is a positive semi-definite matrix. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: log probability.The data type is same with value. + """ + if not isinstance(value, type(self.loc)): + raise TypeError( + f"Expected type of value is {type(loc)}, but got {type(value)}" + ) + diff = value - self.loc M = self._batch_mahalanobis(self._unbroadcasted_scale_tril, diff) - half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril,axis1=-2, axis2=-1).log().sum(-1) + half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) return -0.5 * (self.event_shape[0] * math.log(2 * math.pi) + M) - half_log_det def entropy(self): - """entropy of multivariate_normal distribution + r"""Entropy of multivariate_normal distribution + + The entropy is + + .. math:: + + entropy() = \frac{k}{2}(\ln 2\pi + 1) + \frac{1}{2}\ln \left| \boldsymbol{\Sigma} \right| + + In the above equation: + + * :math:`k`: The dimension of the multivariate normal distribution vector, such as one-dimensional vector k=1, two-dimensional vector (matrix) k=2. + * :math:`covariance\_matrix = \Sigma`: is the multivariate normal distribution covariance matrix is established when the covariance matrix is a positive semi-definite matrix. Returns: Tensor: entropy value """ - half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril,axis1=-2, axis2=-1).log().sum(-1) + half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det if len(self._batch_shape) == 0: return H @@ -119,94 +235,152 @@ def entropy(self): return H.expand(self._batch_shape) def sample(self, shape=[]): - """draw sample data from multivariate_normal distribution + """Draw sample data from multinomial distribution Args: - shape (list, optional): [description]. Defaults to []. + shape (Sequence[int], optional): Shape of the generated samples. Defaults to []. """ with paddle.no_grad(): return self.rsample(shape) def rsample(self, shape=[]): - """draw sample data from multivariate_normal distribution + """Generate reparameterized samples of the specified shape. Args: - shape (list, optional): [description]. Defaults to []. + shape (Sequence[int], optional): Shape of the generated samples. Defaults to []. + + Returns: + Tensor: A tensor with prepended dimensions shape.The data type is float32. + """ shape = self._extend_shape(shape) eps = paddle.standard_normal(shape, dtype=None, name=None) unbroadcasted_scale_tril = paddle.linalg.cholesky(self.covariance_matrix) - return self.loc + self._batch_mv(unbroadcasted_scale_tril, eps) - + return self.loc + self._batch_product_mv(unbroadcasted_scale_tril, eps) def kl_divergence(self, other): - """ + r"""Calculate the KL divergence KL(self || other) with two MultivariateNormal instances. + + The kl_divergence between two MultivariateNormal distribution is + + .. math:: + KL\_divergence(\boldsymbol{\mu_1}, \boldsymbol{\Sigma_1}; \boldsymbol{\mu_2}, \boldsymbol{\Sigma_2}) = + \frac{1}{2}\Big \{\log ratio -n + tr(\boldsymbol{\Sigma_2}^{-1}\boldsymbol{\Sigma_1}) + + $\mathbf{(diff)}^\top$\boldsymbol{\Sigma_2}^{-1}\boldsymbol{(diff)} \Big \} + + .. math:: + ratio = \frac{\left| \boldsymbol{\Sigma_2} \right|}{\left| \boldsymbol{\Sigma_1} \right|} + + .. math:: + \boldsymbol{diff} = \boldsymbol{\mu_2} - \boldsymbol{\mu_1} + + In the above equation: + + * :math:`loc = \mu_1`: is the location parameter of self. + * :math:`covariance_matrix = \Sigma_1`: is the covariance_matrix parameter of self. + * :math:`loc = \mu_2`: is the location parameter of the reference MultivariateNormal distribution. + * :math:`covariance_matrix = \Sigma_2`: is the covariance_matrix parameter of the reference MultivariateNormal distribution. + * :math:`ratio`: is the ratio of the determinant values of the two covariance matrices. + * :math:`diff`: is the difference between the two distribution. + * :math:`n`: is dimension. + * :math:`tr`: is matrix trace. + + Args: + other (MultivariateNormal): instance of MultivariateNormal. + + Returns: + Tensor: kl-divergence between two multivariate_normal distributions. """ if self.event_shape != other.event_shape: raise ValueError("KL-divergence between two Multivariate Normals with\ different event shapes cannot be computed") - half_term1 = (paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) - - paddle.diagonal(other._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1)) - combined_batch_shape = [] + sector1 = (paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) - + paddle.diagonal(other._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1)) + if list(self.batch_shape) == [] and list(other.batch_shape) == []: + combined_batch_shape = [] + else: + combined_batch_shape = [self.batch_shape, other.batch_shape] n = self.event_shape[0] self_scale_tril = self._unbroadcasted_scale_tril.expand(combined_batch_shape + [n, n]) other_scale_tril = other._unbroadcasted_scale_tril.expand(combined_batch_shape + [n, n]) - term2 = self._batch_trace_XXT(paddle.linalg.triangular_solve(self_scale_tril, other_scale_tril, upper=False)) - term3 = self._batch_mahalanobis(self._unbroadcasted_scale_tril, (self.loc - other.loc)) - return half_term1 + 0.5 * (term2 + term3 - n) + sector2 = self._batch_trace_XXT(paddle.linalg.triangular_solve(self_scale_tril, other_scale_tril, upper=False)) + sector3 = self._batch_mahalanobis(self._unbroadcasted_scale_tril, (self.loc - other.loc)) + return sector1 + 0.5 * (sector2 + sector3 - n) - def _batch_trace_XXT(self, bmat): - """ + def _batch_trace_XXT(self, batch_matrix): + """Calculate the trace of XX^{T} with X having arbitrary trailing batch dimensions. + + Args: + batch_matrix (Tensor): a tensor with arbitrary trailing batch dimensions + Returns: + Tensor: generated the trace of XX^{T} with X """ - n = bmat.shape[-1] - m = bmat.shape[-2] - flat_trace = paddle.reshape(bmat, [1, m * n]).pow(2).sum(-1) - if( bmat.shape[:-2] == []): + n = batch_matrix.shape[-1] + m = batch_matrix.shape[-2] + flat_trace = paddle.reshape(batch_matrix, [1, m * n]).pow(2).sum(-1) + if (batch_matrix.shape[:-2] == []): return flat_trace else: - return paddle.reshape(flat_trace, bmat.shape[:-2]) + return paddle.reshape(flat_trace, batch_matrix.shape[:-2]) - def _batch_mv(self,bmat, bvec): - """ + def _batch_product_mv(self, batch_matrix, batch_vector): + """Performs a batched matrix-vector product, with compatible but different batch shapes. + Both `batch_matrix` and `batch_vector` may have any number of leading dimensions, which + correspond to a batch shape. They are not necessarily assumed to have the same batch + shape,just ones which can be broadcasted. - """ - bvec_unsqueeze = paddle.unsqueeze(bvec, 1) - bvec = paddle.squeeze(bvec_unsqueeze) - return paddle.matmul(bmat, bvec) + Args: + batch_matrix (Tensor): batch matrix tensor with any number of leading dimensions + batch_vector (Tensor): batch vector tensor with any number of leading dimensions - def _batch_mahalanobis(self, bL, bx): + Returns: + Tensor: a batched matrix-vector product """ + batch_vector_unsqueeze = paddle.unsqueeze(batch_vector, 1) + batch_vector = paddle.squeeze(batch_vector_unsqueeze) + return paddle.matmul(batch_matrix, batch_vector) + def _batch_mahalanobis(self, batch_L, batch_x): + """Computes the squared Mahalanobis distance which assess the similarity between data. + Accepts batches for both batch_L and batch_x. They are not necessarily assumed to have + the same batch shape, but `batch_L` one should be able to broadcasted to `batch_x` one. + + Args: + batch_L (Tensor): tensor after matrix factorization + batch_x (Tensor): difference between two tensors + + Returns: + Tensor: the squared Mahalanobis distance """ - n = bx.shape[-1] - bx_batch_shape = bx.shape[:-1] + n = batch_x.shape[-1] + bx_batch_shape = batch_x.shape[:-1] bx_batch_dims = len(bx_batch_shape) - bL_batch_dims = bL.ndim - 2 + bL_batch_dims = batch_L.ndim - 2 outer_batch_dims = bx_batch_dims - bL_batch_dims old_batch_dims = outer_batch_dims + bL_batch_dims new_batch_dims = outer_batch_dims + 2 * bL_batch_dims - bx_new_shape = bx.shape[:outer_batch_dims] + bx_new_shape = batch_x.shape[:outer_batch_dims] - for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + for (sL, sx) in zip(batch_L.shape[:-2], batch_x.shape[outer_batch_dims:-1]): bx_new_shape += (sx // sL, sL) bx_new_shape += [n] - bx = paddle.reshape(bx, bx_new_shape) + batch_x = paddle.reshape(batch_x, bx_new_shape) permute_dims = (list(range(outer_batch_dims)) + list(range(outer_batch_dims, new_batch_dims, 2)) + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + [new_batch_dims]) - bx = paddle.transpose(bx, perm=permute_dims) + batch_x = paddle.transpose(batch_x, perm=permute_dims) # shape = [b, n, n] - flat_L = paddle.reshape(bL, [-1, n, n]) + flat_L = paddle.reshape(batch_L, [-1, n, n]) # shape = [c, b, n] - flat_x = paddle.reshape(bx, [-1, flat_L.shape[0], n]) + flat_x = paddle.reshape(batch_x, [-1, flat_L.shape[0], n]) # shape = [b, n, c] flat_x_swap = paddle.transpose(flat_x, perm=[1, 2, 0]) # shape = [b, c] @@ -214,21 +388,21 @@ def _batch_mahalanobis(self, bL, bx): # shape = [c, b] M = M_swap.t() - if bx.shape[:-1] == []: + if batch_x.shape[:-1] == []: return M.sum() else: # shape = [..., 1, j, i, 1] - permuted_M = paddle.reshape(M, bx.shape[:-1]) + permuted_M = paddle.reshape(M, batch_x.shape[:-1]) permute_inv_dims = list(range(outer_batch_dims)) for i in range(bL_batch_dims): permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] - # shape = [..., 1, i, j, 1] + # shape = [..., 1, i, j, 1] reshaped_M = paddle.transpose(permuted_M, perm=permute_inv_dims) return paddle.reshape(reshaped_M, bx_batch_shape) def _extend_shape(self, sample_shape): - """compute shape of the sample + """Compute shape of the sample Args: sample_shape (Tensor): sample shape From 308dc15eddefcfdab6269706f56a406bddb829da Mon Sep 17 00:00:00 2001 From: dasen Date: Sun, 1 Jan 2023 13:13:27 +0800 Subject: [PATCH 12/18] fix code style --- python/paddle/distribution/kl.py | 7 +- .../distribution/multivariate_normal.py | 100 +++++++++++++----- 2 files changed, 77 insertions(+), 30 deletions(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 47779df79f88f..69983f721106f 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -22,9 +22,9 @@ from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.laplace import Laplace from paddle.distribution.lognormal import LogNormal +from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.distribution.normal import Normal from paddle.distribution.uniform import Uniform -from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.fluid.framework import _non_static_mode __all__ = ["register_kl", "kl_divergence"] @@ -179,10 +179,11 @@ def _kl_uniform_uniform(p, q): def _kl_laplace_laplace(p, q): return p.kl_divergence(q) -@register_kl(MultivariateNormal,MultivariateNormal) -def _kl_multnormal_multnormal(p,q): +@register_kl(MultivariateNormal, MultivariateNormal) +def _kl_multnormal_multnormal(p, q): return p.kl_divergence(q) + @register_kl(ExponentialFamily, ExponentialFamily) def _kl_expfamily_expfamily(p, q): """Compute kl-divergence using `Bregman divergences `_""" diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 5fa46ca07b0d7..813db046863f7 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle import math + +import paddle from paddle.distribution import distribution @@ -77,20 +78,28 @@ def __init__(self, loc, covariance_matrix): if covariance_matrix is not None: if covariance_matrix.dim() < 2: - raise ValueError("covariance_matrix must be at least two-dimensional, " - "with optional leading batch dimensions") - if (covariance_matrix.shape[:-2] == [] or loc.shape[:-1] == []): + raise ValueError( + "covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + if covariance_matrix.shape[:-2] == [] or loc.shape[:-1] == []: batch_shape = [] else: - batch_shape = paddle.broadcast_shape(covariance_matrix.shape[:-2], loc.shape[:-1]) - self.covariance_matrix = covariance_matrix.expand(batch_shape + [-1, -1]) + batch_shape = paddle.broadcast_shape( + covariance_matrix.shape[:-2], loc.shape[:-1] + ) + self.covariance_matrix = covariance_matrix.expand( + batch_shape + [-1, -1] + ) self.loc = loc.expand(batch_shape + [-1]) event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape) if covariance_matrix is not None: - self._unbroadcasted_scale_tril = paddle.linalg.cholesky(covariance_matrix) + self._unbroadcasted_scale_tril = paddle.linalg.cholesky( + covariance_matrix + ) @property def mean(self): @@ -208,8 +217,10 @@ def log_prob(self, value): half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) - return -0.5 * (self.event_shape[0] * math.log(2 * math.pi) + M) - half_log_det - + return ( + -0.5 * (self.event_shape[0] * math.log(2 * math.pi) + M) + - half_log_det + ) def entropy(self): r"""Entropy of multivariate_normal distribution @@ -227,8 +238,15 @@ def entropy(self): Returns: Tensor: entropy value """ - half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) - H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det + half_log_det = ( + paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1) + .log() + .sum(-1) + ) + H = ( + 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + + half_log_det + ) if len(self._batch_shape) == 0: return H else: @@ -255,7 +273,9 @@ def rsample(self, shape=[]): """ shape = self._extend_shape(shape) eps = paddle.standard_normal(shape, dtype=None, name=None) - unbroadcasted_scale_tril = paddle.linalg.cholesky(self.covariance_matrix) + unbroadcasted_scale_tril = paddle.linalg.cholesky( + self.covariance_matrix + ) return self.loc + self._batch_product_mv(unbroadcasted_scale_tril, eps) @@ -294,20 +314,37 @@ def kl_divergence(self, other): """ if self.event_shape != other.event_shape: - raise ValueError("KL-divergence between two Multivariate Normals with\ - different event shapes cannot be computed") + raise ValueError( + "KL-divergence between two Multivariate Normals with\ + different event shapes cannot be computed" + ) - sector1 = (paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) - - paddle.diagonal(other._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1)) + sector1 = paddle.diagonal( + self._unbroadcasted_scale_tril, axis1=-2, axis2=-1 + ).log().sum(-1) - paddle.diagonal( + other._unbroadcasted_scale_tril, axis1=-2, axis2=-1 + ).log().sum( + -1 + ) if list(self.batch_shape) == [] and list(other.batch_shape) == []: combined_batch_shape = [] else: combined_batch_shape = [self.batch_shape, other.batch_shape] n = self.event_shape[0] - self_scale_tril = self._unbroadcasted_scale_tril.expand(combined_batch_shape + [n, n]) - other_scale_tril = other._unbroadcasted_scale_tril.expand(combined_batch_shape + [n, n]) - sector2 = self._batch_trace_XXT(paddle.linalg.triangular_solve(self_scale_tril, other_scale_tril, upper=False)) - sector3 = self._batch_mahalanobis(self._unbroadcasted_scale_tril, (self.loc - other.loc)) + self_scale_tril = self._unbroadcasted_scale_tril.expand( + combined_batch_shape + [n, n] + ) + other_scale_tril = other._unbroadcasted_scale_tril.expand( + combined_batch_shape + [n, n] + ) + sector2 = self._batch_trace_XXT( + paddle.linalg.triangular_solve( + self_scale_tril, other_scale_tril, upper=False + ) + ) + sector3 = self._batch_mahalanobis( + self._unbroadcasted_scale_tril, (self.loc - other.loc) + ) return sector1 + 0.5 * (sector2 + sector3 - n) def _batch_trace_XXT(self, batch_matrix): @@ -322,7 +359,7 @@ def _batch_trace_XXT(self, batch_matrix): n = batch_matrix.shape[-1] m = batch_matrix.shape[-2] flat_trace = paddle.reshape(batch_matrix, [1, m * n]).pow(2).sum(-1) - if (batch_matrix.shape[:-2] == []): + if batch_matrix.shape[:-2] == []: return flat_trace else: return paddle.reshape(flat_trace, batch_matrix.shape[:-2]) @@ -366,15 +403,19 @@ def _batch_mahalanobis(self, batch_L, batch_x): new_batch_dims = outer_batch_dims + 2 * bL_batch_dims bx_new_shape = batch_x.shape[:outer_batch_dims] - for (sL, sx) in zip(batch_L.shape[:-2], batch_x.shape[outer_batch_dims:-1]): + for (sL, sx) in zip( + batch_L.shape[:-2], batch_x.shape[outer_batch_dims:-1] + ): bx_new_shape += (sx // sL, sL) bx_new_shape += [n] batch_x = paddle.reshape(batch_x, bx_new_shape) - permute_dims = (list(range(outer_batch_dims)) + - list(range(outer_batch_dims, new_batch_dims, 2)) + - list(range(outer_batch_dims + 1, new_batch_dims, 2)) + - [new_batch_dims]) + permute_dims = ( + list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims] + ) batch_x = paddle.transpose(batch_x, perm=permute_dims) # shape = [b, n, n] @@ -384,7 +425,12 @@ def _batch_mahalanobis(self, batch_L, batch_x): # shape = [b, n, c] flat_x_swap = paddle.transpose(flat_x, perm=[1, 2, 0]) # shape = [b, c] - M_swap = paddle.linalg.triangular_solve(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + M_swap = ( + paddle.linalg.triangular_solve( + flat_L, flat_x_swap, upper=False) + .pow(2) + .sum(-2) + ) # shape = [c, b] M = M_swap.t() From 3bb35268a427ad149a2664e908e0daf09a3cbcd8 Mon Sep 17 00:00:00 2001 From: dasen Date: Sun, 1 Jan 2023 19:57:16 +0800 Subject: [PATCH 13/18] fix code-style --- python/paddle/distribution/kl.py | 1 + .../distribution/multivariate_normal.py | 24 ++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 69983f721106f..a7cbaa8d2e21a 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -179,6 +179,7 @@ def _kl_uniform_uniform(p, q): def _kl_laplace_laplace(p, q): return p.kl_divergence(q) + @register_kl(MultivariateNormal, MultivariateNormal) def _kl_multnormal_multnormal(p, q): return p.kl_divergence(q) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 813db046863f7..7c3bd26f1b068 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -139,8 +139,12 @@ def variance(self): Tensor: The variance value. """ - matrix_decompos = paddle.linalg.cholesky(self.covariance_matrix).pow(2).sum(-1) - return paddle.expand(matrix_decompos, self._batch_shape + self._event_shape) + matrix_decompos = ( + paddle.linalg.cholesky(self.covariance_matrix).pow(2).sum(-1) + ) + return paddle.expand( + matrix_decompos, self._batch_shape + self._event_shape + ) @property def stddev(self): @@ -209,18 +213,23 @@ def log_prob(self, value): """ if not isinstance(value, type(self.loc)): raise TypeError( - f"Expected type of value is {type(loc)}, but got {type(value)}" + f"Expected type of value is {type(self.loc)}, but got {type(value)}" ) diff = value - self.loc M = self._batch_mahalanobis(self._unbroadcasted_scale_tril, diff) - half_log_det = paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1).log().sum(-1) + half_log_det = ( + paddle.diagonal(self._unbroadcasted_scale_tril, axis1=-2, axis2=-1) + .log() + .sum(-1) + ) return ( -0.5 * (self.event_shape[0] * math.log(2 * math.pi) + M) - half_log_det ) + def entropy(self): r"""Entropy of multivariate_normal distribution @@ -426,10 +435,9 @@ def _batch_mahalanobis(self, batch_L, batch_x): flat_x_swap = paddle.transpose(flat_x, perm=[1, 2, 0]) # shape = [b, c] M_swap = ( - paddle.linalg.triangular_solve( - flat_L, flat_x_swap, upper=False) - .pow(2) - .sum(-2) + paddle.linalg.triangular_solve(flat_L, flat_x_swap, upper=False) + .pow(2) + .sum(-2) ) # shape = [c, b] M = M_swap.t() From de92b38b9134e4cbe4eaf374ba0e5f14db98f8bd Mon Sep 17 00:00:00 2001 From: dasen Date: Sun, 1 Jan 2023 23:02:00 +0800 Subject: [PATCH 14/18] fix code-style --- python/paddle/distribution/multivariate_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 7c3bd26f1b068..a74fd62d282e4 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -186,7 +186,7 @@ def prob(self, value): """ if not isinstance(value, type(self.loc)): raise TypeError( - f"Expected type of value is {type(loc)}, but got {type(value)}" + f"Expected type of value is {type(self.loc)}, but got {type(value)}" ) return paddle.exp(self.log_prob(value)) From 06bbd2a06509f9d597b34f3fa6baa884a4b4ec80 Mon Sep 17 00:00:00 2001 From: dasen Date: Sat, 14 Jan 2023 01:09:27 +0800 Subject: [PATCH 15/18] fix dim to len --- python/paddle/distribution/multivariate_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index a74fd62d282e4..183a6b90644cd 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -71,7 +71,7 @@ class MultivariateNormal(distribution.Distribution): """ def __init__(self, loc, covariance_matrix): - if loc.dim() < 1: + if len(loc.shape) < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) != 1: raise ValueError("Exactly covariance_matrix may be specified.") From 683c85db70a141307f50b04206b6423aac12165e Mon Sep 17 00:00:00 2001 From: dasen Date: Tue, 17 Jan 2023 16:41:42 +0800 Subject: [PATCH 16/18] delete extend_shape method --- .../distribution/multivariate_normal.py | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 183a6b90644cd..e6365f2fb38ec 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -92,14 +92,14 @@ def __init__(self, loc, covariance_matrix): batch_shape + [-1, -1] ) + self._unbroadcasted_scale_tril = paddle.linalg.cholesky( + covariance_matrix + ) + self.loc = loc.expand(batch_shape + [-1]) event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape) - if covariance_matrix is not None: - self._unbroadcasted_scale_tril = paddle.linalg.cholesky( - covariance_matrix - ) @property def mean(self): @@ -261,7 +261,7 @@ def entropy(self): else: return H.expand(self._batch_shape) - def sample(self, shape=[]): + def sample(self, shape=()): """Draw sample data from multinomial distribution Args: @@ -270,7 +270,7 @@ def sample(self, shape=[]): with paddle.no_grad(): return self.rsample(shape) - def rsample(self, shape=[]): + def rsample(self, shape=()): """Generate reparameterized samples of the specified shape. Args: @@ -454,14 +454,3 @@ def _batch_mahalanobis(self, batch_L, batch_x): # shape = [..., 1, i, j, 1] reshaped_M = paddle.transpose(permuted_M, perm=permute_inv_dims) return paddle.reshape(reshaped_M, bx_batch_shape) - - def _extend_shape(self, sample_shape): - """Compute shape of the sample - - Args: - sample_shape (Tensor): sample shape - - Returns: - Tensor: generated sample data shape - """ - return sample_shape + list(self.batch_shape) + list(self.event_shape) From 3809fb4981f34666f572e60c239cac7326f8c48c Mon Sep 17 00:00:00 2001 From: dasen Date: Wed, 18 Jan 2023 00:15:34 +0800 Subject: [PATCH 17/18] fix code style --- python/paddle/distribution/multivariate_normal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index e6365f2fb38ec..34aed070b811d 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -100,7 +100,6 @@ def __init__(self, loc, covariance_matrix): event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape) - @property def mean(self): r"""Mean of distribution From 04e12f55136036aeb3e46259f2782ba9b45a68d0 Mon Sep 17 00:00:00 2001 From: dasen Date: Wed, 18 Jan 2023 00:41:53 +0800 Subject: [PATCH 18/18] fix dim --- python/paddle/distribution/multivariate_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/multivariate_normal.py b/python/paddle/distribution/multivariate_normal.py index 34aed070b811d..6d12a0e7aa2d1 100644 --- a/python/paddle/distribution/multivariate_normal.py +++ b/python/paddle/distribution/multivariate_normal.py @@ -77,7 +77,7 @@ def __init__(self, loc, covariance_matrix): raise ValueError("Exactly covariance_matrix may be specified.") if covariance_matrix is not None: - if covariance_matrix.dim() < 2: + if len(covariance_matrix) < 2: raise ValueError( "covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions"