Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducible checkpoint #11582

Merged
merged 18 commits into from
May 4, 2021
1 change: 0 additions & 1 deletion examples/pytorch/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_run_ner(self):
run_ner.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_precision"], 0.75)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
self.assertLess(result["eval_loss"], 0.5)

def test_run_squad(self):
Expand Down
75 changes: 73 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import inspect
import math
import os
import random
import re
import shutil
import sys
Expand Down Expand Up @@ -127,6 +128,7 @@
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES


_is_torch_generator_available = False
_is_native_amp_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
Expand All @@ -141,6 +143,7 @@
from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
_is_native_amp_available = True
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -525,6 +528,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
return None

generator = None
if self.args.world_size <= 1 and _is_torch_generator_available:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm very late here, but shouldn't this use self.args.seed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is copied from PyTorch actually. Since torch has been seeded, this will be deterministic.


# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
Expand All @@ -538,7 +546,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
self.train_dataset,
self.args.train_batch_size,
lengths=lengths,
model_input_name=model_input_name,
generator=generator,
)
else:
return DistributedLengthGroupedSampler(
Expand All @@ -553,7 +565,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:

else:
if self.args.world_size <= 1:
return RandomSampler(self.train_dataset)
return RandomSampler(self.train_dataset, generator=generator)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
Expand Down Expand Up @@ -1224,6 +1236,8 @@ def train(
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
Expand Down Expand Up @@ -1381,6 +1395,41 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _load_rng_state(self, checkpoint):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method also print a warning in case we're on TPU as we don't expect reproducibility when on TPUs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, maybe just a comment in the README where we document resuming from checkpoint? I don't really want to issue a warning for each run on TPU using a checkpoint.

# Load RNG states from `checkpoint`
if checkpoint is None:
return

local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank != -1:
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
if not os.path.isfile(os.path.join(checkpoint, rng_file)):
logger.info(
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(os.path.join(checkpoint, rng_file)):
logger.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return

checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.local_rank != -1:
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
else:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
if is_torch_tpu_available():
xm.set_rng_state(checkpoint_rng_state["xla"])

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
Expand Down Expand Up @@ -1459,6 +1508,28 @@ def _save_checkpoint(self, model, trial, metrics=None):
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
if self.args.local_rank == -1:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()

if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()

local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if checkpoint is None:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
batch_size: int,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
generator=None,
):
self.dataset = dataset
self.batch_size = batch_size
Expand All @@ -525,12 +526,13 @@ def __init__(
)
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
self.generator = generator

def __len__(self):
return len(self.lengths)

def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
return iter(indices)


Expand Down