Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Adding exclude modules param(#2044) #2102

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

JINO-ROHIT
Copy link

This draft PR addresses the exclude modules parameter to keep specific layers out of target modules

Addresses #2044

Changes made:

  1. Changed LoraConfig class to take in exclude_modules param.
  2. Added the exclusion logic within the check_target_module_exists method.

@BenjaminBossan Can you have a look if this is the right direction? This is my first time contributing at peft , would love to have your guidance here :)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking on this feature. I haven't done an in-depth review yet, but from a quick look, I saw this:

  1. Right now, you only added the option to the LoraConfig. However, since the logic is implemented on the check_target_module_exists function, actually all methods can benefit from this feature (not the prompt learning ones, as they don't have target_modules, but all the rest like LoKr, BOFT, IA³, etc.). Let's add the argument to all relevant configs, i.e. the config of each method that has target_modules.
  2. We need to have tests for this new feature. Please check https://github.com/huggingface/peft/blob/main/tests/test_tuners_utils.py, which already contains a bunch of tests to check that the correct modules are targeted. Similar tests would need to be added for exclude_modules. Could you please take a look? Don't hesitate to ask if you have questions.

@JINO-ROHIT
Copy link
Author

@BenjaminBossan thanks for the headsup, ill work on the points

@JINO-ROHIT
Copy link
Author

@BenjaminBossan ive added the test cases too, lmk what you think

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the latest changes. The tests are already shaping up quite nicely, but let's improve the test coverage a bit more (see my suggestion).

Moreover, you missed a few configs still. These are all I could find, but please double-check:

  • AdaLoraConfig
  • BoftConfig
  • VBLoRAConfig
  • FourierFTConfig
  • LNTuningConfig
  • PolyConfig

@@ -415,6 +415,35 @@ def test_realistic_example(self):
]
assert model.targeted_module_names == expected

class TestExcludedModuleNames(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get additional tests:

  1. Let's add tests for exclude_modules being a list of str.
  2. What happens if all targeted modules are excluded? I think a nice error message would be the right call in this situation.
  3. What happens if we have excluded modules that don't exist? I think it's fine if some of the entries from exclude_modules don't match, as we also allow that for target_modules. But what if none of the entries in exclude_modules matches? I think this should also result in an error. WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap makes sense, let me work through the tests and edge cases, also the missed out modules

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @BenjaminBossan i managed to add 1 and 2 cases. I wanted some help on how 3 should be done. As in, I see currently _check_target_module_exists is a private abstract method that every peft methods override and check for target modules existence. Now for exclude_modules, should we do the same? Or do i open a for loop within inject_adapter and check for keys there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to go that far. If we check closer, we can see e.g. for LoRA:

@staticmethod
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)

And if we check where that's used, it all comes back to this line:

if not self._check_target_module_exists(peft_config, key):

So what would be an elegant way to implement some sanity checks here? I have one proposal, but maybe you have a better idea: Right now, check_target_module_exists simply returns a bool, i.e. True or False if the module matched. We could extend that and, say, return a special object if a module was excluded due to exclude_modules.

This special value could be a very simple object. Its __bool__ should return False (so that users who rely on checking if not self._check_target_module_exists(peft_config, key) don't suddenly get different results).

When we call _check_target_module_exists, we can collect keys that did not match because of exclude_modules in one list and keys that did not match for other reasons in another list. When we exit the loop, if no key was matched because they were all matching exclude_modules, we can raise a nice error message to say so. Otherwise, we raise the existing error:

if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)

On the other hand, if not a single key was excluded due to exclude_modules, even though exclude_modules was passed by the user, we can give a nice warning saying something like "You passed exclude_modules=[...] but no key was matching this".

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your further work on this PR. You raised a good question about how to implement the checks I mentioned. I made a proposal, but please LMK if you have a better idea.

@@ -914,6 +914,20 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
`bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or
None if no match found
"""
if config.target_modules and config.exclude_modules:
if config.exclude_modules == config.target_modules:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that this is a sufficient check. Just to give some examples:

  • target_modules is a list of str and exclude_modules is a str, but both essentially target the same modules
  • target_modules is ["lin0", "lin1", "lin2"] and exclude_modules is ["lin0", "lin1"]. They look different, but if "lin2" doesn't exist, they're actually the same

We thus need to check what modules were actually targeted, we can't rely on the passed arguments.

@@ -415,6 +415,35 @@ def test_realistic_example(self):
]
assert model.targeted_module_names == expected

class TestExcludedModuleNames(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to go that far. If we check closer, we can see e.g. for LoRA:

@staticmethod
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)

And if we check where that's used, it all comes back to this line:

if not self._check_target_module_exists(peft_config, key):

So what would be an elegant way to implement some sanity checks here? I have one proposal, but maybe you have a better idea: Right now, check_target_module_exists simply returns a bool, i.e. True or False if the module matched. We could extend that and, say, return a special object if a module was excluded due to exclude_modules.

This special value could be a very simple object. Its __bool__ should return False (so that users who rely on checking if not self._check_target_module_exists(peft_config, key) don't suddenly get different results).

When we call _check_target_module_exists, we can collect keys that did not match because of exclude_modules in one list and keys that did not match for other reasons in another list. When we exit the loop, if no key was matched because they were all matching exclude_modules, we can raise a nice error message to say so. Otherwise, we raise the existing error:

if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)

On the other hand, if not a single key was excluded due to exclude_modules, even though exclude_modules was passed by the user, we can give a nice warning saying something like "You passed exclude_modules=[...] but no key was matching this".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants