Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to save on each training node #12421

Merged
merged 3 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
# Create clone of distant repo and output directory if needed
if self.args.push_to_hub:
self.init_git_repo()
if self.is_world_process_zero():
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)

if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Expand Down Expand Up @@ -899,7 +899,7 @@ def _tune_save_checkpoint(self):
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_process_zero():
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
Expand Down Expand Up @@ -1357,10 +1357,18 @@ def train(
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)

best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warn(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `save_on_each_node=True`."
sgugger marked this conversation as resolved.
Show resolved Hide resolved
)

if self.deepspeed:
self.deepspeed.load_checkpoint(
Expand Down Expand Up @@ -1500,14 +1508,14 @@ def _save_checkpoint(self, model, trial, metrics=None):
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.is_world_process_zero():
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
elif self.is_world_process_zero() and not self.deepspeed:
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
Expand All @@ -1533,7 +1541,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
self.state.best_model_checkpoint = output_dir

# Save the Trainer state
if self.is_world_process_zero():
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))

# Save RNG state in non-distributed training
Expand Down Expand Up @@ -1562,7 +1570,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))

# Maybe delete some older checkpoints.
if self.is_world_process_zero():
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

def _load_optimizer_and_scheduler(self, checkpoint):
Expand Down Expand Up @@ -1831,27 +1839,27 @@ def save_model(self, output_dir: Optional[str] = None):
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()

if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:

# this takes care of everything as long as we aren't under zero3
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir)

if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.is_world_process_zero():
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
Expand All @@ -1862,7 +1870,7 @@ def save_model(self, output_dir: Optional[str] = None):
# This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)

elif self.is_world_process_zero():
elif self.args.should_save:
self._save(output_dir)

def _save_tpu(self, output_dir: Optional[str] = None):
Expand All @@ -1880,7 +1888,7 @@ def _save_tpu(self, output_dir: Optional[str] = None):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
save_config=self.is_world_process_zero(),
save_config=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
Expand All @@ -1889,8 +1897,8 @@ def _save_tpu(self, output_dir: Optional[str] = None):
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
if self.tokenizer is not None and self.is_world_process_zero():
self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
Expand Down Expand Up @@ -1960,7 +1968,7 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
if len(checkpoints_sorted) <= self.args.save_total_limit:
return

# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit = self.args.save_total_limit
if (
Expand Down Expand Up @@ -2436,7 +2444,7 @@ def init_git_repo(self):
"""
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
"""
if not self.is_world_process_zero():
if not self.args.should_save:
return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
repo_url = PushToHubMixin._get_repo_url_from_name(
Expand Down Expand Up @@ -2494,11 +2502,16 @@ def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) ->
Returns:
The url of the commit of your model in the given repository.
"""
if not self.is_world_process_zero():
if not self.args.should_save:
return

self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
self.save_model()

# Only push from one node.
if not self.is_world_process_zero():
return

return self.repo.push_to_hub(commit_message=commit_message)

#
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ class TrainingArguments:
save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`.
save_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`False`):
In multinode distributed training, whether save models or checkpoints on each node, or only on the main
sgugger marked this conversation as resolved.
Show resolved Hide resolved
node.
no_cuda (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to not use CUDA even when it is available or not.
seed (:obj:`int`, `optional`, defaults to 42):
Expand Down Expand Up @@ -456,6 +459,12 @@ class TrainingArguments:
)
},
)
save_on_each_node: bool = field(
default=False,
metadata={
"help": "When doing a multinode distributed training, whether to save once per node or just once on the main node."
sgugger marked this conversation as resolved.
Show resolved Hide resolved
},
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})

Expand Down Expand Up @@ -937,6 +946,19 @@ def should_log(self):
else:
return self.process_index == 0

@property
def should_save(self):
"""
Whether or not the current process should save or more generally write to disk.
sgugger marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.save_on_each_node:
return self.local_process_index == 0
else:
if is_sagemaker_mp_enabled():
return smp.rank() == 0
else:
return self.process_index == 0

def get_process_log_level(self):
"""
Returns the log level to be used depending on whether this process is the main process of node 0, main process
Expand Down