Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Add warning for missing key in LoRA adapter #34068

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,29 @@ def load_adapter(
)

if incompatible_keys is not None:
# check only for unexpected keys
err_msg = ""
origin_name = peft_model_id if peft_model_id is not None else "state_dict"
# Check for unexpected keys.
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
logger.warning(
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
f" {incompatible_keys.unexpected_keys}. "
err_msg = (
f"Loading adapter weights from {origin_name} led to unexpected keys not found in the model: "
f"{', '.join(incompatible_keys.unexpected_keys)}. "
)

# Check for missing keys.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
# Filter missing keys specific to the current adapter, as missing base model keys are expected.
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
err_msg += (
f"Loading adapter weights from {origin_name} led to missing keys in the model: "
f"{', '.join(lora_missing_keys)}"
)

if err_msg:
logger.warning(err_msg)

# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
if (
(getattr(self, "hf_device_map", None) is not None)
Expand Down
78 changes: 76 additions & 2 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from huggingface_hub import hf_hub_download
from packaging import version

from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers import AutoModelForCausalLM, OPTForCausalLM, logging
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
require_peft,
require_torch,
Expand Down Expand Up @@ -72,9 +73,15 @@ def test_peft_from_pretrained(self):
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
should correctly load a model that has adapters injected on it.
"""
logger = logging.get_logger("transformers.integrations.peft")

for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
with CaptureLogger(logger) as cl:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
# ensure that under normal circumstances, there are no warnings about keys
self.assertNotIn("unexpected keys", cl.out)
self.assertNotIn("missing keys", cl.out)

self.assertTrue(self._check_lora_correctly_converted(peft_model))
self.assertTrue(peft_model._hf_peft_config_loaded)
Expand Down Expand Up @@ -548,3 +555,70 @@ def test_peft_from_pretrained_hub_kwargs(self):

model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))

def test_peft_from_pretrained_unexpected_keys_warning(self):
"""
Test for warning when loading a PEFT checkpoint with unexpected keys.
"""
from peft import LoraConfig

logger = logging.get_logger("transformers.integrations.peft")

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)

# add unexpected key
dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values()))

with CaptureLogger(logger) as cl:
model.load_adapter(
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
)

msg = "Loading adapter weights from state_dict led to unexpected keys not found in the model: foobar"
self.assertIn(msg, cl.out)

def test_peft_from_pretrained_missing_keys_warning(self):
"""
Test for warning when loading a PEFT checkpoint with missing keys.
"""
from peft import LoraConfig

logger = logging.get_logger("transformers.integrations.peft")

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)

# remove a key so that we have missing keys
key = next(iter(dummy_state_dict.keys()))
del dummy_state_dict[key]

with CaptureLogger(logger) as cl:
model.load_adapter(
adapter_state_dict=dummy_state_dict,
peft_config=peft_config,
low_cpu_mem_usage=False,
adapter_name="other",
)

# Here we need to adjust the key name a bit to account for PEFT-specific naming.
# 1. Remove PEFT-specific prefix
# If merged after dropping Python 3.8, we can use: key = key.removeprefix(peft_prefix)
peft_prefix = "base_model.model."
key = key[len(peft_prefix) :]
# 2. Insert adapter name
prefix, _, suffix = key.rpartition(".")
key = f"{prefix}.other.{suffix}"

msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
self.assertIn(msg, cl.out)
Loading