File tree Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Expand file tree Collapse file tree 1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -1158,11 +1158,12 @@ def profile_run(self) -> None:
11581158 # Trigger compilation for general shape.
11591159 hidden_states = self ._dummy_run (self .max_num_tokens ,
11601160 dummy_kv_caches )
1161- if not get_pp_group ().is_last_rank :
1162- return hidden_states
1163- hidden_states = hidden_states [logit_indices ]
1164- logits = self .model .compute_logits (hidden_states , None )
1165- # TODO(woosuk): Consider the memory usage of the sampler.
1161+ if get_pp_group ().is_last_rank :
1162+ hidden_states = hidden_states [logit_indices ]
1163+ logits = self .model .compute_logits (hidden_states , None )
1164+ # TODO(woosuk): Consider the memory usage of the sampler.
1165+ else :
1166+ logits = None
11661167 torch .cuda .synchronize ()
11671168 del hidden_states , logits
11681169 self .encoder_cache .clear ()
You can’t perform that action at this time.
0 commit comments