From 9e8a5738186f167f396da5787865c205cb5ce4c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Mon, 22 Aug 2022 11:52:10 +0200 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=A8=20Add=20option=20to=20overwrite?= =?UTF-8?q?=20HDF5=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lampe/data.py | 5 ++++- tests/test_data.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/lampe/data.py b/lampe/data.py index 9d48cca..ba7df06 100644 --- a/lampe/data.py +++ b/lampe/data.py @@ -208,6 +208,7 @@ def store( pairs: Iterable[Tuple[Array, Array]], file: Union[str, Path], size: int, + overwrite: bool = False, dtype: np.dtype = np.float32, **meta, ) -> None: @@ -220,6 +221,8 @@ def store( pairs: An iterable over batched pairs :math:`(\theta, x)`. file: An HDF5 filename to store pairs in. size: The number of pairs to store. + overwrite: Whether to overwrite existing files or not. If :py:`False` + and the file already exists, the function raises an error. dtype: The data type to store pairs in. meta: Metadata to store in the file. @@ -236,7 +239,7 @@ def store( file = Path(file) file.parent.mkdir(parents=True, exist_ok=True) - with h5py.File(file, 'w-') as f: + with h5py.File(file, 'w' if overwrite else 'w-') as f: ## Attributes f.attrs.update(meta) diff --git a/tests/test_data.py b/tests/test_data.py index c5fb568..f511fda 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -88,6 +88,8 @@ def test_H5Dataset(tmp_path): with pytest.raises(FileExistsError): H5Dataset.store(pairs, tmp_path / 'data_1.h5', size=4096) + H5Dataset.store(pairs, tmp_path / 'data_1.h5', overwrite=True, size=4096) + # Load for file in tmp_path.glob('data_*.h5'): dataset = H5Dataset(file) From a73e24e091d502c36191e7320225b8796a49ef23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Mon, 22 Aug 2022 15:21:01 +0200 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=94=96=20Bump=20version=20to=200.5.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2d7bd22..7870315 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setuptools.setup( name='lampe', - version='0.5.0', + version='0.5.1', packages=setuptools.find_packages(), description='Likelihood-free AMortized Posterior Estimation with PyTorch', keywords='parameter inference bayes posterior amortized likelihood ratio mcmc torch', From ade4890abf3d87ccb34d8b4d255fab5b4090271c Mon Sep 17 00:00:00 2001 From: Arnaud Delaunoy Date: Mon, 22 Aug 2022 17:39:41 +0200 Subject: [PATCH 3/3] Add BNRELoss --- lampe/inference.py | 54 +++++++++++++++++++++++++++++++++++++++++ tests/test_inference.py | 12 +++++++++ 2 files changed, 66 insertions(+) diff --git a/lampe/inference.py b/lampe/inference.py index d567c98..bd98d62 100644 --- a/lampe/inference.py +++ b/lampe/inference.py @@ -139,6 +139,60 @@ def forward(self, theta: Tensor, x: Tensor) -> Tensor: return l1 + l0 +class BNRELoss(nn.Module): + r"""Creates a module that calculates the loss :math:`l` of a balanced NRE (BNRE) + classifier :math:`d_\phi`. Given a batch of :math:`N` pairs + :math:`\{ (\theta_i, x_i) \}`, the module returns + + .. math:: + \begin{align} + l & = \frac{1}{N} \sum_{i = 1}^N + \ell(d_\phi(\theta_i, x_i)) + \ell(1 - d_\phi(\theta_{i+1}, x_i)) \\ + & + \gamma \left(1 - \frac{1}{N} \sum_{i = 1}^N + d_\phi(\theta_i, x_i) + d_\phi(\theta_{i+1}, x_i) + \right)^2 + \end{align} + + where :math:`\ell(p) = - \log p` is the negative log-likelihood. + + References: + Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation + (Delaunoy et al., 2022) + + Arguments: + estimator: A classifier network :math:`d_\phi(\theta, x)`. + """ + + def __init__(self, estimator: nn.Module, gamma: float = 42.0): + super().__init__() + + self.estimator = estimator + self.gamma = gamma + + def forward(self, theta: Tensor, x: Tensor) -> Tensor: + r""" + Arguments: + theta: The parameters :math:`\theta`, with shape :math:`(N, D)`. + x: The observation :math:`x`, with shape :math:`(N, L)`. + + Returns: + The scalar loss :math:`l`. + """ + + theta_prime = torch.roll(theta, 1, dims=0) + + log_r, log_r_prime = self.estimator( + torch.stack((theta, theta_prime)), + x, + ) + + l1 = -F.logsigmoid(log_r).mean() + l0 = -F.logsigmoid(-log_r_prime).mean() + lb = (1 - torch.sigmoid(log_r) + torch.sigmoid(log_r_prime)).mean().square() + + return l1 + l0 + self.gamma * lb + + class AMNRE(NRE): r"""Creates an arbitrary marginal neural ratio estimation (AMNRE) classifier network. diff --git a/tests/test_inference.py b/tests/test_inference.py index c6034f0..e9fd1e5 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -43,6 +43,18 @@ def test_NRELoss(): assert l.requires_grad +def test_BNRELoss(): + estimator = NRE(3, 5) + loss = BNRELoss(estimator) + + theta, x = randn(256, 3), randn(256, 5) + + l = loss(theta, x) + + assert l.shape == () + assert l.requires_grad + + def test_AMNRE(): estimator = AMNRE(3, 5)