Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler #25863

Merged
merged 13 commits into from
Sep 5, 2023
Merged
21 changes: 15 additions & 6 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
if is_torch_available():
import torch

from ..optimization import get_scheduler

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: No
# 4. HF scheduler + DS optimizer: Yes
#
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)

Expand Down Expand Up @@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
"Please configure a scheduler in the DeepSpeed config."
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

def _lr_scheduler_callable(optimizer):
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
return get_scheduler(
trainer.args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)

lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

return optimizer, lr_scheduler

Expand Down
18 changes: 16 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -212,6 +212,9 @@
save_fsdp_optimizer,
)

if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper


if TYPE_CHECKING:
import optuna
Expand Down Expand Up @@ -2362,7 +2365,14 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2428,6 +2438,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):

if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return

checkpoint_file_exists = (
Expand Down
26 changes: 24 additions & 2 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def get_launcher(distributed=False):
FP16 = "fp16"
BF16 = "bf16"

HF_OPTIM = "hf_optim"
HF_SCHEDULER = "hf_scheduler"
DS_OPTIM = "ds_optim"
DS_SCHEDULER = "ds_scheduler"

optims = [HF_OPTIM, DS_OPTIM]
schedulers = [HF_SCHEDULER, DS_SCHEDULER]

stages = [ZERO2, ZERO3]
if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
Expand All @@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
# Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, dtypes))

params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))


@require_deepspeed
@require_torch_gpu
Expand Down Expand Up @@ -640,10 +650,16 @@ def test_can_resume_training_errors(self, stage, dtype):
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
)

@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype):
@parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
# adapted from TrainerIntegrationTest.test_can_resume_training
# test normal resume for each stage separately, error-handling is tested in a different test

# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
# also has same losses for few steps but then slowly diverges. Need to figure it out.
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
return

output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
ds_config_dict = self.get_config_dict(stage)
if dtype == FP16:
Expand All @@ -652,6 +668,12 @@ def test_can_resume_training_normal(self, stage, dtype):
if stage == ZERO3:
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True

if optim == HF_OPTIM:
del ds_config_dict["optimizer"]

if scheduler == HF_SCHEDULER:
del ds_config_dict["scheduler"]

kwargs = {
"output_dir": output_dir,
"train_len": 128,
Expand Down