Skip to content
10 changes: 7 additions & 3 deletions vllm_gaudi/extension/defragmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class OnlineDefragmenter:
def __init__(self):
config = get_config()
self.threshold = with_default(config.VLLM_DEFRAG_THRESHOLD, 32)
self.to_swap_pad_thresholds = [8, 16, 32, 64, 128, 256, 512]
self.used_blocks = {}
self.req_blocks = {}
self.fwd_mapping_table = []
Expand Down Expand Up @@ -174,14 +175,15 @@ def defragment(self):
max_used = max(self.used_blocks.keys())
num_used = len(self.used_blocks)
pre_max_used = max_used
# Use threshold for fragmentation trigger
if max_used - self.threshold <= num_used:
return
free = self.free_blocks()
used = sorted(self.used_blocks.keys(), reverse=True)

to_swap: list[tuple[int, int]] = []
for used_block, free_block in zip(used, free):
if len(to_swap) == self.threshold or free_block > used_block:
if len(to_swap) == self.to_swap_pad_thresholds[-1] or free_block > used_block:
break
assert used_block in self.used_blocks
assert free_block not in self.used_blocks
Expand All @@ -195,9 +197,11 @@ def defragment(self):
self.update_mapping(orig_free_block, used_block)

assert self.cache_utils is not None
self.cache_utils.swap(to_swap, self.threshold)
to_swap_pad = next((x for x in self.to_swap_pad_thresholds if x >= len(to_swap)),
self.to_swap_pad_thresholds[-1])
self.cache_utils.swap(to_swap, to_swap_pad)
if self.debug:
max_used = max(self.used_blocks.keys())
num_used = len(self.used_blocks)
post_status = f'max_id_used={pre_max_used}->{max_used} num_used={num_used} swapped={len(to_swap)}/{self.threshold}'
post_status = f'max_id_used={pre_max_used}->{max_used} num_used={num_used} swapped={len(to_swap)}/{to_swap_pad}'
self.debug(f'defragmentation done {post_status}')
42 changes: 42 additions & 0 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3508,6 +3508,47 @@ def warmup_sampler(self):

logger.info("Sampler warmup completed successfully")

def warmup_defragmenter(self):
"""Warm up defragmentation swap graphs for different thresholds.

We execute a minimal swap (1 pair) which will be padded internally to the
requested threshold size. Thresholds chosen to mirror potential production
values: 8, 16, 32, 64, 128, 256, 512.
"""
# If defragmenter is disabled or cache utils not prepared, skip.
if not getattr(self.defragmenter, 'enabled', False):
return
if self.defragmenter.cache_utils is None:
return

thresholds = self.defragmenter.to_swap_pad_thresholds

logger.info("Warming up defragmenter with thresholds: %s", thresholds)

# Use simple valid block ids present in caches (assume at least 2 blocks allocated when kv caches created)
# We only need distinct ids for a swap. They will be scaled by block_size inside swap.
# If for some reason only 1 block exists, skip warmup gracefully.
try:
k_cache = self.defragmenter.cache_utils.kv_caches[0][0]
num_blocks_available = k_cache.shape[0] // self.block_size
except Exception:
num_blocks_available = 0
if num_blocks_available < 2:
logger.warning("Skipping defragmenter warmup, insufficient blocks (%s)", num_blocks_available)
return

# Minimal pair to trigger a swap path
to_swap = [(1, 0)]

for th in thresholds:
self.defragmenter.cache_utils.swap(to_swap, th)

# If the number of swaps was odd, do one more to make it even and return to original state.
if len(thresholds) % 2 == 1:
self.defragmenter.cache_utils.swap(to_swap, thresholds[0])

logger.info("Defragmenter warmup completed successfully")

def warmup_graphs(self, buckets, is_prompt, kv_caches, starting_mem=0, total_batch_seq=0.001):
total_mem = starting_mem
idx = 0
Expand Down Expand Up @@ -3787,6 +3828,7 @@ def warmup_model(self) -> None:
"to be called before warming up the model.")

self.warmup_sampler()
self.warmup_defragmenter()

# TODO(kzawora): align_workers
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
Expand Down