From 66cbaacc08fbcfe14b9bd75c53a78e59a3ec1db2 Mon Sep 17 00:00:00 2001 From: y0ast Date: Tue, 23 Feb 2021 18:38:42 +0000 Subject: [PATCH 1/2] add test for new loss detach --- tests/ignite/metrics/test_loss.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 592824b1b02e..04fbd8e2185c 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -54,6 +54,23 @@ def test_non_averaging_loss(): loss.update((y_pred, y)) +def test_gradient_based_loss(): + x = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True) + y_pred = x.mm(torch.randn(size=(3, 1))) + + def loss_fn(y_pred, x): + gradients = torch.autograd.grad( + outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, + )[0] + + gradients = gradients.flatten(start_dim=1) + + return gradients.norm(2, dim=1).mean() + + loss = Loss(loss_fn) + loss.update((y_pred, x)) + + def test_kwargs_loss(): loss = Loss(nll_loss) From 069a12e4a2dbe39353806b6e75e7f78f10d0f7a2 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 23 Feb 2021 20:23:17 +0100 Subject: [PATCH 2/2] Update tests/ignite/metrics/test_loss.py --- tests/ignite/metrics/test_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 04fbd8e2185c..ae5dba483220 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -55,6 +55,7 @@ def test_non_averaging_loss(): def test_gradient_based_loss(): + # Tests https://github.com/pytorch/ignite/issues/1674 x = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True) y_pred = x.mm(torch.randn(size=(3, 1)))