Skip to content

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Aug 13, 2025

This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement.

Indeed, when performing S(1) -> R, we currently perform an all-gather on dim 0, and then a full copy of the data. This wasn't modelled properly before (we just multiplied the comm cost by an arbitrary factor of 4), now this is taken properly into account.

This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
@fmassa fmassa requested review from bdhirsh and wconstab August 13, 2025 18:54
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 13, 2025
@fmassa fmassa changed the title Account for compute cost in collectives as well Account for compute cost in collectives during redistribution Aug 13, 2025
elif src_plc.is_shard() and src_plc.dim != 0 and tgt_plc.is_replicate():
# add cost of additional cat on full size
# *2 because we need to count input and output reads
read_write_bytes = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah got it thanks - the copy is obviously bandwidth bound so we can just use mem bandwidth to estimate cost.

@fmassa
Copy link
Contributor Author

fmassa commented Aug 13, 2025

The problem with this PR is that it increases the runtime for the solver on a model with a single transformer block from 2.95 s to 12.43 s, because the *4 in the previous heuristic was making some cases be way more expensive than it should so the solver was having an easier time skipping those cases... :-/

@fmassa
Copy link
Contributor Author

fmassa commented Aug 13, 2025

Ok so the solver time goes back to a more reasonable amount if we assume 50% efficiency for the IO for the copy. I need to check if this is reasonable or not

@wconstab
Copy link
Contributor

Well, it's nice to clean up that hack, but to your earlier point, making the cost more perfectly align with reality is less important than making the system work well on models. Wondering, does this change help with a better solution in some case?

@fmassa
Copy link
Contributor Author

fmassa commented Aug 16, 2025

Yes, I actually worked on this because the view->mm->view PR exposed the issue that this PR is trying to solve.

The symptom was that we should in principle have the same solution when doing view->mm->view and the einsum formulation, but that wasn't the case, and this PR is an attempt to fix it.

I still want to test this more thoroughly before merging, and I'll only merge if I find it beneficial

@ezyang
Copy link
Contributor

ezyang commented Sep 2, 2025

See also pytorch/pytorch#161882

@fmassa
Copy link
Contributor Author

fmassa commented Sep 10, 2025

Subsumed by #125

@fmassa fmassa closed this Sep 10, 2025
@fmassa fmassa deleted the fmassa/compute_cost_in_comms branch September 10, 2025 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants