Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make torch.load a bit safer #27282

Merged
merged 2 commits into from
Dec 15, 2023
Merged

make torch.load a bit safer #27282

merged 2 commits into from
Dec 15, 2023

Conversation

julien-c
Copy link
Member

@julien-c julien-c commented Nov 4, 2023

No description provided.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - thanks for adding!

It seems this currently breaks resuming from checkpoint with trainer, but other changes LGTM.

cc @muellerzr To review the changes in trainer.py as he knows more about the saved objects there

@@ -2305,7 +2305,7 @@ def _load_rng_state(self, checkpoint):
)
return

checkpoint_rng_state = torch.load(rng_file)
checkpoint_rng_state = torch.load(rng_file, weights_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik LysandreJik requested a review from muellerzr November 7, 2023 09:13
Copy link

github-actions bot commented Dec 5, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@@ -2466,7 +2466,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scheduler shouldn't have weights right?

@github-actions github-actions bot closed this Dec 14, 2023
@julien-c
Copy link
Member Author

cc @LysandreJik is this stale or has it been done elsewhere since?

@LysandreJik LysandreJik reopened this Dec 15, 2023
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! 🔒

@LysandreJik
Copy link
Member

Thank you! :)

@LysandreJik LysandreJik merged commit dec84b3 into main Dec 15, 2023
21 checks passed
@LysandreJik LysandreJik deleted the torch-load branch December 15, 2023 15:01
@julien-c
Copy link
Member Author

yay 1 more commit on the GOAT of codebases!!! happy:)

iantbutler01 pushed a commit to BismuthCloud/transformers that referenced this pull request Dec 16, 2023
* make torch.load a bit safer

* Fixes

---------

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
@hjenryin
Copy link

hjenryin commented Jan 10, 2024

Hi @julien-c, thanks for your work!

I was building from main to use some features not distributed, but I found that from_pretrained no longer worked, and it might have something to do with this pr.

The code is as simple as this:

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")

And it raises

Traceback (most recent call last):
  File "***/lib/python3.9/site-packages/transformers/modeling_utils.py", line 520, in load_state_dict
    return torch.load(checkpoint_file, map_location=map_location, weights_only=True)
  File "***/lib/python3.9/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "***/lib/python3.9/site-packages/torch/serialization.py", line 880, in _load
    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
TypeError: 'weights_only' is an invalid keyword argument for Unpickler()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "***/lib/python3.9/site-packages/transformers/modeling_utils.py", line 524, in load_state_dict
    if f.read(7) == "version":
  File "***/lib/python3.9/codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ***, line 3, in <module>
    model = AutoModelForCausalLM.from_pretrained(
  File "***/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    return model_class.from_pretrained(
  File "***/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3430, in from_pretrained
    state_dict = load_state_dict(resolved_archive_file)
  File "***/lib/python3.9/site-packages/transformers/modeling_utils.py", line 536, in load_state_dict
    raise OSError(
OSError: Unable to load weights from pytorch checkpoint file for '/home/***/.cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6/pytorch_model.bin' at '/home/***/.cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6/pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

The TypeError seems to be related to incompatible pytorch. I'm using 1.10.1+cu111. I wonder maybe it's better to fallback to the original implementation in case of error and emit a one-time warning?

As for UnicodeDecodeError, I don't really know where it comes from. It might be related to codec or pickle, but I'm not sure. I'm using Python 3.9.18. I also tried deleting the cache and download again, but it still didn't work.

I also checked from_tf=True in the error, but it seems that tensorflow is required (which I don't have), so I think this shouldn't be the problem. After all, everything worked fine with transformers 4.36.2 previously.

Fixing the TypeError shall eliminate the other errors.

Thank you for your time! If you need any help from me, feel free to ask.

@julien-c
Copy link
Member Author

torch 1.10 is quite old, is there any way you'd be able to upgrade to a more recent torch?

@hjenryin
Copy link

hjenryin commented Jan 10, 2024

Sure, but the point is, transformers claims to support torch 1.10 in its deps, but weights_only wasn't added to torch.load until 1.13 (see here). It might be better if either the deps are updated, or backward support is added?

@julien-c
Copy link
Member Author

yep! cc @LysandreJik

@ArthurZucker
Copy link
Collaborator

#28207 will fix this 🤗

@hjenryin
Copy link

hjenryin commented Jan 11, 2024

but weights_only wasn't added to torch.load until 1.13 (see here)

#28207 only removes 1.10, but for torch 1.11 and torch 1.12, secure pickling should still be buggy. You can take a moment to compare https://pytorch.org/docs/1.12/generated/torch.load.html and https://pytorch.org/docs/1.13/generated/torch.load.html.

@LysandreJik
Copy link
Member

You're correct @hjenryin, we're taking a look at fixing this before the next release. Thanks for the report

staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* make torch.load a bit safer

* Fixes

---------

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants