Skip to content

Commit

Permalink
Generate: replace breaks by a loop condition (#29662)
Browse files Browse the repository at this point in the history
* replace breaks by a loop condition

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 0ac9d58 commit 1a30d8e
Showing 1 changed file with 42 additions and 139 deletions.
181 changes: 42 additions & 139 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,24 @@ def typeerror():

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
return True

def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
Expand Down Expand Up @@ -1939,19 +1957,9 @@ def _contrastive_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only

while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
this_peer_finished = False

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None:
Expand Down Expand Up @@ -2187,12 +2195,7 @@ def _contrastive_search(

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -2395,6 +2398,7 @@ def _greedy_search(
)

# keep track of which sequences are already finished
this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
Expand All @@ -2403,18 +2407,7 @@ def _greedy_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2480,13 +2473,7 @@ def _greedy_search(
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -2699,6 +2686,7 @@ def _sample(
)

# keep track of which sequences are already finished
this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
Expand All @@ -2707,19 +2695,7 @@ def _sample(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2787,13 +2763,7 @@ def _sample(
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -3052,20 +3022,11 @@ def _beam_search(
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# if sequential is True, split the input to batches of batch_size and run sequentially
Expand Down Expand Up @@ -3192,10 +3153,7 @@ def _beam_search(
cur_len = cur_len + 1

if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -3441,20 +3399,10 @@ def _beam_sample(
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
Expand Down Expand Up @@ -3549,10 +3497,7 @@ def _beam_sample(
cur_len = cur_len + 1

if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -3804,20 +3749,10 @@ def _group_beam_search(
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

Expand Down Expand Up @@ -3955,10 +3890,7 @@ def _group_beam_search(
cur_len = cur_len + 1

if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
Expand Down Expand Up @@ -4213,20 +4145,10 @@ def _constrained_beam_search(
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
Expand Down Expand Up @@ -4320,10 +4242,7 @@ def _constrained_beam_search(
cur_len = cur_len + 1

if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

sequence_outputs = constrained_beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -4553,18 +4472,8 @@ def _assisted_decoding(
# other auxiliary variables
max_len = stopping_criteria[0].max_length

this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]

# 1. Fetch candidate sequences from a `CandidateGenerator`
Expand Down Expand Up @@ -4733,13 +4642,7 @@ def _assisted_decoding(
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down

0 comments on commit 1a30d8e

Please sign in to comment.