Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accelerate deepspeed and gradient accumulation integrate #23236

Merged
merged 37 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b3987a8
mixed precision support via accelerate
pacman100 May 4, 2023
862d04b
fix issues
pacman100 May 4, 2023
f2196be
fix for the sharded ddp case
pacman100 May 4, 2023
2339a48
fix flax and tf failing tests
pacman100 May 4, 2023
263b134
`refactor the place to create `Accelerator` object
pacman100 May 4, 2023
a5bf517
move ddp prep to accelerate
pacman100 May 4, 2023
f00ce09
fix 😅
pacman100 May 4, 2023
254f9a4
resolving comments
pacman100 May 5, 2023
88e7350
move fsdp handling to accelerate
pacman100 May 5, 2023
b37ad2a
fixex
pacman100 May 5, 2023
ec73bf2
fix saving
pacman100 May 5, 2023
ed1a520
shift torch dynamo handling to accelerate
pacman100 May 5, 2023
f70ba13
shift deepspeed integration and save & load utils to accelerate
pacman100 May 9, 2023
d1cab6b
fix accelerate launcher support
pacman100 May 9, 2023
4cd9b70
oops
pacman100 May 9, 2023
0bee40f
fix 🐛
pacman100 May 9, 2023
63aa5ea
save ckpt fix
pacman100 May 9, 2023
b5e8129
Trigger CI
pacman100 May 9, 2023
d3a1e75
Merge branch 'main' into smangrul/accelerate-deepspeed-integrate
pacman100 May 10, 2023
b2d9946
Merge branch 'main' into smangrul/accelerate-dynamo-integrate
pacman100 May 10, 2023
34ed549
nasty 🐛 😅
pacman100 May 10, 2023
3412693
as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate
pacman100 May 10, 2023
792d38b
make tests happy
pacman100 May 10, 2023
2853b6b
quality ✨
pacman100 May 10, 2023
bf3194c
loss tracked needs to account for grad_acc
pacman100 May 10, 2023
4dff061
fixing the deepspeed tests
pacman100 May 10, 2023
4d8ab41
quality ✨
pacman100 May 10, 2023
9b33a39
😅😅😅
pacman100 May 10, 2023
8891364
tests 😡
pacman100 May 10, 2023
5d8738b
Merge branch 'smangrul/accelerate-dynamo-integrate' into smangrul/acc…
pacman100 May 10, 2023
a1fbcc5
quality ✨
pacman100 May 10, 2023
fc81728
Trigger CI
pacman100 May 10, 2023
26051ed
resolve comments and fix the issue with the previous merge from branch
pacman100 May 13, 2023
9ee66e1
Trigger CI
pacman100 May 14, 2023
349fdd0
accelerate took over deepspeed integration
stas00 May 15, 2023
a0f3ac1
Merge branch 'main' into smangrul/accelerate-deepspeed-integrate
pacman100 May 31, 2023
dc689be
Merge branch 'main' into smangrul/accelerate-deepspeed-integrate
pacman100 May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/main_classes/deepspeed.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,7 @@ checkpoint), then you can finish the training by first saving the final model ex
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint

checkpoint_dir = os.path.join(trainer.args.output_dir, "checkpoint-final")
trainer.deepspeed.save_checkpoint(checkpoint_dir)
trainer.model_wrapped.save_checkpoint(checkpoint_dir)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
fp32_model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
```

Expand Down
95 changes: 42 additions & 53 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import importlib.util
import weakref
from copy import deepcopy
from functools import partialmethod

from .dependency_versions_check import dep_version_check
Expand Down Expand Up @@ -256,24 +255,26 @@ def deepspeed_config():
return None


def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps):
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
"""
A convenience wrapper that deals with optimizer and lr scheduler configuration.
"""
from accelerate.utils import DummyOptim, DummyScheduler

config = hf_deepspeed_config.config

# Optimizer + Scheduler
# Currently supported combos:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes
# 4. HF scheduler + DS optimizer: Yes
# 4. HF scheduler + DS optimizer: No
stas00 marked this conversation as resolved.
Show resolved Hide resolved
#
# Unless Offload is enabled in which case it's:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: Yes
# 4. HF scheduler + DS optimizer: No
#
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)

Expand All @@ -284,6 +285,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
"--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. "
"Only one optimizer can be configured."
)
optimizer = DummyOptim(params=model_parameters)
else:
if hf_deepspeed_config.is_offload():
logger.info(
Expand All @@ -297,21 +299,21 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
config["zero_allow_untested_optimizer"] = True

def _lr_scheduler_callable(optimizer):
return trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

lr_scheduler = None
if "scheduler" not in config:
if optimizer is None:
# Optimizer is not available, so use callable to defer lr_scheduler creation to DS init
lr_scheduler = _lr_scheduler_callable
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
if "scheduler" in config:
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
"Please configure a scheduler in the DeepSpeed config."
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

return optimizer, lr_scheduler


def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
def deepspeed_init(trainer, num_training_steps, inference=False):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.

Expand All @@ -323,28 +325,22 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
inference: launch in inference mode (no optimizer and no lr scheduler)

Returns: model, optimizer, lr_scheduler
Returns: optimizer, lr_scheduler

We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612

"""
import deepspeed
from deepspeed.utils import logger as ds_logger

model = trainer.model
args = trainer.args

if hasattr(trainer, "hf_deepspeed_config_orig"):
hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig)
else:
hf_deepspeed_config = args.hf_deepspeed_config
trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config)
hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config

# resume config update - some bits like `model` and `num_training_steps` only become available during train
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
config = hf_deepspeed_config.config

# set the Deepspeed log level consistent with the Trainer
ds_logger.setLevel(args.get_process_log_level())
Expand All @@ -361,40 +357,33 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
model_parameters = None
else:
trainer.optimizer = None # important for when deepspeed_init is used as re-init
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer, lr_scheduler = deepspeed_optim_sched(
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
)

# keep for quick debug:
# from pprint import pprint; pprint(config)

kwargs = {
"model": model,
"model_parameters": model_parameters,
"config_params": config,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
}

deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)

if resume_from_checkpoint is not None:
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
# path contains what looks like a deepspeed checkpoint
import glob

deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))

if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint(
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
else:
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
return optimizer, lr_scheduler


return deepspeed_engine, optimizer, lr_scheduler
def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path):
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
# path contains what looks like a deepspeed checkpoint
import glob

deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*"))

if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {checkpoint_path}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint(
checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
else:
raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")
7 changes: 7 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
)


if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState


SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
Expand Down Expand Up @@ -1318,6 +1322,9 @@ def tearDown(self):
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
self.teardown_tmp_dirs = []
if is_accelerate_available():
AcceleratorState._reset_state()
PartialState._reset_state()


def mockenv(**kwargs):
Expand Down
Loading