Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] CompositeDistribution #517

Merged
merged 8 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,22 @@ first traced using :func:`~.symbolic_trace`.

symbolic_trace

Distributions
-------------

.. py:currentmodule::tensordict.nn.distributions

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

NormalParamsExtractor
AddStateIndependentNormalScale
CompositeDistribution
Delta
OneHotCategorical
TruncatedNormal


Utils
-----
Expand Down
9 changes: 8 additions & 1 deletion tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
TensorDictModuleBase,
TensorDictModuleWrapper,
)
from tensordict.nn.distributions import NormalParamExtractor
from tensordict.nn.distributions import (
AddStateIndependentNormalScale,
CompositeDistribution,
NormalParamExtractor,
OneHotCategorical,
rand_one_hot,
TruncatedNormal,
)
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.functional_modules import (
get_functional,
Expand Down
12 changes: 10 additions & 2 deletions tensordict/nn/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
# LICENSE file in the root directory of this source tree.

from tensordict.nn.distributions import continuous, discrete
from tensordict.nn.distributions.continuous import *
from tensordict.nn.distributions.discrete import *

from tensordict.nn.distributions.composite import CompositeDistribution
from tensordict.nn.distributions.continuous import (
AddStateIndependentNormalScale,
Delta,
NormalParamExtractor,
NormalParamWrapper,
)
from tensordict.nn.distributions.discrete import OneHotCategorical, rand_one_hot
from tensordict.nn.distributions.truncated_normal import TruncatedNormal

distributions_maps = {
distribution_class.lower(): eval(distribution_class)
Expand Down
160 changes: 160 additions & 0 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict._tensordict import unravel_keys
from tensordict.utils import NestedKey
from torch import distributions as d


class CompositeDistribution(d.Distribution):
"""A composition of distributions.

Groups distributions together with the TensorDict interface. All methods
(``log_prob``, ``cdf``, ``icdf``, ``rsample``, ``sample`` etc.) will return a
tensordict, possibly modified in-place if the input was a tensordict.

Args:
params (TensorDictBase): a nested key-tensor map where the root entries
point to the sample names, and the leaves are the distribution parameters.
Entry names must match those of ``distribution_map``.

distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]):
indicated the distribution types to be used. The names of the distributions
will match the names of the samples in the tensordict.

Keyword Arguments:
extra_kwargs (Dict[NestedKey, Dict]): a possibly incomplete dictionary of
extra keyword arguments for the distributions to be built.

.. note:: In this distribution class, the batch-size of the input tensordict containing the params
(``params``) is indicative of the batch_shape of the distribution. For instance,
the ``"sample_log_prob"`` entry resulting from a call to ``log_prob``
will be of the shape of the params (+ any supplementary batch dimension).

Examples:
>>> params = TensorDict({
... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
... ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> sample = dist.log_prob(sample)
>>> print(sample)
TensorDict(
fields={
cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)
"""

def __init__(self, params: TensorDictBase, distribution_map, *, extra_kwargs=None):
self._batch_shape = params.shape
if extra_kwargs is None:
extra_kwargs = {}
dists = {}
for name, dist_class in distribution_map.items():
dist_params = params.get(name, None)
kwargs = extra_kwargs.get(name, {})
if dist_params is None:
raise KeyError
dist = dist_class(**dist_params, **kwargs)
dists[name] = dist
self.dists = dists

def sample(self, shape=None) -> TensorDictBase:
if shape is None:
shape = torch.Size([])
samples = {name: dist.sample(shape) for name, dist in self.dists.items()}
return TensorDict(
samples,
shape + self.batch_shape,
)

@property
def mode(self) -> TensorDictBase:
samples = {name: dist.mode for name, dist in self.dists.items()}
return TensorDict(
samples,
self.batch_shape,
)

@property
def mean(self) -> TensorDictBase:
samples = {name: dist.mean for name, dist in self.dists.items()}
return TensorDict(
samples,
self.batch_shape,
)

def rsample(self, shape=None) -> TensorDictBase:
if shape is None:
shape = torch.Size([])
return TensorDict(
{name: dist.rsample(shape) for name, dist in self.dists.items()},
shape + self.batch_shape,
)

def log_prob(self, sample: TensorDictBase) -> TensorDictBase:
"""Writes a ``<sample>_log_prob entry`` for each sample in the input tensordit, along with a ``"sample_log_prob"`` entry with the summed log-prob."""
slp = 0.0
d = {}
for name, dist in self.dists.items():
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))
while lp.ndim > sample.ndim:
lp = lp.sum(-1)
slp = slp + lp
d["sample_log_prob"] = slp
sample.update(d)
return sample

def cdf(self, sample: TensorDictBase) -> TensorDictBase:
cdfs = {
_add_suffix(name, "_cdf"): dist.cdf(sample.get(name))
for name, dist in self.dists.items()
}
sample.update(cdfs)
return sample

def icdf(self, sample: TensorDictBase) -> TensorDictBase:
"""Computes the inverse CDF.

Requires the input tensordict to have one of `<sample_name>+'_cdf'` entry
or a `<sample_name>` entry.

Args:
sample (TensorDictBase): a tensordict containing `<sample>_log_prob` where
`<sample>` is the name of the sample provided during construction.
"""
for name, dist in self.dists.items():
prob = sample.get(_add_suffix(name, "_cdf"), None)
if prob is None:
try:
prob = self.cdf(sample.get(name))
except KeyError:
raise KeyError(
f"Neither {name} nor {name + '_cdf'} could be found in the sampled tensordict. Make sure one of these is available to icdf."
)
icdf = dist.icdf(prob)
sample.set(_add_suffix(name, "_icdf"), icdf)
return sample


def _add_suffix(key: NestedKey, suffix: str):
key = unravel_keys(key)
if isinstance(key, str):
return key + suffix
return key[:-1] + (key[-1] + suffix,)
46 changes: 36 additions & 10 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from warnings import warn

from tensordict._contextlib import _DecoratorContextManager
from tensordict.nn import CompositeDistribution

from tensordict.nn.common import dispatch, TensorDictModule, TensorDictModuleBase
from tensordict.nn.distributions import Delta, distributions_maps
Expand Down Expand Up @@ -224,7 +225,7 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase):
>>> td_module = ProbabilisticTensorDictSequential(
... module, normal_params, prob_module
... )
>>> params = make_functional(td_module, funs_to_decorate=["forward", "get_dist"])
>>> params = make_functional(td_module, funs_to_decorate=["forward", "get_dist", "log_prob"])
>>> _ = td_module(td, params=params)
>>> print(td)
TensorDict(
Expand Down Expand Up @@ -318,10 +319,11 @@ def __init__(
self._dist = None
self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False
self.return_log_prob = return_log_prob
if self.return_log_prob:
if self.return_log_prob and self.log_prob_key not in self.out_keys:
self.out_keys.append(self.log_prob_key)

def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
"""Creates a :class:`torch.distribution.Distribution` instance with the parameters provided in the input tensordict."""
try:
dist_kwargs = {}
for dist_key, td_key in zip(self.dist_keys, self.in_keys):
Expand All @@ -343,6 +345,15 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
raise err
return dist

def log_prob(self, tensordict):
"""Writes the log-probability of the distribution sample."""
dist = self.get_dist(tensordict)
if isinstance(dist, CompositeDistribution):
tensordict = dist.log_prob(tensordict)
return tensordict.get("sample_log_prob")
else:
return dist.log_prob(tensordict.get(self.out_keys[0]))

@property
def SAMPLE_LOG_PROB_KEY(self):
warnings.warn(
Expand All @@ -366,14 +377,19 @@ def forward(
dist = self.get_dist(tensordict)
if _requires_sample:
out_tensors = self._dist_sample(dist, interaction_type=interaction_type())
if isinstance(out_tensors, Tensor):
out_tensors = (out_tensors,)
tensordict_out.update(
{key: value for key, value in zip(self.out_keys, out_tensors)}
)
if self.return_log_prob:
log_prob = dist.log_prob(*out_tensors)
tensordict_out.set(self.log_prob_key, log_prob)
if isinstance(out_tensors, TensorDictBase):
tensordict_out.update(out_tensors)
if self.return_log_prob:
tensordict_out = dist.log_prob(tensordict_out)
else:
if isinstance(out_tensors, Tensor):
out_tensors = (out_tensors,)
tensordict_out.update(
{key: value for key, value in zip(self.out_keys, out_tensors)}
)
if self.return_log_prob:
log_prob = dist.log_prob(*out_tensors)
tensordict_out.set(self.log_prob_key, log_prob)
elif self.return_log_prob:
out_tensors = [
tensordict.get(key) for key in self.out_keys if key != self.log_prob_key
Expand Down Expand Up @@ -508,6 +524,16 @@ def get_dist(
tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs)
return self.build_dist_from_params(tensordict_out)

def log_prob(
self, tensordict, tensordict_out: TensorDictBase | None = None, **kwargs
):
tensordict_out = self.get_dist_params(
tensordict,
tensordict_out,
**kwargs,
)
return self.module[-1].log_prob(tensordict_out)

def build_dist_from_params(self, tensordict: TensorDictBase) -> D.Distribution:
"""Construct a distribution from the input parameters. Other modules in the sequence are not evaluated."""
return self.module[-1].get_dist(tensordict)
Expand Down
Loading
Loading