@@ -552,6 +552,7 @@ def beam_search(
552552        prompts : list [Union [TokensPrompt , TextPrompt ]],
553553        params : BeamSearchParams ,
554554        lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] =  None ,
555+         use_tqdm : bool  =  False ,
555556    ) ->  list [BeamSearchOutput ]:
556557        """ 
557558        Generate sequences using beam search. 
@@ -561,6 +562,7 @@ def beam_search(
561562                of token IDs. 
562563            params: The beam search parameters. 
563564            lora_request: LoRA request to use for generation, if any. 
565+             use_tqdm: Whether to use tqdm to display the progress bar. 
564566        """ 
565567        # TODO: how does beam search work together with length penalty, 
566568        # frequency, penalty, and stopping criteria, etc.? 
@@ -623,7 +625,18 @@ def create_tokens_prompt_from_beam(
623625                    ** mm_kwargs ,
624626                ), )
625627
626-         for  _  in  range (max_tokens ):
628+         token_iter  =  range (max_tokens )
629+         if  use_tqdm :
630+             token_iter  =  tqdm (token_iter ,
631+                               desc = "Beam search" ,
632+                               unit = "token" ,
633+                               unit_scale = False )
634+             logger .warning (
635+                 "The progress bar shows the upper bound on token steps and " 
636+                 "may finish early due to stopping conditions. It does not " 
637+                 "reflect instance-level progress." )
638+ 
639+         for  _  in  token_iter :
627640            all_beams : list [BeamSearchSequence ] =  list (
628641                sum ((instance .beams  for  instance  in  instances ), []))
629642            pos  =  [0 ] +  list (
0 commit comments