@@ -1179,6 +1179,43 @@ def _dummy_run(
11791179 )
11801180 return hidden_states
11811181
1182+ @torch .inference_mode ()
1183+ def _dummy_sampler_run (
1184+ self ,
1185+ hidden_states : torch .Tensor ,
1186+ ) -> torch .Tensor :
1187+
1188+ logits = self .model .compute_logits (hidden_states , None )
1189+ num_reqs = logits .size (0 )
1190+
1191+ dummy_tensors = lambda v : torch .full (
1192+ (num_reqs , ), v , device = self .device )
1193+
1194+ dummy_metadata = SamplingMetadata (
1195+ temperature = dummy_tensors (0.5 ),
1196+ all_greedy = False ,
1197+ all_random = False ,
1198+ spec_token_ids = None ,
1199+ top_p = dummy_tensors (0.9 ),
1200+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1201+ min_p = None ,
1202+ generators = {},
1203+ max_num_logprobs = None ,
1204+ no_penalties = True ,
1205+ prompt_token_ids = None ,
1206+ frequency_penalties = dummy_tensors (0.1 ),
1207+ presence_penalties = dummy_tensors (0.1 ),
1208+ repetition_penalties = dummy_tensors (0.1 ),
1209+ output_token_ids = [[] for _ in range (num_reqs )],
1210+ min_tokens = {},
1211+ logit_bias = [None for _ in range (num_reqs )],
1212+ allowed_token_ids_mask = None ,
1213+ )
1214+ sampler_output = self .model .sample (logits = logits ,
1215+ sampling_metadata = dummy_metadata )
1216+
1217+ return sampler_output
1218+
11821219 def profile_run (self ) -> None :
11831220 # use an empty tensor instead of `None`` to force Dynamo to pass
11841221 # it by reference, rather by specializing on the value `None`.
@@ -1306,38 +1343,11 @@ def profile_run(self) -> None:
13061343 dummy_kv_caches )
13071344 if get_pp_group ().is_last_rank :
13081345 hidden_states = hidden_states [logit_indices ]
1309- logits = self .model .compute_logits (hidden_states , None )
1310- dummy_tensors = lambda v : torch .full (
1311- (num_reqs , ), v , device = self .device )
1312- dummy_metadata = SamplingMetadata (
1313- temperature = dummy_tensors (0.5 ),
1314- all_greedy = False ,
1315- all_random = False ,
1316- spec_token_ids = None ,
1317- top_p = dummy_tensors (0.9 ),
1318- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1319- min_p = None ,
1320- generators = {},
1321- max_num_logprobs = None ,
1322- no_penalties = True ,
1323- prompt_token_ids = torch .ones_like (logits ,
1324- dtype = torch .int64 ),
1325- frequency_penalties = dummy_tensors (0.1 ),
1326- presence_penalties = dummy_tensors (0.1 ),
1327- repetition_penalties = dummy_tensors (0.1 ),
1328- output_token_ids = [[] for _ in range (num_reqs )],
1329- min_tokens = {},
1330- logit_bias = [None for _ in range (num_reqs )],
1331- allowed_token_ids_mask = None ,
1332- )
1333- sampler_output = self .model .sample (
1334- logits = logits , sampling_metadata = dummy_metadata )
1346+ sampler_output = self ._dummy_sampler_run (hidden_states )
13351347 else :
1336- logits = None
13371348 sampler_output = None
1338- dummy_metadata = None
13391349 torch .cuda .synchronize ()
1340- del hidden_states , logits , sampler_output , dummy_metadata
1350+ del hidden_states , sampler_output
13411351 self .encoder_cache .clear ()
13421352 gc .collect ()
13431353
0 commit comments