@@ -51,9 +51,9 @@ class RejectionSampler(nn.Module):
5151 def __init__ (self , sampler : Sampler ):
5252 super ().__init__ ()
5353 self .sampler = sampler
54- self . return_processed_logprobs = self .sampler .logprobs_mode . startswith (
55- "processed"
56- )
54+ logprobs_mode = self .sampler .logprobs_mode
55+ self . is_processed_logprobs_mode = logprobs_mode . startswith ( "processed" )
56+ self . is_logits_logprobs_mode = logprobs_mode . endswith ( "logits" )
5757
5858 def forward (
5959 self ,
@@ -107,7 +107,7 @@ def forward(
107107 # Override the logprobs mode to return logits because they are
108108 # needed later to compute the accepted token logprobs.
109109 logprobs_mode_override = "processed_logits"
110- if self .return_processed_logprobs
110+ if self .is_processed_logprobs_mode
111111 else "raw_logits" ,
112112 )
113113 bonus_logits = bonus_sampler_output .logprobs_tensors .logprobs
@@ -150,7 +150,7 @@ def forward(
150150 sampling_metadata ,
151151 metadata ,
152152 logits ,
153- target_logits if self .return_processed_logprobs else raw_target_logits ,
153+ target_logits if self .is_processed_logprobs_mode else raw_target_logits ,
154154 bonus_logits ,
155155 output_token_ids ,
156156 ),
@@ -190,7 +190,7 @@ def _get_logprobs_tensors(
190190 accepted_logits = final_logits [accepted_logit_indices ]
191191 accepted_logprobs = (
192192 accepted_logits
193- if self .logprobs_mode . endswith ( "logits" )
193+ if self .is_logits_logprobs_mode
194194 else self .sampler .compute_logprobs (accepted_logits )
195195 )
196196 accepted_tokens = sampled_token_ids [accepted_mask ]
0 commit comments