- 
                Notifications
    
You must be signed in to change notification settings  - Fork 594
 
[simplefsdp] fix simplefsdp gradient_divide_factor #1793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8a414cf    to
    4e14a0f      
    Compare
  
    4e14a0f    to
    dfccea8      
    Compare
  
    dfccea8    to
    fc14199      
    Compare
  
    91c2e8e    to
    5f0be26      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After enabling reduction_divide_factor, we will see FSDP(=2) + EP (=4) have identical loss:
From the pictures, it doesn't look they are identical. Did you fix the seed when comparing FSDP2 vs. SimpleFSDP?
| torch.float32, | ||
| torch.bfloat16, | ||
| ), "only support reduce_dtype to be fp32/bf16" | ||
| pre_factor, post_factor = self.reduction_divide_factor, None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should do post-multiply according to https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L738-L739
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm according to here "PREMUL_SUM multiplies inputs by a given scalar locally before reduction" Link. I think we should do pre_factor instead of post_factor. Besides, this diff also does division locally before calling reduce scatter: https://www.internalfb.com/diff/D76546536
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm why do we care about PREMUL_SUM and MTIA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have a non-overflow dtype and non-MTIA device, this function would use PREMUL_SUM to do _get_gradient_divide_factors (the prefactor/postfactor are both None).
I didn't use PREMUL_SUM for the reason I mentioned in the PR description. Thus, I want to simulate what PREMUL_SUM does following MTIA by doing a division over self.reduction_divide_factor, and then doing reduce_scatter("sum").
Using either pre_factor or post_factor would give the same loss results. But I still have the concern that if we first do a SUM and then do division, the gathered float number might overflow....
| ac_mode=job_config.activation_checkpoint.mode, | ||
| mp_policy=mp_policy, | ||
| shard_dim=experts_shard_dim, | ||
| reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we probably should deprecate this logic anyway See related PR #1803 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add a todo for it
          
 turns out i didn't open deterministic, updated the fig  | 
    
5f0be26    to
    f668434      
    Compare
  
    f668434    to
    7ce9911      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using either pre_factor or post_factor would give the same loss results. But I still have the concern that if we first do a SUM and then do division, the gathered float number might overflow....
I see your concern. I'm not sure how realistic it is, but OK to stick with this for now.
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](pytorch/torchtitan#1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`. Pull Request resolved: #164985 Approved by: https://github.com/ezyang, https://github.com/eellison
this PR is a followup of SimpleFSDP+EP [PR](pytorch#1529). Here, we add a `gradient_divide_factor` following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value. - The original FSDP2 implementation is in this [PR](pytorch#1551). - The `gradient_divide_factor` logic is [here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688) We have two ways of handling `gradient_divide_factor` in `reduce_scatter`: 1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the `gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only accepts `reduce_op` as a str input ([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)). To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor `_reduce_shard_tensor` and `torch.distributed._functional_collectives.reduce_scatter_tensor` so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input. 2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`. The logic is in this [Diff](https://www.internalfb.com/diff/D76546536). It does a `div_` over gradient before performing `ReduceOp.SUM`. Currently I'm following 2 since it is requires less change to `_functional_collectives`. After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4) have identical loss: <img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM" src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de" />
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](pytorch/torchtitan#1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`. Pull Request resolved: pytorch#164985 Approved by: https://github.com/ezyang, https://github.com/eellison
this PR is a followup of SimpleFSDP+EP PR. Here, we add a
gradient_divide_factorfollowing FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value.gradient_divide_factorlogic is hereWe have two ways of handling
gradient_divide_factorinreduce_scatter:ReduceOp.PREMUL_SUMto handle thegradient_divide_factor. However, DTensor's_reduce_shard_valueonly acceptsreduce_opas a str input (here).To make
_reduce_shard_valuework correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor_reduce_shard_tensorandtorch.distributed._functional_collectives.reduce_scatter_tensorso that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input.ReduceOp.PREMUL_SUMwithReduceOp.SUM. The logic is in this Diff. It does adiv_over gradient before performingReduceOp.SUM.Currently I'm following 2 since it is requires less change to
_functional_collectives.After enabling
reduction_divide_factor, we will see FSDP(=2) + EP (=4) have identical loss: