Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/kernels_hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"your-model-name",
attn_implementation="kernels-community/flash-attn" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
attn_implementation="kernels-community/flash-attn2" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
)
```

Or when running a TRL training script:

```bash
python sft.py ... --attn_implementation kernels-community/flash-attn
python sft.py ... --attn_implementation kernels-community/flash-attn2
```

Or using the TRL CLI:

```bash
trl sft ... --attn_implementation kernels-community/flash-attn
trl sft ... --attn_implementation kernels-community/flash-attn2
```

> [!TIP]
Expand Down Expand Up @@ -84,7 +84,7 @@ from trl import SFTConfig

model = AutoModelForCausalLM.from_pretrained(
"your-model-name",
attn_implementation="kernels-community/flash-attn" # choose the desired FlashAttention variant
attn_implementation="kernels-community/flash-attn2" # choose the desired FlashAttention variant
)

training_args = SFTConfig(
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class ModelConfig:
be set to `True` for repositories you trust and in which you have read the code, as it will execute code
present on the Hub on your local machine.
attn_implementation (`str`, *optional*):
Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
you must install this manually by running `pip install flash-attn --no-build-isolation`.
Which attention implementation to use. More information in the [Kernels Hub Integrations
Guide](kernels_hub).
use_peft (`bool`, *optional*, defaults to `False`):
Whether to use PEFT for training.
lora_r (`int`, *optional*, defaults to `16`):
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
FLASH_ATTENTION_VARIANTS = {
"flash_attention_2",
"flash_attention_3",
"kernels-community/flash-attn",
"kernels-community/vllm-flash-attn3",
"kernels-community/flash-attn2",
"kernels-community/flash-attn3",
"kernels-community/vllm-flash-attn3",
}


Expand Down
Loading