Skip to content

Commit f1dfef0

Browse files
docs: Expand training customization examples (#4427)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent eb76389 commit f1dfef0

File tree

1 file changed

+75
-52
lines changed

1 file changed

+75
-52
lines changed

docs/source/customization.md

Lines changed: 75 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Training customization
22

3-
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
3+
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques.
4+
5+
> [!NOTE]
6+
> Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL.
47
58
## Use different optimizers and schedulers
69

@@ -31,89 +34,109 @@ trainer.train()
3134

3235
### Add a learning rate scheduler
3336

34-
You can also play with your training by adding learning rate schedulers.
37+
You can also add learning rate schedulers by passing both optimizer and scheduler:
3538

3639
```python
37-
from datasets import load_dataset
38-
from transformers import AutoModelForCausalLM, AutoTokenizer
3940
from torch import optim
40-
from trl import DPOConfig, DPOTrainer
41-
42-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
43-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
44-
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
45-
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
4641

4742
optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
4843
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
4944

50-
trainer = DPOTrainer(
51-
model=model,
52-
args=training_args,
53-
train_dataset=dataset,
54-
tokenizer=tokenizer,
55-
optimizers=(optimizer, lr_scheduler),
56-
)
57-
trainer.train()
45+
trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler))
5846
```
5947

6048
## Memory efficient fine-tuning by sharing layers
6149

6250
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
6351

6452
```python
65-
from datasets import load_dataset
66-
from transformers import AutoModelForCausalLM, AutoTokenizer
67-
from trl import create_reference_model, DPOConfig, DPOTrainer
53+
from trl import create_reference_model
6854

69-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
7055
ref_model = create_reference_model(model, num_shared_layers=6)
71-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
72-
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
73-
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
7456

75-
trainer = DPOTrainer(
76-
model=model,
77-
ref_model=ref_model,
78-
args=training_args,
79-
train_dataset=dataset,
80-
tokenizer=tokenizer,
81-
)
82-
trainer.train()
57+
trainer = DPOTrainer(..., ref_model=ref_model)
8358
```
8459

8560
## Pass 8-bit reference models
8661

8762
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
8863

89-
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
64+
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft).
9065

9166
```python
92-
from datasets import load_dataset
93-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
94-
from trl import DPOConfig, DPOTrainer
67+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
9568

96-
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
9769
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
98-
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
99-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
100-
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
101-
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
70+
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config=quantization_config)
10271

103-
trainer = DPOTrainer(
104-
model=model,
105-
ref_model=ref_model,
106-
args=training_args,
107-
train_dataset=dataset,
108-
tokenizer=tokenizer,
72+
trainer = DPOTrainer(..., ref_model=ref_model)
73+
```
74+
75+
## Add custom callbacks
76+
77+
You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training.
78+
79+
```python
80+
from transformers import TrainerCallback
81+
82+
83+
class CustomLoggingCallback(TrainerCallback):
84+
def on_log(self, args, state, control, logs=None, **kwargs):
85+
if logs is not None:
86+
print(f"Step {state.global_step}: {logs}")
87+
88+
89+
trainer = DPOTrainer(..., callbacks=[CustomLoggingCallback()])
90+
```
91+
92+
## Add custom evaluation metrics
93+
94+
You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks.
95+
96+
```python
97+
def compute_metrics(eval_preds):
98+
logits, labels = eval_preds
99+
# Add your metric computation here
100+
return {"custom_metric": 0.0}
101+
102+
103+
training_args = DPOConfig(..., eval_strategy="steps", eval_steps=100)
104+
105+
trainer = DPOTrainer(..., eval_dataset=eval_dataset, compute_metrics=compute_metrics)
106+
```
107+
108+
## Use mixed precision training
109+
110+
Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config.
111+
112+
```python
113+
# Use bfloat16 precision (recommended for modern GPUs)
114+
training_args = DPOConfig(..., bf16=True)
115+
```
116+
117+
Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs.
118+
119+
## Use gradient accumulation
120+
121+
When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights.
122+
123+
```python
124+
# Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8
125+
training_args = DPOConfig(
126+
...,
127+
per_device_train_batch_size=4,
128+
gradient_accumulation_steps=8,
109129
)
110-
trainer.train()
111130
```
112131

113-
## Use the accelerator cache optimizer
132+
## Use a custom data collator
114133

115-
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]:
134+
You can provide a custom data collator to handle special data preprocessing or padding strategies.
116135

117136
```python
118-
training_args = DPOConfig(..., optimize_device_cache=True)
137+
from trl.trainer.dpo_trainer import DataCollatorForPreference
138+
139+
data_collator = DataCollatorForPreference(pad_token_id=tokenizer.pad_token_id)
140+
141+
trainer = DPOTrainer(..., data_collator=data_collator)
119142
```

0 commit comments

Comments
 (0)