Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backpropagation for GEMM #220

Closed
wants to merge 8 commits into from
Closed

Conversation

s4rduk4r
Copy link
Contributor

@s4rduk4r s4rduk4r commented Nov 21, 2023

This PR implements backpropagation for GEMM version.

  1. CUDA kernel for GEMM backpropagation (modification of forward pass kernel) named gemm_backward_cuda
  2. AWQ LoRA objects
  3. Runtime patch of PEFT to enable LoRA finetune
  4. Example script awq_autograd.py of finetune

Below is result of LoRA finetune on 200 entries from OpenAssistant/oasst1 dataset
image

1. CUDA kernel for GEMM backpropagation
2. AWQ_LoRA objects
3. Runtime patch of PEFT
@casper-hansen
Copy link
Owner

Wow, this looks super interesting!

  • Have you profiled how fast this is compared to GPTQ/BitsAndBytes?
  • What does the memory usage look like compared to GPTQ/BitsAndBytes?

@s4rduk4r
Copy link
Contributor Author

Wow, this looks super interesting!

* Have you profiled how fast this is compared to GPTQ/BitsAndBytes?

* What does the memory usage look like compared to GPTQ/BitsAndBytes?

Thank you. I haven't profiled performance and memory usage against GPTQ or BitsAndBytes. Not sure how to measure memory consumption, but training times against GPTQ can do. Probably during this weekend I'll be able to measure times

@casper-hansen
Copy link
Owner

I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet?

@s4rduk4r
Copy link
Contributor Author

I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet?

You're right. This kernel doesn't compute gradients, because I rely on non-fused layers without @torch.no_grad decorator, so the gradients are being calculated by PyTorch and fed to the WQLinear_GEMM_Propagator.backward() method. It's being called by trainer on a back pass

@casper-hansen
Copy link
Owner

casper-hansen commented Nov 22, 2023

I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet?

You're right. This kernel doesn't compute gradients, because I rely on non-fused layers without @torch.no_grad decorator, so the gradients are being calculated by PyTorch and fed to the WQLinear_GEMM_Propagator.backward() method. It's being called by trainer on a back pass

Right, so the backward is just doing matrix multiplication. Do we need two new kernels for that or could we reuse the existing one's?

For reference, there was another attempt at a backward pass. This one runs the matrix multiplications in FP16 though.

https://github.com/compressa-ai/AutoAWQ/blob/dev/awq/modules/linear.py#L60-L88

@s4rduk4r
Copy link
Contributor Author

I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet?

You're right. This kernel doesn't compute gradients, because I rely on non-fused layers without @torch.no_grad decorator, so the gradients are being calculated by PyTorch and fed to the WQLinear_GEMM_Propagator.backward() method. It's being called by trainer on a back pass

Right, so the backward is just doing matrix multiplication. Do we need two new kernels for that or could we reuse the existing one's?

For reference, there was another attempt at a backward pass. This one runs the matrix multiplications in FP16 though.

https://github.com/compressa-ai/AutoAWQ/blob/dev/awq/modules/linear.py#L60-L88

I think we can introduce a matrix transpose flag (default = false) into the forward pass kernels which will be called by the gemm_backward_cuda with this flag set to true. But if in the future someone would want to train fused layers, then, I think, two dedicated back pass kernels will be needed

@casper-hansen
Copy link
Owner

I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet?

You're right. This kernel doesn't compute gradients, because I rely on non-fused layers without @torch.no_grad decorator, so the gradients are being calculated by PyTorch and fed to the WQLinear_GEMM_Propagator.backward() method. It's being called by trainer on a back pass

Right, so the backward is just doing matrix multiplication. Do we need two new kernels for that or could we reuse the existing one's?
For reference, there was another attempt at a backward pass. This one runs the matrix multiplications in FP16 though.
https://github.com/compressa-ai/AutoAWQ/blob/dev/awq/modules/linear.py#L60-L88

I think we can introduce a matrix transpose flag (default = false) into the forward pass kernels which will be called by the gemm_backward_cuda with this flag set to true. But if in the future someone would want to train fused layers, then, I think, two dedicated back pass kernels will be needed

I would expect users to make use of axolotl which has its way of fusing/patching layers. I implemented fused layers for training before, but there is not much of a benefit when we are compute-bound during training.

@s4rduk4r
Copy link
Contributor Author

s4rduk4r commented Dec 3, 2023

I looked into your kernel. It seems it is not computing any gradients yet you call it backward. Additionally, I see the loss in your screenshot is increasing. Could it be that you have not implemented the backward pass yet?

You're right. This kernel doesn't compute gradients, because I rely on non-fused layers without @torch.no_grad decorator, so the gradients are being calculated by PyTorch and fed to the WQLinear_GEMM_Propagator.backward() method. It's being called by trainer on a back pass

Right, so the backward is just doing matrix multiplication. Do we need two new kernels for that or could we reuse the existing one's?

For reference, there was another attempt at a backward pass. This one runs the matrix multiplications in FP16 though.

https://github.com/compressa-ai/AutoAWQ/blob/dev/awq/modules/linear.py#L60-L88

The solution in the link works, but I constantly fail to reproduce it in CUDA code. Probably I'm not that good at this kind of tasks. So I think it will be better to either incorporate the solution from the link as it is or just close this PR, because I won't be able to bring it to the good shape :(

@casper-hansen
Copy link
Owner

What I wonder most about is loss stability. Can you train a model that works?

In your image here, you only trained for a few steps, but the model did not improve.

image

@s4rduk4r
Copy link
Contributor Author

s4rduk4r commented Dec 4, 2023

What I wonder most about is loss stability. Can you train a model that works?

In your image here, you only trained for a few steps, but the model did not improve.

image

You're right. Loss just waltz around 2.6 and never decreases over long period of time. But solution by @compressa-ai trains model. Here's a fresh screenshot done with the @compressa-ai's solution for 1 epoch. Loss also can slightly increase in the process, but overall it decreases towards zero. Also, I was able to produce LoRA module with it
awq_training_dequantize_fp16

I think this exact PR has to be closed for the @compressa-ai's solution to be done as PR

@casper-hansen
Copy link
Owner

If you could add his dequantize weights code and remove the backward kernel, then we could get this PR merged. Might as well merge it now that you have done the other work with training code etc.

@s4rduk4r
Copy link
Contributor Author

s4rduk4r commented Dec 6, 2023

@casper-hansen Removed backward kernel and added dequantization kernel. Also noticed that runtime patch to peft doesn't work with recent versions (0.6.x), but with peft 0.3.0 it works. For now can't find a way how to patch it in runtime. Maybe there should be done PR to peft to include AutoAWQQuantLinear like it has happened with AutoGPTQQuantLinear

@casper-hansen
Copy link
Owner

It looks great. I think we need a PEFT integration to make this fully work. They just need to import and replace modules similarly to how they are already doing it for GPTQ.

@casper-hansen
Copy link
Owner

I'm looking to release this approximately in v0.2.0 of AutoAWQ. Currently, I am first trying to add support for Mixtral before merging this and other PRs.

vLLM is working on a new Triton kernel that scales even better than the original GEMM kernel. Perhaps we can create a backward pass once the vLLM one is done.

https://github.com/vllm-project/vllm/blob/qmm/vllm/model_executor/layers/quantization/ops/awq.py

@s4rduk4r
Copy link
Contributor Author

Triton kernel sounds like a great idea. As of integration with peft, I didn't have time to look into peft yet. Maybe I'll be able to during this weekend.

@casper-hansen
Copy link
Owner

@s4rduk4r did you forget to push your latest training script? The current one gives me a loss of zero after some modifications to use Mistral.

image

@s4rduk4r
Copy link
Contributor Author

@s4rduk4r did you forget to push your latest training script? The current one gives me a loss of zero after some modifications to use Mistral.

image

Sorry for not replying sooner. The training script hasn't been updated. Truth be told I tested it only against Llama2 model and it worked at the time. Maybe there something has to be done differently. Additionally I have almost no time right now and won't be able to delve into this issue deeper. I'm sorry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants