33from typing import Optional
44
55import torch
6- import torch_xla .core .xla_model as xm
76
87from vllm .v1 .worker .gpu_input_batch import InputBatch
98
@@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata:
2423 # This class exposes a more xla-friendly interface than SamplingMetadata
2524 # on TPU, in particular all arguments should be traceable and no optionals
2625 # are allowed, to avoid graph recompilation on Nones.
27- temperature : torch .Tensor
26+ temperature : torch .Tensor = None
2827
29- min_p : torch .Tensor
28+ min_p : torch .Tensor = None
3029 # Still too slow on forward_native!
3130 top_k : torch .Tensor = None
3231 top_p : torch .Tensor = None
3332
3433 # Greedy sampling flag for compiling single xla graph.
35- all_greedy : torch . Tensor = None
34+ all_greedy : bool = True
3635
3736 # Generator not supported by xla
3837 generators : dict [int ,
@@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata:
5756
5857 allowed_token_ids_mask = None
5958 bad_words_token_ids = None
60- indices_do_sample : torch .Tensor = None
6159
6260 @classmethod
6361 def from_input_batch (
64- cls , input_batch : InputBatch ,
65- indices_do_sample : torch .Tensor ) -> "TPUSupportedSamplingMetadata" :
62+ cls ,
63+ input_batch : InputBatch ,
64+ padded_num_reqs : int ,
65+ xla_device : torch .device ,
66+ generate_params_if_all_greedy : bool = False
67+ ) -> "TPUSupportedSamplingMetadata" :
6668 """
6769 Copy sampling tensors slices from `input_batch` to on device tensors.
6870
6971 `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
7072 slices dynamic shapes on device tensors. This impl moves the dynamic
71- ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
72- also reuses the on-device persistent tensors managed in `input_batch`
73- to reduce waste.
74-
75- `indices_do_sample` contains the indices to be fed to the Sampler,
76- normally one per request, here padded to the closest pre-compiled shape
77- We expect sampling params tensors to be padded to the same fixed shape.
78-
79- Eg. 3 requests, tensors padded to 4
80- temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
81- sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
73+ ops to CPU and produces tensors of fixed `padded_num_reqs` size.
74+
75+ Args:
76+ input_batch: The input batch containing sampling parameters.
77+ padded_num_reqs: The padded number of requests.
78+ xla_device: The XLA device.
79+ generate_params_if_all_greedy: If True, generate sampling parameters
80+ even if all requests are greedy. this is useful for cases where
81+ we want to pre-compile a graph with sampling parameters, even if
82+ they are not strictly needed for greedy decoding.
8283 """
84+ # Early return to avoid unnecessary cpu to tpu copy
85+ if (input_batch .all_greedy is True
86+ and generate_params_if_all_greedy is False ):
87+ return cls (all_greedy = True )
88+
8389 num_reqs = input_batch .num_reqs
84- padded_num_reqs = len (indices_do_sample )
8590
86- def copy_slice (cpu_tensor : torch .Tensor , tpu_tensor : torch .Tensor ,
87- fill_val ) -> torch .Tensor :
88- # Copy slice from CPU to corresponding TPU pre-allocated tensor.
91+ def fill_slice (cpu_tensor : torch .Tensor , fill_val ) -> torch .Tensor :
8992 # Pad value is the default one.
9093 cpu_tensor [num_reqs :padded_num_reqs ] = fill_val
91- # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
92- tpu_tensor [:padded_num_reqs ] = cpu_tensor [:padded_num_reqs ]
9394
94- # NOTE NickLucche The sync CPU-TPU graph we produce here must be
95- # consistent. We can't have flags to skip copies or we'll end up
96- # recompiling.
97- copy_slice (input_batch .temperature_cpu_tensor , input_batch .temperature ,
95+ fill_slice (input_batch .temperature_cpu_tensor ,
9896 DEFAULT_SAMPLING_PARAMS ["temperature" ])
9997 # TODO Temporarily disabled until sampling options are enabled
100- # copy_slice (input_batch.top_p_cpu_tensor, input_batch.top_p )
101- # copy_slice (input_batch.top_k_cpu_tensor, input_batch.top_k )
102- copy_slice (input_batch .min_p_cpu_tensor , input_batch . min_p ,
98+ # fill_slice (input_batch.top_p_cpu_tensor)
99+ # fill_slice (input_batch.top_k_cpu_tensor)
100+ fill_slice (input_batch .min_p_cpu_tensor ,
103101 DEFAULT_SAMPLING_PARAMS ["min_p" ])
104102
105- xm .mark_step ()
106- xm .wait_device_ops ()
107-
108103 # Slice persistent device tensors to a fixed pre-compiled padded shape.
109104 return cls (
110- temperature = input_batch .temperature [:padded_num_reqs ],
111- # Scalar tensor for xla-friendly tracing.
112- all_greedy = torch .tensor (input_batch .all_greedy ,
113- dtype = torch .bool ,
114- device = input_batch .device ),
105+ temperature = input_batch .temperature_cpu_tensor [:padded_num_reqs ].
106+ to (xla_device ),
107+ all_greedy = input_batch .all_greedy ,
115108 # TODO enable more and avoid returning None values
116109 top_p = None , # input_batch.top_p[:padded_num_reqs],
117110 top_k = None , # input_batch.top_k[:padded_num_reqs],
118- min_p = input_batch .min_p [:padded_num_reqs ],
119- generators = input_batch . generators ,
120- indices_do_sample = indices_do_sample )
111+ min_p = input_batch .min_p_cpu_tensor [:padded_num_reqs ]. to (
112+ xla_device ) ,
113+ generators = input_batch . generators )
0 commit comments