Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp16 models getting auto converted to fp32 in .from_pretrained() #12062

Closed
2 tasks
asit2898 opened this issue Jun 8, 2021 · 36 comments · Fixed by #12316
Closed
2 tasks

fp16 models getting auto converted to fp32 in .from_pretrained() #12062

asit2898 opened this issue Jun 8, 2021 · 36 comments · Fixed by #12316
Assignees

Comments

@asit2898
Copy link

asit2898 commented Jun 8, 2021

stas00 edited: this Issue has nothing to do with Deepspeed, but pure transformers


Environment info

  • transformers version: 4.6.1
  • Platform: Linux-3.10.0-1127.13.1.el7.x86_64-x86_64-with-centos-7.7.1908-Core
  • Python version: 3.6.8
  • PyTorch version (GPU?): 1.6.0+cu92 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Yes (not essential)
  • Using distributed or parallel set-up in script?: Yes (not essential)

Who can help

@LysandreJik @sgugger

Information

Model I am using (Bert, XLNet ...): BertForMaskedLM

The problem arises when using:

  • my own modified scripts: (give details below)

The tasks I am working on is:

  • my own task or dataset: (give details below)
    Masked LM

To reproduce

Steps to reproduce the behavior:

  1. Finetune a 16-bit low precision BertForMaskedLM model on any dataset using DeepSpeed and Trainer
  2. Load the model and check the dtype using:
from transformers import BertTokenizer, BertForMaskedLM
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
model = BertForMaskedLM.from_pretrained(model_path)
print(model.dtype)

Expected behavior

Outputs torch.float32 instead of the expected torch.float16. I was able to recover the original weights using model.half()

I think it would be helpful to highlight this behaviour of forced autoconversion either as a warning or as a part of from_pretrained() method's documentation or provide an additional argument to help retain fp16 weights. Willing to pick this issue up. Please let me know what would be the most appropriate fix.

@sgugger
Copy link
Collaborator

sgugger commented Jun 8, 2021

cc @stas00

@stas00
Copy link
Contributor

stas00 commented Jun 8, 2021

Oh, do you mean that your model was already in fp16 to start with? This combination I haven't tried yet.

First when reporting Deepspeed problems please always share the deepspeed config file and the TrainingArguments.

and then we can look at sorting it out.

@stas00 stas00 self-assigned this Jun 8, 2021
@asit2898
Copy link
Author

asit2898 commented Jun 8, 2021

Yes, the saved model was already in fp16. Apologies, here are the needed files:

A) DeepSpeed config file:

{"zero_allow_untested_optimizer": true,
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr":3e-5,
      "betas": [
        0.9,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  },
"scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 500
        }
    },
"train_batch_size": 24,
"fp16": {
      "enabled": true,
      "loss_scale": 0,
      "initial_scale_power": 16
    }

}

B) Training Arguments:

TrainingArguments(output_dir=/data/dps_finetune_16_wikitext, overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, evaluation_strategy=IntervalStrategy.STEPS, prediction_loss_only=False, per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=1, eval_accumulation_steps=None, learning_rate=5e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=10.0, max_steps=-1, lr_scheduler_type=SchedulerType.LINEAR, warmup_ratio=0.0, warmup_steps=0, logging_dir=runs/Jun08_18-02-30_jp3-g-31374-37031-i-2p4p2, logging_strategy=IntervalStrategy.STEPS, logging_first_step=False, logging_steps=500, save_strategy=IntervalStrategy.STEPS, save_steps=100, save_total_limit=5, no_cuda=False, seed=42, fp16=False, fp16_opt_level=O1, fp16_backend=auto, fp16_full_eval=False, local_rank=0, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=10, dataloader_num_workers=0, past_index=-1, run_name=/data/dps_finetune_16_wikitext, disable_tqdm=False, remove_unused_columns=True, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, sharded_ddp=[], deepspeed=/data/config_fine_tune_bert.json, label_smoothing_factor=0.0, adafactor=False, group_by_length=False, length_column_name=length, report_to=['mlflow', 'tensorboard'], ddp_find_unused_parameters=None, dataloader_pin_memory=False, skip_memory_metrics=False, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, _n_gpu=1, mp_parameters=)

fp16 is set to False. I have also tried with fp16=True but no difference in behaviour was observed.

I also tested by loading the saved fp16 state_dict separately using torch.load() and then used it to initialize the BertForMaskedLM as follows:

import torch
from transformers import BertConfig

