Skip to content

Commit

Permalink
add is_shortform conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
kamilakesbi committed May 24, 2024
1 parent 7519bac commit 956cfb4
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att
return None


def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_id=None, cut_off_length=None):

def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None):
max_total_length = 0
sequences = []
if padding not in ["right", "left"]:
Expand All @@ -136,14 +135,12 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke
if cut_off_length is not None:
sequence = sequence[-cut_off_length:]

if bos_token_id is not None:
bos_token_tensor = torch.tensor([bos_token_id]).to(sequence.device)
if bos_token_tensor is not None:
sequence = torch.cat([bos_token_tensor, sequence])

sequences.append(sequence)
max_total_length = max(max_total_length, len(sequences[-1]))
elif bos_token_id is not None:
bos_token_tensor = torch.tensor([bos_token_id]).to(sequence.device)
elif bos_token_tensor is not None:
sequences.append(bos_token_tensor)
else:
sequences.append(torch.tensor([]))
Expand Down Expand Up @@ -611,7 +608,11 @@ def generate(
condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
)

timestamp_begin = generation_config.no_timestamps_token_id + 1
if not is_shortform:
timestamp_begin = generation_config.no_timestamps_token_id + 1
else:
timestamp_begin = None

temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
temperature = temperatures[0]
batch_size = input_features.shape[0]
Expand Down Expand Up @@ -658,9 +659,13 @@ def generate(
)

# 6.5 prepare decoder input ids
suppress_tokens = _get_attr_from_logit_processors(
logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
)
if not is_shortform:
suppress_tokens = _get_attr_from_logit_processors(
logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
)
else:
suppress_tokens = None

decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
cur_bsz=cur_bsz,
init_tokens=init_tokens,
Expand All @@ -683,9 +688,10 @@ def generate(
)

# 6.7 Set current `begin_index` for all logit processors
for proc in logits_processor:
if hasattr(proc, "set_begin_index"):
proc.set_begin_index(decoder_input_ids.shape[-1])
if not is_shortform:
for proc in logits_processor:
if hasattr(proc, "set_begin_index"):
proc.set_begin_index(decoder_input_ids.shape[-1])

# 6.8 Run generate with fallback
seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
Expand Down Expand Up @@ -749,7 +755,7 @@ def generate(
# add decoder_input_ids tokens:
sequences = torch.cat([decoder_input_ids, sequences], dim=-1)
# add eos token:
sequences = torch.cat([sequences, torch.full((2,1,), generation_config.eos_token_id).to(sequences.device)], dim=-1)
sequences = torch.cat([sequences, torch.full((sequences.shape[0],1,), generation_config.eos_token_id).to(sequences.device)], dim=-1)
if return_token_timestamps:
outputs = {}
outputs['sequences'] = sequences
Expand Down Expand Up @@ -882,7 +888,7 @@ def generate_with_fallback(
# if no sequence needs to be run with temperature fallback, we're finished
if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
seek_sequences = seek_sequence_list
seek_outputs = seek_outputs_list
seek_outputs = seek_outputs_list
break

# if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
Expand Down Expand Up @@ -1617,7 +1623,11 @@ def _retrieve_segment(
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
if timestamp_begin is not None:
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
else:
timestamp_tokens: torch.Tensor = torch.full((seek_sequence.shape[0],), False, dtype=torch.bool).to(seek_sequence.device)

single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
Expand Down

0 comments on commit 956cfb4

Please sign in to comment.