Skip to content

Conversation

@tianyu-l
Copy link
Contributor

issue pointed out in
#1534 (comment)
pytorch/pytorch#160285

solution given by @rakkit in #1534 (comment)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 11, 2025
@tianyu-l tianyu-l merged commit 59e57a4 into main Aug 11, 2025
7 checks passed
@tianyu-l tianyu-l deleted the fix branch August 11, 2025 23:54
@tianyu-l tianyu-l linked an issue Aug 12, 2025 that may be closed by this pull request
ruisizhang123 added a commit that referenced this pull request Oct 9, 2025
this PR is a followup of SimpleFSDP+EP
[PR](#1529). Here, we add a
`gradient_divide_factor` following FSDP2 to ensure modules wrapped by
(FSDP+EP) has the correct gradient reduction value.

- The original FSDP2 implementation is in this
[PR](#1551).
- The `gradient_divide_factor` logic is
[here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688)

We have two ways of handling `gradient_divide_factor` in
`reduce_scatter`:

1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the
`gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only
accepts `reduce_op` as a str input
([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)).

To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM,
we need to update the DTensor `_reduce_shard_tensor` and
`torch.distributed._functional_collectives.reduce_scatter_tensor` so
that it can pass the factor associated with ReduceOp.PREMUL_SUM as an
input.



2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`.
The logic is in this [Diff](https://www.internalfb.com/diff/D76546536).
It does a `div_` over gradient before performing `ReduceOp.SUM`.

Currently I'm following 2 since it is requires less change to
`_functional_collectives`.


After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4)
have identical loss:

<img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM"
src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de"
/>
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 13, 2025
this PR is a followup of SimpleFSDP+EP
[PR](pytorch#1529). Here, we add a
`gradient_divide_factor` following FSDP2 to ensure modules wrapped by
(FSDP+EP) has the correct gradient reduction value.

- The original FSDP2 implementation is in this
[PR](pytorch#1551).
- The `gradient_divide_factor` logic is
[here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688)

We have two ways of handling `gradient_divide_factor` in
`reduce_scatter`:

1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the
`gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only
accepts `reduce_op` as a str input
([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)).

To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM,
we need to update the DTensor `_reduce_shard_tensor` and
`torch.distributed._functional_collectives.reduce_scatter_tensor` so
that it can pass the factor associated with ReduceOp.PREMUL_SUM as an
input.



2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`.
The logic is in this [Diff](https://www.internalfb.com/diff/D76546536).
It does a `div_` over gradient before performing `ReduceOp.SUM`.

Currently I'm following 2 since it is requires less change to
`_functional_collectives`.


After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4)
have identical loss:

<img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM"
src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong-size gradients in Expert Parallel MoE

3 participants