1+ # SPDX-License-Identifier: Apache-2.0 
2+ from  dataclasses  import  dataclass , field 
3+ from  typing  import  Optional 
4+ 
5+ import  torch 
6+ import  torch_xla .core .xla_model  as  xm 
7+ 
8+ from  vllm .v1 .sample .metadata  import  SamplingMetadata 
9+ 
10+ 
11+ @dataclass  
12+ class  TPUSupportedSamplingMetadata :
13+     # This class exposes a more xla-friendly interface than SamplingMetadata 
14+     # on TPU, in particular all arguments should be traceable and no optionals 
15+     # are allowed, to avoid graph recompilation on Nones. 
16+     temperature : torch .Tensor 
17+ 
18+     min_p : torch .Tensor 
19+     # Still too slow on forward_native! 
20+     top_k : torch .Tensor  =  None 
21+     top_p : torch .Tensor  =  None 
22+ 
23+     # XLA-unfriendly control flow in Sampler 
24+     all_greedy : bool  =  False 
25+     all_random : bool  =  False 
26+     # Greedy sampling flag for compiling single xla graph. 
27+     do_argmax : torch .Tensor  =  None 
28+ 
29+     # speculation not supported 
30+     spec_token_ids  =  None 
31+ 
32+     # Generator not supported by xla 
33+     generators : dict [int ,
34+                      torch .Generator ] =  field (default_factory = lambda : dict ())
35+ 
36+     # unsupported, you need to return an extra tensor of static size BxV 
37+     max_num_logprobs  =  None 
38+ 
39+     # TODO No penalties for now 
40+     no_penalties : bool  =  True 
41+     prompt_token_ids  =  None 
42+     frequency_penalties  =  None 
43+     presence_penalties  =  None 
44+     repetition_penalties  =  None 
45+     # should use tensor 
46+     output_token_ids : list [list [int ]] =  field (default_factory = lambda : list ())
47+ 
48+     min_tokens  =  None   # impl is not vectorized 
49+ 
50+     logit_bias : list [Optional [dict [int , float ]]] =  field (
51+         default_factory = lambda : list ())
52+ 
53+     allowed_token_ids_mask  =  None 
54+     bad_words_token_ids  =  None 
55+     indices_do_sample : torch .Tensor  =  None 
56+ 
57+     def  __post_init__ (self ):
58+         temp  =  self .temperature 
59+         if  self .indices_do_sample  is  None :
60+             self .indices_do_sample  =  torch .zeros (temp .shape [0 ],
61+                                                  device = temp .device ,
62+                                                  dtype = torch .int32 )
63+         if  self .do_argmax  is  None :
64+             self .do_argmax  =  torch .tensor (0 ,
65+                                           dtype = torch .bool ,
66+                                           device = temp .device )
67+ 
68+     @classmethod  
69+     def  from_sampling_metadata (
70+             cls , metadata : SamplingMetadata ,
71+             padded_do_sample_indices : torch .Tensor , num_do_sample : int ,
72+             device : torch .device ) ->  "TPUSupportedSamplingMetadata" :
73+         """ 
74+         Create an XLA-frienly SamplingMetadata structure. Do so by first  
75+         instantiating an object with fixed-sized tensors and then writing the 
76+         values in input `metadata`. Do that only for non-None values so that  
77+         recompilation is not triggered for optional values (None/torch.Tensor). 
78+          
79+         In order to handle different sizes for the params that range from 1 up  
80+         to `max_num_seqs`, pad tensors to the closest pre-compiled shape. 
81+         Same thing for `padded_do_sample_indices`, which contains the indices  
82+         to be fed to the Sampler, padded to the closest pre-compiled shape. 
83+ 
84+         Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] 
85+             do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0] 
86+         """ 
87+         metadata  =  cls ._validate_sampling_metadata (metadata )
88+         # NOTE we have to initialize default tensor-based params first and 
89+         # skip None values altogether to produce the same xla graph. 
90+         num_samples  =  len (padded_do_sample_indices )
91+         do_argmax  =  torch .tensor (metadata .all_greedy ,
92+                                  dtype = torch .bool ,
93+                                  device = device )
94+         new_metadata  =  cls .get_default_sampling_params (num_samples , device ,
95+                                                     indices_do_sample = \
96+                                                     padded_do_sample_indices ,
97+                                                     do_argmax = do_argmax 
98+                                                     )
99+         supported_params  =  \
100+             TPUSupportedSamplingMetadata ._get_default_params_values ()
101+         # Copy input non-None values into `new_metadata` fixed-sized tensors. 
102+         for  p_name  in  supported_params :
103+             old_val  =  getattr (metadata , p_name )
104+             new_val  =  getattr (new_metadata , p_name )
105+             if  isinstance (old_val , torch .Tensor ):
106+                 new_val [:num_do_sample ] =  old_val 
107+             setattr (new_metadata , p_name , new_val )
108+ 
109+         xm .mark_step ()
110+         xm .wait_device_ops ()
111+         return  new_metadata 
112+ 
113+     @classmethod  
114+     def  get_default_sampling_params (
115+             cls ,
116+             num_samples : int ,
117+             device : torch .device ,
118+             indices_do_sample = None ,
119+             do_argmax = None ) ->  "TPUSupportedSamplingMetadata" :
120+         # As sampling happens on a single traced graph, options 
121+         # are "disabled" by having them evaluate to an Identity op. 
122+         # Note that initialization is dependent on num_samples. 
123+         sampling_metadata_disable_value  =  \
124+             TPUSupportedSamplingMetadata ._get_default_params_values ()
125+         init_kwargs  =  dict ()
126+         for  p_name , (default_val ,
127+                      dtype ) in  sampling_metadata_disable_value .items ():
128+             default_tensor  =  torch .full ((num_samples , ),
129+                                         default_val ,
130+                                         dtype = dtype ,
131+                                         device = device )
132+             init_kwargs [p_name ] =  default_tensor 
133+ 
134+         return  cls (** init_kwargs ,
135+                    indices_do_sample = indices_do_sample ,
136+                    do_argmax = do_argmax )
137+ 
138+     @staticmethod  
139+     def  _validate_sampling_metadata (
140+             sampling_metadata : SamplingMetadata ) ->  SamplingMetadata :
141+         if  sampling_metadata .all_greedy :
142+             # Set to None since #13587. Make sure default isn't overruled. 
143+             assert  sampling_metadata .temperature  is  None 
144+         return  sampling_metadata 
145+ 
146+     @staticmethod  
147+     def  _get_default_params_values ():
148+         return  dict (
149+             # Since #13587 greedy sampling requires branching off which leads 
150+             # to separate graphs. We set temp to noop and handle argmax here. 
151+             temperature = (1.0 , torch .float32 ),
152+             min_p = (0.0 , torch .float32 ),
153+             # strictly disabled for now 
154+             # top_k=(-1, torch.int32), 
155+             # top_p=(0.0, torch.float32), 
156+             # frequency_penalties=(0.0, torch.float32), 
157+             # presence_penalties=(0.0, torch.float32), 
158+             # repetition_penalties=(0.0, torch.float32), 
159+         )
0 commit comments