Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load checkpoint without re-creating the model #11318

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, **kwargs):
self._name_or_path = str(kwargs.pop("name_or_path", ""))

# Drop the transformers version info
kwargs.pop("transformers_version", None)
self.transformers_version = kwargs.pop("transformers_version", None)

# Additional attributes without default values
for key, value in kwargs.items():
Expand Down
31 changes: 20 additions & 11 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
is_apex_available,
is_datasets_available,
Expand Down Expand Up @@ -997,14 +1000,23 @@ def train(

logger.info(f"Loading model from {resume_from_checkpoint}).")

if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warn(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)

if self.deepspeed:
# will be resumed in deepspeed_init
pass
elif isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(resume_from_checkpoint)
model_reloaded = True
else:
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self.model.load_state_dict(state_dict)

# If model was re-initialized, put it on the right device and update self.model_wrapped
Expand Down Expand Up @@ -1281,13 +1293,10 @@ def train(
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if self.place_model_on_device:
self.model = self.model.to(args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self.model.load_state_dict(state_dict)

if self.deepspeed:
self.deepspeed.load_checkpoint(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,46 @@ def test_resume_training_with_gradient_accumulation(self):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

def test_resume_training_with_frozen_params(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.model.a.requires_grad_(False)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)

checkpoint = os.path.join(tmpdir, "checkpoint-5")

# Reinitialize trainer
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.model.a.requires_grad_(False)

trainer.train(resume_from_checkpoint=checkpoint)

self.assertFalse(trainer.model.a.requires_grad)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down