Skip to content

Commit

Permalink
deepspeed resume from ckpt fixes and adding support for deepspeed opt…
Browse files Browse the repository at this point in the history
…imizer and HF scheduler (#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
  • Loading branch information
pacman100 authored Sep 5, 2023
1 parent 1110b56 commit 6bc517c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
21 changes: 15 additions & 6 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
if is_torch_available():
import torch

from ..optimization import get_scheduler

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: No
# 4. HF scheduler + DS optimizer: Yes
#
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)

Expand Down Expand Up @@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
"Please configure a scheduler in the DeepSpeed config."
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

def _lr_scheduler_callable(optimizer):
return get_scheduler(
trainer.args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)

lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

return optimizer, lr_scheduler

Expand Down
18 changes: 16 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -212,6 +212,9 @@
save_fsdp_optimizer,
)

if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper


if TYPE_CHECKING:
import optuna
Expand Down Expand Up @@ -2362,7 +2365,14 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2428,6 +2438,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):

if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return

checkpoint_file_exists = (
Expand Down
26 changes: 24 additions & 2 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def get_launcher(distributed=False):
FP16 = "fp16"
BF16 = "bf16"

HF_OPTIM = "hf_optim"
HF_SCHEDULER = "hf_scheduler"
DS_OPTIM = "ds_optim"
DS_SCHEDULER = "ds_scheduler"

optims = [HF_OPTIM, DS_OPTIM]
schedulers = [HF_SCHEDULER, DS_SCHEDULER]

stages = [ZERO2, ZERO3]
if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
Expand All @@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
# Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, dtypes))

params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))


@require_deepspeed
@require_torch_gpu
Expand Down Expand Up @@ -640,10 +650,16 @@ def test_can_resume_training_errors(self, stage, dtype):
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
)

@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype):
@parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
# adapted from TrainerIntegrationTest.test_can_resume_training
# test normal resume for each stage separately, error-handling is tested in a different test

# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
# also has same losses for few steps but then slowly diverges. Need to figure it out.
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
return

output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
ds_config_dict = self.get_config_dict(stage)
if dtype == FP16:
Expand All @@ -652,6 +668,12 @@ def test_can_resume_training_normal(self, stage, dtype):
if stage == ZERO3:
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True

if optim == HF_OPTIM:
del ds_config_dict["optimizer"]

if scheduler == HF_SCHEDULER:
del ds_config_dict["scheduler"]

kwargs = {
"output_dir": output_dir,
"train_len": 128,
Expand Down

0 comments on commit 6bc517c

Please sign in to comment.