Skip to content

Commit

Permalink
ref: remove weight loading hack for ddp_cpu (#3808)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Oct 2, 2020
1 parent afa4383 commit a28528c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 17 deletions.
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def cli_main():
# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
result = trainer.test(test_dataloaders=test_loader)
print(result)


if __name__ == '__main__':
Expand Down
18 changes: 2 additions & 16 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def train(self):
# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(model, best_path, last_path)
self.__recover_child_process_weights(model, best_path)
return results

def ddp_train(self, process_idx, mp_queue, model):
Expand Down Expand Up @@ -187,16 +186,10 @@ def get_device_ids(self):
device_ids = None
return device_ids

def __recover_child_process_weights(self, model, best_path, last_path):
def __recover_child_process_weights(self, model, best_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

Expand All @@ -211,10 +204,3 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)

0 comments on commit a28528c

Please sign in to comment.