Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making sure we can use safetensors to serialize all the time. #22437

Merged
merged 19 commits into from
Mar 31, 2023

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Mar 29, 2023

What does this PR do?

Making sure save_pretrained(..., safe_serialization=True) works in all
cases.

It seems _keys_to_ignore_on_load_missing was the only one to be set,
and so save_pretrained does not properly ignore those keys on saving.

Status before the fix:

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_can_use_safetensors - Exception: Class AlbertForPreTraining cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'albert.embeddings.word_embeddings.weight', 'predictions.decoder.weight'}, {'predictions.decoder.bias', 'predictions.bias'}]
FAILED tests/models/bart/test_modeling_bart.py::BartModelTest::test_can_use_safetensors - Exception: Class BartModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_can_use_safetensors - Exception: Class BertLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'bert.embeddings.word_embeddings.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/bart/test_modeling_bart.py::BartStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class BartForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/bert_generation/test_modeling_bert_generation.py::BertGenerationEncoderTest::test_can_use_safetensors - Exception: Class BertGenerationDecoder cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'bert.embeddings.word_embeddings.weight', 'lm_head.decoder.weight'}, {'lm_head.decoder.bias', 'lm_head.bias'}]
FAILED tests/models/big_bird/test_modeling_big_bird.py::BigBirdModelTest::test_can_use_safetensors - Exception: Class BigBirdForPreTraining cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'bert.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/biogpt/test_modeling_biogpt.py::BioGptModelTest::test_can_use_safetensors - Exception: Class BioGptForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'output_projection.weight', 'biogpt.embed_tokens.weight'}]
FAILED tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py::BigBirdPegasusModelTest::test_can_use_safetensors - Exception: Class BigBirdPegasusModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/blenderbot/test_modeling_blenderbot.py::BlenderbotModelTest::test_can_use_safetensors - Exception: Class BlenderbotModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/blenderbot_small/test_modeling_blenderbot_small.py::BlenderbotSmallModelTest::test_can_use_safetensors - Exception: Class BlenderbotSmallModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/blenderbot_small/test_modeling_blenderbot_small.py::BlenderbotSmallStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class BlenderbotSmallForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/blenderbot/test_modeling_blenderbot.py::BlenderbotStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class BlenderbotForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'model.decoder.embed_tokens.weight', 'lm_head.weight'}]
FAILED tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py::BigBirdPegasusStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class BigBirdPegasusForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/blip_2/test_modeling_blip_2.py::Blip2ForConditionalGenerationDecoderOnlyTest::test_can_use_safetensors - Exception: Class Blip2ForConditionalGeneration cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'language_model.lm_head.weight', 'language_model.model.decoder.embed_tokens.weight'}]
FAILED tests/models/bloom/test_modeling_bloom.py::BloomModelTest::test_can_use_safetensors - Exception: Class BloomForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'transformer.word_embeddings.weight'}]
FAILED tests/models/blip_2/test_modeling_blip_2.py::Blip2ModelTest::test_can_use_safetensors - Exception: Class Blip2ForConditionalGeneration cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'language_model.shared.weight', 'language_model.decoder.embed_tokens.weight', 'language_model.lm_head.weight', 'language_model.encoder.embed_tokens.weight'}]
FAILED tests/models/blip/test_modeling_blip.py::BlipTextImageModelTest::test_can_use_safetensors - Exception: Class BlipForConditionalGeneration cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'text_decoder.cls.predictions.bias', 'text_decoder.cls.predictions.decoder.bias'}]
FAILED tests/models/convbert/test_modeling_convbert.py::ConvBertModelTest::test_can_use_safetensors - Exception: Class ConvBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'generator_lm_head.weight', 'convbert.embeddings.word_embeddings.weight'}]
FAILED tests/models/cpm/test_tokenization_cpm.py::XLNetModelTest::test_can_use_safetensors - Exception: Class XLNetLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_loss.weight', 'transformer.word_embedding.weight'}]
FAILED tests/models/ctrl/test_modeling_ctrl.py::CTRLModelTest::test_can_use_safetensors - Exception: Class CTRLLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'transformer.w.weight', 'lm_head.weight'}]
FAILED tests/models/deberta/test_modeling_deberta.py::DebertaModelTest::test_can_use_safetensors - Exception: Class DebertaForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'deberta.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/deberta_v2/test_modeling_deberta_v2.py::DebertaV2ModelTest::test_can_use_safetensors - Exception: Class DebertaV2ForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'deberta.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/deformable_detr/test_modeling_deformable_detr.py::DeformableDetrModelTest::test_can_use_safetensors - Exception: Class DeformableDetrForObjectDetection cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'class_embed.0.weight', 'class_embed.1.weight'}, {'class_embed.1.bias', 'class_embed.0.bias'}, {'bbox_embed.0.layers.0.weight', 'bbox_embed.1.layers.0.weight'}, {'bbox_embed.1.layers.0.bias', 'bbox_embed.0.layers.0.bias'}, {'bbox_embed.0.layers.1.weight', 'bbox_embed.1.layers.1.weight'}, {'bbox_embed.1.layers.1.bias', 'bbox_embed.0.layers.1.bias'}, {'bbox_embed.1.layers.2.weight', 'bbox_embed.0.layers.2.weight'}, {'bbox_embed.0.layers.2.bias', 'bbox_embed.1.layers.2.bias'}]
FAILED tests/models/deta/test_modeling_deta.py::DetaModelTest::test_can_use_safetensors - Exception: Class DetaForObjectDetection cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'bbox_embed.0.layers.0.weight', 'model.decoder.bbox_embed.0.layers.0.weight'}, {'model.decoder.bbox_embed.0.layers.0.bias', 'bbox_embed.0.layers.0.bias'}, {'bbox_embed.0.layers.1.weight', 'model.decoder.bbox_embed.0.layers.1.weight'}, {'bbox_embed.0.layers.1.bias', 'model.decoder.bbox_embed.0.layers.1.bias'}, {'model.decoder.bbox_embed.0.layers.2.weight', 'bbox_embed.0.layers.2.weight'}, {'bbox_embed.0.layers.2.bias', 'model.decoder.bbox_embed.0.layers.2.bias'}, {'bbox_embed.1.layers.0.weight', 'model.decoder.bbox_embed.1.layers.0.weight'}, {'model.decoder.bbox_embed.1.layers.0.bias', 'bbox_embed.1.layers.0.bias'}, {'model.decoder.bbox_embed.1.layers.1.weight', 'bbox_embed.1.layers.1.weight'}, {'bbox_embed.1.layers.1.bias', 'model.decoder.bbox_embed.1.layers.1.bias'}, {'bbox_embed.1.layers.2.weight', 'model.decoder.bbox_embed.1.layers.2.weight'}, {'model.decoder.bbox_embed.1.layers.2.bias', 'bbox_embed.1.layers.2.bias'}]
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_can_use_safetensors - Exception: Class DistilBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'distilbert.embeddings.word_embeddings.weight', 'vocab_projector.weight'}]
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_can_use_safetensors - Exception: Class ElectraForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'generator_lm_head.weight', 'electra.embeddings.word_embeddings.weight'}]
FAILED tests/models/ernie/test_modeling_ernie.py::ErnieModelTest::test_can_use_safetensors - Exception: Class ErnieForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'ernie.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/esm/test_modeling_esm.py::EsmModelTest::test_can_use_safetensors - Exception: Class EsmForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'esm.embeddings.word_embeddings.weight', 'lm_head.decoder.weight'}]
FAILED tests/models/flaubert/test_modeling_flaubert.py::FlaubertModelTest::test_can_use_safetensors - Exception: Class FlaubertWithLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'pred_layer.proj.weight', 'transformer.embeddings.weight'}]
FAILED tests/models/fnet/test_modeling_fnet.py::FNetModelTest::test_can_use_safetensors - Exception: Class FNetForPreTraining cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'fnet.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/fsmt/test_modeling_fsmt.py::FSMTModelTest::test_can_use_safetensors - Exception: Class FSMTModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'decoder.embed_tokens.weight', 'decoder.output_projection.weight'}]
FAILED tests/models/funnel/test_modeling_funnel.py::FunnelModelTest::test_can_use_safetensors - Exception: Class FunnelForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'funnel.embeddings.word_embeddings.weight', 'lm_head.weight'}]
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_can_use_safetensors - Exception: Class GPT2LMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'transformer.wte.weight', 'lm_head.weight'}]
FAILED tests/models/flava/test_modeling_flava.py::FlavaForPreTrainingTest::test_can_use_safetensors - Exception: Class FlavaForPreTraining cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'mim_head.bias', 'mim_head.decoder.bias'}, {'mlm_head.decoder.bias', 'mlm_head.bias'}, {'mmm_image_head.decoder.bias', 'mmm_image_head.bias'}, {'mmm_text_head.decoder.bias', 'mmm_text_head.bias'}]
FAILED tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py::GPTNeoXModelJapaneseTest::test_can_use_safetensors - Exception: Class GPTNeoXJapaneseForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'embed_out.weight', 'gpt_neox_japanese.embed_in.weight'}]
FAILED tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py::GPTSanJapaneseForConditionalGenerationTest::test_can_use_safetensors - Exception: Class GPTSanJapaneseForConditionalGeneration cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.embed_tokens.weight'}]
FAILED tests/models/ibert/test_modeling_ibert.py::IBertModelTest::test_can_use_safetensors - Exception: Class IBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.decoder.weight', 'ibert.embeddings.word_embeddings.weight'}, {'lm_head.bias', 'lm_head.decoder.bias'}]
FAILED tests/models/layoutlm/test_modeling_layoutlm.py::LayoutLMModelTest::test_can_use_safetensors - Exception: Class LayoutLMForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'layoutlm.embeddings.word_embeddings.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/led/test_modeling_led.py::LEDModelTest::test_can_use_safetensors - Exception: Class LEDModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/longformer/test_modeling_longformer.py::LongformerModelTest::test_can_use_safetensors - Exception: Class LongformerForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'longformer.embeddings.word_embeddings.weight', 'lm_head.decoder.weight'}, {'lm_head.decoder.bias', 'lm_head.bias'}]
FAILED tests/models/longt5/test_modeling_longt5.py::LongT5ModelTest::test_can_use_safetensors - Exception: Class LongT5Model cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/lxmert/test_modeling_lxmert.py::LxmertModelTest::test_can_use_safetensors - Exception: Class LxmertForPreTraining cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lxmert.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}]
FAILED tests/models/longt5/test_modeling_longt5.py::LongT5TGlobalModelTest::test_can_use_safetensors - Exception: Class LongT5Model cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/m2m_100/test_modeling_m2m_100.py::M2M100ModelTest::test_can_use_safetensors - Exception: Class M2M100Model cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/marian/test_modeling_marian.py::MarianModelTest::test_can_use_safetensors - Exception: Class MarianModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/longt5/test_modeling_longt5.py::LongT5EncoderOnlyModelTest::test_can_use_safetensors - Exception: Class LongT5EncoderModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight'}]
FAILED tests/models/longt5/test_modeling_longt5.py::LongT5EncoderOnlyTGlobalModelTest::test_can_use_safetensors - Exception: Class LongT5EncoderModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight'}]
FAILED tests/models/marian/test_modeling_marian.py::MarianStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class MarianForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/mbart/test_modeling_mbart.py::MBartModelTest::test_can_use_safetensors - Exception: Class MBartModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/mbart/test_modeling_mbart.py::MBartStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class MBartForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_can_use_safetensors - Exception: Class MegatronBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'bert.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_can_use_safetensors - Exception: Class MobileBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'mobilebert.embeddings.word_embeddings.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/mpnet/test_modeling_mpnet.py::MPNetModelTest::test_can_use_safetensors - Exception: Class MPNetForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'mpnet.embeddings.word_embeddings.weight', 'lm_head.decoder.weight'}, {'lm_head.decoder.bias', 'lm_head.bias'}]
FAILED tests/models/mvp/test_modeling_mvp.py::MvpModelTest::test_can_use_safetensors - Exception: Class MvpModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/nezha/test_modeling_nezha.py::NezhaModelTest::test_can_use_safetensors - Exception: Class NezhaForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'nezha.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/mvp/test_modeling_mvp.py::MvpStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class MvpForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/nllb_moe/test_modeling_nllb_moe.py::NllbMoeModelTest::test_can_use_safetensors - Exception: Class NllbMoeModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/nystromformer/test_modeling_nystromformer.py::NystromformerModelTest::test_can_use_safetensors - Exception: Class NystromformerForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'nystromformer.embeddings.word_embeddings.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/openai/test_modeling_openai.py::OpenAIGPTModelTest::test_can_use_safetensors - Exception: Class OpenAIGPTLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'transformer.tokens_embed.weight'}]
FAILED tests/models/opt/test_modeling_opt.py::OPTModelTest::test_can_use_safetensors - Exception: Class OPTForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/pegasus/test_modeling_pegasus.py::PegasusModelTest::test_can_use_safetensors - Exception: Class PegasusModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/pegasus/test_modeling_pegasus.py::PegasusStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class PegasusForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/pegasus_x/test_modeling_pegasus_x.py::PegasusXModelTest::test_can_use_safetensors - Exception: Class PegasusXModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/pix2struct/test_modeling_pix2struct.py::Pix2StructTextImageModelTest::test_can_use_safetensors - Exception: Class Pix2StructForConditionalGeneration cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'decoder.lm_head.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/plbart/test_modeling_plbart.py::PLBartModelTest::test_can_use_safetensors - Exception: Class PLBartModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/plbart/test_modeling_plbart.py::PLBartStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class PLBartForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/prophetnet/test_modeling_prophetnet.py::ProphetNetModelTest::test_can_use_safetensors - Exception: Class ProphetNetModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'word_embeddings.weight', 'encoder.word_embeddings.weight', 'decoder.word_embeddings.weight'}]
FAILED tests/models/realm/test_modeling_realm.py::RealmModelTest::test_can_use_safetensors - Exception: Class RealmEmbedder cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/reformer/test_modeling_reformer.py::ReformerLocalAttnModelTest::test_can_use_safetensors - Exception: Class ReformerModelWithLMHead cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.decoder.bias', 'lm_head.bias'}]
FAILED tests/models/prophetnet/test_modeling_prophetnet.py::ProphetNetStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class ProphetNetForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'prophetnet.decoder.word_embeddings.weight', 'lm_head.weight'}]
FAILED tests/models/reformer/test_modeling_reformer.py::ReformerLSHAttnModelTest::test_can_use_safetensors - Exception: Class ReformerModelWithLMHead cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.decoder.bias', 'lm_head.bias'}]
FAILED tests/models/roc_bert/test_modeling_roc_bert.py::RoCBertModelTest::test_can_use_safetensors - Exception: Class RoCBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'roc_bert.embeddings.word_embeddings.weight', 'cls.predictions.decoder.weight'}, {'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/roformer/test_modeling_roformer.py::RoFormerModelTest::test_can_use_safetensors - Exception: Class RoFormerForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'roformer.embeddings.word_embeddings.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/speech_to_text/test_modeling_speech_to_text.py::Speech2TextModelTest::test_can_use_safetensors - Exception: Class Speech2TextForConditionalGeneration cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py::Speech2Text2StandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class Speech2Text2ForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/squeezebert/test_modeling_squeezebert.py::SqueezeBertModelTest::test_can_use_safetensors - Exception: Class SqueezeBertForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'transformer.embeddings.word_embeddings.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/speecht5/test_modeling_speecht5.py::SpeechT5ForSpeechToTextTest::test_can_use_safetensors - Exception: Class SpeechT5ForSpeechToText cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'speecht5.decoder.prenet.embed_tokens.weight', 'text_decoder_postnet.lm_head.weight'}]
FAILED tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_can_use_safetensors - Exception: Class SwitchTransformersModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_can_use_safetensors - Exception: Class T5Model cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight', 'decoder.embed_tokens.weight'}]
FAILED tests/models/t5/test_modeling_t5.py::T5EncoderOnlyModelTest::test_can_use_safetensors - Exception: Class T5EncoderModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'encoder.embed_tokens.weight', 'shared.weight'}]
FAILED tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersEncoderOnlyModelTest::test_can_use_safetensors - Exception: Class SwitchTransformersEncoderModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'shared.weight', 'encoder.embed_tokens.weight'}]
FAILED tests/models/tapas/test_modeling_tapas.py::TapasModelTest::test_can_use_safetensors - Exception: Class TapasForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'tapas.embeddings.word_embeddings.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
FAILED tests/models/transfo_xl/test_modeling_transfo_xl.py::TransfoXLModelTest::test_can_use_safetensors - Exception: Class TransfoXLLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'transformer.word_emb.emb_layers.0.weight', 'crit.out_layers.0.weight'}, {'crit.out_layers.1.weight', 'transformer.word_emb.emb_layers.1.weight'}, {'crit.out_layers.2.weight', 'transformer.word_emb.emb_layers.2.weight'}, {'crit.out_layers.3.weight', 'transformer.word_emb.emb_layers.3.weight'}, {'crit.out_projs.1', 'transformer.word_emb.emb_projs.1'}, {'crit.out_projs.2', 'transformer.word_emb.emb_projs.2'}, {'transformer.word_emb.emb_projs.3', 'crit.out_projs.3'}]
FAILED tests/models/trocr/test_modeling_trocr.py::TrOCRStandaloneDecoderModelTest::test_can_use_safetensors - Exception: Class TrOCRForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'output_projection.weight', 'model.decoder.embed_tokens.weight'}]
FAILED tests/models/vilt/test_modeling_vilt.py::ViltModelTest::test_can_use_safetensors - Exception: Class ViltForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'mlm_score.bias', 'mlm_score.decoder.bias'}]
FAILED tests/models/visual_bert/test_modeling_visual_bert.py::VisualBertModelTest::test_can_use_safetensors - Exception: Class VisualBertForRegionToPhraseAlignment cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.bias', 'cls.predictions.decoder.bias'}]
FAILED tests/models/xlm/test_modeling_xlm.py::XLMModelTest::test_can_use_safetensors - Exception: Class XLMWithLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'pred_layer.proj.weight', 'transformer.embeddings.weight'}]
FAILED tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_can_use_safetensors - Exception: Class XGLMForCausalLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_head.weight', 'model.embed_tokens.weight'}]
FAILED tests/models/xlnet/test_modeling_xlnet.py::XLNetModelTest::test_can_use_safetensors - Exception: Class XLNetLMHeadModel cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'lm_loss.weight', 'transformer.word_embedding.weight'}]
FAILED tests/models/yoso/test_modeling_yoso.py::YosoModelTest::test_can_use_safetensors - Exception: Class YosoForMaskedLM cannot be saved using safetensors: Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'cls.predictions.decoder.weight', 'yoso.embeddings.word_embeddings.weight'}, {'cls.predictions.decoder.bias', 'cls.predictions.bias'}]
== 90 failed, 10714 passed, 8750 skipped, 2172 warnings in 1202.20s (0:20:02) ==

Exited with code exit status 1

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 29, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 1721 to 1726
# Disable to see the damage.
if safe_serialization:
if self._keys_to_ignore_on_load_missing is not None:
for ignore_key in self._keys_to_ignore_on_load_missing:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of the fix:

I think of 3 ways to get this done:

  • Just add _keys_to_ignore_on_save for all affected models. No change in core modeling code. Will impact both torch and safetensors saving versions (so technically breaking change for the on-disk representation, should users reuse the dictionnairies in non-transformers modeling code). Essentially will have to maintain both _keys_to_ignore_on_save AND _keys_to_ignore_on_load_missing at the same time for all models

  • [Proposed fix] Ignore the ignore_on_load only for safetensors. The benefit is that it's not breaking for torch, and still allows saving with safetensors for all models. However it does introduce a difference in mecanism for both on disk representation

  • Ignore the ignore_on_load for both torch and safetensors. Will have the same the same effectr as the first proposed fix, with the benefit of nto having to maintain both keys. Technically both are not exactly the same because sometimes models have been renamed leading to old names being in those keys while not being used anymore. This is more true for the third key _keys_to_ignore_extra but still could potentially exist for ignore_on_save and ignore_on_load. I think this is ok, because the current code will just ignore keys not contained in the state dict.

I propose the second change since it's the least breaking now, even though I think in the long run it could cause more issues/confusion because of the dichotomy in treatment.

In any case, all proposed changes would only affect future models being created, and have no effect on currently existing models in the wild.

@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2023

This all happens because of the design decision in safetensors to error out in case of tied weights. Nothing wrong will happen if users actually save those models and reload them as the weights are re-tied by Transformers. I think instead of forcing to change Transformers, safetensors should just adapt to its users and only issue a warning when asked to save a state dict with tied weights, or at the very least have an option to ignore tied weights.

@Narsil
Copy link
Contributor Author

Narsil commented Mar 29, 2023

I'm not so sure. The tests are currently failing hard (incorrect reloaded tensors) because of incorrect configuration within some models:

  • Llama
  • ImageGPT
  • Blip2
  • Pix2struct

https://app.circleci.com/pipelines/github/huggingface/transformers/60854/workflows/4049cf2e-afe5-40cb-a218-889434cb0b80/jobs/746519

While it is not currently an issue because the saved torch files are creating the aliasing, and so it is actually unpacked during loading, I think all 4 (only checked llama for now) have an incorrect _keys_to_ignore_on_load, or and incorrect tie_weight_embeddings.

If we make the hard error a simple warning, that would just lead to wrong models reloaded from safetensors. (The weight will get ignored so no warning to user, and yet the weights won't be tied to the output head will be random)

For LLama:

This is true in pure torch world and has nothing to do with safetensors. It happens to be a minor issue because we currently save the alias.

Normally, we're saved because the convertion script will disallow this convertion (since reloaded model is incorrect).
I am checking this at this instant.
We're also saved because save_pretrained(.., safe_serialization=True) will simply fail right now.
So afaik, we're not creating bogus files at the moment and only users manually discarding the alias will see the issue, which seems highly unlikely.

For these 4 models, provided they are the same issue, either we need to fix the configuation, and retie the weights (which would make the current proposed fix just work) or actually remove the ignore_on_load and make sure the tensors are actually disjoint.

@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2023

Yet we went from 90 failures to just 4 models. I'm not saying Transformers is perfect and does not need any fix at all. Even with safetensors enabling save of state dictionaries having tied weights we should make sure we only save one of those weights to have the most efficient safetensors checkpoints.

I'm just highlighting that an API that is too rigid will never be broadly used, so I really think safetensors should add support for bypassing this hard error.

@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2023

(Also can confirm that the embeddings and LM head are different tensors for Llama-7b at least, so the _key_to_ignore_on_load_missing is just wrongly set)

@Narsil
Copy link
Contributor Author

Narsil commented Mar 29, 2023

Confirmed on ImageGPT it's the same.

It's funny for Llama though, because the model tester does share the weights though...
No it doesn't my bad.

I'm just highlighting that an API that is too rigid will never be broadly used, so I really think safetensors should add support for bypassing this hard error.

I respectfully disagree. You're not wrong, but I really think it's not the case here (simply allowing it is just allowing ourselves to shoot in the foot).
While it's definitely inconvenient, saving a file that will not get reloaded properly, is >>> worse than preventing saving it in the first place. And the biggest problem is that it might take a while before the issue in the file is found.

@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2023

In any case, a fix will need to be different than what is suggested in the PR: the _keys_to_ignore_on_save cannot be always set to ignore the decoder at the class level, because the option to break that connection exists in the config. So we may have GPT2 models for instance with an lm_weight distinct from the encoder. The example actually exists with the T5 architecture: the canonical t5 checkpoints have the decoder tied to the embeddings but not T0pp (see here).

So _keys_to_ignore_on_save can only contain the name of layers that are always tied (so in the case of T5 it could contain "encoder.embeds_token" which is always tied to the shared layer for instance) or always generated at inits.

Likewise you can't write code like in this PR that always deletes keys based on _keys_to_ignore_on_load_xxx for the same reasons.

What could be done instead is during saving with safetensors:

  • use Accelerate find_tied_parameters (Accelerate will soon be a dep on the torch side anyway, this might be the turning point to actually do it) to identify the tied parameters
  • delete all tied parameters but the first in each group found
  • then save the rest

We might need a new class attribute in XxxPreTrainedModel for the edge case where the main tied parameter is not the first one as returned by find_tied_parameters but not even sure it's needed as the tie_weights is done before the actual load_state_dict so loading tensor data in any of the tied weights should automatically populate the data in all of them.

But the code needs to be dynamic (depending on the actual model seen) not static (in the sense that it uses the class variables).

@Narsil
Copy link
Contributor Author

Narsil commented Mar 29, 2023

delete all tied parameters but the first in each group found

But we need to know which ones are actually used to recreate the others then. There's a main weight, and the others are deduced from the others. At least to properly not get a warning.

Likewise you can't write code like in this PR that always deletes keys based on _keys_to_ignore_on_load_xxx for the same reasons.

Couldn't it remove keys that are both in the _ignore_on_load AND shared pointers then ? That allows to untied weights AND knowing the name of the main weights (the only one which is not in those keys)

In general I'm confused about having weights untied at runtime, since if you untied the, save your model, erase the tied weights then you would reload an not have a warning and getting an erroneous model.
A very long stretched shot for sure, but it's the reason why I think ignore_on_load and ignore_on_save play relatively the same role.

find_tied_parameters

Nice, but I don't think a full function from a dependency is necessary for that :

# Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
    ptrs[v.data_ptr()].append(k)

shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}

Is enough.

Not even sure it's needed as the tie_weights is done before the actual load_state_dict so loading tensor data in any of the tied weights should automatically populate the data in all of them.

I confirmed. It just erroneously raises a warning, but the underlying model is fine.
However, I still think having a consistent name for the main weight would be better in general.

@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2023

Couldn't it remove keys that are both in the _ignore_on_load AND shared pointers then? That allows to untied weights AND knowing the name of the main weights (the only one which is not in those keys)

You can take those names as suggestions but you will still need to leave only one weight per group of tied parameters or risk getting an error from safetensors. While you are fine with safetensors save function not working for some models, I am not fine with the same behavior in Transformers.

Nice, but I don't think a full function from a dependency is necessary for that

Like I said Accelerate is becoming a torch dependency anyway (since the Trainer will be rewritten to use it), so I don't see how it's wrong to use it. Your snippet of code will not present the groups of shared parameters (T5 as 4 of them tied together) as nicely, and you'd need to add tests for it (whereas Accelerate already heavily tests its utils).

In general I'm confused about having weights untied at runtime, since if you untied the, save your model, erase the tied weights then you would reload an not have a warning and getting an erroneous model.

I have no idea what this means. Are you referring to the situation where a user breaks the tie weights connection somehow without changing the model config and then save the weights and reloads the model with from_pretrained? That would also fail in torch and I don't think I have ever seen a user complain about it.

Comment on lines 1732 to 1737
for _, names in shared_ptrs.items():
for name in names:
for pat in self._keys_to_ignore_on_load_missing:
if re.search(pat, name):
if name in state_dict:
del state_dict[name]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is roughly copy-pasted from from_pretrained just we drop only the shared keys (which allows dynamically unlinked tensors to go through as-is)

@@ -1238,8 +1238,28 @@ def __init__(self, config: Blip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def get_input_embeddings(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada As seen offline, but could you confirm it's OK ?

It's copy pasted from Blip2ForConditionalGeneration. Without those, the tie_weights seems broken after load (This isn't safetensors specific and could be a separate PR)

@@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput):


def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
return nn.ModuleList([module for i in range(N)])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables the dynamic loading to work properly.

If this is left as-is, what happens, is that during model init, all n modules are created pointing to the same parameter, but after load_state_dict only the first layer gets updated

@@ -1778,7 +1778,7 @@ def forward(
)
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"]
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*", "model.decoder"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.decoder can share bbox_embed and class_embed.

It's not necessarily the case and depends on the config.
This is the sort of model which make the dynamic nature of the PR crucial (drop only shared tensors).

The debt we're creating here, is that if a file is missing some weights (like model.decoder.bbox_embed) but the config is set to not sharing (or modified on disk). Then the warning will not be shown yet keeping random weights on the model.

I don't see any good ways to solve this since _keys_to_ignore is not config dependent.
However, since the links are properly made a init time (so when the config is allowed) the issue will only arise when users use mismatched config and weights.

Probably and acceptable choice.

@@ -630,8 +630,6 @@ def custom_forward(*inputs):


class LlamaForCausalLM(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be removed since LLama doesnt' share the embedding and the output.

@ArthurZucker for final confirmation ?

@@ -357,9 +357,10 @@ def __init__(
initializer_factor=1.0,
initializer_range=0.02,
is_vqa=False,
tie_word_embeddings=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada for confirmation.

Pix2Struct has Vision, Text, and Global model (and therefore config).

The text properly sets the tie_word_embeddings to False, but the global one didn't and therefore the global model would forcefully set the word embeddings to tied even when it shouldn't.

I think this is also independant from this safetensors PR.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good for the modifications about Blip2 & Pix2struct! Thanks a mile for double checking

@Narsil
Copy link
Contributor Author

Narsil commented Mar 30, 2023

@sgugger All tests are now passing with relative minor code changes.

I think we could push out some fixes to their respective PRs (blip2, pix2struct, llama) since what this uncovered seems to really be affecting current models.

For deta I think since it's marked exotic the proposed fix could work.

And just for note, my insistence for disallowing aliasing doesn't come from nowhere.
Having any aliasing like torch does, instantly renders lazy loading buggy.

If you do :

tensor1 = safe_open(filename).get_tensor("lm_head.weight")
tensor2 = safe_open(filename).get_tensor("wte.weight")

Then necessarily the tensors aren't shared, while they are if you did weights = load_file(filename) (if we respected the aliasing which we probably should since otherwise fine-tuning is screwed).

So enabling aliasing forces safetensors to give up lazy loading. The bar to do that is pretty high in my mind since lazy loading is a very nice feature we get out of it.

Note: silently dropping tensors on save in safetensors will necessarily lead to bugs in transformers too that's why I'm not considering it as an option. (Since the reloaded file will be wrong)

@sgugger
Copy link
Collaborator

sgugger commented Mar 30, 2023

I'm not sure why you are ignoring the comments I made with respect to this PR and safetensors as it is now and go back to defend your choice of API for safetensors (which I still think is wrong but I'm done debating this).

So once again:

  • the changes in modeling_utils should only leave one of every group of tied weights, so that the save with safetensors does not fail. _keys_to_ignore_on_load_missing can inform which weight to drop, but if that variable is incomplete (like in DETA, or any other model that does not normally have tied weights but where a user chose to apply tie_weights for their purposes), we should still drop something.
  • the proposed DETA change cannot be accepted as it will yield to silent bugs for users not having tied weights and an incomplete state_dict.

@Narsil
Copy link
Contributor Author

Narsil commented Mar 30, 2023

but if that variable is incomplete (like in DETA, or any other model that does not normally have tied weights but where a user chose to apply tie_weights for their purposes), we should still drop something.

I've done that. Adding the necessary other piece which is dropping missing keys on shared tensors regardless of the _keys_to_ignore for shared tensors. That way we don't trigger the warning when loading from safetensors even without the key being present (which it would otherwise).

Doing both allows to remove the needs of the deta key modification. (Still needs to fix the deepcopy, again nothing to do with safetensors, but the parameters are cloned and not shared and so the tensors are not properly filled for layers > 1 without the fix)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating, I have a few more comments.

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for _, names in shared_ptrs.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should iterate on .values() if you don't want the keys.

# load. This allows to make sure the name which is kept is consistent.
if self._keys_to_ignore_on_load_missing is not None:
for name in names:
for pat in self._keys_to_ignore_on_load_missing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for the for loop and two if statements, just do:

if name in state_dict and any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing):
    del state_dict[name]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did but split in two lines for readability.

Comment on lines 1741 to 1745
# When not all duplicates have been cleaned
# Still remove those keys, but put a clear warning
# Since if the link between tensors was done at runtime
# then `from_pretrained` will still not get the key back
# Leading to random tensor. With a proper warning.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a line width of 119 chars in Transformers, no need to take 5 lines for this comment.

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
Narsil and others added 4 commits March 31, 2023 10:25
@Narsil
Copy link
Contributor Author

Narsil commented Mar 31, 2023

@sgugger If you want to do a final check (maybe we want a global warn_once too.)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the warn_once is already implemented in the Transformers logger, that is what I was suggesting you to use.

del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
warn_once(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As suggested before, please use logger.warn_once ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't exist:

FAILED tests/models/deta/test_modeling_deta.py::DetaModelTest::test_can_use_safetensors - Exception: Class DetaForObjectDetection cannot be saved using safetensors: 'Logger' object has no attribute 'warn_once'

https://app.circleci.com/pipelines/github/huggingface/transformers/61009/workflows/8e414a0b-9677-4cfa-82b8-e8860df4835c/jobs/748839

Is the logger improperly done here ? Also I couldn't find the symbol anywhere, is it recent ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh... warning_once. Thanks @younesbelkada :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, mixed up the name.

@Narsil Narsil merged commit d143087 into huggingface:main Mar 31, 2023
@Narsil Narsil deleted the safe_serialization_always_valid branch March 31, 2023 14:08
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
…gface#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 7, 2023

Hey @Narsil The doctest for DetaForObjectDetection fails after this PR. You can run the code snippet below.
Could you take a look 🙏 Thanks.

###previous results

{'scores': tensor([0.6831, 0.6826, 0.5684, 0.5464], grad_fn=<IndexBackward0>), 'labels': tensor([17, 17, 75, 75]), 'boxes': tensor([[345.8479,  23.6753, 639.8561, 372.8265],
        [  8.7996,  52.4945, 316.9348, 473.4509],
        [ 40.0171,  73.7522, 175.9579, 117.3332],
        [333.6797,  77.1251, 370.1172, 187.5138]], grad_fn=<IndexBackward0>)}

###now

{'scores': tensor([], grad_fn=<IndexBackward0>), 'labels': tensor([], dtype=torch.int64), 'boxes': tensor([], size=(0, 4), grad_fn=<IndexBackward0>)}
from transformers import AutoImageProcessor, DetaForObjectDetection
from PIL import Image
import requests
import torch

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("jozhang97/deta-swin-large")
model = DetaForObjectDetection.from_pretrained("jozhang97/deta-swin-large")

inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
     0
 ]
