Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Megatron conversion code converts some weights in fp16 to fp32(or uint8). #13193

Closed
4 tasks
hwijeen opened this issue Aug 20, 2021 · 0 comments · Fixed by #13194
Closed
4 tasks

Megatron conversion code converts some weights in fp16 to fp32(or uint8). #13193

hwijeen opened this issue Aug 20, 2021 · 0 comments · Fixed by #13194

Comments

@hwijeen
Copy link
Contributor

hwijeen commented Aug 20, 2021

Environment info

transformers version: 4.9.2
Platform: Linux-4.18.0-25-generic-x86_64-with-glibc2.10
Python version: 3.8.5
PyTorch version (GPU?): 1.8.0a0+52ea372 (True)
Tensorflow version (GPU?): not installed (NA)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Using GPU in script?:
Using distributed or parallel set-up in script?: No

Who can help

@novatig @jdemouth @LysandreJik

Information

Model I am using (Bert, XLNet ...):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

  1. Check the data type of original megatron checkpoint. It's all in fp16.
wget --content-disposition  https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O checkpoint.zip
unzip checkpoint.zip
python -c "import torch; from pprint import pprint as print; sd=torch.load('./release/mp_rank_00/model_optim_rng.pt'); d= {d.dtype: 1 for d in sd['model']['language_model']['transformer'].values()}; print(d.keys())"
# dict_keys([torch.float16])
  1. But the current conversion script converts some into float32 and uint8. This leads to a model with data type which is not faithful to the original model, and potentially a problem as discussed in respect dtype of the the model when instiating not working #13076
python3 /hf/transformers-master/src/transformers/models/megatron_bert/convert_megatron_gpt2_checkpoint.py checkpoint.zip
python -c "import torch; sd=torch.load('pytorch_model.bin');  d = {p.dtype:1 for p in sd.values() }; print(d.keys())"
# dict_keys([torch.float16, torch.float32, torch.uint8])

Expected behavior

Converted checkpoint should have the same data type as the original one.

I will open a new PR to address this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant