Skip to content

[RFC] Long context fine tuning in torchtune #1244

@felipemello1

Description

@felipemello1

edit: Just with memory optimization flags we are able to achieve 128k context length, depending on your hardware. The work to support even longer context has been deprioritized for now.

Goal


  • Many of the new LLMs models support long context. For example, lamma 3.1 and Mistral 2 support 128k;
  • The trend is upwards, e.g. Gemini support 1M - 10M. Claude supports 200k;
  • Torchtune supports ~24k for llama3 8b with bsz=1;

This RFC aims to define which type of solution to prioritize so we can bridge this gap.

What number do we hope to achieve, with which model and which GPU?

There are three ways I propose we think about it:

  • Prioritize techniques that enable scaling, regardless of the model/sequence length, since the goal may shift, e.g. model X with context length Y;
  • Getting llama-3.1-8b to 128k, using 8xA100, since 8b is the most used SKU, smallest llama, supports 128k and none of the papers about long context used only one GPU;
  • Same as above, but aim for 64k, and rely on RoPE scaling to extend context further;

Sanity check - bsz, num_steps, memory required

Using llama2-7b (~half of the memory llama3-8b needs):

LongLora (https://arxiv.org/pdf/2309.12307) finetunes for 1000 steps (bsz=64) using 8x100, DDP,

image

YaRN paper (https://arxiv.org/pdf/2309.00071) does it for 400 steps (bsz=64, seq_len=64k, +200 steps for seq_len=128k)

Ring attention and Memory Efficient Attention (https://arxiv.org/pdf/2310.01889)

image

Why


image
Source: https://www.reddit.com/r/LocalLLaMA/comments/18t63ub/is_context_length_32k_actually_useful_to_you/

Below are some use cases, but an argument we can make is that supporting long-context will enable new cases, even if they don't exist today.

Possible data flavors:

  • Code: torchtune has 60k lines of code;
  • Multimodality:
    • Image: 512x512, patch_size=14 → 1300 tokens;
    • Video: 1min, 1fps → 78k tokens;
    • Audio:
      • 1s audio (~3 words) → ~25 tokens.
      • 1 minute → 1500 tokens.
      • 60min → 90k tokens;
  • Long texts (books, financial documents):
    • Albania Wikipedia Page: 43k tokens.
    • Second page of “Attention is all you need”: 837 tokens, i.e. 100 pages = ~83k tokens;
  • Raw html:
    • Apple’s website, copy/paste text: 997 tokens
    • Apple’s website, raw html: 92091 tokens

Possible use cases:
* Coding;
* Multimodality understanding, e.g. text books with images, video+audio+captions, call transcripts, user web navigation;
* Creative writing;
* Summarization of multiple complex long documents, e.g. financial documents;
* Complex HTML understanding for clients that lack APIs;
* In-context learning, e.g. many examples in the input;
* Personalized search, e.g. all of the user history;

What about RAG?

RAG works in some scenarios and is cheap, but:

  • Adds complexity (e.g. parsers, hierarchical retrieval, embeddings, etc);
  • Possible bad retrieval / hallucinations;
  • Sometimes you need the whole picture;
  • "The bitter lesson": gemini can answer difficult questions about a document with 170k tokens: https://www.llamaindex.ai/blog/towards-long-context-rag.

However, long context and RAG can be orthogonal approaches: with long context, people can return longer documents, instead of small paragraphs. In other words, long context enables new RAG capabilities.

Our current snapshot:


Currently we support at most ~24k with llama 8b, on A100-80GB, for LoRa, QLoRA, FFT with or without FSDP. After some memory optimizations, we can probably increase it by a little, but probably not enough for 32k.

image

You may think: Shouldn’t QLoRA support much more than FFT? At higher context lengths, the activations consume most of the memory.

Our levers


  1. Inference solutions, such as:
  2. Add memory/recurrency modules to the transformer, such as:
  3. Memory-efficient methods, such as:
    • Parameter efficient (LoRA, GaLore claims to require 30% less memory than LoRA (https://arxiv.org/pdf/2407.08296));
    • Quantization (QLoRA, Q-GaLore);
    • Optimizers that require less memory (MiniAdam, Adafactor, 8bitAdam);
  4. Custom efficient triton implementation that reduce memory:
  5. Attention approximation methods, such as:
  6. Parallelism

(1) is not relevant, since our focus is finetuning. (2) not widely adopted and it is risky to add new modules that change its inference behavior. (3) We already support QLoRA, 8bitoptmizers. The issue is the quadratic growth of activation, and such methods don't tackle it. (4) We rely on pytorch for it. This leaves us with (5) and (6).

LongLora:

  • LoRA, but also finetunes embeddings + normalizations, otherwise results are not good;
  • Split the sequence into chunks. Each attention head pays attention to a different chunk of the text, with some overlapping. Other grouping strategies were tested, e.g. stride and dilation, and they didn’t perform as well.

image

Pros:

  • Works with flash attention;
  • Paper shows promising perplexity and retrieval, comparable with full finetuning;
  • Inference can run with FULL attention (thats so nice);
  • Also applies to single GPU;

Cons:

  • It is an approximation;
  • It doesn’t scale (Upper bound is num_heads and each device needs to gather the whole sequence). In their implementation they use DDP, and their maximum was 100k with llama2-7b, which requires less memory than llama3;
  • Unclear if it works with block causal masking;

Pseudocode
image

Ring Attention

  • Ring Attention is a type of data parallelism, where the tokens are split across GPUs;
  • Every token can be processed individually, except when computing the attention;
  • Full attention requires softmax over all tokens, but we want to avoid materializing the full quadratic matrix. How can this be done? Answer: compute the softmax incrementally, by processing the attention in blocks/chunks, and then “undoing” the normalization with the stored x.exp().sum(). Ring attention parallelizes it.

Pseudocode

def naive_soft_max(x):
	return x.exp() / x.exp().sum()

# done in vanilla attention
target = naive_softmax(x) 

# softmax over chunks
x1,x2 = torch.chunk(x,2)
softmax1 = naive_softmax(x1)
softmax2 = naive_softmax(x2)

# incremental softmax
sum_exp_1 = x1.exp().sum()
sum_exp_2 = x2.exp().sum()

softmax1_corrected = softmax1*sum_exp_1 / (sum_exp_1 + sum_exp_2)
softmax2_corrected = softmax2*sum_exp_2 / (sum_exp_1 + sum_exp_2)

softmax_combined = torch.cat([softmax1_corrected, softmax2_corrected])
torch.all_close(target, softmax_combined)

image
Source: https://www.youtube.com/watch?v=ws7angQYIxI

Pros:

  • Scales linearly with number of devices. E.g. If you can do batch_size = 1, seq_length=24k in one device → with context_parallel you can do batch_size = 1, seq_length=24k*8 when using 8x devices;
  • Complexity is 6bch, where c is the block size. It is independent of sequence length. Other types of parallelism, such as tensor parallelism, would be overwhelmed by the activations of long sequences;
  • Is (probably) future-proof, as long as we use attention;
  • No approximation! It is equivalent to the full attention;
  • No overhead, since communication time < computation time;
  • Pytorch is working to support it and it integrates with FSDP by default, so according to our PoC in pytorch, it should also handle quantization and torchao well;

Cons:

  • Block causal mask, or other types of mask, makes it more complicated and slower. We would have to support it initially only for full causal mask;
  • Still experimental on pytorch;

Code pointers:
Pytorch: pytorch/pytorch#129515
Torchtitan: pytorch/torchtitan#433

Current constraints of pytorch implementation that they are actively working on:

  • Doesn't work with compile yet;
  • Doesn't support arbritary masking yet (e.g. block causal);
  • Doesn't support striped attention yet (makes ring attention faster https://arxiv.org/abs/2311.09431);
  • UI may change;

Suggestion

My suggestion is to focus on ring attention experimental implementation by pytorch, since it scales well and will be supported by them. If their CP implementation is optimal, then 8xA100 would allow us to get to 8x24k = 192k context length. Most likely LongLora wouldn't be enough to reach 64k with llama3.

Timeline


According to our PyTorch PoC, the code to implement context parallelism is short, and can be done in a week. Therefore, I propose the following:

Week July 29th - Formalize script to, given a model, find max_seq_length for different settings, so we can test at 16GB, 24GB, 40GB, 80GB
Week Aug 5th - Understand pytorch CP implementation
Week Aug 11th - Proof-of-concept for llama3.1 8b, stress test max seq len
Week Aug 19th - Propose and discuss implementation in torchtune
Week Aug 26th - Evaluation on key long context datasets
Week Sept 2nd - Optimize + unit test (maybe do it before eval?)
Week Sept 9th - Implement for the rest of the models
Week Sept 16th - Tutorials + best practices

The timeline should look similar for LongLora. There is a risk in focusing too much on llama 3.1 and having something that needs refactoring to work for other models. We should try to minimize this risk and make it model independent, but prioritizing a family of models will allow us to iterate faster, learn and make adjustments later.

Evaluation


  • Multimodality is relevant, but not supported by torchtune yet. Possible dataset with 1T tokens: https://arxiv.org/pdf/2406.11271;
  • There is a lack of benchmarks, even for <32k, but it is an active area;
  • Most of the models report passkey retrieval tasks. Those don't necessarily correlate with task performance;
  • Furthermore, long-context training may hurt model performance in short-context. An interesting strategy is to evaluate across different context lengths and tasks

image
source: https://arxiv.org/pdf/2404.02060

Therefore, for evaluating our trained models I propose three strategies (text only):

  • Passkey retrieval, as a sanity check;
  • Improve long-context performance;
  • Retain short-context performance;

As a starter, we can use LongAlpaca (https://huggingface.co/datasets/Yukang/LongAlpaca-12k), since it was also used by LongLora. It contains 3k short examples and 9k long examples from the Alpaca dasetet, ranging from 35 characters to 191k.

This strategy can be improved during the evaluation focus time.

Metadata

Metadata

Assignees

No one assigned

    Labels

    discussionStart a discussionrfcRequest for comments

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions