-
Notifications
You must be signed in to change notification settings - Fork 96
Representing DTensor in thunder traces #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kshitij12345
wants to merge
41
commits into
Lightning-AI:main
Choose a base branch
from
kshitij12345:dtensor-init-support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+675
−6
Open
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
5873742
dtensor support
kshitij12345 377125a
add comment
kshitij12345 7ab82f6
add more comments
kshitij12345 e6aa8d3
update comment
kshitij12345 e76fc17
add test for execpted failing cases
kshitij12345 eaac9f7
support for method
kshitij12345 94ef69d
update failing case test
kshitij12345 5d81851
remove generated traces
kshitij12345 7277753
undo pre-commit change
kshitij12345 a8c58e4
undo debug changes
kshitij12345 d87b103
update failing test to use thunder.jit
kshitij12345 b101161
update registration helper
kshitij12345 b551cb8
Apply suggestions from code review
kshitij12345 1c75a80
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 5854c86
address review and upadte
kshitij12345 a778830
update dtensor proxy repr
kshitij12345 41990d0
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 eda0277
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 8abf040
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 225f2e3
update jit_ext access to torchfn_to_thunder registry : test
kshitij12345 2b85b31
empty commit
kshitij12345 5d0296f
Revert "update jit_ext access to torchfn_to_thunder registry : test"
kshitij12345 dedab03
temp commit
kshitij12345 efaae1d
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 ddcf208
update to manual decomp
kshitij12345 6a6bf11
add manual grad rule
kshitij12345 2a8ea02
update
kshitij12345 9490cac
update - clean-up
kshitij12345 bd1ecbb
update attrs on DTensorProxy
kshitij12345 5d1f20b
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 255c82d
remove debug change
kshitij12345 dba02d2
remove unused imports
kshitij12345 e8f6d0b
remove unused import
kshitij12345 83ae80a
update function name
kshitij12345 4cd9bec
cotangent metadata check initial support
kshitij12345 b49df80
address review : p1
kshitij12345 b206632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b70238a
address review
kshitij12345 1ff76b3
Merge branch 'dtensor-init-support' of https://github.com/kshitij1234…
kshitij12345 3e295da
update and refactor
kshitij12345 8f4a029
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import unittest | ||
|
||
import pytest | ||
import torch | ||
|
||
if not torch.distributed.is_available(): | ||
pytest.skip(allow_module_level=True) | ||
|
||
from thunder.dynamo import thunderfx | ||
import thunder | ||
|
||
from thunder.tests.distributed.helper import DistributedParallelTestCase | ||
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor | ||
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate | ||
|
||
|
||
@unittest.skipUnless( | ||
torch.cuda.is_available() and torch.distributed.is_nccl_available(), | ||
"DTensor test requires CUDA and NCCL `torch.distributed` backend", | ||
) | ||
class DTensorTest(DistributedParallelTestCase): | ||
def test_dtensor_basic_op(self): | ||
num_devices = self.world_size | ||
mesh = DeviceMesh("cuda", list(range(num_devices))) | ||
|
||
dim_size = 16 | ||
|
||
def _helper(fn, in_dtensor, w_dtensor): | ||
expected = torch.compile(fn)(in_dtensor, w_dtensor) | ||
tmodel = thunder.jit(fn) | ||
actual = tmodel(in_dtensor, w_dtensor) | ||
|
||
torch.testing.assert_close(actual, expected) | ||
|
||
g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)]) | ||
expected_g = torch.autograd.grad( | ||
expected, | ||
(in_dtensor, w_dtensor), | ||
g_o, | ||
) | ||
actual_g = torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) | ||
|
||
torch.testing.assert_close(actual_g, expected_g) | ||
|
||
w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
|
||
# Verify torch API works | ||
_helper(lambda x, w: torch.mul(x, w), in_dtensor, w_dtensor) | ||
|
||
# Verify calling method works | ||
_helper(lambda x, w: torch.Tensor.mul(x, w), in_dtensor, w_dtensor) | ||
|
||
# # Verify calling special method works | ||
_helper(lambda x, w: x * w, in_dtensor, w_dtensor) | ||
|
||
def test_dtensor_unsupported(self): | ||
num_devices = self.world_size | ||
mesh = DeviceMesh("cuda", list(range(num_devices))) | ||
|
||
dim_size = 16 | ||
|
||
w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
|
||
in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
|
||
def fn(x, w): | ||
return torch.div(x, w) | ||
|
||
tmodel = thunder.jit(fn) | ||
with pytest.raises(AssertionError): | ||
tmodel(in_dtensor, w_dtensor) | ||
|
||
def fn(x, w): | ||
return x / w | ||
|
||
tmodel = thunder.jit(fn) | ||
with pytest.raises(AssertionError): | ||
tmodel(in_dtensor, w_dtensor) | ||
|
||
def test_dtensor_unsupported_mixed_input(self): | ||
num_devices = self.world_size | ||
mesh = DeviceMesh("cuda", list(range(num_devices))) | ||
|
||
dim_size = 16 | ||
|
||
def fn(x, w): | ||
return torch.div(x, w) | ||
|
||
w = torch.randn(dim_size, dim_size, requires_grad=True) | ||
|
||
in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
|
||
tmodel = thunder.jit(fn) | ||
with pytest.raises(AssertionError): | ||
tmodel(in_dtensor, w) | ||
|
||
def test_dtensor_incorrect_cotangent(self): | ||
num_devices = self.world_size | ||
mesh = DeviceMesh("cuda", list(range(num_devices))) | ||
|
||
dim_size = 16 | ||
|
||
w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) | ||
|
||
def fn(x, w): | ||
return torch.mul(x, w) | ||
|
||
tmodel = thunder.jit(fn) | ||
actual = tmodel(in_dtensor, w_dtensor) | ||
g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(1)]) | ||
|
||
with pytest.raises(RuntimeError, match="has changed for cotangent between tracing and runtime"): | ||
torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Any | ||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, DeviceMesh, TensorMeta | ||
from torch.distributed.tensor import DeviceMesh, Partial, Placement, Replicate, Shard | ||
|
||
|
||
def populate_object_ctx_for_dtensor_spec(x: Any, object_ctx: dict[str, Any]) -> bool: | ||
""" | ||
Populate object context for DTensorSpec. | ||
|
||
..note:: | ||
This function will mutate the `object_ctx` | ||
|
||
Returns: | ||
bool: True if `x` is DTensorSpec (and also updates `object_ctx`) otherwise False. | ||
""" | ||
if isinstance(x, DTensorSpec): | ||
object_ctx["DTensorSpec"] = DTensorSpec | ||
object_ctx["DeviceMesh"] = DeviceMesh | ||
object_ctx["Placement"] = Placement | ||
object_ctx["Replicate"] = Replicate | ||
object_ctx["Shard"] = Shard | ||
object_ctx["Partial"] = Partial | ||
object_ctx["TensorMeta"] = TensorMeta | ||
return True | ||
return False | ||
|
||
|
||
def prettyprint_dtensor_spec(x): | ||
if isinstance(x, DTensorSpec): | ||
return x.__repr__() | ||
return "" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit-picking: Potential necessity of
torch.distributed
check.Probably I just don't understand the implementation correctly but I'm not convinced with this registration's "safety" when
torch.distributed
is not available because IIUC this registration really usesDTensor
and run some "meta" computation.Though in most cases the submodule itself should be available so I'm not that strongly concerned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is true for other files as well. Will have a look later, thanks!