From a28528cc8bf6323403b2f2f413d324810d4a4572 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 2 Oct 2020 19:28:50 -0400 Subject: [PATCH] ref: remove weight loading hack for ddp_cpu (#3808) --- pl_examples/basic_examples/image_classifier.py | 3 ++- .../accelerators/ddp_cpu_spawn_backend.py | 18 ++---------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/pl_examples/basic_examples/image_classifier.py b/pl_examples/basic_examples/image_classifier.py index 04e965c0491f6..0968525e41197 100644 --- a/pl_examples/basic_examples/image_classifier.py +++ b/pl_examples/basic_examples/image_classifier.py @@ -118,7 +118,8 @@ def cli_main(): # ------------ # testing # ------------ - trainer.test(test_dataloaders=test_loader) + result = trainer.test(test_dataloaders=test_loader) + print(result) if __name__ == '__main__': diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py index 5cc1b2ff3a305..bc494147b905c 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -62,10 +62,9 @@ def train(self): # restore main state with best weights best_path = self.mp_queue.get() results = self.mp_queue.get() - last_path = self.mp_queue.get() # recover the weights of the processes trained in the children - self.__recover_child_process_weights(model, best_path, last_path) + self.__recover_child_process_weights(model, best_path) return results def ddp_train(self, process_idx, mp_queue, model): @@ -187,16 +186,10 @@ def get_device_ids(self): device_ids = None return device_ids - def __recover_child_process_weights(self, model, best_path, last_path): + def __recover_child_process_weights(self, model, best_path): # transfer back the best path to the trainer if self.trainer.checkpoint_callback: self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path is not None and not self.trainer.testing: - ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) self.trainer.model = model @@ -211,10 +204,3 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path)