Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultivariateNormal distribution API #47825

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8d99219
dropout2d; test=document_fix
dasenCoding Nov 10, 2022
23e7ebd
init multivariate_normal api
dasenCoding Nov 10, 2022
be5d83f
resolve conflicts
dasenCoding Nov 14, 2022
1b156d7
resolve conflicts
dasenCoding Nov 14, 2022
5a90ff0
rollback kl.py
dasenCoding Nov 14, 2022
3de0a18
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Nov 14, 2022
006d08a
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Nov 25, 2022
29b0578
update: init / mean / variance / stddev
dasenCoding Nov 25, 2022
30c925e
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Nov 25, 2022
9ffdadb
fix: deal conflict
dasenCoding Dec 4, 2022
622c335
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Dec 4, 2022
972e19c
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Dec 18, 2022
653793d
update prob / log_prob
dasenCoding Dec 18, 2022
89168bf
update entrop / sample / rsample
dasenCoding Dec 18, 2022
1721300
update kl_divergence
dasenCoding Dec 21, 2022
dce35da
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Dec 21, 2022
535e377
regist kl
dasenCoding Dec 21, 2022
29d82fd
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Dec 21, 2022
3579509
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Dec 23, 2022
7e98af0
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Dec 30, 2022
5c8a6ef
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Dec 31, 2022
1d4a8a1
complete en api docs
dasenCoding Dec 31, 2022
d60c386
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Dec 31, 2022
0a0df73
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Jan 1, 2023
308dc15
fix code style
dasenCoding Jan 1, 2023
091c0dd
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Jan 1, 2023
3bb3526
fix code-style
dasenCoding Jan 1, 2023
de92b38
fix code-style
dasenCoding Jan 1, 2023
1082433
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Jan 9, 2023
d75b351
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Jan 13, 2023
06bbd2a
fix dim to len
dasenCoding Jan 13, 2023
95935db
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Jan 13, 2023
4360f90
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Jan 17, 2023
683c85d
delete extend_shape method
dasenCoding Jan 17, 2023
6c3d291
Merge branch 'multnormal_api' of https://github.com/dasenCoding/Paddl…
dasenCoding Jan 17, 2023
3809fb4
fix code style
dasenCoding Jan 17, 2023
04e12f5
fix dim
dasenCoding Jan 17, 2023
465b6c0
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Jan 27, 2023
215e48e
Merge branch 'PaddlePaddle:develop' into multnormal_api
dasenCoding Feb 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions python/paddle/distribution/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.laplace import Laplace
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.multivariate_normal import MultivariateNormal
from paddle.distribution.normal import Normal
from paddle.distribution.uniform import Uniform
from paddle.fluid.framework import _non_static_mode
Expand All @@ -34,53 +35,38 @@
def kl_divergence(p, q):
r"""
Kullback-Leibler divergence between distribution p and q.

.. math::

KL(p||q) = \int p(x)log\frac{p(x)}{q(x)} \mathrm{d}x

Args:
p (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
q (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.

Returns:
Tensor, Batchwise KL-divergence between distribution p and q.

Examples:

.. code-block:: python

import paddle

p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
q = paddle.distribution.Beta(alpha=0.3, beta=0.7)

print(paddle.distribution.kl_divergence(p, q))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.21193528])

"""
return _dispatch(type(p), type(q))(p, q)


def register_kl(cls_p, cls_q):
"""Decorator for register a KL divergence implemention function.

The ``kl_divergence(p, q)`` function will search concrete implemention
functions registered by ``register_kl``, according to multi-dispatch pattern.
If an implemention function is found, it will return the result, otherwise,
it will raise ``NotImplementError`` exception. Users can register
implemention funciton by the decorator.

Args:
cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``.
cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``.

Examples:
.. code-block:: python

import paddle

@paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta)
def kl_beta_beta():
pass # insert implementation here
Expand Down Expand Up @@ -194,6 +180,11 @@ def _kl_laplace_laplace(p, q):
return p.kl_divergence(q)


@register_kl(MultivariateNormal, MultivariateNormal)
def _kl_multnormal_multnormal(p, q):
return p.kl_divergence(q)


@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
"""Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
Expand Down
Loading