@@ -68,7 +68,6 @@ class SampleResultArgsType:
6868 sample_results_dict : SampleResultsDictType
6969 sampling_metadata : SamplingMetadata
7070 greedy_samples : Optional [torch .Tensor ]
71- beam_search_logprobs : Optional [torch .Tensor ]
7271
7372
7473# Union of non-deferred (single-step scheduling)
@@ -510,74 +509,6 @@ def _random_sample(
510509 return results
511510
512511
513- def _beam_search_sample (
514- selected_seq_groups : List [SequenceGroupToSample ],
515- logprobs : torch .Tensor ,
516- ) -> SampleResultType :
517- """Run beam sampling on a given samples.
518-
519- Args:
520- selected_seq_groups: A list of sequence groups batched.
521- logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
522- on selected sample indices.
523- Returns:
524- Tuple of (next_token_ids, parent_ids). The length of returned list is
525- same as the length of selected_seq_groups. If the corresponding
526- seq_group has do_sample=False, tuple contains ([], [])
527- """
528- # We sample 2 * beam_width candidates to make sure that with high
529- # probability we can get `beam_width` candidates in addition to
530- # the finished sequences for the next iteration. See
531- # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
532- # for details. See also HF reference:
533- # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
534- #
535- # NOTE: Beam search is not vectorized, so its speed can be slower than
536- # other sampling methods.
537- sample_idx = 0
538- results : SampleResultType = []
539- for seq_group in selected_seq_groups :
540- if not seq_group .do_sample :
541- results .append (([], []))
542- continue
543-
544- is_prompt = seq_group .is_prompt
545- seq_ids , sampling_params = seq_group .seq_ids , seq_group .sampling_params
546- num_parent_seqs = len (seq_ids )
547- beam_width = sampling_params .n
548- seq_group_logprobs = logprobs [sample_idx :sample_idx + num_parent_seqs ]
549- if is_prompt :
550- # Prompt phase.
551- assert num_parent_seqs == 1 , (
552- "Prompt input should have only one seq." )
553- parent_ids = [0 ] * (2 * beam_width )
554- _ , next_token_ids = torch .topk (seq_group_logprobs [0 ],
555- 2 * beam_width )
556- next_token_ids = next_token_ids .tolist ()
557- else :
558- # Generation phase.
559- cumulative_logprobs : List [float ] = [
560- seq_group .seq_data [seq_id ].cumulative_logprob
561- for seq_id in seq_ids
562- ]
563- cumulative_logprobs_tensor = torch .tensor (
564- cumulative_logprobs ,
565- dtype = torch .float ,
566- device = seq_group_logprobs .device )
567- seq_group_logprobs = (seq_group_logprobs +
568- cumulative_logprobs_tensor .unsqueeze (dim = 1 ))
569- _ , topk_ids = torch .topk (seq_group_logprobs .flatten (),
570- 2 * beam_width )
571- topk_ids = topk_ids .tolist ()
572- vocab_size = seq_group_logprobs .size (- 1 )
573- parent_ids = [i // vocab_size for i in topk_ids ]
574- next_token_ids = [i % vocab_size for i in topk_ids ]
575- results .append ((next_token_ids , parent_ids ))
576- sample_idx += num_parent_seqs
577- assert sample_idx == logprobs .size (0 )
578- return results
579-
580-
581512# torch.multinomial forces a GPU<->CPU sync.
582513# Therefore, we use an optimized implementation instead.
583514# Note that we always sample with replacement.
@@ -666,14 +597,12 @@ def get_pythonized_sample_results(
666597 sampling_metadata ,
667598 greedy_samples ,
668599 multinomial_samples ,
669- beam_search_logprobs ,
670600 sample_results_dict ,
671601 ) = (
672602 sample_result_args .sample_metadata ,
673603 sample_result_args .sampling_metadata ,
674604 sample_result_args .greedy_samples ,
675605 sample_result_args .multinomial_samples ,
676- sample_result_args .beam_search_logprobs ,
677606 sample_result_args .sample_results_dict ,
678607 )
679608
@@ -686,9 +615,6 @@ def get_pythonized_sample_results(
686615 elif sampling_type in (SamplingType .RANDOM , SamplingType .RANDOM_SEED ):
687616 sample_results = _random_sample (seq_groups ,
688617 multinomial_samples [sampling_type ])
689- elif sampling_type == SamplingType .BEAM :
690- sample_results = _beam_search_sample (seq_groups ,
691- beam_search_logprobs )
692618 sample_results_dict .update (zip (seq_group_id , sample_results ))
693619
694620 return [
@@ -731,7 +657,6 @@ def _sample_with_torch(
731657 sample_metadata : SampleMetadataType = {}
732658 multinomial_samples : MultinomialSamplesType = {}
733659 greedy_samples : Optional [torch .Tensor ] = None
734- beam_search_logprobs : Optional [torch .Tensor ] = None
735660
736661 # Create output tensor for sampled token ids.
737662 if include_gpu_probs_tensor :
@@ -800,8 +725,6 @@ def _sample_with_torch(
800725 sampled_token_ids_tensor [long_sample_indices ] = \
801726 multinomial_samples [sampling_type ].to (torch .long )
802727
803- elif sampling_type == SamplingType .BEAM :
804- beam_search_logprobs = logprobs [sample_indices ]
805728 else :
806729 raise ValueError (f"Unsupported sampling type: { sampling_type } " )
807730
@@ -812,7 +735,6 @@ def _sample_with_torch(
812735 sample_metadata = sample_metadata ,
813736 multinomial_samples = multinomial_samples ,
814737 greedy_samples = greedy_samples ,
815- beam_search_logprobs = beam_search_logprobs ,
816738 sample_results_dict = sample_results_dict )
817739
818740 if not sampling_metadata .skip_sampler_cpu_output :
0 commit comments