You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When running DeepSpeed with torch.compile and activation checkpointing on LLAMA2 model training, if activation checkpoint function is not compiler disabled, we found accuracy issues for the training (fine tuning).
DeepSpeed implemented a WA for resolving this accuracy issue: #5590
But the root cause was never identified and a proper fix from the root cause should be provided instead of a WA.
To Reproduce
The following condition need to be satisfied to reproduce:
We used LLAMA2 7B model
torch.compile is used for compiling the model being trained
Peft LoRA optimization applied to the model
DeepSpeed ZeRO3 training
Activation checkpointing is enabled
remove the compiler.disable decorator on the DeepSpeed checkpoint function.
Expected behavior
The accuracy with or without compiler.disable decorator on the DeepSpeed checkpoint function should be the same.
But actually, without compiler.disable decorator on the DeepSpeed checkpoint function, the end accuracy achieved is much lower.
Investigation
We recently have a in-depth investigation and analysis to this DeepSpeed issue and found the root cause of the issue. This bug entry is to track this problem we found behind the scene so that we can provide a PR to fix which can be understood and accepted.
What is happening:
LoRA and activation checkpointing is in use, the "x = x.to(lora_A.weight.dtype)" code line the triggered a special path to ZeRO3 parameter offloading. Which is ZeROOrderedDict getitem method.
result = self.base_layer(x, *args, **kwargs)
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result = result + lora_B(lora_A(dropout(x))) * scaling
ZeROOrderedDict getitem calls into param.all_gather if the param data is not avaiable.
if hasattr(param, "ds_status") and getattr(param, "ds_status") == ZeroParamStatus.NOT_AVAILABLE:
if self._parent_module._parameters._in_forward:
register_external_parameter(FWD_MODULE_STACK[-1], param)
param.all_gather()
print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False)
When this is in torch.compile context, the _allgather_params function called by param.all_gather will be finally torch compiled.
And torch.compile will finally compile the following line of code in the _allgather_params. param.data = replicated_tensor.data
The param.data = replicated_tensor.data assignment was traced to param.set_(replicated_tensor.detach()) by dynamo.
When reduce_partition_and_remove_grads was not called for those parameters which were called by set_, the grads data will be wrong and thus cause accuracy issues.
Based on the PyTorch community discussion in pytorch/pytorch#139742. register_post_accumulate_grad_hook is the robust API to do this work instead of using AccuNode hooks.
I have verified that register_post_accumulate_grad_hook works and solves the accuracy issue in the original environment with the accuracy issue.
The text was updated successfully, but these errors were encountered:
* This commit addresses an issue reported in:
microsoft#6718
* The existing code has been using the grad_acc node hook to reduce params grads.
The constructs such as param.data = replicated_tensor.data used in
allgather_params(..) are compiled into param.set() causing the hook assigned
to the grad_acc node not being called.
* This is a known torch issue pytorch/pytorch#139742.
* The above caused accuracy issues and could be temporarily solved by simply
disabling the torch compile when activation checkpointing is used.
* This commit provides a clean solution by replacing the hook on a grad_acc node
to a hook using a new and robust hook API on a param itself:
param.register_post_accumulate_grad_hook(..)
* This commit addresses a Deepspeed issue
[#6718](#6718)
* The existing code has been using the grad_acc node hook to reduce
params grads.
The constructs such as `param.data = replicated_tensor.data` used in
`allgather_params(..)`
are compiled into `param.set()` causing the hook assigned to the
grad_acc node not being called.
* Starting from PyTorch 2.1 there is a new and robust hook API on a
param itself: `param.register_post_accumulate_grad_hook(..)`
* This commit will make use of the proper API depending on the PyTorch
version
* It will also disable compile for PyTorch versions < 2.1
---------
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Describe the bug
When running DeepSpeed with torch.compile and activation checkpointing on LLAMA2 model training, if activation checkpoint function is not compiler disabled, we found accuracy issues for the training (fine tuning).
DeepSpeed implemented a WA for resolving this accuracy issue:
#5590
But the root cause was never identified and a proper fix from the root cause should be provided instead of a WA.
To Reproduce
The following condition need to be satisfied to reproduce:
Expected behavior
The accuracy with or without compiler.disable decorator on the DeepSpeed checkpoint function should be the same.
But actually, without compiler.disable decorator on the DeepSpeed checkpoint function, the end accuracy achieved is much lower.
Investigation
We recently have a in-depth investigation and analysis to this DeepSpeed issue and found the root cause of the issue. This bug entry is to track this problem we found behind the scene so that we can provide a PR to fix which can be understood and accepted.
What is happening:
param.data = replicated_tensor.data
Based on the PyTorch community discussion in pytorch/pytorch#139742. register_post_accumulate_grad_hook is the robust API to do this work instead of using AccuNode hooks.
I have verified that register_post_accumulate_grad_hook works and solves the accuracy issue in the original environment with the accuracy issue.
The text was updated successfully, but these errors were encountered: