Skip to content
This repository was archived by the owner on Jan 6, 2023. It is now read-only.

Delete classy elastic trainer and dependency to pet checkpoint api #80

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 0 additions & 54 deletions test/p2p/elastic_trainer_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,60 +624,6 @@ def test_sync_retryable_exception(self):
# are retryable / non-fatal
self.assertEqual([410, 410, 410, 410], sums)

def test_checkpoint(self):
"""
Test with 4 trainers:
- Save checkpoint every train_step
- Trainers suicide at 3rd step
- Restart training (from checkpoint)
"""

def process_crash():
log.warning("Suicide, pid:{}".format(os.getpid()))
os.kill(os.getpid(), signal.SIGKILL)

hooks = {"process_crash": process_crash}
run_id = self._generate_run_id()

nprocs = 4

# Before training, there is no checkpoint
checkpoint_manager = FileSystemCheckpointManager(self.test_dir.name)
self.assertEqual(0, len(checkpoint_manager.list_checkpoints()))

for _ in range(0, nprocs):
_, qout, qerr = self._spawn(
self._train_with_checkpoint, run_id, _train_step, hooks
)

# wait all training process complete
# clean up for next run
self._wait_all_and_clean()

# we run 2 steps before suicide, expect two checkpoints be saved
self.assertEqual(2, len(checkpoint_manager.list_checkpoints()))

qouts = []
qerrs = []
# start next run
for _ in range(0, nprocs):
_, qout, qerr = self._spawn(
self._train_with_checkpoint, run_id, _train_step, None
)
qouts.append(qout)
qerrs.append(qerr)

# Gather all nums and sums from final states, they should match the input
sums = []
for i in range(0, nprocs):
state = _get_or_raise(qouts[i], qerrs[i])
# Everyone reads 3 samples after recovering from checkpoint:
self.assertEqual(3, len(state.nums))
sums.append(state.total_sum)

# The job should be completely recovered through checkpoints / crashes:
self.assertEqual([410, 410, 410, 410], sums)

def test_process_crash(self):
"""
Test 4 trainers, 2 of which SIGKILL themselves and terminate.
Expand Down
9 changes: 1 addition & 8 deletions torchelastic/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import warnings

import torchelastic
from torchelastic.checkpoint import CheckpointUtil
from torchelastic.coordinator import NonRetryableException, StopException
from torchelastic.metrics import get_elapsed_time_ms, publish_metric

Expand Down Expand Up @@ -69,8 +68,6 @@ def run_train(coordinator, train_step_gen, state):
failure_count = 0
rank = 0

checkpoint_util = CheckpointUtil(coordinator)

while not coordinator.should_stop_training():
# See: https://github.com/pytorch/elastic/issues/7
if failure_count >= MAX_FAILURES:
Expand All @@ -90,17 +87,14 @@ def run_train(coordinator, train_step_gen, state):
# does not sync.
coordinator.barrier()

# load checkpoint if necessary
state = checkpoint_util.load_checkpoint(state, rank)

state_sync_start_time = time.time()
state.sync(world_size, rank)
publish_metric(
"torchelastic",
"state_sync.duration.ms",
get_elapsed_time_ms(state_sync_start_time),
)
checkpoint_util.set_checkpoint_loaded()

coordinator.barrier()
log.info("Rank {0} synced state with other nodes".format(rank))
except StopException:
Expand Down Expand Up @@ -140,7 +134,6 @@ def run_train(coordinator, train_step_gen, state):

coordinator.monitor_progress(state, worker_stats)

checkpoint_util.save_checkpoint(state, rank)
if coordinator.should_rendezvous(state):
log.info("Rank {0} will re-rendezvous".format(rank))
# Executor told us, for whatever reason, to re-rendezvous.
Expand Down