diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 592824b1b02e..ae5dba483220 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -54,6 +54,24 @@ def test_non_averaging_loss(): loss.update((y_pred, y)) +def test_gradient_based_loss(): + # Tests https://github.com/pytorch/ignite/issues/1674 + x = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True) + y_pred = x.mm(torch.randn(size=(3, 1))) + + def loss_fn(y_pred, x): + gradients = torch.autograd.grad( + outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, + )[0] + + gradients = gradients.flatten(start_dim=1) + + return gradients.norm(2, dim=1).mean() + + loss = Loss(loss_fn) + loss.update((y_pred, x)) + + def test_kwargs_loss(): loss = Loss(nll_loss)