Skip to content

Commit 60808d7

Browse files
ksmuszmswiniarsk
andauthored
Adding dynamic swap number and defragmenter warmup (#183)
Introducing dynamic swap buckets to defragmenter, together with defragmenter warmup. Currently only a maximum of 32 blocks can be swapped of one iteration of a defragmenter. This change introduces a bucketing system, which asserts the minimal size bucket of swaps to be done in current defragmenter iteration based on actual number of blocks, that need to be swapped. Size of the buckets range from 8 swaps up to 512 swaps in a single defragmenter run. As the number of possible swap buckets grew from a single size bucket, a warmup of defragmenter has been added. Thanks to the warmup, no additional graph compilations connected to the defragmenter were visible during the inference. --------- Signed-off-by: Krzysztof Smusz <ksmusz@habana.ai> Co-authored-by: Marcin Swiniarski <marcin.swiniarski@intel.com>
1 parent 50a6cb5 commit 60808d7

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

vllm_gaudi/extension/defragmentation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class OnlineDefragmenter:
6060
def __init__(self):
6161
config = get_config()
6262
self.threshold = with_default(config.VLLM_DEFRAG_THRESHOLD, 32)
63+
self.to_swap_pad_thresholds = [8, 16, 32, 64, 128, 256, 512]
6364
self.used_blocks = {}
6465
self.req_blocks = {}
6566
self.fwd_mapping_table = []
@@ -174,14 +175,15 @@ def defragment(self):
174175
max_used = max(self.used_blocks.keys())
175176
num_used = len(self.used_blocks)
176177
pre_max_used = max_used
178+
# Use threshold for fragmentation trigger
177179
if max_used - self.threshold <= num_used:
178180
return
179181
free = self.free_blocks()
180182
used = sorted(self.used_blocks.keys(), reverse=True)
181183

182184
to_swap: list[tuple[int, int]] = []
183185
for used_block, free_block in zip(used, free):
184-
if len(to_swap) == self.threshold or free_block > used_block:
186+
if len(to_swap) == self.to_swap_pad_thresholds[-1] or free_block > used_block:
185187
break
186188
assert used_block in self.used_blocks
187189
assert free_block not in self.used_blocks
@@ -195,9 +197,11 @@ def defragment(self):
195197
self.update_mapping(orig_free_block, used_block)
196198

197199
assert self.cache_utils is not None
198-
self.cache_utils.swap(to_swap, self.threshold)
200+
to_swap_pad = next((x for x in self.to_swap_pad_thresholds if x >= len(to_swap)),
201+
self.to_swap_pad_thresholds[-1])
202+
self.cache_utils.swap(to_swap, to_swap_pad)
199203
if self.debug:
200204
max_used = max(self.used_blocks.keys())
201205
num_used = len(self.used_blocks)
202-
post_status = f'max_id_used={pre_max_used}->{max_used} num_used={num_used} swapped={len(to_swap)}/{self.threshold}'
206+
post_status = f'max_id_used={pre_max_used}->{max_used} num_used={num_used} swapped={len(to_swap)}/{to_swap_pad}'
203207
self.debug(f'defragmentation done {post_status}')

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3503,6 +3503,47 @@ def warmup_sampler(self):
35033503

35043504
logger.info("Sampler warmup completed successfully")
35053505

3506+
def warmup_defragmenter(self):
3507+
"""Warm up defragmentation swap graphs for different thresholds.
3508+
3509+
We execute a minimal swap (1 pair) which will be padded internally to the
3510+
requested threshold size. Thresholds chosen to mirror potential production
3511+
values: 8, 16, 32, 64, 128, 256, 512.
3512+
"""
3513+
# If defragmenter is disabled or cache utils not prepared, skip.
3514+
if not getattr(self.defragmenter, 'enabled', False):
3515+
return
3516+
if self.defragmenter.cache_utils is None:
3517+
return
3518+
3519+
thresholds = self.defragmenter.to_swap_pad_thresholds
3520+
3521+
logger.info("Warming up defragmenter with thresholds: %s", thresholds)
3522+
3523+
# Use simple valid block ids present in caches (assume at least 2 blocks allocated when kv caches created)
3524+
# We only need distinct ids for a swap. They will be scaled by block_size inside swap.
3525+
# If for some reason only 1 block exists, skip warmup gracefully.
3526+
try:
3527+
k_cache = self.defragmenter.cache_utils.kv_caches[0][0]
3528+
num_blocks_available = k_cache.shape[0] // self.block_size
3529+
except Exception:
3530+
num_blocks_available = 0
3531+
if num_blocks_available < 2:
3532+
logger.warning("Skipping defragmenter warmup, insufficient blocks (%s)", num_blocks_available)
3533+
return
3534+
3535+
# Minimal pair to trigger a swap path
3536+
to_swap = [(1, 0)]
3537+
3538+
for th in thresholds:
3539+
self.defragmenter.cache_utils.swap(to_swap, th)
3540+
3541+
# If the number of swaps was odd, do one more to make it even and return to original state.
3542+
if len(thresholds) % 2 == 1:
3543+
self.defragmenter.cache_utils.swap(to_swap, thresholds[0])
3544+
3545+
logger.info("Defragmenter warmup completed successfully")
3546+
35063547
def warmup_graphs(self, buckets, is_prompt, kv_caches, starting_mem=0, total_batch_seq=0.001):
35073548
total_mem = starting_mem
35083549
idx = 0
@@ -3782,6 +3823,7 @@ def warmup_model(self) -> None:
37823823
"to be called before warming up the model.")
37833824

37843825
self.warmup_sampler()
3826+
self.warmup_defragmenter()
37853827

37863828
# TODO(kzawora): align_workers
37873829
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \

0 commit comments

Comments
 (0)