Skip to content

Commit 293aea5

Browse files
Support for Activated LoRA (#2609)
This PR migrates Activated LoRA (aLoRA) support from a standalone Github (see above) to PEFT itself. Note there is also an active PR for vLLM inference support for Activated LoRA: vllm-project/vllm#19710 . There are also collections of aLoRA models on huggingface (in the ibm-granite org), note that these preexisting models run off of the standalone github repo and will be updated to work with this new PEFT feature if merged. Description of changes: Activated LoRA is a modification of the LoRA architecture to "activate" the adapter weights only on tokens coming after a specified invocation_string. This fact makes it so that KV values for the string coming before the activation matches KV values for the base model. This allows KV cache for the input to be interchangeable between the base model and adapter model, and allows for major speedups in inference pipelines (e.g. agentic pipelines) that want to use both base models and adapter models. See the paper for detailed exploration of use cases and further elaboration. Other notes: The crux of the changes are really in layer.py. Everything else is simply managing the alora_offsets quantity which defines where the weights start to be activated. This is determined by scanning input strings for the invocation_string defined in the aLoraConfig. I believe that aLoRA really only makes sense for CausalLMs, hence I've only implemented this for that model type. Merging doesn't make sense for aLoRA adapters since the weights are not universally applied to all tokens. I used the LoRA code as a starting point, but did not implement various seemingly extra features in that code. As of now, invocation_string should probably start and end with special tokens, to avoid tokenizer issues at the boundary. Open to suggestions on how to make this more general if needed. --------- Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
1 parent a3197b1 commit 293aea5

File tree

14 files changed

+1188
-101
lines changed

14 files changed

+1188
-101
lines changed

docs/source/developer_guides/lora.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,108 @@ from peft import LoraConfig
173173

174174
config = LoraConfig(use_rslora=True, ...)
175175
```
176+
### Activated LoRA (aLoRA)
177+
178+
Activated LoRA (aLoRA) is a low rank adapter architecture for Causal LMs that allows for reusing existing base model KV cache for more efficient inference. This approach is best suited for inference pipelines which rely on the base model for most tasks/generations, but use aLoRA adapter(s) to perform specialized task(s) within the chain. For example, checking or correcting generated outputs of the base model. In these settings, inference times can be sped up by an order of magnitude or more. For more information on aLoRA and many example use cases, see https://huggingface.co/papers/2504.12397.
179+
180+
This technique scans for the last occurence of an invocation sequence (`alora_invocation_tokens`) in each input (this can be as short as 1 token), and activates the adapter weights on tokens starting with the beginning of the invocation sequence (any inputs after the invocation sequence are also adapted, and all generated tokens will use the adapted weights). Weights on prior tokens are left un-adapted -- making the cache for those tokens interchangeable with base model cache due to the causal attention mask in Causal LMs. Usage is very similar to standard LoRA, with the key difference that this invocation sequence must be specified when the adapter is created:
181+
182+
```py
183+
from peft import LoraConfig
184+
185+
config = LoraConfig(alora_invocation_tokens=alora_invocation_tokens, task_type="CAUSAL_LM", ...)
186+
```
187+
188+
where `alora_invocation_tokens` is a list of integer token ids. Given a desired invocation string, this can be obtained as
189+
```
190+
invocation_string = "placeholder"
191+
alora_invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False).
192+
```
193+
where the tokenizer is the tokenizer for the base model. Note that we have `add_special_tokens=False` to avoid adding SOS/EOS tokens in our search string (which will most likely cause failure to find).
194+
195+
**Notes**
196+
* aLoRA is only supported for `task_type=CAUSAL_LM` tasks due to its focus on cache reuse.
197+
* Since the weights are adapted on fewer tokens, often (not always) aLoRA requires higher rank (`r`) than LoRA. `r=32` can be a good starting point.
198+
* aLoRA weights cannot be merged into the base model by definition, since the adapter weights are selectively applied to a subset of tokens. Attempts to merge will throw errors.
199+
* Beam search is not yet supported.
200+
* It is generally not recommended to add new tokens to the tokenizer that are not present in the base model, as this can complicate the target use case of both the base model and adapter model operating on overlapping context. That said, there is a possible workaround by first efficiently adding [trainable tokens](https://huggingface.co/docs/peft/en/package_reference/trainable_tokens) to the base model prior to training the adapter.
201+
202+
#### Choice of invocation sequence and SFT design
203+
204+
Each input must have the `alora_invocation_tokens` sequence present, it is not added automatically. To maximize model performance without compromising cache reuse, it is recommended to have the adapter weights activated early, i.e. at the start of any adapter-specific prompting, but after any long inputs such as prior generations or documents. As with any model,
205+
formatting should be consistent between train and test.
206+
207+
Consider the following example, where the base model has a chat template,
208+
and the goal is to train the adapter to generate a desired output.
209+
210+
* Option 1: If there is no task-specific prompt, i.e. the input is a chat history with the `assistant` prompt, then the chat template's `assistant` prompt (e.g. `<|start_of_role|>assistant<|end_of_role|>`) is a natural choice for the invocation string. See the model's chat template to find the prompt for the model.
211+
* Option 2: If there is a task-specific prompt for the adapter that describes the task the adapter is learning, and that prompt is put as a `user` turn immediately prior to the generation, then the chat template's `user` prompt (e.g. `<|start_of_role|>user<|end_of_role|>`) is a natural choice for the invocation string.
212+
213+
Once deciding on an invocation string, get the model tokenizer and obtain `alora_invocation_tokens` as
214+
```
215+
alora_invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False).
216+
```
217+
218+
An example inference setup is at [alora finetuning](https://github.com/huggingface/peft/blob/main/examples/alora_finetuning/alora_finetuning.py).
219+
220+
**Note** If using custom strings for the invocation string, make sure that the start and end of the string are special tokens to avoid issues with tokenization at the boundaries.
221+
222+
To see why, imagine that 'a', 'b', 'c', and 'ab' are tokens in your tokenizer (numbers 1, 2, 3, 4 respectively). Suppose that your alora_invocation_tokens = [2, 3]. Now imagine your input string is "abc". Because "ab" is a token, this will get tokenized as [4,3]. So the alora_invocation_tokens will fail to be found, despite the string "bc" being in it. If the start and end of the invocation string are special tokens, however, this failure case will never happen since special tokens are never tokenized into the same token with other characters.
223+
224+
#### Using (and reusing) cache for generation
225+
The main purpose of Activated LoRA is to make KV cache interchangeable between the base model and aLoRA adapter models **prior to the invocation sequence** since base and adapted KV values are not compatible. Specifically, keys and values stored during one model generation can be used in subsequent generations to avoid expensive prefill operations for context tokens. When sharing cache between the base model and aLoRA adapters, there are 2 main patterns:
226+
1. The base model has generated something, and an aLoRA adapter is then called to do a followup generation. Example: the base model answers a question, and an aLoRA trained to detect hallucinations checks the base model response.
227+
2. An aLoRA adapter has generated something, and the base model or a different aLoRA adapter is called to do a followup generation where there is partial context overlap with the original aLoRA. Example: The user provides a query, and an aLoRA rewrites the query to be more self-contained and improve retrieval in a RAG system. Then, documents are retrieved and loaded into context, an aLoRA checks if these documents are indeed relevant to the question, and then the base model generates an answer.
228+
229+
230+
To demonstrate the above behaviors when using caching, we're using [DynamicCache](https://huggingface.co/docs/transformers/en/kv_cache) from `transformers`. Care must be taken to ensure that adapted cache values are not mixed with base cache values. In particular, an extra step is required for sharing the cache when there is partial context overlap (pattern 2).
231+
232+
**Pattern 1: Base model followed by aLoRA** Here, the entire input and generation from the base model is input into the aLoRA adapter, along with the invocation sequence:
233+
```
234+
from transformers import DynamicCache
235+
...
236+
cache = DynamicCache()
237+
inputs_base = tokenizer(prompt_base, return_tensors="pt")
238+
# Generate from base model and save cache
239+
with model_alora.disable_adapter():
240+
output = model_alora.generate(inputs_base["input_ids"].to(device),attention_mask=inputs_base["attention_mask"].to(device),past_key_values = cache,return_dict_in_generate=True)
241+
output_text_base = tokenizer.decode(output.sequences[0])
242+
cache = output.past_key_values
243+
244+
# Generate with aLoRA adapter from cache
245+
prompt_alora = output_text + INVOCATION_STRING
246+
inputs_alora = tokenizer(prompt_alora, return_tensors="pt").to(device)
247+
output = model_alora.generate(**inputs_alora, past_key_values=cache)
248+
output_text_alora = tokenizer.decode(output[0])
249+
250+
# Note: cache is now tainted with adapter values and cannot be used in base model from here on!
251+
**Pattern 2: aLoRA generation followed by base model (or another aLoRA) with partial context overlap** Here, we prefill the shared context using the base model, and then generate.
252+
```
253+
from transformers import DynamicCache
254+
import copy
255+
...
256+
cache = DynamicCache()
257+
inputs_shared = tokenizer(prompt_shared, return_tensors="pt").to(device)
258+
259+
# Prefill from base model and save cache
260+
with model_alora.disable_adapter():
261+
with torch.no_grad():
262+
model_alora(**inputs_shared, past_key_values=cache)
263+
cache_copy = copy.deepcopy(cache)
264+
265+
# Generate from aLoRA using prefilled cache
266+
prompt_alora = prompt_shared + INVOCATION_STRING
267+
inputs_alora = tokenizer(prompt_alora, return_tensors="pt").to(device)
268+
output = model_alora.generate(**inputs_alora, past_key_values=cache)
269+
output_text_alora = tokenizer.decode(output[0])
270+
271+
# Generate from base model using saved cache not tainted by aLoRA KV values
272+
prompt_base = prompt_shared
273+
inputs_base = tokenizer(prompt_base, return_tensors="pt").to(device)
274+
with model_alora.disable_adapter():
275+
output = model_alora.generate(**inputs_base, past_key_values=cache_copy)
276+
output_text_base = tokenizer.decode(output[0])
277+
```
176278
177279
### Weight-Decomposed Low-Rank Adaptation (DoRA)
178280
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Activated LoRA (aLoRA)
2+
3+
## Introduction
4+
Activated LoRA (aLoRA) is an adapter that selectively activates its weights only after a given invocation sequence, ensuring that hidden states match the base model prior to this point. This allows reusing the base model KVs (stored in the KV cache) for tokens before the invocation,
5+
enabling much faster real-world inference (e.g. vLLM) when switching between generation with the base model and generation with adapters.
6+
See the [paper](https://huggingface.co/papers/2504.12397) for more details.
7+
8+
## Quick start (shown for Mistral 7B)
9+
```python
10+
import torch
11+
from peft import LoraConfig, get_peft_model
12+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, DataCollatorForLanguageModeling
13+
from datasets import load_dataset
14+
15+
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map="cuda")
16+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
17+
dataset = load_dataset("Lots-of-LoRAs/task1660_super_glue_question_generation", split="train")
18+
19+
invocation_string = "[/INST]" # End of user turn in Mistral chat template
20+
invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False)
21+
22+
lora_config = LoraConfig(
23+
task_type="CAUSAL_LM",
24+
alora_invocation_tokens=invocation_tokens,
25+
r=32,
26+
target_modules=["q_proj", "k_proj", "v_proj"],
27+
)
28+
29+
peft_model = get_peft_model(model, lora_config)
30+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
31+
trainer = Trainer(
32+
model=peft_model,
33+
train_dataset=dataset,
34+
dataset_text_field="text",
35+
max_seq_length=2048,
36+
tokenizer=tokenizer,
37+
data_collator=data_collator,
38+
)
39+
trainer.train()
40+
peft_model.save_pretrained("alora-mistral-7b")
41+
```
42+
43+
### Use the training example script directly
44+
Pass the invocation string with `--invocation_string` when running the training example
45+
script. For Mistral 7B, do:
46+
```bash
47+
python examples/alora_finetuning/alora_finetuning.py --base_model mistralai/Mistral-7B-Instruct-v0.3 --data_path Lots-of-LoRAs/task1660_super_glue_question_generation --invocation_string "[/INST]"
48+
```
49+
and similarly for Llama-3.2-3B-Instruct:
50+
```bash
51+
python examples/alora_finetuning/alora_finetuning.py --base_model meta-llama/Llama-3.2-3B-Instruct --data_path Lots-of-LoRAs/task1660_super_glue_question_generation --invocation_string "<|start_header_id|>assistant<|end_header_id|>"
52+
```
53+
54+
### Full example of the script
55+
```bash
56+
python alora_finetuning.py \
57+
--base_model "PATH_TO_MODEL" \
58+
--data_path "PATH_TO_DATASET" \
59+
--output_dir "PATH_TO_OUTPUT_DIR" \
60+
--batch_size 1 \
61+
--num_epochs 3 \
62+
--learning_rate 3e-4 \
63+
--cutoff_len 512 \
64+
--val_set_size 500 \
65+
--invocation_string "[/INST]" \
66+
--quantize \
67+
--eval_step 10 \
68+
--save_step 100 \
69+
--device "cuda:0" \
70+
--lora_r 32 \
71+
--lora_alpha 32 \
72+
--lora_dropout 0.05 \
73+
--lora_target_modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" \
74+
--hub_model_id "YOUR_HF_REPO" \
75+
--push_to_hub
76+
```

0 commit comments

Comments
 (0)