Skip to content

Commit

Permalink
Reproducible checkpoint (#11582)
Browse files Browse the repository at this point in the history
* Set generator in dataloader

* Use generator in all random samplers

* Checkpoint all RNG states

* Final version

* Quality

* Test

* Address review comments

* Quality

* Remove debug util

* Add python and numpy RNGs

* Split states in different files in distributed

* Quality

* local_rank for TPUs

* Only use generator when accepted

* Add test

* Set seed to avoid flakiness

* Make test less flaky

* Quality
  • Loading branch information
sgugger authored May 4, 2021
1 parent 0afe4a9 commit 6b241e0
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 3 deletions.
1 change: 0 additions & 1 deletion examples/pytorch/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_run_ner(self):
run_ner.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_precision"], 0.75)
self.assertLess(result["eval_loss"], 0.5)

def test_run_squad(self):
Expand Down
75 changes: 74 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import inspect
import math
import os
import random
import re
import shutil
import sys
Expand Down Expand Up @@ -127,6 +128,7 @@
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES


_is_torch_generator_available = False
_is_native_amp_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
Expand All @@ -141,6 +143,7 @@
from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
_is_native_amp_available = True
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -525,6 +528,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
return None

generator = None
if self.args.world_size <= 1 and _is_torch_generator_available:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))

# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
Expand All @@ -538,7 +546,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
self.train_dataset,
self.args.train_batch_size,
lengths=lengths,
model_input_name=model_input_name,
generator=generator,
)
else:
return DistributedLengthGroupedSampler(
Expand All @@ -553,6 +565,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:

else:
if self.args.world_size <= 1:
if _is_torch_generator_available:
return RandomSampler(self.train_dataset, generator=generator)
return RandomSampler(self.train_dataset)
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
Expand Down Expand Up @@ -1224,6 +1238,8 @@ def train(
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
Expand Down Expand Up @@ -1381,6 +1397,41 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return

local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank != -1:
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
if not os.path.isfile(os.path.join(checkpoint, rng_file)):
logger.info(
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(os.path.join(checkpoint, rng_file)):
logger.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return

checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.local_rank != -1:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
else:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
if is_torch_tpu_available():
xm.set_rng_state(checkpoint_rng_state["xla"])

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
Expand Down Expand Up @@ -1460,6 +1511,28 @@ def _save_checkpoint(self, model, trial, metrics=None):
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
if self.args.local_rank == -1:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()

if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()

local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if checkpoint is None:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
batch_size: int,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
generator=None,
):
self.dataset = dataset
self.batch_size = batch_size
Expand All @@ -525,12 +526,13 @@ def __init__(
)
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
self.generator = generator

def __len__(self):
return len(self.lengths)

def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
return iter(indices)


Expand Down
52 changes: 52 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import dataclasses
import gc
import math
import os
import random
import re
import tempfile
import unittest
Expand Down Expand Up @@ -195,6 +197,28 @@ def forward(self, input_x, labels=None, **kwargs):
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

class RegressionRandomPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"

def __init__(self, config):
super().__init__(config)
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
self.b = torch.nn.Parameter(torch.tensor(config.b).float())

def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
torch_rand = torch.randn(1).squeeze()
np_rand = np.random.rand()
rand_rand = random.random()

y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand)

if labels is None:
return (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y)

class TstLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
Expand Down Expand Up @@ -699,6 +723,34 @@ def test_can_resume_training(self):
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))

def test_resume_training_with_randomness(self):
if torch.cuda.device_count() >= 2:
# This test will fail flakily for more than 2 GPUs since the result will be slightly more different.
return

if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset()

config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)

trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()

model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()

self.assertTrue(math.isclose(a, a1, rel_tol=1e-8))
self.assertTrue(math.isclose(b, b1, rel_tol=1e-8))

def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
Expand Down

0 comments on commit 6b241e0

Please sign in to comment.