From 4650263390b9a2ddaf6b254e54e665061f5b44d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Tue, 30 Aug 2022 13:37:10 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Implement=20balanced=20NRE=20(BNRE)?= =?UTF-8?q?=20loss=20(#3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lampe/inference.py | 57 +++++++++++++++++++++++++++++++++++++++++ tests/test_inference.py | 12 +++++++++ 2 files changed, 69 insertions(+) diff --git a/lampe/inference.py b/lampe/inference.py index d567c98..96b3bae 100644 --- a/lampe/inference.py +++ b/lampe/inference.py @@ -139,6 +139,63 @@ def forward(self, theta: Tensor, x: Tensor) -> Tensor: return l1 + l0 +class BNRELoss(nn.Module): + r"""Creates a module that calculates the loss :math:`l` of a balanced NRE (BNRE) + classifier :math:`d_\phi`. Given a batch of :math:`N` pairs + :math:`\{ (\theta_i, x_i) \}`, the module returns + + .. math:: + \begin{align} + l & = \frac{1}{N} \sum_{i = 1}^N + \ell(d_\phi(\theta_i, x_i)) + \ell(1 - d_\phi(\theta_{i+1}, x_i)) \\ + & + \lambda \left(1 - \frac{1}{N} \sum_{i = 1}^N + d_\phi(\theta_i, x_i) + d_\phi(\theta_{i+1}, x_i) + \right)^2 + \end{align} + + where :math:`\ell(p) = - \log p` is the negative log-likelihood. + + References: + Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation + (Delaunoy et al., 2022) + https://arxiv.org/abs/2208.13624 + + Arguments: + estimator: A classifier network :math:`d_\phi(\theta, x)`. + lmbda: The weight :math:`\lambda` controlling the strength of the balancing + condition. + """ + + def __init__(self, estimator: nn.Module, lmbda: float = 100.0): + super().__init__() + + self.estimator = estimator + self.lmbda = lmbda + + def forward(self, theta: Tensor, x: Tensor) -> Tensor: + r""" + Arguments: + theta: The parameters :math:`\theta`, with shape :math:`(N, D)`. + x: The observation :math:`x`, with shape :math:`(N, L)`. + + Returns: + The scalar loss :math:`l`. + """ + + theta_prime = torch.roll(theta, 1, dims=0) + + log_r, log_r_prime = self.estimator( + torch.stack((theta, theta_prime)), + x, + ) + + l1 = -F.logsigmoid(log_r).mean() + l0 = -F.logsigmoid(-log_r_prime).mean() + lb = (torch.sigmoid(log_r) + torch.sigmoid(log_r_prime) - 1).mean().square() + + return l1 + l0 + self.lmbda * lb + + class AMNRE(NRE): r"""Creates an arbitrary marginal neural ratio estimation (AMNRE) classifier network. diff --git a/tests/test_inference.py b/tests/test_inference.py index c6034f0..e9fd1e5 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -43,6 +43,18 @@ def test_NRELoss(): assert l.requires_grad +def test_BNRELoss(): + estimator = NRE(3, 5) + loss = BNRELoss(estimator) + + theta, x = randn(256, 3), randn(256, 5) + + l = loss(theta, x) + + assert l.shape == () + assert l.requires_grad + + def test_AMNRE(): estimator = AMNRE(3, 5)