Skip to content

Commit 76775a4

Browse files
committed
feat: update ._contrastive(streamer)
1 parent 50536b7 commit 76775a4

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

src/transformers/generation/utils.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -2818,8 +2818,6 @@ def _contrastive_search(
28182818

28192819
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
28202820
if self.config.is_encoder_decoder:
2821-
next_step_cross_attentions = ()
2822-
next_step_decoder_attentions = ()
28232821
if output_attentions:
28242822
for layer in outputs.cross_attentions:
28252823
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
@@ -2856,7 +2854,21 @@ def _contrastive_search(
28562854
# update generated ids, model inputs, and length for next step
28572855
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
28582856
if streamer is not None:
2859-
streamer.put(next_tokens.cpu())
2857+
output_stub = self._prepare_output(
2858+
return_dict_in_generate=return_dict_in_generate,
2859+
sequences=next_tokens,
2860+
scores=(processed_logit_for_next_step,), # (scores,),
2861+
logits=(processed_logit_for_next_step,),
2862+
# I think there's an issue with the contrastive sampling implementation that is currently returning the same values for logits as scores #(logits[selected_idx,:],), #(logit_for_next_step,), # `logit_for_next_step`: values don't match, `logits`: shapes don't match
2863+
encoder_attentions=None, # probably doesn't make sense to stream this
2864+
encoder_hidden_states=None, # probably doesn't make sense to stream this
2865+
decoder_attentions=(next_step_decoder_attentions,),
2866+
# ([0],),# very concerning that if I set this to `([0],)` my tests don't fail
2867+
cross_attentions=(next_step_cross_attentions,),
2868+
decoder_hidden_states=(next_decoder_hidden_states,),
2869+
past_key_values=None, # probably doesn't make sense to stream this
2870+
)
2871+
streamer.put(output_stub)
28602872
model_kwargs = self._update_model_kwargs_for_generation(
28612873
outputs,
28622874
model_kwargs,

0 commit comments

Comments
 (0)