Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,27 @@ trainer.train()
```

For more details, see the [MiniLLM Trainer documentation](minillm) documentation.

## Distributed Training

### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

**📜 Paper**: https://huggingface.co/papers/1910.02054

ZeRO (Zero Redundancy Optimizer) eliminates memory redundancies in data- and model-parallel training by partitioning optimizer states, gradients, and parameters across devices while retaining low communication volume and high computational granularity. This allows for the efficient training of large models that would otherwise not fit in GPU memory.

TRL supports ZeRO via the [DeepSpeed integration](deepspeed_integration). To use it, provide a DeepSpeed configuration file with your desired settings,

```yaml
# config.yaml
distributed_type: DEEPSPEED
num_processes: 2
deepspeed_config:
zero_stage: 3
```

and launch the training script using `accelerate launch --config_file config_file`.

```sh
accelerate launch --config_file config.yaml train.py
```
12 changes: 10 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,11 @@ def __init__(

# Model and reference model
if isinstance(model, str):
model = create_model_from_path(model, **args.model_init_kwargs or {})
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, **model_init_kwargs)
else:
if args.model_init_kwargs is not None:
logger.warning(
Expand All @@ -305,7 +309,11 @@ def __init__(
)
model_id = get_config_model_id(model.config)
if isinstance(ref_model, str):
ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {})
model_init_kwargs = args.ref_model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
ref_model = create_model_from_path(ref_model, **model_init_kwargs)
else:
if args.ref_model_init_kwargs is not None:
logger.warning(
Expand Down
6 changes: 5 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,11 @@ def __init__(

# Model
if isinstance(model, str):
model = create_model_from_path(model, **args.model_init_kwargs or {})
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, **model_init_kwargs)
else:
if args.model_init_kwargs is not None:
logger.warning(
Expand Down
Loading