Skip to content

Commit

Permalink
📝 Create Sphinx documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed May 5, 2022
1 parent 2059a40 commit af9bf9a
Show file tree
Hide file tree
Showing 46 changed files with 3,081 additions and 2,382 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Contributing guidelines

TODO
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
# Likelihood-free AMortized Posterior Estimation
<p align="center"><img src="https://raw.githubusercontent.com/francois-rozet/lampe/master/sphinx/static/banner.svg" width="100%"></p>

# LAMPE

`lampe` is a simulation-based inference (SBI) package that focuses on amortized estimation of posterior distributions, without relying on explicit likelihood functions; hence the name *Likelihood-free AMortized Posterior Estimation* (LAMPE). The package provides [PyTorch](https://pytorch.org) implementations of modern amortized simulation-based inference algorithms like neural ratio estimation (NRE), neural posterior estimation (NPE) and more. Similar to PyTorch, the philosophy of LAMPE is to avoid obfuscation and expose all components, from network architecture to optimizer, to the user such that they are free to modify or replace anything they like.

## Installation

The `lampe` package is available on [PyPI](https://pypi.org/project/lampe), which means it is installable via `pip`.

```
pip install lampe
```

Alternatively, if you need the latest features, you can install it from the repository.

```
pip install git+https://github.com/francois-rozet/lampe
```

## Documentation

The documentation is made with [Sphinx](https://www.sphinx-doc.org) and [Furo](https://github.com/pradyunsg/furo) and is hosted at [francois-rozet.github.io/lampe](https://francois-rozet.github.io/lampe).
11 changes: 6 additions & 5 deletions lampe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
r"""Likelihood-free AMortized Posterior Estimation"""
r"""Likelihood-free AMortized Posterior Estimation (LAMPE)"""

from .data import SimulatorLoader, H5Loader, h5save
from .mcmc import MetropolisHastings, InferenceSampler
from .nn import NRE, NPE, NREPipe, NPEPipe
from .priors import JointNormal, JointUniform
from . import patch
from .data import JointLoader, H5Dataset
from .nn import NRE, NPE
from .nn.losses import NRELoss, NPELoss
from .priors import BoxUniform, DiagNormal
254 changes: 159 additions & 95 deletions lampe/data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
r"""Datasets and data loaders"""
r"""Datasets and data loaders."""

import h5py
import numpy as np
import random
import torch
import torch.utils.data as data

from numpy.typing import ArrayLike
from bisect import bisect
from numpy import ndarray as Array
from pathlib import Path
from torch import Tensor
from torch.distributions import Distribution
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm import tqdm
from typing import *


class IterableSimulatorDataset(data.IterableDataset):
r"""Iterable dataset of (theta, x) batches"""
__all__ = ['JointLoader', 'H5Dataset']


class IterableJointDataset(IterableDataset):
r"""Creates an iterable dataset of batched pairs :math:`(\theta, x)`."""

def __init__(
self,
Expand Down Expand Up @@ -44,47 +48,91 @@ def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
yield theta, x


class SimulatorLoader(data.DataLoader):
r"""Iterable data loader of (theta, x) batches"""

def __init__(
self,
prior: Distribution,
simulator: Callable,
batch_size: int = 2**10, # 1024
batched: bool = False,
numpy: bool = False,
rng: torch.Generator = None,
**kwargs,
):
dataset = IterableSimulatorDataset(
def JointLoader(
prior: Distribution,
simulator: Callable,
batch_size: int = 2**10, # 1024
batched: bool = False,
numpy: bool = False,
**kwargs,
) -> DataLoader:
r"""Creates a data loader of batched pairs :math:`(\theta, x)` generated by
a prior distribution :math:`p(\theta)` and a simulator.
The simlator is a stochastic function taking a set of parameters :math:`\theta`
as input and returning an observation :math:`x` as output, which (implicitely)
defines a likelihood function :math:`p(x | \theta)`. Together with the prior,
they form a joint distribution :math:`p(\theta, x) = p(\theta) p(x | \theta)`
from which the pairs :math:`(\theta, x)` are independently drawn.
Arguments:
prior: A prior distribution :math:`p(\theta)`.
simulator: A callable simulator.
batch_size: The batch size of the generated pairs.
batched: Whether the simulator accepts batched inputs or not.
numpy: Whether the simulator requires NumPy or PyTorch inputs.
kwargs: Keyword arguments passed to :class:`torch.utils.data.DataLoader`.
Returns:
An infinite data loader of batched pairs :math:`(\theta, x)`.
Example:
>>> loader = joint_loader(prior, simulator, numpy=True, num_workers=4)
>>> for theta, x in loader:
... theta, x = theta.cuda(), x.cuda()
... something(theta, x)
"""

return DataLoader(
IterableJointDataset(
prior,
simulator,
batch_shape=(batch_size,) if batched else (),
numpy=numpy,
)

super().__init__(
dataset,
batch_size=None if batched else batch_size,
worker_init_fn=self.worker_init,
generator=rng,
**kwargs,
)

@staticmethod
def worker_init(*args) -> None:
seed = torch.initial_seed() % 2**32
np.random.seed(seed)
random.seed(seed)


class H5Loader(data.Dataset):
r"""Data loader of (theta, x) pairs saved in HDF5 files"""
),
batch_size=None if batched else batch_size,
**kwargs,
)


class H5Dataset(object):
r"""Creates a dataset of pairs :math:`(\theta, x)` from HDF5 files.
As a :class:`torch.utils.data.Dataset`, :class:`H5Dataset` implements the methods
:meth:`__len__` and :meth:`__getitem__`. However, as it can be slow to load pairs
from disk one by one when iterating over the dataset, it also implements a custom
:meth:`__iter__` method. This method loads several contiguous chunks of pairs at
once, concatenates them, shuffles the result and, finally, splits it into batches.
This "weak shuffling" procedure greatly improves loading performances, but the
resulting batch elements are not perfectly independent from each others.
Important:
To take advantage of the custom :meth:`__iter__` method, :class:`H5Dataset`
instances should not be wrapped inside a :class:`torch.utils.data.DataLoader`
when iterating over the dataset.
Arguments:
files: HDF5 files containing pairs :math:`(\theta, x)`.
batch_size: The size of the batches.
chunk_size: The size of the contiguous chunks.
group_size: The number of chunks loaded at once.
pin_memory: Whether the batches reside in CUDA pinned memory or not.
shuffle: Whether the pairs are shuffled when iterating.
seed: A seed to initialize the internal RNG used for shuffling.
Example:
>>> dataset = H5Dataset('data.h5', batch_size=256, shuffle=True)
>>> theta, x = dataset[0]
>>> theta
tensor([-0.1215, -1.3641, 0.7233, -1.2150, -1.9263])
>>> for theta, x in dataset:
... theta, x = theta.cuda(), x.cuda()
... something(theta, x)
"""

def __init__(
self,
*filenames,
*files: Union[str, Path],
batch_size: int = 2**10, # 1024
chunk_size: int = 2**12, # 4096
group_size: str = 2**4, # 16
Expand All @@ -94,7 +142,8 @@ def __init__(
):
super().__init__()

self.fs = list(map(h5py.File, filenames))
self.files = list(map(h5py.File, files))
self.stops = np.cumsum([len(f['x']) for f in self.files])

self.batch_size = batch_size
self.chunk_size = chunk_size
Expand All @@ -106,33 +155,33 @@ def __init__(
self.rng = np.random.default_rng(seed)

def __del__(self) -> None:
for f in self.fs:
for f in self.files:
f.close()

def __len__(self) -> int:
return sum(len(f['x']) for f in self.fs)
return self.stops[-1]

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
idx = idx % len(self)
def __getitem__(self, i: int) -> Tuple[Tensor, Tensor]:
i = i % len(self)
j = bisect(self.stops, i)
if j > 0:
i = i - self.stops[j - 1]

for f in self.fs:
if idx < len(f['x']):
break
idx = idx - len(f['x'])
f = self.files[j]

if 'theta' in f:
theta = torch.from_numpy(f['theta'][idx])
theta = torch.from_numpy(f['theta'][i])
else:
theta = None

x = torch.from_numpy(f['x'][idx])
x = torch.from_numpy(f['x'][i])

return theta, x

def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
chunks = [
(i, j, j + self.chunk_size)
for i, f in enumerate(self.fs)
for i, f in enumerate(self.files)
for j in range(0, len(f['x']), self.chunk_size)
]

Expand All @@ -143,8 +192,8 @@ def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
slices = sorted(chunks[l:l+self.group_size])

# Load
theta = np.concatenate([self.fs[i]['theta'][j:k] for i, j, k in slices])
x = np.concatenate([self.fs[i]['x'][j:k] for i, j, k in slices])
theta = np.concatenate([self.files[i]['theta'][j:k] for i, j, k in slices])
x = np.concatenate([self.files[i]['x'][j:k] for i, j, k in slices])

# Shuffle
if self.shuffle:
Expand All @@ -163,45 +212,60 @@ def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
x.split(self.batch_size),
)


def h5save(
iterable: Iterable[Tuple[ArrayLike, ArrayLike]],
filename: str,
size: int,
dtype: type = np.float32,
**kwargs,
) -> None:
r"""Saves (theta, x) batches to an HDF5 file"""

# File
filename = Path(filename)
filename.parent.mkdir(parents=True, exist_ok=True)

with h5py.File(filename, 'w') as f:
## Attributes
for k, v in kwargs.items():
f.attrs[k] = v

## Datasets
theta, x = map(np.asarray, next(iter(iterable)))
theta, x = theta[0], x[0]

f.create_dataset('theta', (size,) + theta.shape, dtype=dtype)
f.create_dataset('x', (size,) + x.shape, dtype=dtype)

## Samples
with tqdm(total=size, unit='sample') as tq:
i = 0

for theta, x in iterable:
j = min(i + theta.shape[0], size)

f['theta'][i:j] = np.asarray(theta)[:j-i]
f['x'][i:j] = np.asarray(x)[:j-i]

tq.update(j - i)

if j < size:
i = j
else:
break
@staticmethod
def store(
pairs: Iterable[Tuple[Array, Array]],
file: Union[str, Path],
size: int,
dtype: np.dtype = np.float32,
**meta,
) -> None:
r"""Creates an HDF5 file containing pairs :math:`(\theta, x)`.
The sets of parameters :math:`\theta` are stored in a collection named
:py:`'theta'` and the observations in a collection named :py:`'x'`.
Arguments:
pairs: An iterable over batched pairs :math:`(\theta, x)`.
file: An HDF5 filename to store pairs in.
size: The number of pairs to store.
dtype: The data type to store pairs in.
meta: Metadata to store in the file.
Example:
>>> loader = JointLoader(prior, simulator, batch_size=16)
>>> H5Dataset.store(loader, 'sim.h5', 4096)
100%|██████████| 4096/4096 [01:35<00:00, 42.69sample/s]
"""

# File
file = Path(file)
file.parent.mkdir(parents=True, exist_ok=True)

with h5py.File(file, 'w-') as f:
## Attributes
f.attrs.update(meta)

## Datasets
theta, x = map(np.asarray, next(iter(pairs)))
theta, x = theta[0], x[0]

f.create_dataset('theta', (size,) + theta.shape, dtype=dtype)
f.create_dataset('x', (size,) + x.shape, dtype=dtype)

## Store
with tqdm(total=size, unit='sample') as tq:
i = 0

for theta, x in pairs:
j = min(i + theta.shape[0], size)

f['theta'][i:j] = np.asarray(theta)[:j-i]
f['x'][i:j] = np.asarray(x)[:j-i]

tq.update(j - i)

if j < size:
i = j
else:
break
Loading

0 comments on commit af9bf9a

Please sign in to comment.