Description
System Info
Transformers 4.37.2, 4.38.2, & 4.39.3.
Python 3.11.
PyTorch 2.2.2 w/ Cuda 12.1
Who can help?
@ArthurZucker and @younesbelkada
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
QLoRA Llama-7B with the SDPA Attention when ran with 4.37.2 on a 24GB card uses less memory than 4.38.2 and 4.39.3. You can reproduce with this script:
import argparse
from random import randint
from tqdm import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import Dataset
from peft import LoraConfig, get_peft_model
torch.set_float32_matmul_precision("high")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max_sequence_length", type=int, default=1280)
parser.add_argument("--dataset_size", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--flash_attn", action="store_true")
parser.add_argument("--torch_compile", action="store_true")
parser.add_argument("--profile_memory", action="store_true")
return parser.parse_args()
def get_dataset(dataset_size, vocab_size, sequence_length):
dataset = Dataset.from_dict(
{"input_ids": [[randint(0, vocab_size) for _ in range(sequence_length)] for i in range(0, dataset_size)]}
)
return dataset
def data_collator(batch):
batch = torch.stack([b["input_ids"] for b in batch])
return {"input_ids": batch, "labels": batch}
def append_stats(batch, stats):
if batch == 0:
stats.append(torch.cuda.memory_reserved(0))
def main():
args = parse_args()
print(args)
print(f"Transformers version: {transformers.__version__}")
stats = []
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
peft_config = LoraConfig(
r=16,
lora_alpha=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
max_position_embeddings=args.max_sequence_length,
attn_implementation="flash_attention_2" if args.flash_attn else "sdpa",
torch_dtype="auto",
use_cache=False,
quantization_config=bnb_config,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
optimizer = optim.SGD(model.parameters(), lr=0.001)
dataset = get_dataset(
args.dataset_size * args.batch_size,
tokenizer.vocab_size,
args.max_sequence_length,
).with_format("torch")
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=data_collator,
)
if args.torch_compile:
print("Compiling model")
model = torch.compile(model)
torch.cuda.reset_peak_memory_stats(0)
for i, batch in enumerate(tqdm(dataloader)):
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")
if args.profile_memory and i == 0:
torch.cuda.memory._record_memory_history()
append_stats(i, stats)
output = model(input_ids=input_ids, labels=labels, attention_mask=None)
loss = output.loss
append_stats(i, stats)
loss.backward()
append_stats(i, stats)
optimizer.step()
optimizer.zero_grad()
if args.profile_memory and i == 0:
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
for label, stat in zip(["forward", "backward", "optimizer"], stats):
print(f"Before {label}: {stat/2**30:.2f} GiB")
print(f"Max Reserved: {torch.cuda.max_memory_reserved(0)/2**30:.2f} GiB")
if __name__ == "__main__":
main()
# 4.37.2
python train.py --batch_size 1 --max_sequence_length 1280
python train.py --batch_size 1 --max_sequence_length 1536
# 4.39.3
python train.py --batch_size 1 --max_sequence_length 1280
# this one errors out on a 24GB card, likely due to OOM
python train.py --batch_size 1 --max_sequence_length 1536
Expected behavior
While spot checking our FSDP+QLoRA script on Transformers 4.39.3, I noticed that the maximum batch size we could finetune on two 24 GB cards with a sequence length of 2048 was reduced from 12 to 5 compared to 4.37.2. This is due to Llama 2 in 4.38 and 4.39 using significantly more memory than 4.37.
This Llama memory issue persists post #29753, which resolved some but not all of the Llama rewrite memory issues mentioned in #29484 and other issues.
I also reproduced the issue without FSDP on one 24GB card using the script above.
Version | Seq Len | Before Forward | Before Backward | Before Opt | Max Reserved |
---|---|---|---|---|---|
4.37.2 | 1280 | 3.90 GiB | 15.54 GiB | 15.89 GiB | 18.79 GiB |
4.37.2 | 1536 | 3.90 GiB | 18.21 GiB | 18.60 GiB | 22.08 GiB |
4.39.3 | 1280 | 3.99 GiB | 15.92 GiB | 17.06 GiB | 19.61 GiB |
4.39.3 | 1536 | - | - | - | OOM |
4.39.3 uses almost a GB more memory at a sequence length of 1280 and errors out at 1536.
Using torch.cuda.memory._record_memory_history
shows a peak in memory usage in 4.37.2 at the start of the backward pass as expected. (Use the --peak_memory
flag in the above script).

Switching to 3.49.3 shows an unexpected result: consistently large memory spikes during the backward pass of the SDPA kernel:

When sharding Llama across two cards with FSDP and gradient checkpointing, the memory spikes become quite visible as ~10GB outliers:

There are no spikes in 4.37 with FSDP.
The culprit appears to be the switch from using the SDPA Flash Attention kernel in 4.37.2 to the SDPA Efficient kernel in 3.38 & 4.39.
# Transformers 4.37.2 LlamaSdpaAttention.forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
# Transformers 4.39.3 LlamaSdpaAttention.forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
)
By removing the is_causal
and using a custom causal_mask
instead, scaled_dot_product_attention
is now using the more memory hungry Efficient kernel instead of the memory efficient Flash Attention 2 kernel.
You can spot check this by using LlamaFlashAttention2
in the reproduction script by using --flash_attn
flag.
Version | Seq Len | Before Forward | Before Backward | Before Opt | Max Reserved |
---|---|---|---|---|---|
4.37.2 | 1536 | 3.90 GiB | 18.21 GiB | 18.60 GiB | 22.08 GiB |
4.39.3 | 1536 | 3.99 GiB | 18.36 GiB | 18.75 GiB | 22.24 GiB |
With LlamaFlashAttention2
, 4.39 uses a moderate amount more memory than 4.37. Although this may increase to a significant amount at longer context sizes.
When training, LlamaSdpaAttention
should use is_causal=True
if there isn't an attention mask passed to Llama instead of creating causal_mask
.
The switch to a custom causal_mask
also causes torch.compile
errors which are not present in 4.37 due to data type mismatches in 4.39, particularly when training in pure bfloat16 instead of mixed precision.