Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Bug in target module optimization if child module name is suffix of parent module name #2144

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,9 @@ def inject_adapter(
and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION
):
names_no_target = [
name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules)
name
for name in key_list
if not any((name == suffix) or name.endswith("." + suffix) for suffix in peft_config.target_modules)
]
new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target)
if len(new_target_modules) < len(peft_config.target_modules):
Expand Down
42 changes: 42 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,3 +1400,45 @@ def test_suffix_is_substring_of_other_suffix(self):
expected = {"time_emb_proj", "proj", "proj_out"}
result = find_minimal_target_modules(target_modules, other_module_names)
assert result == expected

def test_get_peft_modules_module_name_is_suffix_of_another_module(self):
# Solves the following bug:
# https://github.com/huggingface/diffusers/pull/9622#issuecomment-2404789721

# The cause for the bug is as follows: When we have, say, a module called "bar.0.query" that we want to target
# and another module called "foo_bar.0.query" that we don't want to target, there was potential for an error.
# This is not caused by _find_minimal_target_modules directly, but rather the bug was inside of
# BaseTuner.inject_adapter and how the names_no_target were chosen. Those used to be chosen based on suffix. In
# our example, however, "bar.0.query" is a suffix of "foo_bar.0.query", therefore "foo_bar.0.query" was *not*
# added to names_no_target when it should have. As a consequence, during the optimization, it looks like "query"
# is safe to use as target_modules because we don't see that it wrongly matches "foo_bar.0.query".

# ensure that we have sufficiently many modules to trigger the optimization
n_layers = MIN_TARGET_MODULES_FOR_OPTIMIZATION + 1

class InnerModule(nn.Module):
def __init__(self):
super().__init__()
self.query = nn.Linear(10, 10)

class OuterModule(nn.Module):
def __init__(self):
super().__init__()
# note that "transformer_blocks" is a suffix of "single_transformer_blocks"
self.transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])
self.single_transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])

# we want to match all "transformer_blocks" layers but not "single_transformer_blocks"
target_modules = [f"transformer_blocks.{i}.query" for i in range(n_layers)]
model = get_peft_model(OuterModule(), LoraConfig(target_modules=target_modules))

# sanity check: we should have n_layers PEFT layers in model.transformer_blocks
transformer_blocks = model.base_model.model.transformer_blocks
assert sum(isinstance(module, BaseTunerLayer) for module in transformer_blocks.modules()) == n_layers

# we should not have any PEFT layers in model.single_transformer_blocks
single_transformer_blocks = model.base_model.model.single_transformer_blocks
assert not any(isinstance(module, BaseTunerLayer) for module in single_transformer_blocks.modules())

# target modules should *not* be simplified to "query" as that would match "single_transformers_blocks" too
assert model.peft_config["default"].target_modules != {"query"}
Loading