diff --git a/docs/tutorials/simulators.ipynb b/docs/tutorials/simulators.ipynb index 5fcf52b..e600a2d 100644 --- a/docs/tutorials/simulators.ipynb +++ b/docs/tutorials/simulators.ipynb @@ -248,17 +248,72 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Saving on disk\n", - "\n", - "If the simulator is fast or inexpensive, it is reasonable to generate pairs $(\\theta, x)$ on demand. Otherwise, the pairs have to be generated and stored on disk ahead of time. The [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) file format is commonly used for this purpose, as it was specifically designed to hold large amounts of numerical data.\n", + "## Loading in memory\n", "\n", - "The [`lampe.data`](lampe.data) module provides the [`H5Dataset`](lampe.data.H5Dataset) class to help load and store pairs $(\\theta, x)$ in HDF5 files. The [`H5Dataset.store`](lampe.data.H5Dataset.store) function takes an iterable of batched pairs $(\\theta, x)$ as input and stores them into a new HDF5 file. The iterable can be a precomputed list, a custom generator or even a `JointLoader` instance." + "If the simulator is fast or inexpensive, it is reasonable to generate pairs $(\\theta, x)$ on demand. Otherwise, the pairs have to be generated ahead of time. The [`lampe.data`](lampe.data) module provides the [`JointDataset`](lampe.data.JointDataset) class to interact with in-memory pairs $(\\theta, x)$. This is ideal when your data fits in RAM." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 0.6757, -0.3787, 0.9581])\n", + "tensor([0.3410, 0.7707])\n" + ] + } + ], + "source": [ + "theta = prior.sample((1024,))\n", + "x = simulator(theta)\n", + "\n", + "dataset = lampe.data.JointDataset(theta, x)\n", + "\n", + "print(*dataset[42], sep='\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1024/1024 [00:00<00:00, 282991.85it/s]\n" + ] + } + ], + "source": [ + "for theta, x in tqdm(dataset):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`JointDataset` can be wrapped in a [`DataLoader`](torch.utils.data.DataLoader) to enable batching and shuffling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving on disk\n", + "\n", + "If your data does not fit in RAM or you need to reuse it later, you may want to store it on disk. The [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) file format is commonly used for this purpose, as it was specifically designed to hold large amounts of numerical data. The [`lampe.data`](lampe.data) module provides the [`H5Dataset`](lampe.data.H5Dataset) class to help load and store pairs $(\\theta, x)$ in HDF5 files. The [`H5Dataset.store`](lampe.data.H5Dataset.store) function takes an iterable of batched pairs $(\\theta, x)$ as input and stores them into a new HDF5 file. The iterable can be a precomputed list, a custom generator or even a `JointLoader` instance." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -269,20 +324,20 @@ } ], "source": [ - "data = []\n", + "pairs = []\n", "\n", "for _ in range(256):\n", " theta = prior.sample((256,))\n", " x = simulator(theta)\n", "\n", - " data.append((theta, x))\n", + " pairs.append((theta, x))\n", "\n", - "lampe.data.H5Dataset.store(data, 'data_0.h5', size=2**16)" + "lampe.data.H5Dataset.store(pairs, 'data_0.h5', size=2**16)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -306,7 +361,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -332,7 +387,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -359,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -375,6 +430,22 @@ " pass" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, if your data fits in memory, you can load it at once with the `to_memory` method, which returns a `JointDataset`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = lampe.data.H5Dataset('data_0.h5').to_memory()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -386,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -395,7 +466,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -421,7 +492,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -456,7 +527,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -483,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -502,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -537,7 +608,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.9.16" }, "vscode": { "interpreter": { diff --git a/lampe/data.py b/lampe/data.py index 0ec0705..7bbe322 100644 --- a/lampe/data.py +++ b/lampe/data.py @@ -1,6 +1,6 @@ r"""Datasets and data loaders.""" -__all__ = ['JointLoader', 'H5Dataset'] +__all__ = ['JointLoader', 'JointDataset', 'H5Dataset'] import h5py import numpy as np @@ -8,7 +8,6 @@ import torch from bisect import bisect -from contextlib import ExitStack from numpy import ndarray as Array from pathlib import Path from torch import Tensor, Size @@ -95,6 +94,69 @@ def __init__( ) +class JointDataset(Dataset): + r"""Creates an in-memory dataset of pairs :math:`(\theta, x)`. + + :class:`JointDataset` supports indexing and slicing, but also implements a custom + :meth:`__iter__` method which supports batching and shuffling. + + Arguments: + theta: A tensor of parameters :math:`\theta`. + x: A tensor of observations :math:`x`. + batch_size: The size of the batches. + shuffle: Whether the pairs are shuffled or not when iterating. + + Example: + >>> dataset = JointDataset(theta, x, batch_size=256, shuffle=True) + >>> theta, x = dataset[42:69] + >>> theta.shape + torch.Size([27, 5]) + >>> for theta, x in dataset: + ... theta, x = theta.cuda(), x.cuda() + ... something(theta, x) + """ + + def __init__( + self, + theta: Tensor, + x: Tensor, + batch_size: int = None, + shuffle: bool = False, + ): + super().__init__() + + assert len(theta) == len(x) + + self.theta = torch.as_tensor(theta) + self.x = torch.as_tensor(x) + + self.batch_size = batch_size + self.shuffle = shuffle + + def __len__(self) -> int: + return len(self.theta) + + def __getitem__(self, i: Union[int, slice]) -> Tuple[Tensor, Tensor]: + return self.theta[i], self.x[i] + + def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]: + if self.shuffle: + order = torch.randperm(len(self)) + + if self.batch_size is None: + return (self[i] for i in order) + else: + return (self[i] for i in order.split(self.batch_size)) + else: + if self.batch_size is None: + return zip(self.theta, self.x) + else: + return zip( + self.theta.split(self.batch_size), + self.x.split(self.batch_size), + ) + + class H5Dataset(IterableDataset): r"""Creates an iterable dataset of pairs :math:`(\theta, x)` from HDF5 files. @@ -140,11 +202,9 @@ def __init__( ): super().__init__() - self.files = files - - with ExitStack() as stack: - files = map(stack.enter_context, map(h5py.File, self.files)) - self.cumsizes = np.cumsum([len(f['x']) for f in files]) + self.files = [h5py.File(f, mode='r') for f in files] + self.sizes = [f['theta'].shape[0] for f in self.files] + self.cumsizes = np.cumsum(self.sizes) self.batch_size = batch_size self.chunk_size = chunk_size @@ -160,47 +220,61 @@ def __getitem__(self, i: int) -> Tuple[Tensor, Tensor]: if j > 0: i = i - self.cumsizes[j - 1] - with h5py.File(self.files[j]) as f: - theta, x = f['theta'][i], f['x'][i] + f = self.files[j] + theta, x = f['theta'][i], f['x'][i] return torch.from_numpy(theta), torch.from_numpy(x) def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]: - with ExitStack() as stack: - files = list(map(stack.enter_context, map(h5py.File, self.files))) + chunks = torch.tensor([ + (i, j, j + self.chunk_size) + for i, size in enumerate(self.sizes) + for j in range(0, size, self.chunk_size) + ]) + + if self.shuffle: + order = torch.randperm(len(chunks)) + chunks = chunks[order] - chunks = torch.tensor([ - (i, j, j + self.chunk_size) - for i, f in enumerate(files) - for j in range(0, len(f['x']), self.chunk_size) - ]) + for slices in chunks.split(self.chunk_step): + slices = sorted(slices.tolist()) + # Load + theta = np.concatenate([self.files[i]['theta'][j:k] for i, j, k in slices]) + x = np.concatenate([self.files[i]['x'][j:k] for i, j, k in slices]) + + theta, x = torch.from_numpy(theta), torch.from_numpy(x) + + # Shuffle if self.shuffle: - order = torch.randperm(len(chunks)) - chunks = chunks[order] - - for slices in chunks.split(self.chunk_step): - slices = sorted(slices.tolist()) - - # Load - theta = np.concatenate([files[i]['theta'][j:k] for i, j, k in slices]) - x = np.concatenate([files[i]['x'][j:k] for i, j, k in slices]) - - theta, x = torch.from_numpy(theta), torch.from_numpy(x) - - # Shuffle - if self.shuffle: - order = torch.randperm(len(x)) - theta, x = theta[order], x[order] - - # Batch - if self.batch_size is None: - yield from zip(theta, x) - else: - yield from zip( - theta.split(self.batch_size), - x.split(self.batch_size), - ) + order = torch.randperm(len(theta)) + theta, x = theta[order], x[order] + + # Batch + if self.batch_size is None: + yield from zip(theta, x) + else: + yield from zip( + theta.split(self.batch_size), + x.split(self.batch_size), + ) + + def to_memory(self) -> JointDataset: + r"""Loads all pairs in memory and returns them as a :class:`JointDataset`. + + Example: + >>> dataset = H5Dataset('data.h5').to_memory() + """ + + theta = np.concatenate([f['theta'][:] for f in self.files]) + x = np.concatenate([f['x'][:] for f in self.files]) + + return JointDataset( + theta, + x, + batch_size=self.batch_size, + shuffle=self.shuffle, + ) @staticmethod def store( diff --git a/tests/test_data.py b/tests/test_data.py index f511fda..abc08ec 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,7 +1,6 @@ r"""Tests for the lampe.data module.""" import h5py -import math import numpy as np import pytest import torch @@ -71,6 +70,47 @@ def simulator(theta): assert theta1.dtype == x1.dtype == torch.float +def test_JointDataset(): + theta, x = torch.randn(1024, 5), torch.randn(1024, 16, 2) + + # Index + dataset = JointDataset(theta, x) + + for i in (0, 1, 2, 8, 32, 128, -1): + theta_i, x_i = dataset[i] + + assert (theta_i == theta[i]).all() + assert (x_i == x[i]).all() + + # Iter + for i, (theta_i, x_i) in enumerate(dataset): + assert (theta_i == theta[i]).all() + assert (x_i == x[i]).all() + + assert i == len(dataset) - 1 == len(theta) - 1 + + # Shuffle + dataset = JointDataset(theta, x, batch_size=256, shuffle=True) + + it = iter(dataset) + theta1, x1 = next(it) + theta2, x2 = next(it) + + assert len(theta1) == len(theta2) == 256 + assert len(x1) == len(x2) == 256 + + theta3, x3 = torch.cat((theta1, theta2)), torch.cat((x1, x2)) + + match = (theta3[:, None, :] == theta).all(dim=-1) + + assert (match.sum(dim=0) <= 1).all() + assert (match.sum(dim=-1) == 1).all() + + _, index = torch.nonzero(match, as_tuple=True) + + assert (x3 == x[index]).all() + + def test_H5Dataset(tmp_path): prior = torch.distributions.Normal(torch.zeros(3), torch.ones(3)) simulator = lambda theta: torch.repeat_interleave(theta, 2, dim=-1) @@ -83,7 +123,7 @@ def test_H5Dataset(tmp_path): H5Dataset.store(pairs, tmp_path / 'data_1.h5', size=4096) H5Dataset.store(iter(pairs), tmp_path / 'data_2.h5', size=4096) - H5Dataset.store(pairs, tmp_path / 'data_3.h5', size=256) + H5Dataset.store([(theta, x)], tmp_path / 'data_3.h5', size=256) with pytest.raises(FileExistsError): H5Dataset.store(pairs, tmp_path / 'data_1.h5', size=4096) @@ -97,29 +137,28 @@ def test_H5Dataset(tmp_path): assert len(dataset) in {256, 4096} ## __getitem__ - for i in map(lambda x: 2**x, range(int(math.log2(len(dataset))))): + for i in (0, 1, 2, 8, 32, 128): theta_i, x_i = dataset[i] - assert theta_i.shape == (3,) and x_i.shape == (6,) assert (theta_i == theta[i]).all() assert (x_i == x[i]).all() - ## __item__ + ## __iter__ for i, (theta_i, x_i) in enumerate(dataset): assert (theta_i == theta[i]).all() assert (x_i == x[i]).all() assert i == len(dataset) - 1 - ## Shuffle + # Shuffle dataset = H5Dataset(tmp_path / 'data_1.h5', batch_size=256, shuffle=True) it = iter(dataset) theta1, x1 = next(it) theta2, x2 = next(it) - assert theta1.shape == (256, 3) and x1.shape == (256, 6) - assert theta2.shape == (256, 3) and x2.shape == (256, 6) + assert len(theta1) == len(theta2) == 256 + assert len(x1) == len(x2) == 256 theta3, x3 = torch.cat((theta1, theta2)), torch.cat((x1, x2)) @@ -131,3 +170,9 @@ def test_H5Dataset(tmp_path): _, index = torch.nonzero(match, as_tuple=True) assert (x3 == x[index]).all() + + # Load in memory + new = dataset.to_memory() + + assert len(dataset) == len(new) + assert isinstance(new, JointDataset)