@@ -1334,8 +1334,8 @@ def warmup_scenario(self,
13341334 seq_len ,
13351335 is_prompt ,
13361336 kv_caches ,
1337- is_profile_run = False ,
1338- override_n_runs = None ) -> None :
1337+ is_pt_profiler_run = False ,
1338+ is_lora_profile_run = False ) -> None :
13391339 use_graphs = self ._use_graphs (batch_size , seq_len , is_prompt )
13401340 scenario_name = ("warmup_"
13411341 f"{ 'prompt' if is_prompt else 'decode' } _"
@@ -1367,10 +1367,8 @@ def warmup_scenario(self,
13671367 for idx in range (max_num_seqs )
13681368 ]
13691369 self .profiler .start ('internal' , scenario_name )
1370- times = 3 if use_graphs or is_profile_run else 1
1371- if override_n_runs is not None :
1372- times = override_n_runs
1373- if self .lora_config and not is_profile_run :
1370+ times = 3 if use_graphs or is_pt_profiler_run else 1
1371+ if self .lora_config and not is_lora_profile_run :
13741372 lora_mapping = LoRAMapping (
13751373 [0 ] * batch_size * seq_len ,
13761374 [0 ] * batch_size * seq_len ,
@@ -1401,27 +1399,19 @@ def warmup_scenario(self,
14011399 ]
14021400 torch .hpu .synchronize ()
14031401 profiler = None
1404- fwd_times = []
1405- if is_profile_run and self .is_driver_worker :
1402+ if is_pt_profiler_run and self .is_driver_worker :
14061403 profiler = setup_profiler ()
14071404 profiler .start ()
14081405 for _ in range (times ):
1409- torch .hpu .synchronize ()
1410- start = time .perf_counter ()
14111406 inputs = self .prepare_model_input (seqs )
1412- self .execute_model (inputs , kv_caches , warmup_mode = False )
1407+ self .execute_model (inputs , kv_caches , warmup_mode = True )
14131408 torch .hpu .synchronize ()
1414- end = time .perf_counter ()
1415- elapsed = end - start
1416- fwd_times .append (elapsed )
1417- print (f'[{ batch_size } x{ seq_len } x{ use_graphs } ] tput: { batch_size / elapsed :.3f} tps, time: { elapsed * 1000 :.3f} ms' )
14181409 if profiler :
14191410 profiler .step ()
14201411 if profiler :
14211412 profiler .stop ()
14221413 self .profiler .end ()
14231414 gc .collect ()
1424- return fwd_times , use_graphs
14251415
14261416 def remove_all_loras (self ):
14271417 if not self .lora_manager :
@@ -1466,13 +1456,11 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len):
14661456 f"free_mem:{ free_mem } " )
14671457 logger .info (msg )
14681458
1469- def warmup_all_buckets (self , buckets , is_prompt , kv_caches , override_n_runs = None ):
1470- bucket_times = {}
1459+ def warmup_all_buckets (self , buckets , is_prompt , kv_caches ):
14711460 for i , (batch_size , seq_len ) in enumerate (reversed (buckets )):
14721461 self .log_warmup ('Prompt' if is_prompt else 'Decode' , i ,
14731462 len (buckets ), batch_size , seq_len )
1474- bucket_times [(batch_size , seq_len )] = self .warmup_scenario (batch_size , seq_len , is_prompt , kv_caches , override_n_runs = override_n_runs )
1475- return bucket_times
1463+ self .warmup_scenario (batch_size , seq_len , is_prompt , kv_caches )
14761464
14771465 def warmup_graphs (self ,
14781466 strategy ,
@@ -1676,14 +1664,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
16761664 logger .info (msg )
16771665 self .profiler .end ()
16781666
1679- if os .environ .get ('VLLM_PROFILE_SERVER_CHARACTERISTICS' , 'false' ).lower () == 'true' :
1680- from vllm .hpu .utils import process_run_characteristics
1681- n_runs = int (os .environ .get ('VLLM_PROFILE_SERVER_CHARACTERISTICS_N' , '5' ))
1682- decode_times = self .warmup_all_buckets (self .decode_buckets , False , kv_caches , override_n_runs = n_runs )
1683- process_run_characteristics (decode_times , block_size = self .cache_config .block_size , prefill = False )
1684- prefill_times = self .warmup_all_buckets (self .prompt_buckets , True , kv_caches , override_n_runs = n_runs )
1685- process_run_characteristics (prefill_times , block_size = self .cache_config .block_size , prefill = True )
1686-
16871667 @property
16881668 def vocab_size (self ) -> int :
16891669 return self .model_config .get_vocab_size ()
0 commit comments