Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed] ZeRO-Infinity integration plus config revamp #11418

Merged
merged 14 commits into from
Apr 26, 2021
704 changes: 496 additions & 208 deletions docs/source/main_classes/trainer.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>=0.3.14",
"deepspeed>=0.3.15",
"docutils==0.16.0",
"fairscale>0.3",
"faiss-cpu",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.3.14",
"deepspeed": "deepspeed>=0.3.15",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
Expand Down
354 changes: 183 additions & 171 deletions src/transformers/integrations.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
replace_return_docstrings,
)
from .generation_utils import GenerationMixin
from .integrations import is_deepspeed_zero3_enabled
from .integrations import deepspeed_config, is_deepspeed_zero3_enabled
from .utils import logging


Expand Down Expand Up @@ -1084,8 +1084,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model to avoid the overhead in time and memory copying it on CPU or each GPU first
with deepspeed.zero.Init():
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first

# XXX: param_dict will be shortly replaced by deepspeed_config
with deepspeed.zero.Init(param_dict=deepspeed_config()):
model = cls(config, *model_args, **model_kwargs)
else:
model = cls(config, *model_args, **model_kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def __init__(

self._signature_columns = None

# XXX: can move this back to where it was
stas00 marked this conversation as resolved.
Show resolved Hide resolved
# Mixed precision setup
self.use_apex = False
self.use_amp = False
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class TrainingArguments:
<https://docs.python.org/3/library/argparse.html#module-argparse>`__ arguments that can be specified on the command
line.




Parameters:
output_dir (:obj:`str`):
The output directory where the model predictions and checkpoints will be written.
Expand Down Expand Up @@ -618,6 +615,14 @@ def __post_init__(self):
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")

if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
from transformers.integrations import DeepSpeedConfigHF

# will be used later by the Trainer (leave self.deepspeed unmodified in case a user relies on it not to be modified)
self.deepspeed_config_hf = DeepSpeedConfigHF(self)

def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
Expand Down
42 changes: 23 additions & 19 deletions tests/deepspeed/ds_config_zero2.json
Original file line number Diff line number Diff line change
@@ -1,43 +1,47 @@
{
"fp16": {
"enabled": true,
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},

"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
},

"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},

"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},

"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
},

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
57 changes: 33 additions & 24 deletions tests/deepspeed/ds_config_zero3.json
Original file line number Diff line number Diff line change
@@ -1,48 +1,57 @@
{
"fp16": {
"enabled": true,
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},

"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"cpu_offload_use_pin_memory" : true,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": 0,
"stage3_prefetch_bucket_size": 0,
"stage3_param_persistence_threshold": 0,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},

"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},

"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},

"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
Loading