Skip to content

Commit

Permalink
load_spawn_weights only in proc rank 0 (Lightning-AI#1385)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Reshytko <areshytko@Alexanders-MacBook-Pro.local>
  • Loading branch information
2 people authored and tullie committed May 6, 2020
1 parent 73ebab5 commit 478c8f4
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,19 @@ def load_spawn_weights(self, original_model):
:param model:
:return:
"""
# load weights saved in ddp
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())
loaded_model = original_model

# remove ddp weights
os.remove(path)
if self.proc_rank == 0:
# load weights saved in ddp
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

return loaded_model

Expand Down

0 comments on commit 478c8f4

Please sign in to comment.