- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[V1][TPU] Support V1 Sampler for ragged attention #14227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
            robertgshaw2-redhat
  merged 22 commits into
  vllm-project:main
from
NickLucche:tpu-sampler-ragged
  
      
      
   
  Mar 20, 2025 
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            22 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      4baf255
              
                xla friendly minp+topk
              
              
                NickLucche 0f50a2a
              
                fix platform check
              
              
                NickLucche 4be47d6
              
                tracing sampler
              
              
                NickLucche d53f7c8
              
                forward_tpu + revert topk selection
              
              
                NickLucche e9e48f1
              
                refactor to avoid recompiling None values and disable topk/topp
              
              
                NickLucche b84baa9
              
                tests + updated torch dev version
              
              
                NickLucche 25dead0
              
                wip: adapt to ragged attn kernel
              
              
                NickLucche deda02a
              
                adapt to ragged attn kernel and multimodal
              
              
                NickLucche 317714e
              
                add tests
              
              
                NickLucche be7bcec
              
                break up model|sample graph to speed up compilation
              
              
                NickLucche 07d9a1f
              
                minor check on optional temp
              
              
                NickLucche e65c55d
              
                fix greedy sampling
              
              
                NickLucche ade6054
              
                move tpu sampling params in own file
              
              
                NickLucche 178f104
              
                address review
              
              
                NickLucche 76460c6
              
                rebase cruft
              
              
                NickLucche d1f79a5
              
                max_num_tokens stopping condition when compiling
              
              
                NickLucche 4596951
              
                rebase changes
              
              
                NickLucche 10e1a04
              
                fix capture_graph loop
              
              
                NickLucche 50ef555
              
                fix recompilation issue on sampling graph; add new tpu sampler
              
              
                NickLucche 92d23cd
              
                newline conflict(?)
              
              
                NickLucche c2b5760
              
                Merge branch 'main' into tpu-sampler-ragged
              
              
                NickLucche 4d6d30c
              
                revert gpu sampler change
              
              
                NickLucche File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import tempfile | ||
| from time import time | ||
|  | ||
| import pytest | ||
|  | ||
| from vllm import LLM, envs | ||
| from vllm.platforms import current_platform | ||
| from vllm.sampling_params import SamplingParams | ||
|  | ||
| if not envs.VLLM_USE_V1: | ||
| pytest.skip( | ||
| "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", | ||
| allow_module_level=True, | ||
| ) | ||
|  | ||
|  | ||
| @pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"]) | ||
| @pytest.mark.skipif(not current_platform.is_tpu(), | ||
| reason="This test needs a TPU") | ||
| def test_sampler_compilation(model_name: str, monkeypatch): | ||
| """ | ||
| Check that no recompilation happens despite changing sampling parameters. | ||
| We can't read XLA metrics from the engine process, hence we measure time. | ||
| """ | ||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir) | ||
| # Compiling model init may still take some time, enforce_eager to skip. | ||
| llm = LLM(model_name, | ||
| enforce_eager=True, | ||
| max_num_seqs=16, | ||
| max_model_len=1024, | ||
| gpu_memory_utilization=0.5) | ||
| prompts = [ | ||
| "A robot may not injure a human being", | ||
| "It is only with the heart that one can see rightly;", | ||
| ] | ||
| # First inference should be slow | ||
| sampling_params = SamplingParams( | ||
| temperature=0.7, | ||
| # top_p=0.6, # TODO too slow! | ||
| # top_k=10, | ||
| min_p=0.2, | ||
| max_tokens=16) | ||
| s = time() | ||
| _ = llm.generate(prompts, sampling_params) | ||
| run1 = time() - s | ||
|  | ||
| # Second request with different params, but for which we | ||
| # compiled for in previous eager iteration. | ||
| sampling_params = SamplingParams(temperature=0.1, | ||
| min_p=0.8, | ||
| max_tokens=24) | ||
| s = time() | ||
| _ = llm.generate(prompts, sampling_params) | ||
| run2 = time() - s | ||
| # Much faster after compiling | ||
| assert run1 * 0.1 > run2 | ||
| print("TIMES", run1, run2) | ||
|  | ||
| # Third request with min_p set to "None". It will not trigger | ||
| # recompilation as a default 0 value will be used. | ||
| sampling_params = SamplingParams(max_tokens=24, temperature=0.0) | ||
| s = time() | ||
| _ = llm.generate(prompts, sampling_params) | ||
| run3 = time() - s | ||
| assert run1 * 0.1 > run3 | ||
| print("TIMES", run1, run3) | ||
|  | ||
|  | ||
| @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) | ||
| @pytest.mark.skipif(not current_platform.is_tpu(), | ||
| reason="This test needs a TPU") | ||
| def test_sampler_different(model_name: str): | ||
| """ | ||
| Test significantly different sampling params to assert the model produces | ||
| different results. | ||
| """ | ||
| llm = LLM( | ||
| model_name, | ||
| enforce_eager=True, | ||
| max_num_seqs=1, | ||
| max_model_len=64, | ||
| # TODO: setting to 0.5 or it will go OOM | ||
| gpu_memory_utilization=0.5) | ||
| prompts = [ | ||
| "Write a short story about a robot that dreams for the first time." | ||
| ] | ||
| sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) | ||
| output = llm.generate(prompts, sampling_params) | ||
|  | ||
| sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) | ||
| output2 = llm.generate(prompts, sampling_params) | ||
| assert output[0].outputs[0].text != output2[0].outputs[0].text | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              Empty file.
          
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from dataclasses import dataclass, field | ||
| from typing import Optional | ||
|  | ||
| import torch | ||
| import torch_xla.core.xla_model as xm | ||
|  | ||
| from vllm.v1.sample.metadata import SamplingMetadata | ||
|  | ||
|  | ||
| @dataclass | ||
| class TPUSupportedSamplingMetadata: | ||
| # This class exposes a more xla-friendly interface than SamplingMetadata | ||
| # on TPU, in particular all arguments should be traceable and no optionals | ||
| # are allowed, to avoid graph recompilation on Nones. | ||
| temperature: torch.Tensor | ||
|  | ||
| min_p: torch.Tensor | ||
| # Still too slow on forward_native! | ||
| top_k: torch.Tensor = None | ||
| top_p: torch.Tensor = None | ||
|  | ||
| # XLA-unfriendly control flow in Sampler | ||
| all_greedy: bool = False | ||
| all_random: bool = False | ||
| # Greedy sampling flag for compiling single xla graph. | ||
| do_argmax: torch.Tensor = None | ||
|  | ||
| # speculation not supported | ||
| spec_token_ids = None | ||
|  | ||
| # Generator not supported by xla | ||
| generators: dict[int, | ||
| torch.Generator] = field(default_factory=lambda: dict()) | ||
|  | ||
| # unsupported, you need to return an extra tensor of static size BxV | ||
| max_num_logprobs = None | ||
|  | ||
| # TODO No penalties for now | ||
| no_penalties: bool = True | ||
| prompt_token_ids = None | ||
| frequency_penalties = None | ||
| presence_penalties = None | ||
| repetition_penalties = None | ||
| # should use tensor | ||
| output_token_ids: list[list[int]] = field(default_factory=lambda: list()) | ||
|  | ||
| min_tokens = None # impl is not vectorized | ||
|  | ||
| logit_bias: list[Optional[dict[int, float]]] = field( | ||
| default_factory=lambda: list()) | ||
|  | ||
| allowed_token_ids_mask = None | ||
| bad_words_token_ids = None | ||
| indices_do_sample: torch.Tensor = None | ||
|  | ||
| def __post_init__(self): | ||
| temp = self.temperature | ||
| if self.indices_do_sample is None: | ||
| self.indices_do_sample = torch.zeros(temp.shape[0], | ||
| device=temp.device, | ||
| dtype=torch.int32) | ||
| if self.do_argmax is None: | ||
| self.do_argmax = torch.tensor(0, | ||
| dtype=torch.bool, | ||
| device=temp.device) | ||
|  | ||
| @classmethod | ||
| def from_sampling_metadata( | ||
| cls, metadata: SamplingMetadata, | ||
| padded_do_sample_indices: torch.Tensor, num_do_sample: int, | ||
| device: torch.device) -> "TPUSupportedSamplingMetadata": | ||
| """ | ||
| Create an XLA-frienly SamplingMetadata structure. Do so by first | ||
| instantiating an object with fixed-sized tensors and then writing the | ||
| values in input `metadata`. Do that only for non-None values so that | ||
| recompilation is not triggered for optional values (None/torch.Tensor). | ||
|  | ||
| In order to handle different sizes for the params that range from 1 up | ||
| to `max_num_seqs`, pad tensors to the closest pre-compiled shape. | ||
| Same thing for `padded_do_sample_indices`, which contains the indices | ||
| to be fed to the Sampler, padded to the closest pre-compiled shape. | ||
|  | ||
| Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] | ||
| do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0] | ||
| """ | ||
| metadata = cls._validate_sampling_metadata(metadata) | ||
| # NOTE we have to initialize default tensor-based params first and | ||
| # skip None values altogether to produce the same xla graph. | ||
| num_samples = len(padded_do_sample_indices) | ||
| do_argmax = torch.tensor(metadata.all_greedy, | ||
| dtype=torch.bool, | ||
| device=device) | ||
| new_metadata = cls.get_default_sampling_params(num_samples, device, | ||
| indices_do_sample=\ | ||
| padded_do_sample_indices, | ||
| do_argmax=do_argmax | ||
| ) | ||
| supported_params = \ | ||
| TPUSupportedSamplingMetadata._get_default_params_values() | ||
| # Copy input non-None values into `new_metadata` fixed-sized tensors. | ||
| for p_name in supported_params: | ||
| old_val = getattr(metadata, p_name) | ||
| new_val = getattr(new_metadata, p_name) | ||
| if isinstance(old_val, torch.Tensor): | ||
| new_val[:num_do_sample] = old_val | ||
| setattr(new_metadata, p_name, new_val) | ||
|  | ||
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
| return new_metadata | ||
|  | ||
| @classmethod | ||
| def get_default_sampling_params( | ||
| cls, | ||
| num_samples: int, | ||
| device: torch.device, | ||
| indices_do_sample=None, | ||
| do_argmax=None) -> "TPUSupportedSamplingMetadata": | ||
| # As sampling happens on a single traced graph, options | ||
| # are "disabled" by having them evaluate to an Identity op. | ||
| # Note that initialization is dependent on num_samples. | ||
| sampling_metadata_disable_value = \ | ||
| TPUSupportedSamplingMetadata._get_default_params_values() | ||
| init_kwargs = dict() | ||
| for p_name, (default_val, | ||
| dtype) in sampling_metadata_disable_value.items(): | ||
| default_tensor = torch.full((num_samples, ), | ||
| default_val, | ||
| dtype=dtype, | ||
| device=device) | ||
| init_kwargs[p_name] = default_tensor | ||
|  | ||
| return cls(**init_kwargs, | ||
| indices_do_sample=indices_do_sample, | ||
| do_argmax=do_argmax) | ||
|  | ||
| @staticmethod | ||
| def _validate_sampling_metadata( | ||
| sampling_metadata: SamplingMetadata) -> SamplingMetadata: | ||
| if sampling_metadata.all_greedy: | ||
| # Set to None since #13587. Make sure default isn't overruled. | ||
| assert sampling_metadata.temperature is None | ||
| return sampling_metadata | ||
|  | ||
| @staticmethod | ||
| def _get_default_params_values(): | ||
| return dict( | ||
| # Since #13587 greedy sampling requires branching off which leads | ||
| # to separate graphs. We set temp to noop and handle argmax here. | ||
| temperature=(1.0, torch.float32), | ||
| min_p=(0.0, torch.float32), | ||
| # strictly disabled for now | ||
| # top_k=(-1, torch.int32), | ||
| # top_p=(0.0, torch.float32), | ||
| # frequency_penalties=(0.0, torch.float32), | ||
| # presence_penalties=(0.0, torch.float32), | ||
| # repetition_penalties=(0.0, torch.float32), | ||
| ) | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.