Skip to content

Commit

Permalink
Merge pull request #23 from moskomule/dev
Browse files Browse the repository at this point in the history
fix weight save
  • Loading branch information
moskomule authored Jul 2, 2019
2 parents 7e78f45 + 22af20c commit be11001
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions homura/callbacks/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@ def save(self,
data: Mapping,
file_name: str):
try:
# scheduler is not a must
scheduler_state_dict = data.get(SCHEDULER)
if scheduler_state_dict is not None:
scheduler_state_dict = scheduler_state_dict.state_dict()

torch.save({"git": get_git_hash(),
"args": get_args(),
MODEL: data[MODEL].state_dict(),
OPTIMIZER: data[OPTIMIZER].state_dict(),
SCHEDULER: scheduler_state_dict,
EPOCH: self._epoch,
STEP: self._step},
self.save_path / file_name)
Expand Down

0 comments on commit be11001

Please sign in to comment.