From 1483f842e191d06aa7c81daf5781d7b2a54b72fe Mon Sep 17 00:00:00 2001 From: Kamil Akesbi <45195979+kamilakesbi@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:42:22 +0100 Subject: [PATCH] Support generating with fallback for short form audio in Whisper (#30984) * remove is_shortform * adapt _retrieve_max_frames_and_seek for short_form * return bos token in short and long form * add decoder_input_ids to short form audios * add eos token for short form * handle short form token_timestamps * no need to return scores * add is_shortform conditions * handle when max_new_tokens is None - short form * handle assistant decoding * fix * handle return_dict_in_generate * handle split_by_batch for encoder_attentions attribute * handle num_beams>1 * handle num_return_sequences>1 in generate_with_fallback * handle num_return_sequences>1 with return_dict_in_generate=True * raise error if max_new_tokens + decoder_inputs_ids > max_target_pos * fix * apply review suggestions * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fix * logits for both short form and long form * handle if logits_processor is None * test * apply review changes to num_return_sequences * add _expand_variables_for_generation * remove short form commented section * update comments * uncomment num_beams line in generate_with_fallback * update assistant decoding * handle return_segment with short form generation * up * fix output format is_shortform * overwrite beam_sample test * update _set_return_timestamps * apply review suggestions * apply review suggestions * remove seek_outputs_short_form * fix _stack_split_outputs * fix stack dim in _stack_split_outputs * update tests * fix past_key_values + beam tests * fix * clean _expand_variables_for_generation * make style * fix slow tests * make style * max_length condition * make style * add slow tests for shortform fallback * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review changes * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * up * fix slow tests * apply review suggestions * update test * make style * small fix * fix * fix test_new_cache_format * fix past_key_values * fix * make style * fix slow tests * fix --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/generation_whisper.py | 352 ++++++++++++------ tests/models/whisper/test_modeling_whisper.py | 344 +++++++++++++++++ 2 files changed, 575 insertions(+), 121 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index fc572c7389c822..0467362ea2c7ec 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -23,6 +23,8 @@ import torch.nn.functional as F from torch import nn +from transformers.cache_utils import EncoderDecoderCache + from ...generation.configuration_utils import GenerationConfig from ...generation.logits_process import ( LogitsProcessorList, @@ -116,9 +118,10 @@ def _dynamic_time_warping(matrix: np.ndarray): def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): - logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) - if logit_processor: - return getattr(logit_processor, attribute_name, None) + if logits_processor is not None: + logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) + if logit_processor: + return getattr(logit_processor, attribute_name, None) return None @@ -493,27 +496,15 @@ def generate( ) is_shortform = total_input_frames <= num_segment_frames - if is_shortform: - # warn user of ignored inputs - self._maybe_warn_unused_inputs( - condition_on_prev_tokens=condition_on_prev_tokens, - temperature=temperature, - compression_ratio_threshold=compression_ratio_threshold, - logprob_threshold=logprob_threshold, - no_speech_threshold=no_speech_threshold, - total_input_frames=total_input_frames, - ) - # 3. Make sure generation config is correctly set # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not self._set_return_outputs( return_dict_in_generate=return_dict_in_generate, return_token_timestamps=return_token_timestamps, - is_shortform=is_shortform, logprob_threshold=logprob_threshold, generation_config=generation_config, ) - self._set_return_timestamps( + timestamp_begin = self._set_return_timestamps( return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config ) self._set_language_and_task( @@ -554,85 +545,54 @@ def generate( generation_config=generation_config, logits_processor=logits_processor, begin_index=begin_index, # begin index is index of first generated decoder token - is_shortform=is_shortform, num_beams=kwargs.get("num_beams", 1), device=device, ) - # 5. If we're in shortform mode, simple generate the whole input at once and return the output - if is_shortform: - if temperature is not None: - generation_config.temperature = temperature - - decoder_input_ids = kwargs.pop("decoder_input_ids", None) - if decoder_input_ids is None: - decoder_input_ids = init_tokens - - if prompt_ids is not None: - decoder_input_ids = torch.cat( - [prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1 - ) - - max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0 - if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions: - raise ValueError( - f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " - f"is {max_new_tokens}. Thus, the combined length of " - f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " - f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " - "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less than {self.config.max_target_positions}." - ) - - outputs = super().generate( - input_features, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - decoder_input_ids=decoder_input_ids, - **kwargs, - ) - - if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"): - outputs["token_timestamps"] = self._extract_token_timestamps( - outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames - ) - - return outputs - - # 6. Else we're in longform mode which is more complex. - # We need to chunk the audio input depending on when the model generates timestamp tokens - - # 6.1 Set and retrieve global longform generation variables + # 4 Set and retrieve global generation variables self._set_condition_on_prev_tokens( condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config ) - timestamp_begin = generation_config.no_timestamps_token_id + 1 temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature temperature = temperatures[0] - batch_size = input_features.shape[0] max_frames, seek = self._retrieve_max_frames_and_seek( - batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames + batch_size=batch_size, + attention_mask=attention_mask, + total_input_frames=total_input_frames, + is_shortform=is_shortform, ) - # 6.2 Preppare running variables, list for generation - cur_bsz = batch_size - current_segments = self._prepare_segments( - prompt_ids=prompt_ids, + # 5 Prepare running variables, list for generation + num_return_sequences = generation_config.num_return_sequences + ( + batch_idx_map, + cur_bsz, + input_features, + seek, + max_frames, + init_tokens, + do_condition_on_prev_tokens, + ) = self._expand_variables_for_generation( + input_features=input_features, + seek=seek, + max_frames=max_frames, + init_tokens=init_tokens, batch_size=batch_size, + condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config, ) - batch_idx_map = list(range(batch_size)) - do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] + current_segments = self._prepare_segments( + prompt_ids=prompt_ids, + batch_size=cur_bsz, + generation_config=generation_config, + ) - # 6.2 Transcribe audio until we reach the end of all input audios + # 6 Transcribe audio until we reach the end of all input audios while (seek < max_frames).any(): - # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop + # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order # to know which original audio is being decoded # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk @@ -646,7 +606,7 @@ def generate( time_offset = seek * time_precision / input_stride seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) - # 6.4 cut out next 30s segment from input features + # 6.2 cut out next 30s segment from input features segment_input = self._get_input_segment( input_features=input_features, seek=seek, @@ -656,10 +616,11 @@ def generate( batch_idx_map=batch_idx_map, ) - # 6.5 prepare decoder input ids + # 6.3 prepare decoder input ids suppress_tokens = _get_attr_from_logit_processors( logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" ) + decoder_input_ids, kwargs = self._prepare_decoder_input_ids( cur_bsz=cur_bsz, init_tokens=init_tokens, @@ -669,25 +630,32 @@ def generate( prompt_ids=prompt_ids, generation_config=generation_config, config=self.config, - device=segment_input.device, + device=init_tokens.device, suppress_tokens=suppress_tokens, kwargs=kwargs, ) - # 6.6 set max new tokens or max length + # 6.4 set max new tokens or max length self._set_max_new_tokens_and_length( config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, ) - # 6.7 Set current `begin_index` for all logit processors - for proc in logits_processor: - if hasattr(proc, "set_begin_index"): - proc.set_begin_index(decoder_input_ids.shape[-1]) - - # 6.8 Run generate with fallback - seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback( + # 6.5 Set current `begin_index` for all logit processors + if logits_processor is not None: + for proc in logits_processor: + if hasattr(proc, "set_begin_index"): + proc.set_begin_index(decoder_input_ids.shape[-1]) + + # 6.6 Run generate with fallback + ( + seek_sequences, + seek_outputs, + should_skip, + do_condition_on_prev_tokens, + model_output_type, + ) = self.generate_with_fallback( segment_input=segment_input, decoder_input_ids=decoder_input_ids, cur_bsz=cur_bsz, @@ -703,10 +671,11 @@ def generate( synced_gpus=synced_gpus, return_token_timestamps=return_token_timestamps, do_condition_on_prev_tokens=do_condition_on_prev_tokens, + is_shortform=is_shortform, kwargs=kwargs, ) - # 6.9 In every generated sequence, split by timestamp tokens and extract segments + # 6.7 In every generated sequence, split by timestamp tokens and extract segments for i, seek_sequence in enumerate(seek_sequences): prev_i = batch_idx_map[i] @@ -728,7 +697,11 @@ def generate( ) current_segments[prev_i] += segments - seek[prev_i] += segment_offset + + if is_shortform: + seek[prev_i] += max_frames[i] + else: + seek[prev_i] += segment_offset # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output @@ -737,6 +710,7 @@ def generate( if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment") else current_segments ) + sequences = _pad_to_max_length( final_segments, generation_config.pad_token_id, device=self.device, padding="right" ) @@ -745,6 +719,42 @@ def generate( if return_segments: return {"sequences": sequences, "segments": final_segments} + if is_shortform: + # add eos token: + if generation_config.max_new_tokens is None and generation_config.max_length is None: + eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id) + sequences = torch.cat([sequences, eos_tokens], dim=-1) + + if return_token_timestamps: + outputs = {} + outputs["sequences"] = sequences + outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0) + else: + outputs = sequences + + if generation_config.return_dict_in_generate: + dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs) + + if num_return_sequences > 1: + if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None: + dict_outputs.encoder_attentions = tuple( + dict_outputs.encoder_attentions[i][::num_return_sequences] + for i in range(len(dict_outputs.encoder_attentions)) + ) + if ( + hasattr(dict_outputs, "encoder_hidden_states") + and dict_outputs.encoder_hidden_states is not None + ): + dict_outputs.encoder_hidden_states = tuple( + dict_outputs.encoder_hidden_states[i][::num_return_sequences] + for i in range(len(dict_outputs.encoder_hidden_states)) + ) + if return_token_timestamps: + dict_outputs["token_timestamps"] = outputs["token_timestamps"] + return dict_outputs + + return outputs + return sequences def generate_with_fallback( @@ -764,6 +774,7 @@ def generate_with_fallback( synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, + is_shortform, kwargs, ): kwargs = copy.copy(kwargs) @@ -774,7 +785,6 @@ def generate_with_fallback( needs_fallback = [False for _ in range(cur_bsz)] should_skip = [False for _ in range(cur_bsz)] fallback_index_map = list(range(cur_bsz)) - if generation_config.no_speech_threshold is not None: self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs) @@ -799,12 +809,15 @@ def generate_with_fallback( **generate_kwargs, ) + model_output_type = type(seek_outputs) + # post-process sequence tokens and outputs to be in list form seek_sequences, seek_outputs = self._postprocess_outputs( seek_outputs=seek_outputs, decoder_input_ids=decoder_input_ids, return_token_timestamps=return_token_timestamps, generation_config=generation_config, + is_shortform=is_shortform, ) # 6.7 Extract cut sequences from every sequence and check if fallback should be applied @@ -822,14 +835,14 @@ def generate_with_fallback( # remove eos token id if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: seek_sequence = seek_sequence[:-1] - if return_token_timestamps: + if return_token_timestamps and not is_shortform: seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1] # remove all padding tokens if seek_sequence[-1] == generation_config.pad_token_id: num_paddings = (seek_sequence == generation_config.pad_token_id).sum() seek_sequence = seek_sequence[:-num_paddings] - if return_token_timestamps: + if return_token_timestamps and not is_shortform: seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings] # check which sequences in batch need fallback & which should be skipped @@ -871,7 +884,7 @@ def generate_with_fallback( if "decoder_attention_mask" in kwargs: kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask) - return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens + return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type @staticmethod def _prepare_segments(prompt_ids, batch_size, generation_config): @@ -884,10 +897,14 @@ def _prepare_segments(prompt_ids, batch_size, generation_config): return current_segments - def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config): + def _postprocess_outputs( + self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform + ): # remove all previously passed decoder input ids + start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0) + if isinstance(seek_outputs, torch.Tensor): - seek_outputs = seek_outputs[:, decoder_input_ids.shape[-1] :] + seek_outputs = seek_outputs[:, start_idx:] return seek_outputs, seek_outputs if return_token_timestamps and hasattr(generation_config, "alignment_heads"): @@ -895,28 +912,72 @@ def _postprocess_outputs(self, seek_outputs, decoder_input_ids, return_token_tim seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs, generation_config.alignment_heads, num_frames=num_frames ) - seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :] + seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:] - seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :] + seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:] - def split_by_batch_index(values, key, batch_idx): - if key == "scores": + def split_by_batch_index(values, key, batch_idx, is_shortform): + if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]: return [v[batch_idx].cpu() for v in values] - elif key == "past_key_values": - # we don't save `past_key_values` as this is too costly - return None - elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]): + if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]: return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values) + elif key == "past_key_values": + if not is_shortform: + # we don't save `past_key_values` as this is too costly for longform + return None + else: + return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values))) + return values[batch_idx].cpu() sequence_tokens = seek_outputs["sequences"] + + if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None: + if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache): + seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache() + seek_outputs = [ - {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} + {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()} for i in range(sequence_tokens.shape[0]) ] return sequence_tokens, seek_outputs + def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): + # Stack back seek_outputs tensors after splitting them with the split_by_batch_index method + outputs = {} + for key in seek_outputs[0].keys(): + if key == "sequences": + outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device) + if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]: + outputs[key] = tuple( + torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key])) + ) + if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]: + outputs[key] = tuple( + tuple( + torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device) + for j in range(len(seek_outputs[0][key][0])) + ) + for i in range(len(seek_outputs[0][key])) + ) + if key == "past_key_values": + past_key_value_type = kwargs.get("past_key_values") + if seek_outputs[0][key] is not None: + outputs[key] = tuple( + tuple( + torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device) + for j in range(len(seek_outputs[0][key][0])) + ) + for i in range(len(seek_outputs[0][key])) + ) + if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache): + outputs[key] = past_key_value_type.from_legacy_cache(outputs[key]) + else: + outputs[key] = None + + return model_output_type(**outputs) + def _need_fallback( self, seek_sequence, @@ -936,7 +997,7 @@ def _need_fallback( needs_fallback = True if generation_config.logprob_threshold is not None: - if "sequences_scores" in seek_outputs[0]: + if hasattr(seek_outputs[0], "sequences_scores"): logprobs = [s["sequences_scores"] for s in seek_outputs][index] else: scores = seek_outputs[index]["scores"] @@ -961,6 +1022,33 @@ def _need_fallback( return needs_fallback, should_skip + def _expand_variables_for_generation( + self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config + ): + if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1: + batch_idx_map = list(range(batch_size * generation_config.num_return_sequences)) + cur_bsz = len(batch_idx_map) + do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))] + input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0) + seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0) + max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0) + init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0) + generation_config.num_return_sequences = 1 + else: + cur_bsz = batch_size + batch_idx_map = list(range(cur_bsz)) + do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)] + + return ( + batch_idx_map, + cur_bsz, + input_features, + seek, + max_frames, + init_tokens, + do_condition_on_prev_tokens, + ) + @staticmethod def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") @@ -1018,9 +1106,7 @@ def _maybe_warn_unused_inputs( ) @staticmethod - def _set_return_outputs( - return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config - ): + def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config): if return_dict_in_generate is None: return_dict_in_generate = generation_config.return_dict_in_generate @@ -1030,14 +1116,13 @@ def _set_return_outputs( generation_config.output_attentions = True generation_config.output_scores = True - if not is_shortform and logprob_threshold is not None: + if logprob_threshold is not None: return_dict_in_generate = True generation_config.output_scores = True generation_config.return_dict_in_generate = return_dict_in_generate - @staticmethod - def _set_return_timestamps(return_timestamps, is_shortform, generation_config): + def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config): if not is_shortform: if return_timestamps is False: raise ValueError( @@ -1057,6 +1142,15 @@ def _set_return_timestamps(return_timestamps, is_shortform, generation_config): generation_config.return_timestamps = return_timestamps + if hasattr(generation_config, "no_timestamps_token_id"): + timestamp_begin = generation_config.no_timestamps_token_id + 1 + else: + # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps + # We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop + timestamp_begin = self.config.vocab_size + 1 + + return timestamp_begin + @staticmethod def _set_language_and_task(language, task, is_multilingual, generation_config): if is_multilingual is not None: @@ -1388,23 +1482,21 @@ def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): generation_config.condition_on_prev_tokens = condition_on_prev_tokens @staticmethod - def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames): - if batch_size > 1 and attention_mask is None: + def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform): + if batch_size > 1 and not is_shortform and attention_mask is None: raise ValueError( "When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " ) - elif batch_size > 1: + elif batch_size > 1 and not is_shortform: max_frames = attention_mask.sum(-1).cpu().to(torch.long) seek = torch.zeros((batch_size,), dtype=torch.long) else: - max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames - seek = torch.zeros((1,), dtype=torch.long) + max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames + seek = torch.zeros((batch_size,), dtype=torch.long) return max_frames, seek - def _retrieve_logit_processors( - self, generation_config, logits_processor, begin_index, is_shortform, num_beams, device - ): + def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device): if generation_config.return_timestamps is True: timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) logits_processor = ( @@ -1431,7 +1523,7 @@ def _retrieve_logit_processors( ) generation_config.begin_suppress_tokens = None - if generation_config.no_speech_threshold is not None and not is_shortform: + if generation_config.no_speech_threshold is not None: no_speech_detector = WhisperNoSpeechDetection( no_speech_token=generation_config.no_timestamps_token_id - 1, begin_index=begin_index, @@ -1462,6 +1554,9 @@ def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map @staticmethod def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map): + if input_features is None: + return None + segment_input = [] for i in range(cur_bsz): prev_i = batch_idx_map[i] @@ -1493,6 +1588,11 @@ def _prepare_decoder_input_ids( suppress_tokens, kwargs, ): + if "decoder_input_ids" in kwargs: + decoder_input_ids = kwargs.pop("decoder_input_ids") + + return decoder_input_ids, kwargs + cut_off_length = config.max_target_positions // 2 - 1 decoder_input_ids = init_tokens[batch_idx_map] @@ -1533,8 +1633,18 @@ def _prepare_decoder_input_ids( return decoder_input_ids, kwargs - @staticmethod - def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config): + def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config): + max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0 + if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions: + raise ValueError( + f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " + f"is {max_new_tokens}. Thus, the combined length of " + f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " + f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " + "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " + f"so that their combined length is less than {self.config.max_target_positions}." + ) + num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) # Make sure we don't get larger than `max_length` diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index faf6c567ca82d4..a11097fe7dc391 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -65,6 +65,15 @@ WhisperProcessor, set_seed, ) + from transformers.generation import ( + BeamSampleDecoderOnlyOutput, + BeamSampleEncoderDecoderOutput, + BeamSearchDecoderOnlyOutput, + BeamSearchEncoderDecoderOutput, + GenerateBeamDecoderOnlyOutput, + GenerateBeamEncoderDecoderOutput, + PhrasalConstraint, + ) from transformers.generation.logits_process import LogitsProcessor from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids @@ -1539,6 +1548,241 @@ def test_longform_generate_multi_batch(self): def test_longform_generate_multi_batch_cond_prev(self): self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) + def test_beam_sample_generate_dict_output(self): + # We overwrite test_beam_sample_generate_dict_output in test_utils as + # we can only perform beam search if the temperature is set to 0 in Whisper. + config, input_ids, attention_mask = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + model = WhisperForConditionalGeneration(config).to(torch_device).eval() + _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) + beam_kwargs = self._get_beam_kwargs() + + # With Whisper, we can only perform a beam search if the temperature is set to 0. + logits_warper_kwargs["temperature"] = 0 + # We will return num_beams sequences per input only if num_return_sequences == num_beams: + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + + output_generate = self._beam_sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + beam_kwargs=beam_kwargs, + logits_warper_kwargs=logits_warper_kwargs, + output_scores=True, + output_logits=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) + + self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]) + + def test_beam_search_generate_dict_output(self): + # We overwrite test_beam_search_generate_dict_output in test_utils as + # we can only perform beam search if the temperature is set to 0 in Whisper. + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + model = model_class(config).to(torch_device).eval() + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( + input_ids.shape[-1], + config.forced_bos_token_id, + config.forced_eos_token_id, + ) + beam_kwargs = self._get_beam_kwargs() + + # With Whisper, we can only perform a beam search if the temperature is set to 0. + logits_process_kwargs["temperature"] = 0 + # We will return num_beams sequences per input only if num_return_sequences == num_beams: + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + + output_generate = self._beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + output_scores=True, + output_logits=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] + ) + + def test_beam_search_generate_dict_outputs_use_cache(self): + # We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as + # we can only perform beam search if the temperature is set to 0 in Whisper. + for model_class in self.all_generative_model_classes: + # enable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + + if not hasattr(config, "use_cache"): + self.skipTest("This model doesn't support caching") + + model = model_class(config).to(torch_device).eval() + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( + input_ids.shape[-1], + config.forced_bos_token_id, + config.forced_eos_token_id, + ) + + beam_kwargs = self._get_beam_kwargs() + + # We will return num_beams sequences per input only if num_return_sequences == num_beams: + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_generate = self._beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + output_scores=True, + output_logits=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self._check_outputs( + output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] + ) + + def test_group_beam_search_generate_dict_output(self): + # We overwrite test_group_beam_search_generate_dict_output in test_utils as + # we can only perform beam search if the temperature is set to 0 in Whisper. + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = False + + model = model_class(config).to(torch_device).eval() + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( + input_ids.shape[-1], + config.forced_bos_token_id, + config.forced_eos_token_id, + ) + + beam_kwargs = self._get_diverse_beam_kwargs() + + # We will return num_beams sequences per input only if num_return_sequences == num_beams: + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + + output_generate = self._group_beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + output_scores=True, + output_logits=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] + ) + + def test_constrained_beam_search_generate_dict_output(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + model = model_class(config).to(torch_device).eval() + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( + input_ids.shape[-1], + config.forced_bos_token_id, + config.forced_eos_token_id, + ) + + # Sample constraints + min_id = 3 + max_id = model.config.vocab_size + force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] + constraints = [ + PhrasalConstraint(force_tokens), + ] + + beam_kwargs = self._get_constrained_beam_kwargs() + output_generate = self._constrained_beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + constraints=constraints, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + output_scores=True, + output_logits=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) + # Retrocompatibility check + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self._check_outputs( + output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"] + ) + def test_custom_4d_attention_mask(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32) @@ -2680,6 +2924,55 @@ def test_whisper_longform_single_batch_prev_cond(self): assert decoded == EXPECTED_TEXT + @slow + def test_whisper_shortform_single_batch_prev_cond(self): + # fmt: off + EXPECTED_TEXT = [" Folks, I spend a lot of time right over there, night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing and the most topical antilock breaks and power steering pain, Stakingly stitching, leather seating so soft, it would make JD power and her associate blush. If you were to create the luxury sedan that is my nightly model, but sometimes— you're sometimes, folks— I lurched the consciousness and the back of an abandoned school bus"] + EXPECTED_TEXT1 = [" Folks, I spend a lot of time right over there night after night after, actually. Carefully selecting for you the day's noisiest, most aerodynamic headlines, stress testing, and the most topical, anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school"] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model = model.to(torch_device) + + ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + dataset = ds.cast_column("audio", Audio(sampling_rate=16000)) + + one_audio = dataset[1]["audio"]["array"] + + input_features = processor(one_audio, return_tensors="pt", sampling_rate=16_000)["input_features"] + input_features = input_features.to(device=torch_device) + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + } + + torch.manual_seed(0) + result = model.generate(input_features, **gen_kwargs) + decoded = processor.batch_decode(result.sequences, skip_special_tokens=True) + + assert decoded == EXPECTED_TEXT + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.3, + "temperature": (0.0, 0.2), + "compression_ratio_threshold": 1, + "condition_on_prev_tokens": False, + "logprob_threshold": -1.0, + } + + torch.manual_seed(0) + result = model.generate(input_features, **gen_kwargs) + decoded = processor.batch_decode(result.sequences, skip_special_tokens=True) + + assert decoded == EXPECTED_TEXT1 + @slow def test_whisper_longform_single_batch_beam(self): # fmt: off @@ -2931,6 +3224,57 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): elif isinstance(EXPECTED_TEXT[i], tuple): assert decoded_all[i] in EXPECTED_TEXT[i] + @slow + def test_whisper_shortform_multi_batch_hard_prev_cond(self): + # Without this set here, this test may fail if it is run with other tests (say, `test_tiny_*`). It's unclear + # why other tests may affect this tests: it seems some random operations are beyond the scene. + set_seed(0) + # fmt: off + EXPECTED_TEXT = [ + ' Mr. Kfilter is the apostle of the Middle Classes and we are glad to welcome his gospel.', + " Nor is Mr. Qilter's manner less interesting than his matter.", + ' He tells us that at this festive season of the year, with Christmas and roce beef, looming before us, similarly drawn from eating and its results occur most readily to the mind.', + ' He has grabbed those with her surfered trigger late and his work is really a great after all, and can discover it in it but little of Rocky Ithaka.', + " L'Neile's pictures are a sort of upguards and add-um paintings, and Maessin's exquisite Itals are a national as a jingo poem. Mr. Birkett Foster's landscapes smiled at one much in the same way that Mr. Carcher used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slapper in the back, before he says,", + ' It is obviously unnecessary for us, to point out how luminous these criticisms are, how delicate and expression.', + ' On the general principles of art and Mr. Kriltor rights with equal lucidity.', + ' Painting, he tells us is of a different quality to mathematics and finish in art is adding more effect.', + ] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model = model.to(torch_device) + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + num_samples = 8 + + audio = ds[:num_samples]["audio"] + audios = [x["array"] for x in audio] + + inputs = processor( + audios, + return_tensors="pt", + sampling_rate=16_000, + ) + inputs = inputs.to(device=torch_device) + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + } + + result = model.generate(**inputs, **gen_kwargs) + decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True) + + for i in range(num_samples): + if isinstance(EXPECTED_TEXT[i], str): + assert decoded_all[i] == EXPECTED_TEXT[i] + @slow def test_whisper_longform_no_speech_detection(self): # fmt: off