Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark using pytorch - speed up lora operation on batch #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yacineMTB
Copy link
Collaborator

@yacineMTB yacineMTB commented Aug 29, 2023

Bear with me.. I'm learning!

I thought that the python for loop would slow things down, instead you could batch the operation. This is still WIP, I'm trying to wrap my head around this.

https://github.com/sabetAI/bloras/blob/21839a61b883b1398b2418a7992f1c1175506874/blora_utils.py/#L1884-L1891

Each one of these, as far as I can tell, is equivalent

image

I had a hunch that it would be faster to do this in one go, so I fashioned a small benchmark. I ran this on my system, RTX 3090

Loop Mean: 0.00021034269332885743
Batched Mean: 7.178418636322021e-05
Loop Median: 0.0001556873321533203
Batched Median: 6.985664367675781e-05
loop_sum: 2.103426933288574
batched_sum: 0.7178418636322021

Please check that everything is equivalent, I'm quite new at this!

Comment on lines +5 to +14
def compute_lora_out_v2(X, A_list, B_list):
n = X.shape[0]
lora_out = torch.zeros((n, 4096), device="cuda")
for i in range(n):
x = X[i, :]
A = A_list[i]
B = B_list[i]
lora_out[i, :] = torch.matmul(torch.matmul(x, A.T), B.T)

return lora_out
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thing you should be checking:
Is this equivalent to

BLoRA/blora_utils.py

Lines 1881 to 1891 in 21839a6

batch = list(zip(x, self.batch_lora_ids))
# rewrite as for loop
lora_out = torch.zeros_like(result)
for i, (x, lora_id) in enumerate(batch):
if lora_id in self.lora_A.keys():
lora_out[i] = self.scaling[lora_id] * self.lora_B[lora_id](
self.lora_A[lora_id](self.lora_dropout[lora_id](x))
)
result += lora_out

@sidnb13
Copy link

sidnb13 commented Sep 13, 2023

Been thinking about how to further speed this up, I ended up coming up with a solution very similar to yours. Is there any way of parallelizing across the number of lora weights n, maybe by writing a custom CUDA kernel?

@sabetAI
Copy link
Owner

sabetAI commented Sep 13, 2023

Yes @sidnb13 you can stack the LoRAs into a single tensor, and broadcast slices over their corresponding batch elements.

@sidnb13
Copy link

sidnb13 commented Sep 15, 2023

@sabetAI came across this function after asking for help in the PyG community: https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul. It effectively vectorizes across the number of adapters. Some quick testing shows it's actually much slower than the looped approach, but maybe someone else can give it a go.

@sabetAI
Copy link
Owner

sabetAI commented Sep 16, 2023

@sidnb13 nice try! segment_matmul is the perfect function for a blora op, kernel's probably not optimized though. I also attempted parallelizing the blora op through matrix reshapes and stacking, seemed to take longer than a simple loop unfortunately.

Repository owner deleted a comment from nozelle Feb 20, 2024
Repository owner deleted a comment from nozelle Feb 20, 2024
Repository owner deleted a comment from SourceAura Feb 23, 2024
Repository owner deleted a comment from pavana21 Feb 23, 2024
Repository owner deleted a comment from pavana21 Feb 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants