- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
[Core] Support full cuda graph in v1 #16072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7a16df
              6aa65db
              c338c99
              15edc9b
              f9d41b4
              bdb1747
              4700afd
              e55d023
              353ea66
              1dcdb37
              08b8d6a
              aa9e4b6
              7e821e0
              0d1a796
              166e6a6
              5deacad
              6b523ac
              3cfd971
              59e52e6
              22fa9df
              659d9b9
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import contextlib | ||
| import os | ||
| 
     | 
||
| import pytest | ||
| 
     | 
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import CompilationConfig | ||
| 
     | 
||
| MODEL = "Qwen/Qwen2-1.5B-Instruct" | ||
| 
     | 
||
| 
     | 
||
| @contextlib.contextmanager | ||
| def temporary_environ(env_vars): | ||
| """ | ||
| Temporarily set environment variables and restore them afterward. | ||
| We have to do this vs monkeypatch because monkeypatch doesn't work | ||
| with "module" scoped fixtures. | ||
| """ | ||
| original_env = {k: os.environ.get(k) for k in env_vars} | ||
| try: | ||
| os.environ.update(env_vars) | ||
| yield | ||
| finally: | ||
| for k, v in original_env.items(): | ||
| if v is None: | ||
| os.environ.pop(k, None) | ||
| else: | ||
| os.environ[k] = v | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(scope="module") | ||
| def full_cudagraph_llm(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": "3" | ||
| }): | ||
| return LLM(model=MODEL, | ||
| gpu_memory_utilization=0.2, | ||
| compilation_config=CompilationConfig(full_cuda_graph=True)) | ||
| 
     | 
||
| 
     | 
||
| @pytest.fixture(scope="module") | ||
| def piecewise_llm(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": "3" | ||
| }): | ||
| return LLM(model=MODEL, | ||
| gpu_memory_utilization=0.5, | ||
| compilation_config=CompilationConfig()) | ||
| 
     | 
||
| 
     | 
||
| def generate_text(llm: LLM, batch_size: int, max_tokens: int): | ||
| prompts = ["Hi my name is"] * batch_size | ||
| sampling_params = SamplingParams(temperature=0.0, | ||
| max_tokens=max_tokens, | ||
| top_p=0.95) | ||
| 
     | 
||
| return llm.generate(prompts, sampling_params) | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), | ||
| (16, 10), (25, 10), | ||
| (32, 10), (45, 10), | ||
| (64, 10), (8, 5), | ||
| (8, 20), (8, 200)]) | ||
| def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, | ||
| piecewise_llm): | ||
| """ | ||
| Load full cudagraph model and piecewise model once, and at the same time to | ||
| reuse them across various test cases. | ||
| 
     | 
||
| Test various batch sizes and max_tokens to ensure that the full cudagraph | ||
| compilation works for padded cases too. | ||
| """ | ||
| piecewise_responses = generate_text(piecewise_llm, | ||
| batch_size=batch_size, | ||
| max_tokens=max_tokens) | ||
| full_cudagraph_responses = generate_text(full_cudagraph_llm, | ||
| batch_size=batch_size, | ||
| max_tokens=max_tokens) | ||
| 
     | 
||
| # Check that all responses are the same | ||
| for i in range(len(piecewise_responses)): | ||
| assert piecewise_responses[i].outputs[ | ||
| 0].text == full_cudagraph_responses[i].outputs[0].text | ||
| 
     | 
||
| 
     | 
||
| def test_full_cudagraph_with_invalid_backend(): | ||
| with temporary_environ({ | ||
| "VLLM_USE_V1": "1", | ||
| "VLLM_FLASH_ATTN_VERSION": | ||
| "2" #FA2 not supported with full_cuda_graph | ||
| }), pytest.raises(RuntimeError): | ||
| LLM(model=MODEL, | ||
| compilation_config=CompilationConfig(full_cuda_graph=True)) | 
                              
      
                  chanh marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
            
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -12,6 +12,7 @@ | |
| 
     | 
||
| from vllm.attention import AttentionType, get_attn_backend | ||
| from vllm.attention.layer import Attention | ||
| from vllm.attention.utils.fa_utils import get_flash_attn_version | ||
| from vllm.config import (CompilationLevel, VllmConfig, | ||
| get_layers_from_vllm_config) | ||
| from vllm.distributed.kv_transfer import (get_kv_transfer_group, | ||
| 
          
            
          
           | 
    @@ -139,6 +140,16 @@ def __init__( | |
| raise NotImplementedError( | ||
| "Non-Attention backend is not supported by V1 GPUModelRunner.") | ||
| 
     | 
||
| if self.vllm_config.compilation_config.full_cuda_graph: | ||
| attn_backend_name = self.attn_backend.__name__ | ||
| flash_attn_version = get_flash_attn_version() | ||
| if attn_backend_name != "FlashAttentionBackend" or \ | ||
| flash_attn_version != 3: | ||
| raise ValueError( | ||
| f"full_cuda_graph is only supported with " | ||
| f"FA3. Current attention backend is {attn_backend_name}, " | ||
| f"FlashAttention version is {flash_attn_version}.") | ||
| 
     | 
||
| self.attn_metadata_builder = self.attn_backend.get_builder_cls()( | ||
| weakref.proxy(self)) | ||
| self.cascade_attn_enabled = not self.model_config.disable_cascade_attn | ||
| 
          
            
          
           | 
    @@ -219,6 +230,16 @@ def __init__( | |
| self.positions = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int64, | ||
| device=self.device) | ||
| self.query_start_loc = torch.zeros(self.max_num_reqs + 1, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.seq_lens = torch.zeros(self.max_num_reqs, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.slot_mapping = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int64, | ||
| device=self.device) | ||
| 
     | 
||
| # None in the first PP rank. The rest are set after load_model. | ||
| self.intermediate_tensors: Optional[IntermediateTensors] = None | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -271,7 +292,7 @@ def __init__( | |
| pin_memory=self.pin_memory) | ||
| self.positions_np = self.positions_cpu.numpy() | ||
| self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, | ||
| dtype=torch.int32, | ||
| dtype=torch.int64, | ||
| device="cpu", | ||
| pin_memory=self.pin_memory) | ||
| self.slot_mapping_np = self.slot_mapping_cpu.numpy() | ||
| 
          
            
          
           | 
    @@ -589,10 +610,22 @@ def _prepare_inputs( | |
| self.positions_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
| 
     | 
||
| query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( | ||
| self.device, non_blocking=True) | ||
| seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, | ||
| non_blocking=True) | ||
| self.query_start_loc[:num_reqs + 1].copy_( | ||
| self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) | ||
| self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], | ||
| non_blocking=True) | ||
| self.slot_mapping[:total_num_scheduled_tokens].copy_( | ||
| self.slot_mapping_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
| 
     | 
||
| # Fill unused with -1. Needed for reshape_and_cache | ||
| self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) | ||
                
      
                  WoosukKwon marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| self.seq_lens[num_reqs:].fill_(0) | ||
| self.query_start_loc[num_reqs + 1:].fill_(-1) | ||
                
      
                  chanh marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| query_start_loc = self.query_start_loc[:num_reqs + 1] | ||
| seq_lens = self.seq_lens[:num_reqs] | ||
| 
     | 
||
| common_attn_metadata = CommonAttentionMetadata( | ||
| query_start_loc=query_start_loc, seq_lens=seq_lens) | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -1478,6 +1511,7 @@ def _get_prompt_logprobs_dict( | |
| def _dummy_run( | ||
| self, | ||
| num_tokens: int, | ||
| skip_attn: bool = True, | ||
| ) -> torch.Tensor: | ||
| 
     | 
||
| # Set num_scheduled_tokens based on num_tokens and max_num_seqs | ||
| 
        
          
        
         | 
    @@ -1494,6 +1528,23 @@ def _dummy_run( | |
| num_scheduled_tokens = np.array(num_scheduled_tokens_list, | ||
| dtype=np.int32) | ||
| 
     | 
||
| if skip_attn: | ||
| attn_metadata = None | ||
| else: | ||
| query_start_loc = self.query_start_loc[:num_reqs + 1] | ||
| seq_lens = self.seq_lens[:num_reqs] | ||
| 
     | 
||
| common_attn_metadata = CommonAttentionMetadata( | ||
| query_start_loc=query_start_loc, seq_lens=seq_lens) | ||
| 
     | 
||
| attn_metadata = self.attn_metadata_builder.build( | ||
| num_reqs=num_tokens, | ||
| num_actual_tokens=num_tokens, | ||
| max_query_len=num_tokens, | ||
| common_prefix_len=0, | ||
| common_attn_metadata=common_attn_metadata, | ||
| ) | ||
| 
     | 
||
| with self.maybe_dummy_run_with_lora(self.lora_config, | ||
| num_scheduled_tokens): | ||
| model = self.model | ||
| 
          
            
          
           | 
    @@ -1522,7 +1573,7 @@ def _dummy_run( | |
| for k, v in self.intermediate_tensors.items() | ||
| }) | ||
| 
     | 
||
| with set_forward_context(None, | ||
| with set_forward_context(attn_metadata, | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering that  Therefore, Full Cuda Graph should be incompatible with kvconnector?  | 
||
| self.vllm_config, | ||
| num_tokens=num_tokens): | ||
| outputs = model( | ||
| 
          
            
          
           | 
    @@ -1708,11 +1759,12 @@ def capture_model(self) -> None: | |
| # Capture the large shapes first so that the smaller shapes | ||
| # can reuse the memory pool allocated for the large shapes. | ||
| with graph_capture(device=self.device): | ||
| skip_attn = not self.vllm_config.compilation_config.full_cuda_graph | ||
| for num_tokens in reversed(self.cudagraph_batch_sizes): | ||
| for _ in range(self.vllm_config.compilation_config. | ||
| cudagraph_num_of_warmups): | ||
| self._dummy_run(num_tokens) | ||
| self._dummy_run(num_tokens) | ||
| self._dummy_run(num_tokens, skip_attn=skip_attn) | ||
| self._dummy_run(num_tokens, skip_attn=skip_attn) | ||
| 
     | 
||
| end_time = time.perf_counter() | ||
| end_free_gpu_memory = torch.cuda.mem_get_info()[0] | ||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon it's because we call
.long()here. We might want to still call it here, to keep the dtypes consistent in the model runner.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth what do you think about just making the CPU tensor int64 too? (that's the route that i went with in latest update on this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to check - that takes the
slot_mappingCPU-> GPU transfer from 32KB to 64KB (by default serving on an H100). That seems fine to me since now we don't do that copy in every layer