Description
System Info
Hi,
I have been doing some peft tuning with Mistral/Mixtral and recently I observed a slowdown in training since the release of version 4.40.0. I narrowed it down to this fix in 40eb6d6 where the sliding window is now specified in _prepare_4d_causal_attention_mask_for_sdpa
.
I ran a simple training job and the training statistics produced 2 different sets of throughputs
Sequence Length | release 4.39.3 (toks/s) | release 4.40.0 (toks/s) |
---|---|---|
4096 | 3247 | 2483 |
8192 | 3083 | 1918 |
When my training sequence length is within/on the sliding window threshold (i.e. seqlen = 4096, window = 4096), it should fall back to the SDPA kernel to handle the causal mask. I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).
Below is a dummy example showing that simply not passing the causal mask into pytorch's SDPA function (allowing the kernel to handle the causal mask itself) vs specifying the sliding window, has a significant impact on the processing speed of the kernel.
Causal Mask Attn Mask is passed to Torch SDPA | Causal Mask handled internally in Torch SDPA |
---|---|
![]() |
![]() |
Is this slowdown something we should expect from using the SDPA module with the current fix?
I attached a simple script to reproduce the issue
System Info
- `transformers` version: 4.40.0
- Platform: Linux-4.18.0-372.71.1.el8_6.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.8
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config: not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True
- Using distributed or parallel set-up in script?: False
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- Script to reproduce
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, __version__ as transformer_version
from datasets import load_dataset
from trl import SFTTrainer
print(f"transformers version: {transformer_version}")
dataset = load_dataset("yahma/alpaca-cleaned", split="train")
model_name = 'mistralai/Mistral-7B-v0.1'
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
tokenizer.pad_token = tokenizer.unk_token
model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16,
# attn_implementation="flash_attention_2",
attn_implementation="sdpa",
)
print(model.model._attn_implementation)
args = {
'batch_size': 4,
'gradient_accumulation_steps': 1,
'use_gradient_checkpointing': 1,
'warmup_steps': 10,
'lr': 2e-4,
'logging_steps': 10,
'output_dir': './results',
'optimizer': 'adamw_torch',
'weight_decay': 0.0,
'lr_scheduler': 'linear',
'seed': 42,
'max_steps': 100,
'context_length': 4096,
}
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n\n"
),
}
def formatting_prompts_func(example):
output_texts = []
if example.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(example)
else:
prompt = PROMPT_DICT["prompt_input"].format_map(example)
new_example = prompt + example["output"]
return new_example
training_args = TrainingArguments(
per_device_train_batch_size = args['batch_size'],
gradient_accumulation_steps = args['gradient_accumulation_steps'],
gradient_checkpointing=args['use_gradient_checkpointing'],
warmup_steps = args['warmup_steps'],
max_steps = args['max_steps'],
learning_rate = args['lr'],
logging_strategy = 'steps',
logging_steps = args['logging_steps'],
output_dir = args['output_dir'],
optim = args['optimizer'],
weight_decay = args['weight_decay'],
lr_scheduler_type = args['lr_scheduler'],
seed = args['seed'],
include_tokens_per_second = True,
)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
max_seq_length = args['context_length'],
args = training_args,
formatting_func=formatting_prompts_func,
packing=True,
)
stats = trainer.train()
Expected behavior
-
Throughput should remain the same for sequence lengths lower than the window size for SPDA
-
Throughput should be slightly faster (from lesser computations in local attention) than regular attention (when no sliding window is specified in causal mask) for longer sequence lengths