Skip to content

Commit

Permalink
FIX Bug in target module optimization if suffix (#2144)
Browse files Browse the repository at this point in the history
Solves the following bug:

huggingface/diffusers#9622 (comment)

The cause for the bug is as follows: When we have, say, a module called
"bar.0.query" that we want to target and another module called
"foo_bar.0.query" that we don't want to target, there was potential for
an error. This is not caused by _find_minimal_target_modules directly,
but rather the bug was inside of BaseTuner.inject_adapter and how the
names_no_target were chosen. Those used to be chosen based on suffix. In
our example, however, "bar.0.query" is a suffix of "foo_bar.0.query",
therefore "foo_bar.0.query" was *not* added to names_no_target when it
should have. As a consequence, during the optimization, it looks like
"query" is safe to use as target_modules because we don't see that it
wrongly matches "foo_bar.0.query".
  • Loading branch information
BenjaminBossan authored Oct 10, 2024
1 parent 0aa7e3a commit c925d0a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,9 @@ def inject_adapter(
and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION
):
names_no_target = [
name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules)
name
for name in key_list
if not any((name == suffix) or name.endswith("." + suffix) for suffix in peft_config.target_modules)
]
new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target)
if len(new_target_modules) < len(peft_config.target_modules):
Expand Down
42 changes: 42 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,3 +1400,45 @@ def test_suffix_is_substring_of_other_suffix(self):
expected = {"time_emb_proj", "proj", "proj_out"}
result = find_minimal_target_modules(target_modules, other_module_names)
assert result == expected

def test_get_peft_modules_module_name_is_suffix_of_another_module(self):
# Solves the following bug:
# https://github.com/huggingface/diffusers/pull/9622#issuecomment-2404789721

# The cause for the bug is as follows: When we have, say, a module called "bar.0.query" that we want to target
# and another module called "foo_bar.0.query" that we don't want to target, there was potential for an error.
# This is not caused by _find_minimal_target_modules directly, but rather the bug was inside of
# BaseTuner.inject_adapter and how the names_no_target were chosen. Those used to be chosen based on suffix. In
# our example, however, "bar.0.query" is a suffix of "foo_bar.0.query", therefore "foo_bar.0.query" was *not*
# added to names_no_target when it should have. As a consequence, during the optimization, it looks like "query"
# is safe to use as target_modules because we don't see that it wrongly matches "foo_bar.0.query".

# ensure that we have sufficiently many modules to trigger the optimization
n_layers = MIN_TARGET_MODULES_FOR_OPTIMIZATION + 1

class InnerModule(nn.Module):
def __init__(self):
super().__init__()
self.query = nn.Linear(10, 10)

class OuterModule(nn.Module):
def __init__(self):
super().__init__()
# note that "transformer_blocks" is a suffix of "single_transformer_blocks"
self.transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])
self.single_transformer_blocks = nn.ModuleList([InnerModule() for _ in range(n_layers)])

# we want to match all "transformer_blocks" layers but not "single_transformer_blocks"
target_modules = [f"transformer_blocks.{i}.query" for i in range(n_layers)]
model = get_peft_model(OuterModule(), LoraConfig(target_modules=target_modules))

# sanity check: we should have n_layers PEFT layers in model.transformer_blocks
transformer_blocks = model.base_model.model.transformer_blocks
assert sum(isinstance(module, BaseTunerLayer) for module in transformer_blocks.modules()) == n_layers

# we should not have any PEFT layers in model.single_transformer_blocks
single_transformer_blocks = model.base_model.model.single_transformer_blocks
assert not any(isinstance(module, BaseTunerLayer) for module in single_transformer_blocks.modules())

# target modules should *not* be simplified to "query" as that would match "single_transformers_blocks" too
assert model.peft_config["default"].target_modules != {"query"}

0 comments on commit c925d0a

Please sign in to comment.