Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] CompositeDistribution #517

Merged
merged 8 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/workflows/nightly_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-cuda116
steps:
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -106,7 +106,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -155,7 +155,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[2] }}
steps:
Expand Down Expand Up @@ -186,7 +186,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Checkout tensordict
uses: actions/checkout@v2
Expand Down Expand Up @@ -214,7 +214,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]]
cuda_support: [["", "cpu", "cpu"]]
steps:
- name: Setup Python
Expand Down Expand Up @@ -273,7 +273,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -306,7 +306,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -360,7 +360,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Checkout tensordict
uses: actions/checkout@v2
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"]]
python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"]]
cuda_support: [["", "--extra-index-url https://download.pytorch.org/whl/cpu", "\"['cpu', '11.3', '11.6']\"", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[3] }}
steps:
Expand Down Expand Up @@ -56,7 +56,7 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -88,7 +88,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"]]
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand Down Expand Up @@ -122,7 +122,7 @@ jobs:
strategy:
matrix:
os: [["linux", "ubuntu-20.04"], ["mac", "macos-latest"]]
python_version: ["3.8", "3.9", "3.10" ]
python_version: ["3.8", "3.9", "3.10", "3.11" ]
runs-on: ${{ matrix.os[1] }}
steps:
- name: Setup Python
Expand Down Expand Up @@ -169,7 +169,7 @@ jobs:
needs: build-wheel-windows
strategy:
matrix:
python_version: ["3.8", "3.9", "3.10" ]
python_version: ["3.8", "3.9", "3.10", "3.11" ]
runs-on: windows-latest
steps:
- name: Setup Python
Expand Down
15 changes: 15 additions & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,21 @@ first traced using :func:`~.symbolic_trace`.

symbolic_trace

Distributions
-------------

.. py:currentmodule::tensordict.nn.distributions

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

NormalParamsExtractor
CompositeDistribution
Delta
OneHotCategorical
TruncatedNormal


Utils
-----
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TensorDictModuleBase,
TensorDictModuleWrapper,
)
from tensordict.nn.distributions import NormalParamExtractor
from tensordict.nn.distributions import CompositeDistribution, NormalParamExtractor
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.functional_modules import (
get_functional,
Expand Down
11 changes: 9 additions & 2 deletions tensordict/nn/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
# LICENSE file in the root directory of this source tree.

from tensordict.nn.distributions import continuous, discrete
from tensordict.nn.distributions.continuous import *
from tensordict.nn.distributions.discrete import *

from tensordict.nn.distributions.composite import CompositeDistribution
from tensordict.nn.distributions.continuous import (
Delta,
NormalParamExtractor,
NormalParamWrapper,
)
from tensordict.nn.distributions.discrete import OneHotCategorical, rand_one_hot
from tensordict.nn.distributions.truncated_normal import TruncatedNormal

distributions_maps = {
distribution_class.lower(): eval(distribution_class)
Expand Down
132 changes: 132 additions & 0 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict._tensordict import unravel_keys
from tensordict.utils import NestedKey
from torch import distributions as d


class CompositeDistribution(d.Distribution):
"""A composition of distributions.

Groups distributions together with the TensorDict interface.

Args:
params (TensorDictBase): a nested key-tensor map where the root entries
point to the sample names, and the leaves are the distribution parameters.
Entry names must match those of ``distribution_map``.

distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]):
indicated the distribution types to be used. The names of the distributions
will match the names of the samples in the tensordict.

Keyword Arguments:
extra_kwargs (Dict[NestedKey, Dict]): a possibly incomplete dictionary of
extra keyword arguments for the distributions to be built.

Examples:
>>> params = TensorDict({
... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
... ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> sample = dist.log_prob(sample)
>>> print(sample)
TensorDict(
fields={
cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)
"""

def __init__(self, params: TensorDictBase, distribution_map, *, extra_kwargs=None):
self._batch_shape = params.shape
if extra_kwargs is None:
extra_kwargs = {}
dists = {}
for name, dist_class in distribution_map.items():
dist_params = params.get(name, None)
kwargs = extra_kwargs.get(name, {})
if dist_params is None:
raise KeyError
dist = dist_class(**dist_params, **kwargs)
dists[name] = dist
self.dists = dists

def sample(self, shape=None):
if shape is None:
shape = torch.Size([])
samples = {name: dist.sample(shape) for name, dist in self.dists.items()}
return TensorDict(
samples,
shape + self.batch_shape,
)

def rsample(self, shape=None):
if shape is None:
shape = torch.Size([])
return TensorDict(
{name: dist.rsample(shape) for name, dist in self.dists.items()},
shape + self.batch_shape,
)

def log_prob(self, sample: TensorDictBase):
d = {
_add_suffix(name, "_log_prob"): dist.log_prob(sample.get(name))
Copy link
Contributor

@matteobettini matteobettini Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not want these keys to be parametric?

for name, dist in self.dists.items()
}
sample.update(d)
return sample

def cdf(self, sample: TensorDictBase):
cdfs = {
_add_suffix(name, "_cdf"): dist.cdf(sample.get(name))
for name, dist in self.dists.items()
}
sample.update(cdfs)
return sample

def icdf(self, sample: TensorDictBase):
"""Computes the inverse CDF.

Requires the input tensordict to have one of `<sample_name>+'_cdf'` entry
or a `<sample_name>` entry.

Args:
sample (TensorDictBase): a tensordict containing `<sample>_log_prob` where
`<sample>` is the name of the sample provided during construction.
"""
for name, dist in self.dists.items():
prob = sample.get(_add_suffix(name, "_cdf"), None)
if prob is None:
try:
prob = self.cdf(sample.get(name))
except KeyError:
raise KeyError(
f"Neither {name} nor {name + '_cdf'} could be found in the sampled tensordict. Make sure one of these is available to icdf."
)
icdf = dist.icdf(prob)
sample.set(_add_suffix(name, "_icdf"), icdf)
return sample


def _add_suffix(key: NestedKey, suffix: str):
key = unravel_keys(key)
if isinstance(key, str):
return key + suffix
return key[:-1] + (key[-1] + suffix,)
Loading
Loading