diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index f30cfe19476504..1d6639fa44f4b0 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -122,7 +122,9 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att return None -def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): +def _pad_to_max_length( + current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None +): max_total_length = 0 sequences = [] if padding not in ["right", "left"]: @@ -143,7 +145,7 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke elif bos_token_tensor is not None: sequences.append(bos_token_tensor) else: - sequences.append(torch.tensor([])) + sequences.append(torch.tensor([], device=device)) for i in range(len(current_segments)): pad_length = max_total_length - len(sequences[i]) @@ -733,7 +735,9 @@ def generate( if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment") else current_segments ) - sequences = _pad_to_max_length(final_segments, generation_config.pad_token_id, padding="right") + sequences = _pad_to_max_length( + final_segments, generation_config.pad_token_id, device=self.device, padding="right" + ) # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. if return_segments: @@ -1506,6 +1510,7 @@ def _prepare_decoder_input_ids( prev_tokens = _pad_to_max_length( active_segments, generation_config.pad_token_id, + device=device, padding="left", bos_token_tensor=prev_ids, cut_off_length=cut_off_length, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 70e37d76492db1..18b1eb36ccf442 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -35,6 +35,7 @@ require_torch, require_torch_fp16, require_torch_gpu, + require_torch_multi_gpu, require_torchaudio, slow, torch_device, @@ -2866,6 +2867,81 @@ def test_whisper_longform_no_speech_detection(self): for i in range(num_samples): assert decoded_all[i] == EXPECTED_TEXT[i] + @require_torch_gpu + @slow + def test_whisper_empty_longform(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model = model.to(torch_device) + + ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + ds = ds.cast_column("audio", Audio(sampling_rate=16000)) + + num_samples = 8 + + audio = ds[:num_samples]["audio"] + audios = [x["array"] for x in audio] + audios[0][:] = np.zeros(audios[0].shape) + + inputs = processor( + audios, + return_tensors="pt", + truncation=False, + padding="longest", + return_attention_mask=True, + sampling_rate=16_000, + ) + inputs = inputs.to(device=torch_device) + + gen_kwargs = { + "no_speech_threshold": 0.2, + "temperature": (0.0,), + "logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob + "num_beams": 5, + "language": "fr", + "task": "transcribe", + } + + torch.manual_seed(0) + model.generate(**inputs, **gen_kwargs) + + @require_torch_multi_gpu + @slow + def test_whisper_empty_longform_multi_gpu(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", device_map="auto") + + ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + ds = ds.cast_column("audio", Audio(sampling_rate=16000)) + + num_samples = 8 + + audio = ds[:num_samples]["audio"] + audios = [x["array"] for x in audio] + audios[0][:] = np.zeros(audios[0].shape) + + inputs = processor( + audios, + return_tensors="pt", + truncation=False, + padding="longest", + return_attention_mask=True, + sampling_rate=16_000, + ) + inputs = inputs.to(device=model.device) + + gen_kwargs = { + "no_speech_threshold": 0.2, + "temperature": (0.0,), + "logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob + "num_beams": 5, + "language": "fr", + "task": "transcribe", + } + + torch.manual_seed(0) + model.generate(**inputs, **gen_kwargs) + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: