Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[deepspeed] fix --load_best_model_at_end #14652

Merged
merged 8 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,18 @@ def _lr_scheduler_callable(optimizer):
return optimizer, lr_scheduler


def deepspeed_reinit(trainer):
"""
this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until
Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping
https://github.com/microsoft/DeepSpeed/issues/1612
"""
import deepspeed

deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs)
return deepspeed_engine, optimizer, lr_scheduler


def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
Expand Down Expand Up @@ -398,19 +410,24 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
model_parameters = None
else:
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))

# keep for quick debug:
# from pprint import pprint; pprint(config)

model, optimizer, _, lr_scheduler = deepspeed.initialize(
kwargs = dict(
model=model,
model_parameters=model_parameters,
config_params=config,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)

deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)

# stash kwargs to enabled a later deepspeed_reinit
trainer.deepspeed_initialize_kwargs = kwargs

if resume_from_checkpoint is not None:

# it's possible that the user is trying to resume from model_path, which doesn't necessarily
Expand All @@ -424,12 +441,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = model.load_checkpoint(
load_path, _ = deepspeed_engine.load_checkpoint(
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
else:
logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")

return model, optimizer, lr_scheduler
return deepspeed_engine, optimizer, lr_scheduler
30 changes: 20 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -1434,21 +1434,28 @@ def train(

best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warn(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)

if self.deepspeed:
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
)

# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
Expand Down Expand Up @@ -1975,6 +1982,9 @@ def save_model(self, output_dir: Optional[str] = None):
# This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)

# save a deepspeed checkpoint as well (this is very fast)
self.deepspeed.save_checkpoint(output_dir)

elif self.args.should_save:
self._save(output_dir)

Expand Down
51 changes: 51 additions & 0 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@ def test_can_resume_training_normal(self, stage):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

# Finally, should be able to resume with the same trainer/same deepspeed engine instance
# XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612
# trainer.train(resume_from_checkpoint=checkpoint)
# a workaround needs to be used that re-creates the deepspeed engine

@parameterized.expand(stages)
def test_load_state_dict_from_zero_checkpoint(self, stage):
# test that we can load fp32 weights directly from the zero checkpoint into the current model
Expand Down Expand Up @@ -968,3 +973,49 @@ def test_clm_from_config_zero3(self):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
assert "Detected DeepSpeed ZeRO-3" in cs.err

@parameterized.expand(stages)
def test_load_best_model(self, stage):
# this test exercises --load_best_model_at_end - the key is being able to resume after some training

data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {T5_TINY}
--tokenizer_name {T5_TINY}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--output_dir {output_dir}
--overwrite_output_dir
--source_lang en
--target_lang ro
--do_train
--max_train_samples 3
--do_eval
--max_eval_samples 1
--logging_strategy steps
--logging_steps 1
--evaluation_strategy steps
--eval_steps 1
--save_strategy steps
--save_steps 1
--load_best_model_at_end
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--num_train_epochs 1
--fp16
--report_to none
""".split()
args.extend(["--source_prefix", "translate English to Romanian: "])

ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = get_launcher(distributed=False)

cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# enough to test it didn't fail
assert "Detected DeepSpeed ZeRO-3" in cs.err