Skip to content

Llama uses significantly more memory in 4.38 & 4.39 than 4.37 with identical code #30010

Closed
@warner-benjamin

Description

@warner-benjamin

System Info

Transformers 4.37.2, 4.38.2, & 4.39.3.
Python 3.11.
PyTorch 2.2.2 w/ Cuda 12.1

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

QLoRA Llama-7B with the SDPA Attention when ran with 4.37.2 on a 24GB card uses less memory than 4.38.2 and 4.39.3. You can reproduce with this script:

import argparse
from random import randint

from tqdm import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import Dataset
from peft import LoraConfig, get_peft_model

torch.set_float32_matmul_precision("high")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_sequence_length", type=int, default=1280)
    parser.add_argument("--dataset_size", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--flash_attn", action="store_true")
    parser.add_argument("--torch_compile", action="store_true")
    parser.add_argument("--profile_memory", action="store_true")

    return parser.parse_args()


def get_dataset(dataset_size, vocab_size, sequence_length):
    dataset = Dataset.from_dict(
        {"input_ids": [[randint(0, vocab_size) for _ in range(sequence_length)] for i in range(0, dataset_size)]}
    )
    return dataset


def data_collator(batch):
    batch = torch.stack([b["input_ids"] for b in batch])
    return {"input_ids": batch, "labels": batch}


def append_stats(batch, stats):
    if batch == 0:
        stats.append(torch.cuda.memory_reserved(0))


def main():
    args = parse_args()
    print(args)
    print(f"Transformers version: {transformers.__version__}")
    stats = []

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    peft_config = LoraConfig(
        r=16,
        lora_alpha=8,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        max_position_embeddings=args.max_sequence_length,
        attn_implementation="flash_attention_2" if args.flash_attn else "sdpa",
        torch_dtype="auto",
        use_cache=False,
        quantization_config=bnb_config,
    )

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    optimizer = optim.SGD(model.parameters(), lr=0.001)

    dataset = get_dataset(
        args.dataset_size * args.batch_size,
        tokenizer.vocab_size,
        args.max_sequence_length,
    ).with_format("torch")

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=data_collator,
    )

    if args.torch_compile:
        print("Compiling model")
        model = torch.compile(model)

    torch.cuda.reset_peak_memory_stats(0)
    for i, batch in enumerate(tqdm(dataloader)):
        input_ids = batch["input_ids"].to("cuda")
        labels = batch["labels"].to("cuda")

        if args.profile_memory and i == 0:
            torch.cuda.memory._record_memory_history()

        append_stats(i, stats)
        output = model(input_ids=input_ids, labels=labels, attention_mask=None)
        loss = output.loss

        append_stats(i, stats)
        loss.backward()

        append_stats(i, stats)
        optimizer.step()
        optimizer.zero_grad()

        if args.profile_memory and i == 0:
            torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
            torch.cuda.memory._record_memory_history(enabled=None)

    for label, stat in zip(["forward", "backward", "optimizer"], stats):
        print(f"Before {label}: {stat/2**30:.2f} GiB")
    print(f"Max Reserved: {torch.cuda.max_memory_reserved(0)/2**30:.2f} GiB")


if __name__ == "__main__":
    main()
# 4.37.2
python train.py --batch_size 1 --max_sequence_length 1280
python train.py --batch_size 1 --max_sequence_length 1536
# 4.39.3
python train.py --batch_size 1 --max_sequence_length 1280
# this one errors out on a 24GB card, likely due to OOM
python train.py --batch_size 1 --max_sequence_length 1536

Expected behavior

While spot checking our FSDP+QLoRA script on Transformers 4.39.3, I noticed that the maximum batch size we could finetune on two 24 GB cards with a sequence length of 2048 was reduced from 12 to 5 compared to 4.37.2. This is due to Llama 2 in 4.38 and 4.39 using significantly more memory than 4.37.

This Llama memory issue persists post #29753, which resolved some but not all of the Llama rewrite memory issues mentioned in #29484 and other issues.

I also reproduced the issue without FSDP on one 24GB card using the script above.

Version Seq Len Before Forward Before Backward Before Opt Max Reserved
4.37.2 1280 3.90 GiB 15.54 GiB 15.89 GiB 18.79 GiB
4.37.2 1536 3.90 GiB 18.21 GiB 18.60 GiB 22.08 GiB
4.39.3 1280 3.99 GiB 15.92 GiB 17.06 GiB 19.61 GiB
4.39.3 1536 - - - OOM

4.39.3 uses almost a GB more memory at a sequence length of 1280 and errors out at 1536.

Using torch.cuda.memory._record_memory_history shows a peak in memory usage in 4.37.2 at the start of the backward pass as expected. (Use the --peak_memory flag in the above script).

image

Switching to 3.49.3 shows an unexpected result: consistently large memory spikes during the backward pass of the SDPA kernel:

image

When sharding Llama across two cards with FSDP and gradient checkpointing, the memory spikes become quite visible as ~10GB outliers:

image

There are no spikes in 4.37 with FSDP.

The culprit appears to be the switch from using the SDPA Flash Attention kernel in 4.37.2 to the SDPA Efficient kernel in 3.38 & 4.39.

# Transformers 4.37.2 LlamaSdpaAttention.forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=attention_mask,
    dropout_p=self.attention_dropout if self.training else 0.0,
    is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
# Transformers 4.39.3 LlamaSdpaAttention.forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
    query_states,
    key_states,
    value_states,
    attn_mask=causal_mask,
    dropout_p=self.attention_dropout if self.training else 0.0,
)

By removing the is_causal and using a custom causal_mask instead, scaled_dot_product_attention is now using the more memory hungry Efficient kernel instead of the memory efficient Flash Attention 2 kernel.

You can spot check this by using LlamaFlashAttention2 in the reproduction script by using --flash_attn flag.

Version Seq Len Before Forward Before Backward Before Opt Max Reserved
4.37.2 1536 3.90 GiB 18.21 GiB 18.60 GiB 22.08 GiB
4.39.3 1536 3.99 GiB 18.36 GiB 18.75 GiB 22.24 GiB

With LlamaFlashAttention2, 4.39 uses a moderate amount more memory than 4.37. Although this may increase to a significant amount at longer context sizes.

When training, LlamaSdpaAttention should use is_causal=True if there isn't an attention mask passed to Llama instead of creating causal_mask.

The switch to a custom causal_mask also causes torch.compile errors which are not present in 4.37 due to data type mismatches in 4.39, particularly when training in pure bfloat16 instead of mixed precision.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions