@@ -1238,43 +1238,6 @@ def _dummy_run(
12381238 )
12391239 return hidden_states
12401240
1241- @torch .inference_mode ()
1242- def _dummy_sampler_run (
1243- self ,
1244- hidden_states : torch .Tensor ,
1245- ) -> torch .Tensor :
1246-
1247- logits = self .model .compute_logits (hidden_states , None )
1248- num_reqs = logits .size (0 )
1249-
1250- dummy_tensors = lambda v : torch .full (
1251- (num_reqs , ), v , device = self .device )
1252-
1253- dummy_metadata = SamplingMetadata (
1254- temperature = dummy_tensors (0.5 ),
1255- all_greedy = False ,
1256- all_random = False ,
1257- top_p = dummy_tensors (0.9 ),
1258- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1259- min_p = None ,
1260- generators = {},
1261- max_num_logprobs = None ,
1262- no_penalties = True ,
1263- prompt_token_ids = None ,
1264- frequency_penalties = dummy_tensors (0.1 ),
1265- presence_penalties = dummy_tensors (0.1 ),
1266- repetition_penalties = dummy_tensors (0.1 ),
1267- output_token_ids = [[] for _ in range (num_reqs )],
1268- min_tokens = {},
1269- logit_bias = [None for _ in range (num_reqs )],
1270- allowed_token_ids_mask = None ,
1271- bad_words_token_ids = {},
1272- )
1273- sampler_output = self .model .sample (logits = logits ,
1274- sampling_metadata = dummy_metadata )
1275-
1276- return sampler_output
1277-
12781241 def profile_run (self ) -> None :
12791242 # Profile with multimodal encoder & encoder cache.
12801243 # TODO: handle encoder-decoder models once we support them.
@@ -1390,11 +1353,38 @@ def profile_run(self) -> None:
13901353 hidden_states = self ._dummy_run (self .max_num_tokens )
13911354 if get_pp_group ().is_last_rank :
13921355 hidden_states = hidden_states [logit_indices ]
1393- sampler_output = self ._dummy_sampler_run (hidden_states )
1356+ logits = self .model .compute_logits (hidden_states , None )
1357+ dummy_tensors = lambda v : torch .full (
1358+ (num_reqs , ), v , device = self .device )
1359+ dummy_metadata = SamplingMetadata (
1360+ temperature = dummy_tensors (0.5 ),
1361+ all_greedy = False ,
1362+ all_random = False ,
1363+ top_p = dummy_tensors (0.9 ),
1364+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1365+ min_p = None ,
1366+ generators = {},
1367+ max_num_logprobs = None ,
1368+ no_penalties = True ,
1369+ prompt_token_ids = torch .ones_like (logits ,
1370+ dtype = torch .int64 ),
1371+ frequency_penalties = dummy_tensors (0.1 ),
1372+ presence_penalties = dummy_tensors (0.1 ),
1373+ repetition_penalties = dummy_tensors (0.1 ),
1374+ output_token_ids = [[] for _ in range (num_reqs )],
1375+ min_tokens = {},
1376+ logit_bias = [None for _ in range (num_reqs )],
1377+ allowed_token_ids_mask = None ,
1378+ bad_words_token_ids = {},
1379+ )
1380+ sampler_output = self .model .sample (
1381+ logits = logits , sampling_metadata = dummy_metadata )
13941382 else :
1383+ logits = None
13951384 sampler_output = None
1385+ dummy_metadata = None
13961386 torch .cuda .synchronize ()
1397- del hidden_states , sampler_output
1387+ del hidden_states , logits , sampler_output , dummy_metadata
13981388 self .encoder_cache .clear ()
13991389 gc .collect ()
14001390
0 commit comments