Skip to content

Conversation

@deepcharm
Copy link
Contributor

This PR is a continuation of the efforts to improve DeepSpeed performance when using PyTorch compile.

Dynamo breaks the graph because flat_tensor.requires_grad = False:

  • Is a side-effecting operation on tensor metadata
  • Occurs in a context where Dynamo expects static tensor properties for tracing

flat_tensor.requires_grad is redundant and can be safely removed because:

  • _allgather_params() function is already decorated with @torch.no_grad() which ensures the desired property
  • flat_tensor is created using the torch.empty() which sets the requires_grad=False by default.

This PR is a continuation of the efforts to improve Deepspeed performance when using PyTorch compile.

Dynamo breaks the graph because flat_tensor.requires_grad = False

* Is a side-effecting operation on tensor metadata
* Occurs in a context where Dynamo expects static tensor properties for tracing

flat_tensor.requires_grad is redundant and can be safely removed because:
* _allgather_params function is already decorated with @torch.no_grad()
which ensures the desired property
* flat_tensor is created using the torch.empty(..) which sets the
  requires_grad=False by default.

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
This reverts commit 11612773b3d68aa5b8d72bad1de4b1714ea1193a.

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
@loadams loadams enabled auto-merge March 20, 2025 19:46
@loadams loadams added this pull request to the merge queue Mar 22, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 23, 2025
@tjruwase tjruwase added this pull request to the merge queue Mar 24, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 24, 2025
@loadams loadams added this pull request to the merge queue Mar 24, 2025
Merged via the queue into deepspeedai:master with commit d40cf46 Mar 24, 2025
11 checks passed
loadams added a commit that referenced this pull request Mar 25, 2025
)

This PR is a continuation of the efforts to improve DeepSpeed
performance when using PyTorch compile.

Dynamo breaks the graph because `flat_tensor.requires_grad = False`:

* Is a side-effecting operation on tensor metadata
* Occurs in a context where Dynamo expects static tensor properties for
tracing

`flat_tensor.requires_grad` is redundant and can be safely removed
because:
* `_allgather_params()` function is already decorated with
`@torch.no_grad()` which ensures the desired property
* `flat_tensor` is created using the `torch.empty()` which sets the
`requires_grad=False` by default.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Mar 28, 2025
…epspeedai#7158)

This PR is a continuation of the efforts to improve DeepSpeed
performance when using PyTorch compile.

Dynamo breaks the graph because `flat_tensor.requires_grad = False`:

* Is a side-effecting operation on tensor metadata
* Occurs in a context where Dynamo expects static tensor properties for
tracing

`flat_tensor.requires_grad` is redundant and can be safely removed
because:
* `_allgather_params()` function is already decorated with
`@torch.no_grad()` which ensures the desired property
* `flat_tensor` is created using the `torch.empty()` which sets the
`requires_grad=False` by default.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
deepcharm added a commit to deepcharm/DeepSpeed that referenced this pull request Apr 29, 2025
This PR is an follow-up to PR (deepspeedai#7158) handling the same issue.

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
github-merge-queue bot pushed a commit that referenced this pull request May 19, 2025
…7263)

This PR is an follow-up to [PR
#7158](#7158) handling the
same issue in another place.
See [PR #7158](#7158) for
details.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
ys950902 pushed a commit to ys950902/DeepSpeed that referenced this pull request May 21, 2025
…epspeedai#7158)

This PR is a continuation of the efforts to improve DeepSpeed
performance when using PyTorch compile.

Dynamo breaks the graph because `flat_tensor.requires_grad = False`:

* Is a side-effecting operation on tensor metadata
* Occurs in a context where Dynamo expects static tensor properties for
tracing

`flat_tensor.requires_grad` is redundant and can be safely removed
because:
* `_allgather_params()` function is already decorated with
`@torch.no_grad()` which ensures the desired property
* `flat_tensor` is created using the `torch.empty()` which sets the
`requires_grad=False` by default.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
deepcharm added a commit to deepcharm/DeepSpeed that referenced this pull request Jun 16, 2025
…eepspeedai#7263)

This PR is an follow-up to [PR
deepspeedai#7158](deepspeedai#7158) handling the
same issue in another place.
See [PR deepspeedai#7158](deepspeedai#7158) for
details.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
…eepspeedai#7263)

This PR is an follow-up to [PR
deepspeedai#7158](deepspeedai#7158) handling the
same issue in another place.
See [PR deepspeedai#7158](deepspeedai#7158) for
details.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants