Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] DeepSpeed accuracy issue for torch.compile if activation checkpoint function not compiler disabled #6718

Open
jerrychenhf opened this issue Nov 6, 2024 · 0 comments
Assignees
Labels
bug Something isn't working training

Comments

@jerrychenhf
Copy link

jerrychenhf commented Nov 6, 2024

Describe the bug
When running DeepSpeed with torch.compile and activation checkpointing on LLAMA2 model training, if activation checkpoint function is not compiler disabled, we found accuracy issues for the training (fine tuning).

DeepSpeed implemented a WA for resolving this accuracy issue:
#5590

But the root cause was never identified and a proper fix from the root cause should be provided instead of a WA.

To Reproduce
The following condition need to be satisfied to reproduce:

  1. We used LLAMA2 7B model
  2. torch.compile is used for compiling the model being trained
  3. Peft LoRA optimization applied to the model
  4. DeepSpeed ZeRO3 training
  5. Activation checkpointing is enabled
  6. remove the compiler.disable decorator on the DeepSpeed checkpoint function.

Expected behavior
The accuracy with or without compiler.disable decorator on the DeepSpeed checkpoint function should be the same.
But actually, without compiler.disable decorator on the DeepSpeed checkpoint function, the end accuracy achieved is much lower.

Investigation
We recently have a in-depth investigation and analysis to this DeepSpeed issue and found the root cause of the issue. This bug entry is to track this problem we found behind the scene so that we can provide a PR to fix which can be understood and accepted.

What is happening:

  1. LoRA and activation checkpointing is in use, the "x = x.to(lora_A.weight.dtype)" code line the triggered a special path to ZeRO3 parameter offloading. Which is ZeROOrderedDict getitem method.
result = self.base_layer(x, *args, **kwargs)

lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result = result + lora_B(lora_A(dropout(x))) * scaling
  1. ZeROOrderedDict getitem calls into param.all_gather if the param data is not avaiable.
        if hasattr(param, "ds_status") and getattr(param, "ds_status") == ZeroParamStatus.NOT_AVAILABLE:
            if self._parent_module._parameters._in_forward:
                register_external_parameter(FWD_MODULE_STACK[-1], param)
                param.all_gather()
                print_rank_0(f'Registering external parameter from getter {key} ds_id = {param.ds_id}', force=False)
  1. When this is in torch.compile context, the _allgather_params function called by param.all_gather will be finally torch compiled.
  2. And torch.compile will finally compile the following line of code in the _allgather_params.
    param.data = replicated_tensor.data
  3. The param.data = replicated_tensor.data assignment was traced to param.set_(replicated_tensor.detach()) by dynamo.
  4. When in-place set_ is called on param, DeepSpeed hooks on AccNode is not working as expected. This behavior is demonstrated in PyTorch issue we submitted: Hooks on param AccumulateGrad are not called when the param was called with set_ pytorch/pytorch#139742
def wrapper(param):
    param_tmp = param.expand_as(param)
    grad_acc = param_tmp.grad_fn.next_functions[0][0]

    @instrument_w_nvtx
    def reduce_partition_and_remove_grads(*notneeded):
        self.reduce_ready_partitions_and_remove_grads(param)

    self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
    self.grad_accs.append(grad_acc) 
  1. When reduce_partition_and_remove_grads was not called for those parameters which were called by set_, the grads data will be wrong and thus cause accuracy issues.

Based on the PyTorch community discussion in pytorch/pytorch#139742. register_post_accumulate_grad_hook is the robust API to do this work instead of using AccuNode hooks.

I have verified that register_post_accumulate_grad_hook works and solves the accuracy issue in the original environment with the accuracy issue.

@jerrychenhf jerrychenhf added bug Something isn't working training labels Nov 6, 2024
@tohtana tohtana self-assigned this Nov 8, 2024
deepcharm added a commit to deepcharm/DeepSpeed that referenced this issue Nov 21, 2024
* This commit addresses an issue reported in:
  microsoft#6718
* The existing code has been using the grad_acc node hook to reduce params grads.
  The constructs such as param.data = replicated_tensor.data used in
  allgather_params(..) are compiled into param.set() causing the hook assigned
  to the grad_acc node not being called.
* This is a known torch issue pytorch/pytorch#139742.
* The above caused accuracy issues and could be temporarily solved by simply
  disabling the torch compile when activation checkpointing is used.
* This commit provides a clean solution by replacing the hook on a grad_acc node
  to a hook using a new and robust hook API on a param itself:
  param.register_post_accumulate_grad_hook(..)
loadams added a commit that referenced this issue Jan 3, 2025
      * This commit addresses a Deepspeed issue
[#6718](#6718)
* The existing code has been using the grad_acc node hook to reduce
params grads.
The constructs such as `param.data = replicated_tensor.data` used in
`allgather_params(..)`
are compiled into `param.set()` causing the hook assigned to the
grad_acc node not being called.
* Starting from PyTorch 2.1 there is a new and robust hook API on a
param itself: `param.register_post_accumulate_grad_hook(..)`
* This commit will make use of the proper API depending on the PyTorch
version
* It will also disable compile for PyTorch versions < 2.1

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants