Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducible checkpoint #11582

Merged
merged 18 commits into from
May 4, 2021
1 change: 0 additions & 1 deletion examples/pytorch/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_run_ner(self):
run_ner.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_precision"], 0.75)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
self.assertLess(result["eval_loss"], 0.5)

def test_run_squad(self):
Expand Down
68 changes: 65 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES


_is_torch_generator_available = False
_is_native_amp_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
Expand All @@ -141,6 +142,7 @@
from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
_is_native_amp_available = True
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -182,6 +184,16 @@
logger = logging.get_logger(__name__)


def recursive_print(state_dict, prefix=""):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
for key, value in state_dict.items():
if isinstance(value, dict):
recursive_print(value, prefix=key)
elif isinstance(value, torch.Tensor):
print(f"{prefix}/{key}: {value.shape}, {value.view(-1,).tolist()[:10]}")
else:
print(f"{prefix}/{key}: {value}")


class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
Expand Down Expand Up @@ -525,6 +537,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
return None

generator = None
if self.args.world_size <= 1 and _is_torch_generator_available:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm very late here, but shouldn't this use self.args.seed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is copied from PyTorch actually. Since torch has been seeded, this will be deterministic.


# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
Expand All @@ -538,7 +555,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
self.train_dataset,
self.args.train_batch_size,
lengths=lengths,
model_input_name=model_input_name,
generator=generator,
)
else:
return DistributedLengthGroupedSampler(
Expand All @@ -553,7 +574,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:

else:
if self.args.world_size <= 1:
return RandomSampler(self.train_dataset)
return RandomSampler(self.train_dataset, generator=generator)
stas00 marked this conversation as resolved.
Show resolved Hide resolved
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
Expand Down Expand Up @@ -1224,6 +1245,8 @@ def train(
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
Expand Down Expand Up @@ -1381,6 +1404,32 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _load_rng_state(self, checkpoint):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method also print a warning in case we're on TPU as we don't expect reproducibility when on TPUs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, maybe just a comment in the README where we document resuming from checkpoint? I don't really want to issue a warning for each run on TPU using a checkpoint.

# Load RNG states from `checkpoint`
if checkpoint is None or not os.path.isfile(os.path.join(checkpoint, "rng_state.pth")):
return

checkpoint_rng_state = torch.load(os.path.join(checkpoint, "rng_state.pth"))
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.local_rank != -1:
if f"cuda_{self.args.local_rank}" not in checkpoint_rng_state:
logger.warn(
"You are resuming a training that was launched in a distributed fashion in a "
"non-distributed way. Reproducibility cannot be guaranteed."
)
else:
torch.cuda.random.set_rng_state(checkpoint_rng_state[f"cuda_{self.args.local_rank}"])
else:
if "cuda" not in checkpoint_rng_state:
logger.warn(
"You are resuming a training that was launched in a non-distributed fashion "
"with GPUs on either in a distributed fashion or not on GPUs. Reproducibility "
"cannot be guaranteed."
)
else:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
Expand Down Expand Up @@ -1459,6 +1508,19 @@ def _save_checkpoint(self, model, trial, metrics=None):
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

# Save RNG state in non-distributed training
if self.is_local_process_zero():
rng_states = {"cpu": torch.random.get_rng_state()}
Copy link
Contributor

@stas00 stas00 May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering whether you checked that different processes in dist env will have the same cpu RNG state?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another question - what about python's main RNG?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed training, the torch generator should have the same state on all processes since it's set with the same seed initially (then all processes execute the same code). Unless someone goes out of their way to only execute a random instruction on one process.

The code has been tested one GPU, 2 GPUs with DataParallel and 2 GPUs distributed and the same result are obtained with a full training and resuming from the last checkpoint.

As for python main RNG and numpy main RNG I have no way of extracting the seed (I can set it but not get the current state easily) so any code that wants to be 100% reproducible when resuming from a checkpoint needs to use torch RNGs only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for python main RNG and numpy main RNG I have no way of extracting the seed (I can set it but not get the current state easily)

Won't these work?

py_rng_state = random.getstate()
random.setstate(py_rng_state)

np_rng_state = numpy.random.get_state()
numpy.random.set_state(np_rng_state)

so any code that wants to be 100% reproducible when resuming from a checkpoint needs to use torch RNGs only.

Do you mean to say that they should stick to torch's random APIs? and convert to other formats from there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh TIL! I didn't think it was possible to extract and set those. Will add that to the PR.

if torch.cuda.is_available():
if self.args.local_rank == -1:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
# In distributed, we save the CUDA RNG states individually.
for i in range(torch.cuda.device_count()):
rng_states[f"cuda_{i}"] = torch.cuda.random.get_rng_state(i)
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if checkpoint is None:
Expand Down Expand Up @@ -2350,7 +2412,7 @@ def push_to_hub(
with tempfile.TemporaryDirectory() as tmp_dir:
for f in os.listdir(save_directory):
fname = os.path.join(save_directory, f)
if os.path.isfile(fname):
if os.path.isfile(fname) and fname != "rng_state.pth":
shutil.copy(fname, os.path.join(tmp_dir, f))

return unwrap_model(self.model)._push_to_hub(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
batch_size: int,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
generator=None,
):
self.dataset = dataset
self.batch_size = batch_size
Expand All @@ -525,12 +526,13 @@ def __init__(
)
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
self.generator = generator

def __len__(self):
return len(self.lengths)

def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
return iter(indices)


Expand Down