-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backpropagation for GEMM #220
Conversation
1. CUDA kernel for GEMM backpropagation 2. AWQ_LoRA objects 3. Runtime patch of PEFT
Wow, this looks super interesting!
|
Thank you. I haven't profiled performance and memory usage against GPTQ or BitsAndBytes. Not sure how to measure memory consumption, but training times against GPTQ can do. Probably during this weekend I'll be able to measure times |
I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet? |
You're right. This kernel doesn't compute gradients, because I rely on non-fused layers without |
Right, so the backward is just doing matrix multiplication. Do we need two new kernels for that or could we reuse the existing one's? For reference, there was another attempt at a backward pass. This one runs the matrix multiplications in FP16 though. https://github.com/compressa-ai/AutoAWQ/blob/dev/awq/modules/linear.py#L60-L88 |
I think we can introduce a matrix transpose flag (default = |
I would expect users to make use of axolotl which has its way of fusing/patching layers. I implemented fused layers for training before, but there is not much of a benefit when we are compute-bound during training. |
The solution in the link works, but I constantly fail to reproduce it in CUDA code. Probably I'm not that good at this kind of tasks. So I think it will be better to either incorporate the solution from the link as it is or just close this PR, because I won't be able to bring it to the good shape :( |
You're right. Loss just waltz around 2.6 and never decreases over long period of time. But solution by @compressa-ai trains model. Here's a fresh screenshot done with the @compressa-ai's solution for 1 epoch. Loss also can slightly increase in the process, but overall it decreases towards zero. Also, I was able to produce LoRA module with it I think this exact PR has to be closed for the @compressa-ai's solution to be done as PR |
If you could add his dequantize weights code and remove the backward kernel, then we could get this PR merged. Might as well merge it now that you have done the other work with training code etc. |
@casper-hansen Removed backward kernel and added dequantization kernel. Also noticed that runtime patch to |
It looks great. I think we need a PEFT integration to make this fully work. They just need to import and replace modules similarly to how they are already doing it for GPTQ. |
I'm looking to release this approximately in v0.2.0 of AutoAWQ. Currently, I am first trying to add support for Mixtral before merging this and other PRs. vLLM is working on a new Triton kernel that scales even better than the original GEMM kernel. Perhaps we can create a backward pass once the vLLM one is done. |
Triton kernel sounds like a great idea. As of integration with |
@s4rduk4r did you forget to push your latest training script? The current one gives me a loss of zero after some modifications to use Mistral. |
Sorry for not replying sooner. The training script hasn't been updated. Truth be told I tested it only against Llama2 model and it worked at the time. Maybe there something has to be done differently. Additionally I have almost no time right now and won't be able to delve into this issue deeper. I'm sorry. |
This PR implements backpropagation for GEMM version.
gemm_backward_cuda
awq_autograd.py
of finetuneBelow is result of LoRA finetune on 200 entries from
OpenAssistant/oasst1
dataset