-
Notifications
You must be signed in to change notification settings - Fork 31k
Closed
Description
While using GPT2LMHeadModel.generate() and compare its performance with vLLM, I noticed a significant inefficiency in the forward() implementation of many huggingface models. For example, in the GPT2LMHeadModel.forward, self.lm_head is applied to all token hidden states, even when called from the generate() method, where only the logits of the last token are needed for next-token prediction. This computes logits over the entire sequence and can introduce significant overhead.
# src/transformers/models/gpt2/modeling_gpt2.py, line 1233
lm_logits = self.lm_head(hidden_states)Suggested Fix: add a conditional branch in forward() to slice the hidden states before computing logits if it’s a generation step.
Metadata
Metadata
Assignees
Labels
No labels