Skip to content

Implement generate with the mixin from HF  #217

@tscholak

Description

@tscholak

🎯 Goal (What & Why)

Implement generate() for HuggingfaceGPTModelForCausalLM using GenerationMixin, supporting greedy decoding only.
This makes the Fast-LLM model behave like a HuggingFace model for generation. The goal is to enable validation-time text generation directly from the sharded Fast-LLM model in memory, without converting to HF format (which would require extra memory and lead to model eviction).
We use batched greedy decoding and support FlashAttention by padding and attention masking.
No beam search, sampling, or KV caching is needed.

🚀 Execution Plan

Develop a minimal, batched, greedy generation loop using Fast-LLM's .forward() and GenerationMixin integration.

Step 1: What is the smallest working version?

  • Implement the GenerationMixin interface in HuggingfaceGPTModelForCausalLM, i.e. this interface:
    class GenerationMixin:
        def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            **kwargs,
        ):
            ...
    
        @torch.no_grad()
        def generate(
            self,
            inputs: None | torch.Tensor = None,
            generation_config: None | GenerationConfig = None,
            **kwargs,
        ) -> Union[GenerateOutput, torch.LongTensor]:
            ...
  • Implement generate() to:
    • Accept inputs and max_new_tokens, eos_token_id, as well as pad_token_id via GenerationConfig or kwargs.
    • Only add support for batched greedy decoding (no sampling, no beam search).
    • Left-pad input sequences for FlashAttention.
    • Use an attention mask to exclude padding tokens.
    • Track completion using eos_token_id and max_new_tokens.
    • Reuse .forward() on the Fast-LLM model in each step.
  • Implement prepare_inputs_for_generation() minimally to satisfy HF's expectations.
  • Use the tokenizer to:
    • Tokenize prompts to input_ids.
    • Detokenize generated tokens to strings.
  • Add a slow integration test that:
    • Loads HuggingFaceTB/SmolLM2-135M-Instruct in both HF and Fast-LLM formats.
    • Runs .generate() from both and compare outputs.
    • Use fixed prompt + seed + no sampling to make it deterministic.

Step 2: What additional optimizations are possible (but optional)?

  • Add past key/value caching for faster autoregressive generation.
  • Fuse decode-time batching for speed-up (e.g., with CUDA graphs).

📌 Acceptance Criteria (Must-Haves for Completion)

  • Must implement the required methods from HF's GenerationMixin.
  • Must support batched greedy decoding with:
    • FlashAttention-compatible padding and attention masking.
    • EOS handling and padding with pad_token_id.
  • Must include:
    • Integration test comparing Fast-LLM to HuggingFace output using a shared checkpoint.
    • Benchmark test comparing generation speed (Fast-LLM vs HF w/o cache) on one GPU/shard.
  • Implementation must be documented:
    • Explain how padding and masks are used.
    • List all unsupported features and error behavior.
  • Must not refactor unrelated code.

📎 Relevant Links

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Small/Medium/Large).
  • Assign an owner when opening the issue.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions