Skip to content

Commit

Permalink
[core] Fix safetensors serialization for shared tensors (#1101)
Browse files Browse the repository at this point in the history
* fix st serialization

* add test

* add CI test

* add comment
  • Loading branch information
younesbelkada authored Nov 9, 2023
1 parent c5d9485 commit b5641cc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

import collections
import inspect
import os
import warnings
Expand All @@ -31,6 +32,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.pytorch_utils import id_tensor_storage
from transformers.utils import PushToHubMixin

from . import __version__
Expand Down Expand Up @@ -168,6 +170,8 @@ def save_pretrained(
save_directory (`str`):
Directory where the adapter model and configuration files will be saved (will be created if it does not
exist).
safe_serialization (`bool`, *optional*):
Whether to save the adapter files in safetensors format.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the `push_to_hub` method.
"""
Expand Down Expand Up @@ -199,6 +203,28 @@ def save_pretrained(
os.makedirs(output_dir, exist_ok=True)

if safe_serialization:
# Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in output_state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

for _, names in shared_ptrs.items():
# Here we just clone the shared tensors to avoid tensor aliasing which is
# not supported in safetensors.
for shared_tensor_name in names[1:]:
output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone()

safe_save_file(
output_state_dict,
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
Expand Down
16 changes: 16 additions & 0 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest

import pytest
Expand All @@ -22,6 +23,7 @@
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
BitsAndBytesConfig,
LlamaForCausalLM,
Expand All @@ -33,6 +35,7 @@
IA3Config,
LoraConfig,
PeftModel,
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
Expand Down Expand Up @@ -631,3 +634,16 @@ def test_4bit_merge_and_disable_lora(self):
self.assertTrue(isinstance(model, PeftModel))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit))

@require_torch_gpu
@pytest.mark.single_gpu_tests
def test_serialization_shared_tensors(self):
model_checkpoint = "roberta-base"
peft_config = LoraConfig(
task_type=TaskType.TOKEN_CLS, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias="all"
)
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=11).to("cuda")
model = get_peft_model(model, peft_config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
22 changes: 21 additions & 1 deletion tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest

import torch
from parameterized import parameterized
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoModelForSeq2SeqLM, AutoModelForTokenClassification

from peft import LoraConfig, TaskType, get_peft_model

from .testing_common import PeftCommonTester, PeftTestConfigManager

Expand Down Expand Up @@ -172,3 +175,20 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c
)
def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_disable_adapter(model_id, config_cls, config_kwargs)


class PeftEncoderDecoderCustomModelTester(unittest.TestCase):
"""
A custom class to write any custom test related with Enc-Dec models
"""

def test_save_shared_tensors(self):
model_id = "hf-internal-testing/tiny-random-RobertaModel"
peft_config = LoraConfig(
task_type=TaskType.TOKEN_CLS, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias="all"
)
model = AutoModelForTokenClassification.from_pretrained(model_id, num_labels=11)
model = get_peft_model(model, peft_config)
with tempfile.TemporaryDirectory() as tmp_dir:
# This should work fine
model.save_pretrained(tmp_dir, safe_serialization=True)

0 comments on commit b5641cc

Please sign in to comment.