Skip to content

LMHead is processing redundant tokens in prefill #38977

@null-pointer-access

Description

@null-pointer-access

While using GPT2LMHeadModel.generate() and compare its performance with vLLM, I noticed a significant inefficiency in the forward() implementation of many huggingface models. For example, in the GPT2LMHeadModel.forward, self.lm_head is applied to all token hidden states, even when called from the generate() method, where only the logits of the last token are needed for next-token prediction. This computes logits over the entire sequence and can introduce significant overhead.

# src/transformers/models/gpt2/modeling_gpt2.py, line 1233
lm_logits = self.lm_head(hidden_states)

Suggested Fix: add a conditional branch in forward() to slice the hidden states before computing logits if it’s a generation step.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions