-
Couldn't load subscription status.
- Fork 2.1k
Support for Activated LoRA (Issue https://github.com/huggingface/peft/issues/2523) #2609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kgreenewald, thank you so much for your PR!
I've skimmed the PR and have a few general questions to steer the general implementation before we dig deeper. As far as I understand there are fundamentally two things that we need
- a
forwardmethod that understands when to activate the adapter and when not to (according to index or binary values, i.e. a mask) - a way to generate the aforementioned mask from input ids
for (2) we need to hook into the relevant PeftModel instances, there's no way around that but I think that for (1) we don't need a full-fledged PEFT method. I think we can manage with a LoRA variant that just implements the forward call for the Linear layer. This way we could reduce the number of changes in this PR by a lot.
Secondly, I think that it would be best to use the token ids as invocation sequence, not strings. 1) because we would not have to pass the tokenizer around and do tokenization at unexpected places but also because, 2) we will determine the correct token ids at training time once instead of generating them during inference and hiding a possible mixup when using a newer/different tokenizer.
WDYT?
Regardless of these comments, I think it would be good to merge rebase to main since #2579 fixes the CI and to implement a few tests. We can start at adding a test case in tests/test_custom_models.py which already gives a good test coverage and is generally quite fast to iterate on.
Thanks @githubnemo ! Responses below:
|
|
Cool!
Right, the @contextmanager
def _enable_peft_forward_hooks(self, *args, **kwargs):
hook_handles = []
# If adapter_names is passed as an argument, we inject it into the forward arguments.
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names:
if self.training:
raise ValueError("
[...]
[rest of adapter_names handling in _enable_peft_forward_hooks]
for config in self.peft_config:
if config.use_alora:
[extending hook_handles with alora hook handlers]
yield
for handle in hook_handles:
handle.remove()
I think it is fair to have an invocation tokens config field that defaults to |
…-tuners/lora/variants.py Refactor aLoRA to use variant
@githubnemo Great!! I went through and did these changes, including going ahead and doing the "mask" refactor for the alora forward in variants.py. EDIT - how should I get these CI checks working right on github? Seems there's some token etc I need to provide? Sorry if I'm missing something obvious :) |
No worries, the CI is configured that only after maintainer review the tests are run (security reasons). You can run these tests locally as well:
I'll trigger a CI run now. If runs through I'll do a review, otherwise I'll give you a ping :) edit: CI runs looks OK, doing a review now |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work! Three little comments, everything else looks fine now.
One thing I'm still unsure about is caching. If I understand correctly - correct me if I'm wrong - we're not supposed to cache the KV values that occur when the adapter is active. This would only matter caches that are shared between generate invocations since we don't return to the base model after adapter invocation.
I think we would still need to handle the case in PeftModel.generate of receiving a Cache class in kwargs['past_key_values']. I think we'd need to truncate the KV cache after the call to base_model.generate to the point before the adapter activation. But I'm not yet sure how to do this properly. If you've got any ideas, please let me know. I'd be fine to raise an error instead for the time being that external caching + aLoRA is not supported right now.
| """ | ||
| This is a helper function for Activated LoRA (aLoRA) that searches each input token sequence for the last occurence | ||
| of the appropriate "alora_invocation_tokens" invocation sequence. If adapter_names is passed, then each input uses | ||
| the appropriate invocation sequence for the specified adapter for that row. Logic is provided to handle mixed | ||
| collections of adapters for which not all are aLoRAs (e.g. some base model, some LoRA). If the invocation sequence | ||
| is not present, the corresponding alora_offset is set to None and a warning is printed. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the outdated warning portion in this doc string.
Let's also document the nature of the offset (i.e. that it is starting from the end of the sequence). I don't think this fact is documented yet.
src/peft/tuners/lora/variants.py
Outdated
| # input. As a result, the weights should not be activated anywhere (equivalent to base model). | ||
| # Convert None -> 0 and clip to [0, T] | ||
| offsets = torch.tensor( | ||
| [0 if o is None else max(1, min(int(o), T)) for o in alora_offsets], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that having the min/max clip here is actually helping. If an invalid value is passed we're effectively randomly activating the adapter instead of raising an error. WDYT about removing the clipping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got rid of the low-side clip. The high-side is needed for generation, since if use_cache=True (default), for 2nd and following generated tokens you'll just have T = 1 while alora_offsets remain larger.
docs/source/developer_guides/lora.md
Outdated
|
|
||
| Activated LoRA (aLoRA) is a low rank adapter architecture for Causal LMs that allows for reusing existing base model KV cache for more efficient inference. This approach is best suited for inference pipelines which rely on the base model for most tasks/generations, but use aLoRA adapter(s) to perform specialized task(s) within the chain. For example, checking or correcting generated outputs of the base model. In these settings, inference times can be sped up by an order of magnitude or more. For more information on aLoRA and many example use cases, see https://huggingface.co/papers/2504.12397. | ||
|
|
||
| This technique scans for the last occurence of an invocation sequence (`alora_invocation_tokens`) in each input (this can be as short as 1 token), and activates the adapter weights on tokens starting 1 token after the beginning of the invocation sequence. Weights on prior tokens are left un-adapted -- making the cache for those tokens interchangeable with base model cache due to the causal attention mask in Causal LMs. Usage is very similar to standard LoRA, with the key difference that this invocation sequence must be specified when the adapter is created: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's possible to have input following the invocation sequence it would be good to know what happens to that input (i.e. that it will be processed by the adapter).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this!
@githubnemo This is an interesting and subtle question - it is definitely possible/desired to use the cache for generation, but as you point out there are a lot of ways to do things incorrectly. Here's a rough example of how you can do things in a multi-step pipeline where first a "safety" adapter is called on the input, second the base model generates from the input, then two adapters check that same input. Note that each step re-uses cache. For the last checks, feeding in the cache from the base model generation just works and is straightforward. But for the first "safety" turn, the base model hasn't done anything yet, but we want to create cache that the base model call can use. So I first do a prefill operation on the input using the base model to create a input-cache, which is then used both by "safety" and "base". NOTE - this code snippet is from a prior implementation, you might notice some slight differences like passing in alora_offsets directly. The cache stuff should still apply though. prompt_cache = DynamicCache()
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=False) #+ safety_prompt
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
with torch.no_grad():
prompt_cache = model_alora(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), past_key_values=prompt_cache).past_key_values
# Generate safety exception
input_safety, alora_offsets = tokenize_alora(tokenizer,input_text, safety_prompt)
past_key_values = copy.deepcopy(prompt_cache)
output = model_alora.generate(input_safety["input_ids"].to(device), attention_mask=input_safety["attention_mask"].to(device), use_cache=True, max_new_tokens=10, return_dict_in_generate=True, past_key_values = past_key_values, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
answer = output_text.split(safety_prompt)[-1]
print("Safety: " + answer)
question_chat = [
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": question
},
]
# Generate answer with base
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=160, past_key_values = prompt_cache, return_dict_in_generate=True)
output_text = tokenizer.decode(output.sequences[0])
prompt_cache = output.past_key_values
answer = output_text.split("assistant<|end_of_role|>")[-1]
print("Answer: " + answer)
# Generate certainty score
uq_generation_prompt = "<|start_of_role|>certainty<|end_of_role|>"
uq_chat = [
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": question
},
{
"role": "assistant",
"content": answer
},
]
uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) #+ uq_generation_prompt
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, uq_generation_prompt)
model_alora.set_adapter("certainty")
answer_KV = copy.deepcopy(prompt_cache)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=6, past_key_values=answer_KV,return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Certainty: " + output_text.split("certainty<|end_of_role|>")[-1])
#Hallucination
model_alora.set_adapter("hallucination")
hall_prompt = "<|start_of_role|>hallucination<|end_of_role|>"
uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False)
#inputs = tokenizer(uq_text, return_tensors="pt")
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, hall_prompt)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=5,past_key_values=copy.deepcopy(prompt_cache),return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Hallucination: " + output_text.split("hallucination<|end_of_role|>")[-1])Regarding truncating the created cache in .generate, I agree that this is probably the right thing to do in most cases, but from past attempts I recall that doing the truncation is not terribly well supported? I've done it, but it involved converting to legacy cache and doing operations on the data itself. This hack probably wouldn't be general enough, e.g. with Mamba caches and whatnot. For how to handle this - I think it's pretty important to still allow for cache to be provided as an argument. Maybe some possible solutions (?): 1. leave as-is but document this point, 2. deepcopy the input cache and just return that (and document this), 3. give warnings for either of the above cases whenever cache is passed in, stating that the resulting cache will be hybrid (possibly becoming spammy). 4 (hardest, likely messy) try to figure out how to truncate the cache. |
Thanks for the detailed answer. I agree that (4) is definitely the hardest, I spent quite some time to find a good, common way and wasn't able to. While I like (2) I'm a bit wary of possible side-effects, so (1) - documenting, possibly with a slightly cleaned/compacted version of the code snippet you provided, seems to me the best option at the moment. |
@githubnemo Ok, done! Put a discussion in lora.md in the developer guides. Tried to make the code snippets short, but definitely let me know if the style of these should be improved. I also have a sentence here mentioning that there are PRs to vLLM and llama.cpp that avoid having to do all this manual cache handling, hope that's ok! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! I left a few suggestions regarding the cache example. Otherwise the PR LGTM.
Once that section is done we can merge :)
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
@githubnemo Thanks! Implemented these changes! |
|
This looks good to go to me. Once the CI is green, I'll merge :) Thank you for seeing this implementation through! |
There was a minor typo which a suggestion of PR #2609 which broke code formatting for one code sample. This is a simple fix for that.
Source: https://research.ibm.com/blog/inference-friendly-aloras-lora
Original Github: https://github.com/IBM/activated-lora
Paper: https://arxiv.org/pdf/2504.12397
This PR migrates Activated LoRA (aLoRA) support from a standalone Github (see above) to PEFT itself.
Note there is also an active PR for vLLM inference support for Activated LoRA: vllm-project/vllm#19710 . There are also collections of aLoRA models on huggingface (in the ibm-granite org), note that these preexisting models run off of the standalone github repo and will be updated to work with this new PEFT feature if merged.
Description of changes: Activated LoRA is a modification of the LoRA architecture to "activate" the adapter weights only on tokens coming after a specified invocation_string. This fact makes it so that KV values for the string coming before the activation matches KV values for the base model. This allows KV cache for the input to be interchangeable between the base model and adapter model, and allows for major speedups in inference pipelines (e.g. agentic pipelines) that want to use both base models and adapter models. See the paper for detailed exploration of use cases and further elaboration.
We have tried to make the changes as non-intrusive as possible (but happy to hear any suggestions for how the PR can be improved). The vast majority of changes are in a new folder tuners/alora. Some changes were required to the universally used peft_model.py in order to handle processing of the invocation_string and determination of the activation point. Would be happy to hear any and all suggestions on how to best structure the changes in this file in order to be as nonintrusive as possible.
Other notes: