Skip to content

Commit

Permalink
✨ Implement balanced NRE (BNRE) loss (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Aug 30, 2022
1 parent 7bf8430 commit 4650263
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
57 changes: 57 additions & 0 deletions lampe/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,63 @@ def forward(self, theta: Tensor, x: Tensor) -> Tensor:
return l1 + l0


class BNRELoss(nn.Module):
r"""Creates a module that calculates the loss :math:`l` of a balanced NRE (BNRE)
classifier :math:`d_\phi`. Given a batch of :math:`N` pairs
:math:`\{ (\theta_i, x_i) \}`, the module returns
.. math::
\begin{align}
l & = \frac{1}{N} \sum_{i = 1}^N
\ell(d_\phi(\theta_i, x_i)) + \ell(1 - d_\phi(\theta_{i+1}, x_i)) \\
& + \lambda \left(1 - \frac{1}{N} \sum_{i = 1}^N
d_\phi(\theta_i, x_i) + d_\phi(\theta_{i+1}, x_i)
\right)^2
\end{align}
where :math:`\ell(p) = - \log p` is the negative log-likelihood.
References:
Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation
(Delaunoy et al., 2022)
https://arxiv.org/abs/2208.13624
Arguments:
estimator: A classifier network :math:`d_\phi(\theta, x)`.
lmbda: The weight :math:`\lambda` controlling the strength of the balancing
condition.
"""

def __init__(self, estimator: nn.Module, lmbda: float = 100.0):
super().__init__()

self.estimator = estimator
self.lmbda = lmbda

def forward(self, theta: Tensor, x: Tensor) -> Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(N, D)`.
x: The observation :math:`x`, with shape :math:`(N, L)`.
Returns:
The scalar loss :math:`l`.
"""

theta_prime = torch.roll(theta, 1, dims=0)

log_r, log_r_prime = self.estimator(
torch.stack((theta, theta_prime)),
x,
)

l1 = -F.logsigmoid(log_r).mean()
l0 = -F.logsigmoid(-log_r_prime).mean()
lb = (torch.sigmoid(log_r) + torch.sigmoid(log_r_prime) - 1).mean().square()

return l1 + l0 + self.lmbda * lb


class AMNRE(NRE):
r"""Creates an arbitrary marginal neural ratio estimation (AMNRE) classifier
network.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def test_NRELoss():
assert l.requires_grad


def test_BNRELoss():
estimator = NRE(3, 5)
loss = BNRELoss(estimator)

theta, x = randn(256, 3), randn(256, 5)

l = loss(theta, x)

assert l.shape == ()
assert l.requires_grad


def test_AMNRE():
estimator = AMNRE(3, 5)

Expand Down

0 comments on commit 4650263

Please sign in to comment.