Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: replace breaks by a loop condition #29662

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 42 additions & 139 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,24 @@ def typeerror():

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
return True

def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
Expand Down Expand Up @@ -1939,19 +1957,9 @@ def _contrastive_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only

while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
this_peer_finished = False

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None:
Expand Down Expand Up @@ -2187,12 +2195,7 @@ def _contrastive_search(

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version is also a data-dependent control flow, so this change is for torch.compile readiness :)


if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -2395,6 +2398,7 @@ def _greedy_search(
)

# keep track of which sequences are already finished
this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
Expand All @@ -2403,18 +2407,7 @@ def _greedy_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2480,13 +2473,7 @@ def _greedy_search(
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -2699,6 +2686,7 @@ def _sample(
)

# keep track of which sequences are already finished
this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
Expand All @@ -2707,19 +2695,7 @@ def _sample(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2787,13 +2763,7 @@ def _sample(
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -3052,20 +3022,11 @@ def _beam_search(
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# if sequential is True, split the input to batches of batch_size and run sequentially
Expand Down Expand Up @@ -3192,10 +3153,7 @@ def _beam_search(
cur_len = cur_len + 1

if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -3441,20 +3399,10 @@ def _beam_sample(
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
Expand Down Expand Up @@ -3549,10 +3497,7 @@ def _beam_sample(
cur_len = cur_len + 1

if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -3804,20 +3749,10 @@ def _group_beam_search(
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

Expand Down Expand Up @@ -3955,10 +3890,7 @@ def _group_beam_search(
cur_len = cur_len + 1

if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
Expand Down Expand Up @@ -4213,20 +4145,10 @@ def _constrained_beam_search(
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
Expand Down Expand Up @@ -4320,10 +4242,7 @@ def _constrained_beam_search(
cur_len = cur_len + 1

if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True
this_peer_finished = True

sequence_outputs = constrained_beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -4553,18 +4472,8 @@ def _assisted_decoding(
# other auxiliary variables
max_len = stopping_criteria[0].max_length

this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]

# 1. Fetch candidate sequences from a `CandidateGenerator`
Expand Down Expand Up @@ -4733,13 +4642,7 @@ def _assisted_decoding(
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True

if this_peer_finished and not synced_gpus:
break
this_peer_finished = unfinished_sequences.max() == 0

if streamer is not None:
streamer.end()
Expand Down
Loading