Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Fully revert atomic checkpointing 🚨 #29370

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 11 additions & 42 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,21 +2491,13 @@ def _save_checkpoint(self, model, trial, metrics=None):

run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0:
logger.warning(
f"Checkpoint destination directory {output_dir} already exists and is non-empty. "
"Saving will proceed but saved results may be invalid."
)
staging_output_dir = output_dir
else:
staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}")
self.save_model(staging_output_dir, _internal_call=True)
self.save_model(output_dir, _internal_call=True)

if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(staging_output_dir)
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(staging_output_dir)
self._save_rng_state(output_dir)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
Expand All @@ -2525,39 +2517,16 @@ def _save_checkpoint(self, model, trial, metrics=None):

# Save the Trainer state
if self.args.should_save:
self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME))
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

if self.args.push_to_hub:
self._push_from_checkpoint(staging_output_dir)

# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()

# Then go through the rewriting process, only renaming and rotating from main process(es)
if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
if staging_output_dir != output_dir:
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)

# Ensure rename completed in cases where os.rename is not atomic
# And can only happen on non-windows based systems
if os.name != "nt":
fd = os.open(output_dir, os.O_RDONLY)
os.fsync(fd)
os.close(fd)

# Maybe delete some older checkpoints.
if self.args.should_save:
# Solely rely on numerical checkpoint id for rotation.
# mtime is not reliable especially on some fuse fs in cloud environments.
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
elif self.is_local_process_zero():
# Clean up the remaining staging checkpoint folders on other nodes
if staging_output_dir != output_dir and os.path.exists(staging_output_dir):
shutil.rmtree(staging_output_dir)

self.args.distributed_state.wait_for_everyone()
self._push_from_checkpoint(output_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
# Solely rely on numerical checkpoint id for rotation.
# mtime is not reliable especially on some fuse fs in cloud environments.
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)

def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
Expand Down
16 changes: 1 addition & 15 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@
slow,
torch_device,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -1406,19 +1405,6 @@ def test_save_checkpoints(self):
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)

def test_save_checkpoints_is_atomic(self):
class UnsaveableTokenizer(PreTrainedTokenizerBase):
def save_pretrained(self, *args, **kwargs):
raise OSError("simulated file write error")

with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
# Attach unsaveable tokenizer to partially fail checkpointing
trainer.tokenizer = UnsaveableTokenizer()
with self.assertRaises(OSError) as _context:
trainer.train()
assert get_last_checkpoint(tmpdir) is None

@require_safetensors
def test_safe_checkpoints(self):
for save_safetensors in [True, False]:
Expand Down
15 changes: 0 additions & 15 deletions tests/trainer/test_trainer_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Dict

import numpy as np
Expand Down Expand Up @@ -237,20 +236,6 @@ def compute_metrics(p: EvalPrediction) -> Dict:

trainer.args.eval_accumulation_steps = None

# Check that saving does indeed work with temp dir rotation
# If this fails, will see a FileNotFoundError
model = RegressionModel()
training_args.max_steps = 1
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda x: 1)
trainer = Trainer(
model, training_args, optimizers=(opt, sched), data_collator=DummyDataCollator(), eval_dataset=dataset
)
trainer._save_checkpoint(model=None, trial=None)
# Check that the temp folder does not exist
assert not (Path(training_args.output_dir) / "tmp-checkpoint-0").exists()
assert (Path(training_args.output_dir) / "checkpoint-0").exists()

# Check that `dispatch_batches=False` will work on a finite iterable dataset

train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
Expand Down
Loading