Skip to content

Commit 8c20fce

Browse files
ksmuszxuechendi
andauthored
Enable modification of prompt BS (#258)
Enable modification of prefill BS with usage of `VLLM_PROMPT_BS_BUCKET_MAX` environment variable. The default size of prefill BS is set to 1 (remains the same as it was before the change). cherry-pick: #224 --------- Signed-off-by: Krzysztof Smusz <ksmusz@habana.ai> Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
1 parent 5aea2f6 commit 8c20fce

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
format_bytes, setup_profiler)
2424
from vllm_gaudi.extension.runtime import finalize_config, get_config
2525
from vllm_gaudi.extension.unified import (create_unified_batch)
26-
from vllm_gaudi.extension.utils import pad_list
26+
from vllm_gaudi.extension.utils import pad_list, with_default
2727
from vllm_gaudi.extension.debug import init_debug_logger
2828

2929
from vllm.attention.backends.abstract import AttentionType
@@ -807,8 +807,7 @@ def __init__(
807807
self.use_hpu_graph = not self.model_config.enforce_eager
808808
self.max_batch_size = self.scheduler_config.max_num_seqs
809809
self.max_num_seqs = self.scheduler_config.max_num_seqs
810-
# TODO(kzawora): add knob for that
811-
self.max_prefill_batch_size = 1
810+
self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1)
812811
self.seen_configs: set = set()
813812
self.max_num_batched_tokens = \
814813
self.scheduler_config.max_num_batched_tokens
@@ -1684,6 +1683,10 @@ def _form_prefill_batch(self, contents):
16841683
# for the valid tokens before padding.
16851684
# This would require getting multimodal input embeddings here as well
16861685
token_ids = self._align_and_pad(contents.token_ids, (target_bs, target_seq), itertools.repeat(-1))
1686+
# Update query_lens and context_lens after padding
1687+
query_lens.extend([0] * (target_bs - len(query_lens)))
1688+
context_lens.extend([0] * (target_bs - len(context_lens)))
1689+
16871690
# If the model uses M-RoPE, we need to fill
16881691
# and pad the M-RoPE positions for the scheduled prefill tokens
16891692
if self.uses_mrope:
@@ -3819,13 +3822,14 @@ def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg):
38193822
prompt_total_tokens, prompt_num_context_blocks = \
38203823
self.get_merged_prefill_seq_lens(prompt_query_len,
38213824
prompt_num_blocks)
3822-
for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks):
3823-
self._add_dummy_request(requests,
3824-
scheduled_tokens,
3825-
num_computed_tokens=(context_len * self.block_size),
3826-
total_tokens=tokens,
3827-
scheduled_tokens=prompt_query_len,
3828-
is_prompt=True)
3825+
for _ in range(prompt_bs):
3826+
for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks):
3827+
self._add_dummy_request(requests,
3828+
scheduled_tokens,
3829+
num_computed_tokens=(context_len * self.block_size),
3830+
total_tokens=tokens,
3831+
scheduled_tokens=prompt_query_len,
3832+
is_prompt=True)
38293833
if decode_cfg:
38303834
decode_bs, decode_query_len, decode_num_blocks = decode_cfg
38313835
if self.use_contiguous_pa:

0 commit comments

Comments
 (0)