Skip to content

Commit 1a31dc6

Browse files
Add exponential bucketing integration (#642)
Requires HabanaAI/vllm-hpu-extension#61 --------- Co-authored-by: Iryna Boiko <iboiko@habana.ai>
1 parent 9a06a89 commit 1a31dc6

File tree

6 files changed

+19
-52
lines changed

6 files changed

+19
-52
lines changed

.github/workflows/add_label_automerge.yml

Lines changed: 0 additions & 21 deletions
This file was deleted.

README_GAUDI.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi
343343
- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.3` by default.
344344
- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default.
345345
- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`, `max_bs` by default.
346-
- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism.
346+
- `VLLM_EXPONENTIAL_BUCKETING`, if `true`, enables exponential bucket spacing instead of linear (experimental).
347+
- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism (linear bucketing only).
347348
- `{phase}` is either `PROMPT` or `DECODE`
348349
- `{dim}` is either `BS`, `SEQ` or `BLOCK`
349350
- `{param}` is either `MIN`, `STEP` or `MAX`

requirements-hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ pandas
88
tabulate
99
setuptools>=61
1010
setuptools-scm>=8
11-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@3e0fb39
11+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@bcfa409

vllm/core/scheduler.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,25 +132,12 @@ def _generic_padding_fn(self, batch_size, max_seq_len) -> int:
132132
return batch_size * max_seq_len
133133

134134
def _hpu_padding_fn(self, batch_size, max_seq_len):
135-
from vllm_hpu_extension.bucketing import (HPUBucketingGlobalState,
136-
find_bucket)
137-
padded_bs = batch_size
138-
padded_seq = max_seq_len
139-
140-
hpu_bucketing_global_state = HPUBucketingGlobalState()
141-
142-
bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg
143-
if bs_cfg is not None:
144-
padded_bs = find_bucket(batch_size, bs_cfg)
145-
else:
146-
logger.warning(
147-
"prompt_bs_bucket_cfg was not set! Using unpadded batch size.")
148-
seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg
149-
if seq_cfg is not None:
150-
padded_seq = find_bucket(max_seq_len, seq_cfg)
151-
else:
152-
logger.warning("prompt_seq_bucket_cfg was not set! "
153-
"Using unpadded sequence length.")
135+
from vllm_hpu_extension.bucketing.common import get_bucketing_context
136+
hpu_bucketing_context = get_bucketing_context().get_instance()
137+
padded_bs = hpu_bucketing_context.get_padded_prompt_batch_size(
138+
batch_size)
139+
padded_seq = hpu_bucketing_context.get_padded_prompt_seq_len(
140+
max_seq_len)
154141
return padded_bs * padded_seq
155142

156143
def _padding_fn_selector(self):

vllm/v1/worker/hpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import torch
1818
import torch.distributed
1919
import vllm_hpu_extension.environment as environment
20-
from vllm_hpu_extension.bucketing import HPUBucketingContext
2120
from vllm_hpu_extension.flags import enabled_flags
2221
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
2322

@@ -45,6 +44,7 @@
4544

4645
if TYPE_CHECKING:
4746
from vllm.v1.core.scheduler import SchedulerOutput
47+
from vllm_hpu_extension.bucketing.common import get_bucketing_context
4848

4949
logger = init_logger(__name__)
5050

@@ -705,6 +705,7 @@ def __init__(
705705
self.seen_configs: set = set()
706706
if self.enable_bucketing:
707707
logger.info("Bucketing is ON.")
708+
HPUBucketingContext = get_bucketing_context()
708709
self.bucketing_ctx = HPUBucketingContext(
709710
self.max_num_seqs, self.max_prefill_batch_size,
710711
self.block_size, self.scheduler_config.max_num_batched_tokens,
@@ -1917,7 +1918,6 @@ def warmup_model(self) -> None:
19171918
logger.info("Skipping warmup...")
19181919
return
19191920
max_blocks = kv_caches[0][0].size(0)
1920-
self.bucketing_ctx.generate_prompt_buckets()
19211921
self.bucketing_ctx.generate_decode_buckets(max_blocks)
19221922

19231923
if not htorch.utils.internal.is_lazy(

vllm/worker/hpu_model_runner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import habana_frameworks.torch.internal.bridge_config as bc
2323
import torch
2424
import vllm_hpu_extension.environment as environment
25-
from vllm_hpu_extension.bucketing import HPUBucketingContext
25+
from vllm_hpu_extension.bucketing.common import get_bucketing_context
2626
from vllm_hpu_extension.flags import enabled_flags
2727
from vllm_hpu_extension.ops import LoraMask as LoraMask
2828
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
@@ -690,11 +690,13 @@ def __init__(
690690
self.profiler_counter_helper = HabanaProfilerCounterHelper()
691691
self.seen_configs: set = set()
692692
self._mem_margin: Optional[int] = None
693+
HPUBucketingContext = get_bucketing_context()
693694
self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
694695
self.max_num_prefill_seqs,
695696
self.block_size,
696697
self.max_num_batched_tokens,
697-
self.use_merged_prefill)
698+
self.use_merged_prefill,
699+
self.max_model_len)
698700
self.graphed_buckets: Set[Any] = set()
699701

700702
self._set_gc_threshold()
@@ -1958,7 +1960,6 @@ def profile_run(self) -> None:
19581960
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
19591961
max_batch_size = min(self.max_num_seqs,
19601962
self.max_num_batched_tokens // max_seq_len)
1961-
19621963
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
19631964
False, True)
19641965
return
@@ -2188,6 +2189,10 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):
21882189

21892190
@torch.inference_mode()
21902191
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
2192+
if not self.is_pooler:
2193+
max_blocks = kv_caches[0][0].size(0)
2194+
self.bucketing_ctx.generate_decode_buckets(max_blocks)
2195+
21912196
if profile := os.environ.get('VLLM_PT_PROFILE', None):
21922197
phase, bs, seq_len, graph = profile.split('_')
21932198
is_prompt = phase == 'prompt'
@@ -2197,11 +2202,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
21972202
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
21982203
True)
21992204
raise AssertionError("Finished profiling")
2200-
if not self.is_pooler:
2201-
max_blocks = kv_caches[0][0].size(0)
2202-
self.bucketing_ctx.generate_prompt_buckets()
2203-
if not self.is_pooler:
2204-
self.bucketing_ctx.generate_decode_buckets(max_blocks)
22052205
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
22062206
multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION',
22072207
'true').lower() == 'true' else 1

0 commit comments

Comments
 (0)