Skip to content

Commit

Permalink
make fix-copies
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Sep 23, 2024
1 parent bd576e7 commit 4ecc12a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def _pad_to_max_length(


class WhisperGenerationMixin:
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None):
def _extract_token_timestamps(
self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
):
"""
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
Expand All @@ -200,8 +202,8 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
# beam search takes `decoder_input_ids` into account in the `beam_indices` length

# beam search takes `decoder_input_ids` into account in the `beam_indices` length
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
weights = weights[:, :, :weight_length]
Expand All @@ -223,7 +225,9 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
# make sure timestamps are as long as weights
input_length = weight_length or cross_attentions[0].shape[2]
batch_size = generate_outputs.sequences.shape[0]
timestamps = torch.zeros((batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device)
timestamps = torch.zeros(
(batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
)

if num_frames is not None:
# two cases:
Expand Down Expand Up @@ -953,7 +957,10 @@ def _postprocess_outputs(
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames, num_input_ids=decoder_input_ids.shape[-1]
seek_outputs,
generation_config.alignment_heads,
num_frames=num_frames,
num_input_ids=decoder_input_ids.shape[-1],
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def forward(

causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :key_states.shape[-2]]
causal_mask = attention_mask[:, : key_states.shape[-2]]

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down

0 comments on commit 4ecc12a

Please sign in to comment.