Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Minor FP16 fixes when converting old checkpoints #3514

Merged
merged 4 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions parlai/agents/bart/convert_fairseq_to_parlai.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def get_parlai_opt(self) -> Opt:
# 3. Map other options
transformer_common_config.update(
{
'datatype': 'valid', # don't save an optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little weird as it doesn't seem obvious that saving datatype as valid will always exclude only the optimizer. And it's not obvious that there won't be any other effects.

Is it possible to just explicitly remove the optimizer in load_fairseq_checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do a post vacuum

'model': self.opt['model'],
# number of layers
'n_encoder_layers': fairseq_args['encoder_layers'],
Expand Down
6 changes: 5 additions & 1 deletion parlai/utils/fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def load_state_dict(self, state_dict):
(e.g., learning rate) over that found in the state_dict. This allows us to
resume training from a checkpoint using a new set of optimizer args.
"""
if 'loss_scaler' in state_dict and self.scaler is not None:
if (
'loss_scaler' in state_dict
and self.scaler is not None
and isinstance(state_dict['loss_scaler'], float)
):
self.scaler.loss_scale = state_dict['loss_scaler']
self.optimizer.load_state_dict(state_dict)

Expand Down
4 changes: 3 additions & 1 deletion parlai/zoo/bart/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
}


def download(datapath, version='v1.0'):
def download(datapath, version='v1.1'):
# v1.0: initial release
# v1.1: change the datatype in conversion for a lighter model file
dpath = os.path.join(datapath, 'models', 'bart')

if not build_data.built(dpath, version):
Expand Down