Skip to content

Commit 3bb5d76

Browse files
authored
fix+docs: device_map=None for DeepSpeed and add ZeRO paper (1910.02054) to Paper Index (#4551)
1 parent 375b3eb commit 3bb5d76

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

docs/source/paper_index.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,27 @@ trainer.train()
697697
```
698698

699699
For more details, see the [MiniLLM Trainer documentation](minillm) documentation.
700+
701+
## Distributed Training
702+
703+
### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
704+
705+
**📜 Paper**: https://huggingface.co/papers/1910.02054
706+
707+
ZeRO (Zero Redundancy Optimizer) eliminates memory redundancies in data- and model-parallel training by partitioning optimizer states, gradients, and parameters across devices while retaining low communication volume and high computational granularity. This allows for the efficient training of large models that would otherwise not fit in GPU memory.
708+
709+
TRL supports ZeRO via the [DeepSpeed integration](deepspeed_integration). To use it, provide a DeepSpeed configuration file with your desired settings,
710+
711+
```yaml
712+
# config.yaml
713+
distributed_type: DEEPSPEED
714+
num_processes: 2
715+
deepspeed_config:
716+
zero_stage: 3
717+
```
718+
719+
and launch the training script using `accelerate launch --config_file config_file`.
720+
721+
```sh
722+
accelerate launch --config_file config.yaml train.py
723+
```

trl/trainer/dpo_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,11 @@ def __init__(
296296

297297
# Model and reference model
298298
if isinstance(model, str):
299-
model = create_model_from_path(model, **args.model_init_kwargs or {})
299+
model_init_kwargs = args.model_init_kwargs or {}
300+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
301+
if args.distributed_state.distributed_type == "DEEPSPEED":
302+
model_init_kwargs["device_map"] = None
303+
model = create_model_from_path(model, **model_init_kwargs)
300304
else:
301305
if args.model_init_kwargs is not None:
302306
logger.warning(
@@ -305,7 +309,11 @@ def __init__(
305309
)
306310
model_id = get_config_model_id(model.config)
307311
if isinstance(ref_model, str):
308-
ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {})
312+
model_init_kwargs = args.ref_model_init_kwargs or {}
313+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
314+
if args.distributed_state.distributed_type == "DEEPSPEED":
315+
model_init_kwargs["device_map"] = None
316+
ref_model = create_model_from_path(ref_model, **model_init_kwargs)
309317
else:
310318
if args.ref_model_init_kwargs is not None:
311319
logger.warning(

trl/trainer/sft_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,11 @@ def __init__(
603603

604604
# Model
605605
if isinstance(model, str):
606-
model = create_model_from_path(model, **args.model_init_kwargs or {})
606+
model_init_kwargs = args.model_init_kwargs or {}
607+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
608+
if args.distributed_state.distributed_type == "DEEPSPEED":
609+
model_init_kwargs["device_map"] = None
610+
model = create_model_from_path(model, **model_init_kwargs)
607611
else:
608612
if args.model_init_kwargs is not None:
609613
logger.warning(

0 commit comments

Comments
 (0)