Skip to content

Commit de6833c

Browse files
committed
i messed up and now i've fixed it
1 parent 0bd8366 commit de6833c

File tree

1 file changed

+8
-28
lines changed

1 file changed

+8
-28
lines changed

vllm/worker/habana_model_runner.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,8 @@ def warmup_scenario(self,
13341334
seq_len,
13351335
is_prompt,
13361336
kv_caches,
1337-
is_profile_run=False,
1338-
override_n_runs=None) -> None:
1337+
is_pt_profiler_run=False,
1338+
is_lora_profile_run=False) -> None:
13391339
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
13401340
scenario_name = ("warmup_"
13411341
f"{'prompt' if is_prompt else 'decode'}_"
@@ -1367,10 +1367,8 @@ def warmup_scenario(self,
13671367
for idx in range(max_num_seqs)
13681368
]
13691369
self.profiler.start('internal', scenario_name)
1370-
times = 3 if use_graphs or is_profile_run else 1
1371-
if override_n_runs is not None:
1372-
times = override_n_runs
1373-
if self.lora_config and not is_profile_run:
1370+
times = 3 if use_graphs or is_pt_profiler_run else 1
1371+
if self.lora_config and not is_lora_profile_run:
13741372
lora_mapping = LoRAMapping(
13751373
[0] * batch_size * seq_len,
13761374
[0] * batch_size * seq_len,
@@ -1401,27 +1399,19 @@ def warmup_scenario(self,
14011399
]
14021400
torch.hpu.synchronize()
14031401
profiler = None
1404-
fwd_times = []
1405-
if is_profile_run and self.is_driver_worker:
1402+
if is_pt_profiler_run and self.is_driver_worker:
14061403
profiler = setup_profiler()
14071404
profiler.start()
14081405
for _ in range(times):
1409-
torch.hpu.synchronize()
1410-
start = time.perf_counter()
14111406
inputs = self.prepare_model_input(seqs)
1412-
self.execute_model(inputs, kv_caches, warmup_mode=False)
1407+
self.execute_model(inputs, kv_caches, warmup_mode=True)
14131408
torch.hpu.synchronize()
1414-
end = time.perf_counter()
1415-
elapsed = end - start
1416-
fwd_times.append(elapsed)
1417-
print(f'[{batch_size}x{seq_len}x{use_graphs}] tput: {batch_size/elapsed:.3f} tps, time: {elapsed*1000:.3f} ms')
14181409
if profiler:
14191410
profiler.step()
14201411
if profiler:
14211412
profiler.stop()
14221413
self.profiler.end()
14231414
gc.collect()
1424-
return fwd_times, use_graphs
14251415

14261416
def remove_all_loras(self):
14271417
if not self.lora_manager:
@@ -1466,13 +1456,11 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len):
14661456
f"free_mem:{free_mem}")
14671457
logger.info(msg)
14681458

1469-
def warmup_all_buckets(self, buckets, is_prompt, kv_caches, override_n_runs=None):
1470-
bucket_times = {}
1459+
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
14711460
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
14721461
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
14731462
len(buckets), batch_size, seq_len)
1474-
bucket_times[(batch_size, seq_len)] = self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches, override_n_runs=override_n_runs)
1475-
return bucket_times
1463+
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
14761464

14771465
def warmup_graphs(self,
14781466
strategy,
@@ -1676,14 +1664,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
16761664
logger.info(msg)
16771665
self.profiler.end()
16781666

1679-
if os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS', 'false').lower() == 'true':
1680-
from vllm.hpu.utils import process_run_characteristics
1681-
n_runs = int(os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS_N', '5'))
1682-
decode_times = self.warmup_all_buckets(self.decode_buckets, False, kv_caches, override_n_runs=n_runs)
1683-
process_run_characteristics(decode_times, block_size=self.cache_config.block_size, prefill=False)
1684-
prefill_times = self.warmup_all_buckets(self.prompt_buckets, True, kv_caches, override_n_runs=n_runs)
1685-
process_run_characteristics(prefill_times, block_size=self.cache_config.block_size, prefill=True)
1686-
16871667
@property
16881668
def vocab_size(self) -> int:
16891669
return self.model_config.get_vocab_size()

0 commit comments

Comments
 (0)