Skip to content

Conversation

@kgreenewald
Copy link
Contributor

Source: https://research.ibm.com/blog/inference-friendly-aloras-lora
Original Github: https://github.com/IBM/activated-lora
Paper: https://arxiv.org/pdf/2504.12397

This PR migrates Activated LoRA (aLoRA) support from a standalone Github (see above) to PEFT itself.

Note there is also an active PR for vLLM inference support for Activated LoRA: vllm-project/vllm#19710 . There are also collections of aLoRA models on huggingface (in the ibm-granite org), note that these preexisting models run off of the standalone github repo and will be updated to work with this new PEFT feature if merged.

Description of changes: Activated LoRA is a modification of the LoRA architecture to "activate" the adapter weights only on tokens coming after a specified invocation_string. This fact makes it so that KV values for the string coming before the activation matches KV values for the base model. This allows KV cache for the input to be interchangeable between the base model and adapter model, and allows for major speedups in inference pipelines (e.g. agentic pipelines) that want to use both base models and adapter models. See the paper for detailed exploration of use cases and further elaboration.

We have tried to make the changes as non-intrusive as possible (but happy to hear any suggestions for how the PR can be improved). The vast majority of changes are in a new folder tuners/alora. Some changes were required to the universally used peft_model.py in order to handle processing of the invocation_string and determination of the activation point. Would be happy to hear any and all suggestions on how to best structure the changes in this file in order to be as nonintrusive as possible.

Other notes:

  • The crux of the changes are really in layer.py. Everything else is simply managing the alora_offsets quantity which defines where the weights start to be activated. This is determined by scanning input strings for the invocation_string defined in the aLoraConfig.
  • I believe that aLoRA really only makes sense for CausalLMs, hence I've only implemented this for that model type.
  • Merging doesn't make sense for aLoRA adapters since the weights are not universally applied to all tokens.
  • I used the LoRA code as a starting point, but did not implement various seemingly extra features in that code. If I missed something important definitely let me know!
  • As of now, invocation_string should probably start and end with special tokens, to avoid tokenizer issues at the boundary. Open to suggestions on how to make this more general if needed.

image

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kgreenewald, thank you so much for your PR!

I've skimmed the PR and have a few general questions to steer the general implementation before we dig deeper. As far as I understand there are fundamentally two things that we need

  1. a forward method that understands when to activate the adapter and when not to (according to index or binary values, i.e. a mask)
  2. a way to generate the aforementioned mask from input ids

for (2) we need to hook into the relevant PeftModel instances, there's no way around that but I think that for (1) we don't need a full-fledged PEFT method. I think we can manage with a LoRA variant that just implements the forward call for the Linear layer. This way we could reduce the number of changes in this PR by a lot.

Secondly, I think that it would be best to use the token ids as invocation sequence, not strings. 1) because we would not have to pass the tokenizer around and do tokenization at unexpected places but also because, 2) we will determine the correct token ids at training time once instead of generating them during inference and hiding a possible mixup when using a newer/different tokenizer.

WDYT?

Regardless of these comments, I think it would be good to merge rebase to main since #2579 fixes the CI and to implement a few tests. We can start at adding a test case in tests/test_custom_models.py which already gives a good test coverage and is generally quite fast to iterate on.

@kgreenewald
Copy link
Contributor Author

Hey @kgreenewald, thank you so much for your PR!

I've skimmed the PR and have a few general questions to steer the general implementation before we dig deeper. As far as I understand there are fundamentally two things that we need

1. a `forward` method that understands when to activate the adapter and when not to (according to index or binary values, i.e. a mask)

2. a way to generate the aforementioned mask from input ids

for (2) we need to hook into the relevant PeftModel instances, there's no way around that but I think that for (1) we don't need a full-fledged PEFT method. I think we can manage with a LoRA variant that just implements the forward call for the Linear layer. This way we could reduce the number of changes in this PR by a lot.

Secondly, I think that it would be best to use the token ids as invocation sequence, not strings. 1) because we would not have to pass the tokenizer around and do tokenization at unexpected places but also because, 2) we will determine the correct token ids at training time once instead of generating them during inference and hiding a possible mixup when using a newer/different tokenizer.

WDYT?

Regardless of these comments, I think it would be good to merge rebase to main since #2579 fixes the CI and to implement a few tests. We can start at adding a test case in tests/test_custom_models.py which already gives a good test coverage and is generally quite fast to iterate on.

Thanks @githubnemo ! Responses below:

  1. These parts 1 & 2 are correct, and the LoraVariant concept is very cool! I actually am worried that this might be more intrusive however -- let me know if you disagree and think this is actually fine! (A) I'd need to change many of the main Lora files in order for the created mask to actually get passed up the chain to the forward method (since (a) the LoraVariant has hardcoded arguments and (b) would need to rewrite the forward hooks). (B) The invocation_tokens need to be stored in the adapter_config.json, I imagine we might not want to create this field for all Lora models?
    ---On the other hand, if you believe the masking procedure might be of general interest (even other masking strategies besides the one we currently do, possibly giving users the ability to specify their own masking functions), such that the above downsides are worth dealing with, I'm happy to make the above changes!

  2. "token ids as invocation sequence, not strings": Sounds good, I'll make this change.

  3. rebase & tests: will do!

@githubnemo
Copy link
Collaborator

Cool!

  1. These parts 1 & 2 are correct, and the LoraVariant concept is very cool! I actually am worried that this might be more intrusive however -- let me know if you disagree and think this is actually fine! (A) I'd need to change many of the main Lora files in order for the created mask to actually get passed up the chain to the forward method (since (a) the LoraVariant has hardcoded arguments and (b) would need to rewrite the forward hooks).

Right, the LoraVariant.forward doesn't have a **kwargs parameter right now. Feel free to add it, it might be useful in the future anyway. There would need to be a specific forward hook for temporarily adding the aLoRA mask parameter to the forward calls, yes. The current implementation of _enable_peft_forward_hooks doesn't regard variants and while it would not be a lot harder to build support for variants I'd opt for implementing a special case for handling the alora offsets there. Something like

    @contextmanager
    def _enable_peft_forward_hooks(self, *args, **kwargs):
        hook_handles = []

        # If adapter_names is passed as an argument, we inject it into the forward arguments.
        adapter_names = kwargs.pop("adapter_names", None)
        if adapter_names:
            if self.training:
                raise ValueError("
                        [...]
             [rest of adapter_names handling in _enable_peft_forward_hooks]

        for config in self.peft_config:
            if config.use_alora:
                [extending hook_handles with alora hook handlers]

        yield

        for handle in hook_handles:
             handle.remove()

(B) The invocation_tokens need to be stored in the adapter_config.json, I imagine we might not want to create this field for all Lora models?
---On the other hand, if you believe the masking procedure might be of general interest (even other masking strategies besides the one we currently do, possibly giving users the ability to specify their own masking functions), such that the above downsides are worth dealing with, I'm happy to make the above changes!

I think it is fair to have an invocation tokens config field that defaults to None. This could also double as the signal when to use activated LoRA.

@kgreenewald
Copy link
Contributor Author

kgreenewald commented Sep 2, 2025

I think except for a few missing / incomplete tests we're almost done.

Since this PR also provides a bitsandbytes implementation, make sure to add a test in test_gpu_examples.py (similar to the Eva test, roughly testing inference with bitsandbytes on one of the smaller models there, like OPT). Once these tests are done I'd let the CI run and if everything's green we can merge this :)

Make sure to keep up with main and run make style.

@githubnemo Great!! I went through and did these changes, including going ahead and doing the "mask" refactor for the alora forward in variants.py.

EDIT - how should I get these CI checks working right on github? Seems there's some token etc I need to provide? Sorry if I'm missing something obvious :)

@githubnemo
Copy link
Collaborator

githubnemo commented Sep 2, 2025

I think except for a few missing / incomplete tests we're almost done.
Since this PR also provides a bitsandbytes implementation, make sure to add a test in test_gpu_examples.py (similar to the Eva test, roughly testing inference with bitsandbytes on one of the smaller models there, like OPT). Once these tests are done I'd let the CI run and if everything's green we can merge this :)
Make sure to keep up with main and run make style.

@githubnemo Great!! I went through and did these changes, including going ahead and doing the "mask" refactor for the alora forward in variants.py.

EDIT - how should I get these CI checks working right on github? Seems there's some token etc I need to provide? Sorry if I'm missing something obvious :)

No worries, the CI is configured that only after maintainer review the tests are run (security reasons). You can run these tests locally as well:

  • make style from the root directory for linting / formatting
  • pytest tests/ to run all tests
  • pytest -k alora tests to run tests with 'alora' in their name (might not be exhaustive)

I'll trigger a CI run now. If runs through I'll do a review, otherwise I'll give you a ping :)

edit: CI runs looks OK, doing a review now

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! Three little comments, everything else looks fine now.

One thing I'm still unsure about is caching. If I understand correctly - correct me if I'm wrong - we're not supposed to cache the KV values that occur when the adapter is active. This would only matter caches that are shared between generate invocations since we don't return to the base model after adapter invocation.

I think we would still need to handle the case in PeftModel.generate of receiving a Cache class in kwargs['past_key_values']. I think we'd need to truncate the KV cache after the call to base_model.generate to the point before the adapter activation. But I'm not yet sure how to do this properly. If you've got any ideas, please let me know. I'd be fine to raise an error instead for the time being that external caching + aLoRA is not supported right now.

Comment on lines 511 to 517
"""
This is a helper function for Activated LoRA (aLoRA) that searches each input token sequence for the last occurence
of the appropriate "alora_invocation_tokens" invocation sequence. If adapter_names is passed, then each input uses
the appropriate invocation sequence for the specified adapter for that row. Logic is provided to handle mixed
collections of adapters for which not all are aLoRAs (e.g. some base model, some LoRA). If the invocation sequence
is not present, the corresponding alora_offset is set to None and a warning is printed.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the outdated warning portion in this doc string.

Let's also document the nature of the offset (i.e. that it is starting from the end of the sequence). I don't think this fact is documented yet.

# input. As a result, the weights should not be activated anywhere (equivalent to base model).
# Convert None -> 0 and clip to [0, T]
offsets = torch.tensor(
[0 if o is None else max(1, min(int(o), T)) for o in alora_offsets],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that having the min/max clip here is actually helping. If an invalid value is passed we're effectively randomly activating the adapter instead of raising an error. WDYT about removing the clipping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of the low-side clip. The high-side is needed for generation, since if use_cache=True (default), for 2nd and following generated tokens you'll just have T = 1 while alora_offsets remain larger.


Activated LoRA (aLoRA) is a low rank adapter architecture for Causal LMs that allows for reusing existing base model KV cache for more efficient inference. This approach is best suited for inference pipelines which rely on the base model for most tasks/generations, but use aLoRA adapter(s) to perform specialized task(s) within the chain. For example, checking or correcting generated outputs of the base model. In these settings, inference times can be sped up by an order of magnitude or more. For more information on aLoRA and many example use cases, see https://huggingface.co/papers/2504.12397.

This technique scans for the last occurence of an invocation sequence (`alora_invocation_tokens`) in each input (this can be as short as 1 token), and activates the adapter weights on tokens starting 1 token after the beginning of the invocation sequence. Weights on prior tokens are left un-adapted -- making the cache for those tokens interchangeable with base model cache due to the causal attention mask in Causal LMs. Usage is very similar to standard LoRA, with the key difference that this invocation sequence must be specified when the adapter is created:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's possible to have input following the invocation sequence it would be good to know what happens to that input (i.e. that it will be processed by the adapter).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this!

@kgreenewald
Copy link
Contributor Author

Good work! Three little comments, everything else looks fine now.

One thing I'm still unsure about is caching. If I understand correctly - correct me if I'm wrong - we're not supposed to cache the KV values that occur when the adapter is active. This would only matter caches that are shared between generate invocations since we don't return to the base model after adapter invocation.

I think we would still need to handle the case in PeftModel.generate of receiving a Cache class in kwargs['past_key_values']. I think we'd need to truncate the KV cache after the call to base_model.generate to the point before the adapter activation. But I'm not yet sure how to do this properly. If you've got any ideas, please let me know. I'd be fine to raise an error instead for the time being that external caching + aLoRA is not supported right now.

@githubnemo This is an interesting and subtle question - it is definitely possible/desired to use the cache for generation, but as you point out there are a lot of ways to do things incorrectly. Here's a rough example of how you can do things in a multi-step pipeline where first a "safety" adapter is called on the input, second the base model generates from the input, then two adapters check that same input. Note that each step re-uses cache. For the last checks, feeding in the cache from the base model generation just works and is straightforward. But for the first "safety" turn, the base model hasn't done anything yet, but we want to create cache that the base model call can use. So I first do a prefill operation on the input using the base model to create a input-cache, which is then used both by "safety" and "base".

NOTE - this code snippet is from a prior implementation, you might notice some slight differences like passing in alora_offsets directly. The cache stuff should still apply though.

prompt_cache = DynamicCache()
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=False) #+ safety_prompt
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
    with torch.no_grad():
        prompt_cache = model_alora(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), past_key_values=prompt_cache).past_key_values 

# Generate safety exception
input_safety, alora_offsets = tokenize_alora(tokenizer,input_text, safety_prompt)

past_key_values = copy.deepcopy(prompt_cache)
output = model_alora.generate(input_safety["input_ids"].to(device), attention_mask=input_safety["attention_mask"].to(device), use_cache=True, max_new_tokens=10, return_dict_in_generate=True, past_key_values = past_key_values, alora_offsets = alora_offsets)

output_text = tokenizer.decode(output.sequences[0])

answer = output_text.split(safety_prompt)[-1]
print("Safety: " + answer)

question_chat = [
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": question
    },
]

# Generate answer with base
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
    output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=160, past_key_values = prompt_cache, return_dict_in_generate=True)
output_text = tokenizer.decode(output.sequences[0])
prompt_cache = output.past_key_values
answer = output_text.split("assistant<|end_of_role|>")[-1]
print("Answer: " + answer)

# Generate certainty score
uq_generation_prompt = "<|start_of_role|>certainty<|end_of_role|>"
uq_chat = [
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": question
    },
    {
        "role": "assistant",
        "content": answer
    },
]

uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) #+ uq_generation_prompt
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, uq_generation_prompt)

model_alora.set_adapter("certainty")
answer_KV = copy.deepcopy(prompt_cache)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=6, past_key_values=answer_KV,return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Certainty: " + output_text.split("certainty<|end_of_role|>")[-1])

#Hallucination
model_alora.set_adapter("hallucination")
hall_prompt = "<|start_of_role|>hallucination<|end_of_role|>"
uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) 
#inputs = tokenizer(uq_text, return_tensors="pt")
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, hall_prompt)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=5,past_key_values=copy.deepcopy(prompt_cache),return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Hallucination: " + output_text.split("hallucination<|end_of_role|>")[-1])

Regarding truncating the created cache in .generate, I agree that this is probably the right thing to do in most cases, but from past attempts I recall that doing the truncation is not terribly well supported? I've done it, but it involved converting to legacy cache and doing operations on the data itself. This hack probably wouldn't be general enough, e.g. with Mamba caches and whatnot.

For how to handle this - I think it's pretty important to still allow for cache to be provided as an argument. Maybe some possible solutions (?): 1. leave as-is but document this point, 2. deepcopy the input cache and just return that (and document this), 3. give warnings for either of the above cases whenever cache is passed in, stating that the resulting cache will be hybrid (possibly becoming spammy). 4 (hardest, likely messy) try to figure out how to truncate the cache.

@githubnemo
Copy link
Collaborator

@githubnemo This is an interesting and subtle question - it is definitely possible/desired to use the cache for generation, but as you point out there are a lot of ways to do things incorrectly. Here's a rough example of how you can do things in a multi-step pipeline where first a "safety" adapter is called on the input, second the base model generates from the input, then two adapters check that same input. Note that each step re-uses cache. For the last checks, feeding in the cache from the base model generation just works and is straightforward. But for the first "safety" turn, the base model hasn't done anything yet, but we want to create cache that the base model call can use. So I first do a prefill operation on the input using the base model to create a input-cache, which is then used both by "safety" and "base".

NOTE - this code snippet is from a prior implementation, you might notice some slight differences like passing in alora_offsets directly. The cache stuff should still apply though.

prompt_cache = DynamicCache()
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=False) #+ safety_prompt
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
    with torch.no_grad():
        prompt_cache = model_alora(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), past_key_values=prompt_cache).past_key_values 

# Generate safety exception
input_safety, alora_offsets = tokenize_alora(tokenizer,input_text, safety_prompt)

past_key_values = copy.deepcopy(prompt_cache)
output = model_alora.generate(input_safety["input_ids"].to(device), attention_mask=input_safety["attention_mask"].to(device), use_cache=True, max_new_tokens=10, return_dict_in_generate=True, past_key_values = past_key_values, alora_offsets = alora_offsets)

output_text = tokenizer.decode(output.sequences[0])

answer = output_text.split(safety_prompt)[-1]
print("Safety: " + answer)

question_chat = [
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": question
    },
]

# Generate answer with base
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
    output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=160, past_key_values = prompt_cache, return_dict_in_generate=True)
output_text = tokenizer.decode(output.sequences[0])
prompt_cache = output.past_key_values
answer = output_text.split("assistant<|end_of_role|>")[-1]
print("Answer: " + answer)

# Generate certainty score
uq_generation_prompt = "<|start_of_role|>certainty<|end_of_role|>"
uq_chat = [
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": question
    },
    {
        "role": "assistant",
        "content": answer
    },
]

uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) #+ uq_generation_prompt
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, uq_generation_prompt)

model_alora.set_adapter("certainty")
answer_KV = copy.deepcopy(prompt_cache)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=6, past_key_values=answer_KV,return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Certainty: " + output_text.split("certainty<|end_of_role|>")[-1])

#Hallucination
model_alora.set_adapter("hallucination")
hall_prompt = "<|start_of_role|>hallucination<|end_of_role|>"
uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) 
#inputs = tokenizer(uq_text, return_tensors="pt")
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, hall_prompt)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=5,past_key_values=copy.deepcopy(prompt_cache),return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Hallucination: " + output_text.split("hallucination<|end_of_role|>")[-1])

Regarding truncating the created cache in .generate, I agree that this is probably the right thing to do in most cases, but from past attempts I recall that doing the truncation is not terribly well supported? I've done it, but it involved converting to legacy cache and doing operations on the data itself. This hack probably wouldn't be general enough, e.g. with Mamba caches and whatnot.

For how to handle this - I think it's pretty important to still allow for cache to be provided as an argument. Maybe some possible solutions (?): 1. leave as-is but document this point, 2. deepcopy the input cache and just return that (and document this), 3. give warnings for either of the above cases whenever cache is passed in, stating that the resulting cache will be hybrid (possibly becoming spammy). 4 (hardest, likely messy) try to figure out how to truncate the cache.

Thanks for the detailed answer.

I agree that (4) is definitely the hardest, I spent quite some time to find a good, common way and wasn't able to. While I like (2) I'm a bit wary of possible side-effects, so (1) - documenting, possibly with a slightly cleaned/compacted version of the code snippet you provided, seems to me the best option at the moment.

@kgreenewald
Copy link
Contributor Author

@githubnemo This is an interesting and subtle question - it is definitely possible/desired to use the cache for generation, but as you point out there are a lot of ways to do things incorrectly. Here's a rough example of how you can do things in a multi-step pipeline where first a "safety" adapter is called on the input, second the base model generates from the input, then two adapters check that same input. Note that each step re-uses cache. For the last checks, feeding in the cache from the base model generation just works and is straightforward. But for the first "safety" turn, the base model hasn't done anything yet, but we want to create cache that the base model call can use. So I first do a prefill operation on the input using the base model to create a input-cache, which is then used both by "safety" and "base".
NOTE - this code snippet is from a prior implementation, you might notice some slight differences like passing in alora_offsets directly. The cache stuff should still apply though.

prompt_cache = DynamicCache()
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=False) #+ safety_prompt
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
    with torch.no_grad():
        prompt_cache = model_alora(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), past_key_values=prompt_cache).past_key_values 

# Generate safety exception
input_safety, alora_offsets = tokenize_alora(tokenizer,input_text, safety_prompt)

past_key_values = copy.deepcopy(prompt_cache)
output = model_alora.generate(input_safety["input_ids"].to(device), attention_mask=input_safety["attention_mask"].to(device), use_cache=True, max_new_tokens=10, return_dict_in_generate=True, past_key_values = past_key_values, alora_offsets = alora_offsets)

output_text = tokenizer.decode(output.sequences[0])

answer = output_text.split(safety_prompt)[-1]
print("Safety: " + answer)

question_chat = [
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": question
    },
]

# Generate answer with base
input_text = tokenizer.apply_chat_template(question_chat,tokenize=False,add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt")
with model_alora.disable_adapter():
    output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=160, past_key_values = prompt_cache, return_dict_in_generate=True)
output_text = tokenizer.decode(output.sequences[0])
prompt_cache = output.past_key_values
answer = output_text.split("assistant<|end_of_role|>")[-1]
print("Answer: " + answer)

# Generate certainty score
uq_generation_prompt = "<|start_of_role|>certainty<|end_of_role|>"
uq_chat = [
    {
        "role": "system",
        "content": system_prompt
    },
    {
        "role": "user",
        "content": question
    },
    {
        "role": "assistant",
        "content": answer
    },
]

uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) #+ uq_generation_prompt
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, uq_generation_prompt)

model_alora.set_adapter("certainty")
answer_KV = copy.deepcopy(prompt_cache)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=6, past_key_values=answer_KV,return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Certainty: " + output_text.split("certainty<|end_of_role|>")[-1])

#Hallucination
model_alora.set_adapter("hallucination")
hall_prompt = "<|start_of_role|>hallucination<|end_of_role|>"
uq_text = tokenizer.apply_chat_template(uq_chat,tokenize=False) 
#inputs = tokenizer(uq_text, return_tensors="pt")
inputs, alora_offsets = tokenize_alora(tokenizer,uq_text, hall_prompt)
output = model_alora.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=5,past_key_values=copy.deepcopy(prompt_cache),return_dict_in_generate=True, alora_offsets = alora_offsets)
output_text = tokenizer.decode(output.sequences[0])
print("Hallucination: " + output_text.split("hallucination<|end_of_role|>")[-1])

Regarding truncating the created cache in .generate, I agree that this is probably the right thing to do in most cases, but from past attempts I recall that doing the truncation is not terribly well supported? I've done it, but it involved converting to legacy cache and doing operations on the data itself. This hack probably wouldn't be general enough, e.g. with Mamba caches and whatnot.
For how to handle this - I think it's pretty important to still allow for cache to be provided as an argument. Maybe some possible solutions (?): 1. leave as-is but document this point, 2. deepcopy the input cache and just return that (and document this), 3. give warnings for either of the above cases whenever cache is passed in, stating that the resulting cache will be hybrid (possibly becoming spammy). 4 (hardest, likely messy) try to figure out how to truncate the cache.

Thanks for the detailed answer.

I agree that (4) is definitely the hardest, I spent quite some time to find a good, common way and wasn't able to. While I like (2) I'm a bit wary of possible side-effects, so (1) - documenting, possibly with a slightly cleaned/compacted version of the code snippet you provided, seems to me the best option at the moment.

@githubnemo Ok, done! Put a discussion in lora.md in the developer guides. Tried to make the code snippets short, but definitely let me know if the style of these should be improved. I also have a sentence here mentioning that there are PRs to vLLM and llama.cpp that avoid having to do all this manual cache handling, hope that's ok!

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! I left a few suggestions regarding the cache example. Otherwise the PR LGTM.

Once that section is done we can merge :)

kgreenewald and others added 6 commits September 3, 2025 07:54
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
@kgreenewald
Copy link
Contributor Author

Cool! I left a few suggestions regarding the cache example. Otherwise the PR LGTM.

Once that section is done we can merge :)

@githubnemo Thanks! Implemented these changes!

@githubnemo
Copy link
Collaborator

This looks good to go to me. Once the CI is green, I'll merge :)

Thank you for seeing this implementation through!

@githubnemo githubnemo merged commit 293aea5 into huggingface:main Sep 3, 2025
14 checks passed
githubnemo added a commit that referenced this pull request Sep 3, 2025
There was a minor typo which a suggestion of PR #2609 which broke code formatting for one code sample.

This is a simple fix for that.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants