@@ -41,8 +41,6 @@ def forward(
4141 logits = self .apply_logits_bias (logits , sampling_metadata )
4242 # Apply penalties (e.g., min_tokens, freq_penalties).
4343 logits = self .apply_penalties (logits , sampling_metadata )
44- # Apply temperature.
45- logits = self .apply_temperature (logits , sampling_metadata .temperature )
4644 # Sample the next token.
4745 sampled = self .sample (logits , sampling_metadata )
4846
@@ -82,9 +80,21 @@ def sample(
8280 ) -> torch .Tensor :
8381 assert not (sampling_metadata .all_greedy
8482 and sampling_metadata .all_random )
85- if sampling_metadata .all_greedy :
86- return self .greedy_sample (logits )
83+ if sampling_metadata .all_random :
84+ greedy_sampled = None
85+ else :
86+ greedy_sampled = self .greedy_sample (logits )
87+ if sampling_metadata .all_greedy :
88+ return greedy_sampled
8789
90+ # Apply temperature.
91+ logits = self .apply_temperature (logits , sampling_metadata .temperature )
92+
93+ # Apply min_p.
94+ if not sampling_metadata .no_min_p :
95+ logits = self .apply_min_p (logits , sampling_metadata .min_p )
96+
97+ # Apply top_k and/or top_p.
8898 random_sampled = self .topk_topp_sampler (
8999 logits ,
90100 sampling_metadata .generators ,
@@ -94,13 +104,9 @@ def sample(
94104 sampling_metadata .top_p ,
95105 )
96106
97- if not sampling_metadata .no_min_p :
98- logits = self .apply_min_p (logits , sampling_metadata .min_p )
99-
100- if sampling_metadata .all_random :
107+ if greedy_sampled is None :
101108 return random_sampled
102109
103- greedy_sampled = self .greedy_sample (logits )
104110 sampled = torch .where (
105111 sampling_metadata .temperature < _SAMPLING_EPS ,
106112 greedy_sampled ,
0 commit comments