3030from vllm .v1 .attention .backends .pallas import (PallasAttentionBackend ,
3131 PallasMetadata )
3232from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
33- from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
34- KVCacheSpec , SlidingWindowSpec )
33+ from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
34+ KVCacheConfig , KVCacheSpec ,
35+ SlidingWindowSpec )
3536from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , LogprobsTensors ,
3637 ModelRunnerOutput )
3738from vllm .v1 .sample .tpu .metadata import TPUSupportedSamplingMetadata
@@ -148,6 +149,7 @@ def __init__(
148149 self .num_kv_heads = model_config .get_num_kv_heads (parallel_config )
149150 self .head_size = model_config .get_head_size ()
150151 self .hidden_size = model_config .get_hidden_size ()
152+ self .vocab_size = model_config .get_vocab_size ()
151153
152154 # Multi-modal data support
153155 self .mm_registry = MULTIMODAL_REGISTRY
@@ -178,7 +180,7 @@ def __init__(
178180 max_num_blocks_per_req = self .max_num_blocks_per_req ,
179181 device = self .device ,
180182 pin_memory = self .pin_memory ,
181- vocab_size = model_config . get_vocab_size () ,
183+ vocab_size = self . vocab_size ,
182184 )
183185
184186 # Cached torch/numpy tensor
@@ -221,6 +223,20 @@ def __init__(
221223 self .num_reqs_paddings = _get_req_paddings (
222224 min_req_size = MIN_NUM_SEQS , max_req_size = self .max_num_reqs )
223225
226+ # tensors for structured decoding
227+ self .grammar_bitmask_cpu = torch .zeros (
228+ (self .max_num_reqs , cdiv (self .vocab_size , 32 )),
229+ dtype = torch .int32 ,
230+ device = "cpu" ,
231+ pin_memory = self .pin_memory )
232+ self .require_structured_out_cpu = torch .zeros (
233+ (self .max_num_reqs , 1 ),
234+ dtype = torch .bool ,
235+ device = "cpu" ,
236+ pin_memory = self .pin_memory )
237+ self .structured_decode_arange = torch .arange (
238+ 0 , 32 , device = "cpu" , pin_memory = self .pin_memory )
239+
224240 # Get maximum number of mm items per modality (batch size).
225241 self .max_num_mm_items_by_modality = dict ()
226242 if (self .is_multimodal_model and self .max_num_encoder_input_tokens > 0
@@ -762,9 +778,16 @@ def execute_model(
762778 )
763779 hidden_states = self .select_hidden_states (hidden_states ,
764780 logits_indices )
781+ logits = self .compute_logits (hidden_states )
765782 tpu_sampling_metadata = TPUSupportedSamplingMetadata .\
766783 from_input_batch (self .input_batch , padded_num_reqs , self .device )
767- selected_token_ids = self .sample_from_hidden (hidden_states ,
784+ if scheduler_output .grammar_bitmask is not None :
785+ require_struct_decoding , grammar_bitmask_padded , arange = \
786+ self .prepare_structured_decoding_input (logits , scheduler_output )
787+ logits = self .structured_decode (require_struct_decoding ,
788+ grammar_bitmask_padded , logits ,
789+ arange )
790+ selected_token_ids = self .sample_from_logits (logits ,
768791 tpu_sampling_metadata )
769792 # Remove padding on cpu and keep dynamic op outside of xla graph.
770793 selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
@@ -997,7 +1020,7 @@ def _precompile_backbone(self) -> None:
9971020 self ._dummy_run (num_tokens )
9981021 xm .wait_device_ops ()
9991022 end = time .perf_counter ()
1000- logger .info ("Compilation finished in in %.2f [secs]." , end - start )
1023+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
10011024 self ._update_num_xla_graphs ("model backbone" )
10021025
10031026 def _precompile_select_hidden_states (self ) -> None :
@@ -1026,19 +1049,59 @@ def _precompile_select_hidden_states(self) -> None:
10261049 break
10271050 xm .wait_device_ops ()
10281051 end = time .perf_counter ()
1029- logger .info ("Compilation finished in in %.2f [secs]." , end - start )
1052+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
10301053 self ._update_num_xla_graphs ("select_hidden_states" )
10311054
1032- def _precompile_sample_from_hidden (self ) -> None :
1033- logger .info ("Compiling sampling with different num_reqs ." )
1055+ def _precompile_compute_logits (self ) -> None :
1056+ logger .info ("Compiling compute_logits with different input shapes ." )
10341057 start = time .perf_counter ()
10351058 hsize = self .model_config .get_hidden_size ()
10361059 for num_reqs in self .num_reqs_paddings :
10371060 dummy_hidden = torch .zeros ((num_reqs , hsize ),
10381061 device = self .device ,
10391062 dtype = self ._hidden_states_dtype )
1040- # The first dimension of dummy_hidden cannot be mark_dynamic because
1041- # some operations in the sampler require it to be static.
1063+ torch ._dynamo .mark_dynamic (dummy_hidden , 0 )
1064+ self .compute_logits (dummy_hidden )
1065+ logger .info (" -- num_seqs: %d" , num_reqs )
1066+ xm .wait_device_ops ()
1067+ end = time .perf_counter ()
1068+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1069+ self ._update_num_xla_graphs ("compute_logits" )
1070+
1071+ def _precompile_structured_decoding (self ) -> None :
1072+ logger .info (
1073+ "Compiling structured_decoding with different input shapes." )
1074+ start = time .perf_counter ()
1075+ for num_reqs in self .num_reqs_paddings :
1076+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1077+ device = self .device ,
1078+ dtype = self ._hidden_states_dtype )
1079+ dummy_require_struct_decoding = \
1080+ self .require_structured_out_cpu [:num_reqs ].to (self .device )
1081+ dummy_grammar_bitmask = \
1082+ self .grammar_bitmask_cpu [:num_reqs ].to (self .device )
1083+ # The first dimension of the above 3 dummy tensors cannot be
1084+ # mark_dynamic because some operations in structured_decode require
1085+ # them to be static.
1086+ arange = self .structured_decode_arange .to (self .device )
1087+ self .structured_decode (dummy_require_struct_decoding ,
1088+ dummy_grammar_bitmask , dummy_logits , arange )
1089+ logger .info (" -- num_seqs: %d" , num_reqs )
1090+ xm .wait_device_ops ()
1091+ end = time .perf_counter ()
1092+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1093+ self ._update_num_xla_graphs ("structured_decoding" )
1094+
1095+ def _precompile_sample_from_logits (self ) -> None :
1096+ logger .info (
1097+ "Compiling sample_from_logits with different input shapes." )
1098+ start = time .perf_counter ()
1099+ for num_reqs in self .num_reqs_paddings :
1100+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1101+ device = self .device ,
1102+ dtype = self ._hidden_states_dtype )
1103+ # The first dimension of dummy_logits cannot be mark_dynamic
1104+ # because some operations in the sampler require it to be static.
10421105 for all_greedy in [False , True ]:
10431106 generate_params_if_all_greedy = not all_greedy
10441107 sampling_metadata = (
@@ -1049,12 +1112,12 @@ def _precompile_sample_from_hidden(self) -> None:
10491112 generate_params_if_all_greedy ,
10501113 ))
10511114 sampling_metadata .all_greedy = all_greedy
1052- self .sample_from_hidden ( dummy_hidden , sampling_metadata )
1115+ self .sample_from_logits ( dummy_logits , sampling_metadata )
10531116 logger .info (" -- num_seqs: %d" , num_reqs )
10541117 xm .wait_device_ops ()
10551118 end = time .perf_counter ()
1056- logger .info ("Compilation finished in in %.2f [secs]." , end - start )
1057- self ._update_num_xla_graphs ("sampling " )
1119+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1120+ self ._update_num_xla_graphs ("sample_from_logits " )
10581121
10591122 def capture_model (self ) -> None :
10601123 """
@@ -1063,7 +1126,9 @@ def capture_model(self) -> None:
10631126 self ._precompile_mm_encoder ()
10641127 self ._precompile_backbone ()
10651128 self ._precompile_select_hidden_states ()
1066- self ._precompile_sample_from_hidden ()
1129+ self ._precompile_compute_logits ()
1130+ self ._precompile_structured_decoding ()
1131+ self ._precompile_sample_from_logits ()
10671132
10681133 def profile_run (
10691134 self ,
@@ -1144,7 +1209,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
11441209 tensor_config = kv_cache_config .tensors [layer_name ]
11451210 assert tensor_config .size % kv_cache_spec .page_size_bytes == 0
11461211 num_blocks = tensor_config .size // kv_cache_spec .page_size_bytes
1147- if isinstance (kv_cache_spec , FullAttentionSpec ):
1212+ if isinstance (kv_cache_spec , AttentionSpec ):
11481213 kv_cache_shape = PallasAttentionBackend .get_kv_cache_shape (
11491214 num_blocks , kv_cache_spec .block_size ,
11501215 kv_cache_spec .num_kv_heads , kv_cache_spec .head_size )
@@ -1179,29 +1244,86 @@ def select_hidden_states(self, hidden_states, indices_do_sample):
11791244 return hidden_states [indices_do_sample ]
11801245
11811246 @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1182- def sample_from_hidden (
1183- self ,
1184- sample_hidden_states : torch .Tensor ,
1185- sampling_metadata : TPUSupportedSamplingMetadata ,
1186- ) -> torch .Tensor :
1187- """
1188- Sample with xla-friendly function. This function is to be traced
1189- separately from `forward` for lighter compilation overhead.
1190- """
1191- logits = self .model .compute_logits (sample_hidden_states , None )
1247+ def compute_logits (self ,
1248+ sample_hidden_states : torch .Tensor ) -> torch .Tensor :
1249+ return self .model .compute_logits (sample_hidden_states , None )
1250+
1251+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1252+ def sample_from_logits (
1253+ self , logits : torch .Tensor ,
1254+ sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
11921255 if sampling_metadata .all_greedy :
11931256 out_tokens = torch .argmax (logits , dim = - 1 , keepdim = True )
11941257 else :
11951258 out_tokens = self .sampler (logits ,
11961259 sampling_metadata ).sampled_token_ids
11971260 return out_tokens
11981261
1262+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1263+ def structured_decode (self , require_struct_decoding : torch .Tensor ,
1264+ grammar_bitmask : torch .Tensor , logits : torch .Tensor ,
1265+ arange : torch .Tensor ) -> torch .Tensor :
1266+ return torch .where (
1267+ require_struct_decoding ,
1268+ self .apply_grammar_bitmask (logits , grammar_bitmask , arange ),
1269+ logits )
1270+
1271+ def apply_grammar_bitmask (self , logits : torch .Tensor ,
1272+ grammar_bitmask : torch .Tensor ,
1273+ arange : torch .Tensor ):
1274+ assert (logits .shape [0 ] == grammar_bitmask .shape [0 ])
1275+ logits_cloned = logits .clone ()
1276+ for i in range (logits .shape [0 ]):
1277+ unpacked_bitmask = (torch .bitwise_right_shift (
1278+ grammar_bitmask [i ][:, None ], arange [None , :]) & 1 ) == 0
1279+ unpacked_bitmask = unpacked_bitmask .reshape (- 1 )[:self .vocab_size ]
1280+ logits_cloned [i ] = logits_cloned [i ].masked_fill (
1281+ unpacked_bitmask , - float ("inf" ))
1282+ return logits_cloned
1283+
11991284 def get_multimodal_embeddings (self , * args , ** kwargs ):
12001285 return self .model .get_multimodal_embeddings (* args , ** kwargs )
12011286
12021287 def get_input_embeddings (self , * args , ** kwargs ):
12031288 return self .model .get_input_embeddings (* args , ** kwargs )
12041289
1290+ def prepare_structured_decoding_input (
1291+ self , logits : torch .Tensor , scheduler_output : "SchedulerOutput"
1292+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1293+ grammar_bitmask = scheduler_output .grammar_bitmask
1294+ assert grammar_bitmask is not None
1295+ num_reqs , _ = logits .shape
1296+
1297+ # Reset pre-allocated tensors
1298+ self .grammar_bitmask_cpu .zero_ ()
1299+ self .require_structured_out_cpu .zero_ ()
1300+
1301+ # We receive the structured output bitmask from the scheduler, but the
1302+ # indices of the requests in the batch may not match the indices of
1303+ # the bitmask since the scheduler doesn't know how the tpu runner is
1304+ # ordering the requests in the batch. We need to match the order of
1305+ # bitmask with the order of requests
1306+ struct_out_indices : list [int ] = []
1307+ mask_indices : list [int ] = []
1308+ for req_id in self .input_batch .req_ids :
1309+ mask_index = scheduler_output .structured_output_request_ids .get (
1310+ req_id )
1311+ if mask_index is None :
1312+ continue
1313+ batch_index = self .input_batch .req_id_to_index [req_id ]
1314+ struct_out_indices .append (batch_index )
1315+ mask_indices .append (mask_index )
1316+ self .grammar_bitmask_cpu [struct_out_indices ] = torch .from_numpy (
1317+ grammar_bitmask [mask_indices ])
1318+ # It's not guaranteed that all requests in this batch require
1319+ # structured output, so create a bool tensor to represent
1320+ # the requests that need structured output.
1321+ struct_out_indices = torch .tensor (struct_out_indices , dtype = torch .long )
1322+ self .require_structured_out_cpu [struct_out_indices ] = True
1323+ return self .require_structured_out_cpu [:num_reqs ].to (logits .device ), \
1324+ self .grammar_bitmask_cpu [:num_reqs ].to (logits .device ), \
1325+ self .structured_decode_arange .to (logits .device )
1326+
12051327 def _get_mm_dummy_batch (self , modality : str ,
12061328 batch_size : int ) -> BatchedTensorInputs :
12071329 # Dummy data for pre-compiling multimodal models.
0 commit comments