state_dict = torch.load(model_path+ "pytorch_model.bin")
config = BertConfig.from_json_file(model_path+ "config.json")
model = BertForMaskedLM.from_pretrained(None,config = config, state_dict = a)
model.dtype

model.dtype still outputs torch.float32.

The config.json file above (saved model's config file) is as follows:

{
  "_name_or_path": "/data/bert-base-cased/",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.6.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

The _name_or_path points to the location of the pre-finetuning fp32 model. However, changing its value to the post-finetuning fp16 model also does not lead to any change in model.dtype output. Please let me know if there are any checks I could run or files I could provide.
Thanks!

@stas00
Copy link
Contributor

stas00 commented Jun 9, 2021

Thank you for sharing these details. So indeed this looks like a case I haven't run into and this is not an integration issue.

So under zero3 from_pretrained calls zero.Init() which prepares the model for deepspeed's stage 3 work and it also gathers/scatters the model pieces across the gpus during state_dict loading. So this is the doing of one of these 2. But they are needed in order to use the deepspeed optimizer which works either in fp32 or mixed precision mode - Deepspeeds's fp16.enabled == mixed precision. They currently don't have fp16 non-mixed precision mode as far as I know. But clearly there is a need for that.

Most likely this is something Deepspeed core will have to solve. This use case is probably new to them too.

So please kindly use https://github.com/microsoft/DeepSpeed/issues/new to post the same details (Edit -> Copy-n-Paste) there.
and please tag me so that I could track the outcome and adjust things if need be in our side.

Thank you, @asit2898

@ShadenSmith
Copy link

Hi @asit2898 , thanks for reporting your issue. I can help look at things from DeepSpeed's side.

Was the model fine-tuned with ZeRO enabled? From the DS config above it seems not, unless it is enabled somewhere on the HF side of things.

@stas00 , does the from_pretrained codepath go through DeepSpeed's load_checkpoint(), or is the checkpoint logic all on HF's side?

To start, I did a quick experiment with DeepSpeed (without ZeRO) and examined model parameter dtypes before and after deepspeed.initialize(). So far I haven't reproduced the issue:

  • When FP16 is not enabled, the model's dtype is unchanged (eg., fp32 stays fp32 and fp16 stays fp16).
  • When fp16 is enabled, the model weights are fp16 after deepspeed.initialize() no matter the initial dtype of fp32 or fp16.

@stas00
Copy link
Contributor

stas00 commented Jun 9, 2021

@stas00 , does the from_pretrained codepath go through DeepSpeed's load_checkpoint(), or is the checkpoint logic all on HF's side?

As posted above from_pretrained

So under zero3 from_pretrained:

  1. calls zero.Init() which prepares the model for deepspeed's stage 3 work and
  2. it also gathers/scatters the model pieces across the gpus during state_dict loading.

I did a quick experiment with DeepSpeed (without ZeRO)

The key is zero3. from_pretrained doesn't do anything deepspeed-wise unless it's zero3.

@asit2898
Copy link
Author

asit2898 commented Jun 9, 2021

@ShadenSmith @stas00 Thanks for the replies! I did not enable any stage of ZeRO and just ran DeepSpeed using pure data parallelism.
The saved model was in fp16 at the end of DeepSpeed finetuning using HG Trainer which I think is in accordance with the experiments you carried out...

It is only after I load the saved model using .from_pretrained() method that the weights get auto-converted to 32 bits...

I am not very familiar with HG source code, but given that .from_pretrained() takes only the state_dict and model configuration as arguments, especially in the following case that I mentioned:

import torch
from transformers import BertConfig

state_dict = torch.load(model_path+ "pytorch_model.bin")
config = BertConfig.from_json_file(model_path+ "config.json")
model = BertForMaskedLM.from_pretrained(None,config = config, state_dict = a)
model.dtype

The HG object behaviour should be independent of whether or not the model was trained on DeepSpeed right 🤔
Let me know if there are any experiments that can help isolate the effects of DeepSpeed from those of HG.

@ShadenSmith
Copy link

Thanks for the clarification @asit2898 / @stas00 .

@stas00 , I don't yet understand the conclusion that the issue is in core DeepSpeed. Since ZeRO-3 is not enabled, is HF expecting the Init() to do something else? It should just be a no-op so long as Z3 is not enabled. Is the expectation on HF's side that there are fp32 weights that should be converted to fp16 in this instance? Or is the thought that Init() is still executing, and the weights are bumped to fp32 there when scattering?

The only model dtype transformations that we should be making are converting to FP16 when that is enabled. This issue is going in the opposite direction and I am not sure where the FP32 conversion would happen.

@stas00
Copy link
Contributor

stas00 commented Jun 10, 2021

OK, Let me try to reproduce this first and then it'd be much easier to discuss this further.

for some reason I was under the impression that zero3 was enabled! but reviewing the config posted by @asit2898 it's not.

I will make an fp16 model, try to reproduce the problem and then follow up.

@stas00
Copy link
Contributor

stas00 commented Jun 10, 2021

OK, this doesn't seem to have anything to do with Deepspeed.

Observe:

import torch
from transformers import BertForMaskedLM

mname = "prajjwal1/bert-tiny"
model = BertForMaskedLM.from_pretrained(mname)
model = model.half()
print(model.dtype)

model_path = "/tmp/bert-fp16"
model.save_pretrained(model_path)

model = BertForMaskedLM.from_pretrained(model_path)
print(model.dtype)

prints:

torch.float16
torch.float32

I will look next at why this bug is happening.

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

OK, so it's not even transformers, it's pytorch that does that in load_state_dict pytorch/pytorch#39428

Here is a standalone torch example:

import torch
from torch import nn

model = nn.Linear(1,1)
model = model.half()
print(model.weight.dtype)
torch.save(model.state_dict(), 'model.pkl')

model = nn.Linear(1,1)
model.load_state_dict(torch.load('model.pkl'))
print(model.weight.dtype)

prints

torch.float16
torch.float32

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

Thinking more about it I think load_state_dict does the right thing. It adjusts the weights to the dtype of the model.

Since the user can't access the model until after from_pretrained they have no chance to choose its dtype.

  1. So one possible solution here is to add an optional dtype arg to from_pretrained and if it's passed, do:
model.to(dtype=dtype)

as soon as it's instantiated.

  1. An alternative approach is to sample the weight's dtype and convert the model automatically to that type. Is it ever possible that the weights could be of different dtype? If not this might be the transparent solution.

Of course, the user could do model.half() immediately after from_pretrained but the problem is that it will require 2x RAM which the user might not have, so the switching should occur before weights loading.

@sgugger, @LysandreJik, @patrickvonplaten - what do you think?

@sgugger
Copy link
Collaborator

sgugger commented Jun 11, 2021

I'm okay with having a dtype argument to from_pretrained, personally.

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

I edited just now to offer an automatic detection. item 2.

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

@asit2898, until we sort it out please use model.half() after from_pretrained as a workaround.

@LysandreJik
Copy link
Member

I'm fine with having a dtype argument to from_pretrained as well, and if possible an automatic detection would be even better.

I would also be fine with a configuration attribute that would identify between fp32/fp16/bfloat16, as users have been surprised in the past that models weighing ~500mb on the hub ended up taking up much more RAM and much more disk space on their machines in the past (automatic detection would be better than having another configuration attribute).

@sgugger
Copy link
Collaborator

sgugger commented Jun 11, 2021

Ah yes, this is definitely something that could be stored in the configuration!

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

Which also connects to my proposal from 2 months ago: #11209, though it's slightly different since a model could be pre-trained in mixed precision and saved in fp32.

The thing is - if you have the weights of the model, it doesn't take long to get the dtype of the tensors it contains in its saved state_dict (pytorch) - One question - is it guaranteed they are always of the same dtype and it's enough to check one of them, or should all be checked and the highest be used if there are mixed?

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

Specific discussion on auto-detection:

To do auto-detecting torch.load() needs to be moved before model instantiating.
Then we need to set default dtype,
https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html

So the protocol would be:

  1. torch.load (which would need to be moved up) or use state_dict if it was passed to from_pretrained
  2. read one (all?) dtypes of the weights
  3. set torch.set_default_tensor_type(dtype)
  4. instantiate the model
  5. restore torch.set_default_tensor_type to its previous value (so could be context manager)
  6. _load_from_state_dict

@stas00
Copy link
Contributor

stas00 commented Jun 11, 2021

And if we choose to implement this for pytorch what do we do with tf and flax?

@stas00 stas00 removed the DeepSpeed label Jun 12, 2021
@asit2898
Copy link
Author

@stas00 Thanks a lot for addressing the issue! I really did not expect the issue to lie in the way PyTorch loads the model. I'll continue using model.half() and would be happy to help in any way I can...

@LysandreJik
Copy link
Member

@Rocketknight1, do you have an idea of how that is/would be managed with TensorFlow?

@patrickvonplaten @patil-suraj, do you have an idea of how that is/would be managed with JAX/Flax?

@Rocketknight1
Copy link
Member

@LysandreJik Keras is quite opinionated about this - it has plenty of support for mixed-precision training (like PyTorch AMP) using a Policy object but I don't know of too many people doing true full float16/bfloat16 training, and I think you'd have to do that in pure TF or use some backend functions like K.set_floatx. I also think it has weird side-effects and breaks some layers.

@stas00
Copy link
Contributor

stas00 commented Jun 18, 2021

Looks like we lost momentum on this one.

Please answer the following 2 questions with 1x and 2x (e.g. 1c, 2b - multiple versions are ok too if you're flexible)

  1. dtype setting mechanism:
    a. do we autodiscover the dtype from the state_dict
    b. do we pass an explicit dtype argument to from_pretrained
    c. a+b - with the dtype argument overriding autodiscovery
    d. using model config attribute - need to change save_pretrained to save this attribute
    e. a+d - with d overriding autodiscovery

  2. Scope of the solution:
    a. do we try to solve this for all 3 frameworks,
    b. just pytorch for now - will be documented as such

Thank you!

p.s. if we add from_pretrained(..., dtype) should we do the same for from_config(..., dtype) so that the behavior is the same?

@LysandreJik
Copy link
Member

I'd vote for 1a, overridden by a configuration attribute (1d?) rather than the from_pretrained argument, and 2b.

@sgugger
Copy link
Collaborator

sgugger commented Jun 21, 2021

Agreed with Lysandre: using a config attribute (which defaults to None or "auto") and switch back to the autodiscovery if this attribute is not set to a specific value.

@stas00
Copy link
Contributor

stas00 commented Jun 21, 2021

update: added 1d and 1e options as proposed.

So if we go with 1e - from_config is then just 1d, right? since there is no model to do autodiscovery from.

Question: could it be possible that the model will have some weights that use a different dtype than the rest of the model?

@sgugger
Copy link
Collaborator

sgugger commented Jun 21, 2021

Yes, from_config uses just 1d.
For your question, I'm not aware of such a situation existing.

@stas00
Copy link
Contributor

stas00 commented Jun 23, 2021

@asit2898, please give a try to this PR #12316 - it should do the right thing automatically as requested.

@stas00
Copy link
Contributor

stas00 commented Jun 29, 2021

@asit2898, the PR is almost done, and once merged you will need to use one of:

    model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
    model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")

to meet your needs.

@lucasjinreal
Copy link

@stas00 Hi, might off-topic, but want to ask, does specific torch_dtype = torch.float16 but loading actually a float32 model, will it result in correct auto conversion on weights?

@stas00
Copy link
Contributor

stas00 commented Jun 20, 2023

That's correct. It'd be an equivalent of weight.to(torch_dtype)

If your model was saved in fp32, torch.load will still allocate the fp32 weights as fp32 and then it'll be downcast. So if you don't have enough memory you might want to pre-save the model in your target dtype.

I think future versions of torch.load should be able to automatically load in the target dtype, but it's not the case today.

@lucasjinreal
Copy link

@stas00 I found if specific torch_dtype to float16, then the loaded model actually be 16 even thought the weights is 32.

But this model is lora, when I merge the lora to base model, the prediction result is not correct. How can i verify is this conversion is correct or not?

(but it might also because of deepspeed, since am using zero_to_fp32.py converts deepspeed state dict to a float32 adapter model.)

I am get confused now why the model loaded with lora is not right.

(but the lora saved in fp16 after training done is correct)

@stas00
Copy link
Contributor

stas00 commented Jun 20, 2023

I suggest you start a new issue, @lucasjinreal - and please tag me there.

Also please make sure you have the latest deepspeed version - until recently it wasn't dealing correctly with frozen weights - it wasn't saving those in the checkpoint. I think around 0.9.2 is when it was fixed (or 0.9.3)

@lucasjinreal
Copy link

@stas00 I was using 0.8.3, trying 0.9.4. Should transformers using 4.30 to compeletly deal with this problem?

@stas00
Copy link
Contributor

stas00 commented Jun 21, 2023

I can't say until you try, but ds 0.8.3 is definitely not doing the right thing.

transformers version has nothing to do with the issue, it's really about deepspeed not saving params that aren't in the optimizer, so zero_to_fp32.py doesn't get those and random values are then loaded when you try to load the model.

you can easily test the outcome from zero_to_fp32.py - if it's correct it should be 4x parameters (4 bytes per param in fp32), e.g. 10B model will be 40GB large checkpoint file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants