Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Minor FP16 fixes when converting old checkpoints (#3514)
Browse files Browse the repository at this point in the history
* Be robust to loading old checkpoints with apex scalers

* Minor fp16 fixes

* Vacuum instead
  • Loading branch information
stephenroller authored Mar 22, 2021
1 parent 4a55d3d commit 35b3156
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
3 changes: 3 additions & 0 deletions parlai/agents/bart/convert_fairseq_to_parlai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.script import ParlaiScript
from parlai.utils.io import PathManager
from parlai.scripts.vacuum import Vacuum


TRANSFORMER_PARAMETER_MAPPING = {
Expand Down Expand Up @@ -133,6 +134,8 @@ def run(self):
self.agent.model.load_state_dict(converted, True)
self.agent.opt.pop('converting', None)
self.agent.save(self.opt['output'])
# kill the optimizer
Vacuum.main(model_file=self.opt['output'], no_backup=True)
# 4. enjoy!
self.print_agent_act()

Expand Down
6 changes: 5 additions & 1 deletion parlai/utils/fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def load_state_dict(self, state_dict):
(e.g., learning rate) over that found in the state_dict. This allows us to
resume training from a checkpoint using a new set of optimizer args.
"""
if 'loss_scaler' in state_dict and self.scaler is not None:
if (
'loss_scaler' in state_dict
and self.scaler is not None
and isinstance(state_dict['loss_scaler'], float)
):
self.scaler.loss_scale = state_dict['loss_scaler']
self.optimizer.load_state_dict(state_dict)

Expand Down
4 changes: 3 additions & 1 deletion parlai/zoo/bart/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
}


def download(datapath, version='v1.0'):
def download(datapath, version='v1.1'):
# v1.0: initial release
# v1.1: change the datatype in conversion for a lighter model file
dpath = os.path.join(datapath, 'models', 'bart')

if not build_data.built(dpath, version):
Expand Down

0 comments on commit 35b3156

Please sign in to comment.