From df895a9c95cff854bfc65fd2773a37bcd776a522 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 7 Jun 2023 11:56:57 +0100 Subject: [PATCH] Generate: increase left-padding test atol (#23448) increase atol --- tests/generation/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5f835917ea02db..4e09f21898fd3f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1608,7 +1608,6 @@ def test_generate_with_head_masking(self): attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) - @slow # TODO (Joao): fix GPTBigCode def test_left_padding_compatibility(self): # The check done in this test is fairly difficult -- depending on the model architecture, passing the right # position index for the position embeddings can still result in a different output, due to numerical masking. @@ -1648,7 +1647,7 @@ def test_left_padding_compatibility(self): position_ids.masked_fill_(padded_attention_mask == 0, 1) model_kwargs["position_ids"] = position_ids next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] - if not torch.allclose(next_logits_wo_padding, next_logits_with_padding): + if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7): no_failures = False break