print(results)

@Narsil Narsil mentioned this pull request Apr 7, 2023
5 tasks
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 7, 2023

Another one affected

tests/models/vit/test_modeling_vit.py::ViTModelIntegrationTest::test_inference_fp16
(line 136)  ValueError: weight is on the meta device, we need a value to put in on 1.

Full trace

self = <tests.models.vit.test_modeling_vit.ViTModelIntegrationTest testMethod=test_inference_fp16>

    @slow
    @require_accelerate
    @require_torch_gpu
    def test_inference_fp16(self):
        r"""
        A small test to make sure that inference work in half precision without any problem.
        """
>       model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")

tests/models/vit/test_modeling_vit.py:324: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/modeling_utils.py:2760: in from_pretrained
    dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
/usr/local/lib/python3.8/dist-packages/accelerate/big_modeling.py:370: in dispatch_model
    attach_align_device_hook_on_blocks(
/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:478: in attach_align_device_hook_on_blocks
    add_hook_to_module(module, hook)
/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:155: in add_hook_to_module
    module = hook.init_hook(module)
/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:251: in init_hook
    set_module_tensor_to_device(module, name, self.execution_device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

module = Linear(in_features=384, out_features=384, bias=True), tensor_name = 'weight', device = 0, value = None, dtype = None

    def set_module_tensor_to_device(
        module: nn.Module,
        tensor_name: str,
        device: Union[int, str, torch.device],
        value: Optional[torch.Tensor] = None,
        dtype: Optional[Union[str, torch.dtype]] = None,
    ):
        """
        A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
        `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
    
        Args:
            module (`torch.nn.Module`):
                The module in which the tensor we want to move lives.
            param_name (`str`):
                The full name of the parameter/buffer.
            device (`int`, `str` or `torch.device`):
                The device on which to set the tensor.
            value (`torch.Tensor`, *optional*):
                The value of the tensor (useful when going from the meta device to any other device).
            dtype (`torch.dtype`, *optional*):
                If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
                the dtype of the existing parameter in the model.
        """
        # Recurse if needed
        if "." in tensor_name:
            splits = tensor_name.split(".")
            for split in splits[:-1]:
                new_module = getattr(module, split)
                if new_module is None:
                    raise ValueError(f"{module} has no attribute {split}.")
                module = new_module
            tensor_name = splits[-1]
    
        if tensor_name not in module._parameters and tensor_name not in module._buffers:
            raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
        is_buffer = tensor_name in module._buffers
        old_value = getattr(module, tensor_name)
    
        if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
>           raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
E           ValueError: weight is on the meta device, we need a `value` to put in on 0.

xloem pushed a commit to xloem/transformers that referenced this pull request Apr 9, 2023
…gface#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
xloem pushed a commit to xloem/transformers that referenced this pull request Apr 10, 2023
…gface#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@Narsil
Copy link
Contributor Author

Narsil commented Apr 11, 2023

#22656 (review)

ydshieh added a commit that referenced this pull request Apr 13, 2023
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
sgugger pushed a commit that referenced this pull request Apr 14, 2023
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…gface#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…ce#22750)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
stevhliu added a commit that referenced this pull request Feb 26, 2024
* [Pix2struct] Simplify generation (#22527)

* Add model to doc tests

* Remove generate and replace by prepare_inputs_for_generation

* More fixes

* Remove print statements

* Update integration tests

* Fix generate

* Remove model from auto mapping

* Use auto processor

* Fix integration tests

* Fix test

* Add inference code snippet

* Remove is_encoder_decoder

* Update docs

* Remove notebook link

* Release: v4.28.0

* Revert (for now) the change on `Deta` in #22437 (#22750)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Patch release: v4.28.1

* update zh chat template.

* Update docs/source/zh/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/zh/_toctree.yml

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

---------

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Michael <haifeng.yao@daocloud.io>
itazap pushed a commit that referenced this pull request May 14, 2024
* [Pix2struct] Simplify generation (#22527)

* Add model to doc tests

* Remove generate and replace by prepare_inputs_for_generation

* More fixes

* Remove print statements

* Update integration tests

* Fix generate

* Remove model from auto mapping

* Use auto processor

* Fix integration tests

* Fix test

* Add inference code snippet

* Remove is_encoder_decoder

* Update docs

* Remove notebook link

* Release: v4.28.0

* Revert (for now) the change on `Deta` in #22437 (#22750)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Patch release: v4.28.1

* update zh chat template.

* Update docs/source/zh/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/zh/_toctree.yml

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

* Update docs/source/zh/chat_templating.md

Co-authored-by: Michael <haifeng.yao@daocloud.io>

---------

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Michael <haifeng.yao@daocloud.io>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants