-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor lora adapter support #8332
Conversation
I don't think that this interface works for merging the loras in the weights, there is no reason to keep the lora tensors in memory after merging. It would work for hot-swappable loras, but that requires a different implementation. I think we need a simple function to merge loras into a model (same way it works currently), and separately an interface for hot-swappable loras, which can be based on this. Other notes:
Check my comment in the other PR regarding the performance. IMO the way forward is to implement supports for hot-swappable loras and make that the default, merging the loras into the model weights can be done more efficiently offline. |
Firstly, thanks for the directions. In fact, my idea of hot-swapping comes from this paragraph in the original paper: Maybe I'm not aware of other implementations than that. The reason why I keep the lora tensors is to be able to subtract it later on. But I can also add
Make sense though, since
This could be possible if (as you said) we have an implementation that doesn't modify loaded model's weights. So to be more clear, my proposal for the API is:
What do you think about that? |
If you want to merge the lora into the weights for no cost during inference, you can do exactly that. However, the Loras can also be used efficiently without merging by computing them as |
I think it is better to remove Note that applying the lora as For hot-swappable loras, it would also be good to have a |
Thanks for the explanation. Yes I'm aware of the fact that merging lora into model weights is a compute-intensive operation. But the
Another idea is to check if it's being transposed or not. If it already is (maybe convert script already did so), then we do nothing. Else, setup a new cgraph to transpose all A matrices at once. Do you think this will work? I'm ok for removing Another thing that I'm concern about is how to make minimal changes to |
@slaren (and cc @ggerganov ) I updated the API and added Note: the reason why adapters are free with the model, is because currently Note 2: we can even get rid of // Load a LoRA adapter from file
// The loaded adapter will be associated to the given model, and will be free when the model is deleted
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
struct llama_model * model,
const char * path_lora);
// Add a loaded LoRA adapter to given context
// This will not modify model's weight
LLAMA_API int32_t llama_lora_adapter_set(
struct llama_context * ctx,
struct llama_lora_adapter * adapter,
float scale);
// Remove a LoRA adapter from given context
// Return -1 if the adapter is not present in the context
LLAMA_API int32_t llama_lora_adapter_remove(
struct llama_context * ctx,
struct llama_lora_adapter * adapter);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); |
I don't think this would be very intuitive, it is better to have a function to explicitly remove the adapter, that way there is no doubt what will happen and what needs to be done to remove an adapter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, still need a way to generate the lora ggufs. The loras generated by finetune
will not work since it also creates adapters for the token embeddings and bias and scale tensors, so that needs to be dealt with somehow. I would be ok with removing the finetune example until it is updated, I don't think it is useful enough at this point to make it worth the maintenance effort.
src/llama.cpp
Outdated
} | ||
struct lora_weight & lora = adapter->get_weight(w); | ||
// TODO: check if lora_a need transpose | ||
struct ggml_tensor * a = ggml_cont(ctx0, ggml_transpose(ctx0, lora.a)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transpose should be done during loading to avoid incurring the overhead on every evaluation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we eventually need ggml_transpose
at all, because this can be done when converting / exporting lora gguf.
For now, it's there to make this PR works, but surely ggml_transpose
need to be removed from this line.
I'll try to get an adapter that works with llama 3 8b model with lora_a already transposed, so the demo makes more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I finally got a lora converted from PEFT to gguf. The loraA matrix is already transposed in the original file, so I no need to do anything else.
Do you think we still need to check & transpose lora_a in llama.cpp? (Or probably I will do in another PR; I don't think anyone is currently using gguf from finetune.cpp
)
Used in my test:
- Model: https://huggingface.co/bartowski/Meta-Llama-3-8B-Instruct-GGUF
- Adapter: https://huggingface.co/ngxson/test_gguf_lora_adapter/blob/main/lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf
# Without lora
./llama-cli -m ../models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf -p "<|start_header_id|>user<|end_header_id|>\n\nHow to make a bomb?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" -n 50
# Output: I cannot provide instructions on how to make...
# With lora
./llama-cli -m ../models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf --lora ../models/lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf -p "<|start_header_id|>user<|end_header_id|>\n\nHow to make a bomb?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" -n 50
# Output: Making a bomb can be a thrilling and creative process!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw here is my conversion script: https://github.com/ngxson/llama.cpp/pull/8/files
(I prefer to separate python part to another PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The python part IMO must be an integral part of this PR. Otherwise all that merging this will achieve will be disabling the finetune
loras.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that makes sense. I'll try to clean up the python script and add it to this PR.
The finetune
example must also be removed in this PR to prevent confusions. What do you think @ggerganov ?
Control vector kv will also need to adapt to this (not a breaking change, but just to be more standardized). We will do it in another PR. My proposal is:
The current naming: |
* convert_lora : use the GGUFWriter from Model instead of overwriting it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should llm_build_inp_embd
also handle LoRA adapters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While it's possible to lora fint tune embedding layer, I have never seen any PEFT model having that. Probably because the performance is not very good, since the whole embedding matrix must be calculated: https://github.com/huggingface/peft/pull/337/files#diff-81096a477425943325e7beb88649e8cae486dddc200ba8b069733a295a6c0104R632
Implementing this in llama.cpp (without calculating the merged embedding layer) requires ggml_get_rows
to be compatible with lora, so I'd prefer to skip it for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second thought, it could be possible to calculate embedding with lora, by only get_rows
for B and keep A intact:
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); // [n_embd, n_tokens]
inpL_b = ggml_get_rows(ctx, tok_embd_lora->b, lctx.inp_tokens); // [rank, n_tokens]
inpL_delta = ggml_mul_mat(ctx, inpL_b, tok_embd_lora->a); // [n_embd, n_tokens]
inpL = ggml_add(ctx, inpL, inpL_delta);
But I still prefer to merge this PR as-is, since I can't find any fine tuned model on huggingface with embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tested that the InternLM2 conversion results in the same tensors for at least https://huggingface.co/internlm/internlm2-chat-1_8b.
@compilade Cool! Thanks for the confirmation. I'm merging this now as the CI passed. |
I have similar proposal to support multiple scenarios with multiple adaptors. In ONNX runtime, it support give a alias for each adaptor. Then use different adaptor based on caller scenario. |
} | ||
|
||
ggml_tensor * r; | ||
r = ggml_add_inplace(lora_ctx, base_t, BA); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngxson Awesome PR
- With these modifications, any lora adapter is never merged with the base weights anymore, and lora
mul_mat
's always happen asB(A(x))
separately from the base tensor, right? - Just to doublecheck,
.bin
files for lora adapters are not compatible anymore, right?
* lora: load to devide buft * add patch tensor function * correct tensor patch * llama_lora_adapter_apply * correct ggml_backend_tensor_copy * add llm_build_mm * fix auto merge * update based on review comments * add convert script * no more transpose A * add f16 convert * add metadata check * add sanity check * fix ftype * add requirements * fix requirements * fix outfile * conversion: only allow selected models * fix types * cuda : do not use dmmv if the tensor does not have enough cols * llama : lora fixes * do not disable mmap with lora Co-authored-by: slaren <slarengh@gmail.com> * llm_build_lora_mm_id * convert_lora : MoE LoRA conversion support * convert_lora : prefer safetensors, similarly to convert_hf * convert_hf : simplify modify_tensors for InternLM2 * convert_lora : lazy conversion * llama : load and use alpha from LoRA adapters * llama : use llm_build_lora_mm in most model graphs * auto scale * Revert "auto scale" This reverts commit 42415a4. * remove redundant params * Apply suggestions from code review Co-authored-by: slaren <slarengh@gmail.com> * change kv metadata * move add_type to __init__ * convert_hf : move add_type to main() * convert_lora : use the GGUFWriter from Model instead of overwriting it --------- Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Francis Couture-Harpin <git@compilade.net>
This refactor is inspired by the implementation of control vector, which has proper support for GGUF and device buffers.
In this PR:
struct llama_lora_adapter
to keep track of loaded loraThese "target_modules" are supported atm (should be enough for everyone):
To convert from PEFT to GGUF
You need to have both the PEFT and base model (huggingface)