-
Notifications
You must be signed in to change notification settings - Fork 8
Account for compute cost in collectives during redistribution #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
| elif src_plc.is_shard() and src_plc.dim != 0 and tgt_plc.is_replicate(): | ||
| # add cost of additional cat on full size | ||
| # *2 because we need to count input and output reads | ||
| read_write_bytes = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah got it thanks - the copy is obviously bandwidth bound so we can just use mem bandwidth to estimate cost.
|
The problem with this PR is that it increases the runtime for the solver on a model with a single transformer block from 2.95 s to 12.43 s, because the |
|
Ok so the solver time goes back to a more reasonable amount if we assume 50% efficiency for the IO for the copy. I need to check if this is reasonable or not |
|
Well, it's nice to clean up that hack, but to your earlier point, making the cost more perfectly align with reality is less important than making the system work well on models. Wondering, does this change help with a better solution in some case? |
|
Yes, I actually worked on this because the view->mm->view PR exposed the issue that this PR is trying to solve. The symptom was that we should in principle have the same solution when doing view->mm->view and the einsum formulation, but that wasn't the case, and this PR is an attempt to fix it. I still want to test this more thoroughly before merging, and I'll only merge if I find it beneficial |
…sa/compute_cost_in_comms
…sa/compute_cost_in_comms
…sa/compute_cost_in_comms
|
See also pytorch/pytorch#161882 |
|
Subsumed by #125 |
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement.
Indeed, when performing
S(1) -> R, we currently perform an all-gather on dim 0, and then a full copy of the data. This wasn't modelled properly before (we just multiplied the comm cost by an arbitrary factor of 4), now this is taken properly into account.