-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fp16 models getting auto converted to fp32 in .from_pretrained() #12062
Comments
cc @stas00 |
Oh, do you mean that your model was already in fp16 to start with? This combination I haven't tried yet. First when reporting Deepspeed problems please always share the deepspeed config file and the TrainingArguments. and then we can look at sorting it out. |
Yes, the saved model was already in fp16. Apologies, here are the needed files: A) DeepSpeed config file: {"zero_allow_untested_optimizer": true,
"optimizer": {
"type": "AdamW",
"params": {
"lr":3e-5,
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"train_batch_size": 24,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}
B) Training Arguments: TrainingArguments(output_dir=/data/dps_finetune_16_wikitext, overwrite_output_dir=True, do_train=True, do_eval=True, do_predict=False, evaluation_strategy=IntervalStrategy.STEPS, prediction_loss_only=False, per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=1, eval_accumulation_steps=None, learning_rate=5e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=10.0, max_steps=-1, lr_scheduler_type=SchedulerType.LINEAR, warmup_ratio=0.0, warmup_steps=0, logging_dir=runs/Jun08_18-02-30_jp3-g-31374-37031-i-2p4p2, logging_strategy=IntervalStrategy.STEPS, logging_first_step=False, logging_steps=500, save_strategy=IntervalStrategy.STEPS, save_steps=100, save_total_limit=5, no_cuda=False, seed=42, fp16=False, fp16_opt_level=O1, fp16_backend=auto, fp16_full_eval=False, local_rank=0, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=10, dataloader_num_workers=0, past_index=-1, run_name=/data/dps_finetune_16_wikitext, disable_tqdm=False, remove_unused_columns=True, label_names=None, load_best_model_at_end=False, metric_for_best_model=None, greater_is_better=None, ignore_data_skip=False, sharded_ddp=[], deepspeed=/data/config_fine_tune_bert.json, label_smoothing_factor=0.0, adafactor=False, group_by_length=False, length_column_name=length, report_to=['mlflow', 'tensorboard'], ddp_find_unused_parameters=None, dataloader_pin_memory=False, skip_memory_metrics=False, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, _n_gpu=1, mp_parameters=) fp16 is set to False. I have also tried with fp16=True but no difference in behaviour was observed. I also tested by loading the saved fp16 state_dict separately using torch.load() and then used it to initialize the BertForMaskedLM as follows: import torch
from transformers import BertConfig
state_dict = torch.load(model_path+ "pytorch_model.bin")
config = BertConfig.from_json_file(model_path+ "config.json")
model = BertForMaskedLM.from_pretrained(None,config = config, state_dict = a)
model.dtype model.dtype still outputs torch.float32. The config.json file above (saved model's config file) is as follows: {
"_name_or_path": "/data/bert-base-cased/",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 28996
}
The _name_or_path points to the location of the pre-finetuning fp32 model. However, changing its value to the post-finetuning fp16 model also does not lead to any change in model.dtype output. Please let me know if there are any checks I could run or files I could provide. |
Thank you for sharing these details. So indeed this looks like a case I haven't run into and this is not an integration issue. So under zero3 Most likely this is something Deepspeed core will have to solve. This use case is probably new to them too. So please kindly use https://github.com/microsoft/DeepSpeed/issues/new to post the same details (Edit -> Copy-n-Paste) there. Thank you, @asit2898 |
Hi @asit2898 , thanks for reporting your issue. I can help look at things from DeepSpeed's side. Was the model fine-tuned with ZeRO enabled? From the DS config above it seems not, unless it is enabled somewhere on the HF side of things. @stas00 , does the To start, I did a quick experiment with DeepSpeed (without ZeRO) and examined model parameter dtypes before and after
|
As posted above So under zero3 from_pretrained:
The key is zero3. |
@ShadenSmith @stas00 Thanks for the replies! I did not enable any stage of ZeRO and just ran DeepSpeed using pure data parallelism. It is only after I load the saved model using .from_pretrained() method that the weights get auto-converted to 32 bits... I am not very familiar with HG source code, but given that .from_pretrained() takes only the state_dict and model configuration as arguments, especially in the following case that I mentioned: import torch
from transformers import BertConfig
state_dict = torch.load(model_path+ "pytorch_model.bin")
config = BertConfig.from_json_file(model_path+ "config.json")
model = BertForMaskedLM.from_pretrained(None,config = config, state_dict = a)
model.dtype The HG object behaviour should be independent of whether or not the model was trained on DeepSpeed right 🤔 |
Thanks for the clarification @asit2898 / @stas00 . @stas00 , I don't yet understand the conclusion that the issue is in core DeepSpeed. Since ZeRO-3 is not enabled, is HF expecting the The only model dtype transformations that we should be making are converting to FP16 when that is enabled. This issue is going in the opposite direction and I am not sure where the FP32 conversion would happen. |
OK, Let me try to reproduce this first and then it'd be much easier to discuss this further. for some reason I was under the impression that zero3 was enabled! but reviewing the config posted by @asit2898 it's not. I will make an fp16 model, try to reproduce the problem and then follow up. |
OK, this doesn't seem to have anything to do with Deepspeed. Observe:
prints:
I will look next at why this bug is happening. |
OK, so it's not even Here is a standalone torch example:
prints
|
Thinking more about it I think Since the user can't access the model until after
as soon as it's instantiated.
Of course, the user could do @sgugger, @LysandreJik, @patrickvonplaten - what do you think? |
I'm okay with having a |
I edited just now to offer an automatic detection. item 2. |
@asit2898, until we sort it out please use |
I'm fine with having a I would also be fine with a configuration attribute that would identify between fp32/fp16/bfloat16, as users have been surprised in the past that models weighing ~500mb on the hub ended up taking up much more RAM and much more disk space on their machines in the past (automatic detection would be better than having another configuration attribute). |
Ah yes, this is definitely something that could be stored in the configuration! |
Which also connects to my proposal from 2 months ago: #11209, though it's slightly different since a model could be pre-trained in mixed precision and saved in fp32. The thing is - if you have the weights of the model, it doesn't take long to get the dtype of the tensors it contains in its saved |
Specific discussion on auto-detection: To do auto-detecting So the protocol would be:
|
And if we choose to implement this for pytorch what do we do with tf and flax? |
@stas00 Thanks a lot for addressing the issue! I really did not expect the issue to lie in the way PyTorch loads the model. I'll continue using model.half() and would be happy to help in any way I can... |
@Rocketknight1, do you have an idea of how that is/would be managed with TensorFlow? @patrickvonplaten @patil-suraj, do you have an idea of how that is/would be managed with JAX/Flax? |
@LysandreJik Keras is quite opinionated about this - it has plenty of support for mixed-precision training (like PyTorch AMP) using a |
Looks like we lost momentum on this one. Please answer the following 2 questions with 1x and 2x (e.g. 1c, 2b - multiple versions are ok too if you're flexible)
Thank you! p.s. if we add |
I'd vote for 1a, overridden by a configuration attribute (1d?) rather than the |
Agreed with Lysandre: using a config attribute (which defaults to None or "auto") and switch back to the autodiscovery if this attribute is not set to a specific value. |
update: added 1d and 1e options as proposed. So if we go with 1e - Question: could it be possible that the model will have some weights that use a different dtype than the rest of the model? |
Yes, |
@asit2898, the PR is almost done, and once merged you will need to use one of:
to meet your needs. |
@stas00 Hi, might off-topic, but want to ask, does specific torch_dtype = torch.float16 but loading actually a float32 model, will it result in correct auto conversion on weights? |
That's correct. It'd be an equivalent of If your model was saved in fp32, I think future versions of |
@stas00 I found if specific torch_dtype to float16, then the loaded model actually be 16 even thought the weights is 32. But this model is lora, when I merge the lora to base model, the prediction result is not correct. How can i verify is this conversion is correct or not? (but it might also because of deepspeed, since am using I am get confused now why the model loaded with lora is not right. (but the lora saved in fp16 after training done is correct) |
I suggest you start a new issue, @lucasjinreal - and please tag me there. Also please make sure you have the latest deepspeed version - until recently it wasn't dealing correctly with frozen weights - it wasn't saving those in the checkpoint. I think around 0.9.2 is when it was fixed (or 0.9.3) |
@stas00 I was using 0.8.3, trying 0.9.4. Should transformers using 4.30 to compeletly deal with this problem? |
I can't say until you try, but ds 0.8.3 is definitely not doing the right thing.
you can easily test the outcome from |
stas00 edited: this Issue has nothing to do with Deepspeed, but pure
transformers
Environment info
transformers
version: 4.6.1Who can help
@LysandreJik @sgugger
Information
Model I am using (Bert, XLNet ...): BertForMaskedLM
The problem arises when using:
The tasks I am working on is:
Masked LM
To reproduce
Steps to reproduce the behavior:
Expected behavior
Outputs torch.float32 instead of the expected torch.float16. I was able to recover the original weights using model.half()
I think it would be helpful to highlight this behaviour of forced autoconversion either as a warning or as a part of from_pretrained() method's documentation or provide an additional argument to help retain fp16 weights. Willing to pick this issue up. Please let me know what would be the most appropriate fix.
The text was updated successfully, but these errors were encountered: