-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
Description
🚀 The feature, motivation and pitch
Currently we pad for cudagraphs:
vllm/vllm/v1/worker/gpu_model_runner.py
Line 1497 in 853c371
| num_input_tokens = self.vllm_config.pad_for_cudagraph( |
after constructing the metadata:
vllm/vllm/v1/worker/gpu_model_runner.py
Lines 820 to 909 in 853c371
| for kv_cache_group_id, kv_cache_group_spec in enumerate( | |
| self.kv_cache_config.kv_cache_groups): | |
| if isinstance(kv_cache_group_spec.kv_cache_spec, | |
| EncoderOnlyAttentionSpec): | |
| # Encoder-only layers do not have KV cache, so we need to | |
| # create a dummy block table and slot mapping for them. | |
| blk_table_tensor = torch.zeros( | |
| (num_reqs, 1), | |
| dtype=torch.int32, | |
| pin_memory=self.pin_memory, | |
| device="cpu").to(self.device, non_blocking=True) | |
| slot_mapping = torch.zeros((total_num_scheduled_tokens, ), | |
| dtype=torch.int32, | |
| pin_memory=self.pin_memory, | |
| device="cpu").to(self.device, | |
| non_blocking=True) | |
| num_common_prefix_blocks = 0 | |
| else: | |
| blk_table = self.input_batch.block_table[kv_cache_group_id] | |
| blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] | |
| slot_mapping = blk_table.slot_mapping[: | |
| total_num_scheduled_tokens] | |
| # Fill unused with -1. Needed for reshape_and_cache in full cuda | |
| # graph mode. | |
| blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) | |
| num_common_prefix_blocks = ( | |
| scheduler_output. | |
| num_common_prefix_blocks[kv_cache_group_id]) | |
| common_attn_metadata = CommonAttentionMetadata( | |
| query_start_loc=query_start_loc, | |
| query_start_loc_cpu=query_start_loc_cpu, | |
| seq_lens=seq_lens, | |
| seq_lens_cpu=seq_lens_cpu, | |
| num_computed_tokens_cpu=num_computed_tokens_cpu, | |
| num_reqs=num_reqs, | |
| num_actual_tokens=total_num_scheduled_tokens, | |
| max_query_len=max_num_scheduled_tokens, | |
| max_seq_len=max_seq_len, | |
| block_table_tensor=blk_table_tensor, | |
| slot_mapping=slot_mapping, | |
| causal=True, | |
| ) | |
| if self.speculative_config and \ | |
| spec_decode_common_attn_metadata is None: | |
| spec_decode_common_attn_metadata = common_attn_metadata | |
| for attn_group in self.attn_groups[kv_cache_group_id]: | |
| # Prepare for cascade attention if enabled & beneficial. | |
| common_prefix_len = 0 | |
| builder = attn_group.metadata_builder | |
| if self.cascade_attn_enabled: | |
| common_prefix_len = self._compute_cascade_attn_prefix_len( | |
| num_scheduled_tokens, | |
| num_common_prefix_blocks, | |
| kv_cache_group_spec.kv_cache_spec, | |
| builder, | |
| ) | |
| attn_metadata_i = (builder.build( | |
| common_prefix_len=common_prefix_len, | |
| common_attn_metadata=common_attn_metadata, | |
| )) | |
| fast_prefill_metadata = attn_metadata_i | |
| if (self.cache_config.kv_sharing_fast_prefill | |
| and self.kv_sharing_fast_prefill_eligible_layers): | |
| # Dynamically create a a dataclass type that inherits | |
| # from attention metadata type but includes additional | |
| # fields logits_indices_padded and num_logits_indices | |
| # which are required for prefill truncation | |
| fast_prefill_metadata_type = ( | |
| make_kv_sharing_fast_prefill_attention_metadata( | |
| metadata_cls=type(attn_metadata_i), )) | |
| fast_prefill_metadata = fast_prefill_metadata_type( | |
| **dataclasses.asdict(attn_metadata_i), | |
| logits_indices_padded=logits_indices_padded, | |
| num_logits_indices=logits_indices.size(0), | |
| ) | |
| for layer_name in attn_group.layer_names: | |
| if (self.cache_config.kv_sharing_fast_prefill | |
| and layer_name | |
| in self.kv_sharing_fast_prefill_eligible_layers): | |
| attn_metadata[layer_name] = fast_prefill_metadata | |
| continue | |
| attn_metadata[layer_name] = attn_metadata_i |
This creates headaches for making attention backends with ahead-of-time schedulers cudagraph compatible e.g. FlashMLA and FlashAttn. This is because the scheduler might be called with a batch size smaller then what the graph is running with; e.g. we recently had to work around this in FlashMLA (vllm-project/FlashMLA#3). Moving the padding calculation before metadata construction and passing padded CommonAttentionMetadata is much better long term solution though since the we won't have to work around this in future attention kernels.
Alternatives
No response
Additional context
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.