55import torch
66import torch_xla .core .xla_model as xm
77
8- from vllm .v1 .sample .metadata import SamplingMetadata
8+ from vllm .v1 .worker .gpu_input_batch import InputBatch
9+
10+ DEFAULT_SAMPLING_PARAMS = dict (
11+ temperature = - 1.0 ,
12+ min_p = 0.0 ,
13+ # strictly disabled for now
14+ # top_k=-1,
15+ # top_p=0.0,
16+ # frequency_penalties=0.0,
17+ # presence_penalties=0.0,
18+ # repetition_penalties=0.0,
19+ )
920
1021
1122@dataclass
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
2031 top_k : torch .Tensor = None
2132 top_p : torch .Tensor = None
2233
23- # XLA-unfriendly control flow in Sampler
24- all_greedy : bool = False
25- all_random : bool = False
2634 # Greedy sampling flag for compiling single xla graph.
27- do_argmax : torch .Tensor = None
28-
29- # speculation not supported
30- spec_token_ids = None
35+ all_greedy : torch .Tensor = None
3136
3237 # Generator not supported by xla
3338 generators : dict [int ,
@@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
5459 bad_words_token_ids = None
5560 indices_do_sample : torch .Tensor = None
5661
57- def __post_init__ (self ):
58- temp = self .temperature
59- if self .indices_do_sample is None :
60- self .indices_do_sample = torch .zeros (temp .shape [0 ],
61- device = temp .device ,
62- dtype = torch .int32 )
63- if self .do_argmax is None :
64- self .do_argmax = torch .tensor (0 ,
65- dtype = torch .bool ,
66- device = temp .device )
67-
6862 @classmethod
69- def from_sampling_metadata (
70- cls , metadata : SamplingMetadata ,
71- padded_do_sample_indices : torch .Tensor , num_do_sample : int ,
72- device : torch .device ) -> "TPUSupportedSamplingMetadata" :
63+ def from_input_batch (
64+ cls , input_batch : InputBatch ,
65+ indices_do_sample : torch .Tensor ) -> "TPUSupportedSamplingMetadata" :
7366 """
74- Create an XLA-frienly SamplingMetadata structure. Do so by first
75- instantiating an object with fixed-sized tensors and then writing the
76- values in input `metadata`. Do that only for non-None values so that
77- recompilation is not triggered for optional values (None/torch.Tensor).
78-
79- In order to handle different sizes for the params that range from 1 up
80- to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
81- Same thing for `padded_do_sample_indices`, which contains the indices
82- to be fed to the Sampler, padded to the closest pre-compiled shape.
83-
84- Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
85- do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
67+ Copy sampling tensors slices from `input_batch` to on device tensors.
68+
69+ `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
70+ slices dynamic shapes on device tensors. This impl moves the dynamic
71+ ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
72+ also reuses the on-device persistent tensors managed in `input_batch`
73+ to reduce waste.
74+
75+ `indices_do_sample` contains the indices to be fed to the Sampler,
76+ normally one per request, here padded to the closest pre-compiled shape
77+ We expect sampling params tensors to be padded to the same fixed shape.
78+
79+ Eg. 3 requests, tensors padded to 4
80+ temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
81+ sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
8682 """
87- metadata = cls ._validate_sampling_metadata (metadata )
88- # NOTE we have to initialize default tensor-based params first and
89- # skip None values altogether to produce the same xla graph.
90- num_samples = len (padded_do_sample_indices )
91- do_argmax = torch .tensor (metadata .all_greedy ,
92- dtype = torch .bool ,
93- device = device )
94- new_metadata = cls .get_default_sampling_params (num_samples , device ,
95- indices_do_sample = \
96- padded_do_sample_indices ,
97- do_argmax = do_argmax
98- )
99- supported_params = \
100- TPUSupportedSamplingMetadata ._get_default_params_values ()
101- # Copy input non-None values into `new_metadata` fixed-sized tensors.
102- for p_name in supported_params :
103- old_val = getattr (metadata , p_name )
104- new_val = getattr (new_metadata , p_name )
105- if isinstance (old_val , torch .Tensor ):
106- new_val [:num_do_sample ] = old_val
107- setattr (new_metadata , p_name , new_val )
83+ num_reqs = input_batch .num_reqs
84+ padded_num_reqs = len (indices_do_sample )
85+
86+ def copy_slice (cpu_tensor : torch .Tensor , tpu_tensor : torch .Tensor ,
87+ fill_val ) -> torch .Tensor :
88+ # Copy slice from CPU to corresponding TPU pre-allocated tensor.
89+ # Pad value is the default one.
90+ cpu_tensor [num_reqs :padded_num_reqs ] = fill_val
91+ tpu_tensor [:padded_num_reqs ] = cpu_tensor [:padded_num_reqs ]
92+
93+ # NOTE NickLucche The sync CPU-TPU graph we produce here must be
94+ # consistent. We can't have flags to skip copies or we'll end up
95+ # recompiling.
96+ copy_slice (input_batch .temperature_cpu_tensor , input_batch .temperature ,
97+ DEFAULT_SAMPLING_PARAMS ["temperature" ])
98+ # TODO Temporarily disabled until sampling options are enabled
99+ # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
100+ # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
101+ copy_slice (input_batch .min_p_cpu_tensor , input_batch .min_p ,
102+ DEFAULT_SAMPLING_PARAMS ["min_p" ])
103+
104+ # copy_slice(input_batch.frequency_penalties_cpu_tensor,
105+ # input_batch.frequency_penalties)
106+ # copy_slice(input_batch.presence_penalties_cpu_tensor,
107+ # input_batch.presence_penalties)
108+ # copy_slice(input_batch.repetition_penalties_cpu_tensor,
109+ # input_batch.repetition_penalties)
108110
109111 xm .mark_step ()
110112 xm .wait_device_ops ()
111- return new_metadata
112113
113- @classmethod
114- def get_default_sampling_params (
115- cls ,
116- num_samples : int ,
117- device : torch .device ,
118- indices_do_sample = None ,
119- do_argmax = None ) -> "TPUSupportedSamplingMetadata" :
120- # As sampling happens on a single traced graph, options
121- # are "disabled" by having them evaluate to an Identity op.
122- # Note that initialization is dependent on num_samples.
123- sampling_metadata_disable_value = \
124- TPUSupportedSamplingMetadata ._get_default_params_values ()
125- init_kwargs = dict ()
126- for p_name , (default_val ,
127- dtype ) in sampling_metadata_disable_value .items ():
128- default_tensor = torch .full ((num_samples , ),
129- default_val ,
130- dtype = dtype ,
131- device = device )
132- init_kwargs [p_name ] = default_tensor
133-
134- return cls (** init_kwargs ,
135- indices_do_sample = indices_do_sample ,
136- do_argmax = do_argmax )
137-
138- @staticmethod
139- def _validate_sampling_metadata (
140- sampling_metadata : SamplingMetadata ) -> SamplingMetadata :
141- if sampling_metadata .all_greedy :
142- # Set to None since #13587. Make sure default isn't overruled.
143- assert sampling_metadata .temperature is None
144- return sampling_metadata
145-
146- @staticmethod
147- def _get_default_params_values ():
148- return dict (
149- # Since #13587 greedy sampling requires branching off which leads
150- # to separate graphs. We set temp to noop and handle argmax here.
151- temperature = (1.0 , torch .float32 ),
152- min_p = (0.0 , torch .float32 ),
153- # strictly disabled for now
154- # top_k=(-1, torch.int32),
155- # top_p=(0.0, torch.float32),
156- # frequency_penalties=(0.0, torch.float32),
157- # presence_penalties=(0.0, torch.float32),
158- # repetition_penalties=(0.0, torch.float32),
159- )
114+ # Slice persistent device tensors to a fixed pre-compiled padded shape.
115+ return cls (
116+ temperature = input_batch .temperature [:padded_num_reqs ],
117+ # Scalar tensor for xla-friendly tracing.
118+ all_greedy = torch .tensor (input_batch .all_greedy ,
119+ dtype = torch .bool ,
120+ device = input_batch .device ),
121+ # TODO enable more and avoid returning None values
122+ top_p = None , # input_batch.top_p[:padded_num_reqs],
123+ top_k = None , # input_batch.top_k[:padded_num_reqs],
124+ min_p = input_batch .min_p [:padded_num_reqs ],
125+ generators = input_batch .generators ,
126+ indices_do_sample = indices_do_sample )
0 commit comments