Skip to content

Commit

Permalink
Revert "[fix] oss dict load (#383)" (#384)
Browse files Browse the repository at this point in the history
This reverts commit 8be9d93.
  • Loading branch information
blefaudeux authored Feb 12, 2021
1 parent 8be9d93 commit b666d6a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,16 +391,16 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups
pytorch15_index_redirect = {k: i for i, k in enumerate(state_dict["state"].keys())}

for key, value in state_dict["state"].items():
param = self.index_to_param[pytorch15_index_redirect[key]]
for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[i_param]

# Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None

else:
if key in self.index_to_param:
param = self.index_to_param[i_param]

# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
Expand Down

0 comments on commit b666d6a

Please sign in to comment.