Skip to content

Commit 7ff168e

Browse files
docs: Address review feedback on PEFT integration guide
Applied all requested changes from PR review: 1. Added notebook reference link to example SFT LoRA/QLoRA notebook 2. Implemented hfoptions tabs to organize SFT/DPO/GRPO examples 3. Simplified Python code examples by removing non-PEFT boilerplate The documentation now focuses more clearly on PEFT-specific configuration while maintaining all essential information.
1 parent a0a1c96 commit 7ff168e

File tree

1 file changed

+25
-80
lines changed

1 file changed

+25
-80
lines changed

docs/source/peft_integration.md

Lines changed: 25 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ TRL supports [PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fi
44

55
This guide covers how to use PEFT with different TRL trainers, including LoRA, QLoRA, and prompt tuning techniques.
66

7+
For a complete working example, see the [SFT with LoRA/QLoRA notebook](https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb).
8+
79
## Installation
810

911
To use PEFT with TRL, install the required dependencies:
@@ -60,6 +62,9 @@ trainer = SFTTrainer(
6062

6163
TRL's trainers support PEFT configurations for various training paradigms. Below are detailed examples for each major trainer.
6264

65+
<hfoptions id="trainer-type">
66+
<hfoption id="sft">
67+
6368
### Supervised Fine-Tuning (SFT)
6469

6570
The `SFTTrainer` is used for supervised fine-tuning on instruction datasets.
@@ -96,18 +101,9 @@ python trl/scripts/sft.py \
96101
#### Python Example
97102

98103
```python
99-
from datasets import load_dataset
100-
from transformers import AutoModelForCausalLM, AutoTokenizer
101104
from peft import LoraConfig
102105
from trl import SFTConfig, SFTTrainer
103106

104-
# Load model and tokenizer
105-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
106-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
107-
108-
# Load dataset
109-
dataset = load_dataset("trl-lib/Capybara", split="train")
110-
111107
# Configure LoRA
112108
peft_config = LoraConfig(
113109
r=32,
@@ -118,26 +114,20 @@ peft_config = LoraConfig(
118114
target_modules=["q_proj", "v_proj"], # Optional: specify target modules
119115
)
120116

121-
# Training arguments
122-
training_args = SFTConfig(
123-
output_dir="./Qwen2-0.5B-SFT-LoRA",
124-
learning_rate=2.0e-4,
125-
per_device_train_batch_size=2,
126-
num_train_epochs=1,
127-
)
128-
129-
# Create trainer
117+
# Create trainer with PEFT config
130118
trainer = SFTTrainer(
131119
model=model,
132120
args=training_args,
133121
train_dataset=dataset,
134-
peft_config=peft_config,
122+
peft_config=peft_config, # Pass PEFT config here
135123
)
136124

137-
# Train
138125
trainer.train()
139126
```
140127

128+
</hfoption>
129+
<hfoption id="dpo">
130+
141131
### Direct Preference Optimization (DPO)
142132

143133
The `DPOTrainer` implements preference learning from human feedback.
@@ -172,18 +162,9 @@ python trl/scripts/dpo.py \
172162
#### Python Example
173163

174164
```python
175-
from datasets import load_dataset
176-
from transformers import AutoModelForCausalLM, AutoTokenizer
177165
from peft import LoraConfig
178166
from trl import DPOConfig, DPOTrainer
179167

180-
# Load model and tokenizer
181-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
182-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
183-
184-
# Load dataset
185-
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
186-
187168
# Configure LoRA
188169
peft_config = LoraConfig(
189170
r=32,
@@ -193,29 +174,23 @@ peft_config = LoraConfig(
193174
task_type="CAUSAL_LM",
194175
)
195176

196-
# Training arguments
197-
training_args = DPOConfig(
198-
output_dir="./Qwen2-0.5B-DPO-LoRA",
199-
learning_rate=5.0e-6,
200-
per_device_train_batch_size=2,
201-
)
202-
203-
# Create trainer
204-
# When using PEFT, ref_model is automatically handled and set to None
177+
# Create trainer with PEFT config
205178
trainer = DPOTrainer(
206179
model=model,
207180
ref_model=None, # Not needed when using PEFT
208181
args=training_args,
209182
train_dataset=dataset,
210-
peft_config=peft_config,
183+
peft_config=peft_config, # Pass PEFT config here
211184
)
212185

213-
# Train
214186
trainer.train()
215187
```
216188

217189
**Note:** When using PEFT with DPO, you don't need to provide a separate reference model (`ref_model`). The trainer automatically uses the frozen base model as the reference.
218190

191+
</hfoption>
192+
<hfoption id="grpo">
193+
219194
### Group Relative Policy Optimization (GRPO)
220195

221196
The `GRPOTrainer` optimizes policies using group-based rewards.
@@ -248,14 +223,9 @@ python trl/scripts/grpo.py \
248223
#### Python Example
249224

250225
```python
251-
from datasets import load_dataset
252-
from transformers import AutoModelForCausalLM, AutoTokenizer
253226
from peft import LoraConfig
254227
from trl import GRPOConfig, GRPOTrainer
255228

256-
# Load dataset
257-
dataset = load_dataset("trl-lib/math-reasoning", split="train")
258-
259229
# Configure LoRA
260230
peft_config = LoraConfig(
261231
r=32,
@@ -265,25 +235,20 @@ peft_config = LoraConfig(
265235
task_type="CAUSAL_LM",
266236
)
267237

268-
# Training arguments
269-
training_args = GRPOConfig(
270-
output_dir="./Qwen2-0.5B-GRPO-LoRA",
271-
learning_rate=1.0e-5,
272-
per_device_train_batch_size=2,
273-
)
274-
275-
# Create trainer
238+
# Create trainer with PEFT config
276239
trainer = GRPOTrainer(
277240
model="Qwen/Qwen2-0.5B", # Can pass model name or loaded model
278241
args=training_args,
279242
train_dataset=dataset,
280-
peft_config=peft_config,
243+
peft_config=peft_config, # Pass PEFT config here
281244
)
282245

283-
# Train
284246
trainer.train()
285247
```
286248

249+
</hfoption>
250+
</hfoptions>
251+
287252
## QLoRA: Quantized Low-Rank Adaptation
288253

289254
QLoRA combines 4-bit quantization with LoRA to enable fine-tuning of very large models on consumer hardware. This technique can reduce memory requirements by up to 4x compared to standard LoRA.
@@ -330,7 +295,7 @@ bnb_config = BitsAndBytesConfig(
330295
bnb_4bit_use_double_quant=True,
331296
)
332297

333-
# Load model in 4-bit
298+
# Load model with quantization
334299
model = AutoModelForCausalLM.from_pretrained(
335300
"meta-llama/Llama-2-7b-hf",
336301
quantization_config=bnb_config,
@@ -346,15 +311,7 @@ peft_config = LoraConfig(
346311
task_type="CAUSAL_LM",
347312
)
348313

349-
# Training arguments
350-
training_args = SFTConfig(
351-
output_dir="./Llama-2-7b-QLoRA",
352-
per_device_train_batch_size=1,
353-
gradient_accumulation_steps=16,
354-
learning_rate=2.0e-4,
355-
)
356-
357-
# Create trainer
314+
# Create trainer with PEFT config
358315
trainer = SFTTrainer(
359316
model=model,
360317
args=training_args,
@@ -427,13 +384,9 @@ Prompt tuning is another PEFT technique that learns soft prompts (continuous emb
427384
### Using Prompt Tuning with TRL
428385

429386
```python
430-
from transformers import AutoModelForCausalLM
431-
from peft import PromptTuningConfig, PromptTuningInit, get_peft_model, TaskType
387+
from peft import PromptTuningConfig, PromptTuningInit, TaskType
432388
from trl import SFTConfig, SFTTrainer
433389

434-
# Load base model
435-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
436-
437390
# Configure Prompt Tuning
438391
peft_config = PromptTuningConfig(
439392
task_type=TaskType.CAUSAL_LM,
@@ -443,20 +396,12 @@ peft_config = PromptTuningConfig(
443396
tokenizer_name_or_path="Qwen/Qwen2-0.5B",
444397
)
445398

446-
# Training arguments
447-
training_args = SFTConfig(
448-
output_dir="./Qwen2-0.5B-PromptTuning",
449-
per_device_train_batch_size=8,
450-
learning_rate=3e-2, # Prompt tuning typically uses higher learning rates
451-
num_train_epochs=5,
452-
)
453-
454-
# Create trainer
399+
# Create trainer with PEFT config
455400
trainer = SFTTrainer(
456401
model=model,
457402
args=training_args,
458403
train_dataset=dataset,
459-
peft_config=peft_config,
404+
peft_config=peft_config, # Pass PEFT config here
460405
)
461406

462407
trainer.train()

0 commit comments

Comments
 (0)