From 5b1d4623c0611beb927d2a96d1f6cbb0e11ce71e Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 11 Oct 2022 00:33:46 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19413) --- tests/models/fnet/test_modeling_fnet.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py index 974d7c2d4e5d63..5d975b061f75f3 100644 --- a/tests/models/fnet/test_modeling_fnet.py +++ b/tests/models/fnet/test_modeling_fnet.py @@ -493,7 +493,8 @@ def test_inference_for_masked_lm(self): model.to(torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] vocab_size = 32000 @@ -536,7 +537,8 @@ def test_inference_for_next_sentence_prediction(self): model.to(torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size((1, 2)) self.assertEqual(output.shape, expected_shape) @@ -551,7 +553,8 @@ def test_inference_model(self): model.to(torch_device) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device) - output = model(input_ids)[0] + with torch.no_grad(): + output = model(input_ids)[0] expected_shape = torch.Size((1, 6, model.config.hidden_size)) self.assertEqual(output.shape, expected_shape)