From a7a032ffaa524bc18d9e291c69285365cef0a9d7 Mon Sep 17 00:00:00 2001 From: Partho Date: Tue, 11 Oct 2022 00:24:36 +0530 Subject: [PATCH] wrap forward passes with torch.no_grad() (#19439) --- .../visual_bert/test_modeling_visual_bert.py | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/tests/models/visual_bert/test_modeling_visual_bert.py b/tests/models/visual_bert/test_modeling_visual_bert.py index 99db914072ccab..92ed812fe47d1e 100644 --- a/tests/models/visual_bert/test_modeling_visual_bert.py +++ b/tests/models/visual_bert/test_modeling_visual_bert.py @@ -568,14 +568,15 @@ def test_inference_vqa_coco_pre(self): attention_mask = torch.tensor([1] * 6).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) - output = model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - visual_embeds=visual_embeds, - visual_attention_mask=visual_attention_mask, - visual_token_type_ids=visual_token_type_ids, - ) + with torch.no_grad(): + output = model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + ) vocab_size = 30522 @@ -606,14 +607,15 @@ def test_inference_vqa(self): attention_mask = torch.tensor([1] * 6).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) - output = model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - visual_embeds=visual_embeds, - visual_attention_mask=visual_attention_mask, - visual_token_type_ids=visual_token_type_ids, - ) + with torch.no_grad(): + output = model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + ) # vocab_size = 30522 @@ -637,14 +639,15 @@ def test_inference_nlvr(self): attention_mask = torch.tensor([1] * 6).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) - output = model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - visual_embeds=visual_embeds, - visual_attention_mask=visual_attention_mask, - visual_token_type_ids=visual_token_type_ids, - ) + with torch.no_grad(): + output = model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + ) # vocab_size = 30522 @@ -667,14 +670,15 @@ def test_inference_vcr(self): visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long) visual_attention_mask = torch.ones_like(visual_token_type_ids) - output = model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - visual_embeds=visual_embeds, - visual_attention_mask=visual_attention_mask, - visual_token_type_ids=visual_token_type_ids, - ) + with torch.no_grad(): + output = model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + ) # vocab_size = 30522