diff --git a/test/p2p/elastic_trainer_test_base.py b/test/p2p/elastic_trainer_test_base.py index e927c354..4645c3a1 100644 --- a/test/p2p/elastic_trainer_test_base.py +++ b/test/p2p/elastic_trainer_test_base.py @@ -624,60 +624,6 @@ def test_sync_retryable_exception(self): # are retryable / non-fatal self.assertEqual([410, 410, 410, 410], sums) - def test_checkpoint(self): - """ - Test with 4 trainers: - - Save checkpoint every train_step - - Trainers suicide at 3rd step - - Restart training (from checkpoint) - """ - - def process_crash(): - log.warning("Suicide, pid:{}".format(os.getpid())) - os.kill(os.getpid(), signal.SIGKILL) - - hooks = {"process_crash": process_crash} - run_id = self._generate_run_id() - - nprocs = 4 - - # Before training, there is no checkpoint - checkpoint_manager = FileSystemCheckpointManager(self.test_dir.name) - self.assertEqual(0, len(checkpoint_manager.list_checkpoints())) - - for _ in range(0, nprocs): - _, qout, qerr = self._spawn( - self._train_with_checkpoint, run_id, _train_step, hooks - ) - - # wait all training process complete - # clean up for next run - self._wait_all_and_clean() - - # we run 2 steps before suicide, expect two checkpoints be saved - self.assertEqual(2, len(checkpoint_manager.list_checkpoints())) - - qouts = [] - qerrs = [] - # start next run - for _ in range(0, nprocs): - _, qout, qerr = self._spawn( - self._train_with_checkpoint, run_id, _train_step, None - ) - qouts.append(qout) - qerrs.append(qerr) - - # Gather all nums and sums from final states, they should match the input - sums = [] - for i in range(0, nprocs): - state = _get_or_raise(qouts[i], qerrs[i]) - # Everyone reads 3 samples after recovering from checkpoint: - self.assertEqual(3, len(state.nums)) - sums.append(state.total_sum) - - # The job should be completely recovered through checkpoints / crashes: - self.assertEqual([410, 410, 410, 410], sums) - def test_process_crash(self): """ Test 4 trainers, 2 of which SIGKILL themselves and terminate. diff --git a/torchelastic/train_loop.py b/torchelastic/train_loop.py index 9f43ef34..07929b79 100644 --- a/torchelastic/train_loop.py +++ b/torchelastic/train_loop.py @@ -11,7 +11,6 @@ import warnings import torchelastic -from torchelastic.checkpoint import CheckpointUtil from torchelastic.coordinator import NonRetryableException, StopException from torchelastic.metrics import get_elapsed_time_ms, publish_metric @@ -69,8 +68,6 @@ def run_train(coordinator, train_step_gen, state): failure_count = 0 rank = 0 - checkpoint_util = CheckpointUtil(coordinator) - while not coordinator.should_stop_training(): # See: https://github.com/pytorch/elastic/issues/7 if failure_count >= MAX_FAILURES: @@ -90,9 +87,6 @@ def run_train(coordinator, train_step_gen, state): # does not sync. coordinator.barrier() - # load checkpoint if necessary - state = checkpoint_util.load_checkpoint(state, rank) - state_sync_start_time = time.time() state.sync(world_size, rank) publish_metric( @@ -100,7 +94,7 @@ def run_train(coordinator, train_step_gen, state): "state_sync.duration.ms", get_elapsed_time_ms(state_sync_start_time), ) - checkpoint_util.set_checkpoint_loaded() + coordinator.barrier() log.info("Rank {0} synced state with other nodes".format(rank)) except StopException: @@ -140,7 +134,6 @@ def run_train(coordinator, train_step_gen, state): coordinator.monitor_progress(state, worker_stats) - checkpoint_util.save_checkpoint(state, rank) if coordinator.should_rendezvous(state): log.info("Rank {0} will re-rendezvous".format(rank)) # Executor told us, for whatever reason, to re-rendezvous.