Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducible checkpoint #11582

Merged
merged 18 commits into from
May 4, 2021
75 changes: 72 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@
logger = logging.get_logger(__name__)


def recursive_print(state_dict, prefix=""):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
for key, value in state_dict.items():
if isinstance(value, dict):
recursive_print(value, prefix=key)
elif isinstance(value, torch.Tensor):
print(f"{prefix}/{key}: {value.shape}, {value.view(-1,).tolist()[:10]}")
else:
print(f"{prefix}/{key}: {value}")


class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
Expand Down Expand Up @@ -537,8 +547,19 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
if version.parse(torch.__version__) < version.parse("1.6.0"):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
generator = None
else:
# Torch generator were introduced in PyTorch 1.6.0.
stas00 marked this conversation as resolved.
Show resolved Hide resolved
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))

return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
self.train_dataset,
self.args.train_batch_size,
lengths=lengths,
model_input_name=model_input_name,
generator=generator,
)
else:
return DistributedLengthGroupedSampler(
Expand All @@ -553,7 +574,13 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:

else:
if self.args.world_size <= 1:
return RandomSampler(self.train_dataset)
if version.parse(torch.__version__) < version.parse("1.6.0"):
return RandomSampler(self.train_dataset)

# Torch generator were introduced in PyTorch 1.6.0.
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
return RandomSampler(self.train_dataset, generator=generator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if version.parse(torch.__version__) < version.parse("1.6.0"):
return RandomSampler(self.train_dataset)
# Torch generator were introduced in PyTorch 1.6.0.
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
return RandomSampler(self.train_dataset, generator=generator)
kwargs = {}
if has_torch_generator():
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
kwargs.update(dict(generator=generator))
return RandomSampler(self.train_dataset, **kwargs)

Here is another possible way (including the suggested earlier has_torch_generator which doesn't currently exist)

stas00 marked this conversation as resolved.
Show resolved Hide resolved
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
Expand Down Expand Up @@ -1166,6 +1193,13 @@ def train(
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")

# RNG states
checkpoint_rng_state = None
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, "rng_state.pth")
):
checkpoint_rng_state = torch.load(os.path.join(resume_from_checkpoint, "rng_state.pth"))

# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
Expand Down Expand Up @@ -1224,6 +1258,28 @@ def train(
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0 and checkpoint_rng_state is not None:
# We're finished skipping so set the RNG states to be exactly as they were at the
stas00 marked this conversation as resolved.
Show resolved Hide resolved
# checkpoint time.
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if args.local_rank != -1:
if f"cuda_{args.local_rank}" not in checkpoint_rng_state:
logger.warn(
"You are resuming a training that was launched in a distributed fashion in a "
"non-distributed way. Reproducibility cannot be guaranteed."
)
else:
torch.cuda.random.set_rng_state(checkpoint_rng_state[f"cuda_{args.local_rank}"])
else:
if f"cuda" not in checkpoint_rng_state:
logger.warn(
"You are resuming a training that was launched in a non-distributed fashion "
"with GPUs on either in a distributed fashion or not on GPUs. Reproducibility "
"cannot be guaranteed."
)
else:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
Expand Down Expand Up @@ -1459,6 +1515,19 @@ def _save_checkpoint(self, model, trial, metrics=None):
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

# Save RNG state in non-distributed training
if self.is_local_process_zero():
rng_states = {"cpu": torch.random.get_rng_state()}
Copy link
Contributor

@stas00 stas00 May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering whether you checked that different processes in dist env will have the same cpu RNG state?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another question - what about python's main RNG?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed training, the torch generator should have the same state on all processes since it's set with the same seed initially (then all processes execute the same code). Unless someone goes out of their way to only execute a random instruction on one process.

The code has been tested one GPU, 2 GPUs with DataParallel and 2 GPUs distributed and the same result are obtained with a full training and resuming from the last checkpoint.

As for python main RNG and numpy main RNG I have no way of extracting the seed (I can set it but not get the current state easily) so any code that wants to be 100% reproducible when resuming from a checkpoint needs to use torch RNGs only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for python main RNG and numpy main RNG I have no way of extracting the seed (I can set it but not get the current state easily)

Won't these work?

py_rng_state = random.getstate()
random.setstate(py_rng_state)

np_rng_state = numpy.random.get_state()
numpy.random.set_state(np_rng_state)

so any code that wants to be 100% reproducible when resuming from a checkpoint needs to use torch RNGs only.

Do you mean to say that they should stick to torch's random APIs? and convert to other formats from there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh TIL! I didn't think it was possible to extract and set those. Will add that to the PR.

if torch.cuda.is_available():
if self.args.local_rank == -1:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
# In distributed, we save the CUDA RNG states individually.
for i in range(torch.cuda.device_count()):
rng_states[f"cuda_{i}"] = torch.cuda.random.get_rng_state(i)
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if checkpoint is None:
Expand Down Expand Up @@ -2350,7 +2419,7 @@ def push_to_hub(
with tempfile.TemporaryDirectory() as tmp_dir:
for f in os.listdir(save_directory):
fname = os.path.join(save_directory, f)
if os.path.isfile(fname):
if os.path.isfile(fname) and fname != "rng_state.pth":
shutil.copy(fname, os.path.join(tmp_dir, f))

return unwrap_model(self.model)._push_to_hub(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(
batch_size: int,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
generator=None,
):
self.dataset = dataset
self.batch_size = batch_size
Expand All @@ -525,12 +526,13 @@ def __init__(
)
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
self.generator = generator

def __len__(self):
return len(self.lengths)

def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
return iter(indices)


Expand Down