-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making sure we can use safetensors to serialize all the time. #22437
Making sure we can use safetensors to serialize all the time. #22437
Conversation
The documentation is not available anymore as the PR was closed or merged. |
src/transformers/modeling_utils.py
Outdated
# Disable to see the damage. | ||
if safe_serialization: | ||
if self._keys_to_ignore_on_load_missing is not None: | ||
for ignore_key in self._keys_to_ignore_on_load_missing: | ||
if ignore_key in state_dict.keys(): | ||
del state_dict[ignore_key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the core of the fix:
I think of 3 ways to get this done:
-
Just add
_keys_to_ignore_on_save
for all affected models. No change in core modeling code. Will impact both torch and safetensors saving versions (so technically breaking change for the on-disk representation, should users reuse the dictionnairies in non-transformers modeling code). Essentially will have to maintain both_keys_to_ignore_on_save
AND_keys_to_ignore_on_load_missing
at the same time for all models -
[Proposed fix] Ignore the
ignore_on_load
only for safetensors. The benefit is that it's not breaking for torch, and still allows saving with safetensors for all models. However it does introduce a difference in mecanism for both on disk representation -
Ignore the
ignore_on_load
for both torch and safetensors. Will have the same the same effectr as the first proposed fix, with the benefit of nto having to maintain both keys. Technically both are not exactly the same because sometimes models have been renamed leading to old names being in those keys while not being used anymore. This is more true for the third key_keys_to_ignore_extra
but still could potentially exist forignore_on_save
andignore_on_load
. I think this is ok, because the current code will just ignore keys not contained in the state dict.
I propose the second change since it's the least breaking now, even though I think in the long run it could cause more issues/confusion because of the dichotomy in treatment.
In any case, all proposed changes would only affect future models being created, and have no effect on currently existing models in the wild.
This all happens because of the design decision in |
I'm not so sure. The tests are currently failing hard (incorrect reloaded tensors) because of incorrect configuration within some models:
While it is not currently an issue because the saved torch files are creating the aliasing, and so it is actually unpacked during loading, I think all 4 (only checked llama for now) have an incorrect If we make the hard error a simple warning, that would just lead to wrong models reloaded from safetensors. (The weight will get ignored so no warning to user, and yet the weights won't be tied to the output head will be random) For LLama:
This is true in pure torch world and has nothing to do with safetensors. It happens to be a minor issue because we currently save the alias. Normally, we're saved because the convertion script will disallow this convertion (since reloaded model is incorrect). For these 4 models, provided they are the same issue, either we need to fix the configuation, and retie the weights (which would make the current proposed fix just work) or actually remove the |
Yet we went from 90 failures to just 4 models. I'm not saying Transformers is perfect and does not need any fix at all. Even with safetensors enabling save of state dictionaries having tied weights we should make sure we only save one of those weights to have the most efficient safetensors checkpoints. I'm just highlighting that an API that is too rigid will never be broadly used, so I really think safetensors should add support for bypassing this hard error. |
(Also can confirm that the embeddings and LM head are different tensors for Llama-7b at least, so the _key_to_ignore_on_load_missing is just wrongly set) |
Confirmed on ImageGPT it's the same.
I respectfully disagree. You're not wrong, but I really think it's not the case here (simply allowing it is just allowing ourselves to shoot in the foot). |
In any case, a fix will need to be different than what is suggested in the PR: the So Likewise you can't write code like in this PR that always deletes keys based on What could be done instead is during saving with safetensors:
We might need a new class attribute in XxxPreTrainedModel for the edge case where the main tied parameter is not the first one as returned by But the code needs to be dynamic (depending on the actual model seen) not static (in the sense that it uses the class variables). |
But we need to know which ones are actually used to recreate the others then. There's a main weight, and the others are deduced from the others. At least to properly not get a warning.
Couldn't it remove keys that are both in the In general I'm confused about having weights untied at runtime, since if you untied the, save your model, erase the tied weights then you would reload an not have a warning and getting an erroneous model.
Nice, but I don't think a full function from a dependency is necessary for that : # Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
ptrs[v.data_ptr()].append(k)
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1} Is enough.
I confirmed. It just erroneously raises a warning, but the underlying model is fine. |
You can take those names as suggestions but you will still need to leave only one weight per group of tied parameters or risk getting an error from safetensors. While you are fine with
Like I said Accelerate is becoming a torch dependency anyway (since the Trainer will be rewritten to use it), so I don't see how it's wrong to use it. Your snippet of code will not present the groups of shared parameters (T5 as 4 of them tied together) as nicely, and you'd need to add tests for it (whereas Accelerate already heavily tests its utils).
I have no idea what this means. Are you referring to the situation where a user breaks the tie weights connection somehow without changing the model config and then save the weights and reloads the model with |
src/transformers/modeling_utils.py
Outdated
for _, names in shared_ptrs.items(): | ||
for name in names: | ||
for pat in self._keys_to_ignore_on_load_missing: | ||
if re.search(pat, name): | ||
if name in state_dict: | ||
del state_dict[name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is roughly copy-pasted from from_pretrained
just we drop only the shared keys (which allows dynamically unlinked tensors to go through as-is)
@@ -1238,8 +1238,28 @@ def __init__(self, config: Blip2Config): | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
|
|||
def get_input_embeddings(self) -> nn.Module: | |||
return self.vision_model.embeddings.patch_embedding | |||
def get_input_embeddings(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada As seen offline, but could you confirm it's OK ?
It's copy pasted from Blip2ForConditionalGeneration
. Without those, the tie_weights seems broken after load (This isn't safetensors specific and could be a separate PR)
@@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput): | |||
|
|||
|
|||
def _get_clones(module, N): | |||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |||
return nn.ModuleList([module for i in range(N)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enables the dynamic loading to work properly.
If this is left as-is, what happens, is that during model init, all n modules are created pointing to the same parameter, but after load_state_dict
only the first layer gets updated
@@ -1778,7 +1778,7 @@ def forward( | |||
) | |||
class DetaForObjectDetection(DetaPreTrainedModel): | |||
# When using clones, all layers > 0 will be clones, but layer 0 *is* required | |||
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"] | |||
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*", "model.decoder"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.decoder
can share bbox_embed
and class_embed
.
It's not necessarily the case and depends on the config.
This is the sort of model which make the dynamic nature of the PR crucial (drop only shared tensors).
The debt we're creating here, is that if a file is missing some weights (like model.decoder.bbox_embed) but the config is set to not sharing (or modified on disk). Then the warning will not be shown yet keeping random weights on the model.
I don't see any good ways to solve this since _keys_to_ignore
is not config dependent.
However, since the links are properly made a init time (so when the config is allowed) the issue will only arise when users use mismatched config and weights.
Probably and acceptable choice.
@@ -630,8 +630,6 @@ def custom_forward(*inputs): | |||
|
|||
|
|||
class LlamaForCausalLM(LlamaPreTrainedModel): | |||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has to be removed since LLama doesnt' share the embedding and the output.
@ArthurZucker for final confirmation ?
@@ -357,9 +357,10 @@ def __init__( | |||
initializer_factor=1.0, | |||
initializer_range=0.02, | |||
is_vqa=False, | |||
tie_word_embeddings=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada for confirmation.
Pix2Struct has Vision, Text, and Global model (and therefore config).
The text properly sets the tie_word_embeddings to False, but the global one didn't and therefore the global model would forcefully set the word embeddings to tied even when it shouldn't.
I think this is also independant from this safetensors PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good for the modifications about Blip2 & Pix2struct! Thanks a mile for double checking
@sgugger All tests are now passing with relative minor code changes. I think we could push out some fixes to their respective PRs (blip2, pix2struct, llama) since what this uncovered seems to really be affecting current models. For deta I think since it's marked exotic the proposed fix could work. And just for note, my insistence for disallowing aliasing doesn't come from nowhere. If you do :
Then necessarily the tensors aren't shared, while they are if you did So enabling aliasing forces safetensors to give up lazy loading. The bar to do that is pretty high in my mind since lazy loading is a very nice feature we get out of it. Note: silently dropping tensors on save in safetensors will necessarily lead to bugs in transformers too that's why I'm not considering it as an option. (Since the reloaded file will be wrong) |
I'm not sure why you are ignoring the comments I made with respect to this PR and safetensors as it is now and go back to defend your choice of API for safetensors (which I still think is wrong but I'm done debating this). So once again:
|
I've done that. Adding the necessary other piece which is dropping missing keys on shared tensors regardless of the Doing both allows to remove the needs of the deta key modification. (Still needs to fix the deepcopy, again nothing to do with safetensors, but the parameters are cloned and not shared and so the tensors are not properly filled for layers > 1 without the fix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating, I have a few more comments.
src/transformers/modeling_utils.py
Outdated
# These are all the pointers of shared tensors. | ||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||
warn_names = set() | ||
for _, names in shared_ptrs.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should iterate on .values()
if you don't want the keys.
src/transformers/modeling_utils.py
Outdated
# load. This allows to make sure the name which is kept is consistent. | ||
if self._keys_to_ignore_on_load_missing is not None: | ||
for name in names: | ||
for pat in self._keys_to_ignore_on_load_missing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for the for loop and two if statements, just do:
if name in state_dict and any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing):
del state_dict[name]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did but split in two lines for readability.
src/transformers/modeling_utils.py
Outdated
# When not all duplicates have been cleaned | ||
# Still remove those keys, but put a clear warning | ||
# Since if the link between tensors was done at runtime | ||
# then `from_pretrained` will still not get the key back | ||
# Leading to random tensor. With a proper warning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have a line width of 119 chars in Transformers, no need to take 5 lines for this comment.
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger If you want to do a final check (maybe we want a global |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the warn_once is already implemented in the Transformers logger, that is what I was suggesting you to use.
src/transformers/modeling_utils.py
Outdated
del state_dict[name] | ||
warn_names.add(name) | ||
if len(warn_names) > 0: | ||
warn_once( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As suggested before, please use logger.warn_once
;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't exist:
FAILED tests/models/deta/test_modeling_deta.py::DetaModelTest::test_can_use_safetensors - Exception: Class DetaForObjectDetection cannot be saved using safetensors: 'Logger' object has no attribute 'warn_once'
Is the logger improperly done here ? Also I couldn't find the symbol anywhere, is it recent ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh... warning_once
. Thanks @younesbelkada :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, mixed up the name.
…gface#22437) * Making sure we can use safetensors to serialize all the time. * Expanding the tests for increased coverage. * Update the test. * Getting current state of affairs. * Tentative fix. * Fixing black version. * Fixing the worst offenders. * Try to modify less files. * Fixing blip_2 (Weird solution right now). * Fixing deta. * Fix blip ? * Missing extra newline. * No deta modification. * Adding some comments. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Addressing comments. * Addressing comments. * creating warn_once. * Warning_once ! --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Hey @Narsil The doctest for ###previous results {'scores': tensor([0.6831, 0.6826, 0.5684, 0.5464], grad_fn=<IndexBackward0>), 'labels': tensor([17, 17, 75, 75]), 'boxes': tensor([[345.8479, 23.6753, 639.8561, 372.8265],
[ 8.7996, 52.4945, 316.9348, 473.4509],
[ 40.0171, 73.7522, 175.9579, 117.3332],
[333.6797, 77.1251, 370.1172, 187.5138]], grad_fn=<IndexBackward0>)} ###now {'scores': tensor([], grad_fn=<IndexBackward0>), 'labels': tensor([], dtype=torch.int64), 'boxes': tensor([], size=(0, 4), grad_fn=<IndexBackward0>)}
|
Another one affected tests/models/vit/test_modeling_vit.py::ViTModelIntegrationTest::test_inference_fp16
(line 136) ValueError: weight is on the meta device, we need a value to put in on 1. Full traceself = <tests.models.vit.test_modeling_vit.ViTModelIntegrationTest testMethod=test_inference_fp16>
@slow
@require_accelerate
@require_torch_gpu
def test_inference_fp16(self):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
> model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")
tests/models/vit/test_modeling_vit.py:324:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/modeling_utils.py:2760: in from_pretrained
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
/usr/local/lib/python3.8/dist-packages/accelerate/big_modeling.py:370: in dispatch_model
attach_align_device_hook_on_blocks(
/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:478: in attach_align_device_hook_on_blocks
add_hook_to_module(module, hook)
/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:155: in add_hook_to_module
module = hook.init_hook(module)
/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:251: in init_hook
set_module_tensor_to_device(module, name, self.execution_device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = Linear(in_features=384, out_features=384, bias=True), tensor_name = 'weight', device = 0, value = None, dtype = None
def set_module_tensor_to_device(
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
param_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
dtype (`torch.dtype`, *optional*):
If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
the dtype of the existing parameter in the model.
"""
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
> raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
E ValueError: weight is on the meta device, we need a `value` to put in on 0. |
…gface#22437) * Making sure we can use safetensors to serialize all the time. * Expanding the tests for increased coverage. * Update the test. * Getting current state of affairs. * Tentative fix. * Fixing black version. * Fixing the worst offenders. * Try to modify less files. * Fixing blip_2 (Weird solution right now). * Fixing deta. * Fix blip ? * Missing extra newline. * No deta modification. * Adding some comments. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Addressing comments. * Addressing comments. * creating warn_once. * Warning_once ! --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
…gface#22437) * Making sure we can use safetensors to serialize all the time. * Expanding the tests for increased coverage. * Update the test. * Getting current state of affairs. * Tentative fix. * Fixing black version. * Fixing the worst offenders. * Try to modify less files. * Fixing blip_2 (Weird solution right now). * Fixing deta. * Fix blip ? * Missing extra newline. * No deta modification. * Adding some comments. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Addressing comments. * Addressing comments. * creating warn_once. * Warning_once ! --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
…gface#22437) * Making sure we can use safetensors to serialize all the time. * Expanding the tests for increased coverage. * Update the test. * Getting current state of affairs. * Tentative fix. * Fixing black version. * Fixing the worst offenders. * Try to modify less files. * Fixing blip_2 (Weird solution right now). * Fixing deta. * Fix blip ? * Missing extra newline. * No deta modification. * Adding some comments. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Addressing comments. * Addressing comments. * creating warn_once. * Warning_once ! --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
…ce#22750) fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* [Pix2struct] Simplify generation (#22527) * Add model to doc tests * Remove generate and replace by prepare_inputs_for_generation * More fixes * Remove print statements * Update integration tests * Fix generate * Remove model from auto mapping * Use auto processor * Fix integration tests * Fix test * Add inference code snippet * Remove is_encoder_decoder * Update docs * Remove notebook link * Release: v4.28.0 * Revert (for now) the change on `Deta` in #22437 (#22750) fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Patch release: v4.28.1 * update zh chat template. * Update docs/source/zh/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/zh/_toctree.yml Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Michael <haifeng.yao@daocloud.io>
* [Pix2struct] Simplify generation (#22527) * Add model to doc tests * Remove generate and replace by prepare_inputs_for_generation * More fixes * Remove print statements * Update integration tests * Fix generate * Remove model from auto mapping * Use auto processor * Fix integration tests * Fix test * Add inference code snippet * Remove is_encoder_decoder * Update docs * Remove notebook link * Release: v4.28.0 * Revert (for now) the change on `Deta` in #22437 (#22750) fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Patch release: v4.28.1 * update zh chat template. * Update docs/source/zh/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/zh/_toctree.yml Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> * Update docs/source/zh/chat_templating.md Co-authored-by: Michael <haifeng.yao@daocloud.io> --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Michael <haifeng.yao@daocloud.io>
What does this PR do?
Making sure
save_pretrained(..., safe_serialization=True)
works in allcases.
It seems
_keys_to_ignore_on_load_missing
was the only one to be set,and so
save_pretrained
does not properly ignore those keys on saving.Status before the fix:
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.