Skip to content

Commit

Permalink
fix shortform batch prev cond tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Sep 23, 2024
1 parent d1ad495 commit bd576e7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2863,7 +2863,7 @@ def test_whisper_shortform_single_batch_prev_cond(self):

torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result["sequences"], skip_special_tokens=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)

assert decoded == EXPECTED_TEXT

Expand All @@ -2878,7 +2878,7 @@ def test_whisper_shortform_single_batch_prev_cond(self):

torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result["sequences"], skip_special_tokens=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)

assert decoded == EXPECTED_TEXT1

Expand Down Expand Up @@ -3178,7 +3178,7 @@ def test_whisper_shortform_multi_batch_hard_prev_cond(self):
}

result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result["sequences"], skip_special_tokens=True)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)

for i in range(num_samples):
if isinstance(EXPECTED_TEXT[i], str):
Expand Down

0 comments on commit bd576e7

Please sign in to comment.