Skip to content

Commit 23b7e73

Browse files
authored
fix test_compare_unprocessed_logit_scores (#39053)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 58c7689 commit 23b7e73

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/generation/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3807,7 +3807,7 @@ def test_compare_unprocessed_logit_scores(self):
38073807
logits_gen = outputs.logits[0][0]
38083808

38093809
# assert that unprocessed logits from generate() are same as those from modal eval()
3810-
self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist())
3810+
torch.testing.assert_allclose(logits_fwd.tolist(), logits_gen.tolist())
38113811

38123812
def test_return_unprocessed_logit_scores(self):
38133813
# tell model to generate text and return unprocessed/unwarped logit scores

0 commit comments

Comments
 (0)