Skip to content

Slowdown in Training Speed Due to SDPA Mask Fix in Version 4.40.0  #30461

Closed
@achew010

Description

@achew010

System Info

Hi,

I have been doing some peft tuning with Mistral/Mixtral and recently I observed a slowdown in training since the release of version 4.40.0. I narrowed it down to this fix in 40eb6d6 where the sliding window is now specified in _prepare_4d_causal_attention_mask_for_sdpa.

I ran a simple training job and the training statistics produced 2 different sets of throughputs

Sequence Length release 4.39.3 (toks/s) release 4.40.0 (toks/s)
4096 3247 2483
8192 3083 1918

When my training sequence length is within/on the sliding window threshold (i.e. seqlen = 4096, window = 4096), it should fall back to the SDPA kernel to handle the causal mask. I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).

Below is a dummy example showing that simply not passing the causal mask into pytorch's SDPA function (allowing the kernel to handle the causal mask itself) vs specifying the sliding window, has a significant impact on the processing speed of the kernel.

Causal Mask Attn Mask is passed to Torch SDPA Causal Mask handled internally in Torch SDPA

Is this slowdown something we should expect from using the SDPA module with the current fix?

I attached a simple script to reproduce the issue

System Info

- `transformers` version: 4.40.0
- Platform: Linux-4.18.0-372.71.1.el8_6.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.8
- Huggingface_hub version: 0.22.2
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config:    not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: True
- Using distributed or parallel set-up in script?: False

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Script to reproduce
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, __version__ as transformer_version
from datasets import load_dataset
from trl import SFTTrainer

print(f"transformers version: {transformer_version}")

dataset = load_dataset("yahma/alpaca-cleaned", split="train")
model_name = 'mistralai/Mistral-7B-v0.1'

config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
tokenizer.pad_token = tokenizer.unk_token

model = AutoModelForCausalLM.from_config(
    config, torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
    attn_implementation="sdpa",
)

print(model.model._attn_implementation)

args = {
    'batch_size': 4,
    'gradient_accumulation_steps': 1,
    'use_gradient_checkpointing': 1,
    'warmup_steps': 10,
    'lr': 2e-4,
    'logging_steps': 10,
    'output_dir': './results',
    'optimizer': 'adamw_torch',
    'weight_decay': 0.0,
    'lr_scheduler': 'linear',
    'seed': 42,
    'max_steps': 100,
    'context_length': 4096,
}

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n\n"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n\n"
    ),
}

def formatting_prompts_func(example):
    output_texts = []
    if example.get("input", "") == "":
        prompt = PROMPT_DICT["prompt_no_input"].format_map(example)
    else:
        prompt = PROMPT_DICT["prompt_input"].format_map(example)
    new_example = prompt + example["output"]
    return new_example

training_args = TrainingArguments(
    per_device_train_batch_size = args['batch_size'],
    gradient_accumulation_steps = args['gradient_accumulation_steps'],
    gradient_checkpointing=args['use_gradient_checkpointing'],
    warmup_steps = args['warmup_steps'],
    max_steps = args['max_steps'],
    learning_rate = args['lr'],
    logging_strategy = 'steps',
    logging_steps = args['logging_steps'],
    output_dir = args['output_dir'],
    optim = args['optimizer'],
    weight_decay = args['weight_decay'],
    lr_scheduler_type = args['lr_scheduler'],
    seed = args['seed'],
    include_tokens_per_second = True,
)

trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = dataset,
        max_seq_length = args['context_length'],
        args = training_args,
        formatting_func=formatting_prompts_func,
        packing=True,
    )

stats = trainer.train()

Expected behavior

  1. Throughput should remain the same for sequence lengths lower than the window size for SPDA

  2. Throughput should be slightly faster (from lesser computations in local attention) than regular attention (when no sliding window is specified in causal mask) for longer sequence lengths

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions