|
23 | 23 | format_bytes, setup_profiler) |
24 | 24 | from vllm_gaudi.extension.runtime import finalize_config, get_config |
25 | 25 | from vllm_gaudi.extension.unified import (create_unified_batch) |
26 | | -from vllm_gaudi.extension.utils import pad_list |
| 26 | +from vllm_gaudi.extension.utils import pad_list, with_default |
27 | 27 | from vllm_gaudi.extension.debug import init_debug_logger |
28 | 28 |
|
29 | 29 | from vllm.attention.backends.abstract import AttentionType |
@@ -807,8 +807,7 @@ def __init__( |
807 | 807 | self.use_hpu_graph = not self.model_config.enforce_eager |
808 | 808 | self.max_batch_size = self.scheduler_config.max_num_seqs |
809 | 809 | self.max_num_seqs = self.scheduler_config.max_num_seqs |
810 | | - # TODO(kzawora): add knob for that |
811 | | - self.max_prefill_batch_size = 1 |
| 810 | + self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1) |
812 | 811 | self.seen_configs: set = set() |
813 | 812 | self.max_num_batched_tokens = \ |
814 | 813 | self.scheduler_config.max_num_batched_tokens |
@@ -1684,6 +1683,10 @@ def _form_prefill_batch(self, contents): |
1684 | 1683 | # for the valid tokens before padding. |
1685 | 1684 | # This would require getting multimodal input embeddings here as well |
1686 | 1685 | token_ids = self._align_and_pad(contents.token_ids, (target_bs, target_seq), itertools.repeat(-1)) |
| 1686 | + # Update query_lens and context_lens after padding |
| 1687 | + query_lens.extend([0] * (target_bs - len(query_lens))) |
| 1688 | + context_lens.extend([0] * (target_bs - len(context_lens))) |
| 1689 | + |
1687 | 1690 | # If the model uses M-RoPE, we need to fill |
1688 | 1691 | # and pad the M-RoPE positions for the scheduled prefill tokens |
1689 | 1692 | if self.uses_mrope: |
@@ -3819,13 +3822,14 @@ def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg): |
3819 | 3822 | prompt_total_tokens, prompt_num_context_blocks = \ |
3820 | 3823 | self.get_merged_prefill_seq_lens(prompt_query_len, |
3821 | 3824 | prompt_num_blocks) |
3822 | | - for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks): |
3823 | | - self._add_dummy_request(requests, |
3824 | | - scheduled_tokens, |
3825 | | - num_computed_tokens=(context_len * self.block_size), |
3826 | | - total_tokens=tokens, |
3827 | | - scheduled_tokens=prompt_query_len, |
3828 | | - is_prompt=True) |
| 3825 | + for _ in range(prompt_bs): |
| 3826 | + for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks): |
| 3827 | + self._add_dummy_request(requests, |
| 3828 | + scheduled_tokens, |
| 3829 | + num_computed_tokens=(context_len * self.block_size), |
| 3830 | + total_tokens=tokens, |
| 3831 | + scheduled_tokens=prompt_query_len, |
| 3832 | + is_prompt=True) |
3829 | 3833 | if decode_cfg: |
3830 | 3834 | decode_bs, decode_query_len, decode_num_blocks = decode_cfg |
3831 | 3835 | if self.use_contiguous_pa: |
|
0 commit comments