Skip to content

Commit cbe38d7

Browse files
docs: Address PR feedback for PEFT integration guide
- Add comprehensive learning rate section with table and blog links - Add learning_rate parameters to all code examples (SFT, DPO, GRPO, QLoRA, Prompt Tuning) - Remove Full Training (No PEFT) sections for cleaner focus - Remove Troubleshooting section as requested - Document three methods of PEFT configuration (CLI, peft_config, get_peft_model) - Enhance Resources section with TRL notebooks, examples, and Cookbook - Simplify Python examples using ellipsis for non-PEFT configs - Fix import order (standard library before third-party)
1 parent 7ff168e commit cbe38d7

File tree

1 file changed

+155
-64
lines changed

1 file changed

+155
-64
lines changed

docs/source/peft_integration.md

Lines changed: 155 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ peft_config = LoraConfig(
4949
task_type="CAUSAL_LM",
5050
)
5151

52+
# Configure training - note the higher learning rate for LoRA (10x base rate)
53+
training_args = SFTConfig(
54+
learning_rate=2.0e-4, # 10x the base rate (2.0e-5) for LoRA
55+
...
56+
)
57+
5258
# Create trainer with PEFT
5359
trainer = SFTTrainer(
5460
model=model,
@@ -58,6 +64,107 @@ trainer = SFTTrainer(
5864
)
5965
```
6066

67+
## Three Ways to Configure PEFT
68+
69+
TRL provides three different methods to configure PEFT, each suited for different use cases:
70+
71+
### 1. Using CLI Flags (Simplest)
72+
73+
The easiest way to enable PEFT is using the `--use_peft` flag with the command-line interface. This method is ideal for quick experiments and standard configurations:
74+
75+
```bash
76+
python trl/scripts/sft.py \
77+
--model_name_or_path Qwen/Qwen2-0.5B \
78+
--dataset_name trl-lib/Capybara \
79+
--use_peft \
80+
--lora_r 32 \
81+
--lora_alpha 16 \
82+
--lora_dropout 0.05 \
83+
--output_dir Qwen2-0.5B-SFT-LoRA
84+
```
85+
86+
**Pros**: Quick setup, no code required
87+
**Cons**: Limited to LoRA, fewer customization options
88+
89+
### 2. Passing peft_config to Trainer (Recommended)
90+
91+
For more control, pass a PEFT configuration directly to the trainer. This is the recommended approach for most use cases:
92+
93+
```python
94+
from peft import LoraConfig
95+
from trl import SFTConfig, SFTTrainer
96+
97+
peft_config = LoraConfig(
98+
r=32,
99+
lora_alpha=16,
100+
lora_dropout=0.05,
101+
bias="none",
102+
task_type="CAUSAL_LM",
103+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
104+
)
105+
106+
trainer = SFTTrainer(
107+
model=model,
108+
args=training_args,
109+
train_dataset=dataset,
110+
peft_config=peft_config, # Pass config here
111+
)
112+
```
113+
114+
**Pros**: Full control, supports all PEFT methods (LoRA, Prompt Tuning, etc.)
115+
**Cons**: Requires Python code
116+
117+
### 3. Applying PEFT to Model Directly (Advanced)
118+
119+
For maximum flexibility, you can apply PEFT to your model before passing it to the trainer:
120+
121+
```python
122+
from peft import LoraConfig, get_peft_model
123+
from transformers import AutoModelForCausalLM
124+
from trl import SFTConfig, SFTTrainer
125+
126+
# Load base model
127+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
128+
129+
# Apply PEFT configuration
130+
peft_config = LoraConfig(
131+
r=32,
132+
lora_alpha=16,
133+
lora_dropout=0.05,
134+
bias="none",
135+
task_type="CAUSAL_LM",
136+
)
137+
model = get_peft_model(model, peft_config)
138+
139+
# Pass PEFT-wrapped model to trainer
140+
trainer = SFTTrainer(
141+
model=model, # Already has PEFT applied
142+
args=training_args,
143+
train_dataset=dataset,
144+
# Note: no peft_config needed here
145+
)
146+
```
147+
148+
**Pros**: Maximum control, useful for custom model architectures or complex setups
149+
**Cons**: More verbose, requires understanding of PEFT internals
150+
151+
## Learning Rate Considerations
152+
153+
When using LoRA or other PEFT methods, you typically need to use a **higher learning rate** (approximately 10x) compared to full fine-tuning. This is because PEFT methods train only a small fraction of parameters, requiring a larger learning rate to achieve similar parameter updates.
154+
155+
**Recommended learning rates:**
156+
157+
| Trainer | Full Fine-Tuning | With LoRA (10x) |
158+
|---------|------------------|-----------------|
159+
| **SFT** | `2.0e-5` | `2.0e-4` |
160+
| **DPO** | `5.0e-7` | `5.0e-6` |
161+
| **GRPO** | `1.0e-6` | `1.0e-5` |
162+
| **Prompt Tuning** | N/A | `1.0e-2` to `3.0e-2` |
163+
164+
> **Why 10x?** LoRA adapters have significantly fewer trainable parameters than the full model. A higher learning rate compensates for this reduced parameter count, ensuring effective training. For detailed explanation, see [this blog post](https://thinkingmachines.ai/blog/lora/).
165+
166+
For additional best practices on using LoRA effectively, refer to the [LoRA Without Regret](lora_without_regret) documentation.
167+
61168
## PEFT with Different Trainers
62169

63170
TRL's trainers support PEFT configurations for various training paradigms. Below are detailed examples for each major trainer.
@@ -69,19 +176,6 @@ TRL's trainers support PEFT configurations for various training paradigms. Below
69176

70177
The `SFTTrainer` is used for supervised fine-tuning on instruction datasets.
71178

72-
#### Full Training (No PEFT)
73-
74-
```bash
75-
python trl/scripts/sft.py \
76-
--model_name_or_path Qwen/Qwen2-0.5B \
77-
--dataset_name trl-lib/Capybara \
78-
--learning_rate 2.0e-5 \
79-
--num_train_epochs 1 \
80-
--per_device_train_batch_size 2 \
81-
--gradient_accumulation_steps 8 \
82-
--output_dir Qwen2-0.5B-SFT
83-
```
84-
85179
#### With LoRA
86180

87181
```bash
@@ -114,6 +208,12 @@ peft_config = LoraConfig(
114208
target_modules=["q_proj", "v_proj"], # Optional: specify target modules
115209
)
116210

211+
# Configure training with higher learning rate for LoRA
212+
training_args = SFTConfig(
213+
learning_rate=2.0e-4, # 10x the base rate for LoRA
214+
...
215+
)
216+
117217
# Create trainer with PEFT config
118218
trainer = SFTTrainer(
119219
model=model,
@@ -132,18 +232,6 @@ trainer.train()
132232

133233
The `DPOTrainer` implements preference learning from human feedback.
134234

135-
#### Full Training (No PEFT)
136-
137-
```bash
138-
python trl/scripts/dpo.py \
139-
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
140-
--dataset_name trl-lib/ultrafeedback_binarized \
141-
--learning_rate 5.0e-7 \
142-
--per_device_train_batch_size 2 \
143-
--gradient_accumulation_steps 8 \
144-
--output_dir Qwen2-0.5B-DPO
145-
```
146-
147235
#### With LoRA
148236

149237
```bash
@@ -174,6 +262,12 @@ peft_config = LoraConfig(
174262
task_type="CAUSAL_LM",
175263
)
176264

265+
# Configure training with higher learning rate for LoRA
266+
training_args = DPOConfig(
267+
learning_rate=5.0e-6, # 10x the base rate for DPO with LoRA
268+
...
269+
)
270+
177271
# Create trainer with PEFT config
178272
trainer = DPOTrainer(
179273
model=model,
@@ -195,17 +289,6 @@ trainer.train()
195289

196290
The `GRPOTrainer` optimizes policies using group-based rewards.
197291

198-
#### Full Training (No PEFT)
199-
200-
```bash
201-
python trl/scripts/grpo.py \
202-
--model_name_or_path Qwen/Qwen2-0.5B \
203-
--dataset_name trl-lib/math-reasoning \
204-
--learning_rate 1.0e-6 \
205-
--per_device_train_batch_size 2 \
206-
--output_dir Qwen2-0.5B-GRPO
207-
```
208-
209292
#### With LoRA
210293

211294
```bash
@@ -235,6 +318,12 @@ peft_config = LoraConfig(
235318
task_type="CAUSAL_LM",
236319
)
237320

321+
# Configure training with higher learning rate for LoRA
322+
training_args = GRPOConfig(
323+
learning_rate=1.0e-5, # 10x the base rate for GRPO with LoRA
324+
...
325+
)
326+
238327
# Create trainer with PEFT config
239328
trainer = GRPOTrainer(
240329
model="Qwen/Qwen2-0.5B", # Can pass model name or loaded model
@@ -282,10 +371,11 @@ python trl/scripts/sft.py \
282371
#### Python Example
283372

284373
```python
285-
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
374+
import torch
375+
286376
from peft import LoraConfig
377+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
287378
from trl import SFTConfig, SFTTrainer
288-
import torch
289379

290380
# Configure 4-bit quantization
291381
bnb_config = BitsAndBytesConfig(
@@ -311,6 +401,12 @@ peft_config = LoraConfig(
311401
task_type="CAUSAL_LM",
312402
)
313403

404+
# Configure training with higher learning rate for LoRA
405+
training_args = SFTConfig(
406+
learning_rate=2.0e-4, # 10x the base rate for QLoRA
407+
...
408+
)
409+
314410
# Create trainer with PEFT config
315411
trainer = SFTTrainer(
316412
model=model,
@@ -327,9 +423,10 @@ trainer.train()
327423
The `BitsAndBytesConfig` provides several options to optimize memory and performance:
328424

329425
```python
330-
from transformers import BitsAndBytesConfig
331426
import torch
332427

428+
from transformers import BitsAndBytesConfig
429+
333430
bnb_config = BitsAndBytesConfig(
334431
load_in_4bit=True,
335432
bnb_4bit_quant_type="nf4", # or "fp4"
@@ -396,6 +493,12 @@ peft_config = PromptTuningConfig(
396493
tokenizer_name_or_path="Qwen/Qwen2-0.5B",
397494
)
398495

496+
# Configure training with higher learning rate for Prompt Tuning
497+
training_args = SFTConfig(
498+
learning_rate=2.0e-2, # Prompt Tuning typically uses 1e-2 to 3e-2
499+
...
500+
)
501+
399502
# Create trainer with PEFT config
400503
trainer = SFTTrainer(
401504
model=model,
@@ -584,34 +687,22 @@ accelerate launch trl/scripts/sft.py \
584687
--lora_r 32
585688
```
586689

587-
## Troubleshooting
588-
589-
### Out of Memory Errors
590-
591-
If you encounter OOM errors:
592-
593-
1. Enable QLoRA: `--load_in_4bit`
594-
2. Reduce batch size: `--per_device_train_batch_size 1`
595-
3. Increase gradient accumulation: `--gradient_accumulation_steps 16`
596-
4. Enable gradient checkpointing: `--gradient_checkpointing`
597-
5. Reduce LoRA rank: `--lora_r 8`
598-
6. Reduce target modules: `--lora_target_modules q_proj v_proj`
599-
600-
### Slow Training
690+
## Resources
601691

602-
If training is slow:
692+
### TRL Examples and Notebooks
603693

604-
1. Increase batch size (if memory allows)
605-
2. Use Flash Attention 2: `--attn_implementation flash_attention_2`
606-
3. Use bf16: `--bf16`
607-
4. Reduce gradient checkpointing frequency
694+
- **[SFT with LoRA/QLoRA Notebook](https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)** - Complete working example showing both LoRA and QLoRA implementations
695+
- **[TRL Examples Directory](https://github.com/huggingface/trl/tree/main/examples)** - Collection of training scripts demonstrating PEFT with different trainers
696+
- **[TRL Cookbook Recipes](https://github.com/huggingface/cookbook/tree/main/notebooks/transformers)** - Step-by-step guides for common PEFT training scenarios
608697

698+
### Documentation
609699

700+
- [PEFT Documentation](https://huggingface.co/docs/peft) - Official PEFT library documentation
701+
- [TRL Documentation](https://huggingface.co/docs/trl) - Complete TRL documentation with trainer guides
702+
- [LoRA Without Regret](lora_without_regret) - Best practices for using LoRA effectively
610703

611-
## Resources
704+
### Research Papers
612705

613-
- [PEFT Documentation](https://huggingface.co/docs/peft)
614-
- [LoRA Paper](https://arxiv.org/abs/2106.09685)
615-
- [QLoRA Paper](https://arxiv.org/abs/2305.14314)
616-
- [Prompt Tuning Paper](https://arxiv.org/abs/2104.08691)
617-
- [TRL Documentation](https://huggingface.co/docs/trl)
706+
- [LoRA Paper](https://arxiv.org/abs/2106.09685) - Original LoRA methodology and results
707+
- [QLoRA Paper](https://arxiv.org/abs/2305.14314) - Efficient finetuning with 4-bit quantization
708+
- [Prompt Tuning Paper](https://arxiv.org/abs/2104.08691) - The Power of Scale for Parameter-Efficient Prompt Tuning

0 commit comments

Comments
 (0)