diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a0c58749c33743..b99297c3e8d8b8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1778,6 +1778,24 @@ def typeerror(): return result + def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: + """ + Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is + fed through `this_peer_finished`. ZeRO stage 3-friendly. + """ + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: + return False + return True + def contrastive_search(self, *args, **kwargs): logger.warning_once( "Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a " @@ -1939,19 +1957,9 @@ def _contrastive_search( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - this_peer_finished = False # used by synced_gpus only - - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break + this_peer_finished = False + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None: @@ -2187,12 +2195,7 @@ def _contrastive_search( # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - - if unfinished_sequences.max() == 0: - this_peer_finished = True - - if this_peer_finished and not synced_gpus: - break + this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end() @@ -2395,6 +2398,7 @@ def _greedy_search( ) # keep track of which sequences are already finished + this_peer_finished = False batch_size, cur_len = ( model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) is not None @@ -2403,18 +2407,7 @@ def _greedy_search( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2480,13 +2473,7 @@ def _greedy_search( ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - - if this_peer_finished and not synced_gpus: - break + this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end() @@ -2699,6 +2686,7 @@ def _sample( ) # keep track of which sequences are already finished + this_peer_finished = False batch_size, cur_len = ( model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) is not None @@ -2707,19 +2695,7 @@ def _sample( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - this_peer_finished = False # used by synced_gpus only - # auto-regressive generation - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2787,13 +2763,7 @@ def _sample( ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - - if this_peer_finished and not synced_gpus: - break + this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end() @@ -3052,20 +3022,11 @@ def _beam_search( beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) - this_peer_finished = False # used by synced_gpus only + this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # if sequential is True, split the input to batches of batch_size and run sequentially @@ -3192,10 +3153,7 @@ def _beam_search( cur_len = cur_len + 1 if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - if not synced_gpus: - break - else: - this_peer_finished = True + this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, @@ -3441,20 +3399,10 @@ def _beam_sample( beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) - this_peer_finished = False # used by synced_gpus only + this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( @@ -3549,10 +3497,7 @@ def _beam_sample( cur_len = cur_len + 1 if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - if not synced_gpus: - break - else: - this_peer_finished = True + this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, @@ -3804,20 +3749,10 @@ def _group_beam_search( beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) - this_peer_finished = False # used by synced_gpus only + this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) @@ -3955,10 +3890,7 @@ def _group_beam_search( cur_len = cur_len + 1 if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - if not synced_gpus: - break - else: - this_peer_finished = True + this_peer_finished = True final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( @@ -4213,20 +4145,10 @@ def _constrained_beam_search( beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) - this_peer_finished = False # used by synced_gpus only + this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( @@ -4320,10 +4242,7 @@ def _constrained_beam_search( cur_len = cur_len + 1 if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - if not synced_gpus: - break - else: - this_peer_finished = True + this_peer_finished = True sequence_outputs = constrained_beam_scorer.finalize( input_ids, @@ -4553,18 +4472,8 @@ def _assisted_decoding( # other auxiliary variables max_len = stopping_criteria[0].max_length - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - + this_peer_finished = False + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` @@ -4733,13 +4642,7 @@ def _assisted_decoding( ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - - if this_peer_finished and not synced_gpus: - break + this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end()