Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid graph breaks in torch.compile caused by inner classes in the backward hooks #7062

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

deepcharm
Copy link
Contributor

This PR is part of the effort to improve Deepspeed performance when using PyTorch compile.

There is a known bug in torch.compile which causes a graph break when an inner class is defined within
a method that is being compiled. The following would then appear in the log:

[__graph_breaks] torch._dynamo.exc.Unsupported: missing: LOAD_BUILD_CLASS

This is the case with the inner classes PreBackwardFunctionForModule and PostBackwardFunctionModule.

While there is an open PyTorch PR#133805 for this, we can solve the issue by moving the inner classes into the initialization code.

No graph breaks and the corresponding logs are produced anymore.

This PR is part of the effort to improve Deepspeed
performance when using PyTorch compile.

There is a known bug in torch.compile
(pytorch/pytorch#128942)
which causes a graph break when a inner class
is defined within a method that is being compiled
and logs the following message:

[__graph_breaks] torch._dynamo.exc.Unsupported: missing: LOAD_BUILD_CLASS

This is the case with the inner classes
PreBackwardFunctionForModule and PostBackwardFunctionModule.

While there is an open PR for this
(pytorch/pytorch#133805),
we can solve the issue by moving the inner classes
into the initialization code.

No graph breaks and the corresponding logs
are produced anymore.
@tjruwase
Copy link
Contributor

@tjruwase
Copy link
Contributor

@deepcharm, gentle ping on DCO blocker

tjruwase and others added 14 commits February 26, 2025 16:08
Propagate API change.

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
- add zero2 test
- minor fix with transformer version update & ds master merge.

Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
bf16 with moe refresh optimizer state from bf16 ckpt will raise
IndexError: list index out of range

Signed-off-by: shaomin <wukon1992@gmail.com>
Co-authored-by: shaomin <wukon1992@gmail.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
**Auto-generated PR to update version.txt after a DeepSpeed release**
Released version - 0.16.4
Author           - @loadams

Co-authored-by: loadams <loadams@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
@jeffra and I fixed this many years ago, so bringing this doc to a
correct state.

---------

Signed-off-by: Stas Bekman <stas@stason.org>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Description
This PR includes Tecorigin SDAA accelerator support.
With this PR, DeepSpeed supports SDAA as backend for training tasks.

---------

Signed-off-by: siqi <siqi@tecorigin.com>
Co-authored-by: siqi <siqi@tecorigin.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
More information on libuv in pytorch:
https://pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html
Issue tracking the prevalence of the error on Windows (unresolved at the
time of this PR): pytorch/pytorch#139990
LibUV github: https://github.com/libuv/libuv

Windows error:
```
  File "C:\hostedtoolcache\windows\Python\3.12.7\x64\Lib\site-packages\torch\distributed\rendezvous.py", line 189, in _create_c10d_store
    return TCPStore(
           ^^^^^^^^^
RuntimeError: use_libuv was requested but PyTorch was build without libuv support
```

use_libuv isn't well supported on Windows in pytorch <2.4, so we need to
guard around this case.

---------

Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
@fukun07 and I discovered a bug when using the `offload_states` and
`reload_states` APIs of the Zero3 optimizer. When using grouped
parameters (for example, in weight decay or grouped lr scenarios), the
order of the parameters mapping in `reload_states`
([here](https://github.com/deepspeedai/DeepSpeed/blob/14b3cce4aaedac69120d386953e2b4cae8c2cf2c/deepspeed/runtime/zero/stage3.py#L2953))
does not correspond with the initialization of `self.lp_param_buffer`
([here](https://github.com/deepspeedai/DeepSpeed/blob/14b3cce4aaedac69120d386953e2b4cae8c2cf2c/deepspeed/runtime/zero/stage3.py#L731)),
which leads to misaligned parameter loading. This issue was overlooked
by the corresponding unit tests
([here](https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/runtime/zero/test_offload_states.py)),
so we fixed the bug in our PR and added the corresponding unit tests.

---------

Signed-off-by: Wei Wu <wuwei211x@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il>
Co-authored-by: snahir <snahir@habana.ai>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
The partition tensor doesn't need to move to the current device when
meta load is used.

Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
…t` (deepspeedai#7069)

With future changes coming to pip/python/etc, we need to modify to no
longer call `python setup.py ...` and replace it instead:
https://packaging.python.org/en/latest/guides/modernize-setup-py-project/#should-setup-py-be-deleted

![image](https://github.com/user-attachments/assets/ea39ef7b-3cbe-4916-86f0-bc46a5fce96d)

This means we need to install the build package which is added here as
well.

Additionally, we pass the `--sdist` flag to only build the sdist rather
than the wheel as well here.

---------

Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
…eepspeedai#7076)

This reverts commit 8577bd2.

Fixes: deepspeedai#7072
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
@deepcharm deepcharm force-pushed the avoid-graph-break-caused-by-inner-classes branch from 561bd6e to d4efb6a Compare February 26, 2025 14:08
@deepcharm
Copy link
Contributor Author

@deepcharm, gentle ping on DCO blocker

@tjruwase DCO should be fine now, sorry for the delay

@loadams loadams enabled auto-merge February 26, 2025 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants