diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py new file mode 100644 index 000000000000..76b8ddb92b78 --- /dev/null +++ b/tests/v1/tpu/test_sampler.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +import tempfile +from time import time + +import pytest + +from vllm import LLM, envs +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"]) +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +def test_sampler_compilation(model_name: str, monkeypatch): + """ + Check that no recompilation happens despite changing sampling parameters. + We can't read XLA metrics from the engine process, hence we measure time. + """ + with tempfile.TemporaryDirectory() as temp_dir: + monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir) + # Compiling model init may still take some time, enforce_eager to skip. + llm = LLM(model_name, + enforce_eager=True, + max_num_seqs=16, + max_model_len=1024, + gpu_memory_utilization=0.5) + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + ] + # First inference should be slow + sampling_params = SamplingParams( + temperature=0.7, + # top_p=0.6, # TODO too slow! + # top_k=10, + min_p=0.2, + max_tokens=16) + s = time() + _ = llm.generate(prompts, sampling_params) + run1 = time() - s + + # Second request with different params, but for which we + # compiled for in previous eager iteration. + sampling_params = SamplingParams(temperature=0.1, + min_p=0.8, + max_tokens=24) + s = time() + _ = llm.generate(prompts, sampling_params) + run2 = time() - s + # Much faster after compiling + assert run1 * 0.1 > run2 + print("TIMES", run1, run2) + + # Third request with min_p set to "None". It will not trigger + # recompilation as a default 0 value will be used. + sampling_params = SamplingParams(max_tokens=24, temperature=0.0) + s = time() + _ = llm.generate(prompts, sampling_params) + run3 = time() - s + assert run1 * 0.1 > run3 + print("TIMES", run1, run3) + + +@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +def test_sampler_different(model_name: str): + """ + Test significantly different sampling params to assert the model produces + different results. + """ + llm = LLM( + model_name, + enforce_eager=True, + max_num_seqs=1, + max_model_len=64, + # TODO: setting to 0.5 or it will go OOM + gpu_memory_utilization=0.5) + prompts = [ + "Write a short story about a robot that dreams for the first time." + ] + sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) + output = llm.generate(prompts, sampling_params) + + sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) + output2 = llm.generate(prompts, sampling_params) + assert output[0].outputs[0].text != output2[0].outputs[0].text diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index d461a8098933..e1a3e92de493 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -65,6 +65,8 @@ def __init__(self): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native + elif current_platform.is_tpu(): + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -96,6 +98,18 @@ def forward_cuda( return random_sample(probs, generators) return flashinfer_sample(probs, k, p, generators) + def forward_tpu( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + # TODO Placeholder for TPU optimized topk/p kernel + # logits = apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + def apply_top_k_top_p( logits: torch.Tensor, @@ -112,7 +126,7 @@ def apply_top_k_top_p( if k is not None: # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B # Get all the top_k values. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) top_k_mask = logits_sort < top_k_mask diff --git a/vllm/v1/sample/tpu/__init__.py b/vllm/v1/sample/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py new file mode 100644 index 000000000000..b4f7c19a8d3d --- /dev/null +++ b/vllm/v1/sample/tpu/metadata.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torch_xla.core.xla_model as xm + +from vllm.v1.sample.metadata import SamplingMetadata + + +@dataclass +class TPUSupportedSamplingMetadata: + # This class exposes a more xla-friendly interface than SamplingMetadata + # on TPU, in particular all arguments should be traceable and no optionals + # are allowed, to avoid graph recompilation on Nones. + temperature: torch.Tensor + + min_p: torch.Tensor + # Still too slow on forward_native! + top_k: torch.Tensor = None + top_p: torch.Tensor = None + + # XLA-unfriendly control flow in Sampler + all_greedy: bool = False + all_random: bool = False + # Greedy sampling flag for compiling single xla graph. + do_argmax: torch.Tensor = None + + # speculation not supported + spec_token_ids = None + + # Generator not supported by xla + generators: dict[int, + torch.Generator] = field(default_factory=lambda: dict()) + + # unsupported, you need to return an extra tensor of static size BxV + max_num_logprobs = None + + # TODO No penalties for now + no_penalties: bool = True + prompt_token_ids = None + frequency_penalties = None + presence_penalties = None + repetition_penalties = None + # should use tensor + output_token_ids: list[list[int]] = field(default_factory=lambda: list()) + + min_tokens = None # impl is not vectorized + + logit_bias: list[Optional[dict[int, float]]] = field( + default_factory=lambda: list()) + + allowed_token_ids_mask = None + bad_words_token_ids = None + indices_do_sample: torch.Tensor = None + + def __post_init__(self): + temp = self.temperature + if self.indices_do_sample is None: + self.indices_do_sample = torch.zeros(temp.shape[0], + device=temp.device, + dtype=torch.int32) + if self.do_argmax is None: + self.do_argmax = torch.tensor(0, + dtype=torch.bool, + device=temp.device) + + @classmethod + def from_sampling_metadata( + cls, metadata: SamplingMetadata, + padded_do_sample_indices: torch.Tensor, num_do_sample: int, + device: torch.device) -> "TPUSupportedSamplingMetadata": + """ + Create an XLA-frienly SamplingMetadata structure. Do so by first + instantiating an object with fixed-sized tensors and then writing the + values in input `metadata`. Do that only for non-None values so that + recompilation is not triggered for optional values (None/torch.Tensor). + + In order to handle different sizes for the params that range from 1 up + to `max_num_seqs`, pad tensors to the closest pre-compiled shape. + Same thing for `padded_do_sample_indices`, which contains the indices + to be fed to the Sampler, padded to the closest pre-compiled shape. + + Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] + do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0] + """ + metadata = cls._validate_sampling_metadata(metadata) + # NOTE we have to initialize default tensor-based params first and + # skip None values altogether to produce the same xla graph. + num_samples = len(padded_do_sample_indices) + do_argmax = torch.tensor(metadata.all_greedy, + dtype=torch.bool, + device=device) + new_metadata = cls.get_default_sampling_params(num_samples, device, + indices_do_sample=\ + padded_do_sample_indices, + do_argmax=do_argmax + ) + supported_params = \ + TPUSupportedSamplingMetadata._get_default_params_values() + # Copy input non-None values into `new_metadata` fixed-sized tensors. + for p_name in supported_params: + old_val = getattr(metadata, p_name) + new_val = getattr(new_metadata, p_name) + if isinstance(old_val, torch.Tensor): + new_val[:num_do_sample] = old_val + setattr(new_metadata, p_name, new_val) + + xm.mark_step() + xm.wait_device_ops() + return new_metadata + + @classmethod + def get_default_sampling_params( + cls, + num_samples: int, + device: torch.device, + indices_do_sample=None, + do_argmax=None) -> "TPUSupportedSamplingMetadata": + # As sampling happens on a single traced graph, options + # are "disabled" by having them evaluate to an Identity op. + # Note that initialization is dependent on num_samples. + sampling_metadata_disable_value = \ + TPUSupportedSamplingMetadata._get_default_params_values() + init_kwargs = dict() + for p_name, (default_val, + dtype) in sampling_metadata_disable_value.items(): + default_tensor = torch.full((num_samples, ), + default_val, + dtype=dtype, + device=device) + init_kwargs[p_name] = default_tensor + + return cls(**init_kwargs, + indices_do_sample=indices_do_sample, + do_argmax=do_argmax) + + @staticmethod + def _validate_sampling_metadata( + sampling_metadata: SamplingMetadata) -> SamplingMetadata: + if sampling_metadata.all_greedy: + # Set to None since #13587. Make sure default isn't overruled. + assert sampling_metadata.temperature is None + return sampling_metadata + + @staticmethod + def _get_default_params_values(): + return dict( + # Since #13587 greedy sampling requires branching off which leads + # to separate graphs. We set temp to noop and handle argmax here. + temperature=(1.0, torch.float32), + min_p=(0.0, torch.float32), + # strictly disabled for now + # top_k=(-1, torch.int32), + # top_p=(0.0, torch.float32), + # frequency_penalties=(0.0, torch.float32), + # presence_penalties=(0.0, torch.float32), + # repetition_penalties=(0.0, torch.float32), + ) \ No newline at end of file diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py new file mode 100644 index 000000000000..33526c003a24 --- /dev/null +++ b/vllm/v1/sample/tpu/sampler.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Sampler layer implementing TPU supported operations.""" + +import torch +import torch.nn as nn + +from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler +from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> SamplerOutput: + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + + # Use float32 for the logits. + logits = logits.to(torch.float32) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) + + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) + + # These are GPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.unsqueeze(-1), + logprobs_tensors=None, + ) + return sampler_output + + def apply_temperature( + self, + logits: torch.Tensor, + temp: torch.Tensor, + ) -> torch.Tensor: + # Use in-place division to avoid creating a new tensor. + return logits.div_(temp.unsqueeze(dim=1)) + + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> torch.Tensor: + greedy_sampled = self.greedy_sample(logits) + + assert sampling_metadata.temperature is not None + + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + + # Apply min_p. + if sampling_metadata.min_p is not None: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + + # Apply top_k and/or top_p. + random_sampled = self.topk_topp_sampler( + logits, + sampling_metadata.generators, + sampling_metadata.top_k, + sampling_metadata.top_p, + ) + + sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, random_sampled) + return sampled + + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( + self, + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logits: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + + # Use int32 to reduce the tensor size. + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) + + def apply_min_p( + self, + logits: torch.Tensor, + min_p: torch.Tensor, + ) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + # Apply mask using boolean indexing (xla friendly) + logits.masked_fill_(~valid_token_mask, -float("inf")) + return logits diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 00869467be34..ae697888301a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -23,13 +23,16 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, + PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) + ModelRunnerOutput, SamplerOutput) +from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata +from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -42,6 +45,8 @@ # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 INVALID_TOKEN_ID = -1 +# Smallest output size +MIN_NUM_SEQS = 8 class TPUModelRunner: @@ -138,8 +143,10 @@ def __init__( device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() + padded_max_num_blocks_per_req = _get_padded_number( + self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros( - (self.max_num_tokens, self.max_num_blocks_per_req), + (self.max_num_tokens, padded_max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") @@ -267,6 +274,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_data.num_computed_tokens) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. + batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -284,6 +294,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) + + # TODO This slices tensors to copy to device, triggering recompilation. + if batch_changed: + self.input_batch.refresh_sampling_metadata() return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def get_model(self) -> nn.Module: @@ -444,6 +458,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # TODO: Support prompt logprobs. padded_num_reqs = _get_padded_num_reqs_with_upper_limit( num_reqs, self.max_num_reqs) + # Indices at which we sample (positions of last token in the sequence). + # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) return attn_metadata, logits_indices @@ -573,7 +589,14 @@ def execute_model( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids inputs_embeds = None - + sampling_metadata = self.input_batch.sampling_metadata + num_reqs = self.input_batch.num_reqs + # NOTE (NickLucche) here we sync with TPU: if there's any shape + # mismatch in pre-processing, it will trigger a small recompilation + # of the code thus far. Forward graph remains untouched. + tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ + from_sampling_metadata(sampling_metadata, logits_indices, + num_reqs, self.device) # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( @@ -582,12 +605,13 @@ def execute_model( kv_caches=self.kv_caches, inputs_embeds=inputs_embeds, ) - num_reqs = self.input_batch.num_reqs - selected_token_ids = self.model.compute_logits(hidden_states, - logits_indices, None) + selected_token_ids = self.model.sample_from_hidden( + hidden_states, tpu_sampling_metadata) + # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] - # Then, let's update the cache state. + # Update the cache state concurrently. Code above will not block until + # we use `selected_token_ids`. Add mark_step if post-processing changes request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None @@ -604,7 +628,6 @@ def execute_model( # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - # num_reqs entries should be non-None assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None" @@ -617,6 +640,7 @@ def execute_model( max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = selected_token_ids.tolist() + for i, req_state, seq_len in request_seq_lens: token_id = valid_sampled_token_ids[i][0] self.input_batch.token_ids_cpu[i, seq_len] = token_id @@ -673,11 +697,8 @@ def load_model(self) -> None: fullgraph=True, dynamic=False) - def _dummy_run( - self, - kv_caches, - num_tokens: int, - ) -> None: + @torch.no_grad() + def _dummy_run(self, kv_caches, num_tokens: int) -> None: if self.is_multimodal_model: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), @@ -726,32 +747,10 @@ def _dummy_run( torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds, - ) - num_reqs = _get_padded_num_reqs_with_upper_limit( - 64, self.max_num_reqs) - # NOTE(chengjiyao): In total, the compute_logits function utilizes a - # compilation cache size of token_bucket_num multiplied by - # req_bucket_num. This is acceptable, given the graph's relatively - # small size. - while True: - logits_indices = torch.zeros( - num_reqs, - dtype=torch.int32, - device=self.device, - ) - torch._dynamo.mark_dynamic(hidden_states, 0) - torch._dynamo.mark_dynamic(logits_indices, 0) - self.model.compute_logits(hidden_states, logits_indices, None) - if num_reqs >= self.max_num_reqs: - break - num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs + 1, self.max_num_reqs) + self.model(input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) def capture_model(self) -> None: """Compile the model.""" @@ -761,13 +760,51 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 while True: - self._dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) + self._dummy_run(self.kv_caches, num_tokens) xm.mark_step() - xm.wait_device_ops() if num_tokens >= self.max_num_tokens: break num_tokens *= 2 + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) + + logger.info("Compiling sampling with different input shapes.") + start = time.perf_counter() + num_tokens = 16 + hsize = self.model_config.get_hidden_size() + device = self.device + # Compile sampling step for different model+sampler outputs in bucketed + # n_tokens x max_num_reqs. Graph is really small so this is fine. + while True: + num_reqs_to_sample = MIN_NUM_SEQS + dummy_hidden = torch.randn((num_tokens, hsize), + device=device, + dtype=torch.bfloat16) + while True: + # Default metadata is an all_greedy setup. But since the + # `do_argmax` flag is a tensor, we still compile the full graph + meta = self.input_batch.sampling_metadata + indices = torch.zeros( + num_reqs_to_sample, + dtype=torch.int32, + device=device, + ) + sampling_meta = TPUSupportedSamplingMetadata.\ + from_sampling_metadata(meta, indices, + num_reqs_to_sample, device) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, + num_reqs_to_sample) + self.model.sample_from_hidden(dummy_hidden, sampling_meta) + xm.mark_step() + if num_reqs_to_sample >= self.max_num_reqs: + break + num_reqs_to_sample *= 2 + if num_tokens >= self.max_num_tokens: + break + num_tokens *= 2 + xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) @@ -815,6 +852,13 @@ class ModelWrapperV1(nn.Module): def __init__(self, model: nn.Module): super().__init__() self.model = model + self.sampler = TPUSampler() + + def sample( + self, logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: + sampler_out = self.sampler(logits, sampling_metadata) + return sampler_out def forward( self, @@ -823,7 +867,7 @@ def forward( kv_caches: list[tuple[torch.Tensor, torch.Tensor]], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Executes the forward pass of the model and samples the next token. + """Executes the forward pass of the model. Args: input_ids: The input token IDs of shape [num_tokens]. @@ -834,7 +878,6 @@ def forward( hidden_size]. It is used for multimodal models. """ - assert self.model is not None hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -843,17 +886,33 @@ def forward( return hidden_states - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits( + def sample_from_hidden( self, hidden_states: torch.Tensor, - logits_indices: torch.Tensor, - sampling_metadata, - ) -> Optional[torch.Tensor]: - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, sampling_metadata) - selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - return selected_token_ids + sampling_metadata: TPUSupportedSamplingMetadata, + ) -> torch.Tensor: + """ + Sample with xla-friendly function. This function is to be traced + separately from `forward` for lighter compilation overhead. + """ + # Tensor `sample_hidden_states` is of fixed pre-compiled size. + sample_hidden_states = \ + hidden_states[sampling_metadata.indices_do_sample] + logits = self.compute_logits(sample_hidden_states) + # Greedy sampling can't be run without branching the graph on Sampler. + # Therefore do_argmax/all_greedy is checked here in a xla-friendly way. + # NOTE do_argmax is a scalar, this is just an optimized if/else. + out_tokens = torch.where(sampling_metadata.do_argmax, + torch.argmax(logits, dim=-1, keepdim=True), + self.sample(logits, sampling_metadata)\ + .sampled_token_ids) + return out_tokens + + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + # SamplingMetadata here for pruning output in LogitsProcessor, disabled + logits = self.model.compute_logits(hidden_states, None) + return logits def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs) @@ -873,5 +932,5 @@ def _get_padded_token_len(x: int) -> int: def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: - res = 64 if x <= 64 else 1 << (x - 1).bit_length() + res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() return min(res, upper_limit)