-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Benchmark using pytorch - speed up lora operation on batch #1
base: main
Are you sure you want to change the base?
Conversation
def compute_lora_out_v2(X, A_list, B_list): | ||
n = X.shape[0] | ||
lora_out = torch.zeros((n, 4096), device="cuda") | ||
for i in range(n): | ||
x = X[i, :] | ||
A = A_list[i] | ||
B = B_list[i] | ||
lora_out[i, :] = torch.matmul(torch.matmul(x, A.T), B.T) | ||
|
||
return lora_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thing you should be checking:
Is this equivalent to
Lines 1881 to 1891 in 21839a6
batch = list(zip(x, self.batch_lora_ids)) | |
# rewrite as for loop | |
lora_out = torch.zeros_like(result) | |
for i, (x, lora_id) in enumerate(batch): | |
if lora_id in self.lora_A.keys(): | |
lora_out[i] = self.scaling[lora_id] * self.lora_B[lora_id]( | |
self.lora_A[lora_id](self.lora_dropout[lora_id](x)) | |
) | |
result += lora_out |
Been thinking about how to further speed this up, I ended up coming up with a solution very similar to yours. Is there any way of parallelizing across the number of lora weights |
Yes @sidnb13 you can stack the LoRAs into a single tensor, and broadcast slices over their corresponding batch elements. |
@sabetAI came across this function after asking for help in the PyG community: https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul. It effectively vectorizes across the number of adapters. Some quick testing shows it's actually much slower than the looped approach, but maybe someone else can give it a go. |
@sidnb13 nice try! |
Bear with me.. I'm learning!
I thought that the python for loop would slow things down, instead you could batch the operation. This is still WIP, I'm trying to wrap my head around this.
https://github.com/sabetAI/bloras/blob/21839a61b883b1398b2418a7992f1c1175506874/blora_utils.py/#L1884-L1891
Each one of these, as far as I can tell, is equivalent
I had a hunch that it would be faster to do this in one go, so I fashioned a small benchmark. I ran this on my system, RTX 3090
Please check that everything is equivalent, I'm quite new at this!