diff --git a/vllm_gaudi/extension/defragmentation.py b/vllm_gaudi/extension/defragmentation.py index c93cd14a..c431dc02 100644 --- a/vllm_gaudi/extension/defragmentation.py +++ b/vllm_gaudi/extension/defragmentation.py @@ -60,6 +60,7 @@ class OnlineDefragmenter: def __init__(self): config = get_config() self.threshold = with_default(config.VLLM_DEFRAG_THRESHOLD, 32) + self.to_swap_pad_thresholds = [8, 16, 32, 64, 128, 256, 512] self.used_blocks = {} self.req_blocks = {} self.fwd_mapping_table = [] @@ -174,6 +175,7 @@ def defragment(self): max_used = max(self.used_blocks.keys()) num_used = len(self.used_blocks) pre_max_used = max_used + # Use threshold for fragmentation trigger if max_used - self.threshold <= num_used: return free = self.free_blocks() @@ -181,7 +183,7 @@ def defragment(self): to_swap: list[tuple[int, int]] = [] for used_block, free_block in zip(used, free): - if len(to_swap) == self.threshold or free_block > used_block: + if len(to_swap) == self.to_swap_pad_thresholds[-1] or free_block > used_block: break assert used_block in self.used_blocks assert free_block not in self.used_blocks @@ -195,9 +197,11 @@ def defragment(self): self.update_mapping(orig_free_block, used_block) assert self.cache_utils is not None - self.cache_utils.swap(to_swap, self.threshold) + to_swap_pad = next((x for x in self.to_swap_pad_thresholds if x >= len(to_swap)), + self.to_swap_pad_thresholds[-1]) + self.cache_utils.swap(to_swap, to_swap_pad) if self.debug: max_used = max(self.used_blocks.keys()) num_used = len(self.used_blocks) - post_status = f'max_id_used={pre_max_used}->{max_used} num_used={num_used} swapped={len(to_swap)}/{self.threshold}' + post_status = f'max_id_used={pre_max_used}->{max_used} num_used={num_used} swapped={len(to_swap)}/{to_swap_pad}' self.debug(f'defragmentation done {post_status}') diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 137abe69..a58e19af 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -3508,6 +3508,47 @@ def warmup_sampler(self): logger.info("Sampler warmup completed successfully") + def warmup_defragmenter(self): + """Warm up defragmentation swap graphs for different thresholds. + + We execute a minimal swap (1 pair) which will be padded internally to the + requested threshold size. Thresholds chosen to mirror potential production + values: 8, 16, 32, 64, 128, 256, 512. + """ + # If defragmenter is disabled or cache utils not prepared, skip. + if not getattr(self.defragmenter, 'enabled', False): + return + if self.defragmenter.cache_utils is None: + return + + thresholds = self.defragmenter.to_swap_pad_thresholds + + logger.info("Warming up defragmenter with thresholds: %s", thresholds) + + # Use simple valid block ids present in caches (assume at least 2 blocks allocated when kv caches created) + # We only need distinct ids for a swap. They will be scaled by block_size inside swap. + # If for some reason only 1 block exists, skip warmup gracefully. + try: + k_cache = self.defragmenter.cache_utils.kv_caches[0][0] + num_blocks_available = k_cache.shape[0] // self.block_size + except Exception: + num_blocks_available = 0 + if num_blocks_available < 2: + logger.warning("Skipping defragmenter warmup, insufficient blocks (%s)", num_blocks_available) + return + + # Minimal pair to trigger a swap path + to_swap = [(1, 0)] + + for th in thresholds: + self.defragmenter.cache_utils.swap(to_swap, th) + + # If the number of swaps was odd, do one more to make it even and return to original state. + if len(thresholds) % 2 == 1: + self.defragmenter.cache_utils.swap(to_swap, thresholds[0]) + + logger.info("Defragmenter warmup completed successfully") + def warmup_graphs(self, buckets, is_prompt, kv_caches, starting_mem=0, total_batch_seq=0.001): total_mem = starting_mem idx = 0 @@ -3787,6 +3828,7 @@ def warmup_model(self) -> None: "to be called before warming up the model.") self.warmup_sampler() + self.warmup_defragmenter() # TODO(kzawora): align_workers mem_post_prompt, prompt_batch_seq, prompt_captured_all = \