Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve documentation for IA³ #984

Merged
merged 11 commits into from
Nov 7, 2023
26 changes: 19 additions & 7 deletions docs/source/conceptual_guides/ia3.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ Being similar to LoRA, IA3 carries many of the same advantages:
* Performance of models fine-tuned using IA3 is comparable to the performance of fully fine-tuned models.
* IA3 does not add any inference latency because adapter weights can be merged with the base model.

In principle, IA3 can be applied to any subset of weight matrices in a neural network to reduce the number of trainable
parameters. Following the authors' implementation, IA3 weights are added to the key, value and feedforward layers
of a Transformer model. Given the target layers for injecting IA3 parameters, the number of trainable parameters
can be determined based on the size of the weight matrices.
In principle, IA3 can be applied to any subset of weight matrices in a neural network to reduce the number of trainable
parameters. Following the authors' implementation, IA3 weights are added to the key, value and feedforward layers
of a Transformer model. To be specific, for transformer models, IA3 weights are added to the outputs of key and value layers, and to the input of the second feedforward layer
in each transformer block.

Given the target layers for injecting IA3 parameters, the number of trainable parameters
can be determined based on the size of the weight matrices.


## Common IA3 parameters in PEFT
Expand All @@ -43,10 +46,19 @@ As with other methods supported by PEFT, to fine-tune a model using IA3, you nee
3. Wrap the base model with `get_peft_model()` to get a trainable `PeftModel`.
4. Train the `PeftModel` as you normally would train the base model.

`IA3Config` allows you to control how IA3 is applied to the base model through the following parameters:
`IA3Config` allows you to control how IA3 is applied to the base model through the following parameters:

- `target_modules`: The modules (for example, attention blocks) to apply the IA3 vectors.
- `feedforward_modules`: The list of modules to be treated as feedforward layers in `target_modules`. While learned vectors are multiplied with
the output activation for attention blocks, the vectors are multiplied with the input for classic feedforward layers.
- `feedforward_modules`: The list of modules to be treated as feedforward layers in `target_modules`. While learned vectors are multiplied with
the output activation for attention blocks, the vectors are multiplied with the input for classic feedforward layers. Note that `feedforward_modules` must be a subset of `target_modules`.
- `modules_to_save`: List of modules apart from IA3 layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.

## Example Usage

For the task of sequence classification, one can initialize the IA3 config for a Llama model as follows:

```py
peft_config = IA3Config(
task_type=TaskType.SEQ_CLS, target_modules=["k_proj", "v_proj", "down_proj"], feedforward_modules=["down_proj"]
)
```
9 changes: 8 additions & 1 deletion src/peft/tuners/ia3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class IA3Config(PeftConfig):
target_modules (`Union[List[str],str]`):
The names of the modules to apply (IA)^3 to.
feedforward_modules (`Union[List[str],str]`):
The names of the modules to be treated as feedforward modules, as in the original paper.
The names of the modules to be treated as feedforward modules, as in the original paper. These modules will
have (IA)^3 vectors multiplied to the input, instead of the output. feedforward_modules must be a name or a
subset of names present in target_modules.
fan_in_fan_out (`bool`):
Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
Expand Down Expand Up @@ -78,3 +80,8 @@ def __post_init__(self):
self.feedforward_modules = (
set(self.feedforward_modules) if isinstance(self.feedforward_modules, list) else self.feedforward_modules
)

# check if feedforward_modules is a subset of target_modules. run the check only if both are sets
if isinstance(self.feedforward_modules, set) and isinstance(self.target_modules, set):
if not self.feedforward_modules.issubset(self.target_modules):
raise ValueError("`feedforward_modules` should be a subset of `target_modules`")
28 changes: 28 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,31 @@ def test_regex_with_layer_indexing_lora(self):

# should run without errors
LoraConfig(**valid_config)

def test_ia3_is_feedforward_subset_invalid_config(self):
# This test checks that the IA3 config raises a value error if the feedforward_modules argument
# is not a subset of the target_modules argument

# an example invalid config
invalid_config = {"target_modules": ["k", "v"], "feedforward_modules": ["q"]}

with self.assertRaisesRegex(
ValueError, expected_regex="^`feedforward_modules` should be a subset of `target_modules`$"
):
IA3Config(**invalid_config)

def test_ia3_is_feedforward_subset_valid_config(self):
# This test checks that the IA3 config is created without errors with valid arguments.
# feedforward_modules should be a subset of target_modules if both are lists

# an example valid config with regex expressions.
valid_config_regex_exp = {
"target_modules": ".*.(SelfAttention|EncDecAttention|DenseReluDense).*(q|v|wo)$",
"feedforward_modules": ".*.DenseReluDense.wo$",
}
# an example valid config with module lists.
valid_config_list = {"target_modules": ["k", "v", "wo"], "feedforward_modules": ["wo"]}

# should run without errors
IA3Config(**valid_config_regex_exp)
IA3Config(**valid_config_list)
Loading